diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis')
56 files changed, 8746 insertions, 4516 deletions
diff --git a/contrib/llvm/lib/Analysis/AliasAnalysis.cpp b/contrib/llvm/lib/Analysis/AliasAnalysis.cpp index 84da76be98bb..4c29aeaa622f 100644 --- a/contrib/llvm/lib/Analysis/AliasAnalysis.cpp +++ b/contrib/llvm/lib/Analysis/AliasAnalysis.cpp @@ -332,8 +332,8 @@ FunctionModRefBehavior AAResults::getModRefBehavior(const Function *F) { ModRefInfo AAResults::getModRefInfo(const LoadInst *L, const MemoryLocation &Loc) { - // Be conservative in the face of volatile/atomic. - if (!L->isUnordered()) + // Be conservative in the face of atomic. + if (isStrongerThan(L->getOrdering(), AtomicOrdering::Unordered)) return MRI_ModRef; // If the load address doesn't alias the given address, it doesn't read @@ -347,8 +347,8 @@ ModRefInfo AAResults::getModRefInfo(const LoadInst *L, ModRefInfo AAResults::getModRefInfo(const StoreInst *S, const MemoryLocation &Loc) { - // Be conservative in the face of volatile/atomic. - if (!S->isUnordered()) + // Be conservative in the face of atomic. + if (isStrongerThan(S->getOrdering(), AtomicOrdering::Unordered)) return MRI_ModRef; if (Loc.Ptr) { @@ -367,6 +367,14 @@ ModRefInfo AAResults::getModRefInfo(const StoreInst *S, return MRI_Mod; } +ModRefInfo AAResults::getModRefInfo(const FenceInst *S, const MemoryLocation &Loc) { + // If we know that the location is a constant memory location, the fence + // cannot modify this location. + if (Loc.Ptr && pointsToConstantMemory(Loc)) + return MRI_Ref; + return MRI_ModRef; +} + ModRefInfo AAResults::getModRefInfo(const VAArgInst *V, const MemoryLocation &Loc) { @@ -689,7 +697,7 @@ AAResults llvm::createLegacyPMAAResults(Pass &P, Function &F, bool llvm::isNoAliasCall(const Value *V) { if (auto CS = ImmutableCallSite(V)) - return CS.paramHasAttr(0, Attribute::NoAlias); + return CS.hasRetAttr(Attribute::NoAlias); return false; } diff --git a/contrib/llvm/lib/Analysis/AliasSetTracker.cpp b/contrib/llvm/lib/Analysis/AliasSetTracker.cpp index 701b0e1a5925..16b711a69ec3 100644 --- a/contrib/llvm/lib/Analysis/AliasSetTracker.cpp +++ b/contrib/llvm/lib/Analysis/AliasSetTracker.cpp @@ -199,9 +199,10 @@ bool AliasSet::aliasesPointer(const Value *Ptr, uint64_t Size, // Check the unknown instructions... if (!UnknownInsts.empty()) { for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) - if (AA.getModRefInfo(UnknownInsts[i], - MemoryLocation(Ptr, Size, AAInfo)) != MRI_NoModRef) - return true; + if (auto *Inst = getUnknownInst(i)) + if (AA.getModRefInfo(Inst, MemoryLocation(Ptr, Size, AAInfo)) != + MRI_NoModRef) + return true; } return false; @@ -217,10 +218,12 @@ bool AliasSet::aliasesUnknownInst(const Instruction *Inst, return false; for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) { - ImmutableCallSite C1(getUnknownInst(i)), C2(Inst); - if (!C1 || !C2 || AA.getModRefInfo(C1, C2) != MRI_NoModRef || - AA.getModRefInfo(C2, C1) != MRI_NoModRef) - return true; + if (auto *Inst = getUnknownInst(i)) { + ImmutableCallSite C1(Inst), C2(Inst); + if (!C1 || !C2 || AA.getModRefInfo(C1, C2) != MRI_NoModRef || + AA.getModRefInfo(C2, C1) != MRI_NoModRef) + return true; + } } for (iterator I = begin(), E = end(); I != E; ++I) @@ -471,7 +474,8 @@ void AliasSetTracker::add(const AliasSetTracker &AST) { // If there are any call sites in the alias set, add them to this AST. for (unsigned i = 0, e = AS.UnknownInsts.size(); i != e; ++i) - add(AS.UnknownInsts[i]); + if (auto *Inst = AS.getUnknownInst(i)) + add(Inst); // Loop over all of the pointers in this alias set. for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) { @@ -489,19 +493,6 @@ void AliasSetTracker::add(const AliasSetTracker &AST) { // dangling pointers to deleted instructions. // void AliasSetTracker::deleteValue(Value *PtrVal) { - // If this is a call instruction, remove the callsite from the appropriate - // AliasSet (if present). - if (Instruction *Inst = dyn_cast<Instruction>(PtrVal)) { - if (Inst->mayReadOrWriteMemory()) { - // Scan all the alias sets to see if this call site is contained. - for (iterator I = begin(), E = end(); I != E;) { - iterator Cur = I++; - if (!Cur->Forward) - Cur->removeUnknownInst(*this, Inst); - } - } - } - // First, look up the PointerRec for this pointer. PointerMapType::iterator I = PointerMap.find_as(PtrVal); if (I == PointerMap.end()) return; // Noop @@ -633,7 +624,8 @@ void AliasSet::print(raw_ostream &OS) const { OS << "\n " << UnknownInsts.size() << " Unknown instructions: "; for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) { if (i) OS << ", "; - UnknownInsts[i]->printAsOperand(OS); + if (auto *I = getUnknownInst(i)) + I->printAsOperand(OS); } } OS << "\n"; diff --git a/contrib/llvm/lib/Analysis/Analysis.cpp b/contrib/llvm/lib/Analysis/Analysis.cpp index 0e7cf402cdb5..0e0b5c92a918 100644 --- a/contrib/llvm/lib/Analysis/Analysis.cpp +++ b/contrib/llvm/lib/Analysis/Analysis.cpp @@ -57,6 +57,7 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeLazyBranchProbabilityInfoPassPass(Registry); initializeLazyBlockFrequencyInfoPassPass(Registry); initializeLazyValueInfoWrapperPassPass(Registry); + initializeLazyValueInfoPrinterPass(Registry); initializeLintPass(Registry); initializeLoopInfoWrapperPassPass(Registry); initializeMemDepPrinterPass(Registry); @@ -78,6 +79,8 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeTypeBasedAAWrapperPassPass(Registry); initializeScopedNoAliasAAWrapperPassPass(Registry); initializeLCSSAVerificationPassPass(Registry); + initializeMemorySSAWrapperPassPass(Registry); + initializeMemorySSAPrinterLegacyPassPass(Registry); } void LLVMInitializeAnalysis(LLVMPassRegistryRef R) { diff --git a/contrib/llvm/lib/Analysis/AssumptionCache.cpp b/contrib/llvm/lib/Analysis/AssumptionCache.cpp index 5851594700a4..0468c794e81d 100644 --- a/contrib/llvm/lib/Analysis/AssumptionCache.cpp +++ b/contrib/llvm/lib/Analysis/AssumptionCache.cpp @@ -24,15 +24,21 @@ using namespace llvm; using namespace llvm::PatternMatch; -SmallVector<WeakVH, 1> &AssumptionCache::getOrInsertAffectedValues(Value *V) { +static cl::opt<bool> + VerifyAssumptionCache("verify-assumption-cache", cl::Hidden, + cl::desc("Enable verification of assumption cache"), + cl::init(false)); + +SmallVector<WeakTrackingVH, 1> & +AssumptionCache::getOrInsertAffectedValues(Value *V) { // Try using find_as first to avoid creating extra value handles just for the // purpose of doing the lookup. auto AVI = AffectedValues.find_as(V); if (AVI != AffectedValues.end()) return AVI->second; - auto AVIP = AffectedValues.insert({ - AffectedValueCallbackVH(V, this), SmallVector<WeakVH, 1>()}); + auto AVIP = AffectedValues.insert( + {AffectedValueCallbackVH(V, this), SmallVector<WeakTrackingVH, 1>()}); return AVIP.first->second; } @@ -47,9 +53,11 @@ void AssumptionCache::updateAffectedValues(CallInst *CI) { } else if (auto *I = dyn_cast<Instruction>(V)) { Affected.push_back(I); - if (I->getOpcode() == Instruction::BitCast || - I->getOpcode() == Instruction::PtrToInt) { - auto *Op = I->getOperand(0); + // Peek through unary operators to find the source of the condition. + Value *Op; + if (match(I, m_BitCast(m_Value(Op))) || + match(I, m_PtrToInt(m_Value(Op))) || + match(I, m_Not(m_Value(Op)))) { if (isa<Instruction>(Op) || isa<Argument>(Op)) Affected.push_back(Op); } @@ -229,7 +237,13 @@ AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) { } void AssumptionCacheTracker::verifyAnalysis() const { -#ifndef NDEBUG + // FIXME: In the long term the verifier should not be controllable with a + // flag. We should either fix all passes to correctly update the assumption + // cache and enable the verifier unconditionally or somehow arrange for the + // assumption list to be updated automatically by passes. + if (!VerifyAssumptionCache) + return; + SmallPtrSet<const CallInst *, 4> AssumptionSet; for (const auto &I : AssumptionCaches) { for (auto &VH : I.second->assumptions()) @@ -238,11 +252,10 @@ void AssumptionCacheTracker::verifyAnalysis() const { for (const BasicBlock &B : cast<Function>(*I.first)) for (const Instruction &II : B) - if (match(&II, m_Intrinsic<Intrinsic::assume>())) - assert(AssumptionSet.count(cast<CallInst>(&II)) && - "Assumption in scanned function not in cache"); + if (match(&II, m_Intrinsic<Intrinsic::assume>()) && + !AssumptionSet.count(cast<CallInst>(&II))) + report_fatal_error("Assumption in scanned function not in cache"); } -#endif } AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) { diff --git a/contrib/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/contrib/llvm/lib/Analysis/BasicAliasAnalysis.cpp index c8d057949493..f743cb234c45 100644 --- a/contrib/llvm/lib/Analysis/BasicAliasAnalysis.cpp +++ b/contrib/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -17,13 +17,13 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/Analysis/AssumptionCache.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" @@ -36,6 +36,7 @@ #include "llvm/IR/Operator.h" #include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" #include <algorithm> #define DEBUG_TYPE "basicaa" @@ -127,7 +128,9 @@ static uint64_t getObjectSize(const Value *V, const DataLayout &DL, const TargetLibraryInfo &TLI, bool RoundToAlign = false) { uint64_t Size; - if (getObjectSize(V, Size, DL, &TLI, RoundToAlign)) + ObjectSizeOpts Opts; + Opts.RoundToAlign = RoundToAlign; + if (getObjectSize(V, Size, DL, &TLI, Opts)) return Size; return MemoryLocation::UnknownSize; } @@ -635,7 +638,7 @@ FunctionModRefBehavior BasicAAResult::getModRefBehavior(const Function *F) { /// Returns true if this is a writeonly (i.e Mod only) parameter. static bool isWriteOnlyParam(ImmutableCallSite CS, unsigned ArgIdx, const TargetLibraryInfo &TLI) { - if (CS.paramHasAttr(ArgIdx + 1, Attribute::WriteOnly)) + if (CS.paramHasAttr(ArgIdx, Attribute::WriteOnly)) return true; // We can bound the aliasing properties of memset_pattern16 just as we can @@ -644,9 +647,9 @@ static bool isWriteOnlyParam(ImmutableCallSite CS, unsigned ArgIdx, // whenever possible. // FIXME Consider handling this in InferFunctionAttr.cpp together with other // attributes. - LibFunc::Func F; + LibFunc F; if (CS.getCalledFunction() && TLI.getLibFunc(*CS.getCalledFunction(), F) && - F == LibFunc::memset_pattern16 && TLI.has(F)) + F == LibFunc_memset_pattern16 && TLI.has(F)) if (ArgIdx == 0) return true; @@ -664,10 +667,10 @@ ModRefInfo BasicAAResult::getArgModRefInfo(ImmutableCallSite CS, if (isWriteOnlyParam(CS, ArgIdx, TLI)) return MRI_Mod; - if (CS.paramHasAttr(ArgIdx + 1, Attribute::ReadOnly)) + if (CS.paramHasAttr(ArgIdx, Attribute::ReadOnly)) return MRI_Ref; - if (CS.paramHasAttr(ArgIdx + 1, Attribute::ReadNone)) + if (CS.paramHasAttr(ArgIdx, Attribute::ReadNone)) return MRI_NoModRef; return AAResultBase::getArgModRefInfo(CS, ArgIdx); @@ -680,8 +683,11 @@ static bool isIntrinsicCall(ImmutableCallSite CS, Intrinsic::ID IID) { #ifndef NDEBUG static const Function *getParent(const Value *V) { - if (const Instruction *inst = dyn_cast<Instruction>(V)) + if (const Instruction *inst = dyn_cast<Instruction>(V)) { + if (!inst->getParent()) + return nullptr; return inst->getParent()->getParent(); + } if (const Argument *arg = dyn_cast<Argument>(V)) return arg->getParent(); @@ -749,7 +755,11 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, // as an argument, and itself doesn't capture it. if (!isa<Constant>(Object) && CS.getInstruction() != Object && isNonEscapingLocalObject(Object)) { - bool PassedAsArg = false; + + // Optimistically assume that call doesn't touch Object and check this + // assumption in the following loop. + ModRefInfo Result = MRI_NoModRef; + unsigned OperandNo = 0; for (auto CI = CS.data_operands_begin(), CE = CS.data_operands_end(); CI != CE; ++CI, ++OperandNo) { @@ -761,20 +771,38 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, OperandNo < CS.getNumArgOperands() && !CS.isByValArgument(OperandNo))) continue; + // Call doesn't access memory through this operand, so we don't care + // if it aliases with Object. + if (CS.doesNotAccessMemory(OperandNo)) + continue; + // If this is a no-capture pointer argument, see if we can tell that it - // is impossible to alias the pointer we're checking. If not, we have to - // assume that the call could touch the pointer, even though it doesn't - // escape. + // is impossible to alias the pointer we're checking. AliasResult AR = getBestAAResults().alias(MemoryLocation(*CI), MemoryLocation(Object)); - if (AR) { - PassedAsArg = true; - break; + + // Operand doesnt alias 'Object', continue looking for other aliases + if (AR == NoAlias) + continue; + // Operand aliases 'Object', but call doesn't modify it. Strengthen + // initial assumption and keep looking in case if there are more aliases. + if (CS.onlyReadsMemory(OperandNo)) { + Result = static_cast<ModRefInfo>(Result | MRI_Ref); + continue; } + // Operand aliases 'Object' but call only writes into it. + if (CS.doesNotReadMemory(OperandNo)) { + Result = static_cast<ModRefInfo>(Result | MRI_Mod); + continue; + } + // This operand aliases 'Object' and call reads and writes into it. + Result = MRI_ModRef; + break; } - if (!PassedAsArg) - return MRI_NoModRef; + // Early return if we improved mod ref information + if (Result != MRI_ModRef) + return Result; } // If the CallSite is to malloc or calloc, we can assume that it doesn't @@ -784,7 +812,7 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, // well. Or alternatively, replace all of this with inaccessiblememonly once // that's implemented fully. auto *Inst = CS.getInstruction(); - if (isMallocLikeFn(Inst, &TLI) || isCallocLikeFn(Inst, &TLI)) { + if (isMallocOrCallocLikeFn(Inst, &TLI)) { // Be conservative if the accessed pointer may alias the allocation - // fallback to the generic handling below. if (getBestAAResults().alias(MemoryLocation(Inst), Loc) == NoAlias) @@ -900,10 +928,9 @@ static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, uint64_t V2Size, const DataLayout &DL) { - assert(GEP1->getPointerOperand()->stripPointerCasts() == - GEP2->getPointerOperand()->stripPointerCasts() && - GEP1->getPointerOperand()->getType() == - GEP2->getPointerOperand()->getType() && + assert(GEP1->getPointerOperand()->stripPointerCastsAndBarriers() == + GEP2->getPointerOperand()->stripPointerCastsAndBarriers() && + GEP1->getPointerOperandType() == GEP2->getPointerOperandType() && "Expected GEPs with the same pointer operand"); // Try to determine whether GEP1 and GEP2 index through arrays, into structs, @@ -1161,10 +1188,9 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // If we know the two GEPs are based off of the exact same pointer (and not // just the same underlying object), see if that tells us anything about // the resulting pointers. - if (GEP1->getPointerOperand()->stripPointerCasts() == - GEP2->getPointerOperand()->stripPointerCasts() && - GEP1->getPointerOperand()->getType() == - GEP2->getPointerOperand()->getType()) { + if (GEP1->getPointerOperand()->stripPointerCastsAndBarriers() == + GEP2->getPointerOperand()->stripPointerCastsAndBarriers() && + GEP1->getPointerOperandType() == GEP2->getPointerOperandType()) { AliasResult R = aliasSameBasePointerGEPs(GEP1, V1Size, GEP2, V2Size, DL); // If we couldn't find anything interesting, don't abandon just yet. if (R != MayAlias) @@ -1261,9 +1287,9 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // give up if we can't determine conditions that hold for every cycle: const Value *V = DecompGEP1.VarIndices[i].V; - bool SignKnownZero, SignKnownOne; - ComputeSignBit(const_cast<Value *>(V), SignKnownZero, SignKnownOne, DL, - 0, &AC, nullptr, DT); + KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, DT); + bool SignKnownZero = Known.isNonNegative(); + bool SignKnownOne = Known.isNegative(); // Zero-extension widens the variable, and so forces the sign // bit to zero. @@ -1478,8 +1504,8 @@ AliasResult BasicAAResult::aliasCheck(const Value *V1, uint64_t V1Size, return NoAlias; // Strip off any casts if they exist. - V1 = V1->stripPointerCasts(); - V2 = V2->stripPointerCasts(); + V1 = V1->stripPointerCastsAndBarriers(); + V2 = V2->stripPointerCastsAndBarriers(); // If V1 or V2 is undef, the result is NoAlias because we can always pick a // value for undef that aliases nothing in the program. diff --git a/contrib/llvm/lib/Analysis/BlockFrequencyInfo.cpp b/contrib/llvm/lib/Analysis/BlockFrequencyInfo.cpp index 4cdbe4d0fcf6..07a2a9229fd5 100644 --- a/contrib/llvm/lib/Analysis/BlockFrequencyInfo.cpp +++ b/contrib/llvm/lib/Analysis/BlockFrequencyInfo.cpp @@ -26,7 +26,6 @@ using namespace llvm; #define DEBUG_TYPE "block-freq" -#ifndef NDEBUG static cl::opt<GVDAGType> ViewBlockFreqPropagationDAG( "view-block-freq-propagation-dags", cl::Hidden, cl::desc("Pop up a window to show a dag displaying how block " @@ -55,8 +54,29 @@ cl::opt<unsigned> "is no less than the max frequency of the " "function multiplied by this percent.")); +// Command line option to turn on CFG dot dump after profile annotation. +cl::opt<bool> + PGOViewCounts("pgo-view-counts", cl::init(false), cl::Hidden, + cl::desc("A boolean option to show CFG dag with " + "block profile counts and branch probabilities " + "right after PGO profile annotation step. The " + "profile counts are computed using branch " + "probabilities from the runtime profile data and " + "block frequency propagation algorithm. To view " + "the raw counts from the profile, use option " + "-pgo-view-raw-counts instead. To limit graph " + "display to only one function, use filtering option " + "-view-bfi-func-name.")); + namespace llvm { +static GVDAGType getGVDT() { + + if (PGOViewCounts) + return GVDT_Count; + return ViewBlockFreqPropagationDAG; +} + template <> struct GraphTraits<BlockFrequencyInfo *> { typedef const BasicBlock *NodeRef; @@ -89,8 +109,7 @@ struct DOTGraphTraits<BlockFrequencyInfo *> : public BFIDOTGTraitsBase { std::string getNodeLabel(const BasicBlock *Node, const BlockFrequencyInfo *Graph) { - return BFIDOTGTraitsBase::getNodeLabel(Node, Graph, - ViewBlockFreqPropagationDAG); + return BFIDOTGTraitsBase::getNodeLabel(Node, Graph, getGVDT()); } std::string getNodeAttributes(const BasicBlock *Node, @@ -107,7 +126,6 @@ struct DOTGraphTraits<BlockFrequencyInfo *> : public BFIDOTGTraitsBase { }; } // end namespace llvm -#endif BlockFrequencyInfo::BlockFrequencyInfo() {} @@ -132,19 +150,26 @@ BlockFrequencyInfo &BlockFrequencyInfo::operator=(BlockFrequencyInfo &&RHS) { // template instantiated which is not available in the header. BlockFrequencyInfo::~BlockFrequencyInfo() {} +bool BlockFrequencyInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<BlockFrequencyAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + void BlockFrequencyInfo::calculate(const Function &F, const BranchProbabilityInfo &BPI, const LoopInfo &LI) { if (!BFI) BFI.reset(new ImplType); BFI->calculate(F, BPI, LI); -#ifndef NDEBUG if (ViewBlockFreqPropagationDAG != GVDT_None && (ViewBlockFreqFuncName.empty() || F.getName().equals(ViewBlockFreqFuncName))) { view(); } -#endif } BlockFrequency BlockFrequencyInfo::getBlockFreq(const BasicBlock *BB) const { @@ -171,16 +196,32 @@ void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) { BFI->setBlockFreq(BB, Freq); } +void BlockFrequencyInfo::setBlockFreqAndScale( + const BasicBlock *ReferenceBB, uint64_t Freq, + SmallPtrSetImpl<BasicBlock *> &BlocksToScale) { + assert(BFI && "Expected analysis to be available"); + // Use 128 bits APInt to avoid overflow. + APInt NewFreq(128, Freq); + APInt OldFreq(128, BFI->getBlockFreq(ReferenceBB).getFrequency()); + APInt BBFreq(128, 0); + for (auto *BB : BlocksToScale) { + BBFreq = BFI->getBlockFreq(BB).getFrequency(); + // Multiply first by NewFreq and then divide by OldFreq + // to minimize loss of precision. + BBFreq *= NewFreq; + // udiv is an expensive operation in the general case. If this ends up being + // a hot spot, one of the options proposed in + // https://reviews.llvm.org/D28535#650071 could be used to avoid this. + BBFreq = BBFreq.udiv(OldFreq); + BFI->setBlockFreq(BB, BBFreq.getLimitedValue()); + } + BFI->setBlockFreq(ReferenceBB, Freq); +} + /// Pop up a ghostview window with the current block frequency propagation /// rendered using dot. void BlockFrequencyInfo::view() const { -// This code is only for debugging. -#ifndef NDEBUG ViewGraph(const_cast<BlockFrequencyInfo *>(this), "BlockFrequencyDAGs"); -#else - errs() << "BlockFrequencyInfo::view is only available in debug builds on " - "systems with Graphviz or gv!\n"; -#endif // NDEBUG } const Function *BlockFrequencyInfo::getFunction() const { diff --git a/contrib/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/contrib/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp index 9850e02fca22..e5d8c3347c16 100644 --- a/contrib/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ b/contrib/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -28,7 +28,9 @@ ScaledNumber<uint64_t> BlockMass::toScaled() const { return ScaledNumber<uint64_t>(getMass() + 1, -64); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void BlockMass::dump() const { print(dbgs()); } +#endif static char getHexDigit(int N) { assert(N < 16); diff --git a/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp index 3eabb780398c..267e19adfe4d 100644 --- a/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -58,19 +58,12 @@ char BranchProbabilityInfoWrapperPass::ID = 0; static const uint32_t LBH_TAKEN_WEIGHT = 124; static const uint32_t LBH_NONTAKEN_WEIGHT = 4; -/// \brief Unreachable-terminating branch taken weight. +/// \brief Unreachable-terminating branch taken probability. /// -/// This is the weight for a branch being taken to a block that terminates +/// This is the probability for a branch being taken to a block that terminates /// (eventually) in unreachable. These are predicted as unlikely as possible. -static const uint32_t UR_TAKEN_WEIGHT = 1; - -/// \brief Unreachable-terminating branch not-taken weight. -/// -/// This is the weight for a branch not being taken toward a block that -/// terminates (eventually) in unreachable. Such a branch is essentially never -/// taken. Set the weight to an absurdly high value so that nested loops don't -/// easily subsume it. -static const uint32_t UR_NONTAKEN_WEIGHT = 1024*1024 - 1; +/// All reachable probability will equally share the remaining part. +static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1); /// \brief Weight for a branch taken going into a cold block. /// @@ -108,11 +101,9 @@ static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1; /// instruction. This is essentially never taken. static const uint32_t IH_NONTAKEN_WEIGHT = 1; -/// \brief Calculate edge weights for successors lead to unreachable. -/// -/// Predict that a successor which leads necessarily to an -/// unreachable-terminated block as extremely unlikely. -bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { +/// \brief Add \p BB to PostDominatedByUnreachable set if applicable. +void +BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) { const TerminatorInst *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 0) { if (isa<UnreachableInst>(TI) || @@ -122,39 +113,85 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { // never execute. BB->getTerminatingDeoptimizeCall()) PostDominatedByUnreachable.insert(BB); - return false; + return; + } + + // If the terminator is an InvokeInst, check only the normal destination block + // as the unwind edge of InvokeInst is also very unlikely taken. + if (auto *II = dyn_cast<InvokeInst>(TI)) { + if (PostDominatedByUnreachable.count(II->getNormalDest())) + PostDominatedByUnreachable.insert(BB); + return; } + for (auto *I : successors(BB)) + // If any of successor is not post dominated then BB is also not. + if (!PostDominatedByUnreachable.count(I)) + return; + + PostDominatedByUnreachable.insert(BB); +} + +/// \brief Add \p BB to PostDominatedByColdCall set if applicable. +void +BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) { + assert(!PostDominatedByColdCall.count(BB)); + const TerminatorInst *TI = BB->getTerminator(); + if (TI->getNumSuccessors() == 0) + return; + + // If all of successor are post dominated then BB is also done. + if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) { + return PostDominatedByColdCall.count(SuccBB); + })) { + PostDominatedByColdCall.insert(BB); + return; + } + + // If the terminator is an InvokeInst, check only the normal destination + // block as the unwind edge of InvokeInst is also very unlikely taken. + if (auto *II = dyn_cast<InvokeInst>(TI)) + if (PostDominatedByColdCall.count(II->getNormalDest())) { + PostDominatedByColdCall.insert(BB); + return; + } + + // Otherwise, if the block itself contains a cold function, add it to the + // set of blocks post-dominated by a cold call. + for (auto &I : *BB) + if (const CallInst *CI = dyn_cast<CallInst>(&I)) + if (CI->hasFnAttr(Attribute::Cold)) { + PostDominatedByColdCall.insert(BB); + return; + } +} + +/// \brief Calculate edge weights for successors lead to unreachable. +/// +/// Predict that a successor which leads necessarily to an +/// unreachable-terminated block as extremely unlikely. +bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { + const TerminatorInst *TI = BB->getTerminator(); + assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); + + // Return false here so that edge weights for InvokeInst could be decided + // in calcInvokeHeuristics(). + if (isa<InvokeInst>(TI)) + return false; + SmallVector<unsigned, 4> UnreachableEdges; SmallVector<unsigned, 4> ReachableEdges; - for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) if (PostDominatedByUnreachable.count(*I)) UnreachableEdges.push_back(I.getSuccessorIndex()); else ReachableEdges.push_back(I.getSuccessorIndex()); - } - // If all successors are in the set of blocks post-dominated by unreachable, - // this block is too. - if (UnreachableEdges.size() == TI->getNumSuccessors()) - PostDominatedByUnreachable.insert(BB); - - // Skip probabilities if this block has a single successor or if all were - // reachable. - if (TI->getNumSuccessors() == 1 || UnreachableEdges.empty()) + // Skip probabilities if all were reachable. + if (UnreachableEdges.empty()) return false; - // If the terminator is an InvokeInst, check only the normal destination block - // as the unwind edge of InvokeInst is also very unlikely taken. - if (auto *II = dyn_cast<InvokeInst>(TI)) - if (PostDominatedByUnreachable.count(II->getNormalDest())) { - PostDominatedByUnreachable.insert(BB); - // Return false here so that edge weights for InvokeInst could be decided - // in calcInvokeHeuristics(). - return false; - } - if (ReachableEdges.empty()) { BranchProbability Prob(1, UnreachableEdges.size()); for (unsigned SuccIdx : UnreachableEdges) @@ -162,12 +199,10 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { return true; } - auto UnreachableProb = BranchProbability::getBranchProbability( - UR_TAKEN_WEIGHT, (UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) * - uint64_t(UnreachableEdges.size())); - auto ReachableProb = BranchProbability::getBranchProbability( - UR_NONTAKEN_WEIGHT, - (UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) * uint64_t(ReachableEdges.size())); + auto UnreachableProb = UR_TAKEN_PROB; + auto ReachableProb = + (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) / + ReachableEdges.size(); for (unsigned SuccIdx : UnreachableEdges) setEdgeProbability(BB, SuccIdx, UnreachableProb); @@ -178,11 +213,12 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { } // Propagate existing explicit probabilities from either profile data or -// 'expect' intrinsic processing. +// 'expect' intrinsic processing. Examine metadata against unreachable +// heuristic. The probability of the edge coming to unreachable block is +// set to min of metadata and unreachable heuristic. bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { const TerminatorInst *TI = BB->getTerminator(); - if (TI->getNumSuccessors() == 1) - return false; + assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI)) return false; @@ -203,6 +239,8 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { // be scaled to fit in 32 bits. uint64_t WeightSum = 0; SmallVector<uint32_t, 2> Weights; + SmallVector<unsigned, 2> UnreachableIdxs; + SmallVector<unsigned, 2> ReachableIdxs; Weights.reserve(TI->getNumSuccessors()); for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) { ConstantInt *Weight = @@ -213,6 +251,10 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { "Too many bits for uint32_t"); Weights.push_back(Weight->getZExtValue()); WeightSum += Weights.back(); + if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1))) + UnreachableIdxs.push_back(i - 1); + else + ReachableIdxs.push_back(i - 1); } assert(Weights.size() == TI->getNumSuccessors() && "Checked above"); @@ -221,22 +263,49 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { uint64_t ScalingFactor = (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1; - WeightSum = 0; - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { - Weights[i] /= ScalingFactor; - WeightSum += Weights[i]; + if (ScalingFactor > 1) { + WeightSum = 0; + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + Weights[i] /= ScalingFactor; + WeightSum += Weights[i]; + } } + assert(WeightSum <= UINT32_MAX && + "Expected weights to scale down to 32 bits"); - if (WeightSum == 0) { - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - setEdgeProbability(BB, i, {1, e}); - } else { + if (WeightSum == 0 || ReachableIdxs.size() == 0) { for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - setEdgeProbability(BB, i, {Weights[i], static_cast<uint32_t>(WeightSum)}); + Weights[i] = 1; + WeightSum = TI->getNumSuccessors(); } - assert(WeightSum <= UINT32_MAX && - "Expected weights to scale down to 32 bits"); + // Set the probability. + SmallVector<BranchProbability, 2> BP; + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) }); + + // Examine the metadata against unreachable heuristic. + // If the unreachable heuristic is more strong then we use it for this edge. + if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) { + auto ToDistribute = BranchProbability::getZero(); + auto UnreachableProb = UR_TAKEN_PROB; + for (auto i : UnreachableIdxs) + if (UnreachableProb < BP[i]) { + ToDistribute += BP[i] - UnreachableProb; + BP[i] = UnreachableProb; + } + + // If we modified the probability of some edges then we must distribute + // the difference between reachable blocks. + if (ToDistribute > BranchProbability::getZero()) { + BranchProbability PerEdge = ToDistribute / ReachableIdxs.size(); + for (auto i : ReachableIdxs) + BP[i] += PerEdge; + } + } + + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + setEdgeProbability(BB, i, BP[i]); return true; } @@ -251,7 +320,11 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { /// Return false, otherwise. bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) { const TerminatorInst *TI = BB->getTerminator(); - if (TI->getNumSuccessors() == 0) + assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); + + // Return false here so that edge weights for InvokeInst could be decided + // in calcInvokeHeuristics(). + if (isa<InvokeInst>(TI)) return false; // Determine which successors are post-dominated by a cold block. @@ -263,34 +336,8 @@ bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) { else NormalEdges.push_back(I.getSuccessorIndex()); - // If all successors are in the set of blocks post-dominated by cold calls, - // this block is in the set post-dominated by cold calls. - if (ColdEdges.size() == TI->getNumSuccessors()) - PostDominatedByColdCall.insert(BB); - else { - // Otherwise, if the block itself contains a cold function, add it to the - // set of blocks postdominated by a cold call. - assert(!PostDominatedByColdCall.count(BB)); - for (BasicBlock::const_iterator I = BB->begin(), E = BB->end(); I != E; ++I) - if (const CallInst *CI = dyn_cast<CallInst>(I)) - if (CI->hasFnAttr(Attribute::Cold)) { - PostDominatedByColdCall.insert(BB); - break; - } - } - - if (auto *II = dyn_cast<InvokeInst>(TI)) { - // If the terminator is an InvokeInst, consider only the normal destination - // block. - if (PostDominatedByColdCall.count(II->getNormalDest())) - PostDominatedByColdCall.insert(BB); - // Return false here so that edge weights for InvokeInst could be decided - // in calcInvokeHeuristics(). - return false; - } - - // Skip probabilities if this block has a single successor. - if (TI->getNumSuccessors() == 1 || ColdEdges.empty()) + // Skip probabilities if no cold edges. + if (ColdEdges.empty()) return false; if (NormalEdges.empty()) { @@ -671,10 +718,15 @@ void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI) { // the successors of a block iteratively. for (auto BB : post_order(&F.getEntryBlock())) { DEBUG(dbgs() << "Computing probabilities for " << BB->getName() << "\n"); - if (calcUnreachableHeuristics(BB)) + updatePostDominatedByUnreachable(BB); + updatePostDominatedByColdCall(BB); + // If there is no at least two successors, no sense to set probability. + if (BB->getTerminator()->getNumSuccessors() < 2) continue; if (calcMetadataWeights(BB)) continue; + if (calcUnreachableHeuristics(BB)) + continue; if (calcColdCallHeuristics(BB)) continue; if (calcLoopBranchHeuristics(BB, LI)) diff --git a/contrib/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp b/contrib/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp index e48ff230f43c..ddd5123d0eff 100644 --- a/contrib/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp +++ b/contrib/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp @@ -307,7 +307,7 @@ class CFLAndersAAResult::FunctionInfo { public: FunctionInfo(const Function &, const SmallVectorImpl<Value *> &, - const ReachabilitySet &, AliasAttrMap); + const ReachabilitySet &, const AliasAttrMap &); bool mayAlias(const Value *, uint64_t, const Value *, uint64_t) const; const AliasSummary &getAliasSummary() const { return Summary; } @@ -470,7 +470,7 @@ static void populateExternalAttributes( CFLAndersAAResult::FunctionInfo::FunctionInfo( const Function &Fn, const SmallVectorImpl<Value *> &RetVals, - const ReachabilitySet &ReachSet, AliasAttrMap AMap) { + const ReachabilitySet &ReachSet, const AliasAttrMap &AMap) { populateAttrMap(AttrMap, AMap); populateExternalAttributes(Summary.RetParamAttributes, Fn, RetVals, AMap); populateAliasMap(AliasMap, ReachSet); diff --git a/contrib/llvm/lib/Analysis/CFLGraph.h b/contrib/llvm/lib/Analysis/CFLGraph.h index e526e0e16aa7..54782b6bd4ad 100644 --- a/contrib/llvm/lib/Analysis/CFLGraph.h +++ b/contrib/llvm/lib/Analysis/CFLGraph.h @@ -210,6 +210,11 @@ template <typename CFLAA> class CFLGraphBuilder { void addDerefEdge(Value *From, Value *To, bool IsRead) { assert(From != nullptr && To != nullptr); + // FIXME: This is subtly broken, due to how we model some instructions + // (e.g. extractvalue, extractelement) as loads. Since those take + // non-pointer operands, we'll entirely skip adding edges for those. + // + // addAssignEdge seems to have a similar issue with insertvalue, etc. if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy()) return; addNode(From); @@ -400,8 +405,7 @@ template <typename CFLAA> class CFLGraphBuilder { // TODO: address other common library functions such as realloc(), // strdup(), // etc. - if (isMallocLikeFn(Inst, &TLI) || isCallocLikeFn(Inst, &TLI) || - isFreeCall(Inst, &TLI)) + if (isMallocOrCallocLikeFn(Inst, &TLI) || isFreeCall(Inst, &TLI)) return; // TODO: Add support for noalias args/all the other fun function @@ -430,7 +434,7 @@ template <typename CFLAA> class CFLGraphBuilder { if (Inst->getType()->isPointerTy()) { auto *Fn = CS.getCalledFunction(); - if (Fn == nullptr || !Fn->doesNotAlias(0)) + if (Fn == nullptr || !Fn->returnDoesNotAlias()) // No need to call addNode() since we've added Inst at the // beginning of this function and we know it is not a global. Graph.addAttr(InstantiatedValue{Inst, 0}, getAttrUnknown()); @@ -541,6 +545,7 @@ template <typename CFLAA> class CFLGraphBuilder { case Instruction::ExtractValue: { auto *Ptr = CE->getOperand(0); addLoadEdge(Ptr, CE); + break; } case Instruction::ShuffleVector: { auto *From1 = CE->getOperand(0); diff --git a/contrib/llvm/lib/Analysis/CGSCCPassManager.cpp b/contrib/llvm/lib/Analysis/CGSCCPassManager.cpp index 054bdc45ad67..9d4521221f47 100644 --- a/contrib/llvm/lib/Analysis/CGSCCPassManager.cpp +++ b/contrib/llvm/lib/Analysis/CGSCCPassManager.cpp @@ -117,6 +117,7 @@ bool CGSCCAnalysisManagerModuleProxy::Result::invalidate( PA.allAnalysesInSetPreserved<AllAnalysesOn<LazyCallGraph::SCC>>(); // Ok, we have a graph, so we can propagate the invalidation down into it. + G->buildRefSCCs(); for (auto &RC : G->postorder_ref_sccs()) for (auto &C : RC) { Optional<PreservedAnalyses> InnerPA; @@ -273,9 +274,9 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( // demoted edges. SmallVector<Constant *, 16> Worklist; SmallPtrSet<Constant *, 16> Visited; - SmallPtrSet<Function *, 16> RetainedEdges; - SmallSetVector<Function *, 4> PromotedRefTargets; - SmallSetVector<Function *, 4> DemotedCallTargets; + SmallPtrSet<Node *, 16> RetainedEdges; + SmallSetVector<Node *, 4> PromotedRefTargets; + SmallSetVector<Node *, 4> DemotedCallTargets; // First walk the function and handle all called functions. We do this first // because if there is a single call edge, whether there are ref edges is @@ -284,7 +285,8 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( if (auto CS = CallSite(&I)) if (Function *Callee = CS.getCalledFunction()) if (Visited.insert(Callee).second && !Callee->isDeclaration()) { - const Edge *E = N.lookup(*Callee); + Node &CalleeN = *G.lookup(*Callee); + Edge *E = N->lookup(CalleeN); // FIXME: We should really handle adding new calls. While it will // make downstream usage more complex, there is no fundamental // limitation and it will allow passes within the CGSCC to be a bit @@ -293,9 +295,9 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( assert(E && "No function transformations should introduce *new* " "call edges! Any new calls should be modeled as " "promoted existing ref edges!"); - RetainedEdges.insert(Callee); + RetainedEdges.insert(&CalleeN); if (!E->isCall()) - PromotedRefTargets.insert(Callee); + PromotedRefTargets.insert(&CalleeN); } // Now walk all references. @@ -306,24 +308,25 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( Worklist.push_back(C); LazyCallGraph::visitReferences(Worklist, Visited, [&](Function &Referee) { - const Edge *E = N.lookup(Referee); + Node &RefereeN = *G.lookup(Referee); + Edge *E = N->lookup(RefereeN); // FIXME: Similarly to new calls, we also currently preclude // introducing new references. See above for details. assert(E && "No function transformations should introduce *new* ref " "edges! Any new ref edges would require IPO which " "function passes aren't allowed to do!"); - RetainedEdges.insert(&Referee); + RetainedEdges.insert(&RefereeN); if (E->isCall()) - DemotedCallTargets.insert(&Referee); + DemotedCallTargets.insert(&RefereeN); }); // First remove all of the edges that are no longer present in this function. // We have to build a list of dead targets first and then remove them as the // data structures will all be invalidated by removing them. SmallVector<PointerIntPair<Node *, 1, Edge::Kind>, 4> DeadTargets; - for (Edge &E : N) - if (!RetainedEdges.count(&E.getFunction())) - DeadTargets.push_back({E.getNode(), E.getKind()}); + for (Edge &E : *N) + if (!RetainedEdges.count(&E.getNode())) + DeadTargets.push_back({&E.getNode(), E.getKind()}); for (auto DeadTarget : DeadTargets) { Node &TargetN = *DeadTarget.getPointer(); bool IsCall = DeadTarget.getInt() == Edge::Call; @@ -397,9 +400,8 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( // Next demote all the call edges that are now ref edges. This helps make // the SCCs small which should minimize the work below as we don't want to // form cycles that this would break. - for (Function *RefTarget : DemotedCallTargets) { - Node &TargetN = *G.lookup(*RefTarget); - SCC &TargetC = *G.lookupSCC(TargetN); + for (Node *RefTarget : DemotedCallTargets) { + SCC &TargetC = *G.lookupSCC(*RefTarget); RefSCC &TargetRC = TargetC.getOuterRefSCC(); // The easy case is when the target RefSCC is not this RefSCC. This is @@ -407,10 +409,10 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( if (&TargetRC != RC) { assert(RC->isAncestorOf(TargetRC) && "Cannot potentially form RefSCC cycles here!"); - RC->switchOutgoingEdgeToRef(N, TargetN); + RC->switchOutgoingEdgeToRef(N, *RefTarget); if (DebugLogging) dbgs() << "Switch outgoing call edge to a ref edge from '" << N - << "' to '" << TargetN << "'\n"; + << "' to '" << *RefTarget << "'\n"; continue; } @@ -418,7 +420,7 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( // some SCCs. if (C != &TargetC) { // For separate SCCs this is trivial. - RC->switchTrivialInternalEdgeToRef(N, TargetN); + RC->switchTrivialInternalEdgeToRef(N, *RefTarget); continue; } @@ -430,14 +432,13 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( // structure is changed. AM.invalidate(*C, PreservedAnalyses::none()); // Now update the call graph. - C = incorporateNewSCCRange(RC->switchInternalEdgeToRef(N, TargetN), G, - N, C, AM, UR, DebugLogging); + C = incorporateNewSCCRange(RC->switchInternalEdgeToRef(N, *RefTarget), G, N, + C, AM, UR, DebugLogging); } // Now promote ref edges into call edges. - for (Function *CallTarget : PromotedRefTargets) { - Node &TargetN = *G.lookup(*CallTarget); - SCC &TargetC = *G.lookupSCC(TargetN); + for (Node *CallTarget : PromotedRefTargets) { + SCC &TargetC = *G.lookupSCC(*CallTarget); RefSCC &TargetRC = TargetC.getOuterRefSCC(); // The easy case is when the target RefSCC is not this RefSCC. This is @@ -445,22 +446,22 @@ LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( if (&TargetRC != RC) { assert(RC->isAncestorOf(TargetRC) && "Cannot potentially form RefSCC cycles here!"); - RC->switchOutgoingEdgeToCall(N, TargetN); + RC->switchOutgoingEdgeToCall(N, *CallTarget); if (DebugLogging) dbgs() << "Switch outgoing ref edge to a call edge from '" << N - << "' to '" << TargetN << "'\n"; + << "' to '" << *CallTarget << "'\n"; continue; } if (DebugLogging) dbgs() << "Switch an internal ref edge to a call edge from '" << N - << "' to '" << TargetN << "'\n"; + << "' to '" << *CallTarget << "'\n"; // Otherwise we are switching an internal ref edge to a call edge. This // may merge away some SCCs, and we add those to the UpdateResult. We also // need to make sure to update the worklist in the event SCCs have moved // before the current one in the post-order sequence. auto InitialSCCIndex = RC->find(*C) - RC->begin(); - auto InvalidatedSCCs = RC->switchInternalEdgeToCall(N, TargetN); + auto InvalidatedSCCs = RC->switchInternalEdgeToCall(N, *CallTarget); if (!InvalidatedSCCs.empty()) { C = &TargetC; assert(G.lookupSCC(N) == C && "Failed to update current SCC!"); diff --git a/contrib/llvm/lib/Analysis/CallGraph.cpp b/contrib/llvm/lib/Analysis/CallGraph.cpp index 458b7bfae959..ff5242f69a1b 100644 --- a/contrib/llvm/lib/Analysis/CallGraph.cpp +++ b/contrib/llvm/lib/Analysis/CallGraph.cpp @@ -21,23 +21,18 @@ using namespace llvm; // CallGraph::CallGraph(Module &M) - : M(M), Root(nullptr), ExternalCallingNode(getOrInsertFunction(nullptr)), + : M(M), ExternalCallingNode(getOrInsertFunction(nullptr)), CallsExternalNode(llvm::make_unique<CallGraphNode>(nullptr)) { // Add every function to the call graph. for (Function &F : M) addToCallGraph(&F); - - // If we didn't find a main function, use the external call graph node - if (!Root) - Root = ExternalCallingNode; } CallGraph::CallGraph(CallGraph &&Arg) - : M(Arg.M), FunctionMap(std::move(Arg.FunctionMap)), Root(Arg.Root), + : M(Arg.M), FunctionMap(std::move(Arg.FunctionMap)), ExternalCallingNode(Arg.ExternalCallingNode), CallsExternalNode(std::move(Arg.CallsExternalNode)) { Arg.FunctionMap.clear(); - Arg.Root = nullptr; Arg.ExternalCallingNode = nullptr; } @@ -57,21 +52,9 @@ CallGraph::~CallGraph() { void CallGraph::addToCallGraph(Function *F) { CallGraphNode *Node = getOrInsertFunction(F); - // If this function has external linkage, anything could call it. - if (!F->hasLocalLinkage()) { - ExternalCallingNode->addCalledFunction(CallSite(), Node); - - // Found the entry point? - if (F->getName() == "main") { - if (Root) // Found multiple external mains? Don't pick one. - Root = ExternalCallingNode; - else - Root = Node; // Found a main, keep track of it! - } - } - - // If this function has its address taken, anything could call it. - if (F->hasAddressTaken()) + // If this function has external linkage or has its address taken, anything + // could call it. + if (!F->hasLocalLinkage() || F->hasAddressTaken()) ExternalCallingNode->addCalledFunction(CallSite(), Node); // If this function is not defined in this translation unit, it could call @@ -96,13 +79,6 @@ void CallGraph::addToCallGraph(Function *F) { } void CallGraph::print(raw_ostream &OS) const { - OS << "CallGraph Root is: "; - if (Function *F = Root->getFunction()) - OS << F->getName() << "\n"; - else { - OS << "<<null function: 0x" << Root << ">>\n"; - } - // Print in a deterministic order by sorting CallGraphNodes by name. We do // this here to avoid slowing down the non-printing fast path. @@ -125,8 +101,9 @@ void CallGraph::print(raw_ostream &OS) const { CN->print(OS); } -LLVM_DUMP_METHOD -void CallGraph::dump() const { print(dbgs()); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void CallGraph::dump() const { print(dbgs()); } +#endif // removeFunctionFromModule - Unlink the function from this module, returning // it. Because this removes the function from the module, the call graph node @@ -194,8 +171,9 @@ void CallGraphNode::print(raw_ostream &OS) const { OS << '\n'; } -LLVM_DUMP_METHOD -void CallGraphNode::dump() const { print(dbgs()); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void CallGraphNode::dump() const { print(dbgs()); } +#endif /// removeCallEdgeFor - This method removes the edge in the node for the /// specified call site. Note that this method takes linear time, so it @@ -307,8 +285,10 @@ void CallGraphWrapperPass::print(raw_ostream &OS, const Module *) const { G->print(OS); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void CallGraphWrapperPass::dump() const { print(dbgs(), nullptr); } +#endif namespace { struct CallGraphPrinterLegacyPass : public ModulePass { diff --git a/contrib/llvm/lib/Analysis/CallGraphSCCPass.cpp b/contrib/llvm/lib/Analysis/CallGraphSCCPass.cpp index 9cef78144150..5896e6e0902f 100644 --- a/contrib/llvm/lib/Analysis/CallGraphSCCPass.cpp +++ b/contrib/llvm/lib/Analysis/CallGraphSCCPass.cpp @@ -204,7 +204,7 @@ bool CGPassManager::RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG, // Get the set of call sites currently in the function. for (CallGraphNode::iterator I = CGN->begin(), E = CGN->end(); I != E; ) { // If this call site is null, then the function pass deleted the call - // entirely and the WeakVH nulled it out. + // entirely and the WeakTrackingVH nulled it out. if (!I->first || // If we've already seen this call site, then the FunctionPass RAUW'd // one call with another, which resulted in two "uses" in the edge @@ -347,7 +347,8 @@ bool CGPassManager::RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG, DevirtualizedCall = true; // After scanning this function, if we still have entries in callsites, then - // they are dangling pointers. WeakVH should save us for this, so abort if + // they are dangling pointers. WeakTrackingVH should save us for this, so + // abort if // this happens. assert(CallSites.empty() && "Dangling pointers found in call sites map"); @@ -476,10 +477,8 @@ bool CGPassManager::runOnModule(Module &M) { if (DevirtualizedCall) DEBUG(dbgs() << " CGSCCPASSMGR: Stopped iteration after " << Iteration << " times, due to -max-cg-scc-iterations\n"); - - if (Iteration > MaxSCCIterations) - MaxSCCIterations = Iteration; - + + MaxSCCIterations.updateMax(Iteration); } Changed |= doFinalization(CG); return Changed; @@ -609,16 +608,28 @@ namespace { } bool runOnSCC(CallGraphSCC &SCC) override { - Out << Banner; + auto PrintBannerOnce = [&] () { + static bool BannerPrinted = false; + if (BannerPrinted) + return; + Out << Banner; + BannerPrinted = true; + }; for (CallGraphNode *CGN : SCC) { if (CGN->getFunction()) { - if (isFunctionInPrintList(CGN->getFunction()->getName())) + if (isFunctionInPrintList(CGN->getFunction()->getName())) { + PrintBannerOnce(); CGN->getFunction()->print(Out); - } else + } + } else if (llvm::isFunctionInPrintList("*")) { + PrintBannerOnce(); Out << "\nPrinting <null> Function\n"; + } } return false; } + + StringRef getPassName() const override { return "Print CallGraph IR"; } }; } // end anonymous namespace. diff --git a/contrib/llvm/lib/Analysis/ConstantFolding.cpp b/contrib/llvm/lib/Analysis/ConstantFolding.cpp index 73867279abe4..a906770dbb34 100644 --- a/contrib/llvm/lib/Analysis/ConstantFolding.cpp +++ b/contrib/llvm/lib/Analysis/ConstantFolding.cpp @@ -42,6 +42,7 @@ #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include <cassert> #include <cerrno> @@ -686,25 +687,21 @@ Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, Constant *Op1, // bits. if (Opc == Instruction::And) { - unsigned BitWidth = DL.getTypeSizeInBits(Op0->getType()->getScalarType()); - APInt KnownZero0(BitWidth, 0), KnownOne0(BitWidth, 0); - APInt KnownZero1(BitWidth, 0), KnownOne1(BitWidth, 0); - computeKnownBits(Op0, KnownZero0, KnownOne0, DL); - computeKnownBits(Op1, KnownZero1, KnownOne1, DL); - if ((KnownOne1 | KnownZero0).isAllOnesValue()) { + KnownBits Known0 = computeKnownBits(Op0, DL); + KnownBits Known1 = computeKnownBits(Op1, DL); + if ((Known1.One | Known0.Zero).isAllOnesValue()) { // All the bits of Op0 that the 'and' could be masking are already zero. return Op0; } - if ((KnownOne0 | KnownZero1).isAllOnesValue()) { + if ((Known0.One | Known1.Zero).isAllOnesValue()) { // All the bits of Op1 that the 'and' could be masking are already zero. return Op1; } - APInt KnownZero = KnownZero0 | KnownZero1; - APInt KnownOne = KnownOne0 & KnownOne1; - if ((KnownZero | KnownOne).isAllOnesValue()) { - return ConstantInt::get(Op0->getType(), KnownOne); - } + Known0.Zero |= Known1.Zero; + Known0.One &= Known1.One; + if (Known0.isConstant()) + return ConstantInt::get(Op0->getType(), Known0.getConstant()); } // If the constant expr is something like &A[123] - &A[4].f, fold this into a @@ -1058,8 +1055,8 @@ ConstantFoldConstantImpl(const Constant *C, const DataLayout &DL, if (It == FoldedOps.end()) { if (auto *FoldedC = ConstantFoldConstantImpl(NewC, DL, TLI, FoldedOps)) { - NewC = FoldedC; FoldedOps.insert({NewC, FoldedC}); + NewC = FoldedC; } else { FoldedOps.insert({NewC, NewC}); } @@ -1173,7 +1170,9 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, const DataLayout &DL, const TargetLibraryInfo *TLI) { // fold: icmp (inttoptr x), null -> icmp x, 0 + // fold: icmp null, (inttoptr x) -> icmp 0, x // fold: icmp (ptrtoint x), 0 -> icmp x, null + // fold: icmp 0, (ptrtoint x) -> icmp null, x // fold: icmp (inttoptr x), (inttoptr y) -> icmp trunc/zext x, trunc/zext y // fold: icmp (ptrtoint x), (ptrtoint y) -> icmp x, y // @@ -1243,6 +1242,11 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, Predicate == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; return ConstantFoldBinaryOpOperands(OpC, LHS, RHS, DL); } + } else if (isa<ConstantExpr>(Ops1)) { + // If RHS is a constant expression, but the left side isn't, swap the + // operands and try again. + Predicate = ICmpInst::getSwappedPredicate((ICmpInst::Predicate)Predicate); + return ConstantFoldCompareInstOperands(Predicate, Ops1, Ops0, DL, TLI); } return ConstantExpr::getCompare(Predicate, Ops0, Ops1); @@ -1401,7 +1405,7 @@ bool llvm::canConstantFoldCallTo(const Function *F) { return true; default: return false; - case 0: break; + case Intrinsic::not_intrinsic: break; } if (!F->hasName()) @@ -1438,6 +1442,36 @@ bool llvm::canConstantFoldCallTo(const Function *F) { Name == "sinf" || Name == "sinhf" || Name == "sqrtf"; case 't': return Name == "tan" || Name == "tanh" || Name == "tanf" || Name == "tanhf"; + case '_': + + // Check for various function names that get used for the math functions + // when the header files are preprocessed with the macro + // __FINITE_MATH_ONLY__ enabled. + // The '12' here is the length of the shortest name that can match. + // We need to check the size before looking at Name[1] and Name[2] + // so we may as well check a limit that will eliminate mismatches. + if (Name.size() < 12 || Name[1] != '_') + return false; + switch (Name[2]) { + default: + return false; + case 'a': + return Name == "__acos_finite" || Name == "__acosf_finite" || + Name == "__asin_finite" || Name == "__asinf_finite" || + Name == "__atan2_finite" || Name == "__atan2f_finite"; + case 'c': + return Name == "__cosh_finite" || Name == "__coshf_finite"; + case 'e': + return Name == "__exp_finite" || Name == "__expf_finite" || + Name == "__exp2_finite" || Name == "__exp2f_finite"; + case 'l': + return Name == "__log_finite" || Name == "__logf_finite" || + Name == "__log10_finite" || Name == "__log10f_finite"; + case 'p': + return Name == "__pow_finite" || Name == "__powf_finite"; + case 's': + return Name == "__sinh_finite" || Name == "__sinhf_finite"; + } } } @@ -1518,9 +1552,9 @@ Constant *ConstantFoldSSEConvertToInt(const APFloat &Val, bool roundTowardZero, bool isExact = false; APFloat::roundingMode mode = roundTowardZero? APFloat::rmTowardZero : APFloat::rmNearestTiesToEven; - APFloat::opStatus status = Val.convertToInteger(&UIntVal, ResultWidth, - /*isSigned=*/true, mode, - &isExact); + APFloat::opStatus status = + Val.convertToInteger(makeMutableArrayRef(UIntVal), ResultWidth, + /*isSigned=*/true, mode, &isExact); if (status != APFloat::opOK && (!roundTowardZero || status != APFloat::opInexact)) return nullptr; @@ -1630,94 +1664,108 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, return ConstantFoldFP(sin, V, Ty); case Intrinsic::cos: return ConstantFoldFP(cos, V, Ty); + case Intrinsic::sqrt: + return ConstantFoldFP(sqrt, V, Ty); } if (!TLI) return nullptr; - switch (Name[0]) { + char NameKeyChar = Name[0]; + if (Name[0] == '_' && Name.size() > 2 && Name[1] == '_') + NameKeyChar = Name[2]; + + switch (NameKeyChar) { case 'a': - if ((Name == "acos" && TLI->has(LibFunc::acos)) || - (Name == "acosf" && TLI->has(LibFunc::acosf))) + if ((Name == "acos" && TLI->has(LibFunc_acos)) || + (Name == "acosf" && TLI->has(LibFunc_acosf)) || + (Name == "__acos_finite" && TLI->has(LibFunc_acos_finite)) || + (Name == "__acosf_finite" && TLI->has(LibFunc_acosf_finite))) return ConstantFoldFP(acos, V, Ty); - else if ((Name == "asin" && TLI->has(LibFunc::asin)) || - (Name == "asinf" && TLI->has(LibFunc::asinf))) + else if ((Name == "asin" && TLI->has(LibFunc_asin)) || + (Name == "asinf" && TLI->has(LibFunc_asinf)) || + (Name == "__asin_finite" && TLI->has(LibFunc_asin_finite)) || + (Name == "__asinf_finite" && TLI->has(LibFunc_asinf_finite))) return ConstantFoldFP(asin, V, Ty); - else if ((Name == "atan" && TLI->has(LibFunc::atan)) || - (Name == "atanf" && TLI->has(LibFunc::atanf))) + else if ((Name == "atan" && TLI->has(LibFunc_atan)) || + (Name == "atanf" && TLI->has(LibFunc_atanf))) return ConstantFoldFP(atan, V, Ty); break; case 'c': - if ((Name == "ceil" && TLI->has(LibFunc::ceil)) || - (Name == "ceilf" && TLI->has(LibFunc::ceilf))) + if ((Name == "ceil" && TLI->has(LibFunc_ceil)) || + (Name == "ceilf" && TLI->has(LibFunc_ceilf))) return ConstantFoldFP(ceil, V, Ty); - else if ((Name == "cos" && TLI->has(LibFunc::cos)) || - (Name == "cosf" && TLI->has(LibFunc::cosf))) + else if ((Name == "cos" && TLI->has(LibFunc_cos)) || + (Name == "cosf" && TLI->has(LibFunc_cosf))) return ConstantFoldFP(cos, V, Ty); - else if ((Name == "cosh" && TLI->has(LibFunc::cosh)) || - (Name == "coshf" && TLI->has(LibFunc::coshf))) + else if ((Name == "cosh" && TLI->has(LibFunc_cosh)) || + (Name == "coshf" && TLI->has(LibFunc_coshf)) || + (Name == "__cosh_finite" && TLI->has(LibFunc_cosh_finite)) || + (Name == "__coshf_finite" && TLI->has(LibFunc_coshf_finite))) return ConstantFoldFP(cosh, V, Ty); break; case 'e': - if ((Name == "exp" && TLI->has(LibFunc::exp)) || - (Name == "expf" && TLI->has(LibFunc::expf))) + if ((Name == "exp" && TLI->has(LibFunc_exp)) || + (Name == "expf" && TLI->has(LibFunc_expf)) || + (Name == "__exp_finite" && TLI->has(LibFunc_exp_finite)) || + (Name == "__expf_finite" && TLI->has(LibFunc_expf_finite))) return ConstantFoldFP(exp, V, Ty); - if ((Name == "exp2" && TLI->has(LibFunc::exp2)) || - (Name == "exp2f" && TLI->has(LibFunc::exp2f))) + if ((Name == "exp2" && TLI->has(LibFunc_exp2)) || + (Name == "exp2f" && TLI->has(LibFunc_exp2f)) || + (Name == "__exp2_finite" && TLI->has(LibFunc_exp2_finite)) || + (Name == "__exp2f_finite" && TLI->has(LibFunc_exp2f_finite))) // Constant fold exp2(x) as pow(2,x) in case the host doesn't have a // C99 library. return ConstantFoldBinaryFP(pow, 2.0, V, Ty); break; case 'f': - if ((Name == "fabs" && TLI->has(LibFunc::fabs)) || - (Name == "fabsf" && TLI->has(LibFunc::fabsf))) + if ((Name == "fabs" && TLI->has(LibFunc_fabs)) || + (Name == "fabsf" && TLI->has(LibFunc_fabsf))) return ConstantFoldFP(fabs, V, Ty); - else if ((Name == "floor" && TLI->has(LibFunc::floor)) || - (Name == "floorf" && TLI->has(LibFunc::floorf))) + else if ((Name == "floor" && TLI->has(LibFunc_floor)) || + (Name == "floorf" && TLI->has(LibFunc_floorf))) return ConstantFoldFP(floor, V, Ty); break; case 'l': - if ((Name == "log" && V > 0 && TLI->has(LibFunc::log)) || - (Name == "logf" && V > 0 && TLI->has(LibFunc::logf))) + if ((Name == "log" && V > 0 && TLI->has(LibFunc_log)) || + (Name == "logf" && V > 0 && TLI->has(LibFunc_logf)) || + (Name == "__log_finite" && V > 0 && + TLI->has(LibFunc_log_finite)) || + (Name == "__logf_finite" && V > 0 && + TLI->has(LibFunc_logf_finite))) return ConstantFoldFP(log, V, Ty); - else if ((Name == "log10" && V > 0 && TLI->has(LibFunc::log10)) || - (Name == "log10f" && V > 0 && TLI->has(LibFunc::log10f))) + else if ((Name == "log10" && V > 0 && TLI->has(LibFunc_log10)) || + (Name == "log10f" && V > 0 && TLI->has(LibFunc_log10f)) || + (Name == "__log10_finite" && V > 0 && + TLI->has(LibFunc_log10_finite)) || + (Name == "__log10f_finite" && V > 0 && + TLI->has(LibFunc_log10f_finite))) return ConstantFoldFP(log10, V, Ty); - else if (IntrinsicID == Intrinsic::sqrt && - (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy())) { - if (V >= -0.0) - return ConstantFoldFP(sqrt, V, Ty); - else { - // Unlike the sqrt definitions in C/C++, POSIX, and IEEE-754 - which - // all guarantee or favor returning NaN - the square root of a - // negative number is not defined for the LLVM sqrt intrinsic. - // This is because the intrinsic should only be emitted in place of - // libm's sqrt function when using "no-nans-fp-math". - return UndefValue::get(Ty); - } - } break; case 'r': - if ((Name == "round" && TLI->has(LibFunc::round)) || - (Name == "roundf" && TLI->has(LibFunc::roundf))) + if ((Name == "round" && TLI->has(LibFunc_round)) || + (Name == "roundf" && TLI->has(LibFunc_roundf))) return ConstantFoldFP(round, V, Ty); + break; case 's': - if ((Name == "sin" && TLI->has(LibFunc::sin)) || - (Name == "sinf" && TLI->has(LibFunc::sinf))) + if ((Name == "sin" && TLI->has(LibFunc_sin)) || + (Name == "sinf" && TLI->has(LibFunc_sinf))) return ConstantFoldFP(sin, V, Ty); - else if ((Name == "sinh" && TLI->has(LibFunc::sinh)) || - (Name == "sinhf" && TLI->has(LibFunc::sinhf))) + else if ((Name == "sinh" && TLI->has(LibFunc_sinh)) || + (Name == "sinhf" && TLI->has(LibFunc_sinhf)) || + (Name == "__sinh_finite" && TLI->has(LibFunc_sinh_finite)) || + (Name == "__sinhf_finite" && TLI->has(LibFunc_sinhf_finite))) return ConstantFoldFP(sinh, V, Ty); - else if ((Name == "sqrt" && V >= 0 && TLI->has(LibFunc::sqrt)) || - (Name == "sqrtf" && V >= 0 && TLI->has(LibFunc::sqrtf))) + else if ((Name == "sqrt" && V >= 0 && TLI->has(LibFunc_sqrt)) || + (Name == "sqrtf" && V >= 0 && TLI->has(LibFunc_sqrtf))) return ConstantFoldFP(sqrt, V, Ty); break; case 't': - if ((Name == "tan" && TLI->has(LibFunc::tan)) || - (Name == "tanf" && TLI->has(LibFunc::tanf))) + if ((Name == "tan" && TLI->has(LibFunc_tan)) || + (Name == "tanf" && TLI->has(LibFunc_tanf))) return ConstantFoldFP(tan, V, Ty); - else if ((Name == "tanh" && TLI->has(LibFunc::tanh)) || - (Name == "tanhf" && TLI->has(LibFunc::tanhf))) + else if ((Name == "tanh" && TLI->has(LibFunc_tanh)) || + (Name == "tanhf" && TLI->has(LibFunc_tanhf))) return ConstantFoldFP(tanh, V, Ty); break; default: @@ -1767,6 +1815,7 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), /*roundTowardZero=*/false, Ty); + LLVM_FALLTHROUGH; case Intrinsic::x86_sse_cvttss2si: case Intrinsic::x86_sse_cvttss2si64: case Intrinsic::x86_sse2_cvttsd2si: @@ -1779,7 +1828,8 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, } if (isa<UndefValue>(Operands[0])) { - if (IntrinsicID == Intrinsic::bswap) + if (IntrinsicID == Intrinsic::bswap || + IntrinsicID == Intrinsic::bitreverse) return Operands[0]; return nullptr; } @@ -1822,14 +1872,18 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, if (!TLI) return nullptr; - if ((Name == "pow" && TLI->has(LibFunc::pow)) || - (Name == "powf" && TLI->has(LibFunc::powf))) + if ((Name == "pow" && TLI->has(LibFunc_pow)) || + (Name == "powf" && TLI->has(LibFunc_powf)) || + (Name == "__pow_finite" && TLI->has(LibFunc_pow_finite)) || + (Name == "__powf_finite" && TLI->has(LibFunc_powf_finite))) return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty); - if ((Name == "fmod" && TLI->has(LibFunc::fmod)) || - (Name == "fmodf" && TLI->has(LibFunc::fmodf))) + if ((Name == "fmod" && TLI->has(LibFunc_fmod)) || + (Name == "fmodf" && TLI->has(LibFunc_fmodf))) return ConstantFoldBinaryFP(fmod, Op1V, Op2V, Ty); - if ((Name == "atan2" && TLI->has(LibFunc::atan2)) || - (Name == "atan2f" && TLI->has(LibFunc::atan2f))) + if ((Name == "atan2" && TLI->has(LibFunc_atan2)) || + (Name == "atan2f" && TLI->has(LibFunc_atan2f)) || + (Name == "__atan2_finite" && TLI->has(LibFunc_atan2_finite)) || + (Name == "__atan2f_finite" && TLI->has(LibFunc_atan2f_finite))) return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty); } else if (auto *Op2C = dyn_cast<ConstantInt>(Operands[1])) { if (IntrinsicID == Intrinsic::powi && Ty->isHalfTy()) @@ -2022,7 +2076,7 @@ bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { if (!F) return false; - LibFunc::Func Func; + LibFunc Func; if (!TLI || !TLI->getLibFunc(*F, Func)) return false; @@ -2030,20 +2084,20 @@ bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { if (ConstantFP *OpC = dyn_cast<ConstantFP>(CS.getArgOperand(0))) { const APFloat &Op = OpC->getValueAPF(); switch (Func) { - case LibFunc::logl: - case LibFunc::log: - case LibFunc::logf: - case LibFunc::log2l: - case LibFunc::log2: - case LibFunc::log2f: - case LibFunc::log10l: - case LibFunc::log10: - case LibFunc::log10f: + case LibFunc_logl: + case LibFunc_log: + case LibFunc_logf: + case LibFunc_log2l: + case LibFunc_log2: + case LibFunc_log2f: + case LibFunc_log10l: + case LibFunc_log10: + case LibFunc_log10f: return Op.isNaN() || (!Op.isZero() && !Op.isNegative()); - case LibFunc::expl: - case LibFunc::exp: - case LibFunc::expf: + case LibFunc_expl: + case LibFunc_exp: + case LibFunc_expf: // FIXME: These boundaries are slightly conservative. if (OpC->getType()->isDoubleTy()) return Op.compare(APFloat(-745.0)) != APFloat::cmpLessThan && @@ -2053,9 +2107,9 @@ bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { Op.compare(APFloat(88.0f)) != APFloat::cmpGreaterThan; break; - case LibFunc::exp2l: - case LibFunc::exp2: - case LibFunc::exp2f: + case LibFunc_exp2l: + case LibFunc_exp2: + case LibFunc_exp2f: // FIXME: These boundaries are slightly conservative. if (OpC->getType()->isDoubleTy()) return Op.compare(APFloat(-1074.0)) != APFloat::cmpLessThan && @@ -2065,17 +2119,17 @@ bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { Op.compare(APFloat(127.0f)) != APFloat::cmpGreaterThan; break; - case LibFunc::sinl: - case LibFunc::sin: - case LibFunc::sinf: - case LibFunc::cosl: - case LibFunc::cos: - case LibFunc::cosf: + case LibFunc_sinl: + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_cosl: + case LibFunc_cos: + case LibFunc_cosf: return !Op.isInfinity(); - case LibFunc::tanl: - case LibFunc::tan: - case LibFunc::tanf: { + case LibFunc_tanl: + case LibFunc_tan: + case LibFunc_tanf: { // FIXME: Stop using the host math library. // FIXME: The computation isn't done in the right precision. Type *Ty = OpC->getType(); @@ -2086,23 +2140,23 @@ bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { break; } - case LibFunc::asinl: - case LibFunc::asin: - case LibFunc::asinf: - case LibFunc::acosl: - case LibFunc::acos: - case LibFunc::acosf: + case LibFunc_asinl: + case LibFunc_asin: + case LibFunc_asinf: + case LibFunc_acosl: + case LibFunc_acos: + case LibFunc_acosf: return Op.compare(APFloat(Op.getSemantics(), "-1")) != APFloat::cmpLessThan && Op.compare(APFloat(Op.getSemantics(), "1")) != APFloat::cmpGreaterThan; - case LibFunc::sinh: - case LibFunc::cosh: - case LibFunc::sinhf: - case LibFunc::coshf: - case LibFunc::sinhl: - case LibFunc::coshl: + case LibFunc_sinh: + case LibFunc_cosh: + case LibFunc_sinhf: + case LibFunc_coshf: + case LibFunc_sinhl: + case LibFunc_coshl: // FIXME: These boundaries are slightly conservative. if (OpC->getType()->isDoubleTy()) return Op.compare(APFloat(-710.0)) != APFloat::cmpLessThan && @@ -2112,9 +2166,9 @@ bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { Op.compare(APFloat(89.0f)) != APFloat::cmpGreaterThan; break; - case LibFunc::sqrtl: - case LibFunc::sqrt: - case LibFunc::sqrtf: + case LibFunc_sqrtl: + case LibFunc_sqrt: + case LibFunc_sqrtf: return Op.isNaN() || Op.isZero() || !Op.isNegative(); // FIXME: Add more functions: sqrt_finite, atanh, expm1, log1p, @@ -2133,9 +2187,9 @@ bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { const APFloat &Op1 = Op1C->getValueAPF(); switch (Func) { - case LibFunc::powl: - case LibFunc::pow: - case LibFunc::powf: { + case LibFunc_powl: + case LibFunc_pow: + case LibFunc_powf: { // FIXME: Stop using the host math library. // FIXME: The computation isn't done in the right precision. Type *Ty = Op0C->getType(); @@ -2149,9 +2203,9 @@ bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { break; } - case LibFunc::fmodl: - case LibFunc::fmod: - case LibFunc::fmodf: + case LibFunc_fmodl: + case LibFunc_fmod: + case LibFunc_fmodf: return Op0.isNaN() || Op1.isNaN() || (!Op0.isInfinity() && !Op1.isZero()); diff --git a/contrib/llvm/lib/Analysis/CostModel.cpp b/contrib/llvm/lib/Analysis/CostModel.cpp index 6b77397956cd..32bfea58bf9d 100644 --- a/contrib/llvm/lib/Analysis/CostModel.cpp +++ b/contrib/llvm/lib/Analysis/CostModel.cpp @@ -447,25 +447,25 @@ unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { case Instruction::Select: { const SelectInst *SI = cast<SelectInst>(I); Type *CondTy = SI->getCondition()->getType(); - return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy); + return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy, I); } case Instruction::ICmp: case Instruction::FCmp: { Type *ValTy = I->getOperand(0)->getType(); - return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy); + return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy, I->getType(), I); } case Instruction::Store: { const StoreInst *SI = cast<StoreInst>(I); Type *ValTy = SI->getValueOperand()->getType(); return TTI->getMemoryOpCost(I->getOpcode(), ValTy, - SI->getAlignment(), - SI->getPointerAddressSpace()); + SI->getAlignment(), + SI->getPointerAddressSpace(), I); } case Instruction::Load: { const LoadInst *LI = cast<LoadInst>(I); return TTI->getMemoryOpCost(I->getOpcode(), I->getType(), - LI->getAlignment(), - LI->getPointerAddressSpace()); + LI->getAlignment(), + LI->getPointerAddressSpace(), I); } case Instruction::ZExt: case Instruction::SExt: @@ -481,7 +481,7 @@ unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { case Instruction::BitCast: case Instruction::AddrSpaceCast: { Type *SrcTy = I->getOperand(0)->getType(); - return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy); + return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, I); } case Instruction::ExtractElement: { const ExtractElementInst * EEI = cast<ExtractElementInst>(I); @@ -542,9 +542,7 @@ unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { } case Instruction::Call: if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { - SmallVector<Value *, 4> Args; - for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J) - Args.push_back(II->getArgOperand(J)); + SmallVector<Value *, 4> Args(II->arg_operands()); FastMathFlags FMF; if (auto *FPMO = dyn_cast<FPMathOperator>(II)) diff --git a/contrib/llvm/lib/Analysis/DemandedBits.cpp b/contrib/llvm/lib/Analysis/DemandedBits.cpp index 688c1db534c1..8f808f3e7871 100644 --- a/contrib/llvm/lib/Analysis/DemandedBits.cpp +++ b/contrib/llvm/lib/Analysis/DemandedBits.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/Operator.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -72,8 +73,7 @@ static bool isAlwaysLive(Instruction *I) { void DemandedBits::determineLiveOperandBits( const Instruction *UserI, const Instruction *I, unsigned OperandNo, - const APInt &AOut, APInt &AB, APInt &KnownZero, APInt &KnownOne, - APInt &KnownZero2, APInt &KnownOne2) { + const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2) { unsigned BitWidth = AB.getBitWidth(); // We're called once per operand, but for some instructions, we need to @@ -85,16 +85,12 @@ void DemandedBits::determineLiveOperandBits( auto ComputeKnownBits = [&](unsigned BitWidth, const Value *V1, const Value *V2) { const DataLayout &DL = I->getModule()->getDataLayout(); - KnownZero = APInt(BitWidth, 0); - KnownOne = APInt(BitWidth, 0); - computeKnownBits(const_cast<Value *>(V1), KnownZero, KnownOne, DL, 0, - &AC, UserI, &DT); + Known = KnownBits(BitWidth); + computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT); if (V2) { - KnownZero2 = APInt(BitWidth, 0); - KnownOne2 = APInt(BitWidth, 0); - computeKnownBits(const_cast<Value *>(V2), KnownZero2, KnownOne2, DL, - 0, &AC, UserI, &DT); + Known2 = KnownBits(BitWidth); + computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT); } }; @@ -110,6 +106,9 @@ void DemandedBits::determineLiveOperandBits( // the output. AB = AOut.byteSwap(); break; + case Intrinsic::bitreverse: + AB = AOut.reverseBits(); + break; case Intrinsic::ctlz: if (OperandNo == 0) { // We need some output bits, so we need all bits of the @@ -117,7 +116,7 @@ void DemandedBits::determineLiveOperandBits( // known to be one. ComputeKnownBits(BitWidth, I, nullptr); AB = APInt::getHighBitsSet(BitWidth, - std::min(BitWidth, KnownOne.countLeadingZeros()+1)); + std::min(BitWidth, Known.countMaxLeadingZeros()+1)); } break; case Intrinsic::cttz: @@ -127,7 +126,7 @@ void DemandedBits::determineLiveOperandBits( // known to be one. ComputeKnownBits(BitWidth, I, nullptr); AB = APInt::getLowBitsSet(BitWidth, - std::min(BitWidth, KnownOne.countTrailingZeros()+1)); + std::min(BitWidth, Known.countMaxTrailingZeros()+1)); } break; } @@ -180,7 +179,7 @@ void DemandedBits::determineLiveOperandBits( // bits, then we must keep the highest input bit. if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt)) .getBoolValue()) - AB.setBit(BitWidth-1); + AB.setSignBit(); // If the shift is exact, then the low bits are not dead // (they must be zero). @@ -197,11 +196,11 @@ void DemandedBits::determineLiveOperandBits( // dead). if (OperandNo == 0) { ComputeKnownBits(BitWidth, I, UserI->getOperand(1)); - AB &= ~KnownZero2; + AB &= ~Known2.Zero; } else { if (!isa<Instruction>(UserI->getOperand(0))) ComputeKnownBits(BitWidth, UserI->getOperand(0), I); - AB &= ~(KnownZero & ~KnownZero2); + AB &= ~(Known.Zero & ~Known2.Zero); } break; case Instruction::Or: @@ -213,11 +212,11 @@ void DemandedBits::determineLiveOperandBits( // dead). if (OperandNo == 0) { ComputeKnownBits(BitWidth, I, UserI->getOperand(1)); - AB &= ~KnownOne2; + AB &= ~Known2.One; } else { if (!isa<Instruction>(UserI->getOperand(0))) ComputeKnownBits(BitWidth, UserI->getOperand(0), I); - AB &= ~(KnownOne & ~KnownOne2); + AB &= ~(Known.One & ~Known2.One); } break; case Instruction::Xor: @@ -238,7 +237,7 @@ void DemandedBits::determineLiveOperandBits( if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(), AOut.getBitWidth() - BitWidth)) .getBoolValue()) - AB.setBit(BitWidth-1); + AB.setSignBit(); break; case Instruction::Select: if (OperandNo != 0) @@ -315,7 +314,7 @@ void DemandedBits::performAnalysis() { if (!UserI->getType()->isIntegerTy()) Visited.insert(UserI); - APInt KnownZero, KnownOne, KnownZero2, KnownOne2; + KnownBits Known, Known2; // Compute the set of alive bits for each operand. These are anded into the // existing set, if any, and if that changes the set of alive bits, the // operand is added to the work-list. @@ -332,8 +331,7 @@ void DemandedBits::performAnalysis() { // Bits of each operand that are used to compute alive bits of the // output are alive, all others are dead. determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, AB, - KnownZero, KnownOne, - KnownZero2, KnownOne2); + Known, Known2); } // If we've added to the set of alive bits (or the operand has not diff --git a/contrib/llvm/lib/Analysis/DependenceAnalysis.cpp b/contrib/llvm/lib/Analysis/DependenceAnalysis.cpp index a332a07ce864..e4d58bf1b4eb 100644 --- a/contrib/llvm/lib/Analysis/DependenceAnalysis.cpp +++ b/contrib/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -385,9 +385,9 @@ void DependenceInfo::Constraint::setAny(ScalarEvolution *NewSE) { Kind = Any; } - +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) // For debugging purposes. Dumps the constraint out to OS. -void DependenceInfo::Constraint::dump(raw_ostream &OS) const { +LLVM_DUMP_METHOD void DependenceInfo::Constraint::dump(raw_ostream &OS) const { if (isEmpty()) OS << " Empty\n"; else if (isAny()) @@ -403,6 +403,7 @@ void DependenceInfo::Constraint::dump(raw_ostream &OS) const { else llvm_unreachable("unknown constraint type in Constraint::dump"); } +#endif // Updates X with the intersection @@ -2983,7 +2984,7 @@ bool DependenceInfo::propagate(const SCEV *&Src, const SCEV *&Dst, SmallVectorImpl<Constraint> &Constraints, bool &Consistent) { bool Result = false; - for (int LI = Loops.find_first(); LI >= 0; LI = Loops.find_next(LI)) { + for (unsigned LI : Loops.set_bits()) { DEBUG(dbgs() << "\t Constraint[" << LI << "] is"); DEBUG(Constraints[LI].dump(dbgs())); if (Constraints[LI].isDistance()) @@ -3265,7 +3266,7 @@ bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst, // For debugging purposes, dump a small bit vector to dbgs(). static void dumpSmallBitVector(SmallBitVector &BV) { dbgs() << "{"; - for (int VI = BV.find_first(); VI >= 0; VI = BV.find_next(VI)) { + for (unsigned VI : BV.set_bits()) { dbgs() << VI; if (BV.find_next(VI) >= 0) dbgs() << ' '; @@ -3505,7 +3506,7 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, NewConstraint.setAny(SE); // test separable subscripts - for (int SI = Separable.find_first(); SI >= 0; SI = Separable.find_next(SI)) { + for (unsigned SI : Separable.set_bits()) { DEBUG(dbgs() << "testing subscript " << SI); switch (Pair[SI].Classification) { case Subscript::ZIV: @@ -3544,14 +3545,14 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, SmallVector<Constraint, 4> Constraints(MaxLevels + 1); for (unsigned II = 0; II <= MaxLevels; ++II) Constraints[II].setAny(SE); - for (int SI = Coupled.find_first(); SI >= 0; SI = Coupled.find_next(SI)) { + for (unsigned SI : Coupled.set_bits()) { DEBUG(dbgs() << "testing subscript group " << SI << " { "); SmallBitVector Group(Pair[SI].Group); SmallBitVector Sivs(Pairs); SmallBitVector Mivs(Pairs); SmallBitVector ConstrainedLevels(MaxLevels + 1); SmallVector<Subscript *, 4> PairsInGroup; - for (int SJ = Group.find_first(); SJ >= 0; SJ = Group.find_next(SJ)) { + for (unsigned SJ : Group.set_bits()) { DEBUG(dbgs() << SJ << " "); if (Pair[SJ].Classification == Subscript::SIV) Sivs.set(SJ); @@ -3563,7 +3564,7 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, DEBUG(dbgs() << "}\n"); while (Sivs.any()) { bool Changed = false; - for (int SJ = Sivs.find_first(); SJ >= 0; SJ = Sivs.find_next(SJ)) { + for (unsigned SJ : Sivs.set_bits()) { DEBUG(dbgs() << "testing subscript " << SJ << ", SIV\n"); // SJ is an SIV subscript that's part of the current coupled group unsigned Level; @@ -3587,7 +3588,7 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, DEBUG(dbgs() << " propagating\n"); DEBUG(dbgs() << "\tMivs = "); DEBUG(dumpSmallBitVector(Mivs)); - for (int SJ = Mivs.find_first(); SJ >= 0; SJ = Mivs.find_next(SJ)) { + for (unsigned SJ : Mivs.set_bits()) { // SJ is an MIV subscript that's part of the current coupled group DEBUG(dbgs() << "\tSJ = " << SJ << "\n"); if (propagate(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops, @@ -3621,7 +3622,7 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, } // test & propagate remaining RDIVs - for (int SJ = Mivs.find_first(); SJ >= 0; SJ = Mivs.find_next(SJ)) { + for (unsigned SJ : Mivs.set_bits()) { if (Pair[SJ].Classification == Subscript::RDIV) { DEBUG(dbgs() << "RDIV test\n"); if (testRDIV(Pair[SJ].Src, Pair[SJ].Dst, Result)) @@ -3634,7 +3635,7 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, // test remaining MIVs // This code is temporary. // Better to somehow test all remaining subscripts simultaneously. - for (int SJ = Mivs.find_first(); SJ >= 0; SJ = Mivs.find_next(SJ)) { + for (unsigned SJ : Mivs.set_bits()) { if (Pair[SJ].Classification == Subscript::MIV) { DEBUG(dbgs() << "MIV test\n"); if (testMIV(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops, Result)) @@ -3646,9 +3647,8 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, // update Result.DV from constraint vector DEBUG(dbgs() << " updating\n"); - for (int SJ = ConstrainedLevels.find_first(); SJ >= 0; - SJ = ConstrainedLevels.find_next(SJ)) { - if (SJ > (int)CommonLevels) + for (unsigned SJ : ConstrainedLevels.set_bits()) { + if (SJ > CommonLevels) break; updateDirection(Result.DV[SJ - 1], Constraints[SJ]); if (Result.DV[SJ - 1].Direction == Dependence::DVEntry::NONE) @@ -3858,7 +3858,7 @@ const SCEV *DependenceInfo::getSplitIteration(const Dependence &Dep, NewConstraint.setAny(SE); // test separable subscripts - for (int SI = Separable.find_first(); SI >= 0; SI = Separable.find_next(SI)) { + for (unsigned SI : Separable.set_bits()) { switch (Pair[SI].Classification) { case Subscript::SIV: { unsigned Level; @@ -3885,12 +3885,12 @@ const SCEV *DependenceInfo::getSplitIteration(const Dependence &Dep, SmallVector<Constraint, 4> Constraints(MaxLevels + 1); for (unsigned II = 0; II <= MaxLevels; ++II) Constraints[II].setAny(SE); - for (int SI = Coupled.find_first(); SI >= 0; SI = Coupled.find_next(SI)) { + for (unsigned SI : Coupled.set_bits()) { SmallBitVector Group(Pair[SI].Group); SmallBitVector Sivs(Pairs); SmallBitVector Mivs(Pairs); SmallBitVector ConstrainedLevels(MaxLevels + 1); - for (int SJ = Group.find_first(); SJ >= 0; SJ = Group.find_next(SJ)) { + for (unsigned SJ : Group.set_bits()) { if (Pair[SJ].Classification == Subscript::SIV) Sivs.set(SJ); else @@ -3898,7 +3898,7 @@ const SCEV *DependenceInfo::getSplitIteration(const Dependence &Dep, } while (Sivs.any()) { bool Changed = false; - for (int SJ = Sivs.find_first(); SJ >= 0; SJ = Sivs.find_next(SJ)) { + for (unsigned SJ : Sivs.set_bits()) { // SJ is an SIV subscript that's part of the current coupled group unsigned Level; const SCEV *SplitIter = nullptr; @@ -3913,7 +3913,7 @@ const SCEV *DependenceInfo::getSplitIteration(const Dependence &Dep, } if (Changed) { // propagate, possibly creating new SIVs and ZIVs - for (int SJ = Mivs.find_first(); SJ >= 0; SJ = Mivs.find_next(SJ)) { + for (unsigned SJ : Mivs.set_bits()) { // SJ is an MIV subscript that's part of the current coupled group if (propagate(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops, Constraints, Result.Consistent)) { diff --git a/contrib/llvm/lib/Analysis/DomPrinter.cpp b/contrib/llvm/lib/Analysis/DomPrinter.cpp index 7acfb41500d4..8abc0e7d0df9 100644 --- a/contrib/llvm/lib/Analysis/DomPrinter.cpp +++ b/contrib/llvm/lib/Analysis/DomPrinter.cpp @@ -80,6 +80,22 @@ struct DOTGraphTraits<PostDominatorTree*> }; } +void DominatorTree::viewGraph(const Twine &Name, const Twine &Title) { +#ifndef NDEBUG + ViewGraph(this, Name, false, Title); +#else + errs() << "DomTree dump not available, build with DEBUG\n"; +#endif // NDEBUG +} + +void DominatorTree::viewGraph() { +#ifndef NDEBUG + this->viewGraph("domtree", "Dominator Tree for function"); +#else + errs() << "DomTree dump not available, build with DEBUG\n"; +#endif // NDEBUG +} + namespace { struct DominatorTreeWrapperPassAnalysisGraphTraits { static DominatorTree *getGraph(DominatorTreeWrapperPass *DTWP) { diff --git a/contrib/llvm/lib/Analysis/DominanceFrontier.cpp b/contrib/llvm/lib/Analysis/DominanceFrontier.cpp index 15856c3f8b7a..5b6e2d0476e4 100644 --- a/contrib/llvm/lib/Analysis/DominanceFrontier.cpp +++ b/contrib/llvm/lib/Analysis/DominanceFrontier.cpp @@ -56,6 +56,16 @@ LLVM_DUMP_METHOD void DominanceFrontierWrapperPass::dump() const { } #endif +/// Handle invalidation explicitly. +bool DominanceFrontier::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<DominanceFrontierAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + AnalysisKey DominanceFrontierAnalysis::Key; DominanceFrontier DominanceFrontierAnalysis::run(Function &F, diff --git a/contrib/llvm/lib/Analysis/EHPersonalities.cpp b/contrib/llvm/lib/Analysis/EHPersonalities.cpp index ebf0a370b0b0..b12ae9884e3d 100644 --- a/contrib/llvm/lib/Analysis/EHPersonalities.cpp +++ b/contrib/llvm/lib/Analysis/EHPersonalities.cpp @@ -27,8 +27,10 @@ EHPersonality llvm::classifyEHPersonality(const Value *Pers) { return StringSwitch<EHPersonality>(F->getName()) .Case("__gnat_eh_personality", EHPersonality::GNU_Ada) .Case("__gxx_personality_v0", EHPersonality::GNU_CXX) + .Case("__gxx_personality_seh0",EHPersonality::GNU_CXX) .Case("__gxx_personality_sj0", EHPersonality::GNU_CXX_SjLj) .Case("__gcc_personality_v0", EHPersonality::GNU_C) + .Case("__gcc_personality_seh0",EHPersonality::GNU_C) .Case("__gcc_personality_sj0", EHPersonality::GNU_C_SjLj) .Case("__objc_personality_v0", EHPersonality::GNU_ObjC) .Case("_except_handler3", EHPersonality::MSVC_X86SEH) diff --git a/contrib/llvm/lib/Analysis/IVUsers.cpp b/contrib/llvm/lib/Analysis/IVUsers.cpp index a661b0101e6a..c30feb973e60 100644 --- a/contrib/llvm/lib/Analysis/IVUsers.cpp +++ b/contrib/llvm/lib/Analysis/IVUsers.cpp @@ -76,9 +76,8 @@ static bool isInteresting(const SCEV *S, const Instruction *I, const Loop *L, // An add is interesting if exactly one of its operands is interesting. if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { bool AnyInterestingYet = false; - for (SCEVAddExpr::op_iterator OI = Add->op_begin(), OE = Add->op_end(); - OI != OE; ++OI) - if (isInteresting(*OI, I, L, SE, LI)) { + for (const auto *Op : Add->operands()) + if (isInteresting(Op, I, L, SE, LI)) { if (AnyInterestingYet) return false; AnyInterestingYet = true; @@ -118,6 +117,50 @@ static bool isSimplifiedLoopNest(BasicBlock *BB, const DominatorTree *DT, return true; } +/// IVUseShouldUsePostIncValue - We have discovered a "User" of an IV expression +/// and now we need to decide whether the user should use the preinc or post-inc +/// value. If this user should use the post-inc version of the IV, return true. +/// +/// Choosing wrong here can break dominance properties (if we choose to use the +/// post-inc value when we cannot) or it can end up adding extra live-ranges to +/// the loop, resulting in reg-reg copies (if we use the pre-inc value when we +/// should use the post-inc value). +static bool IVUseShouldUsePostIncValue(Instruction *User, Value *Operand, + const Loop *L, DominatorTree *DT) { + // If the user is in the loop, use the preinc value. + if (L->contains(User)) + return false; + + BasicBlock *LatchBlock = L->getLoopLatch(); + if (!LatchBlock) + return false; + + // Ok, the user is outside of the loop. If it is dominated by the latch + // block, use the post-inc value. + if (DT->dominates(LatchBlock, User->getParent())) + return true; + + // There is one case we have to be careful of: PHI nodes. These little guys + // can live in blocks that are not dominated by the latch block, but (since + // their uses occur in the predecessor block, not the block the PHI lives in) + // should still use the post-inc value. Check for this case now. + PHINode *PN = dyn_cast<PHINode>(User); + if (!PN || !Operand) + return false; // not a phi, not dominated by latch block. + + // Look at all of the uses of Operand by the PHI node. If any use corresponds + // to a block that is not dominated by the latch block, give up and use the + // preincremented value. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == Operand && + !DT->dominates(LatchBlock, PN->getIncomingBlock(i))) + return false; + + // Okay, all uses of Operand by PN are in predecessor blocks that really are + // dominated by the latch block. Use the post-incremented value. + return true; +} + /// AddUsersImpl - Inspect the specified instruction. If it is a /// reducible SCEV, recursively add its users to the IVUsesByStride set and /// return true. Otherwise, return false. @@ -208,10 +251,16 @@ bool IVUsers::AddUsersImpl(Instruction *I, // The regular return value here is discarded; instead of recording // it, we just recompute it when we need it. const SCEV *OriginalISE = ISE; - ISE = TransformForPostIncUse(NormalizeAutodetect, - ISE, User, I, - NewUse.PostIncLoops, - *SE, *DT); + + auto NormalizePred = [&](const SCEVAddRecExpr *AR) { + auto *L = AR->getLoop(); + bool Result = IVUseShouldUsePostIncValue(User, I, L, DT); + if (Result) + NewUse.PostIncLoops.insert(L); + return Result; + }; + + ISE = normalizeForPostIncUseIf(ISE, NormalizePred, *SE); // PostIncNormalization effectively simplifies the expression under // pre-increment assumptions. Those assumptions (no wrapping) might not @@ -219,8 +268,7 @@ bool IVUsers::AddUsersImpl(Instruction *I, // transformation is invertible. if (OriginalISE != ISE) { const SCEV *DenormalizedISE = - TransformForPostIncUse(Denormalize, ISE, User, I, - NewUse.PostIncLoops, *SE, *DT); + denormalizeForPostIncUse(ISE, NewUse.PostIncLoops, *SE); // If we normalized the expression, but denormalization doesn't give the // original one, discard this user. @@ -338,11 +386,8 @@ const SCEV *IVUsers::getReplacementExpr(const IVStrideUse &IU) const { /// getExpr - Return the expression for the use. const SCEV *IVUsers::getExpr(const IVStrideUse &IU) const { - return - TransformForPostIncUse(Normalize, getReplacementExpr(IU), - IU.getUser(), IU.getOperandValToReplace(), - const_cast<PostIncLoopSet &>(IU.getPostIncLoops()), - *SE, *DT); + return normalizeForPostIncUse(getReplacementExpr(IU), IU.getPostIncLoops(), + *SE); } static const SCEVAddRecExpr *findAddRecForLoop(const SCEV *S, const Loop *L) { @@ -353,9 +398,8 @@ static const SCEVAddRecExpr *findAddRecForLoop(const SCEV *S, const Loop *L) { } if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { - for (SCEVAddExpr::op_iterator I = Add->op_begin(), E = Add->op_end(); - I != E; ++I) - if (const SCEVAddRecExpr *AR = findAddRecForLoop(*I, L)) + for (const auto *Op : Add->operands()) + if (const SCEVAddRecExpr *AR = findAddRecForLoop(Op, L)) return AR; return nullptr; } diff --git a/contrib/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp b/contrib/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp index 3da33ac71421..ed233d201537 100644 --- a/contrib/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp +++ b/contrib/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp @@ -43,7 +43,7 @@ static cl::opt<unsigned> // The percent threshold for the direct-call target (this call site vs the // total call count) for it to be considered as the promotion target. static cl::opt<unsigned> - ICPPercentThreshold("icp-percent-threshold", cl::init(33), cl::Hidden, + ICPPercentThreshold("icp-percent-threshold", cl::init(30), cl::Hidden, cl::ZeroOrMore, cl::desc("The percentage threshold for the promotion")); diff --git a/contrib/llvm/lib/Analysis/InlineCost.cpp b/contrib/llvm/lib/Analysis/InlineCost.cpp index 4109049ecabc..77c87928728a 100644 --- a/contrib/llvm/lib/Analysis/InlineCost.cpp +++ b/contrib/llvm/lib/Analysis/InlineCost.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -48,11 +49,16 @@ static cl::opt<int> HintThreshold( "inlinehint-threshold", cl::Hidden, cl::init(325), cl::desc("Threshold for inlining functions with inline hint")); +static cl::opt<int> + ColdCallSiteThreshold("inline-cold-callsite-threshold", cl::Hidden, + cl::init(45), + cl::desc("Threshold for inlining cold callsites")); + // We introduce this threshold to help performance of instrumentation based // PGO before we actually hook up inliner with analysis passes such as BPI and // BFI. static cl::opt<int> ColdThreshold( - "inlinecold-threshold", cl::Hidden, cl::init(225), + "inlinecold-threshold", cl::Hidden, cl::init(45), cl::desc("Threshold for inlining functions with cold attribute")); static cl::opt<int> @@ -72,12 +78,18 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { /// Getter for the cache of @llvm.assume intrinsics. std::function<AssumptionCache &(Function &)> &GetAssumptionCache; + /// Getter for BlockFrequencyInfo + Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI; + /// Profile summary information. ProfileSummaryInfo *PSI; /// The called function. Function &F; + // Cache the DataLayout since we use it a lot. + const DataLayout &DL; + /// The candidate callsite being analyzed. Please do not use this to do /// analysis in the caller function; we want the inline cost query to be /// easily cacheable. Instead, use the cover function paramHasAttr. @@ -133,9 +145,11 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { void disableSROA(Value *V); void accumulateSROACost(DenseMap<Value *, int>::iterator CostIt, int InstructionCost); - bool isGEPOffsetConstant(GetElementPtrInst &GEP); + bool isGEPFree(GetElementPtrInst &GEP); bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset); bool simplifyCallSite(Function *F, CallSite CS); + template <typename Callable> + bool simplifyInstruction(Instruction &I, Callable Evaluate); ConstantInt *stripAndComputeInBoundsConstantOffsets(Value *&V); /// Return true if the given argument to the function being considered for @@ -202,9 +216,11 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { public: CallAnalyzer(const TargetTransformInfo &TTI, std::function<AssumptionCache &(Function &)> &GetAssumptionCache, + Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI, ProfileSummaryInfo *PSI, Function &Callee, CallSite CSArg, const InlineParams &Params) - : TTI(TTI), GetAssumptionCache(GetAssumptionCache), PSI(PSI), F(Callee), + : TTI(TTI), GetAssumptionCache(GetAssumptionCache), GetBFI(GetBFI), + PSI(PSI), F(Callee), DL(F.getParent()->getDataLayout()), CandidateCS(CSArg), Params(Params), Threshold(Params.DefaultThreshold), Cost(0), IsCallerRecursive(false), IsRecursiveCall(false), ExposesReturnsTwice(false), HasDynamicAlloca(false), @@ -286,23 +302,11 @@ void CallAnalyzer::accumulateSROACost(DenseMap<Value *, int>::iterator CostIt, SROACostSavings += InstructionCost; } -/// \brief Check whether a GEP's indices are all constant. -/// -/// Respects any simplified values known during the analysis of this callsite. -bool CallAnalyzer::isGEPOffsetConstant(GetElementPtrInst &GEP) { - for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I) - if (!isa<Constant>(*I) && !SimplifiedValues.lookup(*I)) - return false; - - return true; -} - /// \brief Accumulate a constant GEP offset into an APInt if possible. /// /// Returns false if unable to compute the offset for any reason. Respects any /// simplified values known during the analysis of this callsite. bool CallAnalyzer::accumulateGEPOffset(GEPOperator &GEP, APInt &Offset) { - const DataLayout &DL = F.getParent()->getDataLayout(); unsigned IntPtrWidth = DL.getPointerSizeInBits(); assert(IntPtrWidth == Offset.getBitWidth()); @@ -331,13 +335,27 @@ bool CallAnalyzer::accumulateGEPOffset(GEPOperator &GEP, APInt &Offset) { return true; } +/// \brief Use TTI to check whether a GEP is free. +/// +/// Respects any simplified values known during the analysis of this callsite. +bool CallAnalyzer::isGEPFree(GetElementPtrInst &GEP) { + SmallVector<Value *, 4> Indices; + for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I) + if (Constant *SimpleOp = SimplifiedValues.lookup(*I)) + Indices.push_back(SimpleOp); + else + Indices.push_back(*I); + return TargetTransformInfo::TCC_Free == + TTI.getGEPCost(GEP.getSourceElementType(), GEP.getPointerOperand(), + Indices); +} + bool CallAnalyzer::visitAlloca(AllocaInst &I) { // Check whether inlining will turn a dynamic alloca into a static // alloca and handle that case. if (I.isArrayAllocation()) { Constant *Size = SimplifiedValues.lookup(I.getArraySize()); if (auto *AllocSize = dyn_cast_or_null<ConstantInt>(Size)) { - const DataLayout &DL = F.getParent()->getDataLayout(); Type *Ty = I.getAllocatedType(); AllocatedSize = SaturatingMultiplyAdd( AllocSize->getLimitedValue(), DL.getTypeAllocSize(Ty), AllocatedSize); @@ -347,7 +365,6 @@ bool CallAnalyzer::visitAlloca(AllocaInst &I) { // Accumulate the allocated size. if (I.isStaticAlloca()) { - const DataLayout &DL = F.getParent()->getDataLayout(); Type *Ty = I.getAllocatedType(); AllocatedSize = SaturatingAdd(DL.getTypeAllocSize(Ty), AllocatedSize); } @@ -396,7 +413,7 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) { // Non-constant GEPs aren't folded, and disable SROA. if (SROACandidate) disableSROA(CostIt); - return false; + return isGEPFree(I); } // Add the result as a new mapping to Base + Offset. @@ -411,7 +428,15 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) { } } - if (isGEPOffsetConstant(I)) { + // Lambda to check whether a GEP's indices are all constant. + auto IsGEPOffsetConstant = [&](GetElementPtrInst &GEP) { + for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I) + if (!isa<Constant>(*I) && !SimplifiedValues.lookup(*I)) + return false; + return true; + }; + + if (IsGEPOffsetConstant(I)) { if (SROACandidate) SROAArgValues[&I] = SROAArg; @@ -422,19 +447,36 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) { // Variable GEPs will require math and will disable SROA. if (SROACandidate) disableSROA(CostIt); - return false; + return isGEPFree(I); +} + +/// Simplify \p I if its operands are constants and update SimplifiedValues. +/// \p Evaluate is a callable specific to instruction type that evaluates the +/// instruction when all the operands are constants. +template <typename Callable> +bool CallAnalyzer::simplifyInstruction(Instruction &I, Callable Evaluate) { + SmallVector<Constant *, 2> COps; + for (Value *Op : I.operands()) { + Constant *COp = dyn_cast<Constant>(Op); + if (!COp) + COp = SimplifiedValues.lookup(Op); + if (!COp) + return false; + COps.push_back(COp); + } + auto *C = Evaluate(COps); + if (!C) + return false; + SimplifiedValues[&I] = C; + return true; } bool CallAnalyzer::visitBitCast(BitCastInst &I) { // Propagate constants through bitcasts. - Constant *COp = dyn_cast<Constant>(I.getOperand(0)); - if (!COp) - COp = SimplifiedValues.lookup(I.getOperand(0)); - if (COp) - if (Constant *C = ConstantExpr::getBitCast(COp, I.getType())) { - SimplifiedValues[&I] = C; - return true; - } + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getBitCast(COps[0], I.getType()); + })) + return true; // Track base/offsets through casts std::pair<Value *, APInt> BaseAndOffset = @@ -455,19 +497,14 @@ bool CallAnalyzer::visitBitCast(BitCastInst &I) { bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) { // Propagate constants through ptrtoint. - Constant *COp = dyn_cast<Constant>(I.getOperand(0)); - if (!COp) - COp = SimplifiedValues.lookup(I.getOperand(0)); - if (COp) - if (Constant *C = ConstantExpr::getPtrToInt(COp, I.getType())) { - SimplifiedValues[&I] = C; - return true; - } + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getPtrToInt(COps[0], I.getType()); + })) + return true; // Track base/offset pairs when converted to a plain integer provided the // integer is large enough to represent the pointer. unsigned IntegerSize = I.getType()->getScalarSizeInBits(); - const DataLayout &DL = F.getParent()->getDataLayout(); if (IntegerSize >= DL.getPointerSizeInBits()) { std::pair<Value *, APInt> BaseAndOffset = ConstantOffsetPtrs.lookup(I.getOperand(0)); @@ -492,20 +529,15 @@ bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) { bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) { // Propagate constants through ptrtoint. - Constant *COp = dyn_cast<Constant>(I.getOperand(0)); - if (!COp) - COp = SimplifiedValues.lookup(I.getOperand(0)); - if (COp) - if (Constant *C = ConstantExpr::getIntToPtr(COp, I.getType())) { - SimplifiedValues[&I] = C; - return true; - } + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getIntToPtr(COps[0], I.getType()); + })) + return true; // Track base/offset pairs when round-tripped through a pointer without // modifications provided the integer is not too large. Value *Op = I.getOperand(0); unsigned IntegerSize = Op->getType()->getScalarSizeInBits(); - const DataLayout &DL = F.getParent()->getDataLayout(); if (IntegerSize <= DL.getPointerSizeInBits()) { std::pair<Value *, APInt> BaseAndOffset = ConstantOffsetPtrs.lookup(Op); if (BaseAndOffset.first) @@ -523,14 +555,10 @@ bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) { bool CallAnalyzer::visitCastInst(CastInst &I) { // Propagate constants through ptrtoint. - Constant *COp = dyn_cast<Constant>(I.getOperand(0)); - if (!COp) - COp = SimplifiedValues.lookup(I.getOperand(0)); - if (COp) - if (Constant *C = ConstantExpr::getCast(I.getOpcode(), COp, I.getType())) { - SimplifiedValues[&I] = C; - return true; - } + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getCast(I.getOpcode(), COps[0], I.getType()); + })) + return true; // Disable SROA in the face of arbitrary casts we don't whitelist elsewhere. disableSROA(I.getOperand(0)); @@ -540,16 +568,10 @@ bool CallAnalyzer::visitCastInst(CastInst &I) { bool CallAnalyzer::visitUnaryInstruction(UnaryInstruction &I) { Value *Operand = I.getOperand(0); - Constant *COp = dyn_cast<Constant>(Operand); - if (!COp) - COp = SimplifiedValues.lookup(Operand); - if (COp) { - const DataLayout &DL = F.getParent()->getDataLayout(); - if (Constant *C = ConstantFoldInstOperands(&I, COp, DL)) { - SimplifiedValues[&I] = C; - return true; - } - } + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantFoldInstOperands(&I, COps[0], DL); + })) + return true; // Disable any SROA on the argument to arbitrary unary operators. disableSROA(Operand); @@ -558,8 +580,7 @@ bool CallAnalyzer::visitUnaryInstruction(UnaryInstruction &I) { } bool CallAnalyzer::paramHasAttr(Argument *A, Attribute::AttrKind Attr) { - unsigned ArgNo = A->getArgNo(); - return CandidateCS.paramHasAttr(ArgNo + 1, Attr); + return CandidateCS.paramHasAttr(A->getArgNo(), Attr); } bool CallAnalyzer::isKnownNonNullInCallee(Value *V) { @@ -642,17 +663,34 @@ void CallAnalyzer::updateThreshold(CallSite CS, Function &Callee) { if (Callee.hasFnAttribute(Attribute::InlineHint)) Threshold = MaxIfValid(Threshold, Params.HintThreshold); if (PSI) { - uint64_t TotalWeight; - if (CS.getInstruction()->extractProfTotalWeight(TotalWeight) && - PSI->isHotCount(TotalWeight)) { - Threshold = MaxIfValid(Threshold, Params.HotCallSiteThreshold); - } else if (PSI->isFunctionEntryHot(&Callee)) { - // If callsite hotness can not be determined, we may still know - // that the callee is hot and treat it as a weaker hint for threshold - // increase. - Threshold = MaxIfValid(Threshold, Params.HintThreshold); - } else if (PSI->isFunctionEntryCold(&Callee)) { - Threshold = MinIfValid(Threshold, Params.ColdThreshold); + BlockFrequencyInfo *CallerBFI = GetBFI ? &((*GetBFI)(*Caller)) : nullptr; + // FIXME: After switching to the new passmanager, simplify the logic below + // by checking only the callsite hotness/coldness. The check for CallerBFI + // exists only because we do not have BFI available with the old PM. + // + // Use callee's hotness information only if we have no way of determining + // callsite's hotness information. Callsite hotness can be determined if + // sample profile is used (which adds hotness metadata to calls) or if + // caller's BlockFrequencyInfo is available. + if (CallerBFI || PSI->hasSampleProfile()) { + if (PSI->isHotCallSite(CS, CallerBFI)) { + DEBUG(dbgs() << "Hot callsite.\n"); + Threshold = Params.HotCallSiteThreshold.getValue(); + } else if (PSI->isColdCallSite(CS, CallerBFI)) { + DEBUG(dbgs() << "Cold callsite.\n"); + Threshold = MinIfValid(Threshold, Params.ColdCallSiteThreshold); + } + } else { + if (PSI->isFunctionEntryHot(&Callee)) { + DEBUG(dbgs() << "Hot callee.\n"); + // If callsite hotness can not be determined, we may still know + // that the callee is hot and treat it as a weaker hint for threshold + // increase. + Threshold = MaxIfValid(Threshold, Params.HintThreshold); + } else if (PSI->isFunctionEntryCold(&Callee)) { + DEBUG(dbgs() << "Cold callee.\n"); + Threshold = MinIfValid(Threshold, Params.ColdThreshold); + } } } } @@ -665,20 +703,10 @@ void CallAnalyzer::updateThreshold(CallSite CS, Function &Callee) { bool CallAnalyzer::visitCmpInst(CmpInst &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); // First try to handle simplified comparisons. - if (!isa<Constant>(LHS)) - if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) - LHS = SimpleLHS; - if (!isa<Constant>(RHS)) - if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) - RHS = SimpleRHS; - if (Constant *CLHS = dyn_cast<Constant>(LHS)) { - if (Constant *CRHS = dyn_cast<Constant>(RHS)) - if (Constant *C = - ConstantExpr::getCompare(I.getPredicate(), CLHS, CRHS)) { - SimplifiedValues[&I] = C; - return true; - } - } + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getCompare(I.getPredicate(), COps[0], COps[1]); + })) + return true; if (I.getOpcode() == Instruction::FCmp) return false; @@ -756,24 +784,18 @@ bool CallAnalyzer::visitSub(BinaryOperator &I) { bool CallAnalyzer::visitBinaryOperator(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - const DataLayout &DL = F.getParent()->getDataLayout(); - if (!isa<Constant>(LHS)) - if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) - LHS = SimpleLHS; - if (!isa<Constant>(RHS)) - if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) - RHS = SimpleRHS; - Value *SimpleV = nullptr; - if (auto FI = dyn_cast<FPMathOperator>(&I)) - SimpleV = - SimplifyFPBinOp(I.getOpcode(), LHS, RHS, FI->getFastMathFlags(), DL); - else - SimpleV = SimplifyBinOp(I.getOpcode(), LHS, RHS, DL); + auto Evaluate = [&](SmallVectorImpl<Constant *> &COps) { + Value *SimpleV = nullptr; + if (auto FI = dyn_cast<FPMathOperator>(&I)) + SimpleV = SimplifyFPBinOp(I.getOpcode(), COps[0], COps[1], + FI->getFastMathFlags(), DL); + else + SimpleV = SimplifyBinOp(I.getOpcode(), COps[0], COps[1], DL); + return dyn_cast_or_null<Constant>(SimpleV); + }; - if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) { - SimplifiedValues[&I] = C; + if (simplifyInstruction(I, Evaluate)) return true; - } // Disable any SROA on arguments to arbitrary, unsimplified binary operators. disableSROA(LHS); @@ -814,13 +836,10 @@ bool CallAnalyzer::visitStore(StoreInst &I) { bool CallAnalyzer::visitExtractValue(ExtractValueInst &I) { // Constant folding for extract value is trivial. - Constant *C = dyn_cast<Constant>(I.getAggregateOperand()); - if (!C) - C = SimplifiedValues.lookup(I.getAggregateOperand()); - if (C) { - SimplifiedValues[&I] = ConstantExpr::getExtractValue(C, I.getIndices()); + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getExtractValue(COps[0], I.getIndices()); + })) return true; - } // SROA can look through these but give them a cost. return false; @@ -828,17 +847,12 @@ bool CallAnalyzer::visitExtractValue(ExtractValueInst &I) { bool CallAnalyzer::visitInsertValue(InsertValueInst &I) { // Constant folding for insert value is trivial. - Constant *AggC = dyn_cast<Constant>(I.getAggregateOperand()); - if (!AggC) - AggC = SimplifiedValues.lookup(I.getAggregateOperand()); - Constant *InsertedC = dyn_cast<Constant>(I.getInsertedValueOperand()); - if (!InsertedC) - InsertedC = SimplifiedValues.lookup(I.getInsertedValueOperand()); - if (AggC && InsertedC) { - SimplifiedValues[&I] = - ConstantExpr::getInsertValue(AggC, InsertedC, I.getIndices()); + if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { + return ConstantExpr::getInsertValue(/*AggregateOperand*/ COps[0], + /*InsertedValueOperand*/ COps[1], + I.getIndices()); + })) return true; - } // SROA can look through these but give them a cost. return false; @@ -959,7 +973,8 @@ bool CallAnalyzer::visitCallSite(CallSite CS) { // out. Pretend to inline the function, with a custom threshold. auto IndirectCallParams = Params; IndirectCallParams.DefaultThreshold = InlineConstants::IndirectCallThreshold; - CallAnalyzer CA(TTI, GetAssumptionCache, PSI, *F, CS, IndirectCallParams); + CallAnalyzer CA(TTI, GetAssumptionCache, GetBFI, PSI, *F, CS, + IndirectCallParams); if (CA.analyzeCall(CS)) { // We were able to inline the indirect call! Subtract the cost from the // threshold to get the bonus we want to apply, but don't go below zero. @@ -995,22 +1010,68 @@ bool CallAnalyzer::visitSwitchInst(SwitchInst &SI) { if (isa<ConstantInt>(V)) return true; - // Otherwise, we need to accumulate a cost proportional to the number of - // distinct successor blocks. This fan-out in the CFG cannot be represented - // for free even if we can represent the core switch as a jumptable that - // takes a single instruction. + // Assume the most general case where the swith is lowered into + // either a jump table, bit test, or a balanced binary tree consisting of + // case clusters without merging adjacent clusters with the same + // destination. We do not consider the switches that are lowered with a mix + // of jump table/bit test/binary search tree. The cost of the switch is + // proportional to the size of the tree or the size of jump table range. // // NB: We convert large switches which are just used to initialize large phi // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent // inlining those. It will prevent inlining in cases where the optimization // does not (yet) fire. - SmallPtrSet<BasicBlock *, 8> SuccessorBlocks; - SuccessorBlocks.insert(SI.getDefaultDest()); - for (auto I = SI.case_begin(), E = SI.case_end(); I != E; ++I) - SuccessorBlocks.insert(I.getCaseSuccessor()); - // Add cost corresponding to the number of distinct destinations. The first - // we model as free because of fallthrough. - Cost += (SuccessorBlocks.size() - 1) * InlineConstants::InstrCost; + + // Exit early for a large switch, assuming one case needs at least one + // instruction. + // FIXME: This is not true for a bit test, but ignore such case for now to + // save compile-time. + int64_t CostLowerBound = + std::min((int64_t)INT_MAX, + (int64_t)SI.getNumCases() * InlineConstants::InstrCost + Cost); + + if (CostLowerBound > Threshold) { + Cost = CostLowerBound; + return false; + } + + unsigned JumpTableSize = 0; + unsigned NumCaseCluster = + TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize); + + // If suitable for a jump table, consider the cost for the table size and + // branch to destination. + if (JumpTableSize) { + int64_t JTCost = (int64_t)JumpTableSize * InlineConstants::InstrCost + + 4 * InlineConstants::InstrCost; + Cost = std::min((int64_t)INT_MAX, JTCost + Cost); + return false; + } + + // Considering forming a binary search, we should find the number of nodes + // which is same as the number of comparisons when lowered. For a given + // number of clusters, n, we can define a recursive function, f(n), to find + // the number of nodes in the tree. The recursion is : + // f(n) = 1 + f(n/2) + f (n - n/2), when n > 3, + // and f(n) = n, when n <= 3. + // This will lead a binary tree where the leaf should be either f(2) or f(3) + // when n > 3. So, the number of comparisons from leaves should be n, while + // the number of non-leaf should be : + // 2^(log2(n) - 1) - 1 + // = 2^log2(n) * 2^-1 - 1 + // = n / 2 - 1. + // Considering comparisons from leaf and non-leaf nodes, we can estimate the + // number of comparisons in a simple closed form : + // n + n / 2 - 1 = n * 3 / 2 - 1 + if (NumCaseCluster <= 3) { + // Suppose a comparison includes one compare and one conditional branch. + Cost += NumCaseCluster * 2 * InlineConstants::InstrCost; + return false; + } + int64_t ExpectedNumberOfCompare = 3 * (uint64_t)NumCaseCluster / 2 - 1; + uint64_t SwitchCost = + ExpectedNumberOfCompare * 2 * InlineConstants::InstrCost; + Cost = std::min((uint64_t)INT_MAX, SwitchCost + Cost); return false; } @@ -1098,19 +1159,10 @@ bool CallAnalyzer::analyzeBlock(BasicBlock *BB, // is expensive or the function has the "use-soft-float" attribute, this may // eventually become a library call. Treat the cost as such. if (I->getType()->isFloatingPointTy()) { - bool hasSoftFloatAttr = false; - // If the function has the "use-soft-float" attribute, mark it as // expensive. - if (F.hasFnAttribute("use-soft-float")) { - Attribute Attr = F.getFnAttribute("use-soft-float"); - StringRef Val = Attr.getValueAsString(); - if (Val == "true") - hasSoftFloatAttr = true; - } - if (TTI.getFPOpCost(I->getType()) == TargetTransformInfo::TCC_Expensive || - hasSoftFloatAttr) + (F.getFnAttribute("use-soft-float").getValueAsString() == "true")) Cost += InlineConstants::CallPenalty; } @@ -1155,7 +1207,6 @@ ConstantInt *CallAnalyzer::stripAndComputeInBoundsConstantOffsets(Value *&V) { if (!V->getType()->isPointerTy()) return nullptr; - const DataLayout &DL = F.getParent()->getDataLayout(); unsigned IntPtrWidth = DL.getPointerSizeInBits(); APInt Offset = APInt::getNullValue(IntPtrWidth); @@ -1212,7 +1263,6 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { FiftyPercentVectorBonus = 3 * Threshold / 2; TenPercentVectorBonus = 3 * Threshold / 4; - const DataLayout &DL = F.getParent()->getDataLayout(); // Track whether the post-inlining function would have more than one basic // block. A single basic block is often intended for inlining. Balloon the @@ -1225,36 +1275,10 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // the rest of the function body. Threshold += (SingleBBBonus + FiftyPercentVectorBonus); - // Give out bonuses per argument, as the instructions setting them up will - // be gone after inlining. - for (unsigned I = 0, E = CS.arg_size(); I != E; ++I) { - if (CS.isByValArgument(I)) { - // We approximate the number of loads and stores needed by dividing the - // size of the byval type by the target's pointer size. - PointerType *PTy = cast<PointerType>(CS.getArgument(I)->getType()); - unsigned TypeSize = DL.getTypeSizeInBits(PTy->getElementType()); - unsigned PointerSize = DL.getPointerSizeInBits(); - // Ceiling division. - unsigned NumStores = (TypeSize + PointerSize - 1) / PointerSize; - - // If it generates more than 8 stores it is likely to be expanded as an - // inline memcpy so we take that as an upper bound. Otherwise we assume - // one load and one store per word copied. - // FIXME: The maxStoresPerMemcpy setting from the target should be used - // here instead of a magic number of 8, but it's not available via - // DataLayout. - NumStores = std::min(NumStores, 8U); + // Give out bonuses for the callsite, as the instructions setting them up + // will be gone after inlining. + Cost -= getCallsiteCost(CS, DL); - Cost -= 2 * NumStores * InlineConstants::InstrCost; - } else { - // For non-byval arguments subtract off one instruction per call - // argument. - Cost -= InlineConstants::InstrCost; - } - } - // The call instruction also disappears after inlining. - Cost -= InlineConstants::InstrCost + InlineConstants::CallPenalty; - // If there is only one call of the function, and it has internal linkage, // the cost of inlining it drops dramatically. bool OnlyOneCallAndLocalLinkage = @@ -1371,7 +1395,7 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { Value *Cond = SI->getCondition(); if (ConstantInt *SimpleCond = dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { - BBWorklist.insert(SI->findCaseValue(SimpleCond).getCaseSuccessor()); + BBWorklist.insert(SI->findCaseValue(SimpleCond)->getCaseSuccessor()); continue; } } @@ -1430,13 +1454,6 @@ LLVM_DUMP_METHOD void CallAnalyzer::dump() { } #endif -/// \brief Test that two functions either have or have not the given attribute -/// at the same time. -template <typename AttrKind> -static bool attributeMatches(Function *F1, Function *F2, AttrKind Attr) { - return F1->getFnAttribute(Attr) == F2->getFnAttribute(Attr); -} - /// \brief Test that there are no attribute conflicts between Caller and Callee /// that prevent inlining. static bool functionsHaveCompatibleAttributes(Function *Caller, @@ -1446,18 +1463,52 @@ static bool functionsHaveCompatibleAttributes(Function *Caller, AttributeFuncs::areInlineCompatible(*Caller, *Callee); } +int llvm::getCallsiteCost(CallSite CS, const DataLayout &DL) { + int Cost = 0; + for (unsigned I = 0, E = CS.arg_size(); I != E; ++I) { + if (CS.isByValArgument(I)) { + // We approximate the number of loads and stores needed by dividing the + // size of the byval type by the target's pointer size. + PointerType *PTy = cast<PointerType>(CS.getArgument(I)->getType()); + unsigned TypeSize = DL.getTypeSizeInBits(PTy->getElementType()); + unsigned PointerSize = DL.getPointerSizeInBits(); + // Ceiling division. + unsigned NumStores = (TypeSize + PointerSize - 1) / PointerSize; + + // If it generates more than 8 stores it is likely to be expanded as an + // inline memcpy so we take that as an upper bound. Otherwise we assume + // one load and one store per word copied. + // FIXME: The maxStoresPerMemcpy setting from the target should be used + // here instead of a magic number of 8, but it's not available via + // DataLayout. + NumStores = std::min(NumStores, 8U); + + Cost += 2 * NumStores * InlineConstants::InstrCost; + } else { + // For non-byval arguments subtract off one instruction per call + // argument. + Cost += InlineConstants::InstrCost; + } + } + // The call instruction also disappears after inlining. + Cost += InlineConstants::InstrCost + InlineConstants::CallPenalty; + return Cost; +} + InlineCost llvm::getInlineCost( CallSite CS, const InlineParams &Params, TargetTransformInfo &CalleeTTI, std::function<AssumptionCache &(Function &)> &GetAssumptionCache, + Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI, ProfileSummaryInfo *PSI) { return getInlineCost(CS, CS.getCalledFunction(), Params, CalleeTTI, - GetAssumptionCache, PSI); + GetAssumptionCache, GetBFI, PSI); } InlineCost llvm::getInlineCost( CallSite CS, Function *Callee, const InlineParams &Params, TargetTransformInfo &CalleeTTI, std::function<AssumptionCache &(Function &)> &GetAssumptionCache, + Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI, ProfileSummaryInfo *PSI) { // Cannot inline indirect calls. @@ -1492,7 +1543,8 @@ InlineCost llvm::getInlineCost( DEBUG(llvm::dbgs() << " Analyzing call of " << Callee->getName() << "...\n"); - CallAnalyzer CA(CalleeTTI, GetAssumptionCache, PSI, *Callee, CS, Params); + CallAnalyzer CA(CalleeTTI, GetAssumptionCache, GetBFI, PSI, *Callee, CS, + Params); bool ShouldInline = CA.analyzeCall(CS); DEBUG(CA.dump()); @@ -1565,7 +1617,9 @@ InlineParams llvm::getInlineParams(int Threshold) { // Set the HotCallSiteThreshold knob from the -hot-callsite-threshold. Params.HotCallSiteThreshold = HotCallSiteThreshold; - // Set the OptMinSizeThreshold and OptSizeThreshold params only if the + // Set the ColdCallSiteThreshold knob from the -inline-cold-callsite-threshold. + Params.ColdCallSiteThreshold = ColdCallSiteThreshold; + // Set the OptMinSizeThreshold and OptSizeThreshold params only if the // -inlinehint-threshold commandline option is not explicitly given. If that // option is present, then its value applies even for callees with size and diff --git a/contrib/llvm/lib/Analysis/InstructionSimplify.cpp b/contrib/llvm/lib/Analysis/InstructionSimplify.cpp index 796e6e444980..66ac847455cd 100644 --- a/contrib/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/contrib/llvm/lib/Analysis/InstructionSimplify.cpp @@ -21,9 +21,12 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" @@ -34,6 +37,7 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Support/KnownBits.h" #include <algorithm> using namespace llvm; using namespace llvm::PatternMatch; @@ -45,49 +49,30 @@ enum { RecursionLimit = 3 }; STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumReassoc, "Number of reassociations"); -namespace { -struct Query { - const DataLayout &DL; - const TargetLibraryInfo *TLI; - const DominatorTree *DT; - AssumptionCache *AC; - const Instruction *CxtI; - - Query(const DataLayout &DL, const TargetLibraryInfo *tli, - const DominatorTree *dt, AssumptionCache *ac = nullptr, - const Instruction *cxti = nullptr) - : DL(DL), TLI(tli), DT(dt), AC(ac), CxtI(cxti) {} -}; -} // end anonymous namespace - -static Value *SimplifyAndInst(Value *, Value *, const Query &, unsigned); -static Value *SimplifyBinOp(unsigned, Value *, Value *, const Query &, +static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyFPBinOp(unsigned, Value *, Value *, const FastMathFlags &, - const Query &, unsigned); -static Value *SimplifyCmpInst(unsigned, Value *, Value *, const Query &, + const SimplifyQuery &, unsigned); +static Value *SimplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse); -static Value *SimplifyOrInst(Value *, Value *, const Query &, unsigned); -static Value *SimplifyXorInst(Value *, Value *, const Query &, unsigned); + const SimplifyQuery &Q, unsigned MaxRecurse); +static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyCastInst(unsigned, Value *, Type *, - const Query &, unsigned); + const SimplifyQuery &, unsigned); -/// For a boolean type, or a vector of boolean type, return false, or -/// a vector with every element false, as appropriate for the type. +/// For a boolean type or a vector of boolean type, return false or a vector +/// with every element false. static Constant *getFalse(Type *Ty) { - assert(Ty->getScalarType()->isIntegerTy(1) && - "Expected i1 type or a vector of i1!"); - return Constant::getNullValue(Ty); + return ConstantInt::getFalse(Ty); } -/// For a boolean type, or a vector of boolean type, return true, or -/// a vector with every element true, as appropriate for the type. +/// For a boolean type or a vector of boolean type, return true or a vector +/// with every element true. static Constant *getTrue(Type *Ty) { - assert(Ty->getScalarType()->isIntegerTy(1) && - "Expected i1 type or a vector of i1!"); - return Constant::getAllOnesValue(Ty); + return ConstantInt::getTrue(Ty); } /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"? @@ -118,13 +103,8 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { return false; // If we have a DominatorTree then do a precise test. - if (DT) { - if (!DT->isReachableFromEntry(P->getParent())) - return true; - if (!DT->isReachableFromEntry(I->getParent())) - return false; + if (DT) return DT->dominates(I, P); - } // Otherwise, if the instruction is in the entry block and is not an invoke, // then it obviously dominates all phi nodes. @@ -140,10 +120,9 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { /// given by OpcodeToExpand, while "A" corresponds to LHS and "B op' C" to RHS. /// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)". /// Returns the simplified value, or null if no simplification was performed. -static Value *ExpandBinOp(unsigned Opcode, Value *LHS, Value *RHS, - unsigned OpcToExpand, const Query &Q, - unsigned MaxRecurse) { - Instruction::BinaryOps OpcodeToExpand = (Instruction::BinaryOps)OpcToExpand; +static Value *ExpandBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, + Instruction::BinaryOps OpcodeToExpand, + const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -199,9 +178,10 @@ static Value *ExpandBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// Generic simplifications for associative binary operations. /// Returns the simpler value, or null if none was found. -static Value *SimplifyAssociativeBinOp(unsigned Opc, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { - Instruction::BinaryOps Opcode = (Instruction::BinaryOps)Opc; +static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, + Value *LHS, Value *RHS, + const SimplifyQuery &Q, + unsigned MaxRecurse) { assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); // Recursion is always used, so bail out at once if we already hit the limit. @@ -298,8 +278,9 @@ static Value *SimplifyAssociativeBinOp(unsigned Opc, Value *LHS, Value *RHS, /// try to simplify the binop by seeing whether evaluating it on both branches /// of the select results in the same value. Returns the common value if so, /// otherwise returns null. -static Value *ThreadBinOpOverSelect(unsigned Opcode, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { +static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -370,7 +351,7 @@ static Value *ThreadBinOpOverSelect(unsigned Opcode, Value *LHS, Value *RHS, /// comparison by seeing whether both branches of the select result in the same /// value. Returns the common value if so, otherwise returns null. static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q, + Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) @@ -451,8 +432,9 @@ static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, /// try to simplify the binop by seeing whether evaluating it on the incoming /// phi values yields the same result for every value. If so returns the common /// value, otherwise returns null. -static Value *ThreadBinOpOverPHI(unsigned Opcode, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { +static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -494,7 +476,7 @@ static Value *ThreadBinOpOverPHI(unsigned Opcode, Value *LHS, Value *RHS, /// yields the same result every time. If so returns the common result, /// otherwise returns null. static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; @@ -527,17 +509,26 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, return CommonValue; } +static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, + Value *&Op0, Value *&Op1, + const SimplifyQuery &Q) { + if (auto *CLHS = dyn_cast<Constant>(Op0)) { + if (auto *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); + + // Canonicalize the constant to the RHS if this is a commutative operation. + if (Instruction::isCommutative(Opcode)) + std::swap(Op0, Op1); + } + return nullptr; +} + /// Given operands for an Add, see if we can fold the result. /// If not, this returns null. static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) + return C; // X + undef -> undef if (match(Op1, m_Undef())) @@ -556,12 +547,20 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return Y; // X + ~X -> -1 since ~X = -X-1 + Type *Ty = Op0->getType(); if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) - return Constant::getAllOnesValue(Op0->getType()); + return Constant::getAllOnesValue(Ty); + + // add nsw/nuw (xor Y, signmask), signmask --> Y + // The no-wrapping add guarantees that the top bit will be set by the add. + // Therefore, the xor must be clearing the already set sign bit of Y. + if ((isNSW || isNUW) && match(Op1, m_SignMask()) && + match(Op0, m_Xor(m_Value(Y), m_SignMask()))) + return Y; /// i1 add -> xor. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->getScalarType()->isIntegerTy(1)) if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -583,11 +582,8 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, } Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Query) { + return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query, RecursionLimit); } /// \brief Compute the base pointer and cumulative constant offsets for V. @@ -664,10 +660,9 @@ static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, /// Given operands for a Sub, see if we can fold the result. /// If not, this returns null. static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL); + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q)) + return C; // X - undef -> undef // undef - X -> undef @@ -688,11 +683,8 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (isNUW) return Op0; - unsigned BitWidth = Op1->getType()->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(Op1, KnownZero, KnownOne, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (KnownZero == ~APInt::getSignBit(BitWidth)) { + KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (Known.Zero.isMaxSignedValue()) { // Op1 is either 0 or the minimum signed value. If the sub is NSW, then // Op1 must be 0 because negating the minimum signed value is undefined. if (isNSW) @@ -779,7 +771,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return ConstantExpr::getIntegerCast(Result, Op0->getType(), true); // i1 sub -> xor. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->getScalarType()->isIntegerTy(1)) if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -796,24 +788,16 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, } Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); } /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) + return C; // fadd X, -0 ==> X if (match(Op1, m_NegZero())) @@ -845,11 +829,9 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL); - } + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; // fsub X, 0 ==> X if (match(Op1, m_Zero())) @@ -878,40 +860,28 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, } /// Given the operands for an FMul, see if we can fold the result -static Value *SimplifyFMulInst(Value *Op0, Value *Op1, - FastMathFlags FMF, - const Query &Q, - unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } +static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) + return C; - // fmul X, 1.0 ==> X - if (match(Op1, m_FPOne())) - return Op0; + // fmul X, 1.0 ==> X + if (match(Op1, m_FPOne())) + return Op0; - // fmul nnan nsz X, 0 ==> 0 - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) - return Op1; + // fmul nnan nsz X, 0 ==> 0 + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) + return Op1; - return nullptr; + return nullptr; } /// Given operands for a Mul, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) + return C; // X * undef -> 0 if (match(Op1, m_Undef())) @@ -932,7 +902,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, return X; // i1 mul -> and. - if (MaxRecurse && Op0->getType()->isIntegerTy(1)) + if (MaxRecurse && Op0->getType()->getScalarType()->isIntegerTy(1)) if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse-1)) return V; @@ -964,77 +934,87 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, } Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit); } + Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); } Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFMulInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit); } -Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyMulInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyMulInst(Op0, Op1, Q, RecursionLimit); } -/// Given operands for an SDiv or UDiv, see if we can fold the result. -/// If not, this returns null. -static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); - - bool isSigned = Opcode == Instruction::SDiv; +/// Check for common or similar folds of integer division or integer remainder. +static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { + Type *Ty = Op0->getType(); // X / undef -> undef + // X % undef -> undef if (match(Op1, m_Undef())) return Op1; - // X / 0 -> undef, we don't need to preserve faults! + // X / 0 -> undef + // X % 0 -> undef + // We don't need to preserve faults! if (match(Op1, m_Zero())) - return UndefValue::get(Op1->getType()); + return UndefValue::get(Ty); + + // If any element of a constant divisor vector is zero, the whole op is undef. + auto *Op1C = dyn_cast<Constant>(Op1); + if (Op1C && Ty->isVectorTy()) { + unsigned NumElts = Ty->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = Op1C->getAggregateElement(i); + if (Elt && Elt->isNullValue()) + return UndefValue::get(Ty); + } + } // undef / X -> 0 + // undef % X -> 0 if (match(Op0, m_Undef())) - return Constant::getNullValue(Op0->getType()); + return Constant::getNullValue(Ty); - // 0 / X -> 0, we don't need to preserve faults! + // 0 / X -> 0 + // 0 % X -> 0 if (match(Op0, m_Zero())) return Op0; + // X / X -> 1 + // X % X -> 0 + if (Op0 == Op1) + return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty); + // X / 1 -> X - if (match(Op1, m_One())) - return Op0; + // X % 1 -> 0 + // If this is a boolean op (single-bit element type), we can't have + // division-by-zero or remainder-by-zero, so assume the divisor is 1. + if (match(Op1, m_One()) || Ty->getScalarType()->isIntegerTy(1)) + return IsDiv ? Op0 : Constant::getNullValue(Ty); - if (Op0->getType()->isIntegerTy(1)) - // It can't be division by zero, hence it must be division by one. - return Op0; + return nullptr; +} - // X / X -> 1 - if (Op0 == Op1) - return ConstantInt::get(Op0->getType(), 1); +/// Given operands for an SDiv or UDiv, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; + + if (Value *V = simplifyDivRem(Op0, Op1, true)) + return V; + + bool isSigned = Opcode == Instruction::SDiv; // (X * Y) / Y -> X if the multiplication does not overflow. Value *X = nullptr, *Y = nullptr; @@ -1061,7 +1041,7 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (!isSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) && match(Op1, m_ConstantInt(C2))) { bool Overflow; - C1->getValue().umul_ov(C2->getValue(), Overflow); + (void)C1->getValue().umul_ov(C2->getValue(), Overflow); if (Overflow) return Constant::getNullValue(Op0->getType()); } @@ -1083,7 +1063,7 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, /// Given operands for an SDiv, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySDivInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse)) return V; @@ -1091,17 +1071,13 @@ static Value *SimplifySDivInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifySDivInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifySDivInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a UDiv, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse)) return V; @@ -1119,16 +1095,15 @@ static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyUDivInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyUDivInst(Op0, Op1, Q, RecursionLimit); } static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &Q, unsigned) { + const SimplifyQuery &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) + return C; + // undef / X -> undef (the undef could be a snan). if (match(Op0, m_Undef())) return Op0; @@ -1166,49 +1141,19 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, } Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFDivInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit); } /// Given operands for an SRem or URem, see if we can fold the result. /// If not, this returns null. static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); - - // X % undef -> undef - if (match(Op1, m_Undef())) - return Op1; - - // undef % X -> 0 - if (match(Op0, m_Undef())) - return Constant::getNullValue(Op0->getType()); - - // 0 % X -> 0, we don't need to preserve faults! - if (match(Op0, m_Zero())) - return Op0; + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; - // X % 0 -> undef, we don't need to preserve faults! - if (match(Op1, m_Zero())) - return UndefValue::get(Op0->getType()); - - // X % 1 -> 0 - if (match(Op1, m_One())) - return Constant::getNullValue(Op0->getType()); - - if (Op0->getType()->isIntegerTy(1)) - // It can't be remainder by zero, hence it must be remainder by one. - return Constant::getNullValue(Op0->getType()); - - // X % X -> 0 - if (Op0 == Op1) - return Constant::getNullValue(Op0->getType()); + if (Value *V = simplifyDivRem(Op0, Op1, false)) + return V; // (X % Y) % Y -> X % Y if ((Opcode == Instruction::SRem && @@ -1234,7 +1179,7 @@ static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, /// Given operands for an SRem, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySRemInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse)) return V; @@ -1242,17 +1187,13 @@ static Value *SimplifySRemInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifySRemInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifySRemInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a URem, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRem(Instruction::URem, Op0, Op1, Q, MaxRecurse)) return V; @@ -1270,16 +1211,15 @@ static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyURemInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyURemInst(Op0, Op1, Q, RecursionLimit); } static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const Query &, unsigned) { + const SimplifyQuery &Q, unsigned) { + if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) + return C; + // undef % X -> undef (the undef could be a snan). if (match(Op0, m_Undef())) return Op0; @@ -1298,12 +1238,8 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, } Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFRemInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit); } /// Returns true if a shift by \c Amount always yields undef. @@ -1335,11 +1271,10 @@ static bool isUndefShift(Value *Amount) { /// Given operands for an Shl, LShr or AShr, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, - const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) - if (Constant *C1 = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); +static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) + return C; // 0 shift by X -> 0 if (match(Op0, m_Zero())) @@ -1367,18 +1302,14 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, // If any bits in the shift amount make that value greater than or equal to // the number of bits in the type, the shift is undefined. - unsigned BitWidth = Op1->getType()->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(Op1, KnownZero, KnownOne, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (KnownOne.getLimitedValue() >= BitWidth) + KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (Known.One.getLimitedValue() >= Known.getBitWidth()) return UndefValue::get(Op0->getType()); // If all valid bits in the shift amount are known zero, the first operand is // unchanged. - unsigned NumValidShiftBits = Log2_32_Ceil(BitWidth); - APInt ShiftAmountMask = APInt::getLowBitsSet(BitWidth, NumValidShiftBits); - if ((KnownZero & ShiftAmountMask) == ShiftAmountMask) + unsigned NumValidShiftBits = Log2_32_Ceil(Known.getBitWidth()); + if (Known.countMinTrailingZeros() >= NumValidShiftBits) return Op0; return nullptr; @@ -1386,8 +1317,8 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, /// \brief Given operands for an Shl, LShr or AShr, see if we can /// fold the result. If not, this returns null. -static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, - bool isExact, const Query &Q, +static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, bool isExact, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) return V; @@ -1403,12 +1334,8 @@ static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, // The low bit cannot be shifted out of an exact shift if it is set. if (isExact) { - unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); - APInt Op0KnownZero(BitWidth, 0); - APInt Op0KnownOne(BitWidth, 0); - computeKnownBits(Op0, Op0KnownZero, Op0KnownOne, Q.DL, /*Depth=*/0, Q.AC, - Q.CxtI, Q.DT); - if (Op0KnownOne[0]) + KnownBits Op0Known = computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + if (Op0Known.One[0]) return Op0; } @@ -1418,7 +1345,7 @@ static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, /// Given operands for an Shl, see if we can fold the result. /// If not, this returns null. static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, Q, MaxRecurse)) return V; @@ -1435,17 +1362,14 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, } Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); } /// Given operands for an LShr, see if we can fold the result. /// If not, this returns null. static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, MaxRecurse)) return V; @@ -1459,18 +1383,14 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, } Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyLShrInst(Op0, Op1, isExact, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyLShrInst(Op0, Op1, isExact, Q, RecursionLimit); } /// Given operands for an AShr, see if we can fold the result. /// If not, this returns null. static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q, MaxRecurse)) return V; @@ -1493,14 +1413,12 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, } Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyAShrInst(Op0, Op1, isExact, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyAShrInst(Op0, Op1, isExact, Q, RecursionLimit); } +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, ICmpInst *UnsignedICmp, bool IsAnd) { Value *X, *Y; @@ -1569,29 +1487,75 @@ static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { /// Commuted variants are assumed to be handled by calling this function again /// with the parameters swapped. -static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { - if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) - return X; +static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { + ICmpInst::Predicate Pred0, Pred1; + Value *A ,*B; + if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || + !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) + return nullptr; - if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1)) - return X; + // We have (icmp Pred0, A, B) | (icmp Pred1, A, B). + // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we + // can eliminate Op0 from this 'or'. + if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) + return Op1; + + // Check for any combination of predicates that cover the entire range of + // possibilities. + if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || + (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) || + (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) || + (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE)) + return getTrue(Op0->getType()); + + return nullptr; +} + +/// Test if a pair of compares with a shared operand and 2 constants has an +/// empty set intersection, full set union, or if one compare is a superset of +/// the other. +static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { + // Look for this pattern: {and/or} (icmp X, C0), (icmp X, C1)). + if (Cmp0->getOperand(0) != Cmp1->getOperand(0)) + return nullptr; - // Look for this pattern: (icmp V, C0) & (icmp V, C1)). - Type *ITy = Op0->getType(); - ICmpInst::Predicate Pred0, Pred1; const APInt *C0, *C1; - Value *V; - if (match(Op0, m_ICmp(Pred0, m_Value(V), m_APInt(C0))) && - match(Op1, m_ICmp(Pred1, m_Specific(V), m_APInt(C1)))) { - // Make a constant range that's the intersection of the two icmp ranges. - // If the intersection is empty, we know that the result is false. - auto Range0 = ConstantRange::makeAllowedICmpRegion(Pred0, *C0); - auto Range1 = ConstantRange::makeAllowedICmpRegion(Pred1, *C1); - if (Range0.intersectWith(Range1).isEmptySet()) - return getFalse(ITy); - } + if (!match(Cmp0->getOperand(1), m_APInt(C0)) || + !match(Cmp1->getOperand(1), m_APInt(C1))) + return nullptr; + + auto Range0 = ConstantRange::makeExactICmpRegion(Cmp0->getPredicate(), *C0); + auto Range1 = ConstantRange::makeExactICmpRegion(Cmp1->getPredicate(), *C1); + + // For and-of-compares, check if the intersection is empty: + // (icmp X, C0) && (icmp X, C1) --> empty set --> false + if (IsAnd && Range0.intersectWith(Range1).isEmptySet()) + return getFalse(Cmp0->getType()); + + // For or-of-compares, check if the union is full: + // (icmp X, C0) || (icmp X, C1) --> full set --> true + if (!IsAnd && Range0.unionWith(Range1).isFullSet()) + return getTrue(Cmp0->getType()); + + // Is one range a superset of the other? + // If this is and-of-compares, take the smaller set: + // (icmp sgt X, 4) && (icmp sgt X, 42) --> icmp sgt X, 42 + // If this is or-of-compares, take the larger set: + // (icmp sgt X, 4) || (icmp sgt X, 42) --> icmp sgt X, 4 + if (Range0.contains(Range1)) + return IsAnd ? Cmp1 : Cmp0; + if (Range1.contains(Range0)) + return IsAnd ? Cmp0 : Cmp1; + return nullptr; +} + +static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { // (icmp (add V, C0), C1) & (icmp V, C0) + ICmpInst::Predicate Pred0, Pred1; + const APInt *C0, *C1; + Value *V; if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) return nullptr; @@ -1602,6 +1566,7 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (AddInst->getOperand(1) != Op1->getOperand(1)) return nullptr; + Type *ITy = Op0->getType(); bool isNSW = AddInst->hasNoSignedWrap(); bool isNUW = AddInst->hasNoUnsignedWrap(); @@ -1632,17 +1597,132 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return nullptr; } +static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) + return X; + if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/true)) + return X; + + if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1)) + return X; + if (Value *X = simplifyAndOfICmpsWithSameOperands(Op1, Op0)) + return X; + + if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true)) + return X; + + if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1)) + return X; + if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0)) + return X; + + return nullptr; +} + +static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { + // (icmp (add V, C0), C1) | (icmp V, C0) + ICmpInst::Predicate Pred0, Pred1; + const APInt *C0, *C1; + Value *V; + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) + return nullptr; + + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) + return nullptr; + + auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + if (AddInst->getOperand(1) != Op1->getOperand(1)) + return nullptr; + + Type *ITy = Op0->getType(); + bool isNSW = AddInst->hasNoSignedWrap(); + bool isNUW = AddInst->hasNoUnsignedWrap(); + + const APInt Delta = *C1 - *C0; + if (C0->isStrictlyPositive()) { + if (Delta == 2) { + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + if (Delta == 1) { + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + } + if (C0->getBoolValue() && isNUW) { + if (Delta == 2) + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Delta == 1) + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + return nullptr; +} + +static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) + return X; + if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/false)) + return X; + + if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1)) + return X; + if (Value *X = simplifyOrOfICmpsWithSameOperands(Op1, Op0)) + return X; + + if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false)) + return X; + + if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1)) + return X; + if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0)) + return X; + + return nullptr; +} + +static Value *simplifyAndOrOfICmps(Value *Op0, Value *Op1, bool IsAnd) { + // Look through casts of the 'and' operands to find compares. + auto *Cast0 = dyn_cast<CastInst>(Op0); + auto *Cast1 = dyn_cast<CastInst>(Op1); + if (Cast0 && Cast1 && Cast0->getOpcode() == Cast1->getOpcode() && + Cast0->getSrcTy() == Cast1->getSrcTy()) { + Op0 = Cast0->getOperand(0); + Op1 = Cast1->getOperand(0); + } + + auto *Cmp0 = dyn_cast<ICmpInst>(Op0); + auto *Cmp1 = dyn_cast<ICmpInst>(Op1); + if (!Cmp0 || !Cmp1) + return nullptr; + + Value *V = + IsAnd ? simplifyAndOfICmps(Cmp0, Cmp1) : simplifyOrOfICmps(Cmp0, Cmp1); + if (!V) + return nullptr; + if (!Cast0) + return V; + + // If we looked through casts, we can only handle a constant simplification + // because we are not allowed to create a cast instruction here. + if (auto *C = dyn_cast<Constant>(V)) + return ConstantExpr::getCast(Cast0->getOpcode(), C, Cast0->getType()); + + return nullptr; +} + /// Given operands for an And, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q)) + return C; // X & undef -> 0 if (match(Op1, m_Undef())) @@ -1676,6 +1756,24 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, (A == Op0 || B == Op0)) return Op0; + // A mask that only clears known zeros of a shifted value is a no-op. + Value *X; + const APInt *Mask; + const APInt *ShAmt; + if (match(Op1, m_APInt(Mask))) { + // If all bits in the inverted and shifted mask are clear: + // and (shl X, ShAmt), Mask --> shl X, ShAmt + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShAmt))) && + (~(*Mask)).lshr(*ShAmt).isNullValue()) + return Op0; + + // If all bits in the inverted and shifted mask are clear: + // and (lshr X, ShAmt), Mask --> lshr X, ShAmt + if (match(Op0, m_LShr(m_Value(X), m_APInt(ShAmt))) && + (~(*Mask)).shl(*ShAmt).isNullValue()) + return Op0; + } + // A & (-A) = A if A is a power of two or zero. if (match(Op0, m_Neg(m_Specific(Op1))) || match(Op1, m_Neg(m_Specific(Op0)))) { @@ -1687,32 +1785,8 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, return Op1; } - if (auto *ICILHS = dyn_cast<ICmpInst>(Op0)) { - if (auto *ICIRHS = dyn_cast<ICmpInst>(Op1)) { - if (Value *V = SimplifyAndOfICmps(ICILHS, ICIRHS)) - return V; - if (Value *V = SimplifyAndOfICmps(ICIRHS, ICILHS)) - return V; - } - } - - // The compares may be hidden behind casts. Look through those and try the - // same folds as above. - auto *Cast0 = dyn_cast<CastInst>(Op0); - auto *Cast1 = dyn_cast<CastInst>(Op1); - if (Cast0 && Cast1 && Cast0->getOpcode() == Cast1->getOpcode() && - Cast0->getSrcTy() == Cast1->getSrcTy()) { - auto *Cmp0 = dyn_cast<ICmpInst>(Cast0->getOperand(0)); - auto *Cmp1 = dyn_cast<ICmpInst>(Cast1->getOperand(0)); - if (Cmp0 && Cmp1) { - Instruction::CastOps CastOpc = Cast0->getOpcode(); - Type *ResultType = Cast0->getType(); - if (auto *V = dyn_cast_or_null<Constant>(SimplifyAndOfICmps(Cmp0, Cmp1))) - return ConstantExpr::getCast(CastOpc, V, ResultType); - if (auto *V = dyn_cast_or_null<Constant>(SimplifyAndOfICmps(Cmp1, Cmp0))) - return ConstantExpr::getCast(CastOpc, V, ResultType); - } - } + if (Value *V = simplifyAndOrOfICmps(Op0, Op1, true)) + return V; // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, @@ -1746,105 +1820,16 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyAndInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); -} - -/// Commuted variants are assumed to be handled by calling this function again -/// with the parameters swapped. -static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { - ICmpInst::Predicate Pred0, Pred1; - Value *A ,*B; - if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || - !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) - return nullptr; - - // We have (icmp Pred0, A, B) | (icmp Pred1, A, B). - // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we - // can eliminate Op0 from this 'or'. - if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) - return Op1; - - // Check for any combination of predicates that cover the entire range of - // possibilities. - if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || - (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) || - (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) || - (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE)) - return getTrue(Op0->getType()); - - return nullptr; -} - -/// Commuted variants are assumed to be handled by calling this function again -/// with the parameters swapped. -static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { - if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) - return X; - - if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1)) - return X; - - // (icmp (add V, C0), C1) | (icmp V, C0) - ICmpInst::Predicate Pred0, Pred1; - const APInt *C0, *C1; - Value *V; - if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) - return nullptr; - - if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) - return nullptr; - - auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); - if (AddInst->getOperand(1) != Op1->getOperand(1)) - return nullptr; - - Type *ITy = Op0->getType(); - bool isNSW = AddInst->hasNoSignedWrap(); - bool isNUW = AddInst->hasNoUnsignedWrap(); - - const APInt Delta = *C1 - *C0; - if (C0->isStrictlyPositive()) { - if (Delta == 2) { - if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) - return getTrue(ITy); - if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW) - return getTrue(ITy); - } - if (Delta == 1) { - if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE) - return getTrue(ITy); - if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW) - return getTrue(ITy); - } - } - if (C0->getBoolValue() && isNUW) { - if (Delta == 2) - if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) - return getTrue(ITy); - if (Delta == 1) - if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE) - return getTrue(ITy); - } - - return nullptr; +Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyAndInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for an Or, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q)) + return C; // X | undef -> -1 if (match(Op1, m_Undef())) @@ -1888,14 +1873,45 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, (A == Op0 || B == Op0)) return Constant::getAllOnesValue(Op0->getType()); - if (auto *ICILHS = dyn_cast<ICmpInst>(Op0)) { - if (auto *ICIRHS = dyn_cast<ICmpInst>(Op1)) { - if (Value *V = SimplifyOrOfICmps(ICILHS, ICIRHS)) - return V; - if (Value *V = SimplifyOrOfICmps(ICIRHS, ICILHS)) - return V; - } - } + // (A & ~B) | (A ^ B) -> (A ^ B) + // (~B & A) | (A ^ B) -> (A ^ B) + // (A & ~B) | (B ^ A) -> (B ^ A) + // (~B & A) | (B ^ A) -> (B ^ A) + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + (match(Op0, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op0, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) + return Op1; + + // Commute the 'or' operands. + // (A ^ B) | (A & ~B) -> (A ^ B) + // (A ^ B) | (~B & A) -> (A ^ B) + // (B ^ A) | (A & ~B) -> (B ^ A) + // (B ^ A) | (~B & A) -> (B ^ A) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + (match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) + return Op0; + + // (A & B) | (~A ^ B) -> (~A ^ B) + // (B & A) | (~A ^ B) -> (~A ^ B) + // (A & B) | (B ^ ~A) -> (B ^ ~A) + // (B & A) | (B ^ ~A) -> (B ^ ~A) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + (match(Op1, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op1, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) + return Op1; + + // (~A ^ B) | (A & B) -> (~A ^ B) + // (~A ^ B) | (B & A) -> (~A ^ B) + // (B ^ ~A) | (A & B) -> (B ^ ~A) + // (B ^ ~A) | (B & A) -> (B ^ ~A) + if (match(Op1, m_And(m_Value(A), m_Value(B))) && + (match(Op0, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) + return Op0; + + if (Value *V = simplifyAndOrOfICmps(Op0, Op1, false)) + return V; // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, @@ -1914,37 +1930,27 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, MaxRecurse)) return V; - // (A & C)|(B & D) - Value *C = nullptr, *D = nullptr; - if (match(Op0, m_And(m_Value(A), m_Value(C))) && - match(Op1, m_And(m_Value(B), m_Value(D)))) { - ConstantInt *C1 = dyn_cast<ConstantInt>(C); - ConstantInt *C2 = dyn_cast<ConstantInt>(D); - if (C1 && C2 && (C1->getValue() == ~C2->getValue())) { + // (A & C1)|(B & C2) + const APInt *C1, *C2; + if (match(Op0, m_And(m_Value(A), m_APInt(C1))) && + match(Op1, m_And(m_Value(B), m_APInt(C2)))) { + if (*C1 == ~*C2) { // (A & C1)|(B & C2) // If we have: ((V + N) & C1) | (V & C2) // .. and C2 = ~C1 and C2 is 0+1+ and (N & C2) == 0 // replace with V+N. - Value *V1, *V2; - if ((C2->getValue() & (C2->getValue() + 1)) == 0 && // C2 == 0+1+ - match(A, m_Add(m_Value(V1), m_Value(V2)))) { + Value *N; + if (C2->isMask() && // C2 == 0+1+ + match(A, m_c_Add(m_Specific(B), m_Value(N)))) { // Add commutes, try both ways. - if (V1 == B && - MaskedValueIsZero(V2, C2->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return A; - if (V2 == B && - MaskedValueIsZero(V1, C2->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (MaskedValueIsZero(N, *C2, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return A; } // Or commutes, try both ways. - if ((C1->getValue() & (C1->getValue() + 1)) == 0 && - match(B, m_Add(m_Value(V1), m_Value(V2)))) { + if (C1->isMask() && + match(B, m_c_Add(m_Specific(A), m_Value(N)))) { // Add commutes, try both ways. - if (V1 == A && - MaskedValueIsZero(V2, C1->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) - return B; - if (V2 == A && - MaskedValueIsZero(V1, C1->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return B; } } @@ -1959,25 +1965,16 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyOrInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyOrInst(Op0, Op1, Q, RecursionLimit); } /// Given operands for a Xor, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, +static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) - return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL); - - // Canonicalize the constant to the RHS. - std::swap(Op0, Op1); - } + if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q)) + return C; // A ^ undef -> undef if (match(Op1, m_Undef())) @@ -2013,14 +2010,11 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, return nullptr; } -Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyXorInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { + return ::SimplifyXorInst(Op0, Op1, Q, RecursionLimit); } + static Type *GetCompareTy(Value *Op) { return CmpInst::makeCmpResultType(Op->getType()); } @@ -2254,34 +2248,55 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, /// Fold an icmp when its operands have i1 scalar type. static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q) { + Value *RHS, const SimplifyQuery &Q) { Type *ITy = GetCompareTy(LHS); // The return type. Type *OpTy = LHS->getType(); // The operand type. if (!OpTy->getScalarType()->isIntegerTy(1)) return nullptr; - switch (Pred) { - default: - break; - case ICmpInst::ICMP_EQ: - // X == 1 -> X - if (match(RHS, m_One())) - return LHS; - break; - case ICmpInst::ICMP_NE: - // X != 0 -> X - if (match(RHS, m_Zero())) + // A boolean compared to true/false can be simplified in 14 out of the 20 + // (10 predicates * 2 constants) possible combinations. Cases not handled here + // require a 'not' of the LHS, so those must be transformed in InstCombine. + if (match(RHS, m_Zero())) { + switch (Pred) { + case CmpInst::ICMP_NE: // X != 0 -> X + case CmpInst::ICMP_UGT: // X >u 0 -> X + case CmpInst::ICMP_SLT: // X <s 0 -> X return LHS; - break; - case ICmpInst::ICMP_UGT: - // X >u 0 -> X - if (match(RHS, m_Zero())) + + case CmpInst::ICMP_ULT: // X <u 0 -> false + case CmpInst::ICMP_SGT: // X >s 0 -> false + return getFalse(ITy); + + case CmpInst::ICMP_UGE: // X >=u 0 -> true + case CmpInst::ICMP_SLE: // X <=s 0 -> true + return getTrue(ITy); + + default: break; + } + } else if (match(RHS, m_One())) { + switch (Pred) { + case CmpInst::ICMP_EQ: // X == 1 -> X + case CmpInst::ICMP_UGE: // X >=u 1 -> X + case CmpInst::ICMP_SLE: // X <=s -1 -> X return LHS; + + case CmpInst::ICMP_UGT: // X >u 1 -> false + case CmpInst::ICMP_SLT: // X <s -1 -> false + return getFalse(ITy); + + case CmpInst::ICMP_ULE: // X <=u 1 -> true + case CmpInst::ICMP_SGE: // X >=s -1 -> true + return getTrue(ITy); + + default: break; + } + } + + switch (Pred) { + default: break; case ICmpInst::ICMP_UGE: - // X >=u 1 -> X - if (match(RHS, m_One())) - return LHS; if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) return getTrue(ITy); break; @@ -2296,16 +2311,6 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) return getTrue(ITy); break; - case ICmpInst::ICMP_SLT: - // X <s 0 -> X - if (match(RHS, m_Zero())) - return LHS; - break; - case ICmpInst::ICMP_SLE: - // X <=s -1 -> X - if (match(RHS, m_One())) - return LHS; - break; case ICmpInst::ICMP_ULE: if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) return getTrue(ITy); @@ -2317,12 +2322,11 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, /// Try hard to fold icmp with zero RHS because this is a common case. static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q) { + Value *RHS, const SimplifyQuery &Q) { if (!match(RHS, m_Zero())) return nullptr; Type *ITy = GetCompareTy(LHS); // The return type. - bool LHSKnownNonNegative, LHSKnownNegative; switch (Pred) { default: llvm_unreachable("Unknown ICmp predicate!"); @@ -2340,43 +2344,202 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(ITy); break; - case ICmpInst::ICMP_SLT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + case ICmpInst::ICMP_SLT: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getTrue(ITy); - if (LHSKnownNonNegative) + if (LHSKnown.isNonNegative()) return getFalse(ITy); break; - case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SLE: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getTrue(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (LHSKnown.isNonNegative() && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getFalse(ITy); break; - case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SGE: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getFalse(ITy); - if (LHSKnownNonNegative) + if (LHSKnown.isNonNegative()) return getTrue(ITy); break; - case ICmpInst::ICMP_SGT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SGT: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getFalse(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (LHSKnown.isNonNegative() && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(ITy); break; } + } return nullptr; } +/// Many binary operators with a constant operand have an easy-to-compute +/// range of outputs. This can be used to fold a comparison to always true or +/// always false. +static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { + unsigned Width = Lower.getBitWidth(); + const APInt *C; + switch (BO.getOpcode()) { + case Instruction::Add: + if (match(BO.getOperand(1), m_APInt(C)) && *C != 0) { + // FIXME: If we have both nuw and nsw, we should reduce the range further. + if (BO.hasNoUnsignedWrap()) { + // 'add nuw x, C' produces [C, UINT_MAX]. + Lower = *C; + } else if (BO.hasNoSignedWrap()) { + if (C->isNegative()) { + // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C]. + Lower = APInt::getSignedMinValue(Width); + Upper = APInt::getSignedMaxValue(Width) + *C + 1; + } else { + // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX]. + Lower = APInt::getSignedMinValue(Width) + *C; + Upper = APInt::getSignedMaxValue(Width) + 1; + } + } + } + break; + + case Instruction::And: + if (match(BO.getOperand(1), m_APInt(C))) + // 'and x, C' produces [0, C]. + Upper = *C + 1; + break; + + case Instruction::Or: + if (match(BO.getOperand(1), m_APInt(C))) + // 'or x, C' produces [C, UINT_MAX]. + Lower = *C; + break; + + case Instruction::AShr: + if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { + // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C]. + Lower = APInt::getSignedMinValue(Width).ashr(*C); + Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + unsigned ShiftAmount = Width - 1; + if (*C != 0 && BO.isExact()) + ShiftAmount = C->countTrailingZeros(); + if (C->isNegative()) { + // 'ashr C, x' produces [C, C >> (Width-1)] + Lower = *C; + Upper = C->ashr(ShiftAmount) + 1; + } else { + // 'ashr C, x' produces [C >> (Width-1), C] + Lower = C->ashr(ShiftAmount); + Upper = *C + 1; + } + } + break; + + case Instruction::LShr: + if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { + // 'lshr x, C' produces [0, UINT_MAX >> C]. + Upper = APInt::getAllOnesValue(Width).lshr(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + // 'lshr C, x' produces [C >> (Width-1), C]. + unsigned ShiftAmount = Width - 1; + if (*C != 0 && BO.isExact()) + ShiftAmount = C->countTrailingZeros(); + Lower = C->lshr(ShiftAmount); + Upper = *C + 1; + } + break; + + case Instruction::Shl: + if (match(BO.getOperand(0), m_APInt(C))) { + if (BO.hasNoUnsignedWrap()) { + // 'shl nuw C, x' produces [C, C << CLZ(C)] + Lower = *C; + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw? + if (C->isNegative()) { + // 'shl nsw C, x' produces [C << CLO(C)-1, C] + unsigned ShiftAmount = C->countLeadingOnes() - 1; + Lower = C->shl(ShiftAmount); + Upper = *C + 1; + } else { + // 'shl nsw C, x' produces [C, C << CLZ(C)-1] + unsigned ShiftAmount = C->countLeadingZeros() - 1; + Lower = *C; + Upper = C->shl(ShiftAmount) + 1; + } + } + } + break; + + case Instruction::SDiv: + if (match(BO.getOperand(1), m_APInt(C))) { + APInt IntMin = APInt::getSignedMinValue(Width); + APInt IntMax = APInt::getSignedMaxValue(Width); + if (C->isAllOnesValue()) { + // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] + // where C != -1 and C != 0 and C != 1 + Lower = IntMin + 1; + Upper = IntMax + 1; + } else if (C->countLeadingZeros() < Width - 1) { + // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C] + // where C != -1 and C != 0 and C != 1 + Lower = IntMin.sdiv(*C); + Upper = IntMax.sdiv(*C); + if (Lower.sgt(Upper)) + std::swap(Lower, Upper); + Upper = Upper + 1; + assert(Upper != Lower && "Upper part of range has wrapped!"); + } + } else if (match(BO.getOperand(0), m_APInt(C))) { + if (C->isMinSignedValue()) { + // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. + Lower = *C; + Upper = Lower.lshr(1) + 1; + } else { + // 'sdiv C, x' produces [-|C|, |C|]. + Upper = C->abs() + 1; + Lower = (-Upper) + 1; + } + } + break; + + case Instruction::UDiv: + if (match(BO.getOperand(1), m_APInt(C)) && *C != 0) { + // 'udiv x, C' produces [0, UINT_MAX / C]. + Upper = APInt::getMaxValue(Width).udiv(*C) + 1; + } else if (match(BO.getOperand(0), m_APInt(C))) { + // 'udiv C, x' produces [0, C]. + Upper = *C + 1; + } + break; + + case Instruction::SRem: + if (match(BO.getOperand(1), m_APInt(C))) { + // 'srem x, C' produces (-|C|, |C|). + Upper = C->abs(); + Lower = (-Upper) + 1; + } + break; + + case Instruction::URem: + if (match(BO.getOperand(1), m_APInt(C))) + // 'urem x, C' produces [0, C). + Upper = *C; + break; + + default: + break; + } +} + static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, Value *RHS) { const APInt *C; @@ -2390,114 +2553,12 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, if (RHS_CR.isFullSet()) return ConstantInt::getTrue(GetCompareTy(RHS)); - // Many binary operators with constant RHS have easy to compute constant - // range. Use them to check whether the comparison is a tautology. + // Find the range of possible values for binary operators. unsigned Width = C->getBitWidth(); APInt Lower = APInt(Width, 0); APInt Upper = APInt(Width, 0); - const APInt *C2; - if (match(LHS, m_URem(m_Value(), m_APInt(C2)))) { - // 'urem x, C2' produces [0, C2). - Upper = *C2; - } else if (match(LHS, m_SRem(m_Value(), m_APInt(C2)))) { - // 'srem x, C2' produces (-|C2|, |C2|). - Upper = C2->abs(); - Lower = (-Upper) + 1; - } else if (match(LHS, m_UDiv(m_APInt(C2), m_Value()))) { - // 'udiv C2, x' produces [0, C2]. - Upper = *C2 + 1; - } else if (match(LHS, m_UDiv(m_Value(), m_APInt(C2)))) { - // 'udiv x, C2' produces [0, UINT_MAX / C2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (*C2 != 0) - Upper = NegOne.udiv(*C2) + 1; - } else if (match(LHS, m_SDiv(m_APInt(C2), m_Value()))) { - if (C2->isMinSignedValue()) { - // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. - Lower = *C2; - Upper = Lower.lshr(1) + 1; - } else { - // 'sdiv C2, x' produces [-|C2|, |C2|]. - Upper = C2->abs() + 1; - Lower = (-Upper) + 1; - } - } else if (match(LHS, m_SDiv(m_Value(), m_APInt(C2)))) { - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (C2->isAllOnesValue()) { - // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] - // where C2 != -1 and C2 != 0 and C2 != 1 - Lower = IntMin + 1; - Upper = IntMax + 1; - } else if (C2->countLeadingZeros() < Width - 1) { - // 'sdiv x, C2' produces [INT_MIN / C2, INT_MAX / C2] - // where C2 != -1 and C2 != 0 and C2 != 1 - Lower = IntMin.sdiv(*C2); - Upper = IntMax.sdiv(*C2); - if (Lower.sgt(Upper)) - std::swap(Lower, Upper); - Upper = Upper + 1; - assert(Upper != Lower && "Upper part of range has wrapped!"); - } - } else if (match(LHS, m_NUWShl(m_APInt(C2), m_Value()))) { - // 'shl nuw C2, x' produces [C2, C2 << CLZ(C2)] - Lower = *C2; - Upper = Lower.shl(Lower.countLeadingZeros()) + 1; - } else if (match(LHS, m_NSWShl(m_APInt(C2), m_Value()))) { - if (C2->isNegative()) { - // 'shl nsw C2, x' produces [C2 << CLO(C2)-1, C2] - unsigned ShiftAmount = C2->countLeadingOnes() - 1; - Lower = C2->shl(ShiftAmount); - Upper = *C2 + 1; - } else { - // 'shl nsw C2, x' produces [C2, C2 << CLZ(C2)-1] - unsigned ShiftAmount = C2->countLeadingZeros() - 1; - Lower = *C2; - Upper = C2->shl(ShiftAmount) + 1; - } - } else if (match(LHS, m_LShr(m_Value(), m_APInt(C2)))) { - // 'lshr x, C2' produces [0, UINT_MAX >> C2]. - APInt NegOne = APInt::getAllOnesValue(Width); - if (C2->ult(Width)) - Upper = NegOne.lshr(*C2) + 1; - } else if (match(LHS, m_LShr(m_APInt(C2), m_Value()))) { - // 'lshr C2, x' produces [C2 >> (Width-1), C2]. - unsigned ShiftAmount = Width - 1; - if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = C2->countTrailingZeros(); - Lower = C2->lshr(ShiftAmount); - Upper = *C2 + 1; - } else if (match(LHS, m_AShr(m_Value(), m_APInt(C2)))) { - // 'ashr x, C2' produces [INT_MIN >> C2, INT_MAX >> C2]. - APInt IntMin = APInt::getSignedMinValue(Width); - APInt IntMax = APInt::getSignedMaxValue(Width); - if (C2->ult(Width)) { - Lower = IntMin.ashr(*C2); - Upper = IntMax.ashr(*C2) + 1; - } - } else if (match(LHS, m_AShr(m_APInt(C2), m_Value()))) { - unsigned ShiftAmount = Width - 1; - if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact()) - ShiftAmount = C2->countTrailingZeros(); - if (C2->isNegative()) { - // 'ashr C2, x' produces [C2, C2 >> (Width-1)] - Lower = *C2; - Upper = C2->ashr(ShiftAmount) + 1; - } else { - // 'ashr C2, x' produces [C2 >> (Width-1), C2] - Lower = C2->ashr(ShiftAmount); - Upper = *C2 + 1; - } - } else if (match(LHS, m_Or(m_Value(), m_APInt(C2)))) { - // 'or x, C2' produces [C2, UINT_MAX]. - Lower = *C2; - } else if (match(LHS, m_And(m_Value(), m_APInt(C2)))) { - // 'and x, C2' produces [0, C2]. - Upper = *C2 + 1; - } else if (match(LHS, m_NUWAdd(m_Value(), m_APInt(C2)))) { - // 'add nuw x, C2' produces [C2, UINT_MAX]. - Lower = *C2; - } + if (auto *BO = dyn_cast<BinaryOperator>(LHS)) + setLimitsForBinOp(*BO, Lower, Upper); ConstantRange LHS_CR = Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); @@ -2516,8 +2577,11 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, return nullptr; } +/// TODO: A large part of this logic is duplicated in InstCombine's +/// foldICmpBinOp(). We should be able to share that and avoid the code +/// duplication. static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q, + Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { Type *ITy = GetCompareTy(LHS); // The return type. @@ -2597,15 +2661,11 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return getTrue(ITy); if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { - bool RHSKnownNonNegative, RHSKnownNegative; - bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, Q.DL, 0, - Q.AC, Q.CxtI, Q.DT); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (RHSKnownNonNegative && YKnownNegative) + KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (RHSKnown.isNonNegative() && YKnown.isNegative()) return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); - if (RHSKnownNegative || YKnownNonNegative) + if (RHSKnown.isNegative() || YKnown.isNonNegative()) return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); } } @@ -2617,15 +2677,11 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return getFalse(ITy); if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) { - bool LHSKnownNonNegative, LHSKnownNegative; - bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, - Q.AC, Q.CxtI, Q.DT); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNonNegative && YKnownNegative) + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNonNegative() && YKnown.isNegative()) return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy); - if (LHSKnownNegative || YKnownNonNegative) + if (LHSKnown.isNegative() || YKnown.isNonNegative()) return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy); } } @@ -2672,28 +2728,27 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp pred (urem X, Y), Y if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { - bool KnownNonNegative, KnownNegative; switch (Pred) { default: break; case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: return getFalse(ITy); case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2703,28 +2758,27 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp pred X, (urem Y, X) if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) { - bool KnownNonNegative, KnownNegative; switch (Pred) { default: break; case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: return getTrue(ITy); case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2780,7 +2834,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return ConstantInt::getTrue(RHS->getContext()); } } - if (CIVal->isSignBit() && *CI2Val == 1) { + if (CIVal->isSignMask() && *CI2Val == 1) { if (Pred == ICmpInst::ICMP_UGT) return ConstantInt::getFalse(RHS->getContext()); if (Pred == ICmpInst::ICMP_ULE) @@ -2796,10 +2850,19 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, break; case Instruction::UDiv: case Instruction::LShr: - if (ICmpInst::isSigned(Pred)) + if (ICmpInst::isSigned(Pred) || !LBO->isExact() || !RBO->isExact()) break; - LLVM_FALLTHROUGH; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; case Instruction::SDiv: + if (!ICmpInst::isEquality(Pred) || !LBO->isExact() || !RBO->isExact()) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; case Instruction::AShr: if (!LBO->isExact() || !RBO->isExact()) break; @@ -2827,7 +2890,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, /// Simplify integer comparisons where at least one operand of the compare /// matches an integer min/max idiom. static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, - Value *RHS, const Query &Q, + Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { Type *ITy = GetCompareTy(LHS); // The return type. Value *A, *B; @@ -3031,7 +3094,7 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, /// Given operands for an ICmpInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); @@ -3064,8 +3127,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // If both operands have range metadata, use the metadata // to simplify the comparison. if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { - auto RHS_Instr = dyn_cast<Instruction>(RHS); - auto LHS_Instr = dyn_cast<Instruction>(LHS); + auto RHS_Instr = cast<Instruction>(RHS); + auto LHS_Instr = cast<Instruction>(LHS); if (RHS_Instr->getMetadata(LLVMContext::MD_range) && LHS_Instr->getMetadata(LLVMContext::MD_range)) { @@ -3302,12 +3365,9 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (ICmpInst::isEquality(Pred)) { const APInt *RHSVal; if (match(RHS, m_APInt(RHSVal))) { - unsigned BitWidth = RHSVal->getBitWidth(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, Q.DL, /*Depth=*/0, Q.AC, - Q.CxtI, Q.DT); - if (((LHSKnownZero & *RHSVal) != 0) || ((LHSKnownOne & ~(*RHSVal)) != 0)) + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.Zero.intersects(*RHSVal) || + !LHSKnown.One.isSubsetOf(*RHSVal)) return Pred == ICmpInst::ICMP_EQ ? ConstantInt::getFalse(ITy) : ConstantInt::getTrue(ITy); } @@ -3329,18 +3389,14 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyICmpInst(Predicate, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit); } /// Given operands for an FCmpInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - FastMathFlags FMF, const Query &Q, + FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!"); @@ -3462,22 +3518,22 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - FastMathFlags FMF, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, - Query(DL, TLI, DT, AC, CxtI), RecursionLimit); + FastMathFlags FMF, const SimplifyQuery &Q) { + return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } /// See if V simplifies when its operand Op is replaced with RepOp. static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const Query &Q, + const SimplifyQuery &Q, unsigned MaxRecurse) { // Trivial replacement. if (V == Op) return RepOp; + // We cannot replace a constant, and shouldn't even try. + if (isa<Constant>(Op)) + return nullptr; + auto *I = dyn_cast<Instruction>(V); if (!I) return nullptr; @@ -3620,7 +3676,7 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *TrueVal, /// Try to simplify a select instruction when its condition operand is an /// integer comparison. static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, - Value *FalseVal, const Query &Q, + Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) { ICmpInst::Predicate Pred; Value *CmpLHS, *CmpRHS; @@ -3699,7 +3755,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, - Value *FalseVal, const Query &Q, + Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) { // select true, X, Y -> X // select false, X, Y -> Y @@ -3715,9 +3771,9 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, return TrueVal; if (isa<UndefValue>(CondVal)) { // select undef, X, Y -> X or Y - if (isa<Constant>(TrueVal)) - return TrueVal; - return FalseVal; + if (isa<Constant>(FalseVal)) + return FalseVal; + return TrueVal; } if (isa<UndefValue>(TrueVal)) // select C, undef, X -> X return FalseVal; @@ -3732,18 +3788,14 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, } Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifySelectInst(Cond, TrueVal, FalseVal, - Query(DL, TLI, DT, AC, CxtI), RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Q, RecursionLimit); } /// Given operands for an GetElementPtrInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, - const Query &Q, unsigned) { + const SimplifyQuery &Q, unsigned) { // The type of the GEP pointer operand. unsigned AS = cast<PointerType>(Ops[0]->getType()->getScalarType())->getAddressSpace(); @@ -3757,6 +3809,8 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, Type *GEPTy = PointerType::get(LastType, AS); if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType())) + GEPTy = VectorType::get(GEPTy, VT->getNumElements()); if (isa<UndefValue>(Ops[0])) return UndefValue::get(GEPTy); @@ -3851,18 +3905,14 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, } Value *llvm::SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyGEPInst(SrcTy, Ops, - Query(DL, TLI, DT, AC, CxtI), RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyGEPInst(SrcTy, Ops, Q, RecursionLimit); } /// Given operands for an InsertValueInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, - ArrayRef<unsigned> Idxs, const Query &Q, + ArrayRef<unsigned> Idxs, const SimplifyQuery &Q, unsigned) { if (Constant *CAgg = dyn_cast<Constant>(Agg)) if (Constant *CVal = dyn_cast<Constant>(Val)) @@ -3888,18 +3938,16 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, return nullptr; } -Value *llvm::SimplifyInsertValueInst( - Value *Agg, Value *Val, ArrayRef<unsigned> Idxs, const DataLayout &DL, - const TargetLibraryInfo *TLI, const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyInsertValueInst(Agg, Val, Idxs, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, + ArrayRef<unsigned> Idxs, + const SimplifyQuery &Q) { + return ::SimplifyInsertValueInst(Agg, Val, Idxs, Q, RecursionLimit); } /// Given operands for an ExtractValueInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, - const Query &, unsigned) { + const SimplifyQuery &, unsigned) { if (auto *CAgg = dyn_cast<Constant>(Agg)) return ConstantFoldExtractValueInstruction(CAgg, Idxs); @@ -3922,18 +3970,13 @@ static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, } Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, - AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyExtractValueInst(Agg, Idxs, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyExtractValueInst(Agg, Idxs, Q, RecursionLimit); } /// Given operands for an ExtractElementInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &, +static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &, unsigned) { if (auto *CVec = dyn_cast<Constant>(Vec)) { if (auto *CIdx = dyn_cast<Constant>(Idx)) @@ -3956,15 +3999,13 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &, return nullptr; } -Value *llvm::SimplifyExtractElementInst( - Value *Vec, Value *Idx, const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) { - return ::SimplifyExtractElementInst(Vec, Idx, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); +Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx, + const SimplifyQuery &Q) { + return ::SimplifyExtractElementInst(Vec, Idx, Q, RecursionLimit); } /// See if we can fold the given phi. If not, returns null. -static Value *SimplifyPHINode(PHINode *PN, const Query &Q) { +static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { // If all of the PHI's incoming values are the same then replace the PHI node // with the common value. Value *CommonValue = nullptr; @@ -3997,7 +4038,7 @@ static Value *SimplifyPHINode(PHINode *PN, const Query &Q) { } static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, - Type *Ty, const Query &Q, unsigned MaxRecurse) { + Type *Ty, const SimplifyQuery &Q, unsigned MaxRecurse) { if (auto *C = dyn_cast<Constant>(Op)) return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL); @@ -4031,12 +4072,141 @@ static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, } Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, - const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyCastInst(CastOpc, Op, Ty, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit); +} + +/// For the given destination element of a shuffle, peek through shuffles to +/// match a root vector source operand that contains that element in the same +/// vector lane (ie, the same mask index), so we can eliminate the shuffle(s). +static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, + int MaskVal, Value *RootVec, + unsigned MaxRecurse) { + if (!MaxRecurse--) + return nullptr; + + // Bail out if any mask value is undefined. That kind of shuffle may be + // simplified further based on demanded bits or other folds. + if (MaskVal == -1) + return nullptr; + + // The mask value chooses which source operand we need to look at next. + int InVecNumElts = Op0->getType()->getVectorNumElements(); + int RootElt = MaskVal; + Value *SourceOp = Op0; + if (MaskVal >= InVecNumElts) { + RootElt = MaskVal - InVecNumElts; + SourceOp = Op1; + } + + // If the source operand is a shuffle itself, look through it to find the + // matching root vector. + if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) { + return foldIdentityShuffles( + DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1), + SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse); + } + + // TODO: Look through bitcasts? What if the bitcast changes the vector element + // size? + + // The source operand is not a shuffle. Initialize the root vector value for + // this shuffle if that has not been done yet. + if (!RootVec) + RootVec = SourceOp; + + // Give up as soon as a source operand does not match the existing root value. + if (RootVec != SourceOp) + return nullptr; + + // The element must be coming from the same lane in the source vector + // (although it may have crossed lanes in intermediate shuffles). + if (RootElt != DestElt) + return nullptr; + + return RootVec; +} + +static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, + Type *RetTy, const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (isa<UndefValue>(Mask)) + return UndefValue::get(RetTy); + + Type *InVecTy = Op0->getType(); + unsigned MaskNumElts = Mask->getType()->getVectorNumElements(); + unsigned InVecNumElts = InVecTy->getVectorNumElements(); + + SmallVector<int, 32> Indices; + ShuffleVectorInst::getShuffleMask(Mask, Indices); + assert(MaskNumElts == Indices.size() && + "Size of Indices not same as number of mask elements?"); + + // Canonicalization: If mask does not select elements from an input vector, + // replace that input vector with undef. + bool MaskSelects0 = false, MaskSelects1 = false; + for (unsigned i = 0; i != MaskNumElts; ++i) { + if (Indices[i] == -1) + continue; + if ((unsigned)Indices[i] < InVecNumElts) + MaskSelects0 = true; + else + MaskSelects1 = true; + } + if (!MaskSelects0) + Op0 = UndefValue::get(InVecTy); + if (!MaskSelects1) + Op1 = UndefValue::get(InVecTy); + + auto *Op0Const = dyn_cast<Constant>(Op0); + auto *Op1Const = dyn_cast<Constant>(Op1); + + // If all operands are constant, constant fold the shuffle. + if (Op0Const && Op1Const) + return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask); + + // Canonicalization: if only one input vector is constant, it shall be the + // second one. + if (Op0Const && !Op1Const) { + std::swap(Op0, Op1); + ShuffleVectorInst::commuteShuffleMask(Indices, InVecNumElts); + } + + // A shuffle of a splat is always the splat itself. Legal if the shuffle's + // value type is same as the input vectors' type. + if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0)) + if (isa<UndefValue>(Op1) && RetTy == InVecTy && + OpShuf->getMask()->getSplatValue()) + return Op0; + + // Don't fold a shuffle with undef mask elements. This may get folded in a + // better way using demanded bits or other analysis. + // TODO: Should we allow this? + if (find(Indices, -1) != Indices.end()) + return nullptr; + + // Check if every element of this shuffle can be mapped back to the + // corresponding element of a single root vector. If so, we don't need this + // shuffle. This handles simple identity shuffles as well as chains of + // shuffles that may widen/narrow and/or move elements across lanes and back. + Value *RootVec = nullptr; + for (unsigned i = 0; i != MaskNumElts; ++i) { + // Note that recursion is limited for each vector element, so if any element + // exceeds the limit, this will fail to simplify. + RootVec = + foldIdentityShuffles(i, Op0, Op1, Indices[i], RootVec, MaxRecurse); + + // We can't replace a widening/narrowing shuffle with one of its operands. + if (!RootVec || RootVec->getType() != RetTy) + return nullptr; + } + return RootVec; +} + +/// Given operands for a ShuffleVectorInst, fold the result or return null. +Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, + Type *RetTy, const SimplifyQuery &Q) { + return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); } //=== Helper functions for higher up the class hierarchy. @@ -4044,64 +4214,46 @@ Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { switch (Opcode) { case Instruction::Add: - return SimplifyAddInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::FAdd: return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::Sub: - return SimplifySubInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::FSub: return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - - case Instruction::Mul: return SimplifyMulInst (LHS, RHS, Q, MaxRecurse); + case Instruction::Mul: + return SimplifyMulInst(LHS, RHS, Q, MaxRecurse); case Instruction::FMul: - return SimplifyFMulInst (LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::SDiv: return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); - case Instruction::UDiv: return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); + return SimplifyFMulInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::SDiv: + return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); + case Instruction::UDiv: + return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); case Instruction::FDiv: - return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); - case Instruction::SRem: return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); - case Instruction::URem: return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); + return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + case Instruction::SRem: + return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); + case Instruction::URem: + return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); case Instruction::FRem: - return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::Shl: - return SimplifyShlInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false, - Q, MaxRecurse); + return SimplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::LShr: - return SimplifyLShrInst(LHS, RHS, /*isExact*/false, Q, MaxRecurse); + return SimplifyLShrInst(LHS, RHS, false, Q, MaxRecurse); case Instruction::AShr: - return SimplifyAShrInst(LHS, RHS, /*isExact*/false, Q, MaxRecurse); - case Instruction::And: return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); - case Instruction::Or: return SimplifyOrInst (LHS, RHS, Q, MaxRecurse); - case Instruction::Xor: return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); + return SimplifyAShrInst(LHS, RHS, false, Q, MaxRecurse); + case Instruction::And: + return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Or: + return SimplifyOrInst(LHS, RHS, Q, MaxRecurse); + case Instruction::Xor: + return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); default: - if (Constant *CLHS = dyn_cast<Constant>(LHS)) - if (Constant *CRHS = dyn_cast<Constant>(RHS)) - return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); - - // If the operation is associative, try some generic simplifications. - if (Instruction::isAssociative(Opcode)) - if (Value *V = SimplifyAssociativeBinOp(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - // If the operation is with the result of a select instruction check whether - // operating on either branch of the select always yields the same value. - if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) - if (Value *V = ThreadBinOpOverSelect(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - // If the operation is with the result of a phi instruction, check whether - // operating on all incoming values of the phi always yields the same value. - if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) - if (Value *V = ThreadBinOpOverPHI(Opcode, LHS, RHS, Q, MaxRecurse)) - return V; - - return nullptr; + llvm_unreachable("Unexpected opcode"); } } @@ -4110,7 +4262,7 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// In contrast to SimplifyBinOp, try to use FastMathFlag when folding the /// result. In case we don't need FastMathFlags, simply fall to SimplifyBinOp. static Value *SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const FastMathFlags &FMF, const Query &Q, + const FastMathFlags &FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { switch (Opcode) { case Instruction::FAdd: @@ -4127,36 +4279,26 @@ static Value *SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, } Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyBinOp(Opcode, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); } Value *llvm::SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const FastMathFlags &FMF, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyFPBinOp(Opcode, LHS, RHS, FMF, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + FastMathFlags FMF, const SimplifyQuery &Q) { + return ::SimplifyFPBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); } /// Given operands for a CmpInst, see if we can fold the result. static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate)) return SimplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse); return SimplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse); } Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyCmpInst(Predicate, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit); } static bool IsIdempotent(Intrinsic::ID ID) { @@ -4249,7 +4391,7 @@ static bool maskIsAllZeroOrUndef(Value *Mask) { template <typename IterTy> static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { Intrinsic::ID IID = F->getIntrinsicID(); unsigned NumOperands = std::distance(ArgBegin, ArgEnd); @@ -4267,6 +4409,7 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, case Intrinsic::fabs: { if (SignBitMustBeZero(*ArgBegin, Q.TLI)) return *ArgBegin; + return nullptr; } default: return nullptr; @@ -4296,19 +4439,21 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: { // X + undef -> undef - if (isa<UndefValue>(RHS)) + if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) return UndefValue::get(ReturnType); return nullptr; } case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: { + // 0 * X -> { 0, false } // X * 0 -> { 0, false } - if (match(RHS, m_Zero())) + if (match(LHS, m_Zero()) || match(RHS, m_Zero())) return Constant::getNullValue(ReturnType); + // undef * X -> { 0, false } // X * undef -> { 0, false } - if (match(RHS, m_Undef())) + if (match(LHS, m_Undef()) || match(RHS, m_Undef())) return Constant::getNullValue(ReturnType); return nullptr; @@ -4342,7 +4487,7 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, template <typename IterTy> static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, - const Query &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, unsigned MaxRecurse) { Type *Ty = V->getType(); if (PointerType *PTy = dyn_cast<PointerType>(Ty)) Ty = PTy->getElementType(); @@ -4377,177 +4522,164 @@ static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, } Value *llvm::SimplifyCall(Value *V, User::op_iterator ArgBegin, - User::op_iterator ArgEnd, const DataLayout &DL, - const TargetLibraryInfo *TLI, const DominatorTree *DT, - AssumptionCache *AC, const Instruction *CxtI) { - return ::SimplifyCall(V, ArgBegin, ArgEnd, Query(DL, TLI, DT, AC, CxtI), - RecursionLimit); + User::op_iterator ArgEnd, const SimplifyQuery &Q) { + return ::SimplifyCall(V, ArgBegin, ArgEnd, Q, RecursionLimit); } Value *llvm::SimplifyCall(Value *V, ArrayRef<Value *> Args, - const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC, - const Instruction *CxtI) { - return ::SimplifyCall(V, Args.begin(), Args.end(), - Query(DL, TLI, DT, AC, CxtI), RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyCall(V, Args.begin(), Args.end(), Q, RecursionLimit); } /// See if we can compute a simplified version of this instruction. /// If not, this returns null. -Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, AssumptionCache *AC) { + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); Value *Result; switch (I->getOpcode()) { default: - Result = ConstantFoldInstruction(I, DL, TLI); + Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); break; case Instruction::FAdd: Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::Add: Result = SimplifyAddInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, - TLI, DT, AC, I); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); break; case Instruction::FSub: Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::Sub: Result = SimplifySubInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, - TLI, DT, AC, I); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); break; case Instruction::FMul: Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::Mul: - Result = - SimplifyMulInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::SDiv: - Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, - AC, I); + Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::UDiv: - Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, - AC, I); + Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::FDiv: Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::SRem: - Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, - AC, I); + Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::URem: - Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, - AC, I); + Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::FRem: Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + I->getFastMathFlags(), Q); break; case Instruction::Shl: Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, - TLI, DT, AC, I); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); break; case Instruction::LShr: Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->isExact(), DL, TLI, DT, - AC, I); + cast<BinaryOperator>(I)->isExact(), Q); break; case Instruction::AShr: Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->isExact(), DL, TLI, DT, - AC, I); + cast<BinaryOperator>(I)->isExact(), Q); break; case Instruction::And: - Result = - SimplifyAndInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::Or: - Result = - SimplifyOrInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::Xor: - Result = - SimplifyXorInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q); break; case Instruction::ICmp: - Result = - SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), I->getOperand(0), - I->getOperand(1), DL, TLI, DT, AC, I); + Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), + I->getOperand(0), I->getOperand(1), Q); break; case Instruction::FCmp: - Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT, AC, I); + Result = + SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0), + I->getOperand(1), I->getFastMathFlags(), Q); break; case Instruction::Select: Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), - I->getOperand(2), DL, TLI, DT, AC, I); + I->getOperand(2), Q); break; case Instruction::GetElementPtr: { - SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); + SmallVector<Value *, 8> Ops(I->op_begin(), I->op_end()); Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(), - Ops, DL, TLI, DT, AC, I); + Ops, Q); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast<InsertValueInst>(I); Result = SimplifyInsertValueInst(IV->getAggregateOperand(), IV->getInsertedValueOperand(), - IV->getIndices(), DL, TLI, DT, AC, I); + IV->getIndices(), Q); break; } case Instruction::ExtractValue: { auto *EVI = cast<ExtractValueInst>(I); Result = SimplifyExtractValueInst(EVI->getAggregateOperand(), - EVI->getIndices(), DL, TLI, DT, AC, I); + EVI->getIndices(), Q); break; } case Instruction::ExtractElement: { auto *EEI = cast<ExtractElementInst>(I); - Result = SimplifyExtractElementInst( - EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I); + Result = SimplifyExtractElementInst(EEI->getVectorOperand(), + EEI->getIndexOperand(), Q); + break; + } + case Instruction::ShuffleVector: { + auto *SVI = cast<ShuffleVectorInst>(I); + Result = SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), + SVI->getMask(), SVI->getType(), Q); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast<PHINode>(I), Query(DL, TLI, DT, AC, I)); + Result = SimplifyPHINode(cast<PHINode>(I), Q); break; case Instruction::Call: { CallSite CS(cast<CallInst>(I)); - Result = SimplifyCall(CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), DL, - TLI, DT, AC, I); + Result = SimplifyCall(CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), Q); break; } #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - Result = SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), - DL, TLI, DT, AC, I); + Result = + SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); + break; + case Instruction::Alloca: + // No simplifications for Alloca and it can't be constant folded. + Result = nullptr; break; } // In general, it is possible for computeKnownBits to determine all bits in a // value even when the operands are not all constants. if (!Result && I->getType()->isIntOrIntVectorTy()) { - unsigned BitWidth = I->getType()->getScalarSizeInBits(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(I, KnownZero, KnownOne, DL, /*Depth*/0, AC, I, DT); - if ((KnownZero | KnownOne).isAllOnesValue()) - Result = ConstantInt::get(I->getType(), KnownOne); + KnownBits Known = computeKnownBits(I, Q.DL, /*Depth*/ 0, Q.AC, I, Q.DT, ORE); + if (Known.isConstant()) + Result = ConstantInt::get(I->getType(), Known.getConstant()); } /// If called on unreachable code, the above logic may report that the @@ -4599,7 +4731,7 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, I = Worklist[Idx]; // See if this instruction simplifies. - SimpleV = SimplifyInstruction(I, DL, TLI, DT, AC); + SimpleV = SimplifyInstruction(I, {DL, TLI, DT, AC}); if (!SimpleV) continue; @@ -4638,3 +4770,31 @@ bool llvm::replaceAndRecursivelySimplify(Instruction *I, Value *SimpleV, assert(SimpleV && "Must provide a simplified value."); return replaceAndRecursivelySimplifyImpl(I, SimpleV, TLI, DT, AC); } + +namespace llvm { +const SimplifyQuery getBestSimplifyQuery(Pass &P, Function &F) { + auto *DTWP = P.getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + auto *TLIWP = P.getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + auto *TLI = TLIWP ? &TLIWP->getTLI() : nullptr; + auto *ACWP = P.getAnalysisIfAvailable<AssumptionCacheTracker>(); + auto *AC = ACWP ? &ACWP->getAssumptionCache(F) : nullptr; + return {F.getParent()->getDataLayout(), TLI, DT, AC}; +} + +const SimplifyQuery getBestSimplifyQuery(LoopStandardAnalysisResults &AR, + const DataLayout &DL) { + return {DL, &AR.TLI, &AR.DT, &AR.AC}; +} + +template <class T, class... TArgs> +const SimplifyQuery getBestSimplifyQuery(AnalysisManager<T, TArgs...> &AM, + Function &F) { + auto *DT = AM.template getCachedResult<DominatorTreeAnalysis>(F); + auto *TLI = AM.template getCachedResult<TargetLibraryAnalysis>(F); + auto *AC = AM.template getCachedResult<AssumptionAnalysis>(F); + return {F.getParent()->getDataLayout(), TLI, DT, AC}; +} +template const SimplifyQuery getBestSimplifyQuery(AnalysisManager<Function> &, + Function &); +} diff --git a/contrib/llvm/lib/Analysis/IteratedDominanceFrontier.cpp b/contrib/llvm/lib/Analysis/IteratedDominanceFrontier.cpp index d1374acd963e..2a736ec0379c 100644 --- a/contrib/llvm/lib/Analysis/IteratedDominanceFrontier.cpp +++ b/contrib/llvm/lib/Analysis/IteratedDominanceFrontier.cpp @@ -64,10 +64,7 @@ void IDFCalculator<NodeTy>::calculate( BasicBlock *BB = Node->getBlock(); // Succ is the successor in the direction we are calculating IDF, so it is // successor for IDF, and predecessor for Reverse IDF. - for (auto SuccIter = GraphTraits<NodeTy>::child_begin(BB), - End = GraphTraits<NodeTy>::child_end(BB); - SuccIter != End; ++SuccIter) { - BasicBlock *Succ = *SuccIter; + for (auto *Succ : children<NodeTy>(BB)) { DomTreeNode *SuccNode = DT.getNode(Succ); // Quickly skip all CFG edges that are also dominator tree edges instead diff --git a/contrib/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp b/contrib/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp index 596b6fc1afb5..a8178ecc0a24 100644 --- a/contrib/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp +++ b/contrib/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp @@ -9,7 +9,7 @@ // // This is an alternative analysis pass to BlockFrequencyInfoWrapperPass. The // difference is that with this pass the block frequencies are not computed when -// the analysis pass is executed but rather when the BFI results is explicitly +// the analysis pass is executed but rather when the BFI result is explicitly // requested by the analysis client. // //===----------------------------------------------------------------------===// diff --git a/contrib/llvm/lib/Analysis/LazyCallGraph.cpp b/contrib/llvm/lib/Analysis/LazyCallGraph.cpp index f7cf8c6729f2..eef56815f2e0 100644 --- a/contrib/llvm/lib/Analysis/LazyCallGraph.cpp +++ b/contrib/llvm/lib/Analysis/LazyCallGraph.cpp @@ -18,26 +18,50 @@ #include "llvm/IR/PassManager.h" #include "llvm/Support/Debug.h" #include "llvm/Support/GraphWriter.h" +#include <utility> using namespace llvm; #define DEBUG_TYPE "lcg" +void LazyCallGraph::EdgeSequence::insertEdgeInternal(Node &TargetN, + Edge::Kind EK) { + EdgeIndexMap.insert({&TargetN, Edges.size()}); + Edges.emplace_back(TargetN, EK); +} + +void LazyCallGraph::EdgeSequence::setEdgeKind(Node &TargetN, Edge::Kind EK) { + Edges[EdgeIndexMap.find(&TargetN)->second].setKind(EK); +} + +bool LazyCallGraph::EdgeSequence::removeEdgeInternal(Node &TargetN) { + auto IndexMapI = EdgeIndexMap.find(&TargetN); + if (IndexMapI == EdgeIndexMap.end()) + return false; + + Edges[IndexMapI->second] = Edge(); + EdgeIndexMap.erase(IndexMapI); + return true; +} + static void addEdge(SmallVectorImpl<LazyCallGraph::Edge> &Edges, - DenseMap<Function *, int> &EdgeIndexMap, Function &F, - LazyCallGraph::Edge::Kind EK) { - if (!EdgeIndexMap.insert({&F, Edges.size()}).second) + DenseMap<LazyCallGraph::Node *, int> &EdgeIndexMap, + LazyCallGraph::Node &N, LazyCallGraph::Edge::Kind EK) { + if (!EdgeIndexMap.insert({&N, Edges.size()}).second) return; - DEBUG(dbgs() << " Added callable function: " << F.getName() << "\n"); - Edges.emplace_back(LazyCallGraph::Edge(F, EK)); + DEBUG(dbgs() << " Added callable function: " << N.getName() << "\n"); + Edges.emplace_back(LazyCallGraph::Edge(N, EK)); } -LazyCallGraph::Node::Node(LazyCallGraph &G, Function &F) - : G(&G), F(F), DFSNumber(0), LowLink(0) { - DEBUG(dbgs() << " Adding functions called by '" << F.getName() +LazyCallGraph::EdgeSequence &LazyCallGraph::Node::populateSlow() { + assert(!Edges && "Must not have already populated the edges for this node!"); + + DEBUG(dbgs() << " Adding functions called by '" << getName() << "' to the graph.\n"); + Edges = EdgeSequence(); + SmallVector<Constant *, 16> Worklist; SmallPtrSet<Function *, 4> Callees; SmallPtrSet<Constant *, 16> Visited; @@ -58,14 +82,15 @@ LazyCallGraph::Node::Node(LazyCallGraph &G, Function &F) // alias. Then a test of the address of the weak function against the new // strong definition's address would be an effective way to determine the // safety of optimizing a direct call edge. - for (BasicBlock &BB : F) + for (BasicBlock &BB : *F) for (Instruction &I : BB) { if (auto CS = CallSite(&I)) if (Function *Callee = CS.getCalledFunction()) if (!Callee->isDeclaration()) if (Callees.insert(Callee).second) { Visited.insert(Callee); - addEdge(Edges, EdgeIndexMap, *Callee, LazyCallGraph::Edge::Call); + addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(*Callee), + LazyCallGraph::Edge::Call); } for (Value *Op : I.operand_values()) @@ -78,50 +103,33 @@ LazyCallGraph::Node::Node(LazyCallGraph &G, Function &F) // function containing) operands to all of the instructions in the function. // Process them (recursively) collecting every function found. visitReferences(Worklist, Visited, [&](Function &F) { - addEdge(Edges, EdgeIndexMap, F, LazyCallGraph::Edge::Ref); + addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(F), + LazyCallGraph::Edge::Ref); }); -} - -void LazyCallGraph::Node::insertEdgeInternal(Function &Target, Edge::Kind EK) { - if (Node *N = G->lookup(Target)) - return insertEdgeInternal(*N, EK); - - EdgeIndexMap.insert({&Target, Edges.size()}); - Edges.emplace_back(Target, EK); -} -void LazyCallGraph::Node::insertEdgeInternal(Node &TargetN, Edge::Kind EK) { - EdgeIndexMap.insert({&TargetN.getFunction(), Edges.size()}); - Edges.emplace_back(TargetN, EK); + return *Edges; } -void LazyCallGraph::Node::setEdgeKind(Function &TargetF, Edge::Kind EK) { - Edges[EdgeIndexMap.find(&TargetF)->second].setKind(EK); +void LazyCallGraph::Node::replaceFunction(Function &NewF) { + assert(F != &NewF && "Must not replace a function with itself!"); + F = &NewF; } -void LazyCallGraph::Node::removeEdgeInternal(Function &Target) { - auto IndexMapI = EdgeIndexMap.find(&Target); - assert(IndexMapI != EdgeIndexMap.end() && - "Target not in the edge set for this caller?"); - - Edges[IndexMapI->second] = Edge(); - EdgeIndexMap.erase(IndexMapI); -} - -void LazyCallGraph::Node::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::Node::dump() const { dbgs() << *this << '\n'; } +#endif -LazyCallGraph::LazyCallGraph(Module &M) : NextDFSNumber(0) { +LazyCallGraph::LazyCallGraph(Module &M) { DEBUG(dbgs() << "Building CG for module: " << M.getModuleIdentifier() << "\n"); for (Function &F : M) - if (!F.isDeclaration() && !F.hasLocalLinkage()) - if (EntryIndexMap.insert({&F, EntryEdges.size()}).second) { - DEBUG(dbgs() << " Adding '" << F.getName() - << "' to entry set of the graph.\n"); - EntryEdges.emplace_back(F, Edge::Ref); - } + if (!F.isDeclaration() && !F.hasLocalLinkage()) { + DEBUG(dbgs() << " Adding '" << F.getName() + << "' to entry set of the graph.\n"); + addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(F), Edge::Ref); + } // Now add entry nodes for functions reachable via initializers to globals. SmallVector<Constant *, 16> Worklist; @@ -134,21 +142,15 @@ LazyCallGraph::LazyCallGraph(Module &M) : NextDFSNumber(0) { DEBUG(dbgs() << " Adding functions referenced by global initializers to the " "entry set.\n"); visitReferences(Worklist, Visited, [&](Function &F) { - addEdge(EntryEdges, EntryIndexMap, F, LazyCallGraph::Edge::Ref); + addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(F), + LazyCallGraph::Edge::Ref); }); - - for (const Edge &E : EntryEdges) - RefSCCEntryNodes.push_back(&E.getFunction()); } LazyCallGraph::LazyCallGraph(LazyCallGraph &&G) : BPA(std::move(G.BPA)), NodeMap(std::move(G.NodeMap)), - EntryEdges(std::move(G.EntryEdges)), - EntryIndexMap(std::move(G.EntryIndexMap)), SCCBPA(std::move(G.SCCBPA)), - SCCMap(std::move(G.SCCMap)), LeafRefSCCs(std::move(G.LeafRefSCCs)), - DFSStack(std::move(G.DFSStack)), - RefSCCEntryNodes(std::move(G.RefSCCEntryNodes)), - NextDFSNumber(G.NextDFSNumber) { + EntryEdges(std::move(G.EntryEdges)), SCCBPA(std::move(G.SCCBPA)), + SCCMap(std::move(G.SCCMap)), LeafRefSCCs(std::move(G.LeafRefSCCs)) { updateGraphPtrs(); } @@ -156,20 +158,18 @@ LazyCallGraph &LazyCallGraph::operator=(LazyCallGraph &&G) { BPA = std::move(G.BPA); NodeMap = std::move(G.NodeMap); EntryEdges = std::move(G.EntryEdges); - EntryIndexMap = std::move(G.EntryIndexMap); SCCBPA = std::move(G.SCCBPA); SCCMap = std::move(G.SCCMap); LeafRefSCCs = std::move(G.LeafRefSCCs); - DFSStack = std::move(G.DFSStack); - RefSCCEntryNodes = std::move(G.RefSCCEntryNodes); - NextDFSNumber = G.NextDFSNumber; updateGraphPtrs(); return *this; } -void LazyCallGraph::SCC::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::SCC::dump() const { dbgs() << *this << '\n'; } +#endif #ifndef NDEBUG void LazyCallGraph::SCC::verify() { @@ -184,8 +184,8 @@ void LazyCallGraph::SCC::verify() { "Must set DFS numbers to -1 when adding a node to an SCC!"); assert(N->LowLink == -1 && "Must set low link to -1 when adding a node to an SCC!"); - for (Edge &E : *N) - assert(E.getNode() && "Can't have an edge to a raw function!"); + for (Edge &E : **N) + assert(E.getNode() && "Can't have an unpopulated node!"); } } #endif @@ -195,10 +195,9 @@ bool LazyCallGraph::SCC::isParentOf(const SCC &C) const { return false; for (Node &N : *this) - for (Edge &E : N.calls()) - if (Node *CalleeN = E.getNode()) - if (OuterRefSCC->G->lookupSCC(*CalleeN) == &C) - return true; + for (Edge &E : N->calls()) + if (OuterRefSCC->G->lookupSCC(E.getNode()) == &C) + return true; // No edges found. return false; @@ -218,11 +217,8 @@ bool LazyCallGraph::SCC::isAncestorOf(const SCC &TargetC) const { do { const SCC &C = *Worklist.pop_back_val(); for (Node &N : C) - for (Edge &E : N.calls()) { - Node *CalleeN = E.getNode(); - if (!CalleeN) - continue; - SCC *CalleeC = G.lookupSCC(*CalleeN); + for (Edge &E : N->calls()) { + SCC *CalleeC = G.lookupSCC(E.getNode()); if (!CalleeC) continue; @@ -243,9 +239,11 @@ bool LazyCallGraph::SCC::isAncestorOf(const SCC &TargetC) const { LazyCallGraph::RefSCC::RefSCC(LazyCallGraph &G) : G(&G) {} -void LazyCallGraph::RefSCC::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::RefSCC::dump() const { dbgs() << *this << '\n'; } +#endif #ifndef NDEBUG void LazyCallGraph::RefSCC::verify() { @@ -279,10 +277,10 @@ void LazyCallGraph::RefSCC::verify() { for (int i = 0, Size = SCCs.size(); i < Size; ++i) { SCC &SourceSCC = *SCCs[i]; for (Node &N : SourceSCC) - for (Edge &E : N) { + for (Edge &E : *N) { if (!E.isCall()) continue; - SCC &TargetSCC = *G->lookupSCC(*E.getNode()); + SCC &TargetSCC = *G->lookupSCC(E.getNode()); if (&TargetSCC.getOuterRefSCC() == this) { assert(SCCIndices.find(&TargetSCC)->second <= i && "Edge between SCCs violates post-order relationship."); @@ -299,8 +297,8 @@ void LazyCallGraph::RefSCC::verify() { auto HasConnectingEdge = [&] { for (SCC &C : *ParentRC) for (Node &N : C) - for (Edge &E : N) - if (G->lookupRefSCC(*E.getNode()) == this) + for (Edge &E : *N) + if (G->lookupRefSCC(E.getNode()) == this) return true; return false; }; @@ -461,7 +459,7 @@ updatePostorderSequenceForEdgeInsertion( SmallVector<LazyCallGraph::SCC *, 1> LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { - assert(!SourceN[TargetN].isCall() && "Must start with a ref edge!"); + assert(!(*SourceN)[TargetN].isCall() && "Must start with a ref edge!"); SmallVector<SCC *, 1> DeletedSCCs; #ifndef NDEBUG @@ -477,7 +475,7 @@ LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { // If the two nodes are already part of the same SCC, we're also done as // we've just added more connectivity. if (&SourceSCC == &TargetSCC) { - SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); + SourceN->setEdgeKind(TargetN, Edge::Call); return DeletedSCCs; } @@ -490,7 +488,7 @@ LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { int SourceIdx = SCCIndices[&SourceSCC]; int TargetIdx = SCCIndices[&TargetSCC]; if (TargetIdx < SourceIdx) { - SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); + SourceN->setEdgeKind(TargetN, Edge::Call); return DeletedSCCs; } @@ -504,11 +502,9 @@ LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { ConnectedSet.insert(&SourceSCC); auto IsConnected = [&](SCC &C) { for (Node &N : C) - for (Edge &E : N.calls()) { - assert(E.getNode() && "Must have formed a node within an SCC!"); - if (ConnectedSet.count(G->lookupSCC(*E.getNode()))) + for (Edge &E : N->calls()) + if (ConnectedSet.count(G->lookupSCC(E.getNode()))) return true; - } return false; }; @@ -535,11 +531,10 @@ LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { do { SCC &C = *Worklist.pop_back_val(); for (Node &N : C) - for (Edge &E : N) { - assert(E.getNode() && "Must have formed a node within an SCC!"); + for (Edge &E : *N) { if (!E.isCall()) continue; - SCC &EdgeC = *G->lookupSCC(*E.getNode()); + SCC &EdgeC = *G->lookupSCC(E.getNode()); if (&EdgeC.getOuterRefSCC() != this) // Not in this RefSCC... continue; @@ -565,7 +560,7 @@ LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { // new cycles. We're done. if (MergeRange.begin() == MergeRange.end()) { // Now that the SCC structure is finalized, flip the kind to call. - SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); + SourceN->setEdgeKind(TargetN, Edge::Call); return DeletedSCCs; } @@ -600,7 +595,7 @@ LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { SCCIndices[C] -= IndexOffset; // Now that the SCC structure is finalized, flip the kind to call. - SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); + SourceN->setEdgeKind(TargetN, Edge::Call); // And we're done! return DeletedSCCs; @@ -608,7 +603,7 @@ LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { void LazyCallGraph::RefSCC::switchTrivialInternalEdgeToRef(Node &SourceN, Node &TargetN) { - assert(SourceN[TargetN].isCall() && "Must start with a call edge!"); + assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); #ifndef NDEBUG // In a debug build, verify the RefSCC is valid to start with and when this @@ -625,12 +620,12 @@ void LazyCallGraph::RefSCC::switchTrivialInternalEdgeToRef(Node &SourceN, "Source and Target must be in separate SCCs for this to be trivial!"); // Set the edge kind. - SourceN.setEdgeKind(TargetN.getFunction(), Edge::Ref); + SourceN->setEdgeKind(TargetN, Edge::Ref); } iterator_range<LazyCallGraph::RefSCC::iterator> LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) { - assert(SourceN[TargetN].isCall() && "Must start with a call edge!"); + assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); #ifndef NDEBUG // In a debug build, verify the RefSCC is valid to start with and when this @@ -650,7 +645,7 @@ LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) { "full CG update."); // Set the edge kind. - SourceN.setEdgeKind(TargetN.getFunction(), Edge::Ref); + SourceN->setEdgeKind(TargetN, Edge::Ref); // Otherwise we are removing a call edge from a single SCC. This may break // the cycle. In order to compute the new set of SCCs, we need to do a small @@ -665,7 +660,7 @@ LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) { // etc. SCC &OldSCC = TargetSCC; - SmallVector<std::pair<Node *, call_edge_iterator>, 16> DFSStack; + SmallVector<std::pair<Node *, EdgeSequence::call_iterator>, 16> DFSStack; SmallVector<Node *, 16> PendingSCCStack; SmallVector<SCC *, 4> NewSCCs; @@ -706,14 +701,14 @@ LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) { RootN->DFSNumber = RootN->LowLink = 1; int NextDFSNumber = 2; - DFSStack.push_back({RootN, RootN->call_begin()}); + DFSStack.push_back({RootN, (*RootN)->call_begin()}); do { Node *N; - call_edge_iterator I; + EdgeSequence::call_iterator I; std::tie(N, I) = DFSStack.pop_back_val(); - auto E = N->call_end(); + auto E = (*N)->call_end(); while (I != E) { - Node &ChildN = *I->getNode(); + Node &ChildN = I->getNode(); if (ChildN.DFSNumber == 0) { // We haven't yet visited this child, so descend, pushing the current // node onto the stack. @@ -723,8 +718,8 @@ LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) { "Found a node with 0 DFS number but already in an SCC!"); ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++; N = &ChildN; - I = N->call_begin(); - E = N->call_end(); + I = (*N)->call_begin(); + E = (*N)->call_end(); continue; } @@ -817,17 +812,19 @@ LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) { void LazyCallGraph::RefSCC::switchOutgoingEdgeToCall(Node &SourceN, Node &TargetN) { - assert(!SourceN[TargetN].isCall() && "Must start with a ref edge!"); + assert(!(*SourceN)[TargetN].isCall() && "Must start with a ref edge!"); assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); assert(G->lookupRefSCC(TargetN) != this && "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && "Target must be a descendant of the Source."); +#endif // Edges between RefSCCs are the same regardless of call or ref, so we can // just flip the edge here. - SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); + SourceN->setEdgeKind(TargetN, Edge::Call); #ifndef NDEBUG // Check that the RefSCC is still valid. @@ -837,17 +834,19 @@ void LazyCallGraph::RefSCC::switchOutgoingEdgeToCall(Node &SourceN, void LazyCallGraph::RefSCC::switchOutgoingEdgeToRef(Node &SourceN, Node &TargetN) { - assert(SourceN[TargetN].isCall() && "Must start with a call edge!"); + assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); assert(G->lookupRefSCC(TargetN) != this && "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && "Target must be a descendant of the Source."); +#endif // Edges between RefSCCs are the same regardless of call or ref, so we can // just flip the edge here. - SourceN.setEdgeKind(TargetN.getFunction(), Edge::Ref); + SourceN->setEdgeKind(TargetN, Edge::Ref); #ifndef NDEBUG // Check that the RefSCC is still valid. @@ -860,7 +859,7 @@ void LazyCallGraph::RefSCC::insertInternalRefEdge(Node &SourceN, assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC."); - SourceN.insertEdgeInternal(TargetN, Edge::Ref); + SourceN->insertEdgeInternal(TargetN, Edge::Ref); #ifndef NDEBUG // Check that the RefSCC is still valid. @@ -871,14 +870,16 @@ void LazyCallGraph::RefSCC::insertInternalRefEdge(Node &SourceN, void LazyCallGraph::RefSCC::insertOutgoingEdge(Node &SourceN, Node &TargetN, Edge::Kind EK) { // First insert it into the caller. - SourceN.insertEdgeInternal(TargetN, EK); + SourceN->insertEdgeInternal(TargetN, EK); assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); RefSCC &TargetC = *G->lookupRefSCC(TargetN); assert(&TargetC != this && "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS assert(TargetC.isDescendantOf(*this) && "Target must be a descendant of the Source."); +#endif // The only change required is to add this SCC to the parent set of the // callee. @@ -895,8 +896,10 @@ LazyCallGraph::RefSCC::insertIncomingRefEdge(Node &SourceN, Node &TargetN) { assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC."); RefSCC &SourceC = *G->lookupRefSCC(SourceN); assert(&SourceC != this && "Source must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS assert(SourceC.isDescendantOf(*this) && "Source must be a descendant of the Target."); +#endif SmallVector<RefSCC *, 1> DeletedRefSCCs; @@ -951,9 +954,8 @@ LazyCallGraph::RefSCC::insertIncomingRefEdge(Node &SourceN, Node &TargetN) { RefSCC &RC = *Worklist.pop_back_val(); for (SCC &C : RC) for (Node &N : C) - for (Edge &E : N) { - assert(E.getNode() && "Must have formed a node!"); - RefSCC &EdgeRC = *G->lookupRefSCC(*E.getNode()); + for (Edge &E : *N) { + RefSCC &EdgeRC = *G->lookupRefSCC(E.getNode()); if (G->getRefSCCIndex(EdgeRC) <= SourceIdx) // Not in the postorder sequence between source and target. continue; @@ -1003,10 +1005,8 @@ LazyCallGraph::RefSCC::insertIncomingRefEdge(Node &SourceN, Node &TargetN) { SCCIndices[&InnerC] = SCCIndex++; for (Node &N : InnerC) { G->SCCMap[&N] = &InnerC; - for (Edge &E : N) { - assert(E.getNode() && - "Cannot have a null node within a visited SCC!"); - RefSCC &ChildRC = *G->lookupRefSCC(*E.getNode()); + for (Edge &E : *N) { + RefSCC &ChildRC = *G->lookupRefSCC(E.getNode()); if (MergeSet.count(&ChildRC)) continue; ChildRC.Parents.erase(RC); @@ -1042,7 +1042,7 @@ LazyCallGraph::RefSCC::insertIncomingRefEdge(Node &SourceN, Node &TargetN) { // At this point we have a merged RefSCC with a post-order SCCs list, just // connect the nodes to form the new edge. - SourceN.insertEdgeInternal(TargetN, Edge::Ref); + SourceN->insertEdgeInternal(TargetN, Edge::Ref); // We return the list of SCCs which were merged so that callers can // invalidate any data they have associated with those SCCs. Note that these @@ -1069,15 +1069,16 @@ void LazyCallGraph::RefSCC::removeOutgoingEdge(Node &SourceN, Node &TargetN) { #endif // First remove it from the node. - SourceN.removeEdgeInternal(TargetN.getFunction()); + bool Removed = SourceN->removeEdgeInternal(TargetN); + (void)Removed; + assert(Removed && "Target not in the edge set for this caller?"); bool HasOtherEdgeToChildRC = false; bool HasOtherChildRC = false; for (SCC *InnerC : SCCs) { for (Node &N : *InnerC) { - for (Edge &E : N) { - assert(E.getNode() && "Cannot have a missing node in a visited SCC!"); - RefSCC &OtherChildRC = *G->lookupRefSCC(*E.getNode()); + for (Edge &E : *N) { + RefSCC &OtherChildRC = *G->lookupRefSCC(E.getNode()); if (&OtherChildRC == &TargetRC) { HasOtherEdgeToChildRC = true; break; @@ -1116,7 +1117,7 @@ void LazyCallGraph::RefSCC::removeOutgoingEdge(Node &SourceN, Node &TargetN) { SmallVector<LazyCallGraph::RefSCC *, 1> LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, Node &TargetN) { - assert(!SourceN[TargetN].isCall() && + assert(!(*SourceN)[TargetN].isCall() && "Cannot remove a call edge, it must first be made a ref edge"); #ifndef NDEBUG @@ -1127,7 +1128,9 @@ LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, Node &TargetN) { #endif // First remove the actual edge. - SourceN.removeEdgeInternal(TargetN.getFunction()); + bool Removed = SourceN->removeEdgeInternal(TargetN); + (void)Removed; + assert(Removed && "Target not in the edge set for this caller?"); // We return a list of the resulting *new* RefSCCs in post-order. SmallVector<RefSCC *, 1> Result; @@ -1186,7 +1189,7 @@ LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, Node &TargetN) { PostOrderMapping[&N] = Number; }; - SmallVector<std::pair<Node *, edge_iterator>, 4> DFSStack; + SmallVector<std::pair<Node *, EdgeSequence::iterator>, 4> DFSStack; SmallVector<Node *, 4> PendingRefSCCStack; do { assert(DFSStack.empty() && @@ -1205,18 +1208,18 @@ LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, Node &TargetN) { RootN->DFSNumber = RootN->LowLink = 1; int NextDFSNumber = 2; - DFSStack.push_back({RootN, RootN->begin()}); + DFSStack.push_back({RootN, (*RootN)->begin()}); do { Node *N; - edge_iterator I; + EdgeSequence::iterator I; std::tie(N, I) = DFSStack.pop_back_val(); - auto E = N->end(); + auto E = (*N)->end(); assert(N->DFSNumber != 0 && "We should always assign a DFS number " "before processing a node."); while (I != E) { - Node &ChildN = I->getNode(*G); + Node &ChildN = I->getNode(); if (ChildN.DFSNumber == 0) { // Mark that we should start at this child when next this node is the // top of the stack. We don't start at the next child to ensure this @@ -1226,8 +1229,8 @@ LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, Node &TargetN) { // Continue, resetting to the child node. ChildN.LowLink = ChildN.DFSNumber = NextDFSNumber++; N = &ChildN; - I = ChildN.begin(); - E = ChildN.end(); + I = ChildN->begin(); + E = ChildN->end(); continue; } if (ChildN.DFSNumber == -1) { @@ -1382,9 +1385,8 @@ LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, Node &TargetN) { #endif for (SCC *C : SCCs) for (Node &N : *C) { - for (Edge &E : N) { - assert(E.getNode() && "Cannot have a missing node in a visited SCC!"); - RefSCC &ChildRC = *G->lookupRefSCC(*E.getNode()); + for (Edge &E : *N) { + RefSCC &ChildRC = *G->lookupRefSCC(E.getNode()); if (&ChildRC == this) continue; ChildRC.Parents.insert(this); @@ -1408,9 +1410,8 @@ LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, Node &TargetN) { for (RefSCC *ParentRC : OldParents) for (SCC &ParentC : *ParentRC) for (Node &ParentN : ParentC) - for (Edge &E : ParentN) { - assert(E.getNode() && "Cannot have a missing node in a visited SCC!"); - RefSCC &RC = *G->lookupRefSCC(*E.getNode()); + for (Edge &E : *ParentN) { + RefSCC &RC = *G->lookupRefSCC(E.getNode()); if (&RC != ParentRC) RC.Parents.insert(ParentRC); } @@ -1448,8 +1449,10 @@ void LazyCallGraph::RefSCC::handleTrivialEdgeInsertion(Node &SourceN, return; } +#ifdef EXPENSIVE_CHECKS assert(TargetRC.isDescendantOf(*this) && "Target must be a descendant of the Source."); +#endif // The only change required is to add this RefSCC to the parent set of the // target. This is a set and so idempotent if the edge already existed. TargetRC.Parents.insert(this); @@ -1461,25 +1464,29 @@ void LazyCallGraph::RefSCC::insertTrivialCallEdge(Node &SourceN, // Check that the RefSCC is still valid when we finish. auto ExitVerifier = make_scope_exit([this] { verify(); }); - // Check that we aren't breaking some invariants of the SCC graph. +#ifdef EXPENSIVE_CHECKS + // Check that we aren't breaking some invariants of the SCC graph. Note that + // this is quadratic in the number of edges in the call graph! SCC &SourceC = *G->lookupSCC(SourceN); SCC &TargetC = *G->lookupSCC(TargetN); if (&SourceC != &TargetC) assert(SourceC.isAncestorOf(TargetC) && "Call edge is not trivial in the SCC graph!"); -#endif +#endif // EXPENSIVE_CHECKS +#endif // NDEBUG + // First insert it into the source or find the existing edge. - auto InsertResult = SourceN.EdgeIndexMap.insert( - {&TargetN.getFunction(), SourceN.Edges.size()}); + auto InsertResult = + SourceN->EdgeIndexMap.insert({&TargetN, SourceN->Edges.size()}); if (!InsertResult.second) { // Already an edge, just update it. - Edge &E = SourceN.Edges[InsertResult.first->second]; + Edge &E = SourceN->Edges[InsertResult.first->second]; if (E.isCall()) return; // Nothing to do! E.setKind(Edge::Call); } else { // Create the new edge. - SourceN.Edges.emplace_back(TargetN, Edge::Call); + SourceN->Edges.emplace_back(TargetN, Edge::Call); } // Now that we have the edge, handle the graph fallout. @@ -1491,39 +1498,75 @@ void LazyCallGraph::RefSCC::insertTrivialRefEdge(Node &SourceN, Node &TargetN) { // Check that the RefSCC is still valid when we finish. auto ExitVerifier = make_scope_exit([this] { verify(); }); +#ifdef EXPENSIVE_CHECKS // Check that we aren't breaking some invariants of the RefSCC graph. RefSCC &SourceRC = *G->lookupRefSCC(SourceN); RefSCC &TargetRC = *G->lookupRefSCC(TargetN); if (&SourceRC != &TargetRC) assert(SourceRC.isAncestorOf(TargetRC) && "Ref edge is not trivial in the RefSCC graph!"); -#endif +#endif // EXPENSIVE_CHECKS +#endif // NDEBUG + // First insert it into the source or find the existing edge. - auto InsertResult = SourceN.EdgeIndexMap.insert( - {&TargetN.getFunction(), SourceN.Edges.size()}); + auto InsertResult = + SourceN->EdgeIndexMap.insert({&TargetN, SourceN->Edges.size()}); if (!InsertResult.second) // Already an edge, we're done. return; // Create the new edge. - SourceN.Edges.emplace_back(TargetN, Edge::Ref); + SourceN->Edges.emplace_back(TargetN, Edge::Ref); // Now that we have the edge, handle the graph fallout. handleTrivialEdgeInsertion(SourceN, TargetN); } -void LazyCallGraph::insertEdge(Node &SourceN, Function &Target, Edge::Kind EK) { - assert(SCCMap.empty() && DFSStack.empty() && +void LazyCallGraph::RefSCC::replaceNodeFunction(Node &N, Function &NewF) { + Function &OldF = N.getFunction(); + +#ifndef NDEBUG + // Check that the RefSCC is still valid when we finish. + auto ExitVerifier = make_scope_exit([this] { verify(); }); + + assert(G->lookupRefSCC(N) == this && + "Cannot replace the function of a node outside this RefSCC."); + + assert(G->NodeMap.find(&NewF) == G->NodeMap.end() && + "Must not have already walked the new function!'"); + + // It is important that this replacement not introduce graph changes so we + // insist that the caller has already removed every use of the original + // function and that all uses of the new function correspond to existing + // edges in the graph. The common and expected way to use this is when + // replacing the function itself in the IR without changing the call graph + // shape and just updating the analysis based on that. + assert(&OldF != &NewF && "Cannot replace a function with itself!"); + assert(OldF.use_empty() && + "Must have moved all uses from the old function to the new!"); +#endif + + N.replaceFunction(NewF); + + // Update various call graph maps. + G->NodeMap.erase(&OldF); + G->NodeMap[&NewF] = &N; +} + +void LazyCallGraph::insertEdge(Node &SourceN, Node &TargetN, Edge::Kind EK) { + assert(SCCMap.empty() && "This method cannot be called after SCCs have been formed!"); - return SourceN.insertEdgeInternal(Target, EK); + return SourceN->insertEdgeInternal(TargetN, EK); } -void LazyCallGraph::removeEdge(Node &SourceN, Function &Target) { - assert(SCCMap.empty() && DFSStack.empty() && +void LazyCallGraph::removeEdge(Node &SourceN, Node &TargetN) { + assert(SCCMap.empty() && "This method cannot be called after SCCs have been formed!"); - return SourceN.removeEdgeInternal(Target); + bool Removed = SourceN->removeEdgeInternal(TargetN); + (void)Removed; + assert(Removed && "Target not in the edge set for this caller?"); } void LazyCallGraph::removeDeadFunction(Function &F) { @@ -1532,19 +1575,6 @@ void LazyCallGraph::removeDeadFunction(Function &F) { assert(F.use_empty() && "This routine should only be called on trivially dead functions!"); - auto EII = EntryIndexMap.find(&F); - if (EII != EntryIndexMap.end()) { - EntryEdges[EII->second] = Edge(); - EntryIndexMap.erase(EII); - } - - // It's safe to just remove un-visited functions from the RefSCC entry list. - // FIXME: This is a linear operation which could become hot and benefit from - // an index map. - auto RENI = find(RefSCCEntryNodes, &F); - if (RENI != RefSCCEntryNodes.end()) - RefSCCEntryNodes.erase(RENI); - auto NI = NodeMap.find(&F); if (NI == NodeMap.end()) // Not in the graph at all! @@ -1553,22 +1583,16 @@ void LazyCallGraph::removeDeadFunction(Function &F) { Node &N = *NI->second; NodeMap.erase(NI); - if (SCCMap.empty() && DFSStack.empty()) { - // No SCC walk has begun, so removing this is fine and there is nothing + // Remove this from the entry edges if present. + EntryEdges.removeEdgeInternal(N); + + if (SCCMap.empty()) { + // No SCCs have been formed, so removing this is fine and there is nothing // else necessary at this point but clearing out the node. N.clear(); return; } - // Check that we aren't going to break the DFS walk. - assert(all_of(DFSStack, - [&N](const std::pair<Node *, edge_iterator> &Element) { - return Element.first != &N; - }) && - "Tried to remove a function currently in the DFS stack!"); - assert(find(PendingRefSCCStack, &N) == PendingRefSCCStack.end() && - "Tried to remove a function currently pending to add to a RefSCC!"); - // Cannot remove a function which has yet to be visited in the DFS walk, so // if we have a node at all then we must have an SCC and RefSCC. auto CI = SCCMap.find(&N); @@ -1583,13 +1607,19 @@ void LazyCallGraph::removeDeadFunction(Function &F) { // Validate these properties first. assert(C.size() == 1 && "Dead functions must be in a singular SCC"); assert(RC.size() == 1 && "Dead functions must be in a singular RefSCC"); - assert(RC.Parents.empty() && "Cannot have parents of a dead RefSCC!"); + + // Clean up any remaining reference edges. Note that we walk an unordered set + // here but are just removing and so the order doesn't matter. + for (RefSCC &ParentRC : RC.parents()) + for (SCC &ParentC : ParentRC) + for (Node &ParentN : ParentC) + if (ParentN) + ParentN->removeEdgeInternal(N); // Now remove this RefSCC from any parents sets and the leaf list. - for (Edge &E : N) - if (Node *TargetN = E.getNode()) - if (RefSCC *TargetRC = lookupRefSCC(*TargetN)) - TargetRC->Parents.erase(&RC); + for (Edge &E : *N) + if (RefSCC *TargetRC = lookupRefSCC(E.getNode())) + TargetRC->Parents.erase(&RC); // FIXME: This is a linear operation which could become hot and benefit from // an index map. auto LRI = find(LeafRefSCCs, &RC); @@ -1622,15 +1652,14 @@ void LazyCallGraph::updateGraphPtrs() { { SmallVector<Node *, 16> Worklist; for (Edge &E : EntryEdges) - if (Node *EntryN = E.getNode()) - Worklist.push_back(EntryN); + Worklist.push_back(&E.getNode()); while (!Worklist.empty()) { - Node *N = Worklist.pop_back_val(); - N->G = this; - for (Edge &E : N->Edges) - if (Node *TargetN = E.getNode()) - Worklist.push_back(TargetN); + Node &N = *Worklist.pop_back_val(); + N.G = this; + if (N) + for (Edge &E : *N) + Worklist.push_back(&E.getNode()); } } @@ -1647,34 +1676,18 @@ void LazyCallGraph::updateGraphPtrs() { } } -/// Build the internal SCCs for a RefSCC from a sequence of nodes. -/// -/// Appends the SCCs to the provided vector and updates the map with their -/// indices. Both the vector and map must be empty when passed into this -/// routine. -void LazyCallGraph::buildSCCs(RefSCC &RC, node_stack_range Nodes) { - assert(RC.SCCs.empty() && "Already built SCCs!"); - assert(RC.SCCIndices.empty() && "Already mapped SCC indices!"); - - for (Node *N : Nodes) { - assert(N->LowLink >= (*Nodes.begin())->LowLink && - "We cannot have a low link in an SCC lower than its root on the " - "stack!"); +template <typename RootsT, typename GetBeginT, typename GetEndT, + typename GetNodeT, typename FormSCCCallbackT> +void LazyCallGraph::buildGenericSCCs(RootsT &&Roots, GetBeginT &&GetBegin, + GetEndT &&GetEnd, GetNodeT &&GetNode, + FormSCCCallbackT &&FormSCC) { + typedef decltype(GetBegin(std::declval<Node &>())) EdgeItT; - // This node will go into the next RefSCC, clear out its DFS and low link - // as we scan. - N->DFSNumber = N->LowLink = 0; - } - - // Each RefSCC contains a DAG of the call SCCs. To build these, we do - // a direct walk of the call edges using Tarjan's algorithm. We reuse the - // internal storage as we won't need it for the outer graph's DFS any longer. - - SmallVector<std::pair<Node *, call_edge_iterator>, 16> DFSStack; + SmallVector<std::pair<Node *, EdgeItT>, 16> DFSStack; SmallVector<Node *, 16> PendingSCCStack; // Scan down the stack and DFS across the call edges. - for (Node *RootN : Nodes) { + for (Node *RootN : Roots) { assert(DFSStack.empty() && "Cannot begin a new root with a non-empty DFS stack!"); assert(PendingSCCStack.empty() && @@ -1690,25 +1703,23 @@ void LazyCallGraph::buildSCCs(RefSCC &RC, node_stack_range Nodes) { RootN->DFSNumber = RootN->LowLink = 1; int NextDFSNumber = 2; - DFSStack.push_back({RootN, RootN->call_begin()}); + DFSStack.push_back({RootN, GetBegin(*RootN)}); do { Node *N; - call_edge_iterator I; + EdgeItT I; std::tie(N, I) = DFSStack.pop_back_val(); - auto E = N->call_end(); + auto E = GetEnd(*N); while (I != E) { - Node &ChildN = *I->getNode(); + Node &ChildN = GetNode(I); if (ChildN.DFSNumber == 0) { // We haven't yet visited this child, so descend, pushing the current // node onto the stack. DFSStack.push_back({N, I}); - assert(!lookupSCC(ChildN) && - "Found a node with 0 DFS number but already in an SCC!"); ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++; N = &ChildN; - I = N->call_begin(); - E = N->call_end(); + I = GetBegin(*N); + E = GetEnd(*N); continue; } @@ -1750,20 +1761,93 @@ void LazyCallGraph::buildSCCs(RefSCC &RC, node_stack_range Nodes) { })); // Form a new SCC out of these nodes and then clear them off our pending // stack. - RC.SCCs.push_back(createSCC(RC, SCCNodes)); - for (Node &N : *RC.SCCs.back()) { - N.DFSNumber = N.LowLink = -1; - SCCMap[&N] = RC.SCCs.back(); - } + FormSCC(SCCNodes); PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end()); } while (!DFSStack.empty()); } +} + +/// Build the internal SCCs for a RefSCC from a sequence of nodes. +/// +/// Appends the SCCs to the provided vector and updates the map with their +/// indices. Both the vector and map must be empty when passed into this +/// routine. +void LazyCallGraph::buildSCCs(RefSCC &RC, node_stack_range Nodes) { + assert(RC.SCCs.empty() && "Already built SCCs!"); + assert(RC.SCCIndices.empty() && "Already mapped SCC indices!"); + + for (Node *N : Nodes) { + assert(N->LowLink >= (*Nodes.begin())->LowLink && + "We cannot have a low link in an SCC lower than its root on the " + "stack!"); + + // This node will go into the next RefSCC, clear out its DFS and low link + // as we scan. + N->DFSNumber = N->LowLink = 0; + } + + // Each RefSCC contains a DAG of the call SCCs. To build these, we do + // a direct walk of the call edges using Tarjan's algorithm. We reuse the + // internal storage as we won't need it for the outer graph's DFS any longer. + buildGenericSCCs( + Nodes, [](Node &N) { return N->call_begin(); }, + [](Node &N) { return N->call_end(); }, + [](EdgeSequence::call_iterator I) -> Node & { return I->getNode(); }, + [this, &RC](node_stack_range Nodes) { + RC.SCCs.push_back(createSCC(RC, Nodes)); + for (Node &N : *RC.SCCs.back()) { + N.DFSNumber = N.LowLink = -1; + SCCMap[&N] = RC.SCCs.back(); + } + }); // Wire up the SCC indices. for (int i = 0, Size = RC.SCCs.size(); i < Size; ++i) RC.SCCIndices[RC.SCCs[i]] = i; } +void LazyCallGraph::buildRefSCCs() { + if (EntryEdges.empty() || !PostOrderRefSCCs.empty()) + // RefSCCs are either non-existent or already built! + return; + + assert(RefSCCIndices.empty() && "Already mapped RefSCC indices!"); + + SmallVector<Node *, 16> Roots; + for (Edge &E : *this) + Roots.push_back(&E.getNode()); + + // The roots will be popped of a stack, so use reverse to get a less + // surprising order. This doesn't change any of the semantics anywhere. + std::reverse(Roots.begin(), Roots.end()); + + buildGenericSCCs( + Roots, + [](Node &N) { + // We need to populate each node as we begin to walk its edges. + N.populate(); + return N->begin(); + }, + [](Node &N) { return N->end(); }, + [](EdgeSequence::iterator I) -> Node & { return I->getNode(); }, + [this](node_stack_range Nodes) { + RefSCC *NewRC = createRefSCC(*this); + buildSCCs(*NewRC, Nodes); + connectRefSCC(*NewRC); + + // Push the new node into the postorder list and remember its position + // in the index map. + bool Inserted = + RefSCCIndices.insert({NewRC, PostOrderRefSCCs.size()}).second; + (void)Inserted; + assert(Inserted && "Cannot already have this RefSCC in the index map!"); + PostOrderRefSCCs.push_back(NewRC); +#ifndef NDEBUG + NewRC->verify(); +#endif + }); +} + // FIXME: We should move callers of this to embed the parent linking and leaf // tracking into their DFS in order to remove a full walk of all edges. void LazyCallGraph::connectRefSCC(RefSCC &RC) { @@ -1773,10 +1857,8 @@ void LazyCallGraph::connectRefSCC(RefSCC &RC) { bool IsLeaf = true; for (SCC &C : RC) for (Node &N : C) - for (Edge &E : N) { - assert(E.getNode() && - "Cannot have a missing node in a visited part of the graph!"); - RefSCC &ChildRC = *lookupRefSCC(*E.getNode()); + for (Edge &E : *N) { + RefSCC &ChildRC = *lookupRefSCC(E.getNode()); if (&ChildRC == &RC) continue; ChildRC.Parents.insert(&RC); @@ -1788,113 +1870,13 @@ void LazyCallGraph::connectRefSCC(RefSCC &RC) { LeafRefSCCs.push_back(&RC); } -bool LazyCallGraph::buildNextRefSCCInPostOrder() { - if (DFSStack.empty()) { - Node *N; - do { - // If we've handled all candidate entry nodes to the SCC forest, we're - // done. - if (RefSCCEntryNodes.empty()) - return false; - - N = &get(*RefSCCEntryNodes.pop_back_val()); - } while (N->DFSNumber != 0); - - // Found a new root, begin the DFS here. - N->LowLink = N->DFSNumber = 1; - NextDFSNumber = 2; - DFSStack.push_back({N, N->begin()}); - } - - for (;;) { - Node *N; - edge_iterator I; - std::tie(N, I) = DFSStack.pop_back_val(); - - assert(N->DFSNumber > 0 && "We should always assign a DFS number " - "before placing a node onto the stack."); - - auto E = N->end(); - while (I != E) { - Node &ChildN = I->getNode(*this); - if (ChildN.DFSNumber == 0) { - // We haven't yet visited this child, so descend, pushing the current - // node onto the stack. - DFSStack.push_back({N, N->begin()}); - - assert(!SCCMap.count(&ChildN) && - "Found a node with 0 DFS number but already in an SCC!"); - ChildN.LowLink = ChildN.DFSNumber = NextDFSNumber++; - N = &ChildN; - I = N->begin(); - E = N->end(); - continue; - } - - // If the child has already been added to some child component, it - // couldn't impact the low-link of this parent because it isn't - // connected, and thus its low-link isn't relevant so skip it. - if (ChildN.DFSNumber == -1) { - ++I; - continue; - } - - // Track the lowest linked child as the lowest link for this node. - assert(ChildN.LowLink > 0 && "Must have a positive low-link number!"); - if (ChildN.LowLink < N->LowLink) - N->LowLink = ChildN.LowLink; - - // Move to the next edge. - ++I; - } - - // We've finished processing N and its descendents, put it on our pending - // SCC stack to eventually get merged into an SCC of nodes. - PendingRefSCCStack.push_back(N); - - // If this node is linked to some lower entry, continue walking up the - // stack. - if (N->LowLink != N->DFSNumber) { - assert(!DFSStack.empty() && - "We never found a viable root for an SCC to pop off!"); - continue; - } - - // Otherwise, form a new RefSCC from the top of the pending node stack. - int RootDFSNumber = N->DFSNumber; - // Find the range of the node stack by walking down until we pass the - // root DFS number. - auto RefSCCNodes = node_stack_range( - PendingRefSCCStack.rbegin(), - find_if(reverse(PendingRefSCCStack), [RootDFSNumber](const Node *N) { - return N->DFSNumber < RootDFSNumber; - })); - // Form a new RefSCC out of these nodes and then clear them off our pending - // stack. - RefSCC *NewRC = createRefSCC(*this); - buildSCCs(*NewRC, RefSCCNodes); - connectRefSCC(*NewRC); - PendingRefSCCStack.erase(RefSCCNodes.end().base(), - PendingRefSCCStack.end()); - - // Push the new node into the postorder list and return true indicating we - // successfully grew the postorder sequence by one. - bool Inserted = - RefSCCIndices.insert({NewRC, PostOrderRefSCCs.size()}).second; - (void)Inserted; - assert(Inserted && "Cannot already have this RefSCC in the index map!"); - PostOrderRefSCCs.push_back(NewRC); - return true; - } -} - AnalysisKey LazyCallGraphAnalysis::Key; LazyCallGraphPrinterPass::LazyCallGraphPrinterPass(raw_ostream &OS) : OS(OS) {} static void printNode(raw_ostream &OS, LazyCallGraph::Node &N) { OS << " Edges in function: " << N.getFunction().getName() << "\n"; - for (const LazyCallGraph::Edge &E : N) + for (LazyCallGraph::Edge &E : N.populate()) OS << " " << (E.isCall() ? "call" : "ref ") << " -> " << E.getFunction().getName() << "\n"; @@ -1929,6 +1911,7 @@ PreservedAnalyses LazyCallGraphPrinterPass::run(Module &M, for (Function &F : M) printNode(OS, G.get(F)); + G.buildRefSCCs(); for (LazyCallGraph::RefSCC &C : G.postorder_ref_sccs()) printRefSCC(OS, C); @@ -1941,7 +1924,7 @@ LazyCallGraphDOTPrinterPass::LazyCallGraphDOTPrinterPass(raw_ostream &OS) static void printNodeDOT(raw_ostream &OS, LazyCallGraph::Node &N) { std::string Name = "\"" + DOT::EscapeString(N.getFunction().getName()) + "\""; - for (const LazyCallGraph::Edge &E : N) { + for (LazyCallGraph::Edge &E : N.populate()) { OS << " " << Name << " -> \"" << DOT::EscapeString(E.getFunction().getName()) << "\""; if (!E.isCall()) // It is a ref edge. diff --git a/contrib/llvm/lib/Analysis/LazyValueInfo.cpp b/contrib/llvm/lib/Analysis/LazyValueInfo.cpp index d442310476cf..6a9ae6440ace 100644 --- a/contrib/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/contrib/llvm/lib/Analysis/LazyValueInfo.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" #include "llvm/IR/CFG.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -31,6 +32,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/FormattedStream.h" #include "llvm/Support/raw_ostream.h" #include <map> #include <stack> @@ -39,6 +41,10 @@ using namespace PatternMatch; #define DEBUG_TYPE "lazy-value-info" +// This is the number of worklist items we will process to try to discover an +// answer for a given value. +static const unsigned MaxProcessedPerValue = 500; + char LazyValueInfoWrapperPass::ID = 0; INITIALIZE_PASS_BEGIN(LazyValueInfoWrapperPass, "lazy-value-info", "Lazy Value Information Analysis", false, true) @@ -136,7 +142,7 @@ public: return Val; } - ConstantRange getConstantRange() const { + const ConstantRange &getConstantRange() const { assert(isConstantRange() && "Cannot get the constant-range of a non-constant-range!"); return Range; @@ -244,7 +250,7 @@ public: if (NewR.isFullSet()) markOverdefined(); else - markConstantRange(NewR); + markConstantRange(std::move(NewR)); } }; @@ -358,6 +364,7 @@ namespace { /// This is the cache kept by LazyValueInfo which /// maintains information about queries across the clients' queries. class LazyValueInfoCache { + friend class LazyValueInfoAnnotatedWriter; /// This is all of the cached block information for exactly one Value*. /// The entries are sorted by the BasicBlock* of the /// entries, allowing us to do a lookup with a binary search. @@ -366,22 +373,23 @@ namespace { struct ValueCacheEntryTy { ValueCacheEntryTy(Value *V, LazyValueInfoCache *P) : Handle(V, P) {} LVIValueHandle Handle; - SmallDenseMap<AssertingVH<BasicBlock>, LVILatticeVal, 4> BlockVals; + SmallDenseMap<PoisoningVH<BasicBlock>, LVILatticeVal, 4> BlockVals; }; - /// This is all of the cached information for all values, - /// mapped from Value* to key information. - DenseMap<Value *, std::unique_ptr<ValueCacheEntryTy>> ValueCache; - /// This tracks, on a per-block basis, the set of values that are /// over-defined at the end of that block. - typedef DenseMap<AssertingVH<BasicBlock>, SmallPtrSet<Value *, 4>> + typedef DenseMap<PoisoningVH<BasicBlock>, SmallPtrSet<Value *, 4>> OverDefinedCacheTy; - OverDefinedCacheTy OverDefinedCache; - /// Keep track of all blocks that we have ever seen, so we /// don't spend time removing unused blocks from our caches. - DenseSet<AssertingVH<BasicBlock> > SeenBlocks; + DenseSet<PoisoningVH<BasicBlock> > SeenBlocks; + + protected: + /// This is all of the cached information for all values, + /// mapped from Value* to key information. + DenseMap<Value *, std::unique_ptr<ValueCacheEntryTy>> ValueCache; + OverDefinedCacheTy OverDefinedCache; + public: void insertResult(Value *Val, BasicBlock *BB, const LVILatticeVal &Result) { @@ -435,6 +443,7 @@ namespace { return BBI->second; } + void printCache(Function &F, raw_ostream &OS); /// clear - Empty the cache. void clear() { SeenBlocks.clear(); @@ -458,16 +467,71 @@ namespace { }; } + +namespace { + + /// An assembly annotator class to print LazyValueCache information in + /// comments. + class LazyValueInfoAnnotatedWriter : public AssemblyAnnotationWriter { + const LazyValueInfoCache* LVICache; + + public: + LazyValueInfoAnnotatedWriter(const LazyValueInfoCache *L) : LVICache(L) {} + + virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, + formatted_raw_ostream &OS) { + auto ODI = LVICache->OverDefinedCache.find(const_cast<BasicBlock*>(BB)); + if (ODI == LVICache->OverDefinedCache.end()) + return; + OS << "; OverDefined values for block are: \n"; + for (auto *V : ODI->second) + OS << ";" << *V << "\n"; + + // Find if there are latticevalues defined for arguments of the function. + auto *F = const_cast<Function *>(BB->getParent()); + for (auto &Arg : F->args()) { + auto VI = LVICache->ValueCache.find_as(&Arg); + if (VI == LVICache->ValueCache.end()) + continue; + auto BBI = VI->second->BlockVals.find(const_cast<BasicBlock *>(BB)); + if (BBI != VI->second->BlockVals.end()) + OS << "; CachedLatticeValue for: '" << *VI->first << "' is: '" + << BBI->second << "'\n"; + } + } + + virtual void emitInstructionAnnot(const Instruction *I, + formatted_raw_ostream &OS) { + + auto VI = LVICache->ValueCache.find_as(const_cast<Instruction *>(I)); + if (VI == LVICache->ValueCache.end()) + return; + OS << "; CachedLatticeValues for: '" << *VI->first << "'\n"; + for (auto &BV : VI->second->BlockVals) { + OS << "; at beginning of BasicBlock: '"; + BV.first->printAsOperand(OS, false); + OS << "' LatticeVal: '" << BV.second << "' \n"; + } + } +}; +} + +void LazyValueInfoCache::printCache(Function &F, raw_ostream &OS) { + LazyValueInfoAnnotatedWriter Writer(this); + F.print(OS, &Writer); + +} + void LazyValueInfoCache::eraseValue(Value *V) { - SmallVector<AssertingVH<BasicBlock>, 4> ToErase; - for (auto &I : OverDefinedCache) { - SmallPtrSetImpl<Value *> &ValueSet = I.second; + for (auto I = OverDefinedCache.begin(), E = OverDefinedCache.end(); I != E;) { + // Copy and increment the iterator immediately so we can erase behind + // ourselves. + auto Iter = I++; + SmallPtrSetImpl<Value *> &ValueSet = Iter->second; ValueSet.erase(V); if (ValueSet.empty()) - ToErase.push_back(I.first); + OverDefinedCache.erase(Iter); } - for (auto &BB : ToErase) - OverDefinedCache.erase(BB); ValueCache.erase(V); } @@ -480,7 +544,7 @@ void LVIValueHandle::deleted() { void LazyValueInfoCache::eraseBlock(BasicBlock *BB) { // Shortcut if we have never seen this block. - DenseSet<AssertingVH<BasicBlock> >::iterator I = SeenBlocks.find(BB); + DenseSet<PoisoningVH<BasicBlock> >::iterator I = SeenBlocks.find(BB); if (I == SeenBlocks.end()) return; SeenBlocks.erase(I); @@ -563,7 +627,7 @@ namespace { /// This stack holds the state of the value solver during a query. /// It basically emulates the callstack of the naive /// recursive value lookup process. - std::stack<std::pair<BasicBlock*, Value*> > BlockValueStack; + SmallVector<std::pair<BasicBlock*, Value*>, 8> BlockValueStack; /// Keeps track of which block-value pairs are in BlockValueStack. DenseSet<std::pair<BasicBlock*, Value*> > BlockValueSet; @@ -576,7 +640,7 @@ namespace { DEBUG(dbgs() << "PUSH: " << *BV.second << " in " << BV.first->getName() << "\n"); - BlockValueStack.push(BV); + BlockValueStack.push_back(BV); return true; } @@ -598,13 +662,13 @@ namespace { bool solveBlockValuePHINode(LVILatticeVal &BBLV, PHINode *PN, BasicBlock *BB); bool solveBlockValueSelect(LVILatticeVal &BBLV, SelectInst *S, BasicBlock *BB); - bool solveBlockValueBinaryOp(LVILatticeVal &BBLV, Instruction *BBI, + bool solveBlockValueBinaryOp(LVILatticeVal &BBLV, BinaryOperator *BBI, BasicBlock *BB); - bool solveBlockValueCast(LVILatticeVal &BBLV, Instruction *BBI, + bool solveBlockValueCast(LVILatticeVal &BBLV, CastInst *CI, BasicBlock *BB); void intersectAssumeOrGuardBlockValueConstantRange(Value *Val, LVILatticeVal &BBLV, - Instruction *BBI); + Instruction *BBI); void solve(); @@ -629,6 +693,11 @@ namespace { TheCache.clear(); } + /// Printing the LazyValueInfoCache. + void printCache(Function &F, raw_ostream &OS) { + TheCache.printCache(F, OS); + } + /// This is part of the update interface to inform the cache /// that a block has been deleted. void eraseBlock(BasicBlock *BB) { @@ -646,24 +715,50 @@ namespace { } // end anonymous namespace void LazyValueInfoImpl::solve() { + SmallVector<std::pair<BasicBlock *, Value *>, 8> StartingStack( + BlockValueStack.begin(), BlockValueStack.end()); + + unsigned processedCount = 0; while (!BlockValueStack.empty()) { - std::pair<BasicBlock*, Value*> &e = BlockValueStack.top(); + processedCount++; + // Abort if we have to process too many values to get a result for this one. + // Because of the design of the overdefined cache currently being per-block + // to avoid naming-related issues (IE it wants to try to give different + // results for the same name in different blocks), overdefined results don't + // get cached globally, which in turn means we will often try to rediscover + // the same overdefined result again and again. Once something like + // PredicateInfo is used in LVI or CVP, we should be able to make the + // overdefined cache global, and remove this throttle. + if (processedCount > MaxProcessedPerValue) { + DEBUG(dbgs() << "Giving up on stack because we are getting too deep\n"); + // Fill in the original values + while (!StartingStack.empty()) { + std::pair<BasicBlock *, Value *> &e = StartingStack.back(); + TheCache.insertResult(e.second, e.first, + LVILatticeVal::getOverdefined()); + StartingStack.pop_back(); + } + BlockValueSet.clear(); + BlockValueStack.clear(); + return; + } + std::pair<BasicBlock *, Value *> e = BlockValueStack.back(); assert(BlockValueSet.count(e) && "Stack value should be in BlockValueSet!"); if (solveBlockValue(e.second, e.first)) { // The work item was completely processed. - assert(BlockValueStack.top() == e && "Nothing should have been pushed!"); + assert(BlockValueStack.back() == e && "Nothing should have been pushed!"); assert(TheCache.hasCachedValueInfo(e.second, e.first) && "Result should be in cache!"); DEBUG(dbgs() << "POP " << *e.second << " in " << e.first->getName() << " = " << TheCache.getCachedValueInfo(e.second, e.first) << "\n"); - BlockValueStack.pop(); + BlockValueStack.pop_back(); BlockValueSet.erase(e); } else { // More work needs to be done before revisiting. - assert(BlockValueStack.top() != e && "Stack should have been pushed!"); + assert(BlockValueStack.back() != e && "Stack should have been pushed!"); } } } @@ -754,12 +849,12 @@ bool LazyValueInfoImpl::solveBlockValueImpl(LVILatticeVal &Res, return true; } if (BBI->getType()->isIntegerTy()) { - if (isa<CastInst>(BBI)) - return solveBlockValueCast(Res, BBI, BB); - + if (auto *CI = dyn_cast<CastInst>(BBI)) + return solveBlockValueCast(Res, CI, BB); + BinaryOperator *BO = dyn_cast<BinaryOperator>(BBI); if (BO && isa<ConstantInt>(BO->getOperand(1))) - return solveBlockValueBinaryOp(Res, BBI, BB); + return solveBlockValueBinaryOp(Res, BO, BB); } DEBUG(dbgs() << " compute BB '" << BB->getName() @@ -825,7 +920,7 @@ bool LazyValueInfoImpl::solveBlockValueNonLocal(LVILatticeVal &BBLV, // value is overdefined. if (BB == &BB->getParent()->getEntryBlock()) { assert(isa<Argument>(Val) && "Unknown live-in to the entry block"); - // Bofore giving up, see if we can prove the pointer non-null local to + // Before giving up, see if we can prove the pointer non-null local to // this particular block. if (Val->getType()->isPointerTy() && (isKnownNonNull(Val) || isObjectDereferencedInBlock(Val, BB))) { @@ -839,13 +934,19 @@ bool LazyValueInfoImpl::solveBlockValueNonLocal(LVILatticeVal &BBLV, } // Loop over all of our predecessors, merging what we know from them into - // result. - bool EdgesMissing = false; + // result. If we encounter an unexplored predecessor, we eagerly explore it + // in a depth first manner. In practice, this has the effect of discovering + // paths we can't analyze eagerly without spending compile times analyzing + // other paths. This heuristic benefits from the fact that predecessors are + // frequently arranged such that dominating ones come first and we quickly + // find a path to function entry. TODO: We should consider explicitly + // canonicalizing to make this true rather than relying on this happy + // accident. for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { LVILatticeVal EdgeResult; - EdgesMissing |= !getEdgeValue(Val, *PI, BB, EdgeResult); - if (EdgesMissing) - continue; + if (!getEdgeValue(Val, *PI, BB, EdgeResult)) + // Explore that input, then return here + return false; Result.mergeIn(EdgeResult, DL); @@ -866,8 +967,6 @@ bool LazyValueInfoImpl::solveBlockValueNonLocal(LVILatticeVal &BBLV, return true; } } - if (EdgesMissing) - return false; // Return the merged value, which is more precise than 'overdefined'. assert(!Result.isOverdefined()); @@ -880,8 +979,8 @@ bool LazyValueInfoImpl::solveBlockValuePHINode(LVILatticeVal &BBLV, LVILatticeVal Result; // Start Undefined. // Loop over all of our predecessors, merging what we know from them into - // result. - bool EdgesMissing = false; + // result. See the comment about the chosen traversal order in + // solveBlockValueNonLocal; the same reasoning applies here. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { BasicBlock *PhiBB = PN->getIncomingBlock(i); Value *PhiVal = PN->getIncomingValue(i); @@ -889,9 +988,9 @@ bool LazyValueInfoImpl::solveBlockValuePHINode(LVILatticeVal &BBLV, // Note that we can provide PN as the context value to getEdgeValue, even // though the results will be cached, because PN is the value being used as // the cache key in the caller. - EdgesMissing |= !getEdgeValue(PhiVal, PhiBB, BB, EdgeResult, PN); - if (EdgesMissing) - continue; + if (!getEdgeValue(PhiVal, PhiBB, BB, EdgeResult, PN)) + // Explore that input, then return here + return false; Result.mergeIn(EdgeResult, DL); @@ -905,8 +1004,6 @@ bool LazyValueInfoImpl::solveBlockValuePHINode(LVILatticeVal &BBLV, return true; } } - if (EdgesMissing) - return false; // Return the merged value, which is more precise than 'overdefined'. assert(!Result.isOverdefined() && "Possible PHI in entry block?"); @@ -982,8 +1079,8 @@ bool LazyValueInfoImpl::solveBlockValueSelect(LVILatticeVal &BBLV, } if (TrueVal.isConstantRange() && FalseVal.isConstantRange()) { - ConstantRange TrueCR = TrueVal.getConstantRange(); - ConstantRange FalseCR = FalseVal.getConstantRange(); + const ConstantRange &TrueCR = TrueVal.getConstantRange(); + const ConstantRange &FalseCR = FalseVal.getConstantRange(); Value *LHS = nullptr; Value *RHS = nullptr; SelectPatternResult SPR = matchSelectPattern(SI, LHS, RHS); @@ -1071,9 +1168,9 @@ bool LazyValueInfoImpl::solveBlockValueSelect(LVILatticeVal &BBLV, } bool LazyValueInfoImpl::solveBlockValueCast(LVILatticeVal &BBLV, - Instruction *BBI, - BasicBlock *BB) { - if (!BBI->getOperand(0)->getType()->isSized()) { + CastInst *CI, + BasicBlock *BB) { + if (!CI->getOperand(0)->getType()->isSized()) { // Without knowing how wide the input is, we can't analyze it in any useful // way. BBLV = LVILatticeVal::getOverdefined(); @@ -1083,7 +1180,7 @@ bool LazyValueInfoImpl::solveBlockValueCast(LVILatticeVal &BBLV, // Filter out casts we don't know how to reason about before attempting to // recurse on our operand. This can cut a long search short if we know we're // not going to be able to get any useful information anways. - switch (BBI->getOpcode()) { + switch (CI->getOpcode()) { case Instruction::Trunc: case Instruction::SExt: case Instruction::ZExt: @@ -1100,44 +1197,43 @@ bool LazyValueInfoImpl::solveBlockValueCast(LVILatticeVal &BBLV, // Figure out the range of the LHS. If that fails, we still apply the // transfer rule on the full set since we may be able to locally infer // interesting facts. - if (!hasBlockValue(BBI->getOperand(0), BB)) - if (pushBlockValue(std::make_pair(BB, BBI->getOperand(0)))) + if (!hasBlockValue(CI->getOperand(0), BB)) + if (pushBlockValue(std::make_pair(BB, CI->getOperand(0)))) // More work to do before applying this transfer rule. return false; const unsigned OperandBitWidth = - DL.getTypeSizeInBits(BBI->getOperand(0)->getType()); + DL.getTypeSizeInBits(CI->getOperand(0)->getType()); ConstantRange LHSRange = ConstantRange(OperandBitWidth); - if (hasBlockValue(BBI->getOperand(0), BB)) { - LVILatticeVal LHSVal = getBlockValue(BBI->getOperand(0), BB); - intersectAssumeOrGuardBlockValueConstantRange(BBI->getOperand(0), LHSVal, - BBI); + if (hasBlockValue(CI->getOperand(0), BB)) { + LVILatticeVal LHSVal = getBlockValue(CI->getOperand(0), BB); + intersectAssumeOrGuardBlockValueConstantRange(CI->getOperand(0), LHSVal, + CI); if (LHSVal.isConstantRange()) LHSRange = LHSVal.getConstantRange(); } - const unsigned ResultBitWidth = - cast<IntegerType>(BBI->getType())->getBitWidth(); + const unsigned ResultBitWidth = CI->getType()->getIntegerBitWidth(); // NOTE: We're currently limited by the set of operations that ConstantRange // can evaluate symbolically. Enhancing that set will allows us to analyze // more definitions. - auto CastOp = (Instruction::CastOps) BBI->getOpcode(); - BBLV = LVILatticeVal::getRange(LHSRange.castOp(CastOp, ResultBitWidth)); + BBLV = LVILatticeVal::getRange(LHSRange.castOp(CI->getOpcode(), + ResultBitWidth)); return true; } bool LazyValueInfoImpl::solveBlockValueBinaryOp(LVILatticeVal &BBLV, - Instruction *BBI, + BinaryOperator *BO, BasicBlock *BB) { - assert(BBI->getOperand(0)->getType()->isSized() && + assert(BO->getOperand(0)->getType()->isSized() && "all operands to binary operators are sized"); // Filter out operators we don't know how to reason about before attempting to // recurse on our operand(s). This can cut a long search short if we know - // we're not going to be able to get any useful information anways. - switch (BBI->getOpcode()) { + // we're not going to be able to get any useful information anyways. + switch (BO->getOpcode()) { case Instruction::Add: case Instruction::Sub: case Instruction::Mul: @@ -1159,29 +1255,29 @@ bool LazyValueInfoImpl::solveBlockValueBinaryOp(LVILatticeVal &BBLV, // Figure out the range of the LHS. If that fails, use a conservative range, // but apply the transfer rule anyways. This lets us pick up facts from // expressions like "and i32 (call i32 @foo()), 32" - if (!hasBlockValue(BBI->getOperand(0), BB)) - if (pushBlockValue(std::make_pair(BB, BBI->getOperand(0)))) + if (!hasBlockValue(BO->getOperand(0), BB)) + if (pushBlockValue(std::make_pair(BB, BO->getOperand(0)))) // More work to do before applying this transfer rule. return false; const unsigned OperandBitWidth = - DL.getTypeSizeInBits(BBI->getOperand(0)->getType()); + DL.getTypeSizeInBits(BO->getOperand(0)->getType()); ConstantRange LHSRange = ConstantRange(OperandBitWidth); - if (hasBlockValue(BBI->getOperand(0), BB)) { - LVILatticeVal LHSVal = getBlockValue(BBI->getOperand(0), BB); - intersectAssumeOrGuardBlockValueConstantRange(BBI->getOperand(0), LHSVal, - BBI); + if (hasBlockValue(BO->getOperand(0), BB)) { + LVILatticeVal LHSVal = getBlockValue(BO->getOperand(0), BB); + intersectAssumeOrGuardBlockValueConstantRange(BO->getOperand(0), LHSVal, + BO); if (LHSVal.isConstantRange()) LHSRange = LHSVal.getConstantRange(); } - ConstantInt *RHS = cast<ConstantInt>(BBI->getOperand(1)); + ConstantInt *RHS = cast<ConstantInt>(BO->getOperand(1)); ConstantRange RHSRange = ConstantRange(RHS->getValue()); // NOTE: We're currently limited by the set of operations that ConstantRange // can evaluate symbolically. Enhancing that set will allows us to analyze // more definitions. - auto BinOp = (Instruction::BinaryOps) BBI->getOpcode(); + Instruction::BinaryOps BinOp = BO->getOpcode(); BBLV = LVILatticeVal::getRange(LHSRange.binaryOp(BinOp, RHSRange)); return true; } @@ -1333,14 +1429,14 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, unsigned BitWidth = Val->getType()->getIntegerBitWidth(); ConstantRange EdgesVals(BitWidth, DefaultCase/*isFullSet*/); - for (SwitchInst::CaseIt i : SI->cases()) { - ConstantRange EdgeVal(i.getCaseValue()->getValue()); + for (auto Case : SI->cases()) { + ConstantRange EdgeVal(Case.getCaseValue()->getValue()); if (DefaultCase) { // It is possible that the default destination is the destination of // some cases. There is no need to perform difference for those cases. - if (i.getCaseSuccessor() != BBTo) + if (Case.getCaseSuccessor() != BBTo) EdgesVals = EdgesVals.difference(EdgeVal); - } else if (i.getCaseSuccessor() == BBTo) + } else if (Case.getCaseSuccessor() == BBTo) EdgesVals = EdgesVals.unionWith(EdgeVal); } Result = LVILatticeVal::getRange(std::move(EdgesVals)); @@ -1352,8 +1448,8 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, /// \brief Compute the value of Val on the edge BBFrom -> BBTo or the value at /// the basic block if the edge does not constrain Val. bool LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom, - BasicBlock *BBTo, LVILatticeVal &Result, - Instruction *CxtI) { + BasicBlock *BBTo, LVILatticeVal &Result, + Instruction *CxtI) { // If already a constant, there is nothing to compute. if (Constant *VC = dyn_cast<Constant>(Val)) { Result = LVILatticeVal::get(VC); @@ -1503,6 +1599,18 @@ void LazyValueInfo::releaseMemory() { } } +bool LazyValueInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // We need to invalidate if we have either failed to preserve this analyses + // result directly or if any of its dependencies have been invalidated. + auto PAC = PA.getChecker<LazyValueAnalysis>(); + if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || + (DT && Inv.invalidate<DominatorTreeAnalysis>(F, PA))) + return true; + + return false; +} + void LazyValueInfoWrapperPass::releaseMemory() { Info.releaseMemory(); } LazyValueInfo LazyValueAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { @@ -1510,7 +1618,7 @@ LazyValueInfo LazyValueAnalysis::run(Function &F, FunctionAnalysisManager &FAM) auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); - return LazyValueInfo(&AC, &TLI, DT); + return LazyValueInfo(&AC, &F.getParent()->getDataLayout(), &TLI, DT); } /// Returns true if we can statically tell that this value will never be a @@ -1540,7 +1648,7 @@ Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB, if (Result.isConstant()) return Result.getConstant(); if (Result.isConstantRange()) { - ConstantRange CR = Result.getConstantRange(); + const ConstantRange &CR = Result.getConstantRange(); if (const APInt *SingleVal = CR.getSingleElement()) return ConstantInt::get(V->getContext(), *SingleVal); } @@ -1577,7 +1685,7 @@ Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB, if (Result.isConstant()) return Result.getConstant(); if (Result.isConstantRange()) { - ConstantRange CR = Result.getConstantRange(); + const ConstantRange &CR = Result.getConstantRange(); if (const APInt *SingleVal = CR.getSingleElement()) return ConstantInt::get(V->getContext(), *SingleVal); } @@ -1603,7 +1711,7 @@ static LazyValueInfo::Tristate getPredicateResult(unsigned Pred, Constant *C, ConstantInt *CI = dyn_cast<ConstantInt>(C); if (!CI) return LazyValueInfo::Unknown; - ConstantRange CR = Result.getConstantRange(); + const ConstantRange &CR = Result.getConstantRange(); if (Pred == ICmpInst::ICMP_EQ) { if (!CR.contains(CI->getValue())) return LazyValueInfo::False; @@ -1780,3 +1888,40 @@ void LazyValueInfo::eraseBlock(BasicBlock *BB) { getImpl(PImpl, AC, &DL, DT).eraseBlock(BB); } } + + +void LazyValueInfo::printCache(Function &F, raw_ostream &OS) { + if (PImpl) { + getImpl(PImpl, AC, DL, DT).printCache(F, OS); + } +} + +namespace { +// Printer class for LazyValueInfo results. +class LazyValueInfoPrinter : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + LazyValueInfoPrinter() : FunctionPass(ID) { + initializeLazyValueInfoPrinterPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired<LazyValueInfoWrapperPass>(); + } + + bool runOnFunction(Function &F) override { + dbgs() << "LVI for function '" << F.getName() << "':\n"; + auto &LVI = getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + LVI.printCache(F, dbgs()); + return false; + } +}; +} + +char LazyValueInfoPrinter::ID = 0; +INITIALIZE_PASS_BEGIN(LazyValueInfoPrinter, "print-lazy-value-info", + "Lazy Value Info Printer Pass", false, false) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_END(LazyValueInfoPrinter, "print-lazy-value-info", + "Lazy Value Info Printer Pass", false, false) diff --git a/contrib/llvm/lib/Analysis/Lint.cpp b/contrib/llvm/lib/Analysis/Lint.cpp index 2ca46b1fe872..e6391792bc23 100644 --- a/contrib/llvm/lib/Analysis/Lint.cpp +++ b/contrib/llvm/lib/Analysis/Lint.cpp @@ -70,6 +70,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include <cassert> @@ -533,11 +534,8 @@ static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, VectorType *VecTy = dyn_cast<VectorType>(V->getType()); if (!VecTy) { - unsigned BitWidth = V->getType()->getIntegerBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL, 0, AC, - dyn_cast<Instruction>(V), DT); - return KnownZero.isAllOnesValue(); + KnownBits Known = computeKnownBits(V, DL, 0, AC, dyn_cast<Instruction>(V), DT); + return Known.isZero(); } // Per-component check doesn't work with zeroinitializer @@ -550,15 +548,13 @@ static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, // For a vector, KnownZero will only be true if all values are zero, so check // this per component - unsigned BitWidth = VecTy->getElementType()->getIntegerBitWidth(); for (unsigned I = 0, N = VecTy->getNumElements(); I != N; ++I) { Constant *Elem = C->getAggregateElement(I); if (isa<UndefValue>(Elem)) return true; - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(Elem, KnownZero, KnownOne, DL); - if (KnownZero.isAllOnesValue()) + KnownBits Known = computeKnownBits(Elem, DL); + if (Known.isZero()) return true; } @@ -699,7 +695,7 @@ Value *Lint::findValueImpl(Value *V, bool OffsetOk, // As a last resort, try SimplifyInstruction or constant folding. if (Instruction *Inst = dyn_cast<Instruction>(V)) { - if (Value *W = SimplifyInstruction(Inst, *DL, TLI, DT, AC)) + if (Value *W = SimplifyInstruction(Inst, {*DL, TLI, DT, AC})) return findValueImpl(W, OffsetOk, Visited); } else if (auto *C = dyn_cast<Constant>(V)) { if (Value *W = ConstantFoldConstant(C, *DL, TLI)) diff --git a/contrib/llvm/lib/Analysis/Loads.cpp b/contrib/llvm/lib/Analysis/Loads.cpp index e46541e6538d..96799a459bfc 100644 --- a/contrib/llvm/lib/Analysis/Loads.cpp +++ b/contrib/llvm/lib/Analysis/Loads.cpp @@ -312,21 +312,26 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BasicBlock *ScanBB, BasicBlock::iterator &ScanFrom, unsigned MaxInstsToScan, - AliasAnalysis *AA, bool *IsLoadCSE) { - if (MaxInstsToScan == 0) - MaxInstsToScan = ~0U; - - Value *Ptr = Load->getPointerOperand(); - Type *AccessTy = Load->getType(); - - // We can never remove a volatile load - if (Load->isVolatile()) - return nullptr; - - // Anything stronger than unordered is currently unimplemented. + AliasAnalysis *AA, bool *IsLoad, + unsigned *NumScanedInst) { + // Don't CSE load that is volatile or anything stronger than unordered. if (!Load->isUnordered()) return nullptr; + return FindAvailablePtrLoadStore( + Load->getPointerOperand(), Load->getType(), Load->isAtomic(), ScanBB, + ScanFrom, MaxInstsToScan, AA, IsLoad, NumScanedInst); +} + +Value *llvm::FindAvailablePtrLoadStore(Value *Ptr, Type *AccessTy, + bool AtLeastAtomic, BasicBlock *ScanBB, + BasicBlock::iterator &ScanFrom, + unsigned MaxInstsToScan, + AliasAnalysis *AA, bool *IsLoadCSE, + unsigned *NumScanedInst) { + if (MaxInstsToScan == 0) + MaxInstsToScan = ~0U; + const DataLayout &DL = ScanBB->getModule()->getDataLayout(); // Try to get the store size for the type. @@ -344,6 +349,9 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, // Restore ScanFrom to expected value in case next test succeeds ScanFrom++; + if (NumScanedInst) + ++(*NumScanedInst); + // Don't scan huge blocks. if (MaxInstsToScan-- == 0) return nullptr; @@ -359,7 +367,7 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, // We can value forward from an atomic to a non-atomic, but not the // other way around. - if (LI->isAtomic() < Load->isAtomic()) + if (LI->isAtomic() < AtLeastAtomic) return nullptr; if (IsLoadCSE) @@ -378,7 +386,7 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, // We can value forward from an atomic to a non-atomic, but not the // other way around. - if (SI->isAtomic() < Load->isAtomic()) + if (SI->isAtomic() < AtLeastAtomic) return nullptr; if (IsLoadCSE) diff --git a/contrib/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/contrib/llvm/lib/Analysis/LoopAccessAnalysis.cpp index bf8007213097..4ba12583ff83 100644 --- a/contrib/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/contrib/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -135,21 +135,6 @@ bool VectorizerParams::isInterleaveForced() { return ::VectorizationInterleave.getNumOccurrences() > 0; } -void LoopAccessReport::emitAnalysis(const LoopAccessReport &Message, - const Loop *TheLoop, const char *PassName, - OptimizationRemarkEmitter &ORE) { - DebugLoc DL = TheLoop->getStartLoc(); - const Value *V = TheLoop->getHeader(); - if (const Instruction *I = Message.getInstr()) { - // If there is no debug location attached to the instruction, revert back to - // using the loop's. - if (I->getDebugLoc()) - DL = I->getDebugLoc(); - V = I->getParent(); - } - ORE.emitOptimizationRemarkAnalysis(PassName, DL, V, Message.str()); -} - Value *llvm::stripIntegerCast(Value *V) { if (auto *CI = dyn_cast<CastInst>(V)) if (CI->getOperand(0)->getType()->isIntegerTy()) @@ -172,11 +157,6 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE, // Strip casts. StrideVal = stripIntegerCast(StrideVal); - // Replace symbolic stride by one. - Value *One = ConstantInt::get(StrideVal->getType(), 1); - ValueToValueMap RewriteMap; - RewriteMap[StrideVal] = One; - ScalarEvolution *SE = PSE.getSE(); const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal)); const auto *CT = @@ -518,7 +498,7 @@ class AccessAnalysis { public: /// \brief Read or write access location. typedef PointerIntPair<Value *, 1, bool> MemAccessInfo; - typedef SmallPtrSet<MemAccessInfo, 8> MemAccessInfoSet; + typedef SmallVector<MemAccessInfo, 8> MemAccessInfoList; AccessAnalysis(const DataLayout &Dl, AliasAnalysis *AA, LoopInfo *LI, MemoryDepChecker::DepCandidates &DA, @@ -570,7 +550,7 @@ public: DepChecker.clearDependences(); } - MemAccessInfoSet &getDependenciesToCheck() { return CheckDeps; } + MemAccessInfoList &getDependenciesToCheck() { return CheckDeps; } private: typedef SetVector<MemAccessInfo> PtrAccessSet; @@ -584,8 +564,8 @@ private: const DataLayout &DL; - /// Set of accesses that need a further dependence check. - MemAccessInfoSet CheckDeps; + /// List of accesses that need a further dependence check. + MemAccessInfoList CheckDeps; /// Set of pointers that are read only. SmallPtrSet<Value*, 16> ReadOnlyPtr; @@ -842,7 +822,7 @@ void AccessAnalysis::processMemAccesses() { // there is no other write to the ptr - this is an optimization to // catch "a[i] = a[i] + " without having to do a dependence check). if ((IsWrite || IsReadOnlyPtr) && SetHasWrite) { - CheckDeps.insert(Access); + CheckDeps.push_back(Access); IsRTCheckAnalysisNeeded = true; } @@ -1205,6 +1185,73 @@ bool MemoryDepChecker::couldPreventStoreLoadForward(uint64_t Distance, return false; } +/// Given a non-constant (unknown) dependence-distance \p Dist between two +/// memory accesses, that have the same stride whose absolute value is given +/// in \p Stride, and that have the same type size \p TypeByteSize, +/// in a loop whose takenCount is \p BackedgeTakenCount, check if it is +/// possible to prove statically that the dependence distance is larger +/// than the range that the accesses will travel through the execution of +/// the loop. If so, return true; false otherwise. This is useful for +/// example in loops such as the following (PR31098): +/// for (i = 0; i < D; ++i) { +/// = out[i]; +/// out[i+D] = +/// } +static bool isSafeDependenceDistance(const DataLayout &DL, ScalarEvolution &SE, + const SCEV &BackedgeTakenCount, + const SCEV &Dist, uint64_t Stride, + uint64_t TypeByteSize) { + + // If we can prove that + // (**) |Dist| > BackedgeTakenCount * Step + // where Step is the absolute stride of the memory accesses in bytes, + // then there is no dependence. + // + // Ratioanle: + // We basically want to check if the absolute distance (|Dist/Step|) + // is >= the loop iteration count (or > BackedgeTakenCount). + // This is equivalent to the Strong SIV Test (Practical Dependence Testing, + // Section 4.2.1); Note, that for vectorization it is sufficient to prove + // that the dependence distance is >= VF; This is checked elsewhere. + // But in some cases we can prune unknown dependence distances early, and + // even before selecting the VF, and without a runtime test, by comparing + // the distance against the loop iteration count. Since the vectorized code + // will be executed only if LoopCount >= VF, proving distance >= LoopCount + // also guarantees that distance >= VF. + // + const uint64_t ByteStride = Stride * TypeByteSize; + const SCEV *Step = SE.getConstant(BackedgeTakenCount.getType(), ByteStride); + const SCEV *Product = SE.getMulExpr(&BackedgeTakenCount, Step); + + const SCEV *CastedDist = &Dist; + const SCEV *CastedProduct = Product; + uint64_t DistTypeSize = DL.getTypeAllocSize(Dist.getType()); + uint64_t ProductTypeSize = DL.getTypeAllocSize(Product->getType()); + + // The dependence distance can be positive/negative, so we sign extend Dist; + // The multiplication of the absolute stride in bytes and the + // backdgeTakenCount is non-negative, so we zero extend Product. + if (DistTypeSize > ProductTypeSize) + CastedProduct = SE.getZeroExtendExpr(Product, Dist.getType()); + else + CastedDist = SE.getNoopOrSignExtend(&Dist, Product->getType()); + + // Is Dist - (BackedgeTakenCount * Step) > 0 ? + // (If so, then we have proven (**) because |Dist| >= Dist) + const SCEV *Minus = SE.getMinusSCEV(CastedDist, CastedProduct); + if (SE.isKnownPositive(Minus)) + return true; + + // Second try: Is -Dist - (BackedgeTakenCount * Step) > 0 ? + // (If so, then we have proven (**) because |Dist| >= -1*Dist) + const SCEV *NegDist = SE.getNegativeSCEV(CastedDist); + Minus = SE.getMinusSCEV(NegDist, CastedProduct); + if (SE.isKnownPositive(Minus)) + return true; + + return false; +} + /// \brief Check the dependence for two accesses with the same stride \p Stride. /// \p Distance is the positive distance and \p TypeByteSize is type size in /// bytes. @@ -1292,21 +1339,26 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, return Dependence::Unknown; } + Type *ATy = APtr->getType()->getPointerElementType(); + Type *BTy = BPtr->getType()->getPointerElementType(); + auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout(); + uint64_t TypeByteSize = DL.getTypeAllocSize(ATy); + uint64_t Stride = std::abs(StrideAPtr); const SCEVConstant *C = dyn_cast<SCEVConstant>(Dist); if (!C) { + if (TypeByteSize == DL.getTypeAllocSize(BTy) && + isSafeDependenceDistance(DL, *(PSE.getSE()), + *(PSE.getBackedgeTakenCount()), *Dist, Stride, + TypeByteSize)) + return Dependence::NoDep; + DEBUG(dbgs() << "LAA: Dependence because of non-constant distance\n"); ShouldRetryWithRuntimeCheck = true; return Dependence::Unknown; } - Type *ATy = APtr->getType()->getPointerElementType(); - Type *BTy = BPtr->getType()->getPointerElementType(); - auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout(); - uint64_t TypeByteSize = DL.getTypeAllocSize(ATy); - const APInt &Val = C->getAPInt(); int64_t Distance = Val.getSExtValue(); - uint64_t Stride = std::abs(StrideAPtr); // Attempt to prove strided accesses independent. if (std::abs(Distance) > 0 && Stride > 1 && ATy == BTy && @@ -1427,12 +1479,14 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, } bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, - MemAccessInfoSet &CheckDeps, + MemAccessInfoList &CheckDeps, const ValueToValueMap &Strides) { MaxSafeDepDistBytes = -1; - while (!CheckDeps.empty()) { - MemAccessInfo CurAccess = *CheckDeps.begin(); + SmallPtrSet<MemAccessInfo, 8> Visited; + for (MemAccessInfo CurAccess : CheckDeps) { + if (Visited.count(CurAccess)) + continue; // Get the relevant memory access set. EquivalenceClasses<MemAccessInfo>::iterator I = @@ -1446,7 +1500,7 @@ bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, // Check every access pair. while (AI != AE) { - CheckDeps.erase(*AI); + Visited.insert(*AI); EquivalenceClasses<MemAccessInfo>::member_iterator OI = std::next(AI); while (OI != AE) { // Check every accessing instruction pair in program order. @@ -1885,7 +1939,10 @@ expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop, Value *NewPtr = (Inst && TheLoop->contains(Inst)) ? Exp.expandCodeFor(Sc, PtrArithTy, Loc) : Ptr; - return {NewPtr, NewPtr}; + // We must return a half-open range, which means incrementing Sc. + const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy)); + Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc); + return {NewPtr, NewPtrPlusOne}; } else { Value *Start = nullptr, *End = nullptr; DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); diff --git a/contrib/llvm/lib/Analysis/LoopAnalysisManager.cpp b/contrib/llvm/lib/Analysis/LoopAnalysisManager.cpp index 5be3ee341c9c..e4a0f90b2f71 100644 --- a/contrib/llvm/lib/Analysis/LoopAnalysisManager.cpp +++ b/contrib/llvm/lib/Analysis/LoopAnalysisManager.cpp @@ -31,24 +31,10 @@ bool LoopAnalysisManagerFunctionProxy::Result::invalidate( FunctionAnalysisManager::Invalidator &Inv) { // First compute the sequence of IR units covered by this proxy. We will want // to visit this in postorder, but because this is a tree structure we can do - // this by building a preorder sequence and walking it in reverse. - SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist; - // Note that we want to walk the roots in reverse order because we will end - // up reversing the preorder sequence. However, it happens that the loop nest - // roots are in reverse order within the LoopInfo object. So we just walk - // forward here. - // FIXME: If we change the order of LoopInfo we will want to add a reverse - // here. - for (Loop *RootL : *LI) { - assert(PreOrderWorklist.empty() && - "Must start with an empty preorder walk worklist."); - PreOrderWorklist.push_back(RootL); - do { - Loop *L = PreOrderWorklist.pop_back_val(); - PreOrderWorklist.append(L->begin(), L->end()); - PreOrderLoops.push_back(L); - } while (!PreOrderWorklist.empty()); - } + // this by building a preorder sequence and walking it backwards. We also + // want siblings in forward program order to match the LoopPassManager so we + // get the preorder with siblings reversed. + SmallVector<Loop *, 4> PreOrderLoops = LI->getLoopsInReverseSiblingPreorder(); // If this proxy or the loop info is going to be invalidated, we also need // to clear all the keys coming from that analysis. We also completely blow @@ -145,7 +131,6 @@ LoopAnalysisManagerFunctionProxy::run(Function &F, PreservedAnalyses llvm::getLoopPassPreservedAnalyses() { PreservedAnalyses PA; - PA.preserve<AssumptionAnalysis>(); PA.preserve<DominatorTreeAnalysis>(); PA.preserve<LoopAnalysis>(); PA.preserve<LoopAnalysisManagerFunctionProxy>(); diff --git a/contrib/llvm/lib/Analysis/LoopInfo.cpp b/contrib/llvm/lib/Analysis/LoopInfo.cpp index f449ce94d57c..ff68810abb82 100644 --- a/contrib/llvm/lib/Analysis/LoopInfo.cpp +++ b/contrib/llvm/lib/Analysis/LoopInfo.cpp @@ -40,9 +40,9 @@ template class llvm::LoopInfoBase<BasicBlock, Loop>; // Always verify loopinfo if expensive checking is enabled. #ifdef EXPENSIVE_CHECKS -static bool VerifyLoopInfo = true; +bool llvm::VerifyLoopInfo = true; #else -static bool VerifyLoopInfo = false; +bool llvm::VerifyLoopInfo = false; #endif static cl::opt<bool,true> VerifyLoopInfoX("verify-loop-info", cl::location(VerifyLoopInfo), @@ -211,9 +211,11 @@ bool Loop::isSafeToClone() const { MDNode *Loop::getLoopID() const { MDNode *LoopID = nullptr; - if (isLoopSimplifyForm()) { - LoopID = getLoopLatch()->getTerminator()->getMetadata(LLVMContext::MD_loop); + if (BasicBlock *Latch = getLoopLatch()) { + LoopID = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop); } else { + assert(!getLoopLatch() && + "The loop should have no single latch at this point"); // Go through each predecessor of the loop header and check the // terminator for the metadata. BasicBlock *H = getHeader(); @@ -248,11 +250,12 @@ void Loop::setLoopID(MDNode *LoopID) const { assert(LoopID->getNumOperands() > 0 && "Loop ID needs at least one operand"); assert(LoopID->getOperand(0) == LoopID && "Loop ID should refer to itself"); - if (isLoopSimplifyForm()) { - getLoopLatch()->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); + if (BasicBlock *Latch = getLoopLatch()) { + Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); return; } + assert(!getLoopLatch() && "The loop should have no single latch at this point"); BasicBlock *H = getHeader(); for (BasicBlock *BB : this->blocks()) { TerminatorInst *TI = BB->getTerminator(); @@ -610,6 +613,15 @@ LoopInfo::LoopInfo(const DominatorTreeBase<BasicBlock> &DomTree) { analyze(DomTree); } +bool LoopInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<LoopAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + void LoopInfo::markAsRemoved(Loop *Unloop) { assert(!Unloop->isInvalid() && "Loop has already been removed"); Unloop->invalidate(); diff --git a/contrib/llvm/lib/Analysis/LoopPass.cpp b/contrib/llvm/lib/Analysis/LoopPass.cpp index 3f4a07942154..e988f6444a58 100644 --- a/contrib/llvm/lib/Analysis/LoopPass.cpp +++ b/contrib/llvm/lib/Analysis/LoopPass.cpp @@ -54,6 +54,8 @@ public: } return false; } + + StringRef getPassName() const override { return "Print Loop IR"; } }; char PrintLoopPassWrapper::ID = 0; @@ -71,30 +73,23 @@ LPPassManager::LPPassManager() CurrentLoop = nullptr; } -// Inset loop into loop nest (LoopInfo) and loop queue (LQ). -Loop &LPPassManager::addLoop(Loop *ParentLoop) { - // Create a new loop. LI will take ownership. - Loop *L = new Loop(); - - // Insert into the loop nest and the loop queue. - if (!ParentLoop) { +// Insert loop into loop nest (LoopInfo) and loop queue (LQ). +void LPPassManager::addLoop(Loop &L) { + if (!L.getParentLoop()) { // This is the top level loop. - LI->addTopLevelLoop(L); - LQ.push_front(L); - return *L; + LQ.push_front(&L); + return; } - ParentLoop->addChildLoop(L); // Insert L into the loop queue after the parent loop. for (auto I = LQ.begin(), E = LQ.end(); I != E; ++I) { - if (*I == L->getParentLoop()) { + if (*I == L.getParentLoop()) { // deque does not support insert after. ++I; - LQ.insert(I, 1, L); - break; + LQ.insert(I, 1, &L); + return; } } - return *L; } /// cloneBasicBlockSimpleAnalysis - Invoke cloneBasicBlockAnalysis hook for diff --git a/contrib/llvm/lib/Analysis/MemoryBuiltins.cpp b/contrib/llvm/lib/Analysis/MemoryBuiltins.cpp index 2d8274040d39..7983d62c2f7a 100644 --- a/contrib/llvm/lib/Analysis/MemoryBuiltins.cpp +++ b/contrib/llvm/lib/Analysis/MemoryBuiltins.cpp @@ -37,6 +37,7 @@ enum AllocType : uint8_t { CallocLike = 1<<2, // allocates + bzero ReallocLike = 1<<3, // reallocates StrDupLike = 1<<4, + MallocOrCallocLike = MallocLike | CallocLike, AllocLike = MallocLike | CallocLike | StrDupLike, AnyAlloc = AllocLike | ReallocLike }; @@ -50,35 +51,35 @@ struct AllocFnsTy { // FIXME: certain users need more information. E.g., SimplifyLibCalls needs to // know which functions are nounwind, noalias, nocapture parameters, etc. -static const std::pair<LibFunc::Func, AllocFnsTy> AllocationFnData[] = { - {LibFunc::malloc, {MallocLike, 1, 0, -1}}, - {LibFunc::valloc, {MallocLike, 1, 0, -1}}, - {LibFunc::Znwj, {OpNewLike, 1, 0, -1}}, // new(unsigned int) - {LibFunc::ZnwjRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow) - {LibFunc::Znwm, {OpNewLike, 1, 0, -1}}, // new(unsigned long) - {LibFunc::ZnwmRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned long, nothrow) - {LibFunc::Znaj, {OpNewLike, 1, 0, -1}}, // new[](unsigned int) - {LibFunc::ZnajRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow) - {LibFunc::Znam, {OpNewLike, 1, 0, -1}}, // new[](unsigned long) - {LibFunc::ZnamRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned long, nothrow) - {LibFunc::msvc_new_int, {OpNewLike, 1, 0, -1}}, // new(unsigned int) - {LibFunc::msvc_new_int_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow) - {LibFunc::msvc_new_longlong, {OpNewLike, 1, 0, -1}}, // new(unsigned long long) - {LibFunc::msvc_new_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned long long, nothrow) - {LibFunc::msvc_new_array_int, {OpNewLike, 1, 0, -1}}, // new[](unsigned int) - {LibFunc::msvc_new_array_int_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow) - {LibFunc::msvc_new_array_longlong, {OpNewLike, 1, 0, -1}}, // new[](unsigned long long) - {LibFunc::msvc_new_array_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned long long, nothrow) - {LibFunc::calloc, {CallocLike, 2, 0, 1}}, - {LibFunc::realloc, {ReallocLike, 2, 1, -1}}, - {LibFunc::reallocf, {ReallocLike, 2, 1, -1}}, - {LibFunc::strdup, {StrDupLike, 1, -1, -1}}, - {LibFunc::strndup, {StrDupLike, 2, 1, -1}} +static const std::pair<LibFunc, AllocFnsTy> AllocationFnData[] = { + {LibFunc_malloc, {MallocLike, 1, 0, -1}}, + {LibFunc_valloc, {MallocLike, 1, 0, -1}}, + {LibFunc_Znwj, {OpNewLike, 1, 0, -1}}, // new(unsigned int) + {LibFunc_ZnwjRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow) + {LibFunc_Znwm, {OpNewLike, 1, 0, -1}}, // new(unsigned long) + {LibFunc_ZnwmRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned long, nothrow) + {LibFunc_Znaj, {OpNewLike, 1, 0, -1}}, // new[](unsigned int) + {LibFunc_ZnajRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow) + {LibFunc_Znam, {OpNewLike, 1, 0, -1}}, // new[](unsigned long) + {LibFunc_ZnamRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned long, nothrow) + {LibFunc_msvc_new_int, {OpNewLike, 1, 0, -1}}, // new(unsigned int) + {LibFunc_msvc_new_int_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow) + {LibFunc_msvc_new_longlong, {OpNewLike, 1, 0, -1}}, // new(unsigned long long) + {LibFunc_msvc_new_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned long long, nothrow) + {LibFunc_msvc_new_array_int, {OpNewLike, 1, 0, -1}}, // new[](unsigned int) + {LibFunc_msvc_new_array_int_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow) + {LibFunc_msvc_new_array_longlong, {OpNewLike, 1, 0, -1}}, // new[](unsigned long long) + {LibFunc_msvc_new_array_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned long long, nothrow) + {LibFunc_calloc, {CallocLike, 2, 0, 1}}, + {LibFunc_realloc, {ReallocLike, 2, 1, -1}}, + {LibFunc_reallocf, {ReallocLike, 2, 1, -1}}, + {LibFunc_strdup, {StrDupLike, 1, -1, -1}}, + {LibFunc_strndup, {StrDupLike, 2, 1, -1}} // TODO: Handle "int posix_memalign(void **, size_t, size_t)" }; -static Function *getCalledFunction(const Value *V, bool LookThroughBitCast, - bool &IsNoBuiltin) { +static const Function *getCalledFunction(const Value *V, bool LookThroughBitCast, + bool &IsNoBuiltin) { // Don't care about intrinsics in this case. if (isa<IntrinsicInst>(V)) return nullptr; @@ -86,13 +87,13 @@ static Function *getCalledFunction(const Value *V, bool LookThroughBitCast, if (LookThroughBitCast) V = V->stripPointerCasts(); - CallSite CS(const_cast<Value*>(V)); + ImmutableCallSite CS(V); if (!CS.getInstruction()) return nullptr; IsNoBuiltin = CS.isNoBuiltin(); - Function *Callee = CS.getCalledFunction(); + const Function *Callee = CS.getCalledFunction(); if (!Callee || !Callee->isDeclaration()) return nullptr; return Callee; @@ -106,12 +107,12 @@ getAllocationDataForFunction(const Function *Callee, AllocType AllocTy, const TargetLibraryInfo *TLI) { // Make sure that the function is available. StringRef FnName = Callee->getName(); - LibFunc::Func TLIFn; + LibFunc TLIFn; if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn)) return None; const auto *Iter = find_if( - AllocationFnData, [TLIFn](const std::pair<LibFunc::Func, AllocFnsTy> &P) { + AllocationFnData, [TLIFn](const std::pair<LibFunc, AllocFnsTy> &P) { return P.first == TLIFn; }); @@ -183,7 +184,7 @@ static Optional<AllocFnsTy> getAllocationSize(const Value *V, static bool hasNoAliasAttr(const Value *V, bool LookThroughBitCast) { ImmutableCallSite CS(LookThroughBitCast ? V->stripPointerCasts() : V); - return CS && CS.paramHasAttr(AttributeSet::ReturnIndex, Attribute::NoAlias); + return CS && CS.hasRetAttr(Attribute::NoAlias); } @@ -220,6 +221,14 @@ bool llvm::isCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, } /// \brief Tests if a value is a call or invoke to a library function that +/// allocates memory similiar to malloc or calloc. +bool llvm::isMallocOrCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, + bool LookThroughBitCast) { + return getAllocationData(V, MallocOrCallocLike, TLI, + LookThroughBitCast).hasValue(); +} + +/// \brief Tests if a value is a call or invoke to a library function that /// allocates memory (either malloc, calloc, or strdup like). bool llvm::isAllocLikeFn(const Value *V, const TargetLibraryInfo *TLI, bool LookThroughBitCast) { @@ -333,33 +342,33 @@ const CallInst *llvm::isFreeCall(const Value *I, const TargetLibraryInfo *TLI) { return nullptr; StringRef FnName = Callee->getName(); - LibFunc::Func TLIFn; + LibFunc TLIFn; if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn)) return nullptr; unsigned ExpectedNumParams; - if (TLIFn == LibFunc::free || - TLIFn == LibFunc::ZdlPv || // operator delete(void*) - TLIFn == LibFunc::ZdaPv || // operator delete[](void*) - TLIFn == LibFunc::msvc_delete_ptr32 || // operator delete(void*) - TLIFn == LibFunc::msvc_delete_ptr64 || // operator delete(void*) - TLIFn == LibFunc::msvc_delete_array_ptr32 || // operator delete[](void*) - TLIFn == LibFunc::msvc_delete_array_ptr64) // operator delete[](void*) + if (TLIFn == LibFunc_free || + TLIFn == LibFunc_ZdlPv || // operator delete(void*) + TLIFn == LibFunc_ZdaPv || // operator delete[](void*) + TLIFn == LibFunc_msvc_delete_ptr32 || // operator delete(void*) + TLIFn == LibFunc_msvc_delete_ptr64 || // operator delete(void*) + TLIFn == LibFunc_msvc_delete_array_ptr32 || // operator delete[](void*) + TLIFn == LibFunc_msvc_delete_array_ptr64) // operator delete[](void*) ExpectedNumParams = 1; - else if (TLIFn == LibFunc::ZdlPvj || // delete(void*, uint) - TLIFn == LibFunc::ZdlPvm || // delete(void*, ulong) - TLIFn == LibFunc::ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) - TLIFn == LibFunc::ZdaPvj || // delete[](void*, uint) - TLIFn == LibFunc::ZdaPvm || // delete[](void*, ulong) - TLIFn == LibFunc::ZdaPvRKSt9nothrow_t || // delete[](void*, nothrow) - TLIFn == LibFunc::msvc_delete_ptr32_int || // delete(void*, uint) - TLIFn == LibFunc::msvc_delete_ptr64_longlong || // delete(void*, ulonglong) - TLIFn == LibFunc::msvc_delete_ptr32_nothrow || // delete(void*, nothrow) - TLIFn == LibFunc::msvc_delete_ptr64_nothrow || // delete(void*, nothrow) - TLIFn == LibFunc::msvc_delete_array_ptr32_int || // delete[](void*, uint) - TLIFn == LibFunc::msvc_delete_array_ptr64_longlong || // delete[](void*, ulonglong) - TLIFn == LibFunc::msvc_delete_array_ptr32_nothrow || // delete[](void*, nothrow) - TLIFn == LibFunc::msvc_delete_array_ptr64_nothrow) // delete[](void*, nothrow) + else if (TLIFn == LibFunc_ZdlPvj || // delete(void*, uint) + TLIFn == LibFunc_ZdlPvm || // delete(void*, ulong) + TLIFn == LibFunc_ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) + TLIFn == LibFunc_ZdaPvj || // delete[](void*, uint) + TLIFn == LibFunc_ZdaPvm || // delete[](void*, ulong) + TLIFn == LibFunc_ZdaPvRKSt9nothrow_t || // delete[](void*, nothrow) + TLIFn == LibFunc_msvc_delete_ptr32_int || // delete(void*, uint) + TLIFn == LibFunc_msvc_delete_ptr64_longlong || // delete(void*, ulonglong) + TLIFn == LibFunc_msvc_delete_ptr32_nothrow || // delete(void*, nothrow) + TLIFn == LibFunc_msvc_delete_ptr64_nothrow || // delete(void*, nothrow) + TLIFn == LibFunc_msvc_delete_array_ptr32_int || // delete[](void*, uint) + TLIFn == LibFunc_msvc_delete_array_ptr64_longlong || // delete[](void*, ulonglong) + TLIFn == LibFunc_msvc_delete_array_ptr32_nothrow || // delete[](void*, nothrow) + TLIFn == LibFunc_msvc_delete_array_ptr64_nothrow) // delete[](void*, nothrow) ExpectedNumParams = 2; else return nullptr; @@ -394,10 +403,8 @@ static APInt getSizeWithOverflow(const SizeOffsetType &Data) { /// If RoundToAlign is true, then Size is rounded up to the aligment of allocas, /// byval arguments, and global variables. bool llvm::getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, - const TargetLibraryInfo *TLI, bool RoundToAlign, - llvm::ObjSizeMode Mode) { - ObjectSizeOffsetVisitor Visitor(DL, TLI, Ptr->getContext(), - RoundToAlign, Mode); + const TargetLibraryInfo *TLI, ObjectSizeOpts Opts) { + ObjectSizeOffsetVisitor Visitor(DL, TLI, Ptr->getContext(), Opts); SizeOffsetType Data = Visitor.compute(const_cast<Value*>(Ptr)); if (!Visitor.bothKnown(Data)) return false; @@ -414,19 +421,23 @@ ConstantInt *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize, "ObjectSize must be a call to llvm.objectsize!"); bool MaxVal = cast<ConstantInt>(ObjectSize->getArgOperand(1))->isZero(); - ObjSizeMode Mode; + ObjectSizeOpts EvalOptions; // Unless we have to fold this to something, try to be as accurate as // possible. if (MustSucceed) - Mode = MaxVal ? ObjSizeMode::Max : ObjSizeMode::Min; + EvalOptions.EvalMode = + MaxVal ? ObjectSizeOpts::Mode::Max : ObjectSizeOpts::Mode::Min; else - Mode = ObjSizeMode::Exact; + EvalOptions.EvalMode = ObjectSizeOpts::Mode::Exact; + + EvalOptions.NullIsUnknownSize = + cast<ConstantInt>(ObjectSize->getArgOperand(2))->isOne(); // FIXME: Does it make sense to just return a failure value if the size won't // fit in the output and `!MustSucceed`? uint64_t Size; auto *ResultType = cast<IntegerType>(ObjectSize->getType()); - if (getObjectSize(ObjectSize->getArgOperand(0), Size, DL, TLI, false, Mode) && + if (getObjectSize(ObjectSize->getArgOperand(0), Size, DL, TLI, EvalOptions) && isUIntN(ResultType->getBitWidth(), Size)) return ConstantInt::get(ResultType, Size); @@ -443,7 +454,7 @@ STATISTIC(ObjectVisitorLoad, APInt ObjectSizeOffsetVisitor::align(APInt Size, uint64_t Align) { - if (RoundToAlign && Align) + if (Options.RoundToAlign && Align) return APInt(IntTyBits, alignTo(Size.getZExtValue(), Align)); return Size; } @@ -451,9 +462,8 @@ APInt ObjectSizeOffsetVisitor::align(APInt Size, uint64_t Align) { ObjectSizeOffsetVisitor::ObjectSizeOffsetVisitor(const DataLayout &DL, const TargetLibraryInfo *TLI, LLVMContext &Context, - bool RoundToAlign, - ObjSizeMode Mode) - : DL(DL), TLI(TLI), RoundToAlign(RoundToAlign), Mode(Mode) { + ObjectSizeOpts Options) + : DL(DL), TLI(TLI), Options(Options) { // Pointer size must be rechecked for each object visited since it could have // a different address space. } @@ -596,7 +606,9 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitCallSite(CallSite CS) { } SizeOffsetType -ObjectSizeOffsetVisitor::visitConstantPointerNull(ConstantPointerNull&) { +ObjectSizeOffsetVisitor::visitConstantPointerNull(ConstantPointerNull& CPN) { + if (Options.NullIsUnknownSize && CPN.getType()->getAddressSpace() == 0) + return unknown(); return std::make_pair(Zero, Zero); } @@ -663,12 +675,12 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitSelectInst(SelectInst &I) { if (TrueResult == FalseResult) { return TrueSide; } - if (Mode == ObjSizeMode::Min) { + if (Options.EvalMode == ObjectSizeOpts::Mode::Min) { if (TrueResult.slt(FalseResult)) return TrueSide; return FalseSide; } - if (Mode == ObjSizeMode::Max) { + if (Options.EvalMode == ObjectSizeOpts::Mode::Max) { if (TrueResult.sgt(FalseResult)) return TrueSide; return FalseSide; @@ -719,7 +731,10 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute(Value *V) { } SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute_(Value *V) { - ObjectSizeOffsetVisitor Visitor(DL, TLI, Context, RoundToAlign); + ObjectSizeOpts ObjSizeOptions; + ObjSizeOptions.RoundToAlign = RoundToAlign; + + ObjectSizeOffsetVisitor Visitor(DL, TLI, Context, ObjSizeOptions); SizeOffsetType Const = Visitor.compute(V); if (Visitor.bothKnown(Const)) return std::make_pair(ConstantInt::get(Context, Const.first), diff --git a/contrib/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp b/contrib/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp index 66a0d145dcd8..188885063b39 100644 --- a/contrib/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp +++ b/contrib/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -691,6 +691,7 @@ MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( // load query, we can safely ignore it (scan past it). if (isLoad) continue; + LLVM_FALLTHROUGH; default: // Otherwise, there is a potential dependence. Return a clobber. return MemDepResult::getClobber(Inst); diff --git a/contrib/llvm/lib/Analysis/MemoryLocation.cpp b/contrib/llvm/lib/Analysis/MemoryLocation.cpp index a0ae72f1415f..9db6c499129a 100644 --- a/contrib/llvm/lib/Analysis/MemoryLocation.cpp +++ b/contrib/llvm/lib/Analysis/MemoryLocation.cpp @@ -142,9 +142,9 @@ MemoryLocation MemoryLocation::getForArgument(ImmutableCallSite CS, // for memcpy/memset. This is particularly important because the // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16 // whenever possible. - LibFunc::Func F; + LibFunc F; if (CS.getCalledFunction() && TLI.getLibFunc(*CS.getCalledFunction(), F) && - F == LibFunc::memset_pattern16 && TLI.has(F)) { + F == LibFunc_memset_pattern16 && TLI.has(F)) { assert((ArgIdx == 0 || ArgIdx == 1) && "Invalid argument index for memset_pattern16"); if (ArgIdx == 1) diff --git a/contrib/llvm/lib/Analysis/MemorySSA.cpp b/contrib/llvm/lib/Analysis/MemorySSA.cpp new file mode 100644 index 000000000000..e0e04a91410f --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemorySSA.cpp @@ -0,0 +1,2075 @@ +//===-- MemorySSA.cpp - Memory SSA Builder---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------===// +// +// This file implements the MemorySSA class. +// +//===----------------------------------------------------------------===// +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Transforms/Scalar.h" +#include <algorithm> + +#define DEBUG_TYPE "memoryssa" +using namespace llvm; +INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, + true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, + true) + +INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa", + "Memory SSA Printer", false, false) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", + "Memory SSA Printer", false, false) + +static cl::opt<unsigned> MaxCheckLimit( + "memssa-check-limit", cl::Hidden, cl::init(100), + cl::desc("The maximum number of stores/phis MemorySSA" + "will consider trying to walk past (default = 100)")); + +static cl::opt<bool> + VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden, + cl::desc("Verify MemorySSA in legacy printer pass.")); + +namespace llvm { +/// \brief An assembly annotator class to print Memory SSA information in +/// comments. +class MemorySSAAnnotatedWriter : public AssemblyAnnotationWriter { + friend class MemorySSA; + const MemorySSA *MSSA; + +public: + MemorySSAAnnotatedWriter(const MemorySSA *M) : MSSA(M) {} + + virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, + formatted_raw_ostream &OS) { + if (MemoryAccess *MA = MSSA->getMemoryAccess(BB)) + OS << "; " << *MA << "\n"; + } + + virtual void emitInstructionAnnot(const Instruction *I, + formatted_raw_ostream &OS) { + if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) + OS << "; " << *MA << "\n"; + } +}; +} + +namespace { +/// Our current alias analysis API differentiates heavily between calls and +/// non-calls, and functions called on one usually assert on the other. +/// This class encapsulates the distinction to simplify other code that wants +/// "Memory affecting instructions and related data" to use as a key. +/// For example, this class is used as a densemap key in the use optimizer. +class MemoryLocOrCall { +public: + MemoryLocOrCall() : IsCall(false) {} + MemoryLocOrCall(MemoryUseOrDef *MUD) + : MemoryLocOrCall(MUD->getMemoryInst()) {} + MemoryLocOrCall(const MemoryUseOrDef *MUD) + : MemoryLocOrCall(MUD->getMemoryInst()) {} + + MemoryLocOrCall(Instruction *Inst) { + if (ImmutableCallSite(Inst)) { + IsCall = true; + CS = ImmutableCallSite(Inst); + } else { + IsCall = false; + // There is no such thing as a memorylocation for a fence inst, and it is + // unique in that regard. + if (!isa<FenceInst>(Inst)) + Loc = MemoryLocation::get(Inst); + } + } + + explicit MemoryLocOrCall(const MemoryLocation &Loc) + : IsCall(false), Loc(Loc) {} + + bool IsCall; + ImmutableCallSite getCS() const { + assert(IsCall); + return CS; + } + MemoryLocation getLoc() const { + assert(!IsCall); + return Loc; + } + + bool operator==(const MemoryLocOrCall &Other) const { + if (IsCall != Other.IsCall) + return false; + + if (IsCall) + return CS.getCalledValue() == Other.CS.getCalledValue(); + return Loc == Other.Loc; + } + +private: + union { + ImmutableCallSite CS; + MemoryLocation Loc; + }; +}; +} + +namespace llvm { +template <> struct DenseMapInfo<MemoryLocOrCall> { + static inline MemoryLocOrCall getEmptyKey() { + return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getEmptyKey()); + } + static inline MemoryLocOrCall getTombstoneKey() { + return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getTombstoneKey()); + } + static unsigned getHashValue(const MemoryLocOrCall &MLOC) { + if (MLOC.IsCall) + return hash_combine(MLOC.IsCall, + DenseMapInfo<const Value *>::getHashValue( + MLOC.getCS().getCalledValue())); + return hash_combine( + MLOC.IsCall, DenseMapInfo<MemoryLocation>::getHashValue(MLOC.getLoc())); + } + static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) { + return LHS == RHS; + } +}; + +enum class Reorderability { Always, IfNoAlias, Never }; + +/// This does one-way checks to see if Use could theoretically be hoisted above +/// MayClobber. This will not check the other way around. +/// +/// This assumes that, for the purposes of MemorySSA, Use comes directly after +/// MayClobber, with no potentially clobbering operations in between them. +/// (Where potentially clobbering ops are memory barriers, aliased stores, etc.) +static Reorderability getLoadReorderability(const LoadInst *Use, + const LoadInst *MayClobber) { + bool VolatileUse = Use->isVolatile(); + bool VolatileClobber = MayClobber->isVolatile(); + // Volatile operations may never be reordered with other volatile operations. + if (VolatileUse && VolatileClobber) + return Reorderability::Never; + + // The lang ref allows reordering of volatile and non-volatile operations. + // Whether an aliasing nonvolatile load and volatile load can be reordered, + // though, is ambiguous. Because it may not be best to exploit this ambiguity, + // we only allow volatile/non-volatile reordering if the volatile and + // non-volatile operations don't alias. + Reorderability Result = VolatileUse || VolatileClobber + ? Reorderability::IfNoAlias + : Reorderability::Always; + + // If a load is seq_cst, it cannot be moved above other loads. If its ordering + // is weaker, it can be moved above other loads. We just need to be sure that + // MayClobber isn't an acquire load, because loads can't be moved above + // acquire loads. + // + // Note that this explicitly *does* allow the free reordering of monotonic (or + // weaker) loads of the same address. + bool SeqCstUse = Use->getOrdering() == AtomicOrdering::SequentiallyConsistent; + bool MayClobberIsAcquire = isAtLeastOrStrongerThan(MayClobber->getOrdering(), + AtomicOrdering::Acquire); + if (SeqCstUse || MayClobberIsAcquire) + return Reorderability::Never; + return Result; +} + +static bool instructionClobbersQuery(MemoryDef *MD, + const MemoryLocation &UseLoc, + const Instruction *UseInst, + AliasAnalysis &AA) { + Instruction *DefInst = MD->getMemoryInst(); + assert(DefInst && "Defining instruction not actually an instruction"); + ImmutableCallSite UseCS(UseInst); + + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) { + // These intrinsics will show up as affecting memory, but they are just + // markers. + switch (II->getIntrinsicID()) { + case Intrinsic::lifetime_start: + if (UseCS) + return false; + return AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), UseLoc); + case Intrinsic::lifetime_end: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::assume: + return false; + default: + break; + } + } + + if (UseCS) { + ModRefInfo I = AA.getModRefInfo(DefInst, UseCS); + return I != MRI_NoModRef; + } + + if (auto *DefLoad = dyn_cast<LoadInst>(DefInst)) { + if (auto *UseLoad = dyn_cast<LoadInst>(UseInst)) { + switch (getLoadReorderability(UseLoad, DefLoad)) { + case Reorderability::Always: + return false; + case Reorderability::Never: + return true; + case Reorderability::IfNoAlias: + return !AA.isNoAlias(UseLoc, MemoryLocation::get(DefLoad)); + } + } + } + + return AA.getModRefInfo(DefInst, UseLoc) & MRI_Mod; +} + +static bool instructionClobbersQuery(MemoryDef *MD, const MemoryUseOrDef *MU, + const MemoryLocOrCall &UseMLOC, + AliasAnalysis &AA) { + // FIXME: This is a temporary hack to allow a single instructionClobbersQuery + // to exist while MemoryLocOrCall is pushed through places. + if (UseMLOC.IsCall) + return instructionClobbersQuery(MD, MemoryLocation(), MU->getMemoryInst(), + AA); + return instructionClobbersQuery(MD, UseMLOC.getLoc(), MU->getMemoryInst(), + AA); +} + +// Return true when MD may alias MU, return false otherwise. +bool MemorySSAUtil::defClobbersUseOrDef(MemoryDef *MD, const MemoryUseOrDef *MU, + AliasAnalysis &AA) { + return instructionClobbersQuery(MD, MU, MemoryLocOrCall(MU), AA); +} +} + +namespace { +struct UpwardsMemoryQuery { + // True if our original query started off as a call + bool IsCall; + // The pointer location we started the query with. This will be empty if + // IsCall is true. + MemoryLocation StartingLoc; + // This is the instruction we were querying about. + const Instruction *Inst; + // The MemoryAccess we actually got called with, used to test local domination + const MemoryAccess *OriginalAccess; + + UpwardsMemoryQuery() + : IsCall(false), Inst(nullptr), OriginalAccess(nullptr) {} + + UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) + : IsCall(ImmutableCallSite(Inst)), Inst(Inst), OriginalAccess(Access) { + if (!IsCall) + StartingLoc = MemoryLocation::get(Inst); + } +}; + +static bool lifetimeEndsAt(MemoryDef *MD, const MemoryLocation &Loc, + AliasAnalysis &AA) { + Instruction *Inst = MD->getMemoryInst(); + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + case Intrinsic::lifetime_end: + return AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), Loc); + default: + return false; + } + } + return false; +} + +static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysis &AA, + const Instruction *I) { + // If the memory can't be changed, then loads of the memory can't be + // clobbered. + // + // FIXME: We should handle invariant groups, as well. It's a bit harder, + // because we need to pay close attention to invariant group barriers. + return isa<LoadInst>(I) && (I->getMetadata(LLVMContext::MD_invariant_load) || + AA.pointsToConstantMemory(cast<LoadInst>(I)-> + getPointerOperand())); +} + +/// Verifies that `Start` is clobbered by `ClobberAt`, and that nothing +/// inbetween `Start` and `ClobberAt` can clobbers `Start`. +/// +/// This is meant to be as simple and self-contained as possible. Because it +/// uses no cache, etc., it can be relatively expensive. +/// +/// \param Start The MemoryAccess that we want to walk from. +/// \param ClobberAt A clobber for Start. +/// \param StartLoc The MemoryLocation for Start. +/// \param MSSA The MemorySSA isntance that Start and ClobberAt belong to. +/// \param Query The UpwardsMemoryQuery we used for our search. +/// \param AA The AliasAnalysis we used for our search. +static void LLVM_ATTRIBUTE_UNUSED +checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt, + const MemoryLocation &StartLoc, const MemorySSA &MSSA, + const UpwardsMemoryQuery &Query, AliasAnalysis &AA) { + assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?"); + + if (MSSA.isLiveOnEntryDef(Start)) { + assert(MSSA.isLiveOnEntryDef(ClobberAt) && + "liveOnEntry must clobber itself"); + return; + } + + bool FoundClobber = false; + DenseSet<MemoryAccessPair> VisitedPhis; + SmallVector<MemoryAccessPair, 8> Worklist; + Worklist.emplace_back(Start, StartLoc); + // Walk all paths from Start to ClobberAt, while looking for clobbers. If one + // is found, complain. + while (!Worklist.empty()) { + MemoryAccessPair MAP = Worklist.pop_back_val(); + // All we care about is that nothing from Start to ClobberAt clobbers Start. + // We learn nothing from revisiting nodes. + if (!VisitedPhis.insert(MAP).second) + continue; + + for (MemoryAccess *MA : def_chain(MAP.first)) { + if (MA == ClobberAt) { + if (auto *MD = dyn_cast<MemoryDef>(MA)) { + // instructionClobbersQuery isn't essentially free, so don't use `|=`, + // since it won't let us short-circuit. + // + // Also, note that this can't be hoisted out of the `Worklist` loop, + // since MD may only act as a clobber for 1 of N MemoryLocations. + FoundClobber = + FoundClobber || MSSA.isLiveOnEntryDef(MD) || + instructionClobbersQuery(MD, MAP.second, Query.Inst, AA); + } + break; + } + + // We should never hit liveOnEntry, unless it's the clobber. + assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?"); + + if (auto *MD = dyn_cast<MemoryDef>(MA)) { + (void)MD; + assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) && + "Found clobber before reaching ClobberAt!"); + continue; + } + + assert(isa<MemoryPhi>(MA)); + Worklist.append(upward_defs_begin({MA, MAP.second}), upward_defs_end()); + } + } + + // If ClobberAt is a MemoryPhi, we can assume something above it acted as a + // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point. + assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) && + "ClobberAt never acted as a clobber"); +} + +/// Our algorithm for walking (and trying to optimize) clobbers, all wrapped up +/// in one class. +class ClobberWalker { + /// Save a few bytes by using unsigned instead of size_t. + using ListIndex = unsigned; + + /// Represents a span of contiguous MemoryDefs, potentially ending in a + /// MemoryPhi. + struct DefPath { + MemoryLocation Loc; + // Note that, because we always walk in reverse, Last will always dominate + // First. Also note that First and Last are inclusive. + MemoryAccess *First; + MemoryAccess *Last; + Optional<ListIndex> Previous; + + DefPath(const MemoryLocation &Loc, MemoryAccess *First, MemoryAccess *Last, + Optional<ListIndex> Previous) + : Loc(Loc), First(First), Last(Last), Previous(Previous) {} + + DefPath(const MemoryLocation &Loc, MemoryAccess *Init, + Optional<ListIndex> Previous) + : DefPath(Loc, Init, Init, Previous) {} + }; + + const MemorySSA &MSSA; + AliasAnalysis &AA; + DominatorTree &DT; + UpwardsMemoryQuery *Query; + + // Phi optimization bookkeeping + SmallVector<DefPath, 32> Paths; + DenseSet<ConstMemoryAccessPair> VisitedPhis; + + /// Find the nearest def or phi that `From` can legally be optimized to. + const MemoryAccess *getWalkTarget(const MemoryPhi *From) const { + assert(From->getNumOperands() && "Phi with no operands?"); + + BasicBlock *BB = From->getBlock(); + MemoryAccess *Result = MSSA.getLiveOnEntryDef(); + DomTreeNode *Node = DT.getNode(BB); + while ((Node = Node->getIDom())) { + auto *Defs = MSSA.getBlockDefs(Node->getBlock()); + if (Defs) + return &*Defs->rbegin(); + } + return Result; + } + + /// Result of calling walkToPhiOrClobber. + struct UpwardsWalkResult { + /// The "Result" of the walk. Either a clobber, the last thing we walked, or + /// both. + MemoryAccess *Result; + bool IsKnownClobber; + }; + + /// Walk to the next Phi or Clobber in the def chain starting at Desc.Last. + /// This will update Desc.Last as it walks. It will (optionally) also stop at + /// StopAt. + /// + /// This does not test for whether StopAt is a clobber + UpwardsWalkResult + walkToPhiOrClobber(DefPath &Desc, + const MemoryAccess *StopAt = nullptr) const { + assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world"); + + for (MemoryAccess *Current : def_chain(Desc.Last)) { + Desc.Last = Current; + if (Current == StopAt) + return {Current, false}; + + if (auto *MD = dyn_cast<MemoryDef>(Current)) + if (MSSA.isLiveOnEntryDef(MD) || + instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA)) + return {MD, true}; + } + + assert(isa<MemoryPhi>(Desc.Last) && + "Ended at a non-clobber that's not a phi?"); + return {Desc.Last, false}; + } + + void addSearches(MemoryPhi *Phi, SmallVectorImpl<ListIndex> &PausedSearches, + ListIndex PriorNode) { + auto UpwardDefs = make_range(upward_defs_begin({Phi, Paths[PriorNode].Loc}), + upward_defs_end()); + for (const MemoryAccessPair &P : UpwardDefs) { + PausedSearches.push_back(Paths.size()); + Paths.emplace_back(P.second, P.first, PriorNode); + } + } + + /// Represents a search that terminated after finding a clobber. This clobber + /// may or may not be present in the path of defs from LastNode..SearchStart, + /// since it may have been retrieved from cache. + struct TerminatedPath { + MemoryAccess *Clobber; + ListIndex LastNode; + }; + + /// Get an access that keeps us from optimizing to the given phi. + /// + /// PausedSearches is an array of indices into the Paths array. Its incoming + /// value is the indices of searches that stopped at the last phi optimization + /// target. It's left in an unspecified state. + /// + /// If this returns None, NewPaused is a vector of searches that terminated + /// at StopWhere. Otherwise, NewPaused is left in an unspecified state. + Optional<TerminatedPath> + getBlockingAccess(const MemoryAccess *StopWhere, + SmallVectorImpl<ListIndex> &PausedSearches, + SmallVectorImpl<ListIndex> &NewPaused, + SmallVectorImpl<TerminatedPath> &Terminated) { + assert(!PausedSearches.empty() && "No searches to continue?"); + + // BFS vs DFS really doesn't make a difference here, so just do a DFS with + // PausedSearches as our stack. + while (!PausedSearches.empty()) { + ListIndex PathIndex = PausedSearches.pop_back_val(); + DefPath &Node = Paths[PathIndex]; + + // If we've already visited this path with this MemoryLocation, we don't + // need to do so again. + // + // NOTE: That we just drop these paths on the ground makes caching + // behavior sporadic. e.g. given a diamond: + // A + // B C + // D + // + // ...If we walk D, B, A, C, we'll only cache the result of phi + // optimization for A, B, and D; C will be skipped because it dies here. + // This arguably isn't the worst thing ever, since: + // - We generally query things in a top-down order, so if we got below D + // without needing cache entries for {C, MemLoc}, then chances are + // that those cache entries would end up ultimately unused. + // - We still cache things for A, so C only needs to walk up a bit. + // If this behavior becomes problematic, we can fix without a ton of extra + // work. + if (!VisitedPhis.insert({Node.Last, Node.Loc}).second) + continue; + + UpwardsWalkResult Res = walkToPhiOrClobber(Node, /*StopAt=*/StopWhere); + if (Res.IsKnownClobber) { + assert(Res.Result != StopWhere); + // If this wasn't a cache hit, we hit a clobber when walking. That's a + // failure. + TerminatedPath Term{Res.Result, PathIndex}; + if (!MSSA.dominates(Res.Result, StopWhere)) + return Term; + + // Otherwise, it's a valid thing to potentially optimize to. + Terminated.push_back(Term); + continue; + } + + if (Res.Result == StopWhere) { + // We've hit our target. Save this path off for if we want to continue + // walking. + NewPaused.push_back(PathIndex); + continue; + } + + assert(!MSSA.isLiveOnEntryDef(Res.Result) && "liveOnEntry is a clobber"); + addSearches(cast<MemoryPhi>(Res.Result), PausedSearches, PathIndex); + } + + return None; + } + + template <typename T, typename Walker> + struct generic_def_path_iterator + : public iterator_facade_base<generic_def_path_iterator<T, Walker>, + std::forward_iterator_tag, T *> { + generic_def_path_iterator() : W(nullptr), N(None) {} + generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {} + + T &operator*() const { return curNode(); } + + generic_def_path_iterator &operator++() { + N = curNode().Previous; + return *this; + } + + bool operator==(const generic_def_path_iterator &O) const { + if (N.hasValue() != O.N.hasValue()) + return false; + return !N.hasValue() || *N == *O.N; + } + + private: + T &curNode() const { return W->Paths[*N]; } + + Walker *W; + Optional<ListIndex> N; + }; + + using def_path_iterator = generic_def_path_iterator<DefPath, ClobberWalker>; + using const_def_path_iterator = + generic_def_path_iterator<const DefPath, const ClobberWalker>; + + iterator_range<def_path_iterator> def_path(ListIndex From) { + return make_range(def_path_iterator(this, From), def_path_iterator()); + } + + iterator_range<const_def_path_iterator> const_def_path(ListIndex From) const { + return make_range(const_def_path_iterator(this, From), + const_def_path_iterator()); + } + + struct OptznResult { + /// The path that contains our result. + TerminatedPath PrimaryClobber; + /// The paths that we can legally cache back from, but that aren't + /// necessarily the result of the Phi optimization. + SmallVector<TerminatedPath, 4> OtherClobbers; + }; + + ListIndex defPathIndex(const DefPath &N) const { + // The assert looks nicer if we don't need to do &N + const DefPath *NP = &N; + assert(!Paths.empty() && NP >= &Paths.front() && NP <= &Paths.back() && + "Out of bounds DefPath!"); + return NP - &Paths.front(); + } + + /// Try to optimize a phi as best as we can. Returns a SmallVector of Paths + /// that act as legal clobbers. Note that this won't return *all* clobbers. + /// + /// Phi optimization algorithm tl;dr: + /// - Find the earliest def/phi, A, we can optimize to + /// - Find if all paths from the starting memory access ultimately reach A + /// - If not, optimization isn't possible. + /// - Otherwise, walk from A to another clobber or phi, A'. + /// - If A' is a def, we're done. + /// - If A' is a phi, try to optimize it. + /// + /// A path is a series of {MemoryAccess, MemoryLocation} pairs. A path + /// terminates when a MemoryAccess that clobbers said MemoryLocation is found. + OptznResult tryOptimizePhi(MemoryPhi *Phi, MemoryAccess *Start, + const MemoryLocation &Loc) { + assert(Paths.empty() && VisitedPhis.empty() && + "Reset the optimization state."); + + Paths.emplace_back(Loc, Start, Phi, None); + // Stores how many "valid" optimization nodes we had prior to calling + // addSearches/getBlockingAccess. Necessary for caching if we had a blocker. + auto PriorPathsSize = Paths.size(); + + SmallVector<ListIndex, 16> PausedSearches; + SmallVector<ListIndex, 8> NewPaused; + SmallVector<TerminatedPath, 4> TerminatedPaths; + + addSearches(Phi, PausedSearches, 0); + + // Moves the TerminatedPath with the "most dominated" Clobber to the end of + // Paths. + auto MoveDominatedPathToEnd = [&](SmallVectorImpl<TerminatedPath> &Paths) { + assert(!Paths.empty() && "Need a path to move"); + auto Dom = Paths.begin(); + for (auto I = std::next(Dom), E = Paths.end(); I != E; ++I) + if (!MSSA.dominates(I->Clobber, Dom->Clobber)) + Dom = I; + auto Last = Paths.end() - 1; + if (Last != Dom) + std::iter_swap(Last, Dom); + }; + + MemoryPhi *Current = Phi; + while (1) { + assert(!MSSA.isLiveOnEntryDef(Current) && + "liveOnEntry wasn't treated as a clobber?"); + + const auto *Target = getWalkTarget(Current); + // If a TerminatedPath doesn't dominate Target, then it wasn't a legal + // optimization for the prior phi. + assert(all_of(TerminatedPaths, [&](const TerminatedPath &P) { + return MSSA.dominates(P.Clobber, Target); + })); + + // FIXME: This is broken, because the Blocker may be reported to be + // liveOnEntry, and we'll happily wait for that to disappear (read: never) + // For the moment, this is fine, since we do nothing with blocker info. + if (Optional<TerminatedPath> Blocker = getBlockingAccess( + Target, PausedSearches, NewPaused, TerminatedPaths)) { + + // Find the node we started at. We can't search based on N->Last, since + // we may have gone around a loop with a different MemoryLocation. + auto Iter = find_if(def_path(Blocker->LastNode), [&](const DefPath &N) { + return defPathIndex(N) < PriorPathsSize; + }); + assert(Iter != def_path_iterator()); + + DefPath &CurNode = *Iter; + assert(CurNode.Last == Current); + + // Two things: + // A. We can't reliably cache all of NewPaused back. Consider a case + // where we have two paths in NewPaused; one of which can't optimize + // above this phi, whereas the other can. If we cache the second path + // back, we'll end up with suboptimal cache entries. We can handle + // cases like this a bit better when we either try to find all + // clobbers that block phi optimization, or when our cache starts + // supporting unfinished searches. + // B. We can't reliably cache TerminatedPaths back here without doing + // extra checks; consider a case like: + // T + // / \ + // D C + // \ / + // S + // Where T is our target, C is a node with a clobber on it, D is a + // diamond (with a clobber *only* on the left or right node, N), and + // S is our start. Say we walk to D, through the node opposite N + // (read: ignoring the clobber), and see a cache entry in the top + // node of D. That cache entry gets put into TerminatedPaths. We then + // walk up to C (N is later in our worklist), find the clobber, and + // quit. If we append TerminatedPaths to OtherClobbers, we'll cache + // the bottom part of D to the cached clobber, ignoring the clobber + // in N. Again, this problem goes away if we start tracking all + // blockers for a given phi optimization. + TerminatedPath Result{CurNode.Last, defPathIndex(CurNode)}; + return {Result, {}}; + } + + // If there's nothing left to search, then all paths led to valid clobbers + // that we got from our cache; pick the nearest to the start, and allow + // the rest to be cached back. + if (NewPaused.empty()) { + MoveDominatedPathToEnd(TerminatedPaths); + TerminatedPath Result = TerminatedPaths.pop_back_val(); + return {Result, std::move(TerminatedPaths)}; + } + + MemoryAccess *DefChainEnd = nullptr; + SmallVector<TerminatedPath, 4> Clobbers; + for (ListIndex Paused : NewPaused) { + UpwardsWalkResult WR = walkToPhiOrClobber(Paths[Paused]); + if (WR.IsKnownClobber) + Clobbers.push_back({WR.Result, Paused}); + else + // Micro-opt: If we hit the end of the chain, save it. + DefChainEnd = WR.Result; + } + + if (!TerminatedPaths.empty()) { + // If we couldn't find the dominating phi/liveOnEntry in the above loop, + // do it now. + if (!DefChainEnd) + for (auto *MA : def_chain(const_cast<MemoryAccess *>(Target))) + DefChainEnd = MA; + + // If any of the terminated paths don't dominate the phi we'll try to + // optimize, we need to figure out what they are and quit. + const BasicBlock *ChainBB = DefChainEnd->getBlock(); + for (const TerminatedPath &TP : TerminatedPaths) { + // Because we know that DefChainEnd is as "high" as we can go, we + // don't need local dominance checks; BB dominance is sufficient. + if (DT.dominates(ChainBB, TP.Clobber->getBlock())) + Clobbers.push_back(TP); + } + } + + // If we have clobbers in the def chain, find the one closest to Current + // and quit. + if (!Clobbers.empty()) { + MoveDominatedPathToEnd(Clobbers); + TerminatedPath Result = Clobbers.pop_back_val(); + return {Result, std::move(Clobbers)}; + } + + assert(all_of(NewPaused, + [&](ListIndex I) { return Paths[I].Last == DefChainEnd; })); + + // Because liveOnEntry is a clobber, this must be a phi. + auto *DefChainPhi = cast<MemoryPhi>(DefChainEnd); + + PriorPathsSize = Paths.size(); + PausedSearches.clear(); + for (ListIndex I : NewPaused) + addSearches(DefChainPhi, PausedSearches, I); + NewPaused.clear(); + + Current = DefChainPhi; + } + } + + void verifyOptResult(const OptznResult &R) const { + assert(all_of(R.OtherClobbers, [&](const TerminatedPath &P) { + return MSSA.dominates(P.Clobber, R.PrimaryClobber.Clobber); + })); + } + + void resetPhiOptznState() { + Paths.clear(); + VisitedPhis.clear(); + } + +public: + ClobberWalker(const MemorySSA &MSSA, AliasAnalysis &AA, DominatorTree &DT) + : MSSA(MSSA), AA(AA), DT(DT) {} + + void reset() {} + + /// Finds the nearest clobber for the given query, optimizing phis if + /// possible. + MemoryAccess *findClobber(MemoryAccess *Start, UpwardsMemoryQuery &Q) { + Query = &Q; + + MemoryAccess *Current = Start; + // This walker pretends uses don't exist. If we're handed one, silently grab + // its def. (This has the nice side-effect of ensuring we never cache uses) + if (auto *MU = dyn_cast<MemoryUse>(Start)) + Current = MU->getDefiningAccess(); + + DefPath FirstDesc(Q.StartingLoc, Current, Current, None); + // Fast path for the overly-common case (no crazy phi optimization + // necessary) + UpwardsWalkResult WalkResult = walkToPhiOrClobber(FirstDesc); + MemoryAccess *Result; + if (WalkResult.IsKnownClobber) { + Result = WalkResult.Result; + } else { + OptznResult OptRes = tryOptimizePhi(cast<MemoryPhi>(FirstDesc.Last), + Current, Q.StartingLoc); + verifyOptResult(OptRes); + resetPhiOptznState(); + Result = OptRes.PrimaryClobber.Clobber; + } + +#ifdef EXPENSIVE_CHECKS + checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA); +#endif + return Result; + } + + void verify(const MemorySSA *MSSA) { assert(MSSA == &this->MSSA); } +}; + +struct RenamePassData { + DomTreeNode *DTN; + DomTreeNode::const_iterator ChildIt; + MemoryAccess *IncomingVal; + + RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It, + MemoryAccess *M) + : DTN(D), ChildIt(It), IncomingVal(M) {} + void swap(RenamePassData &RHS) { + std::swap(DTN, RHS.DTN); + std::swap(ChildIt, RHS.ChildIt); + std::swap(IncomingVal, RHS.IncomingVal); + } +}; +} // anonymous namespace + +namespace llvm { +/// \brief A MemorySSAWalker that does AA walks to disambiguate accesses. It no +/// longer does caching on its own, +/// but the name has been retained for the moment. +class MemorySSA::CachingWalker final : public MemorySSAWalker { + ClobberWalker Walker; + bool AutoResetWalker; + + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &); + void verifyRemoved(MemoryAccess *); + +public: + CachingWalker(MemorySSA *, AliasAnalysis *, DominatorTree *); + ~CachingWalker() override; + + using MemorySSAWalker::getClobberingMemoryAccess; + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *) override; + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, + const MemoryLocation &) override; + void invalidateInfo(MemoryAccess *) override; + + /// Whether we call resetClobberWalker() after each time we *actually* walk to + /// answer a clobber query. + void setAutoResetWalker(bool AutoReset) { AutoResetWalker = AutoReset; } + + /// Drop the walker's persistent data structures. + void resetClobberWalker() { Walker.reset(); } + + void verify(const MemorySSA *MSSA) override { + MemorySSAWalker::verify(MSSA); + Walker.verify(MSSA); + } +}; + +void MemorySSA::renameSuccessorPhis(BasicBlock *BB, MemoryAccess *IncomingVal, + bool RenameAllUses) { + // Pass through values to our successors + for (const BasicBlock *S : successors(BB)) { + auto It = PerBlockAccesses.find(S); + // Rename the phi nodes in our successor block + if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) + continue; + AccessList *Accesses = It->second.get(); + auto *Phi = cast<MemoryPhi>(&Accesses->front()); + if (RenameAllUses) { + int PhiIndex = Phi->getBasicBlockIndex(BB); + assert(PhiIndex != -1 && "Incomplete phi during partial rename"); + Phi->setIncomingValue(PhiIndex, IncomingVal); + } else + Phi->addIncoming(IncomingVal, BB); + } +} + +/// \brief Rename a single basic block into MemorySSA form. +/// Uses the standard SSA renaming algorithm. +/// \returns The new incoming value. +MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB, MemoryAccess *IncomingVal, + bool RenameAllUses) { + auto It = PerBlockAccesses.find(BB); + // Skip most processing if the list is empty. + if (It != PerBlockAccesses.end()) { + AccessList *Accesses = It->second.get(); + for (MemoryAccess &L : *Accesses) { + if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&L)) { + if (MUD->getDefiningAccess() == nullptr || RenameAllUses) + MUD->setDefiningAccess(IncomingVal); + if (isa<MemoryDef>(&L)) + IncomingVal = &L; + } else { + IncomingVal = &L; + } + } + } + return IncomingVal; +} + +/// \brief This is the standard SSA renaming algorithm. +/// +/// We walk the dominator tree in preorder, renaming accesses, and then filling +/// in phi nodes in our successors. +void MemorySSA::renamePass(DomTreeNode *Root, MemoryAccess *IncomingVal, + SmallPtrSetImpl<BasicBlock *> &Visited, + bool SkipVisited, bool RenameAllUses) { + SmallVector<RenamePassData, 32> WorkStack; + // Skip everything if we already renamed this block and we are skipping. + // Note: You can't sink this into the if, because we need it to occur + // regardless of whether we skip blocks or not. + bool AlreadyVisited = !Visited.insert(Root->getBlock()).second; + if (SkipVisited && AlreadyVisited) + return; + + IncomingVal = renameBlock(Root->getBlock(), IncomingVal, RenameAllUses); + renameSuccessorPhis(Root->getBlock(), IncomingVal, RenameAllUses); + WorkStack.push_back({Root, Root->begin(), IncomingVal}); + + while (!WorkStack.empty()) { + DomTreeNode *Node = WorkStack.back().DTN; + DomTreeNode::const_iterator ChildIt = WorkStack.back().ChildIt; + IncomingVal = WorkStack.back().IncomingVal; + + if (ChildIt == Node->end()) { + WorkStack.pop_back(); + } else { + DomTreeNode *Child = *ChildIt; + ++WorkStack.back().ChildIt; + BasicBlock *BB = Child->getBlock(); + // Note: You can't sink this into the if, because we need it to occur + // regardless of whether we skip blocks or not. + AlreadyVisited = !Visited.insert(BB).second; + if (SkipVisited && AlreadyVisited) { + // We already visited this during our renaming, which can happen when + // being asked to rename multiple blocks. Figure out the incoming val, + // which is the last def. + // Incoming value can only change if there is a block def, and in that + // case, it's the last block def in the list. + if (auto *BlockDefs = getWritableBlockDefs(BB)) + IncomingVal = &*BlockDefs->rbegin(); + } else + IncomingVal = renameBlock(BB, IncomingVal, RenameAllUses); + renameSuccessorPhis(BB, IncomingVal, RenameAllUses); + WorkStack.push_back({Child, Child->begin(), IncomingVal}); + } + } +} + +/// \brief This handles unreachable block accesses by deleting phi nodes in +/// unreachable blocks, and marking all other unreachable MemoryAccess's as +/// being uses of the live on entry definition. +void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) { + assert(!DT->isReachableFromEntry(BB) && + "Reachable block found while handling unreachable blocks"); + + // Make sure phi nodes in our reachable successors end up with a + // LiveOnEntryDef for our incoming edge, even though our block is forward + // unreachable. We could just disconnect these blocks from the CFG fully, + // but we do not right now. + for (const BasicBlock *S : successors(BB)) { + if (!DT->isReachableFromEntry(S)) + continue; + auto It = PerBlockAccesses.find(S); + // Rename the phi nodes in our successor block + if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) + continue; + AccessList *Accesses = It->second.get(); + auto *Phi = cast<MemoryPhi>(&Accesses->front()); + Phi->addIncoming(LiveOnEntryDef.get(), BB); + } + + auto It = PerBlockAccesses.find(BB); + if (It == PerBlockAccesses.end()) + return; + + auto &Accesses = It->second; + for (auto AI = Accesses->begin(), AE = Accesses->end(); AI != AE;) { + auto Next = std::next(AI); + // If we have a phi, just remove it. We are going to replace all + // users with live on entry. + if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(AI)) + UseOrDef->setDefiningAccess(LiveOnEntryDef.get()); + else + Accesses->erase(AI); + AI = Next; + } +} + +MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT) + : AA(AA), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr), + NextID(INVALID_MEMORYACCESS_ID) { + buildMemorySSA(); +} + +MemorySSA::~MemorySSA() { + // Drop all our references + for (const auto &Pair : PerBlockAccesses) + for (MemoryAccess &MA : *Pair.second) + MA.dropAllReferences(); +} + +MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) { + auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr)); + + if (Res.second) + Res.first->second = make_unique<AccessList>(); + return Res.first->second.get(); +} +MemorySSA::DefsList *MemorySSA::getOrCreateDefsList(const BasicBlock *BB) { + auto Res = PerBlockDefs.insert(std::make_pair(BB, nullptr)); + + if (Res.second) + Res.first->second = make_unique<DefsList>(); + return Res.first->second.get(); +} + +/// This class is a batch walker of all MemoryUse's in the program, and points +/// their defining access at the thing that actually clobbers them. Because it +/// is a batch walker that touches everything, it does not operate like the +/// other walkers. This walker is basically performing a top-down SSA renaming +/// pass, where the version stack is used as the cache. This enables it to be +/// significantly more time and memory efficient than using the regular walker, +/// which is walking bottom-up. +class MemorySSA::OptimizeUses { +public: + OptimizeUses(MemorySSA *MSSA, MemorySSAWalker *Walker, AliasAnalysis *AA, + DominatorTree *DT) + : MSSA(MSSA), Walker(Walker), AA(AA), DT(DT) { + Walker = MSSA->getWalker(); + } + + void optimizeUses(); + +private: + /// This represents where a given memorylocation is in the stack. + struct MemlocStackInfo { + // This essentially is keeping track of versions of the stack. Whenever + // the stack changes due to pushes or pops, these versions increase. + unsigned long StackEpoch; + unsigned long PopEpoch; + // This is the lower bound of places on the stack to check. It is equal to + // the place the last stack walk ended. + // Note: Correctness depends on this being initialized to 0, which densemap + // does + unsigned long LowerBound; + const BasicBlock *LowerBoundBlock; + // This is where the last walk for this memory location ended. + unsigned long LastKill; + bool LastKillValid; + }; + void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &, + SmallVectorImpl<MemoryAccess *> &, + DenseMap<MemoryLocOrCall, MemlocStackInfo> &); + MemorySSA *MSSA; + MemorySSAWalker *Walker; + AliasAnalysis *AA; + DominatorTree *DT; +}; + +/// Optimize the uses in a given block This is basically the SSA renaming +/// algorithm, with one caveat: We are able to use a single stack for all +/// MemoryUses. This is because the set of *possible* reaching MemoryDefs is +/// the same for every MemoryUse. The *actual* clobbering MemoryDef is just +/// going to be some position in that stack of possible ones. +/// +/// We track the stack positions that each MemoryLocation needs +/// to check, and last ended at. This is because we only want to check the +/// things that changed since last time. The same MemoryLocation should +/// get clobbered by the same store (getModRefInfo does not use invariantness or +/// things like this, and if they start, we can modify MemoryLocOrCall to +/// include relevant data) +void MemorySSA::OptimizeUses::optimizeUsesInBlock( + const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch, + SmallVectorImpl<MemoryAccess *> &VersionStack, + DenseMap<MemoryLocOrCall, MemlocStackInfo> &LocStackInfo) { + + /// If no accesses, nothing to do. + MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB); + if (Accesses == nullptr) + return; + + // Pop everything that doesn't dominate the current block off the stack, + // increment the PopEpoch to account for this. + while (true) { + assert( + !VersionStack.empty() && + "Version stack should have liveOnEntry sentinel dominating everything"); + BasicBlock *BackBlock = VersionStack.back()->getBlock(); + if (DT->dominates(BackBlock, BB)) + break; + while (VersionStack.back()->getBlock() == BackBlock) + VersionStack.pop_back(); + ++PopEpoch; + } + + for (MemoryAccess &MA : *Accesses) { + auto *MU = dyn_cast<MemoryUse>(&MA); + if (!MU) { + VersionStack.push_back(&MA); + ++StackEpoch; + continue; + } + + if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) { + MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true); + continue; + } + + MemoryLocOrCall UseMLOC(MU); + auto &LocInfo = LocStackInfo[UseMLOC]; + // If the pop epoch changed, it means we've removed stuff from top of + // stack due to changing blocks. We may have to reset the lower bound or + // last kill info. + if (LocInfo.PopEpoch != PopEpoch) { + LocInfo.PopEpoch = PopEpoch; + LocInfo.StackEpoch = StackEpoch; + // If the lower bound was in something that no longer dominates us, we + // have to reset it. + // We can't simply track stack size, because the stack may have had + // pushes/pops in the meantime. + // XXX: This is non-optimal, but only is slower cases with heavily + // branching dominator trees. To get the optimal number of queries would + // be to make lowerbound and lastkill a per-loc stack, and pop it until + // the top of that stack dominates us. This does not seem worth it ATM. + // A much cheaper optimization would be to always explore the deepest + // branch of the dominator tree first. This will guarantee this resets on + // the smallest set of blocks. + if (LocInfo.LowerBoundBlock && LocInfo.LowerBoundBlock != BB && + !DT->dominates(LocInfo.LowerBoundBlock, BB)) { + // Reset the lower bound of things to check. + // TODO: Some day we should be able to reset to last kill, rather than + // 0. + LocInfo.LowerBound = 0; + LocInfo.LowerBoundBlock = VersionStack[0]->getBlock(); + LocInfo.LastKillValid = false; + } + } else if (LocInfo.StackEpoch != StackEpoch) { + // If all that has changed is the StackEpoch, we only have to check the + // new things on the stack, because we've checked everything before. In + // this case, the lower bound of things to check remains the same. + LocInfo.PopEpoch = PopEpoch; + LocInfo.StackEpoch = StackEpoch; + } + if (!LocInfo.LastKillValid) { + LocInfo.LastKill = VersionStack.size() - 1; + LocInfo.LastKillValid = true; + } + + // At this point, we should have corrected last kill and LowerBound to be + // in bounds. + assert(LocInfo.LowerBound < VersionStack.size() && + "Lower bound out of range"); + assert(LocInfo.LastKill < VersionStack.size() && + "Last kill info out of range"); + // In any case, the new upper bound is the top of the stack. + unsigned long UpperBound = VersionStack.size() - 1; + + if (UpperBound - LocInfo.LowerBound > MaxCheckLimit) { + DEBUG(dbgs() << "MemorySSA skipping optimization of " << *MU << " (" + << *(MU->getMemoryInst()) << ")" + << " because there are " << UpperBound - LocInfo.LowerBound + << " stores to disambiguate\n"); + // Because we did not walk, LastKill is no longer valid, as this may + // have been a kill. + LocInfo.LastKillValid = false; + continue; + } + bool FoundClobberResult = false; + while (UpperBound > LocInfo.LowerBound) { + if (isa<MemoryPhi>(VersionStack[UpperBound])) { + // For phis, use the walker, see where we ended up, go there + Instruction *UseInst = MU->getMemoryInst(); + MemoryAccess *Result = Walker->getClobberingMemoryAccess(UseInst); + // We are guaranteed to find it or something is wrong + while (VersionStack[UpperBound] != Result) { + assert(UpperBound != 0); + --UpperBound; + } + FoundClobberResult = true; + break; + } + + MemoryDef *MD = cast<MemoryDef>(VersionStack[UpperBound]); + // If the lifetime of the pointer ends at this instruction, it's live on + // entry. + if (!UseMLOC.IsCall && lifetimeEndsAt(MD, UseMLOC.getLoc(), *AA)) { + // Reset UpperBound to liveOnEntryDef's place in the stack + UpperBound = 0; + FoundClobberResult = true; + break; + } + if (instructionClobbersQuery(MD, MU, UseMLOC, *AA)) { + FoundClobberResult = true; + break; + } + --UpperBound; + } + // At the end of this loop, UpperBound is either a clobber, or lower bound + // PHI walking may cause it to be < LowerBound, and in fact, < LastKill. + if (FoundClobberResult || UpperBound < LocInfo.LastKill) { + MU->setDefiningAccess(VersionStack[UpperBound], true); + // We were last killed now by where we got to + LocInfo.LastKill = UpperBound; + } else { + // Otherwise, we checked all the new ones, and now we know we can get to + // LastKill. + MU->setDefiningAccess(VersionStack[LocInfo.LastKill], true); + } + LocInfo.LowerBound = VersionStack.size() - 1; + LocInfo.LowerBoundBlock = BB; + } +} + +/// Optimize uses to point to their actual clobbering definitions. +void MemorySSA::OptimizeUses::optimizeUses() { + SmallVector<MemoryAccess *, 16> VersionStack; + DenseMap<MemoryLocOrCall, MemlocStackInfo> LocStackInfo; + VersionStack.push_back(MSSA->getLiveOnEntryDef()); + + unsigned long StackEpoch = 1; + unsigned long PopEpoch = 1; + // We perform a non-recursive top-down dominator tree walk. + for (const auto *DomNode : depth_first(DT->getRootNode())) + optimizeUsesInBlock(DomNode->getBlock(), StackEpoch, PopEpoch, VersionStack, + LocStackInfo); +} + +void MemorySSA::placePHINodes( + const SmallPtrSetImpl<BasicBlock *> &DefiningBlocks, + const DenseMap<const BasicBlock *, unsigned int> &BBNumbers) { + // Determine where our MemoryPhi's should go + ForwardIDFCalculator IDFs(*DT); + IDFs.setDefiningBlocks(DefiningBlocks); + SmallVector<BasicBlock *, 32> IDFBlocks; + IDFs.calculate(IDFBlocks); + + std::sort(IDFBlocks.begin(), IDFBlocks.end(), + [&BBNumbers](const BasicBlock *A, const BasicBlock *B) { + return BBNumbers.lookup(A) < BBNumbers.lookup(B); + }); + + // Now place MemoryPhi nodes. + for (auto &BB : IDFBlocks) + createMemoryPhi(BB); +} + +void MemorySSA::buildMemorySSA() { + // We create an access to represent "live on entry", for things like + // arguments or users of globals, where the memory they use is defined before + // the beginning of the function. We do not actually insert it into the IR. + // We do not define a live on exit for the immediate uses, and thus our + // semantics do *not* imply that something with no immediate uses can simply + // be removed. + BasicBlock &StartingPoint = F.getEntryBlock(); + LiveOnEntryDef = make_unique<MemoryDef>(F.getContext(), nullptr, nullptr, + &StartingPoint, NextID++); + DenseMap<const BasicBlock *, unsigned int> BBNumbers; + unsigned NextBBNum = 0; + + // We maintain lists of memory accesses per-block, trading memory for time. We + // could just look up the memory access for every possible instruction in the + // stream. + SmallPtrSet<BasicBlock *, 32> DefiningBlocks; + // Go through each block, figure out where defs occur, and chain together all + // the accesses. + for (BasicBlock &B : F) { + BBNumbers[&B] = NextBBNum++; + bool InsertIntoDef = false; + AccessList *Accesses = nullptr; + DefsList *Defs = nullptr; + for (Instruction &I : B) { + MemoryUseOrDef *MUD = createNewAccess(&I); + if (!MUD) + continue; + + if (!Accesses) + Accesses = getOrCreateAccessList(&B); + Accesses->push_back(MUD); + if (isa<MemoryDef>(MUD)) { + InsertIntoDef = true; + if (!Defs) + Defs = getOrCreateDefsList(&B); + Defs->push_back(*MUD); + } + } + if (InsertIntoDef) + DefiningBlocks.insert(&B); + } + placePHINodes(DefiningBlocks, BBNumbers); + + // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get + // filled in with all blocks. + SmallPtrSet<BasicBlock *, 16> Visited; + renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited); + + CachingWalker *Walker = getWalkerImpl(); + + // We're doing a batch of updates; don't drop useful caches between them. + Walker->setAutoResetWalker(false); + OptimizeUses(this, Walker, AA, DT).optimizeUses(); + Walker->setAutoResetWalker(true); + Walker->resetClobberWalker(); + + // Mark the uses in unreachable blocks as live on entry, so that they go + // somewhere. + for (auto &BB : F) + if (!Visited.count(&BB)) + markUnreachableAsLiveOnEntry(&BB); +} + +MemorySSAWalker *MemorySSA::getWalker() { return getWalkerImpl(); } + +MemorySSA::CachingWalker *MemorySSA::getWalkerImpl() { + if (Walker) + return Walker.get(); + + Walker = make_unique<CachingWalker>(this, AA, DT); + return Walker.get(); +} + +// This is a helper function used by the creation routines. It places NewAccess +// into the access and defs lists for a given basic block, at the given +// insertion point. +void MemorySSA::insertIntoListsForBlock(MemoryAccess *NewAccess, + const BasicBlock *BB, + InsertionPlace Point) { + auto *Accesses = getOrCreateAccessList(BB); + if (Point == Beginning) { + // If it's a phi node, it goes first, otherwise, it goes after any phi + // nodes. + if (isa<MemoryPhi>(NewAccess)) { + Accesses->push_front(NewAccess); + auto *Defs = getOrCreateDefsList(BB); + Defs->push_front(*NewAccess); + } else { + auto AI = find_if_not( + *Accesses, [](const MemoryAccess &MA) { return isa<MemoryPhi>(MA); }); + Accesses->insert(AI, NewAccess); + if (!isa<MemoryUse>(NewAccess)) { + auto *Defs = getOrCreateDefsList(BB); + auto DI = find_if_not( + *Defs, [](const MemoryAccess &MA) { return isa<MemoryPhi>(MA); }); + Defs->insert(DI, *NewAccess); + } + } + } else { + Accesses->push_back(NewAccess); + if (!isa<MemoryUse>(NewAccess)) { + auto *Defs = getOrCreateDefsList(BB); + Defs->push_back(*NewAccess); + } + } + BlockNumberingValid.erase(BB); +} + +void MemorySSA::insertIntoListsBefore(MemoryAccess *What, const BasicBlock *BB, + AccessList::iterator InsertPt) { + auto *Accesses = getWritableBlockAccesses(BB); + bool WasEnd = InsertPt == Accesses->end(); + Accesses->insert(AccessList::iterator(InsertPt), What); + if (!isa<MemoryUse>(What)) { + auto *Defs = getOrCreateDefsList(BB); + // If we got asked to insert at the end, we have an easy job, just shove it + // at the end. If we got asked to insert before an existing def, we also get + // an terator. If we got asked to insert before a use, we have to hunt for + // the next def. + if (WasEnd) { + Defs->push_back(*What); + } else if (isa<MemoryDef>(InsertPt)) { + Defs->insert(InsertPt->getDefsIterator(), *What); + } else { + while (InsertPt != Accesses->end() && !isa<MemoryDef>(InsertPt)) + ++InsertPt; + // Either we found a def, or we are inserting at the end + if (InsertPt == Accesses->end()) + Defs->push_back(*What); + else + Defs->insert(InsertPt->getDefsIterator(), *What); + } + } + BlockNumberingValid.erase(BB); +} + +// Move What before Where in the IR. The end result is taht What will belong to +// the right lists and have the right Block set, but will not otherwise be +// correct. It will not have the right defining access, and if it is a def, +// things below it will not properly be updated. +void MemorySSA::moveTo(MemoryUseOrDef *What, BasicBlock *BB, + AccessList::iterator Where) { + // Keep it in the lookup tables, remove from the lists + removeFromLists(What, false); + What->setBlock(BB); + insertIntoListsBefore(What, BB, Where); +} + +void MemorySSA::moveTo(MemoryUseOrDef *What, BasicBlock *BB, + InsertionPlace Point) { + removeFromLists(What, false); + What->setBlock(BB); + insertIntoListsForBlock(What, BB, Point); +} + +MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { + assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); + MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); + // Phi's always are placed at the front of the block. + insertIntoListsForBlock(Phi, BB, Beginning); + ValueToMemoryAccess[BB] = Phi; + return Phi; +} + +MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I, + MemoryAccess *Definition) { + assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI"); + MemoryUseOrDef *NewAccess = createNewAccess(I); + assert( + NewAccess != nullptr && + "Tried to create a memory access for a non-memory touching instruction"); + NewAccess->setDefiningAccess(Definition); + return NewAccess; +} + +// Return true if the instruction has ordering constraints. +// Note specifically that this only considers stores and loads +// because others are still considered ModRef by getModRefInfo. +static inline bool isOrdered(const Instruction *I) { + if (auto *SI = dyn_cast<StoreInst>(I)) { + if (!SI->isUnordered()) + return true; + } else if (auto *LI = dyn_cast<LoadInst>(I)) { + if (!LI->isUnordered()) + return true; + } + return false; +} +/// \brief Helper function to create new memory accesses +MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { + // The assume intrinsic has a control dependency which we model by claiming + // that it writes arbitrarily. Ignore that fake memory dependency here. + // FIXME: Replace this special casing with a more accurate modelling of + // assume's control dependency. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + if (II->getIntrinsicID() == Intrinsic::assume) + return nullptr; + + // Find out what affect this instruction has on memory. + ModRefInfo ModRef = AA->getModRefInfo(I); + // The isOrdered check is used to ensure that volatiles end up as defs + // (atomics end up as ModRef right now anyway). Until we separate the + // ordering chain from the memory chain, this enables people to see at least + // some relative ordering to volatiles. Note that getClobberingMemoryAccess + // will still give an answer that bypasses other volatile loads. TODO: + // Separate memory aliasing and ordering into two different chains so that we + // can precisely represent both "what memory will this read/write/is clobbered + // by" and "what instructions can I move this past". + bool Def = bool(ModRef & MRI_Mod) || isOrdered(I); + bool Use = bool(ModRef & MRI_Ref); + + // It's possible for an instruction to not modify memory at all. During + // construction, we ignore them. + if (!Def && !Use) + return nullptr; + + assert((Def || Use) && + "Trying to create a memory access with a non-memory instruction"); + + MemoryUseOrDef *MUD; + if (Def) + MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++); + else + MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent()); + ValueToMemoryAccess[I] = MUD; + return MUD; +} + +/// \brief Returns true if \p Replacer dominates \p Replacee . +bool MemorySSA::dominatesUse(const MemoryAccess *Replacer, + const MemoryAccess *Replacee) const { + if (isa<MemoryUseOrDef>(Replacee)) + return DT->dominates(Replacer->getBlock(), Replacee->getBlock()); + const auto *MP = cast<MemoryPhi>(Replacee); + // For a phi node, the use occurs in the predecessor block of the phi node. + // Since we may occur multiple times in the phi node, we have to check each + // operand to ensure Replacer dominates each operand where Replacee occurs. + for (const Use &Arg : MP->operands()) { + if (Arg.get() != Replacee && + !DT->dominates(Replacer->getBlock(), MP->getIncomingBlock(Arg))) + return false; + } + return true; +} + +/// \brief Properly remove \p MA from all of MemorySSA's lookup tables. +void MemorySSA::removeFromLookups(MemoryAccess *MA) { + assert(MA->use_empty() && + "Trying to remove memory access that still has uses"); + BlockNumbering.erase(MA); + if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MUD->setDefiningAccess(nullptr); + // Invalidate our walker's cache if necessary + if (!isa<MemoryUse>(MA)) + Walker->invalidateInfo(MA); + // The call below to erase will destroy MA, so we can't change the order we + // are doing things here + Value *MemoryInst; + if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) { + MemoryInst = MUD->getMemoryInst(); + } else { + MemoryInst = MA->getBlock(); + } + auto VMA = ValueToMemoryAccess.find(MemoryInst); + if (VMA->second == MA) + ValueToMemoryAccess.erase(VMA); +} + +/// \brief Properly remove \p MA from all of MemorySSA's lists. +/// +/// Because of the way the intrusive list and use lists work, it is important to +/// do removal in the right order. +/// ShouldDelete defaults to true, and will cause the memory access to also be +/// deleted, not just removed. +void MemorySSA::removeFromLists(MemoryAccess *MA, bool ShouldDelete) { + // The access list owns the reference, so we erase it from the non-owning list + // first. + if (!isa<MemoryUse>(MA)) { + auto DefsIt = PerBlockDefs.find(MA->getBlock()); + std::unique_ptr<DefsList> &Defs = DefsIt->second; + Defs->remove(*MA); + if (Defs->empty()) + PerBlockDefs.erase(DefsIt); + } + + // The erase call here will delete it. If we don't want it deleted, we call + // remove instead. + auto AccessIt = PerBlockAccesses.find(MA->getBlock()); + std::unique_ptr<AccessList> &Accesses = AccessIt->second; + if (ShouldDelete) + Accesses->erase(MA); + else + Accesses->remove(MA); + + if (Accesses->empty()) + PerBlockAccesses.erase(AccessIt); +} + +void MemorySSA::print(raw_ostream &OS) const { + MemorySSAAnnotatedWriter Writer(this); + F.print(OS, &Writer); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void MemorySSA::dump() const { print(dbgs()); } +#endif + +void MemorySSA::verifyMemorySSA() const { + verifyDefUses(F); + verifyDomination(F); + verifyOrdering(F); + Walker->verify(this); +} + +/// \brief Verify that the order and existence of MemoryAccesses matches the +/// order and existence of memory affecting instructions. +void MemorySSA::verifyOrdering(Function &F) const { + // Walk all the blocks, comparing what the lookups think and what the access + // lists think, as well as the order in the blocks vs the order in the access + // lists. + SmallVector<MemoryAccess *, 32> ActualAccesses; + SmallVector<MemoryAccess *, 32> ActualDefs; + for (BasicBlock &B : F) { + const AccessList *AL = getBlockAccesses(&B); + const auto *DL = getBlockDefs(&B); + MemoryAccess *Phi = getMemoryAccess(&B); + if (Phi) { + ActualAccesses.push_back(Phi); + ActualDefs.push_back(Phi); + } + + for (Instruction &I : B) { + MemoryAccess *MA = getMemoryAccess(&I); + assert((!MA || (AL && (isa<MemoryUse>(MA) || DL))) && + "We have memory affecting instructions " + "in this block but they are not in the " + "access list or defs list"); + if (MA) { + ActualAccesses.push_back(MA); + if (isa<MemoryDef>(MA)) + ActualDefs.push_back(MA); + } + } + // Either we hit the assert, really have no accesses, or we have both + // accesses and an access list. + // Same with defs. + if (!AL && !DL) + continue; + assert(AL->size() == ActualAccesses.size() && + "We don't have the same number of accesses in the block as on the " + "access list"); + assert((DL || ActualDefs.size() == 0) && + "Either we should have a defs list, or we should have no defs"); + assert((!DL || DL->size() == ActualDefs.size()) && + "We don't have the same number of defs in the block as on the " + "def list"); + auto ALI = AL->begin(); + auto AAI = ActualAccesses.begin(); + while (ALI != AL->end() && AAI != ActualAccesses.end()) { + assert(&*ALI == *AAI && "Not the same accesses in the same order"); + ++ALI; + ++AAI; + } + ActualAccesses.clear(); + if (DL) { + auto DLI = DL->begin(); + auto ADI = ActualDefs.begin(); + while (DLI != DL->end() && ADI != ActualDefs.end()) { + assert(&*DLI == *ADI && "Not the same defs in the same order"); + ++DLI; + ++ADI; + } + } + ActualDefs.clear(); + } +} + +/// \brief Verify the domination properties of MemorySSA by checking that each +/// definition dominates all of its uses. +void MemorySSA::verifyDomination(Function &F) const { +#ifndef NDEBUG + for (BasicBlock &B : F) { + // Phi nodes are attached to basic blocks + if (MemoryPhi *MP = getMemoryAccess(&B)) + for (const Use &U : MP->uses()) + assert(dominates(MP, U) && "Memory PHI does not dominate it's uses"); + + for (Instruction &I : B) { + MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I)); + if (!MD) + continue; + + for (const Use &U : MD->uses()) + assert(dominates(MD, U) && "Memory Def does not dominate it's uses"); + } + } +#endif +} + +/// \brief Verify the def-use lists in MemorySSA, by verifying that \p Use +/// appears in the use list of \p Def. + +void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const { +#ifndef NDEBUG + // The live on entry use may cause us to get a NULL def here + if (!Def) + assert(isLiveOnEntryDef(Use) && + "Null def but use not point to live on entry def"); + else + assert(is_contained(Def->users(), Use) && + "Did not find use in def's use list"); +#endif +} + +/// \brief Verify the immediate use information, by walking all the memory +/// accesses and verifying that, for each use, it appears in the +/// appropriate def's use list +void MemorySSA::verifyDefUses(Function &F) const { + for (BasicBlock &B : F) { + // Phi nodes are attached to basic blocks + if (MemoryPhi *Phi = getMemoryAccess(&B)) { + assert(Phi->getNumOperands() == static_cast<unsigned>(std::distance( + pred_begin(&B), pred_end(&B))) && + "Incomplete MemoryPhi Node"); + for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) + verifyUseInDefs(Phi->getIncomingValue(I), Phi); + } + + for (Instruction &I : B) { + if (MemoryUseOrDef *MA = getMemoryAccess(&I)) { + verifyUseInDefs(MA->getDefiningAccess(), MA); + } + } + } +} + +MemoryUseOrDef *MemorySSA::getMemoryAccess(const Instruction *I) const { + return cast_or_null<MemoryUseOrDef>(ValueToMemoryAccess.lookup(I)); +} + +MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const { + return cast_or_null<MemoryPhi>(ValueToMemoryAccess.lookup(cast<Value>(BB))); +} + +/// Perform a local numbering on blocks so that instruction ordering can be +/// determined in constant time. +/// TODO: We currently just number in order. If we numbered by N, we could +/// allow at least N-1 sequences of insertBefore or insertAfter (and at least +/// log2(N) sequences of mixed before and after) without needing to invalidate +/// the numbering. +void MemorySSA::renumberBlock(const BasicBlock *B) const { + // The pre-increment ensures the numbers really start at 1. + unsigned long CurrentNumber = 0; + const AccessList *AL = getBlockAccesses(B); + assert(AL != nullptr && "Asking to renumber an empty block"); + for (const auto &I : *AL) + BlockNumbering[&I] = ++CurrentNumber; + BlockNumberingValid.insert(B); +} + +/// \brief Determine, for two memory accesses in the same block, +/// whether \p Dominator dominates \p Dominatee. +/// \returns True if \p Dominator dominates \p Dominatee. +bool MemorySSA::locallyDominates(const MemoryAccess *Dominator, + const MemoryAccess *Dominatee) const { + + const BasicBlock *DominatorBlock = Dominator->getBlock(); + + assert((DominatorBlock == Dominatee->getBlock()) && + "Asking for local domination when accesses are in different blocks!"); + // A node dominates itself. + if (Dominatee == Dominator) + return true; + + // When Dominatee is defined on function entry, it is not dominated by another + // memory access. + if (isLiveOnEntryDef(Dominatee)) + return false; + + // When Dominator is defined on function entry, it dominates the other memory + // access. + if (isLiveOnEntryDef(Dominator)) + return true; + + if (!BlockNumberingValid.count(DominatorBlock)) + renumberBlock(DominatorBlock); + + unsigned long DominatorNum = BlockNumbering.lookup(Dominator); + // All numbers start with 1 + assert(DominatorNum != 0 && "Block was not numbered properly"); + unsigned long DominateeNum = BlockNumbering.lookup(Dominatee); + assert(DominateeNum != 0 && "Block was not numbered properly"); + return DominatorNum < DominateeNum; +} + +bool MemorySSA::dominates(const MemoryAccess *Dominator, + const MemoryAccess *Dominatee) const { + if (Dominator == Dominatee) + return true; + + if (isLiveOnEntryDef(Dominatee)) + return false; + + if (Dominator->getBlock() != Dominatee->getBlock()) + return DT->dominates(Dominator->getBlock(), Dominatee->getBlock()); + return locallyDominates(Dominator, Dominatee); +} + +bool MemorySSA::dominates(const MemoryAccess *Dominator, + const Use &Dominatee) const { + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Dominatee.getUser())) { + BasicBlock *UseBB = MP->getIncomingBlock(Dominatee); + // The def must dominate the incoming block of the phi. + if (UseBB != Dominator->getBlock()) + return DT->dominates(Dominator->getBlock(), UseBB); + // If the UseBB and the DefBB are the same, compare locally. + return locallyDominates(Dominator, cast<MemoryAccess>(Dominatee)); + } + // If it's not a PHI node use, the normal dominates can already handle it. + return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser())); +} + +const static char LiveOnEntryStr[] = "liveOnEntry"; + +void MemoryAccess::print(raw_ostream &OS) const { + switch (getValueID()) { + case MemoryPhiVal: return static_cast<const MemoryPhi *>(this)->print(OS); + case MemoryDefVal: return static_cast<const MemoryDef *>(this)->print(OS); + case MemoryUseVal: return static_cast<const MemoryUse *>(this)->print(OS); + } + llvm_unreachable("invalid value id"); +} + +void MemoryDef::print(raw_ostream &OS) const { + MemoryAccess *UO = getDefiningAccess(); + + OS << getID() << " = MemoryDef("; + if (UO && UO->getID()) + OS << UO->getID(); + else + OS << LiveOnEntryStr; + OS << ')'; +} + +void MemoryPhi::print(raw_ostream &OS) const { + bool First = true; + OS << getID() << " = MemoryPhi("; + for (const auto &Op : operands()) { + BasicBlock *BB = getIncomingBlock(Op); + MemoryAccess *MA = cast<MemoryAccess>(Op); + if (!First) + OS << ','; + else + First = false; + + OS << '{'; + if (BB->hasName()) + OS << BB->getName(); + else + BB->printAsOperand(OS, false); + OS << ','; + if (unsigned ID = MA->getID()) + OS << ID; + else + OS << LiveOnEntryStr; + OS << '}'; + } + OS << ')'; +} + +void MemoryUse::print(raw_ostream &OS) const { + MemoryAccess *UO = getDefiningAccess(); + OS << "MemoryUse("; + if (UO && UO->getID()) + OS << UO->getID(); + else + OS << LiveOnEntryStr; + OS << ')'; +} + +void MemoryAccess::dump() const { +// Cannot completely remove virtual function even in release mode. +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + print(dbgs()); + dbgs() << "\n"; +#endif +} + +char MemorySSAPrinterLegacyPass::ID = 0; + +MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) { + initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); +} + +void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); +} + +bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { + auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSA.print(dbgs()); + if (VerifyMemorySSA) + MSSA.verifyMemorySSA(); + return false; +} + +AnalysisKey MemorySSAAnalysis::Key; + +MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + return MemorySSAAnalysis::Result(make_unique<MemorySSA>(F, &AA, &DT)); +} + +PreservedAnalyses MemorySSAPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "MemorySSA for function: " << F.getName() << "\n"; + AM.getResult<MemorySSAAnalysis>(F).getMSSA().print(OS); + + return PreservedAnalyses::all(); +} + +PreservedAnalyses MemorySSAVerifierPass::run(Function &F, + FunctionAnalysisManager &AM) { + AM.getResult<MemorySSAAnalysis>(F).getMSSA().verifyMemorySSA(); + + return PreservedAnalyses::all(); +} + +char MemorySSAWrapperPass::ID = 0; + +MemorySSAWrapperPass::MemorySSAWrapperPass() : FunctionPass(ID) { + initializeMemorySSAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void MemorySSAWrapperPass::releaseMemory() { MSSA.reset(); } + +void MemorySSAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<DominatorTreeWrapperPass>(); + AU.addRequiredTransitive<AAResultsWrapperPass>(); +} + +bool MemorySSAWrapperPass::runOnFunction(Function &F) { + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + MSSA.reset(new MemorySSA(F, &AA, &DT)); + return false; +} + +void MemorySSAWrapperPass::verifyAnalysis() const { MSSA->verifyMemorySSA(); } + +void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const { + MSSA->print(OS); +} + +MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {} + +MemorySSA::CachingWalker::CachingWalker(MemorySSA *M, AliasAnalysis *A, + DominatorTree *D) + : MemorySSAWalker(M), Walker(*M, *A, *D), AutoResetWalker(true) {} + +MemorySSA::CachingWalker::~CachingWalker() {} + +void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) { + if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MUD->resetOptimized(); +} + +/// \brief Walk the use-def chains starting at \p MA and find +/// the MemoryAccess that actually clobbers Loc. +/// +/// \returns our clobbering memory access +MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( + MemoryAccess *StartingAccess, UpwardsMemoryQuery &Q) { + MemoryAccess *New = Walker.findClobber(StartingAccess, Q); +#ifdef EXPENSIVE_CHECKS + MemoryAccess *NewNoCache = Walker.findClobber(StartingAccess, Q); + assert(NewNoCache == New && "Cache made us hand back a different result?"); +#endif + if (AutoResetWalker) + resetClobberWalker(); + return New; +} + +MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( + MemoryAccess *StartingAccess, const MemoryLocation &Loc) { + if (isa<MemoryPhi>(StartingAccess)) + return StartingAccess; + + auto *StartingUseOrDef = cast<MemoryUseOrDef>(StartingAccess); + if (MSSA->isLiveOnEntryDef(StartingUseOrDef)) + return StartingUseOrDef; + + Instruction *I = StartingUseOrDef->getMemoryInst(); + + // Conservatively, fences are always clobbers, so don't perform the walk if we + // hit a fence. + if (!ImmutableCallSite(I) && I->isFenceLike()) + return StartingUseOrDef; + + UpwardsMemoryQuery Q; + Q.OriginalAccess = StartingUseOrDef; + Q.StartingLoc = Loc; + Q.Inst = I; + Q.IsCall = false; + + // Unlike the other function, do not walk to the def of a def, because we are + // handed something we already believe is the clobbering access. + MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef) + ? StartingUseOrDef->getDefiningAccess() + : StartingUseOrDef; + + MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q); + DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); + DEBUG(dbgs() << *StartingUseOrDef << "\n"); + DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); + DEBUG(dbgs() << *Clobber << "\n"); + return Clobber; +} + +MemoryAccess * +MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) { + auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA); + // If this is a MemoryPhi, we can't do anything. + if (!StartingAccess) + return MA; + + // If this is an already optimized use or def, return the optimized result. + // Note: Currently, we do not store the optimized def result because we'd need + // a separate field, since we can't use it as the defining access. + if (auto *MUD = dyn_cast<MemoryUseOrDef>(StartingAccess)) + if (MUD->isOptimized()) + return MUD->getOptimized(); + + const Instruction *I = StartingAccess->getMemoryInst(); + UpwardsMemoryQuery Q(I, StartingAccess); + // We can't sanely do anything with a fences, they conservatively + // clobber all memory, and have no locations to get pointers from to + // try to disambiguate. + if (!Q.IsCall && I->isFenceLike()) + return StartingAccess; + + if (isUseTriviallyOptimizableToLiveOnEntry(*MSSA->AA, I)) { + MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef(); + if (auto *MUD = dyn_cast<MemoryUseOrDef>(StartingAccess)) + MUD->setOptimized(LiveOnEntry); + return LiveOnEntry; + } + + // Start with the thing we already think clobbers this location + MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); + + // At this point, DefiningAccess may be the live on entry def. + // If it is, we will not get a better result. + if (MSSA->isLiveOnEntryDef(DefiningAccess)) + return DefiningAccess; + + MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q); + DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); + DEBUG(dbgs() << *DefiningAccess << "\n"); + DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); + DEBUG(dbgs() << *Result << "\n"); + if (auto *MUD = dyn_cast<MemoryUseOrDef>(StartingAccess)) + MUD->setOptimized(Result); + + return Result; +} + +MemoryAccess * +DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) { + if (auto *Use = dyn_cast<MemoryUseOrDef>(MA)) + return Use->getDefiningAccess(); + return MA; +} + +MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess( + MemoryAccess *StartingAccess, const MemoryLocation &) { + if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess)) + return Use->getDefiningAccess(); + return StartingAccess; +} +} // namespace llvm + +void MemoryPhi::deleteMe(DerivedUser *Self) { + delete static_cast<MemoryPhi *>(Self); +} + +void MemoryDef::deleteMe(DerivedUser *Self) { + delete static_cast<MemoryDef *>(Self); +} + +void MemoryUse::deleteMe(DerivedUser *Self) { + delete static_cast<MemoryUse *>(Self); +} diff --git a/contrib/llvm/lib/Analysis/MemorySSAUpdater.cpp b/contrib/llvm/lib/Analysis/MemorySSAUpdater.cpp new file mode 100644 index 000000000000..da5c79ab6c81 --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemorySSAUpdater.cpp @@ -0,0 +1,493 @@ +//===-- MemorySSAUpdater.cpp - Memory SSA Updater--------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------===// +// +// This file implements the MemorySSAUpdater class. +// +//===----------------------------------------------------------------===// +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Analysis/MemorySSA.h" +#include <algorithm> + +#define DEBUG_TYPE "memoryssa" +using namespace llvm; + +// This is the marker algorithm from "Simple and Efficient Construction of +// Static Single Assignment Form" +// The simple, non-marker algorithm places phi nodes at any join +// Here, we place markers, and only place phi nodes if they end up necessary. +// They are only necessary if they break a cycle (IE we recursively visit +// ourselves again), or we discover, while getting the value of the operands, +// that there are two or more definitions needing to be merged. +// This still will leave non-minimal form in the case of irreducible control +// flow, where phi nodes may be in cycles with themselves, but unnecessary. +MemoryAccess *MemorySSAUpdater::getPreviousDefRecursive(BasicBlock *BB) { + // Single predecessor case, just recurse, we can only have one definition. + if (BasicBlock *Pred = BB->getSinglePredecessor()) { + return getPreviousDefFromEnd(Pred); + } else if (VisitedBlocks.count(BB)) { + // We hit our node again, meaning we had a cycle, we must insert a phi + // node to break it so we have an operand. The only case this will + // insert useless phis is if we have irreducible control flow. + return MSSA->createMemoryPhi(BB); + } else if (VisitedBlocks.insert(BB).second) { + // Mark us visited so we can detect a cycle + SmallVector<MemoryAccess *, 8> PhiOps; + + // Recurse to get the values in our predecessors for placement of a + // potential phi node. This will insert phi nodes if we cycle in order to + // break the cycle and have an operand. + for (auto *Pred : predecessors(BB)) + PhiOps.push_back(getPreviousDefFromEnd(Pred)); + + // Now try to simplify the ops to avoid placing a phi. + // This may return null if we never created a phi yet, that's okay + MemoryPhi *Phi = dyn_cast_or_null<MemoryPhi>(MSSA->getMemoryAccess(BB)); + bool PHIExistsButNeedsUpdate = false; + // See if the existing phi operands match what we need. + // Unlike normal SSA, we only allow one phi node per block, so we can't just + // create a new one. + if (Phi && Phi->getNumOperands() != 0) + if (!std::equal(Phi->op_begin(), Phi->op_end(), PhiOps.begin())) { + PHIExistsButNeedsUpdate = true; + } + + // See if we can avoid the phi by simplifying it. + auto *Result = tryRemoveTrivialPhi(Phi, PhiOps); + // If we couldn't simplify, we may have to create a phi + if (Result == Phi) { + if (!Phi) + Phi = MSSA->createMemoryPhi(BB); + + // These will have been filled in by the recursive read we did above. + if (PHIExistsButNeedsUpdate) { + std::copy(PhiOps.begin(), PhiOps.end(), Phi->op_begin()); + std::copy(pred_begin(BB), pred_end(BB), Phi->block_begin()); + } else { + unsigned i = 0; + for (auto *Pred : predecessors(BB)) + Phi->addIncoming(PhiOps[i++], Pred); + } + + Result = Phi; + } + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Result)) + InsertedPHIs.push_back(MP); + // Set ourselves up for the next variable by resetting visited state. + VisitedBlocks.erase(BB); + return Result; + } + llvm_unreachable("Should have hit one of the three cases above"); +} + +// This starts at the memory access, and goes backwards in the block to find the +// previous definition. If a definition is not found the block of the access, +// it continues globally, creating phi nodes to ensure we have a single +// definition. +MemoryAccess *MemorySSAUpdater::getPreviousDef(MemoryAccess *MA) { + auto *LocalResult = getPreviousDefInBlock(MA); + + return LocalResult ? LocalResult : getPreviousDefRecursive(MA->getBlock()); +} + +// This starts at the memory access, and goes backwards in the block to the find +// the previous definition. If the definition is not found in the block of the +// access, it returns nullptr. +MemoryAccess *MemorySSAUpdater::getPreviousDefInBlock(MemoryAccess *MA) { + auto *Defs = MSSA->getWritableBlockDefs(MA->getBlock()); + + // It's possible there are no defs, or we got handed the first def to start. + if (Defs) { + // If this is a def, we can just use the def iterators. + if (!isa<MemoryUse>(MA)) { + auto Iter = MA->getReverseDefsIterator(); + ++Iter; + if (Iter != Defs->rend()) + return &*Iter; + } else { + // Otherwise, have to walk the all access iterator. + auto Iter = MA->getReverseIterator(); + ++Iter; + while (&*Iter != &*Defs->begin()) { + if (!isa<MemoryUse>(*Iter)) + return &*Iter; + --Iter; + } + // At this point it must be pointing at firstdef + assert(&*Iter == &*Defs->begin() && + "Should have hit first def walking backwards"); + return &*Iter; + } + } + return nullptr; +} + +// This starts at the end of block +MemoryAccess *MemorySSAUpdater::getPreviousDefFromEnd(BasicBlock *BB) { + auto *Defs = MSSA->getWritableBlockDefs(BB); + + if (Defs) + return &*Defs->rbegin(); + + return getPreviousDefRecursive(BB); +} +// Recurse over a set of phi uses to eliminate the trivial ones +MemoryAccess *MemorySSAUpdater::recursePhi(MemoryAccess *Phi) { + if (!Phi) + return nullptr; + TrackingVH<MemoryAccess> Res(Phi); + SmallVector<TrackingVH<Value>, 8> Uses; + std::copy(Phi->user_begin(), Phi->user_end(), std::back_inserter(Uses)); + for (auto &U : Uses) { + if (MemoryPhi *UsePhi = dyn_cast<MemoryPhi>(&*U)) { + auto OperRange = UsePhi->operands(); + tryRemoveTrivialPhi(UsePhi, OperRange); + } + } + return Res; +} + +// Eliminate trivial phis +// Phis are trivial if they are defined either by themselves, or all the same +// argument. +// IE phi(a, a) or b = phi(a, b) or c = phi(a, a, c) +// We recursively try to remove them. +template <class RangeType> +MemoryAccess *MemorySSAUpdater::tryRemoveTrivialPhi(MemoryPhi *Phi, + RangeType &Operands) { + // Detect equal or self arguments + MemoryAccess *Same = nullptr; + for (auto &Op : Operands) { + // If the same or self, good so far + if (Op == Phi || Op == Same) + continue; + // not the same, return the phi since it's not eliminatable by us + if (Same) + return Phi; + Same = cast<MemoryAccess>(Op); + } + // Never found a non-self reference, the phi is undef + if (Same == nullptr) + return MSSA->getLiveOnEntryDef(); + if (Phi) { + Phi->replaceAllUsesWith(Same); + removeMemoryAccess(Phi); + } + + // We should only end up recursing in case we replaced something, in which + // case, we may have made other Phis trivial. + return recursePhi(Same); +} + +void MemorySSAUpdater::insertUse(MemoryUse *MU) { + InsertedPHIs.clear(); + MU->setDefiningAccess(getPreviousDef(MU)); + // Unlike for defs, there is no extra work to do. Because uses do not create + // new may-defs, there are only two cases: + // + // 1. There was a def already below us, and therefore, we should not have + // created a phi node because it was already needed for the def. + // + // 2. There is no def below us, and therefore, there is no extra renaming work + // to do. +} + +// Set every incoming edge {BB, MP->getBlock()} of MemoryPhi MP to NewDef. +static void setMemoryPhiValueForBlock(MemoryPhi *MP, const BasicBlock *BB, + MemoryAccess *NewDef) { + // Replace any operand with us an incoming block with the new defining + // access. + int i = MP->getBasicBlockIndex(BB); + assert(i != -1 && "Should have found the basic block in the phi"); + // We can't just compare i against getNumOperands since one is signed and the + // other not. So use it to index into the block iterator. + for (auto BBIter = MP->block_begin() + i; BBIter != MP->block_end(); + ++BBIter) { + if (*BBIter != BB) + break; + MP->setIncomingValue(i, NewDef); + ++i; + } +} + +// A brief description of the algorithm: +// First, we compute what should define the new def, using the SSA +// construction algorithm. +// Then, we update the defs below us (and any new phi nodes) in the graph to +// point to the correct new defs, to ensure we only have one variable, and no +// disconnected stores. +void MemorySSAUpdater::insertDef(MemoryDef *MD, bool RenameUses) { + InsertedPHIs.clear(); + + // See if we had a local def, and if not, go hunting. + MemoryAccess *DefBefore = getPreviousDefInBlock(MD); + bool DefBeforeSameBlock = DefBefore != nullptr; + if (!DefBefore) + DefBefore = getPreviousDefRecursive(MD->getBlock()); + + // There is a def before us, which means we can replace any store/phi uses + // of that thing with us, since we are in the way of whatever was there + // before. + // We now define that def's memorydefs and memoryphis + if (DefBeforeSameBlock) { + for (auto UI = DefBefore->use_begin(), UE = DefBefore->use_end(); + UI != UE;) { + Use &U = *UI++; + // Leave the uses alone + if (isa<MemoryUse>(U.getUser())) + continue; + U.set(MD); + } + } + + // and that def is now our defining access. + // We change them in this order otherwise we will appear in the use list + // above and reset ourselves. + MD->setDefiningAccess(DefBefore); + + SmallVector<MemoryAccess *, 8> FixupList(InsertedPHIs.begin(), + InsertedPHIs.end()); + if (!DefBeforeSameBlock) { + // If there was a local def before us, we must have the same effect it + // did. Because every may-def is the same, any phis/etc we would create, it + // would also have created. If there was no local def before us, we + // performed a global update, and have to search all successors and make + // sure we update the first def in each of them (following all paths until + // we hit the first def along each path). This may also insert phi nodes. + // TODO: There are other cases we can skip this work, such as when we have a + // single successor, and only used a straight line of single pred blocks + // backwards to find the def. To make that work, we'd have to track whether + // getDefRecursive only ever used the single predecessor case. These types + // of paths also only exist in between CFG simplifications. + FixupList.push_back(MD); + } + + while (!FixupList.empty()) { + unsigned StartingPHISize = InsertedPHIs.size(); + fixupDefs(FixupList); + FixupList.clear(); + // Put any new phis on the fixup list, and process them + FixupList.append(InsertedPHIs.end() - StartingPHISize, InsertedPHIs.end()); + } + // Now that all fixups are done, rename all uses if we are asked. + if (RenameUses) { + SmallPtrSet<BasicBlock *, 16> Visited; + BasicBlock *StartBlock = MD->getBlock(); + // We are guaranteed there is a def in the block, because we just got it + // handed to us in this function. + MemoryAccess *FirstDef = &*MSSA->getWritableBlockDefs(StartBlock)->begin(); + // Convert to incoming value if it's a memorydef. A phi *is* already an + // incoming value. + if (auto *MD = dyn_cast<MemoryDef>(FirstDef)) + FirstDef = MD->getDefiningAccess(); + + MSSA->renamePass(MD->getBlock(), FirstDef, Visited); + // We just inserted a phi into this block, so the incoming value will become + // the phi anyway, so it does not matter what we pass. + for (auto *MP : InsertedPHIs) + MSSA->renamePass(MP->getBlock(), nullptr, Visited); + } +} + +void MemorySSAUpdater::fixupDefs(const SmallVectorImpl<MemoryAccess *> &Vars) { + SmallPtrSet<const BasicBlock *, 8> Seen; + SmallVector<const BasicBlock *, 16> Worklist; + for (auto *NewDef : Vars) { + // First, see if there is a local def after the operand. + auto *Defs = MSSA->getWritableBlockDefs(NewDef->getBlock()); + auto DefIter = NewDef->getDefsIterator(); + + // If there is a local def after us, we only have to rename that. + if (++DefIter != Defs->end()) { + cast<MemoryDef>(DefIter)->setDefiningAccess(NewDef); + continue; + } + + // Otherwise, we need to search down through the CFG. + // For each of our successors, handle it directly if their is a phi, or + // place on the fixup worklist. + for (const auto *S : successors(NewDef->getBlock())) { + if (auto *MP = MSSA->getMemoryAccess(S)) + setMemoryPhiValueForBlock(MP, NewDef->getBlock(), NewDef); + else + Worklist.push_back(S); + } + + while (!Worklist.empty()) { + const BasicBlock *FixupBlock = Worklist.back(); + Worklist.pop_back(); + + // Get the first def in the block that isn't a phi node. + if (auto *Defs = MSSA->getWritableBlockDefs(FixupBlock)) { + auto *FirstDef = &*Defs->begin(); + // The loop above and below should have taken care of phi nodes + assert(!isa<MemoryPhi>(FirstDef) && + "Should have already handled phi nodes!"); + // We are now this def's defining access, make sure we actually dominate + // it + assert(MSSA->dominates(NewDef, FirstDef) && + "Should have dominated the new access"); + + // This may insert new phi nodes, because we are not guaranteed the + // block we are processing has a single pred, and depending where the + // store was inserted, it may require phi nodes below it. + cast<MemoryDef>(FirstDef)->setDefiningAccess(getPreviousDef(FirstDef)); + return; + } + // We didn't find a def, so we must continue. + for (const auto *S : successors(FixupBlock)) { + // If there is a phi node, handle it. + // Otherwise, put the block on the worklist + if (auto *MP = MSSA->getMemoryAccess(S)) + setMemoryPhiValueForBlock(MP, FixupBlock, NewDef); + else { + // If we cycle, we should have ended up at a phi node that we already + // processed. FIXME: Double check this + if (!Seen.insert(S).second) + continue; + Worklist.push_back(S); + } + } + } + } +} + +// Move What before Where in the MemorySSA IR. +template <class WhereType> +void MemorySSAUpdater::moveTo(MemoryUseOrDef *What, BasicBlock *BB, + WhereType Where) { + // Replace all our users with our defining access. + What->replaceAllUsesWith(What->getDefiningAccess()); + + // Let MemorySSA take care of moving it around in the lists. + MSSA->moveTo(What, BB, Where); + + // Now reinsert it into the IR and do whatever fixups needed. + if (auto *MD = dyn_cast<MemoryDef>(What)) + insertDef(MD); + else + insertUse(cast<MemoryUse>(What)); +} + +// Move What before Where in the MemorySSA IR. +void MemorySSAUpdater::moveBefore(MemoryUseOrDef *What, MemoryUseOrDef *Where) { + moveTo(What, Where->getBlock(), Where->getIterator()); +} + +// Move What after Where in the MemorySSA IR. +void MemorySSAUpdater::moveAfter(MemoryUseOrDef *What, MemoryUseOrDef *Where) { + moveTo(What, Where->getBlock(), ++Where->getIterator()); +} + +void MemorySSAUpdater::moveToPlace(MemoryUseOrDef *What, BasicBlock *BB, + MemorySSA::InsertionPlace Where) { + return moveTo(What, BB, Where); +} + +/// \brief If all arguments of a MemoryPHI are defined by the same incoming +/// argument, return that argument. +static MemoryAccess *onlySingleValue(MemoryPhi *MP) { + MemoryAccess *MA = nullptr; + + for (auto &Arg : MP->operands()) { + if (!MA) + MA = cast<MemoryAccess>(Arg); + else if (MA != Arg) + return nullptr; + } + return MA; +} + +void MemorySSAUpdater::removeMemoryAccess(MemoryAccess *MA) { + assert(!MSSA->isLiveOnEntryDef(MA) && + "Trying to remove the live on entry def"); + // We can only delete phi nodes if they have no uses, or we can replace all + // uses with a single definition. + MemoryAccess *NewDefTarget = nullptr; + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(MA)) { + // Note that it is sufficient to know that all edges of the phi node have + // the same argument. If they do, by the definition of dominance frontiers + // (which we used to place this phi), that argument must dominate this phi, + // and thus, must dominate the phi's uses, and so we will not hit the assert + // below. + NewDefTarget = onlySingleValue(MP); + assert((NewDefTarget || MP->use_empty()) && + "We can't delete this memory phi"); + } else { + NewDefTarget = cast<MemoryUseOrDef>(MA)->getDefiningAccess(); + } + + // Re-point the uses at our defining access + if (!isa<MemoryUse>(MA) && !MA->use_empty()) { + // Reset optimized on users of this store, and reset the uses. + // A few notes: + // 1. This is a slightly modified version of RAUW to avoid walking the + // uses twice here. + // 2. If we wanted to be complete, we would have to reset the optimized + // flags on users of phi nodes if doing the below makes a phi node have all + // the same arguments. Instead, we prefer users to removeMemoryAccess those + // phi nodes, because doing it here would be N^3. + if (MA->hasValueHandle()) + ValueHandleBase::ValueIsRAUWd(MA, NewDefTarget); + // Note: We assume MemorySSA is not used in metadata since it's not really + // part of the IR. + + while (!MA->use_empty()) { + Use &U = *MA->use_begin(); + if (auto *MUD = dyn_cast<MemoryUseOrDef>(U.getUser())) + MUD->resetOptimized(); + U.set(NewDefTarget); + } + } + + // The call below to erase will destroy MA, so we can't change the order we + // are doing things here + MSSA->removeFromLookups(MA); + MSSA->removeFromLists(MA); +} + +MemoryAccess *MemorySSAUpdater::createMemoryAccessInBB( + Instruction *I, MemoryAccess *Definition, const BasicBlock *BB, + MemorySSA::InsertionPlace Point) { + MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); + MSSA->insertIntoListsForBlock(NewAccess, BB, Point); + return NewAccess; +} + +MemoryUseOrDef *MemorySSAUpdater::createMemoryAccessBefore( + Instruction *I, MemoryAccess *Definition, MemoryUseOrDef *InsertPt) { + assert(I->getParent() == InsertPt->getBlock() && + "New and old access must be in the same block"); + MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); + MSSA->insertIntoListsBefore(NewAccess, InsertPt->getBlock(), + InsertPt->getIterator()); + return NewAccess; +} + +MemoryUseOrDef *MemorySSAUpdater::createMemoryAccessAfter( + Instruction *I, MemoryAccess *Definition, MemoryAccess *InsertPt) { + assert(I->getParent() == InsertPt->getBlock() && + "New and old access must be in the same block"); + MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); + MSSA->insertIntoListsBefore(NewAccess, InsertPt->getBlock(), + ++InsertPt->getIterator()); + return NewAccess; +} diff --git a/contrib/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/contrib/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp index f5ba637e58e2..3253f27c010d 100644 --- a/contrib/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp +++ b/contrib/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -28,7 +28,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/ValueSymbolTable.h" -#include "llvm/Object/IRObjectFile.h" +#include "llvm/Object/ModuleSymbolTable.h" #include "llvm/Pass.h" using namespace llvm; @@ -37,7 +37,8 @@ using namespace llvm; // Walk through the operands of a given User via worklist iteration and populate // the set of GlobalValue references encountered. Invoked either on an // Instruction or a GlobalVariable (which walks its initializer). -static void findRefEdges(const User *CurUser, SetVector<ValueInfo> &RefEdges, +static void findRefEdges(ModuleSummaryIndex &Index, const User *CurUser, + SetVector<ValueInfo> &RefEdges, SmallPtrSet<const User *, 8> &Visited) { SmallVector<const User *, 32> Worklist; Worklist.push_back(CurUser); @@ -61,7 +62,7 @@ static void findRefEdges(const User *CurUser, SetVector<ValueInfo> &RefEdges, // the reference set unless it is a callee. Callees are handled // specially by WriteFunction and are added to a separate list. if (!(CS && CS.isCallee(&OI))) - RefEdges.insert(GV); + RefEdges.insert(Index.getOrInsertValueInfo(GV)); continue; } Worklist.push_back(Operand); @@ -84,6 +85,92 @@ static bool isNonRenamableLocal(const GlobalValue &GV) { return GV.hasSection() && GV.hasLocalLinkage(); } +/// Determine whether this call has all constant integer arguments (excluding +/// "this") and summarize it to VCalls or ConstVCalls as appropriate. +static void addVCallToSet(DevirtCallSite Call, GlobalValue::GUID Guid, + SetVector<FunctionSummary::VFuncId> &VCalls, + SetVector<FunctionSummary::ConstVCall> &ConstVCalls) { + std::vector<uint64_t> Args; + // Start from the second argument to skip the "this" pointer. + for (auto &Arg : make_range(Call.CS.arg_begin() + 1, Call.CS.arg_end())) { + auto *CI = dyn_cast<ConstantInt>(Arg); + if (!CI || CI->getBitWidth() > 64) { + VCalls.insert({Guid, Call.Offset}); + return; + } + Args.push_back(CI->getZExtValue()); + } + ConstVCalls.insert({{Guid, Call.Offset}, std::move(Args)}); +} + +/// If this intrinsic call requires that we add information to the function +/// summary, do so via the non-constant reference arguments. +static void addIntrinsicToSummary( + const CallInst *CI, SetVector<GlobalValue::GUID> &TypeTests, + SetVector<FunctionSummary::VFuncId> &TypeTestAssumeVCalls, + SetVector<FunctionSummary::VFuncId> &TypeCheckedLoadVCalls, + SetVector<FunctionSummary::ConstVCall> &TypeTestAssumeConstVCalls, + SetVector<FunctionSummary::ConstVCall> &TypeCheckedLoadConstVCalls) { + switch (CI->getCalledFunction()->getIntrinsicID()) { + case Intrinsic::type_test: { + auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1)); + auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata()); + if (!TypeId) + break; + GlobalValue::GUID Guid = GlobalValue::getGUID(TypeId->getString()); + + // Produce a summary from type.test intrinsics. We only summarize type.test + // intrinsics that are used other than by an llvm.assume intrinsic. + // Intrinsics that are assumed are relevant only to the devirtualization + // pass, not the type test lowering pass. + bool HasNonAssumeUses = llvm::any_of(CI->uses(), [](const Use &CIU) { + auto *AssumeCI = dyn_cast<CallInst>(CIU.getUser()); + if (!AssumeCI) + return true; + Function *F = AssumeCI->getCalledFunction(); + return !F || F->getIntrinsicID() != Intrinsic::assume; + }); + if (HasNonAssumeUses) + TypeTests.insert(Guid); + + SmallVector<DevirtCallSite, 4> DevirtCalls; + SmallVector<CallInst *, 4> Assumes; + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); + for (auto &Call : DevirtCalls) + addVCallToSet(Call, Guid, TypeTestAssumeVCalls, + TypeTestAssumeConstVCalls); + + break; + } + + case Intrinsic::type_checked_load: { + auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(2)); + auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata()); + if (!TypeId) + break; + GlobalValue::GUID Guid = GlobalValue::getGUID(TypeId->getString()); + + SmallVector<DevirtCallSite, 4> DevirtCalls; + SmallVector<Instruction *, 4> LoadedPtrs; + SmallVector<Instruction *, 4> Preds; + bool HasNonCallUses = false; + findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, + HasNonCallUses, CI); + // Any non-call uses of the result of llvm.type.checked.load will + // prevent us from optimizing away the llvm.type.test. + if (HasNonCallUses) + TypeTests.insert(Guid); + for (auto &Call : DevirtCalls) + addVCallToSet(Call, Guid, TypeCheckedLoadVCalls, + TypeCheckedLoadConstVCalls); + + break; + } + default: + break; + } +} + static void computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, const Function &F, BlockFrequencyInfo *BFI, @@ -99,6 +186,10 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, MapVector<ValueInfo, CalleeInfo> CallGraphEdges; SetVector<ValueInfo> RefEdges; SetVector<GlobalValue::GUID> TypeTests; + SetVector<FunctionSummary::VFuncId> TypeTestAssumeVCalls, + TypeCheckedLoadVCalls; + SetVector<FunctionSummary::ConstVCall> TypeTestAssumeConstVCalls, + TypeCheckedLoadConstVCalls; ICallPromotionAnalysis ICallAnalysis; bool HasInlineAsmMaybeReferencingInternal = false; @@ -108,7 +199,7 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, if (isa<DbgInfoIntrinsic>(I)) continue; ++NumInsts; - findRefEdges(&I, RefEdges, Visited); + findRefEdges(Index, &I, RefEdges, Visited); auto CS = ImmutableCallSite(&I); if (!CS) continue; @@ -133,29 +224,15 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, // Check if this is a direct call to a known function or a known // intrinsic, or an indirect call with profile data. if (CalledFunction) { - if (CalledFunction->isIntrinsic()) { - if (CalledFunction->getIntrinsicID() != Intrinsic::type_test) - continue; - // Produce a summary from type.test intrinsics. We only summarize - // type.test intrinsics that are used other than by an llvm.assume - // intrinsic. Intrinsics that are assumed are relevant only to the - // devirtualization pass, not the type test lowering pass. - bool HasNonAssumeUses = llvm::any_of(CI->uses(), [](const Use &CIU) { - auto *AssumeCI = dyn_cast<CallInst>(CIU.getUser()); - if (!AssumeCI) - return true; - Function *F = AssumeCI->getCalledFunction(); - return !F || F->getIntrinsicID() != Intrinsic::assume; - }); - if (HasNonAssumeUses) { - auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1)); - if (auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata())) - TypeTests.insert(GlobalValue::getGUID(TypeId->getString())); - } + if (CI && CalledFunction->isIntrinsic()) { + addIntrinsicToSummary( + CI, TypeTests, TypeTestAssumeVCalls, TypeCheckedLoadVCalls, + TypeTestAssumeConstVCalls, TypeCheckedLoadConstVCalls); + continue; } // We should have named any anonymous globals assert(CalledFunction->hasName()); - auto ScaledCount = BFI ? BFI->getBlockProfileCount(&BB) : None; + auto ScaledCount = PSI->getProfileCount(&I, BFI); auto Hotness = ScaledCount ? getHotness(ScaledCount.getValue(), PSI) : CalleeInfo::HotnessType::Unknown; @@ -163,7 +240,9 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, // to record the call edge to the alias in that case. Eventually // an alias summary will be created to associate the alias and // aliasee. - CallGraphEdges[cast<GlobalValue>(CalledValue)].updateHotness(Hotness); + CallGraphEdges[Index.getOrInsertValueInfo( + cast<GlobalValue>(CalledValue))] + .updateHotness(Hotness); } else { // Skip inline assembly calls. if (CI && CI->isInlineAsm()) @@ -178,11 +257,17 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, ICallAnalysis.getPromotionCandidatesForInstruction( &I, NumVals, TotalCount, NumCandidates); for (auto &Candidate : CandidateProfileData) - CallGraphEdges[Candidate.Value].updateHotness( - getHotness(Candidate.Count, PSI)); + CallGraphEdges[Index.getOrInsertValueInfo(Candidate.Value)] + .updateHotness(getHotness(Candidate.Count, PSI)); } } + // Explicit add hot edges to enforce importing for designated GUIDs for + // sample PGO, to enable the same inlines as the profiled optimized binary. + for (auto &I : F.getImportGUIDs()) + CallGraphEdges[Index.getOrInsertValueInfo(I)].updateHotness( + CalleeInfo::HotnessType::Hot); + bool NonRenamableLocal = isNonRenamableLocal(F); bool NotEligibleForImport = NonRenamableLocal || HasInlineAsmMaybeReferencingInternal || @@ -190,10 +275,13 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, // FIXME: refactor this to use the same code that inliner is using. F.isVarArg(); GlobalValueSummary::GVFlags Flags(F.getLinkage(), NotEligibleForImport, - /* LiveRoot = */ false); + /* Live = */ false); auto FuncSummary = llvm::make_unique<FunctionSummary>( Flags, NumInsts, RefEdges.takeVector(), CallGraphEdges.takeVector(), - TypeTests.takeVector()); + TypeTests.takeVector(), TypeTestAssumeVCalls.takeVector(), + TypeCheckedLoadVCalls.takeVector(), + TypeTestAssumeConstVCalls.takeVector(), + TypeCheckedLoadConstVCalls.takeVector()); if (NonRenamableLocal) CantBePromoted.insert(F.getGUID()); Index.addGlobalValueSummary(F.getName(), std::move(FuncSummary)); @@ -204,10 +292,10 @@ computeVariableSummary(ModuleSummaryIndex &Index, const GlobalVariable &V, DenseSet<GlobalValue::GUID> &CantBePromoted) { SetVector<ValueInfo> RefEdges; SmallPtrSet<const User *, 8> Visited; - findRefEdges(&V, RefEdges, Visited); + findRefEdges(Index, &V, RefEdges, Visited); bool NonRenamableLocal = isNonRenamableLocal(V); GlobalValueSummary::GVFlags Flags(V.getLinkage(), NonRenamableLocal, - /* LiveRoot = */ false); + /* Live = */ false); auto GVarSummary = llvm::make_unique<GlobalVarSummary>(Flags, RefEdges.takeVector()); if (NonRenamableLocal) @@ -220,7 +308,7 @@ computeAliasSummary(ModuleSummaryIndex &Index, const GlobalAlias &A, DenseSet<GlobalValue::GUID> &CantBePromoted) { bool NonRenamableLocal = isNonRenamableLocal(A); GlobalValueSummary::GVFlags Flags(A.getLinkage(), NonRenamableLocal, - /* LiveRoot = */ false); + /* Live = */ false); auto AS = llvm::make_unique<AliasSummary>(Flags, ArrayRef<ValueInfo>{}); auto *Aliasee = A.getBaseObject(); auto *AliaseeSummary = Index.getGlobalValueSummary(*Aliasee); @@ -233,18 +321,16 @@ computeAliasSummary(ModuleSummaryIndex &Index, const GlobalAlias &A, // Set LiveRoot flag on entries matching the given value name. static void setLiveRoot(ModuleSummaryIndex &Index, StringRef Name) { - auto SummaryList = - Index.findGlobalValueSummaryList(GlobalValue::getGUID(Name)); - if (SummaryList == Index.end()) - return; - for (auto &Summary : SummaryList->second) - Summary->setLiveRoot(); + if (ValueInfo VI = Index.getValueInfo(GlobalValue::getGUID(Name))) + for (auto &Summary : VI.getSummaryList()) + Summary->setLive(true); } ModuleSummaryIndex llvm::buildModuleSummaryIndex( const Module &M, std::function<BlockFrequencyInfo *(const Function &F)> GetBFICallback, ProfileSummaryInfo *PSI) { + assert(PSI); ModuleSummaryIndex Index; // Identify the local values in the llvm.used and llvm.compiler.used sets, @@ -326,9 +412,8 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( // be listed on the llvm.used or llvm.compiler.used global and marked as // referenced from there. ModuleSymbolTable::CollectAsmSymbols( - Triple(M.getTargetTriple()), M.getModuleInlineAsm(), - [&M, &Index, &CantBePromoted](StringRef Name, - object::BasicSymbolRef::Flags Flags) { + M, [&M, &Index, &CantBePromoted](StringRef Name, + object::BasicSymbolRef::Flags Flags) { // Symbols not marked as Weak or Global are local definitions. if (Flags & (object::BasicSymbolRef::SF_Weak | object::BasicSymbolRef::SF_Global)) @@ -338,8 +423,8 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( return; assert(GV->isDeclaration() && "Def in module asm already has definition"); GlobalValueSummary::GVFlags GVFlags(GlobalValue::InternalLinkage, - /* NotEligibleToImport */ true, - /* LiveRoot */ true); + /* NotEligibleToImport = */ true, + /* Live = */ true); CantBePromoted.insert(GlobalValue::getGUID(Name)); // Create the appropriate summary type. if (isa<Function>(GV)) { @@ -347,7 +432,11 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( llvm::make_unique<FunctionSummary>( GVFlags, 0, ArrayRef<ValueInfo>{}, ArrayRef<FunctionSummary::EdgeTy>{}, - ArrayRef<GlobalValue::GUID>{}); + ArrayRef<GlobalValue::GUID>{}, + ArrayRef<FunctionSummary::VFuncId>{}, + ArrayRef<FunctionSummary::VFuncId>{}, + ArrayRef<FunctionSummary::ConstVCall>{}, + ArrayRef<FunctionSummary::ConstVCall>{}); Index.addGlobalValueSummary(Name, std::move(Summary)); } else { std::unique_ptr<GlobalVarSummary> Summary = @@ -359,12 +448,16 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( } for (auto &GlobalList : Index) { - assert(GlobalList.second.size() == 1 && + // Ignore entries for references that are undefined in the current module. + if (GlobalList.second.SummaryList.empty()) + continue; + + assert(GlobalList.second.SummaryList.size() == 1 && "Expected module's index to have one summary per GUID"); - auto &Summary = GlobalList.second[0]; + auto &Summary = GlobalList.second.SummaryList[0]; bool AllRefsCanBeExternallyReferenced = llvm::all_of(Summary->refs(), [&](const ValueInfo &VI) { - return !CantBePromoted.count(VI.getValue()->getGUID()); + return !CantBePromoted.count(VI.getGUID()); }); if (!AllRefsCanBeExternallyReferenced) { Summary->setNotEligibleToImport(); @@ -374,9 +467,7 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( if (auto *FuncSummary = dyn_cast<FunctionSummary>(Summary.get())) { bool AllCallsCanBeExternallyReferenced = llvm::all_of( FuncSummary->calls(), [&](const FunctionSummary::EdgeTy &Edge) { - auto GUID = Edge.first.isGUID() ? Edge.first.getGUID() - : Edge.first.getValue()->getGUID(); - return !CantBePromoted.count(GUID); + return !CantBePromoted.count(Edge.first.getGUID()); }); if (!AllCallsCanBeExternallyReferenced) Summary->setNotEligibleToImport(); diff --git a/contrib/llvm/lib/Analysis/OptimizationDiagnosticInfo.cpp b/contrib/llvm/lib/Analysis/OptimizationDiagnosticInfo.cpp index fa8b07d61b01..e38e530c052d 100644 --- a/contrib/llvm/lib/Analysis/OptimizationDiagnosticInfo.cpp +++ b/contrib/llvm/lib/Analysis/OptimizationDiagnosticInfo.cpp @@ -23,14 +23,14 @@ using namespace llvm; -OptimizationRemarkEmitter::OptimizationRemarkEmitter(Function *F) +OptimizationRemarkEmitter::OptimizationRemarkEmitter(const Function *F) : F(F), BFI(nullptr) { if (!F->getContext().getDiagnosticHotnessRequested()) return; // First create a dominator tree. DominatorTree DT; - DT.recalculate(*F); + DT.recalculate(*const_cast<Function *>(F)); // Generate LoopInfo from it. LoopInfo LI; @@ -45,6 +45,18 @@ OptimizationRemarkEmitter::OptimizationRemarkEmitter(Function *F) BFI = OwnedBFI.get(); } +bool OptimizationRemarkEmitter::invalidate( + Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // This analysis has no state and so can be trivially preserved but it needs + // a fresh view of BFI if it was constructed with one. + if (BFI && Inv.invalidate<BlockFrequencyAnalysis>(F, PA)) + return true; + + // Otherwise this analysis result remains valid. + return false; +} + Optional<uint64_t> OptimizationRemarkEmitter::computeHotness(const Value *V) { if (!BFI) return None; @@ -55,53 +67,59 @@ Optional<uint64_t> OptimizationRemarkEmitter::computeHotness(const Value *V) { namespace llvm { namespace yaml { -template <> struct MappingTraits<DiagnosticInfoOptimizationBase *> { - static void mapping(IO &io, DiagnosticInfoOptimizationBase *&OptDiag) { - assert(io.outputting() && "input not yet implemented"); +void MappingTraits<DiagnosticInfoOptimizationBase *>::mapping( + IO &io, DiagnosticInfoOptimizationBase *&OptDiag) { + assert(io.outputting() && "input not yet implemented"); + + if (io.mapTag("!Passed", + (OptDiag->getKind() == DK_OptimizationRemark || + OptDiag->getKind() == DK_MachineOptimizationRemark))) + ; + else if (io.mapTag( + "!Missed", + (OptDiag->getKind() == DK_OptimizationRemarkMissed || + OptDiag->getKind() == DK_MachineOptimizationRemarkMissed))) + ; + else if (io.mapTag( + "!Analysis", + (OptDiag->getKind() == DK_OptimizationRemarkAnalysis || + OptDiag->getKind() == DK_MachineOptimizationRemarkAnalysis))) + ; + else if (io.mapTag("!AnalysisFPCommute", + OptDiag->getKind() == + DK_OptimizationRemarkAnalysisFPCommute)) + ; + else if (io.mapTag("!AnalysisAliasing", + OptDiag->getKind() == + DK_OptimizationRemarkAnalysisAliasing)) + ; + else if (io.mapTag("!Failure", OptDiag->getKind() == DK_OptimizationFailure)) + ; + else + llvm_unreachable("Unknown remark type"); - if (io.mapTag("!Passed", OptDiag->getKind() == DK_OptimizationRemark)) - ; - else if (io.mapTag("!Missed", - OptDiag->getKind() == DK_OptimizationRemarkMissed)) - ; - else if (io.mapTag("!Analysis", - OptDiag->getKind() == DK_OptimizationRemarkAnalysis)) - ; - else if (io.mapTag("!AnalysisFPCommute", - OptDiag->getKind() == - DK_OptimizationRemarkAnalysisFPCommute)) - ; - else if (io.mapTag("!AnalysisAliasing", - OptDiag->getKind() == - DK_OptimizationRemarkAnalysisAliasing)) - ; - else - llvm_unreachable("todo"); - - // These are read-only for now. - DebugLoc DL = OptDiag->getDebugLoc(); - StringRef FN = GlobalValue::getRealLinkageName( - OptDiag->getFunction().getName()); - - StringRef PassName(OptDiag->PassName); - io.mapRequired("Pass", PassName); - io.mapRequired("Name", OptDiag->RemarkName); - if (!io.outputting() || DL) - io.mapOptional("DebugLoc", DL); - io.mapRequired("Function", FN); - io.mapOptional("Hotness", OptDiag->Hotness); - io.mapOptional("Args", OptDiag->Args); - } -}; + // These are read-only for now. + DiagnosticLocation DL = OptDiag->getLocation(); + StringRef FN = + GlobalValue::dropLLVMManglingEscape(OptDiag->getFunction().getName()); + + StringRef PassName(OptDiag->PassName); + io.mapRequired("Pass", PassName); + io.mapRequired("Name", OptDiag->RemarkName); + if (!io.outputting() || DL.isValid()) + io.mapOptional("DebugLoc", DL); + io.mapRequired("Function", FN); + io.mapOptional("Hotness", OptDiag->Hotness); + io.mapOptional("Args", OptDiag->Args); +} -template <> struct MappingTraits<DebugLoc> { - static void mapping(IO &io, DebugLoc &DL) { +template <> struct MappingTraits<DiagnosticLocation> { + static void mapping(IO &io, DiagnosticLocation &DL) { assert(io.outputting() && "input not yet implemented"); - auto *Scope = cast<DIScope>(DL.getScope()); - StringRef File = Scope->getFilename(); + StringRef File = DL.getFilename(); unsigned Line = DL.getLine(); - unsigned Col = DL.getCol(); + unsigned Col = DL.getColumn(); io.mapRequired("File", File); io.mapRequired("Line", Line); @@ -116,8 +134,8 @@ template <> struct MappingTraits<DiagnosticInfoOptimizationBase::Argument> { static void mapping(IO &io, DiagnosticInfoOptimizationBase::Argument &A) { assert(io.outputting() && "input not yet implemented"); io.mapRequired(A.Key.data(), A.Val); - if (A.DLoc) - io.mapOptional("DebugLoc", A.DLoc); + if (A.Loc.isValid()) + io.mapOptional("DebugLoc", A.Loc); } }; @@ -127,18 +145,20 @@ template <> struct MappingTraits<DiagnosticInfoOptimizationBase::Argument> { LLVM_YAML_IS_SEQUENCE_VECTOR(DiagnosticInfoOptimizationBase::Argument) void OptimizationRemarkEmitter::computeHotness( - DiagnosticInfoOptimizationBase &OptDiag) { - Value *V = OptDiag.getCodeRegion(); + DiagnosticInfoIROptimization &OptDiag) { + const Value *V = OptDiag.getCodeRegion(); if (V) OptDiag.setHotness(computeHotness(V)); } -void OptimizationRemarkEmitter::emit(DiagnosticInfoOptimizationBase &OptDiag) { +void OptimizationRemarkEmitter::emit( + DiagnosticInfoOptimizationBase &OptDiagBase) { + auto &OptDiag = cast<DiagnosticInfoIROptimization>(OptDiagBase); computeHotness(OptDiag); yaml::Output *Out = F->getContext().getDiagnosticsOutputFile(); if (Out) { - auto *P = &const_cast<DiagnosticInfoOptimizationBase &>(OptDiag); + auto *P = const_cast<DiagnosticInfoOptimizationBase *>(&OptDiagBase); *Out << P; } // FIXME: now that IsVerbose is part of DI, filtering for this will be moved @@ -147,72 +167,6 @@ void OptimizationRemarkEmitter::emit(DiagnosticInfoOptimizationBase &OptDiag) { F->getContext().diagnose(OptDiag); } -void OptimizationRemarkEmitter::emitOptimizationRemark(const char *PassName, - const DebugLoc &DLoc, - const Value *V, - const Twine &Msg) { - LLVMContext &Ctx = F->getContext(); - Ctx.diagnose(OptimizationRemark(PassName, *F, DLoc, Msg, computeHotness(V))); -} - -void OptimizationRemarkEmitter::emitOptimizationRemark(const char *PassName, - Loop *L, - const Twine &Msg) { - emitOptimizationRemark(PassName, L->getStartLoc(), L->getHeader(), Msg); -} - -void OptimizationRemarkEmitter::emitOptimizationRemarkMissed( - const char *PassName, const DebugLoc &DLoc, const Value *V, - const Twine &Msg, bool IsVerbose) { - LLVMContext &Ctx = F->getContext(); - if (!IsVerbose || shouldEmitVerbose()) - Ctx.diagnose( - OptimizationRemarkMissed(PassName, *F, DLoc, Msg, computeHotness(V))); -} - -void OptimizationRemarkEmitter::emitOptimizationRemarkMissed( - const char *PassName, Loop *L, const Twine &Msg, bool IsVerbose) { - emitOptimizationRemarkMissed(PassName, L->getStartLoc(), L->getHeader(), Msg, - IsVerbose); -} - -void OptimizationRemarkEmitter::emitOptimizationRemarkAnalysis( - const char *PassName, const DebugLoc &DLoc, const Value *V, - const Twine &Msg, bool IsVerbose) { - LLVMContext &Ctx = F->getContext(); - if (!IsVerbose || shouldEmitVerbose()) - Ctx.diagnose( - OptimizationRemarkAnalysis(PassName, *F, DLoc, Msg, computeHotness(V))); -} - -void OptimizationRemarkEmitter::emitOptimizationRemarkAnalysis( - const char *PassName, Loop *L, const Twine &Msg, bool IsVerbose) { - emitOptimizationRemarkAnalysis(PassName, L->getStartLoc(), L->getHeader(), - Msg, IsVerbose); -} - -void OptimizationRemarkEmitter::emitOptimizationRemarkAnalysisFPCommute( - const char *PassName, const DebugLoc &DLoc, const Value *V, - const Twine &Msg) { - LLVMContext &Ctx = F->getContext(); - Ctx.diagnose(OptimizationRemarkAnalysisFPCommute(PassName, *F, DLoc, Msg, - computeHotness(V))); -} - -void OptimizationRemarkEmitter::emitOptimizationRemarkAnalysisAliasing( - const char *PassName, const DebugLoc &DLoc, const Value *V, - const Twine &Msg) { - LLVMContext &Ctx = F->getContext(); - Ctx.diagnose(OptimizationRemarkAnalysisAliasing(PassName, *F, DLoc, Msg, - computeHotness(V))); -} - -void OptimizationRemarkEmitter::emitOptimizationRemarkAnalysisAliasing( - const char *PassName, Loop *L, const Twine &Msg) { - emitOptimizationRemarkAnalysisAliasing(PassName, L->getStartLoc(), - L->getHeader(), Msg); -} - OptimizationRemarkEmitterWrapperPass::OptimizationRemarkEmitterWrapperPass() : FunctionPass(ID) { initializeOptimizationRemarkEmitterWrapperPassPass( diff --git a/contrib/llvm/lib/Analysis/OrderedBasicBlock.cpp b/contrib/llvm/lib/Analysis/OrderedBasicBlock.cpp index 0f0016f22cc0..a04c0aef04be 100644 --- a/contrib/llvm/lib/Analysis/OrderedBasicBlock.cpp +++ b/contrib/llvm/lib/Analysis/OrderedBasicBlock.cpp @@ -55,7 +55,7 @@ bool OrderedBasicBlock::comesBefore(const Instruction *A, assert(II != IE && "Instruction not found?"); assert((Inst == A || Inst == B) && "Should find A or B"); LastInstFound = II; - return Inst == A; + return Inst != B; } /// \brief Find out whether \p A dominates \p B, meaning whether \p A diff --git a/contrib/llvm/lib/Analysis/PHITransAddr.cpp b/contrib/llvm/lib/Analysis/PHITransAddr.cpp index 84ecd4ab9809..682af4dc708e 100644 --- a/contrib/llvm/lib/Analysis/PHITransAddr.cpp +++ b/contrib/llvm/lib/Analysis/PHITransAddr.cpp @@ -227,7 +227,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, // Simplify the GEP to handle 'gep x, 0' -> x etc. if (Value *V = SimplifyGEPInst(GEP->getSourceElementType(), - GEPOps, DL, TLI, DT, AC)) { + GEPOps, {DL, TLI, DT, AC})) { for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) RemoveInstInputs(GEPOps[i], InstInputs); @@ -276,7 +276,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, } // See if the add simplifies away. - if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, DL, TLI, DT, AC)) { + if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, {DL, TLI, DT, AC})) { // If we simplified the operands, the LHS is no longer an input, but Res // is. RemoveInstInputs(LHS, InstInputs); diff --git a/contrib/llvm/lib/Analysis/PostDominators.cpp b/contrib/llvm/lib/Analysis/PostDominators.cpp index cb9438a2f928..1caf151546d9 100644 --- a/contrib/llvm/lib/Analysis/PostDominators.cpp +++ b/contrib/llvm/lib/Analysis/PostDominators.cpp @@ -31,6 +31,15 @@ char PostDominatorTreeWrapperPass::ID = 0; INITIALIZE_PASS(PostDominatorTreeWrapperPass, "postdomtree", "Post-Dominator Tree Construction", true, true) +bool PostDominatorTree::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<PostDominatorTreeAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + bool PostDominatorTreeWrapperPass::runOnFunction(Function &F) { DT.recalculate(F); return false; diff --git a/contrib/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/contrib/llvm/lib/Analysis/ProfileSummaryInfo.cpp index 16d3614c14c6..12b86daa602b 100644 --- a/contrib/llvm/lib/Analysis/ProfileSummaryInfo.cpp +++ b/contrib/llvm/lib/Analysis/ProfileSummaryInfo.cpp @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ProfileSummary.h" @@ -55,22 +56,43 @@ static uint64_t getMinCountForPercentile(SummaryEntryVector &DS, // The profile summary metadata may be attached either by the frontend or by // any backend passes (IR level instrumentation, for example). This method // checks if the Summary is null and if so checks if the summary metadata is now -// available in the module and parses it to get the Summary object. -void ProfileSummaryInfo::computeSummary() { +// available in the module and parses it to get the Summary object. Returns true +// if a valid Summary is available. +bool ProfileSummaryInfo::computeSummary() { if (Summary) - return; + return true; auto *SummaryMD = M.getProfileSummary(); if (!SummaryMD) - return; + return false; Summary.reset(ProfileSummary::getFromMD(SummaryMD)); + return true; +} + +Optional<uint64_t> +ProfileSummaryInfo::getProfileCount(const Instruction *Inst, + BlockFrequencyInfo *BFI) { + if (!Inst) + return None; + assert((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) && + "We can only get profile count for call/invoke instruction."); + if (hasSampleProfile()) { + // In sample PGO mode, check if there is a profile metadata on the + // instruction. If it is present, determine hotness solely based on that, + // since the sampled entry count may not be accurate. + uint64_t TotalCount; + if (Inst->extractProfTotalWeight(TotalCount)) + return TotalCount; + } + if (BFI) + return BFI->getBlockProfileCount(Inst->getParent()); + return None; } /// Returns true if the function's entry is hot. If it returns false, it /// either means it is not hot or it is unknown whether it is hot or not (for /// example, no profile data is available). bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) { - computeSummary(); - if (!F || !Summary) + if (!F || !computeSummary()) return false; auto FunctionCount = F->getEntryCount(); // FIXME: The heuristic used below for determining hotness is based on @@ -79,17 +101,53 @@ bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) { return FunctionCount && isHotCount(FunctionCount.getValue()); } +/// Returns true if the function's entry or total call edge count is hot. +/// If it returns false, it either means it is not hot or it is unknown +/// whether it is hot or not (for example, no profile data is available). +bool ProfileSummaryInfo::isFunctionHotInCallGraph(const Function *F) { + if (!F || !computeSummary()) + return false; + if (auto FunctionCount = F->getEntryCount()) + if (isHotCount(FunctionCount.getValue())) + return true; + + uint64_t TotalCallCount = 0; + for (const auto &BB : *F) + for (const auto &I : BB) + if (isa<CallInst>(I) || isa<InvokeInst>(I)) + if (auto CallCount = getProfileCount(&I, nullptr)) + TotalCallCount += CallCount.getValue(); + return isHotCount(TotalCallCount); +} + +/// Returns true if the function's entry and total call edge count is cold. +/// If it returns false, it either means it is not cold or it is unknown +/// whether it is cold or not (for example, no profile data is available). +bool ProfileSummaryInfo::isFunctionColdInCallGraph(const Function *F) { + if (!F || !computeSummary()) + return false; + if (auto FunctionCount = F->getEntryCount()) + if (!isColdCount(FunctionCount.getValue())) + return false; + + uint64_t TotalCallCount = 0; + for (const auto &BB : *F) + for (const auto &I : BB) + if (isa<CallInst>(I) || isa<InvokeInst>(I)) + if (auto CallCount = getProfileCount(&I, nullptr)) + TotalCallCount += CallCount.getValue(); + return isColdCount(TotalCallCount); +} + /// Returns true if the function's entry is a cold. If it returns false, it /// either means it is not cold or it is unknown whether it is cold or not (for /// example, no profile data is available). bool ProfileSummaryInfo::isFunctionEntryCold(const Function *F) { - computeSummary(); if (!F) return false; - if (F->hasFnAttribute(Attribute::Cold)) { + if (F->hasFnAttribute(Attribute::Cold)) return true; - } - if (!Summary) + if (!computeSummary()) return false; auto FunctionCount = F->getEntryCount(); // FIXME: The heuristic used below for determining coldness is based on @@ -100,9 +158,7 @@ bool ProfileSummaryInfo::isFunctionEntryCold(const Function *F) { /// Compute the hot and cold thresholds. void ProfileSummaryInfo::computeThresholds() { - if (!Summary) - computeSummary(); - if (!Summary) + if (!computeSummary()) return; auto &DetailedSummary = Summary->getDetailedSummary(); HotCountThreshold = @@ -125,20 +181,25 @@ bool ProfileSummaryInfo::isColdCount(uint64_t C) { bool ProfileSummaryInfo::isHotBB(const BasicBlock *B, BlockFrequencyInfo *BFI) { auto Count = BFI->getBlockProfileCount(B); - if (Count && isHotCount(*Count)) - return true; - // Use extractProfTotalWeight to get BB count. - // For Sample PGO, BFI may not provide accurate BB count due to errors - // magnified during sample count propagation. This serves as a backup plan - // to ensure all hot BB will not be missed. - // The query currently has false positives as branch instruction cloning does - // not update/scale branch weights. Unlike false negatives, this will not cause - // performance problem. - uint64_t TotalCount; - if (B->getTerminator()->extractProfTotalWeight(TotalCount) && - isHotCount(TotalCount)) - return true; - return false; + return Count && isHotCount(*Count); +} + +bool ProfileSummaryInfo::isColdBB(const BasicBlock *B, + BlockFrequencyInfo *BFI) { + auto Count = BFI->getBlockProfileCount(B); + return Count && isColdCount(*Count); +} + +bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS, + BlockFrequencyInfo *BFI) { + auto C = getProfileCount(CS.getInstruction(), BFI); + return C && isHotCount(*C); +} + +bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, + BlockFrequencyInfo *BFI) { + auto C = getProfileCount(CS.getInstruction(), BFI); + return C && isColdCount(*C); } INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info", diff --git a/contrib/llvm/lib/Analysis/RegionInfo.cpp b/contrib/llvm/lib/Analysis/RegionInfo.cpp index 8c084ddd2266..63ef8d28d44a 100644 --- a/contrib/llvm/lib/Analysis/RegionInfo.cpp +++ b/contrib/llvm/lib/Analysis/RegionInfo.cpp @@ -83,6 +83,15 @@ RegionInfo::~RegionInfo() { } +bool RegionInfo::invalidate(Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // Check whether the analysis, all analyses on functions, or the function's + // CFG have been preserved. + auto PAC = PA.getChecker<RegionInfoAnalysis>(); + return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || + PAC.preservedSet<CFGAnalyses>()); +} + void RegionInfo::updateStatistics(Region *R) { ++numRegions; diff --git a/contrib/llvm/lib/Analysis/RegionPass.cpp b/contrib/llvm/lib/Analysis/RegionPass.cpp index 7358aa6810a1..b38e6225c840 100644 --- a/contrib/llvm/lib/Analysis/RegionPass.cpp +++ b/contrib/llvm/lib/Analysis/RegionPass.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/RegionPass.h" #include "llvm/Analysis/RegionIterator.h" +#include "llvm/IR/OptBisect.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Timer.h" #include "llvm/Support/raw_ostream.h" @@ -206,6 +207,8 @@ public: return false; } + + StringRef getPassName() const override { return "Print Region IR"; }
}; char PrintRegionPass::ID = 0; @@ -278,3 +281,18 @@ Pass *RegionPass::createPrinterPass(raw_ostream &O, const std::string &Banner) const { return new PrintRegionPass(Banner, O); } + +bool RegionPass::skipRegion(Region &R) const { + Function &F = *R.getEntry()->getParent(); + if (!F.getContext().getOptBisect().shouldRunPass(this, R)) + return true; + + if (F.hasFnAttribute(Attribute::OptimizeNone)) { + // Report this only once per function. + if (R.getEntry() == &F.getEntryBlock()) + DEBUG(dbgs() << "Skipping pass '" << getPassName() + << "' on function " << F.getName() << "\n"); + return true; + } + return false; +} diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp index ed328f12c463..d96697cafbe9 100644 --- a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp @@ -89,6 +89,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/SaveAndRestore.h" @@ -127,16 +128,35 @@ static cl::opt<unsigned> MulOpsInlineThreshold( cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(1000)); +static cl::opt<unsigned> AddOpsInlineThreshold( + "scev-addops-inline-threshold", cl::Hidden, + cl::desc("Threshold for inlining multiplication operands into a SCEV"), + cl::init(500)); + static cl::opt<unsigned> MaxSCEVCompareDepth( "scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32)); +static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth( + "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, + cl::desc("Maximum depth of recursive SCEV operations implication analysis"), + cl::init(2)); + static cl::opt<unsigned> MaxValueCompareDepth( "scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2)); +static cl::opt<unsigned> + MaxAddExprDepth("scalar-evolution-max-addexpr-depth", cl::Hidden, + cl::desc("Maximum depth of recursive AddExpr"), + cl::init(32)); + +static cl::opt<unsigned> MaxConstantEvolvingDepth( + "scalar-evolution-max-constant-evolving-depth", cl::Hidden, + cl::desc("Maximum depth of recursive constant evolving"), cl::init(32)); + //===----------------------------------------------------------------------===// // SCEV class definitions //===----------------------------------------------------------------------===// @@ -145,11 +165,12 @@ static cl::opt<unsigned> MaxValueCompareDepth( // Implementation of the SCEV class. // -LLVM_DUMP_METHOD -void SCEV::dump() const { +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEV::dump() const { print(dbgs()); dbgs() << '\n'; } +#endif void SCEV::print(raw_ostream &OS) const { switch (static_cast<SCEVTypes>(getSCEVType())) { @@ -563,7 +584,7 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, static int CompareSCEVComplexity( SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, - unsigned Depth = 0) { + DominatorTree &DT, unsigned Depth = 0) { // Fast-path: SCEVs are uniqued so we can do a quick equality check. if (LHS == RHS) return 0; @@ -608,12 +629,19 @@ static int CompareSCEVComplexity( const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); - // Compare addrec loop depths. + // There is always a dominance between two recs that are used by one SCEV, + // so we can safely sort recs by loop header dominance. We require such + // order in getAddExpr. const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); if (LLoop != RLoop) { - unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth(); - if (LDepth != RDepth) - return (int)LDepth - (int)RDepth; + const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader(); + assert(LHead != RHead && "Two loops share the same header?"); + if (DT.dominates(LHead, RHead)) + return 1; + else + assert(DT.dominates(RHead, LHead) && + "No dominance between recurrences used by one SCEV?"); + return -1; } // Addrec complexity grows with operand count. @@ -624,7 +652,7 @@ static int CompareSCEVComplexity( // Lexicographically compare. for (unsigned i = 0; i != LNumOps; ++i) { int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i), - RA->getOperand(i), Depth + 1); + RA->getOperand(i), DT, Depth + 1); if (X != 0) return X; } @@ -648,7 +676,7 @@ static int CompareSCEVComplexity( if (i >= RNumOps) return 1; int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i), - RC->getOperand(i), Depth + 1); + RC->getOperand(i), DT, Depth + 1); if (X != 0) return X; } @@ -662,10 +690,10 @@ static int CompareSCEVComplexity( // Lexicographically compare udiv expressions. int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(), - Depth + 1); + DT, Depth + 1); if (X != 0) return X; - X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), + X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), DT, Depth + 1); if (X == 0) EqCacheSCEV.insert({LHS, RHS}); @@ -680,7 +708,7 @@ static int CompareSCEVComplexity( // Compare cast expressions by operand. int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(), - RC->getOperand(), Depth + 1); + RC->getOperand(), DT, Depth + 1); if (X == 0) EqCacheSCEV.insert({LHS, RHS}); return X; @@ -703,7 +731,7 @@ static int CompareSCEVComplexity( /// land in memory. /// static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, - LoopInfo *LI) { + LoopInfo *LI, DominatorTree &DT) { if (Ops.size() < 2) return; // Noop SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache; @@ -711,15 +739,16 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, // This is the common case, which also happens to be trivially simple. // Special case it. const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; - if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0) + if (CompareSCEVComplexity(EqCache, LI, RHS, LHS, DT) < 0) std::swap(LHS, RHS); return; } // Do the rough sort by complexity. std::stable_sort(Ops.begin(), Ops.end(), - [&EqCache, LI](const SCEV *LHS, const SCEV *RHS) { - return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0; + [&EqCache, LI, &DT](const SCEV *LHS, const SCEV *RHS) { + return + CompareSCEVComplexity(EqCache, LI, LHS, RHS, DT) < 0; }); // Now that we are sorted by complexity, group elements of the same @@ -1073,7 +1102,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, APInt Mult(W, i); unsigned TwoFactors = Mult.countTrailingZeros(); T += TwoFactors; - Mult = Mult.lshr(TwoFactors); + Mult.lshrInPlace(TwoFactors); OddFactorial *= Mult; } @@ -1256,7 +1285,8 @@ static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, namespace { struct ExtendOpTraitsBase { - typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); + typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)( + const SCEV *, Type *, ScalarEvolution::ExtendCacheTy &Cache); }; // Used to make code generic over signed and unsigned overflow. @@ -1285,8 +1315,9 @@ struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase { } }; -const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< - SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; +const ExtendOpTraitsBase::GetExtendExprTy + ExtendOpTraits<SCEVSignExtendExpr>::GetExtendExpr = + &ScalarEvolution::getSignExtendExprCached; template <> struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { @@ -1301,8 +1332,9 @@ struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { } }; -const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< - SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; +const ExtendOpTraitsBase::GetExtendExprTy + ExtendOpTraits<SCEVZeroExtendExpr>::GetExtendExpr = + &ScalarEvolution::getZeroExtendExprCached; } // The recurrence AR has been shown to have no signed/unsigned wrap or something @@ -1314,7 +1346,8 @@ const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< // "sext/zext(PostIncAR)" template <typename ExtendOpTy> static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE) { + ScalarEvolution *SE, + ScalarEvolution::ExtendCacheTy &Cache) { auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; @@ -1361,9 +1394,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); const SCEV *OperandExtendedStart = - SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), - (SE->*GetExtendExpr)(Step, WideTy)); - if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { + SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Cache), + (SE->*GetExtendExpr)(Step, WideTy, Cache)); + if ((SE->*GetExtendExpr)(Start, WideTy, Cache) == OperandExtendedStart) { if (PreAR && AR->getNoWrapFlags(WrapType)) { // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then @@ -1388,15 +1421,17 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // Get the normalized zero or sign extended expression for this AddRec's Start. template <typename ExtendOpTy> static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE) { + ScalarEvolution *SE, + ScalarEvolution::ExtendCacheTy &Cache) { auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; - const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE); + const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Cache); if (!PreStart) - return (SE->*GetExtendExpr)(AR->getStart(), Ty); + return (SE->*GetExtendExpr)(AR->getStart(), Ty, Cache); - return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), - (SE->*GetExtendExpr)(PreStart, Ty)); + return SE->getAddExpr( + (SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, Cache), + (SE->*GetExtendExpr)(PreStart, Ty, Cache)); } // Try to prove away overflow by looking at "nearby" add recurrences. A @@ -1476,8 +1511,31 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, return false; } -const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, - Type *Ty) { +const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) { + // Use the local cache to prevent exponential behavior of + // getZeroExtendExprImpl. + ExtendCacheTy Cache; + return getZeroExtendExprCached(Op, Ty, Cache); +} + +/// Query \p Cache before calling getZeroExtendExprImpl. If there is no +/// related entry in the \p Cache, call getZeroExtendExprImpl and save +/// the result in the \p Cache. +const SCEV *ScalarEvolution::getZeroExtendExprCached(const SCEV *Op, Type *Ty, + ExtendCacheTy &Cache) { + auto It = Cache.find({Op, Ty}); + if (It != Cache.end()) + return It->second; + const SCEV *ZExt = getZeroExtendExprImpl(Op, Ty, Cache); + auto InsertResult = Cache.insert({{Op, Ty}, ZExt}); + assert(InsertResult.second && "Expect the key was not in the cache"); + (void)InsertResult; + return ZExt; +} + +/// The real implementation of getZeroExtendExpr. +const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, + ExtendCacheTy &Cache) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1487,11 +1545,11 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) return getConstant( - cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty))); + cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty))); // zext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) - return getZeroExtendExpr(SZ->getOperand(), Ty); + return getZeroExtendExprCached(SZ->getOperand(), Ty, Cache); // Before doing any expensive analysis, check to see if we've already // computed a SCEV for this Op and Ty. @@ -1535,8 +1593,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // we don't need to do any further analysis. if (AR->hasNoUnsignedWrap()) return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache), + getZeroExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags()); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1561,21 +1619,22 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step); - const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy); - const SCEV *WideStart = getZeroExtendExpr(Start, WideTy); + const SCEV *ZAdd = + getZeroExtendExprCached(getAddExpr(Start, ZMul), WideTy, Cache); + const SCEV *WideStart = getZeroExtendExprCached(Start, WideTy, Cache); const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getZeroExtendExpr(Step, WideTy))); + getZeroExtendExprCached(CastedMaxBECount, WideTy, Cache); + const SCEV *OperandExtendedAdd = getAddExpr( + WideStart, getMulExpr(WideMaxBECount, getZeroExtendExprCached( + Step, WideTy, Cache))); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NUW, which is propagated to this AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache), + getZeroExtendExprCached(Step, Ty, Cache), L, + AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as signed. // This covers loops that count down. @@ -1589,7 +1648,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1621,8 +1680,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache), + getZeroExtendExprCached(Step, Ty, Cache), L, + AR->getNoWrapFlags()); } } else if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - @@ -1637,7 +1697,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache), getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1646,8 +1706,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); return getAddRecExpr( - getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), - getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache), + getZeroExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags()); } } @@ -1658,7 +1718,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, // commute the zero extension with the addition operation. SmallVector<const SCEV *, 4> Ops; for (const auto *Op : SA->operands()) - Ops.push_back(getZeroExtendExpr(Op, Ty)); + Ops.push_back(getZeroExtendExprCached(Op, Ty, Cache)); return getAddExpr(Ops, SCEV::FlagNUW); } } @@ -1672,8 +1732,31 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, return S; } -const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, - Type *Ty) { +const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) { + // Use the local cache to prevent exponential behavior of + // getSignExtendExprImpl. + ExtendCacheTy Cache; + return getSignExtendExprCached(Op, Ty, Cache); +} + +/// Query \p Cache before calling getSignExtendExprImpl. If there is no +/// related entry in the \p Cache, call getSignExtendExprImpl and save +/// the result in the \p Cache. +const SCEV *ScalarEvolution::getSignExtendExprCached(const SCEV *Op, Type *Ty, + ExtendCacheTy &Cache) { + auto It = Cache.find({Op, Ty}); + if (It != Cache.end()) + return It->second; + const SCEV *SExt = getSignExtendExprImpl(Op, Ty, Cache); + auto InsertResult = Cache.insert({{Op, Ty}, SExt}); + assert(InsertResult.second && "Expect the key was not in the cache"); + (void)InsertResult; + return SExt; +} + +/// The real implementation of getSignExtendExpr. +const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, + ExtendCacheTy &Cache) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1683,11 +1766,11 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) return getConstant( - cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty))); + cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty))); // sext(sext(x)) --> sext(x) if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) - return getSignExtendExpr(SS->getOperand(), Ty); + return getSignExtendExprCached(SS->getOperand(), Ty, Cache); // sext(zext(x)) --> zext(x) if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) @@ -1726,8 +1809,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) - return getAddExpr(getSignExtendExpr(SC1, Ty), - getSignExtendExpr(SMul, Ty)); + return getAddExpr(getSignExtendExprCached(SC1, Ty, Cache), + getSignExtendExprCached(SMul, Ty, Cache)); } } } @@ -1738,7 +1821,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // commute the sign extension with the addition operation. SmallVector<const SCEV *, 4> Ops; for (const auto *Op : SA->operands()) - Ops.push_back(getSignExtendExpr(Op, Ty)); + Ops.push_back(getSignExtendExprCached(Op, Ty, Cache)); return getAddExpr(Ops, SCEV::FlagNSW); } } @@ -1762,8 +1845,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // we don't need to do any further analysis. if (AR->hasNoSignedWrap()) return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache), + getSignExtendExprCached(Step, Ty, Cache), L, SCEV::FlagNSW); // Check whether the backedge-taken count is SCEVCouldNotCompute. // Note that this serves two purposes: It filters out loops that are @@ -1788,21 +1871,22 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); - const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy); - const SCEV *WideStart = getSignExtendExpr(Start, WideTy); + const SCEV *SAdd = + getSignExtendExprCached(getAddExpr(Start, SMul), WideTy, Cache); + const SCEV *WideStart = getSignExtendExprCached(Start, WideTy, Cache); const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getSignExtendExpr(Step, WideTy))); + getZeroExtendExpr(CastedMaxBECount, WideTy); + const SCEV *OperandExtendedAdd = getAddExpr( + WideStart, getMulExpr(WideMaxBECount, getSignExtendExprCached( + Step, WideTy, Cache))); if (SAdd == OperandExtendedAdd) { // Cache knowledge of AR NSW, which is propagated to this AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache), + getSignExtendExprCached(Step, Ty, Cache), L, + AR->getNoWrapFlags()); } // Similar to above, only this time treat the step value as unsigned. // This covers loops that count up with an unsigned step. @@ -1823,7 +1907,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Return the expression with the addrec on the outside. return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } @@ -1855,8 +1939,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache), + getSignExtendExprCached(Step, Ty, Cache), L, + AR->getNoWrapFlags()); } } @@ -1870,18 +1955,18 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) { - Start = getSignExtendExpr(Start, Ty); + Start = getSignExtendExprCached(Start, Ty, Cache); const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L, AR->getNoWrapFlags()); - return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); + return getAddExpr(Start, getSignExtendExprCached(NewAR, Ty, Cache)); } } if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), - getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache), + getSignExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags()); } } @@ -2093,9 +2178,66 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, return Flags; } +bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { + if (!isLoopInvariant(S, L)) + return false; + // If a value depends on a SCEVUnknown which is defined after the loop, we + // conservatively assume that we cannot calculate it at the loop's entry. + struct FindDominatedSCEVUnknown { + bool Found = false; + const Loop *L; + DominatorTree &DT; + LoopInfo &LI; + + FindDominatedSCEVUnknown(const Loop *L, DominatorTree &DT, LoopInfo &LI) + : L(L), DT(DT), LI(LI) {} + + bool checkSCEVUnknown(const SCEVUnknown *SU) { + if (auto *I = dyn_cast<Instruction>(SU->getValue())) { + if (DT.dominates(L->getHeader(), I->getParent())) + Found = true; + else + assert(DT.dominates(I->getParent(), L->getHeader()) && + "No dominance relationship between SCEV and loop?"); + } + return false; + } + + bool follow(const SCEV *S) { + switch (static_cast<SCEVTypes>(S->getSCEVType())) { + case scConstant: + return false; + case scAddRecExpr: + case scTruncate: + case scZeroExtend: + case scSignExtend: + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: + case scUDivExpr: + return true; + case scUnknown: + return checkSCEVUnknown(cast<SCEVUnknown>(S)); + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + return false; + } + + bool isDone() { return Found; } + }; + + FindDominatedSCEVUnknown FSU(L, DT, LI); + SCEVTraversal<FindDominatedSCEVUnknown> ST(FSU); + ST.visitAll(S); + return !FSU.Found; +} + /// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, - SCEV::NoWrapFlags Flags) { + SCEV::NoWrapFlags Flags, + unsigned Depth) { assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); @@ -2108,7 +2250,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, &LI); + GroupByComplexity(Ops, &LI, DT); Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); @@ -2134,6 +2276,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, if (Ops.size() == 1) return Ops[0]; } + // Limit recursion calls depth + if (Depth > MaxAddExprDepth) + return getOrCreateAddExpr(Ops, Flags); + // Okay, check to see if the same value occurs in the operand list more than // once. If so, merge them together into an multiply expression. Since we // sorted the list, these values are required to be adjacent. @@ -2205,7 +2351,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps, Flags); + const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) return getTruncateExpr(Fold, DstType); @@ -2220,6 +2366,9 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, if (Idx < Ops.size()) { bool DeletedAdd = false; while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { + if (Ops.size() > AddOpsInlineThreshold || + Add->getNumOperands() > AddOpsInlineThreshold) + break; // If we have an add, expand the add operands onto the end of the operands // list. Ops.erase(Ops.begin()+Idx); @@ -2231,7 +2380,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // and they are not necessarily sorted. Recurse to resort and resimplify // any operands we just acquired. if (DeletedAdd) - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Skip over the add expression until we get to a multiply. @@ -2266,13 +2415,14 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops.push_back(getConstant(AccumulatedConstant)); for (auto &MulOp : MulOpLists) if (MulOp.first != 0) - Ops.push_back(getMulExpr(getConstant(MulOp.first), - getAddExpr(MulOp.second))); + Ops.push_back(getMulExpr( + getConstant(MulOp.first), + getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1))); if (Ops.empty()) return getZero(Ty); if (Ops.size() == 1) return Ops[0]; - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } } @@ -2297,8 +2447,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); InnerMul = getMulExpr(MulOps); } - const SCEV *One = getOne(Ty); - const SCEV *AddOne = getAddExpr(One, InnerMul); + SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul}; + const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { @@ -2309,7 +2459,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops.erase(Ops.begin()+AddOp-1); } Ops.push_back(OuterMul); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Check this multiply against other multiplies being added together. @@ -2337,13 +2487,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); InnerMul2 = getMulExpr(MulOps); } - const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); + SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2}; + const SCEV *InnerMulSum = + getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); Ops.push_back(OuterMul); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } } } @@ -2363,7 +2515,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (isLoopInvariant(Ops[i], AddRecLoop)) { + if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; @@ -2379,7 +2531,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // This follows from the fact that the no-wrap flags on the outer add // expression are applicable on the 0th iteration, when the add recurrence // will be equal to its start value. - AddRecOps[0] = getAddExpr(LIOps, Flags); + AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. @@ -2396,7 +2548,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, Ops[i] = NewRec; break; } - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } // Okay, if there weren't any loop invariants to be folded, check to see if @@ -2404,31 +2556,40 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // added together. If so, we can fold them. for (unsigned OtherIdx = Idx+1; OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); - ++OtherIdx) + ++OtherIdx) { + // We expect the AddRecExpr's to be sorted in reverse dominance order, + // so that the 1st found AddRecExpr is dominated by all others. + assert(DT.dominates( + cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), + AddRec->getLoop()->getHeader()) && + "AddRecExprs are not sorted in reverse dominance order?"); if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L> SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), AddRec->op_end()); for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); - ++OtherIdx) - if (const auto *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx])) - if (OtherAddRec->getLoop() == AddRecLoop) { - for (unsigned i = 0, e = OtherAddRec->getNumOperands(); - i != e; ++i) { - if (i >= AddRecOps.size()) { - AddRecOps.append(OtherAddRec->op_begin()+i, - OtherAddRec->op_end()); - break; - } - AddRecOps[i] = getAddExpr(AddRecOps[i], - OtherAddRec->getOperand(i)); + ++OtherIdx) { + const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); + if (OtherAddRec->getLoop() == AddRecLoop) { + for (unsigned i = 0, e = OtherAddRec->getNumOperands(); + i != e; ++i) { + if (i >= AddRecOps.size()) { + AddRecOps.append(OtherAddRec->op_begin()+i, + OtherAddRec->op_end()); + break; } - Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + SmallVector<const SCEV *, 2> TwoOps = { + AddRecOps[i], OtherAddRec->getOperand(i)}; + AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); } + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + } + } // Step size has changed, so we cannot guarantee no self-wraparound. Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); - return getAddExpr(Ops); + return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); } + } // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. @@ -2436,18 +2597,24 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. + return getOrCreateAddExpr(Ops, Flags); +} + +const SCEV * +ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddExpr); for (unsigned i = 0, e = Ops.size(); i != e; ++i) ID.AddPointer(Ops[i]); void *IP = nullptr; SCEVAddExpr *S = - static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); - S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), - O, Ops.size()); + S = new (SCEVAllocator) + SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); } S->setNoWrapFlags(Flags); @@ -2519,7 +2686,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, &LI); + GroupByComplexity(Ops, &LI, DT); Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); @@ -2623,7 +2790,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (isLoopInvariant(Ops[i], AddRecLoop)) { + if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; @@ -2875,7 +3042,7 @@ static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { else if (ABW < BBW) A = A.zext(BBW); - return APIntOps::GreatestCommonDivisor(A, B); + return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B)); } /// Get a canonical unsigned division expression, or something simpler if @@ -2889,7 +3056,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, // end of this file for inspiration. const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS); - if (!Mul) + if (!Mul || !Mul->hasNoUnsignedWrap()) return getUDivExpr(LHS, RHS); if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) { @@ -3116,7 +3283,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, &LI); + GroupByComplexity(Ops, &LI, DT); // If there are any constants, fold them together. unsigned Idx = 0; @@ -3217,7 +3384,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { #endif // Sort by complexity, this groups all similar expression types together. - GroupByComplexity(Ops, &LI); + GroupByComplexity(Ops, &LI, DT); // If there are any constants, fold them together. unsigned Idx = 0; @@ -3385,6 +3552,10 @@ Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { return getDataLayout().getIntPtrType(Ty); } +Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { + return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; +} + const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } @@ -3770,7 +3941,7 @@ public: : SCEVRewriteVisitor(SE), L(L), Valid(true) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { - if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + if (!SE.isLoopInvariant(Expr, L)) Valid = false; return Expr; } @@ -3804,7 +3975,7 @@ public: const SCEV *visitUnknown(const SCEVUnknown *Expr) { // Only allow AddRecExprs for this loop. - if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + if (!SE.isLoopInvariant(Expr, L)) Valid = false; return Expr; } @@ -3909,9 +4080,9 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { case Instruction::Xor: if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) - // If the RHS of the xor is a signbit, then this is just an add. - // Instcombine turns add of signbit into xor as a strength reduction step. - if (RHSC->getValue().isSignBit()) + // If the RHS of the xor is a signmask, then this is just an add. + // Instcombine turns add of signmask into xor as a strength reduction step. + if (RHSC->getValue().isSignMask()) return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); return BinaryOp(Op); @@ -3984,6 +4155,56 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { return None; } +/// A helper function for createAddRecFromPHI to handle simple cases. +/// +/// This function tries to find an AddRec expression for the simplest (yet most +/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)). +/// If it fails, createAddRecFromPHI will use a more general, but slow, +/// technique for finding the AddRec expression. +const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, + Value *BEValueV, + Value *StartValueV) { + const Loop *L = LI.getLoopFor(PN->getParent()); + assert(L && L->getHeader() == PN->getParent()); + assert(BEValueV && StartValueV); + + auto BO = MatchBinaryOp(BEValueV, DT); + if (!BO) + return nullptr; + + if (BO->Opcode != Instruction::Add) + return nullptr; + + const SCEV *Accum = nullptr; + if (BO->LHS == PN && L->isLoopInvariant(BO->RHS)) + Accum = getSCEV(BO->RHS); + else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS)) + Accum = getSCEV(BO->LHS); + + if (!Accum) + return nullptr; + + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BO->IsNUW) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (BO->IsNSW) + Flags = setFlags(Flags, SCEV::FlagNSW); + + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + + // We can add Flags to the post-inc expression only if we + // know that it is *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + + return PHISCEV; +} + const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) @@ -4009,127 +4230,134 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { break; } } - if (BEValueV && StartValueV) { - // While we are analyzing this PHI node, handle its value symbolically. - const SCEV *SymbolicName = getUnknown(PN); - assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && - "PHI node already processed?"); - ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); + if (!BEValueV || !StartValueV) + return nullptr; - // Using this symbolic name for the PHI, analyze the value coming around - // the back-edge. - const SCEV *BEValue = getSCEV(BEValueV); + assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && + "PHI node already processed?"); - // NOTE: If BEValue is loop invariant, we know that the PHI node just - // has a special value for the first iteration of the loop. + // First, try to find AddRec expression without creating a fictituos symbolic + // value for PN. + if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) + return S; - // If the value coming around the backedge is an add with the symbolic - // value we just inserted, then we found a simple induction variable! - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { - // If there is a single occurrence of the symbolic value, replace it - // with a recurrence. - unsigned FoundIndex = Add->getNumOperands(); + // Handle PHI node value symbolically. + const SCEV *SymbolicName = getUnknown(PN); + ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); + + // Using this symbolic name for the PHI, analyze the value coming around + // the back-edge. + const SCEV *BEValue = getSCEV(BEValueV); + + // NOTE: If BEValue is loop invariant, we know that the PHI node just + // has a special value for the first iteration of the loop. + + // If the value coming around the backedge is an add with the symbolic + // value we just inserted, then we found a simple induction variable! + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { + // If there is a single occurrence of the symbolic value, replace it + // with a recurrence. + unsigned FoundIndex = Add->getNumOperands(); + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (Add->getOperand(i) == SymbolicName) + if (FoundIndex == e) { + FoundIndex = i; + break; + } + + if (FoundIndex != Add->getNumOperands()) { + // Create an add with everything but the specified operand. + SmallVector<const SCEV *, 8> Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) - if (Add->getOperand(i) == SymbolicName) - if (FoundIndex == e) { - FoundIndex = i; - break; + if (i != FoundIndex) + Ops.push_back(Add->getOperand(i)); + const SCEV *Accum = getAddExpr(Ops); + + // This is not a valid addrec if the step amount is varying each + // loop iteration, but is not itself an addrec in this loop. + if (isLoopInvariant(Accum, L) || + (isa<SCEVAddRecExpr>(Accum) && + cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + + if (auto BO = MatchBinaryOp(BEValueV, DT)) { + if (BO->Opcode == Instruction::Add && BO->LHS == PN) { + if (BO->IsNUW) + Flags = setFlags(Flags, SCEV::FlagNUW); + if (BO->IsNSW) + Flags = setFlags(Flags, SCEV::FlagNSW); } - - if (FoundIndex != Add->getNumOperands()) { - // Create an add with everything but the specified operand. - SmallVector<const SCEV *, 8> Ops; - for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) - if (i != FoundIndex) - Ops.push_back(Add->getOperand(i)); - const SCEV *Accum = getAddExpr(Ops); - - // This is not a valid addrec if the step amount is varying each - // loop iteration, but is not itself an addrec in this loop. - if (isLoopInvariant(Accum, L) || - (isa<SCEVAddRecExpr>(Accum) && - cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; - - if (auto BO = MatchBinaryOp(BEValueV, DT)) { - if (BO->Opcode == Instruction::Add && BO->LHS == PN) { - if (BO->IsNUW) - Flags = setFlags(Flags, SCEV::FlagNUW); - if (BO->IsNSW) - Flags = setFlags(Flags, SCEV::FlagNSW); - } - } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { - // If the increment is an inbounds GEP, then we know the address - // space cannot be wrapped around. We cannot make any guarantee - // about signed or unsigned overflow because pointers are - // unsigned but we may have a negative index from the base - // pointer. We can guarantee that no unsigned wrap occurs if the - // indices form a positive value. - if (GEP->isInBounds() && GEP->getOperand(0) == PN) { - Flags = setFlags(Flags, SCEV::FlagNW); - - const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); - if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) - Flags = setFlags(Flags, SCEV::FlagNUW); - } - - // We cannot transfer nuw and nsw flags from subtraction - // operations -- sub nuw X, Y is not the same as add nuw X, -Y - // for instance. + } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { + // If the increment is an inbounds GEP, then we know the address + // space cannot be wrapped around. We cannot make any guarantee + // about signed or unsigned overflow because pointers are + // unsigned but we may have a negative index from the base + // pointer. We can guarantee that no unsigned wrap occurs if the + // indices form a positive value. + if (GEP->isInBounds() && GEP->getOperand(0) == PN) { + Flags = setFlags(Flags, SCEV::FlagNW); + + const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); + if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) + Flags = setFlags(Flags, SCEV::FlagNUW); } - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + // We cannot transfer nuw and nsw flags from subtraction + // operations -- sub nuw X, Y is not the same as add nuw X, -Y + // for instance. + } - // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and purge all of the - // entries for the scalars that use the symbolic expression. - forgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + const SCEV *StartVal = getSCEV(StartValueV); + const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); - // We can add Flags to the post-inc expression only if we - // know that it us *undefined behavior* for BEValueV to - // overflow. - if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) - if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) - (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + forgetSymbolicName(PN, SymbolicName); + ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; - return PHISCEV; - } + // We can add Flags to the post-inc expression only if we + // know that it is *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + + return PHISCEV; } - } else { - // Otherwise, this could be a loop like this: - // i = 0; for (j = 1; ..; ++j) { .... i = j; } - // In this case, j = {1,+,1} and BEValue is j. - // Because the other in-value of i (0) fits the evolution of BEValue - // i really is an addrec evolution. - // - // We can generalize this saying that i is the shifted value of BEValue - // by one iteration: - // PHI(f(0), f({1,+,1})) --> f({0,+,1}) - const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); - const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); - if (Shifted != getCouldNotCompute() && - Start != getCouldNotCompute()) { - const SCEV *StartVal = getSCEV(StartValueV); - if (Start == StartVal) { - // Okay, for the entire analysis of this edge we assumed the PHI - // to be symbolic. We now need to go back and purge all of the - // entries for the scalars that use the symbolic expression. - forgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; - return Shifted; - } + } + } else { + // Otherwise, this could be a loop like this: + // i = 0; for (j = 1; ..; ++j) { .... i = j; } + // In this case, j = {1,+,1} and BEValue is j. + // Because the other in-value of i (0) fits the evolution of BEValue + // i really is an addrec evolution. + // + // We can generalize this saying that i is the shifted value of BEValue + // by one iteration: + // PHI(f(0), f({1,+,1})) --> f({0,+,1}) + const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); + const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); + if (Shifted != getCouldNotCompute() && + Start != getCouldNotCompute()) { + const SCEV *StartVal = getSCEV(StartValueV); + if (Start == StartVal) { + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and purge all of the + // entries for the scalars that use the symbolic expression. + forgetSymbolicName(PN, SymbolicName); + ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; + return Shifted; } } - - // Remove the temporary PHI node SCEV that has been inserted while intending - // to create an AddRecExpr for this PHI node. We can not keep this temporary - // as it will prevent later (possibly simpler) SCEV expressions to be added - // to the ValueExprMap. - eraseValueFromMap(PN); } + // Remove the temporary PHI node SCEV that has been inserted while intending + // to create an AddRecExpr for this PHI node. We can not keep this temporary + // as it will prevent later (possibly simpler) SCEV expressions to be added + // to the ValueExprMap. + eraseValueFromMap(PN); + return nullptr; } @@ -4289,7 +4517,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC)) + if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC})) if (LI.replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -4409,8 +4637,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { return getGEPExpr(GEP, IndexExprs); } -uint32_t -ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { +uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) return C->getAPInt().countTrailingZeros(); @@ -4420,14 +4647,16 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? - getTypeSizeInBits(E->getType()) : OpRes; + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) + ? getTypeSizeInBits(E->getType()) + : OpRes; } if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? - getTypeSizeInBits(E->getType()) : OpRes; + return OpRes == getTypeSizeInBits(E->getOperand()->getType()) + ? getTypeSizeInBits(E->getType()) + : OpRes; } if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { @@ -4444,8 +4673,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { uint32_t BitWidth = getTypeSizeInBits(M->getType()); for (unsigned i = 1, e = M->getNumOperands(); SumOpRes != BitWidth && i != e; ++i) - SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), - BitWidth); + SumOpRes = + std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); return SumOpRes; } @@ -4475,17 +4704,25 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { // For a SCEVUnknown, ask ValueTracking. - unsigned BitWidth = getTypeSizeInBits(U->getType()); - APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC, - nullptr, &DT); - return Zeros.countTrailingOnes(); + KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT); + return Known.countMinTrailingZeros(); } // SCEVUDivExpr return 0; } +uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { + auto I = MinTrailingZerosCache.find(S); + if (I != MinTrailingZerosCache.end()) + return I->second; + + uint32_t Result = GetMinTrailingZerosImpl(S); + auto InsertPair = MinTrailingZerosCache.insert({S, Result}); + assert(InsertPair.second && "Should insert a new key"); + return InsertPair.first->second; +} + /// Helper method to assign a range to V from metadata present in the IR. static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) @@ -4632,7 +4869,7 @@ ScalarEvolution::getRange(const SCEV *S, } } - return setRange(AddRec, SignHint, ConservativeResult); + return setRange(AddRec, SignHint, std::move(ConservativeResult)); } if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { @@ -4647,11 +4884,11 @@ ScalarEvolution::getRange(const SCEV *S, const DataLayout &DL = getDataLayout(); if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { // For a SCEVUnknown, ask ValueTracking. - APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, &AC, nullptr, &DT); - if (Ones != ~Zeros + 1) + KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT); + if (Known.One != ~Known.Zero + 1) ConservativeResult = - ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); + ConservativeResult.intersectWith(ConstantRange(Known.One, + ~Known.Zero + 1)); } else { assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && "generalize as needed!"); @@ -4662,10 +4899,78 @@ ScalarEvolution::getRange(const SCEV *S, APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); } - return setRange(U, SignHint, ConservativeResult); + return setRange(U, SignHint, std::move(ConservativeResult)); } - return setRange(S, SignHint, ConservativeResult); + return setRange(S, SignHint, std::move(ConservativeResult)); +} + +// Given a StartRange, Step and MaxBECount for an expression compute a range of +// values that the expression can take. Initially, the expression has a value +// from StartRange and then is changed by Step up to MaxBECount times. Signed +// argument defines if we treat Step as signed or unsigned. +static ConstantRange getRangeForAffineARHelper(APInt Step, + const ConstantRange &StartRange, + const APInt &MaxBECount, + unsigned BitWidth, bool Signed) { + // If either Step or MaxBECount is 0, then the expression won't change, and we + // just need to return the initial range. + if (Step == 0 || MaxBECount == 0) + return StartRange; + + // If we don't know anything about the initial value (i.e. StartRange is + // FullRange), then we don't know anything about the final range either. + // Return FullRange. + if (StartRange.isFullSet()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // If Step is signed and negative, then we use its absolute value, but we also + // note that we're moving in the opposite direction. + bool Descending = Signed && Step.isNegative(); + + if (Signed) + // This is correct even for INT_SMIN. Let's look at i8 to illustrate this: + // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128. + // This equations hold true due to the well-defined wrap-around behavior of + // APInt. + Step = Step.abs(); + + // Check if Offset is more than full span of BitWidth. If it is, the + // expression is guaranteed to overflow. + if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount)) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // Offset is by how much the expression can change. Checks above guarantee no + // overflow here. + APInt Offset = Step * MaxBECount; + + // Minimum value of the final range will match the minimal value of StartRange + // if the expression is increasing and will be decreased by Offset otherwise. + // Maximum value of the final range will match the maximal value of StartRange + // if the expression is decreasing and will be increased by Offset otherwise. + APInt StartLower = StartRange.getLower(); + APInt StartUpper = StartRange.getUpper() - 1; + APInt MovedBoundary = Descending ? (StartLower - std::move(Offset)) + : (StartUpper + std::move(Offset)); + + // It's possible that the new minimum/maximum value will fall into the initial + // range (due to wrap around). This means that the expression can take any + // value in this bitwidth, and we have to return full range. + if (StartRange.contains(MovedBoundary)) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + APInt NewLower = + Descending ? std::move(MovedBoundary) : std::move(StartLower); + APInt NewUpper = + Descending ? std::move(StartUpper) : std::move(MovedBoundary); + NewUpper += 1; + + // If we end up with full range, return a proper full range. + if (NewLower == NewUpper) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + // No overflow detected, return [StartLower, StartUpper + Offset + 1) range. + return ConstantRange(std::move(NewLower), std::move(NewUpper)); } ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, @@ -4676,60 +4981,30 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && "Precondition!"); - ConstantRange Result(BitWidth, /* isFullSet = */ true); - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange ZExtMaxBECountRange = MaxBECountRange.zextOrTrunc(BitWidth * 2); + APInt MaxBECountValue = MaxBECountRange.getUnsignedMax(); + // First, consider step signed. + ConstantRange StartSRange = getSignedRange(Start); ConstantRange StepSRange = getSignedRange(Step); - ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2); - - ConstantRange StartURange = getUnsignedRange(Start); - ConstantRange EndURange = - StartURange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for unsigned overflow. - ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2); - ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2); - if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - ZExtEndURange) { - APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), - EndURange.getUnsignedMin()); - APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), - EndURange.getUnsignedMax()); - bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); - if (!IsFullRange) - Result = - Result.intersectWith(ConstantRange(Min, Max + 1)); - } - ConstantRange StartSRange = getSignedRange(Start); - ConstantRange EndSRange = - StartSRange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for signed overflow. This must be done with ConstantRange - // arithmetic because we could be called from within the ScalarEvolution - // overflow checking code. - ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2); - ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2); - if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - SExtEndSRange) { - APInt Min = - APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin()); - APInt Max = - APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax()); - bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); - if (!IsFullRange) - Result = - Result.intersectWith(ConstantRange(Min, Max + 1)); - } + // If Step can be both positive and negative, we need to find ranges for the + // maximum absolute step values in both directions and union them. + ConstantRange SR = + getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange, + MaxBECountValue, BitWidth, /* Signed = */ true); + SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(), + StartSRange, MaxBECountValue, + BitWidth, /* Signed = */ true)); - return Result; + // Next, consider step unsigned. + ConstantRange UR = getRangeForAffineARHelper( + getUnsignedRange(Step).getUnsignedMax(), getUnsignedRange(Start), + MaxBECountValue, BitWidth, /* Signed = */ false); + + // Finally, intersect signed and unsigned ranges. + return SR.intersectWith(UR); } ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, @@ -4875,7 +5150,8 @@ bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { return false; // Only proceed if we can prove that I does not yield poison. - if (!isKnownNotFullPoison(I)) return false; + if (!programUndefinedIfFullPoison(I)) + return false; // At this point we know that if I is executed, then it does not wrap // according to at least one of NSW or NUW. If I is not executed, then we do @@ -5141,19 +5417,34 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned LZ = A.countLeadingZeros(); unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(), + KnownBits Known(BitWidth); + computeKnownBits(BO->LHS, Known, getDataLayout(), 0, &AC, nullptr, &DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); - if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { - const SCEV *MulCount = getConstant(ConstantInt::get( - getContext(), APInt::getOneBitSet(BitWidth, TZ))); + if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) { + const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); + const SCEV *LHS = getSCEV(BO->LHS); + const SCEV *ShiftedLHS = nullptr; + if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) { + if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) { + // For an expression like (x * 8) & 8, simplify the multiply. + unsigned MulZeros = OpC->getAPInt().countTrailingZeros(); + unsigned GCD = std::min(MulZeros, TZ); + APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); + SmallVector<const SCEV*, 4> MulOps; + MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD))); + MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end()); + auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); + ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); + } + } + if (!ShiftedLHS) + ShiftedLHS = getUDivExpr(LHS, MulCount); return getMulExpr( getZeroExtendExpr( - getTruncateExpr( - getUDivExactExpr(getSCEV(BO->LHS), MulCount), + getTruncateExpr(ShiftedLHS, IntegerType::get(getContext(), BitWidth - LZ - TZ)), BO->LHS->getType()), MulCount); @@ -5211,7 +5502,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // If C is a low-bits mask, the zero extend is serving to // mask off the high bits. Complement the operand and // re-apply the zext. - if (APIntOps::isMask(Z0TySize, CI->getValue())) + if (CI->getValue().isMask(Z0TySize)) return getZeroExtendExpr(getNotSCEV(Z0), UTy); // If C is a single bit, it may be in the sign-bit position @@ -5219,7 +5510,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // using an add, which is equivalent, and re-apply the zext. APInt Trunc = CI->getValue().trunc(Z0TySize); if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && - Trunc.isSignBit()) + Trunc.isSignMask()) return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), UTy); } @@ -5255,28 +5546,55 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) - if (Operator *L = dyn_cast<Operator>(BO->LHS)) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == BO->RHS) { - uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; + // AShr X, C, where C is a constant. + ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS); + if (!CI) + break; + + Type *OuterTy = BO->LHS->getType(); + uint64_t BitWidth = getTypeSizeInBits(OuterTy); + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; - uint64_t Amt = BitWidth - CI->getZExtValue(); - if (Amt == BitWidth) - return getSCEV(L->getOperand(0)); // shift by zero --> noop + if (CI->isNullValue()) + return getSCEV(BO->LHS); // shift by zero --> noop + + uint64_t AShrAmt = CI->getZExtValue(); + Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); + + Operator *L = dyn_cast<Operator>(BO->LHS); + if (L && L->getOpcode() == Instruction::Shl) { + // X = Shl A, n + // Y = AShr X, m + // Both n and m are constant. + + const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); + if (L->getOperand(1) == BO->RHS) + // For a two-shift sext-inreg, i.e. n = m, + // use sext(trunc(x)) as the SCEV expression. + return getSignExtendExpr( + getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy); + + ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1)); + if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) { + uint64_t ShlAmt = ShlAmtCI->getZExtValue(); + if (ShlAmt > AShrAmt) { + // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV + // expression. We already checked that ShlAmt < BitWidth, so + // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as + // ShlAmt - AShrAmt < Amt. + APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, + ShlAmt - AShrAmt); return getSignExtendExpr( - getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(getContext(), Amt)), - BO->LHS->getType()); + getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy), + getConstant(Mul)), OuterTy); } + } + } break; } } @@ -5348,7 +5666,7 @@ static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { return ((unsigned)ExitConst->getZExtValue()) + 1; } -unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripCount(L, ExitingBB); @@ -5356,7 +5674,7 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { return 0; } -unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && @@ -5366,13 +5684,13 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, return getConstantTripCount(ExitCount); } -unsigned ScalarEvolution::getSmallConstantMaxTripCount(Loop *L) { +unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) { const auto *MaxExitCount = dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L)); return getConstantTripCount(MaxExitCount); } -unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { +unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { if (BasicBlock *ExitingBB = L->getExitingBlock()) return getSmallConstantTripMultiple(L, ExitingBB); @@ -5393,7 +5711,7 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { /// As explained in the comments for getSmallConstantTripCount, this assumes /// that control exits the loop via ExitingBlock. unsigned -ScalarEvolution::getSmallConstantTripMultiple(Loop *L, +ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && @@ -5403,17 +5721,16 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, return 1; // Get the trip count from the BE count by adding 1. - const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType())); - // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt - // to factor simple cases. - if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul)) - TCMul = Mul->getOperand(0); - - const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul); - if (!MulC) - return 1; + const SCEV *TCExpr = getAddExpr(ExitCount, getOne(ExitCount->getType())); + + const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr); + if (!TC) + // Attempt to factor more general cases. Returns the greatest power of + // two divisor. If overflow happens, the trip count expression is still + // divisible by the greatest power of 2 divisor returned. + return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr)); - ConstantInt *Result = MulC->getValue(); + ConstantInt *Result = TC->getValue(); // Guard against huge trip counts (this requires checking // for zero to handle the case where the trip count == -1 and the @@ -5428,7 +5745,8 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, /// Get the expression for the number of loop iterations for which this loop is /// guaranteed not to exit via ExitingBlock. Otherwise return /// SCEVCouldNotCompute. -const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { +const SCEV *ScalarEvolution::getExitCount(const Loop *L, + BasicBlock *ExitingBlock) { return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); } @@ -5681,6 +5999,8 @@ ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax()) return SE->getCouldNotCompute(); + assert((isa<SCEVCouldNotCompute>(getMax()) || isa<SCEVConstant>(getMax())) && + "No point in having a non-constant max backedge taken count!"); return getMax(); } @@ -5705,6 +6025,45 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, return false; } +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) + : ExactNotTaken(E), MaxNotTaken(E), MaxOrZero(false) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + +ScalarEvolution::ExitLimit::ExitLimit( + const SCEV *E, const SCEV *M, bool MaxOrZero, + ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList) + : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) { + assert((isa<SCEVCouldNotCompute>(ExactNotTaken) || + !isa<SCEVCouldNotCompute>(MaxNotTaken)) && + "Exact is not allowed to be less precise than Max"); + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); + for (auto *PredSet : PredSetList) + for (auto *P : *PredSet) + addPredicate(P); +} + +ScalarEvolution::ExitLimit::ExitLimit( + const SCEV *E, const SCEV *M, bool MaxOrZero, + const SmallPtrSetImpl<const SCEVPredicate *> &PredSet) + : ExitLimit(E, M, MaxOrZero, {&PredSet}) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M, + bool MaxOrZero) + : ExitLimit(E, M, MaxOrZero, None) { + assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || + isa<SCEVConstant>(MaxNotTaken)) && + "No point in having a non-constant max backedge taken count!"); +} + /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( @@ -5728,6 +6087,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate)); }); + assert((isa<SCEVCouldNotCompute>(MaxCount) || isa<SCEVConstant>(MaxCount)) && + "No point in having a non-constant max backedge taken count!"); } /// Invalidate this result and free the ExitNotTakenInfo array. @@ -5886,24 +6247,74 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, return getCouldNotCompute(); } -ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimitFromCond(const Loop *L, - Value *ExitCond, - BasicBlock *TBB, - BasicBlock *FBB, - bool ControlsExit, - bool AllowPredicates) { +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( + const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, + bool ControlsExit, bool AllowPredicates) { + ScalarEvolution::ExitLimitCacheTy Cache(L, TBB, FBB, AllowPredicates); + return computeExitLimitFromCondCached(Cache, L, ExitCond, TBB, FBB, + ControlsExit, AllowPredicates); +} + +Optional<ScalarEvolution::ExitLimit> +ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, + BasicBlock *TBB, BasicBlock *FBB, + bool ControlsExit, bool AllowPredicates) { + (void)this->L; + (void)this->TBB; + (void)this->FBB; + (void)this->AllowPredicates; + + assert(this->L == L && this->TBB == TBB && this->FBB == FBB && + this->AllowPredicates == AllowPredicates && + "Variance in assumed invariant key components!"); + auto Itr = TripCountMap.find({ExitCond, ControlsExit}); + if (Itr == TripCountMap.end()) + return None; + return Itr->second; +} + +void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, + BasicBlock *TBB, BasicBlock *FBB, + bool ControlsExit, + bool AllowPredicates, + const ExitLimit &EL) { + assert(this->L == L && this->TBB == TBB && this->FBB == FBB && + this->AllowPredicates == AllowPredicates && + "Variance in assumed invariant key components!"); + + auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL}); + assert(InsertResult.second && "Expected successful insertion!"); + (void)InsertResult; +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, BasicBlock *TBB, + BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { + + if (auto MaybeEL = + Cache.find(L, ExitCond, TBB, FBB, ControlsExit, AllowPredicates)) + return *MaybeEL; + + ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, TBB, FBB, + ControlsExit, AllowPredicates); + Cache.insert(L, ExitCond, TBB, FBB, ControlsExit, AllowPredicates, EL); + return EL; +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, BasicBlock *TBB, + BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); - ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit, - AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit, - AllowPredicates); + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit, + AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -5939,7 +6350,7 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, // to not. if (isa<SCEVCouldNotCompute>(MaxBECount) && !isa<SCEVCouldNotCompute>(BECount)) - MaxBECount = BECount; + MaxBECount = getConstant(getUnsignedRange(BECount).getUnsignedMax()); return ExitLimit(BECount, MaxBECount, false, {&EL0.Predicates, &EL1.Predicates}); @@ -5947,12 +6358,12 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); - ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit, - AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit, - AllowPredicates); + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit, + AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -6337,13 +6748,12 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most // bitwidth(K) iterations. Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); - bool KnownZero, KnownOne; - ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr, - Predecessor->getTerminator(), &DT); + KnownBits Known = computeKnownBits(FirstValue, DL, 0, nullptr, + Predecessor->getTerminator(), &DT); auto *Ty = cast<IntegerType>(RHS->getType()); - if (KnownZero) + if (Known.isNonNegative()) StableValue = ConstantInt::get(Ty, 0); - else if (KnownOne) + else if (Known.isNegative()) StableValue = ConstantInt::get(Ty, -1, true); else return getCouldNotCompute(); @@ -6408,7 +6818,10 @@ static bool canConstantEvolve(Instruction *I, const Loop *L) { /// recursing through each instruction operand until reaching a loop header phi. static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, - DenseMap<Instruction *, PHINode *> &PHIMap) { + DenseMap<Instruction *, PHINode *> &PHIMap, + unsigned Depth) { + if (Depth > MaxConstantEvolvingDepth) + return nullptr; // Otherwise, we can evaluate this instruction if all of its operands are // constant or derived from a PHI node themselves. @@ -6428,7 +6841,7 @@ getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, if (!P) { // Recurse and memoize the results, whether a phi is found or not. // This recursive call invalidates pointers into PHIMap. - P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap); + P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1); PHIMap[OpInst] = P; } if (!P) @@ -6455,7 +6868,7 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { // Record non-constant instructions contained by the loop. DenseMap<Instruction *, PHINode *> PHIMap; - return getConstantEvolvingPHIOperands(I, L, PHIMap); + return getConstantEvolvingPHIOperands(I, L, PHIMap, 0); } /// EvaluateExpression - Given an expression that passes the @@ -7014,10 +7427,10 @@ const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { /// A and B isn't important. /// /// If the equation does not have a solution, SCEVCouldNotCompute is returned. -static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, +static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE) { uint32_t BW = A.getBitWidth(); - assert(BW == B.getBitWidth() && "Bit widths must be the same."); + assert(BW == SE.getTypeSizeInBits(B->getType())); assert(A != 0 && "A must be non-zero."); // 1. D = gcd(A, N) @@ -7031,7 +7444,7 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // // B is divisible by D if and only if the multiplicity of prime factor 2 for B // is not less than multiplicity of this prime factor for D. - if (B.countTrailingZeros() < Mult2) + if (SE.GetMinTrailingZeros(B) < Mult2) return SE.getCouldNotCompute(); // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic @@ -7049,9 +7462,8 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // I * (B / D) mod (N / D) // To simplify the computation, we factor out the divide by D: // (I * B mod N) / D - APInt Result = (I * B).lshr(Mult2); - - return SE.getConstant(Result); + const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); + return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); } /// Find the roots of the quadratic equation for the given quadratic chrec @@ -7074,50 +7486,50 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { const APInt &M = MC->getAPInt(); const APInt &N = NC->getAPInt(); APInt Two(BitWidth, 2); - APInt Four(BitWidth, 4); - - { - using namespace APIntOps; - const APInt& C = L; - // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C - // The B coefficient is M-N/2 - APInt B(M); - B -= sdiv(N,Two); - - // The A coefficient is N/2 - APInt A(N.sdiv(Two)); - - // Compute the B^2-4ac term. - APInt SqrtTerm(B); - SqrtTerm *= B; - SqrtTerm -= Four * (A * C); - - if (SqrtTerm.isNegative()) { - // The loop is provably infinite. - return None; - } - // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest - // integer value or else APInt::sqrt() will assert. - APInt SqrtVal(SqrtTerm.sqrt()); + // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C - // Compute the two solutions for the quadratic formula. - // The divisions must be performed as signed divisions. - APInt NegB(-B); - APInt TwoA(A << 1); - if (TwoA.isMinValue()) - return None; + // The A coefficient is N/2 + APInt A = N.sdiv(Two); + + // The B coefficient is M-N/2 + APInt B = M; + B -= A; // A is the same as N/2. + + // The C coefficient is L. + const APInt& C = L; - LLVMContext &Context = SE.getContext(); + // Compute the B^2-4ac term. + APInt SqrtTerm = B; + SqrtTerm *= B; + SqrtTerm -= 4 * (A * C); - ConstantInt *Solution1 = - ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); - ConstantInt *Solution2 = - ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + if (SqrtTerm.isNegative()) { + // The loop is provably infinite. + return None; + } + + // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest + // integer value or else APInt::sqrt() will assert. + APInt SqrtVal = SqrtTerm.sqrt(); + + // Compute the two solutions for the quadratic formula. + // The divisions must be performed as signed divisions. + APInt NegB = -std::move(B); + APInt TwoA = std::move(A); + TwoA <<= 1; + if (TwoA.isNullValue()) + return None; - return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), - cast<SCEVConstant>(SE.getConstant(Solution2))); - } // end APIntOps namespace + LLVMContext &Context = SE.getContext(); + + ConstantInt *Solution1 = + ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); + ConstantInt *Solution2 = + ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + + return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), + cast<SCEVConstant>(SE.getConstant(Solution2))); } ScalarEvolution::ExitLimit @@ -7233,62 +7645,6 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); } - // As a special case, handle the instance where Step is a positive power of - // two. In this case, determining whether Step divides Distance evenly can be - // done by counting and comparing the number of trailing zeros of Step and - // Distance. - if (!CountDown) { - const APInt &StepV = StepC->getAPInt(); - // StepV.isPowerOf2() returns true if StepV is an positive power of two. It - // also returns true if StepV is maximally negative (eg, INT_MIN), but that - // case is not handled as this code is guarded by !CountDown. - if (StepV.isPowerOf2() && - GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) { - // Here we've constrained the equation to be of the form - // - // 2^(N + k) * Distance' = (StepV == 2^N) * X (mod 2^W) ... (0) - // - // where we're operating on a W bit wide integer domain and k is - // non-negative. The smallest unsigned solution for X is the trip count. - // - // (0) is equivalent to: - // - // 2^(N + k) * Distance' - 2^N * X = L * 2^W - // <=> 2^N(2^k * Distance' - X) = L * 2^(W - N) * 2^N - // <=> 2^k * Distance' - X = L * 2^(W - N) - // <=> 2^k * Distance' = L * 2^(W - N) + X ... (1) - // - // The smallest X satisfying (1) is unsigned remainder of dividing the LHS - // by 2^(W - N). - // - // <=> X = 2^k * Distance' URem 2^(W - N) ... (2) - // - // E.g. say we're solving - // - // 2 * Val = 2 * X (in i8) ... (3) - // - // then from (2), we get X = Val URem i8 128 (k = 0 in this case). - // - // Note: It is tempting to solve (3) by setting X = Val, but Val is not - // necessarily the smallest unsigned value of X that satisfies (3). - // E.g. if Val is i8 -127 then the smallest value of X that satisfies (3) - // is i8 1, not i8 -127 - - const auto *ModuloResult = getUDivExactExpr(Distance, Step); - - // Since SCEV does not have a URem node, we construct one using a truncate - // and a zero extend. - - unsigned NarrowWidth = StepV.getBitWidth() - StepV.countTrailingZeros(); - auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth); - auto *WideTy = Distance->getType(); - - const SCEV *Limit = - getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); - return ExitLimit(Limit, Limit, false, Predicates); - } - } - // If the condition controls loop exit (the loop exits only if the expression // is true) and the addition is no-wrap we can use unsigned divide to // compute the backedge count. In this case, the step may not divide the @@ -7298,16 +7654,20 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact, false, Predicates); + const SCEV *Max = + Exact == getCouldNotCompute() + ? Exact + : getConstant(getUnsignedRange(Exact).getUnsignedMax()); + return ExitLimit(Exact, Max, false, Predicates); } - // Then, try to solve the above equation provided that Start is constant. - if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { - const SCEV *E = SolveLinEquationWithOverflow( - StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); - return ExitLimit(E, E, false, Predicates); - } - return getCouldNotCompute(); + // Solve the general equation. + const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(), + getNegativeSCEV(Start), *this); + const SCEV *M = E == getCouldNotCompute() + ? E + : getConstant(getUnsignedRange(E).getUnsignedMax()); + return ExitLimit(E, M, false, Predicates); } ScalarEvolution::ExitLimit @@ -7822,6 +8182,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, case ICmpInst::ICMP_SGE: std::swap(LHS, RHS); + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLE: // X s<= (X + C)<nsw> if C >= 0 if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative()) @@ -7835,6 +8196,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, case ICmpInst::ICMP_SGT: std::swap(LHS, RHS); + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLT: // X s< (X + C)<nsw> if C > 0 if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && @@ -8192,6 +8554,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin))) return true; + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: @@ -8206,6 +8569,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) return true; + LLVM_FALLTHROUGH; default: // No change @@ -8488,19 +8852,161 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, llvm_unreachable("covered switch fell through?!"); } +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, + unsigned Depth) { + assert(getTypeSizeInBits(LHS->getType()) == + getTypeSizeInBits(RHS->getType()) && + "LHS and RHS have different sizes?"); + assert(getTypeSizeInBits(FoundLHS->getType()) == + getTypeSizeInBits(FoundRHS->getType()) && + "FoundLHS and FoundRHS have different sizes?"); + // We want to avoid hurting the compile time with analysis of too big trees. + if (Depth > MaxSCEVOperationsImplicationDepth) + return false; + // We only want to work with ICMP_SGT comparison so far. + // TODO: Extend to ICMP_UGT? + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SGT; + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + } + if (Pred != ICmpInst::ICMP_SGT) + return false; + + auto GetOpFromSExt = [&](const SCEV *S) { + if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S)) + return Ext->getOperand(); + // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off + // the constant in some cases. + return S; + }; + + // Acquire values from extensions. + auto *OrigFoundLHS = FoundLHS; + LHS = GetOpFromSExt(LHS); + FoundLHS = GetOpFromSExt(FoundLHS); + + // Is the SGT predicate can be proved trivially or using the found context. + auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { + return isKnownViaSimpleReasoning(ICmpInst::ICMP_SGT, S1, S2) || + isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, + FoundRHS, Depth + 1); + }; + + if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) { + // We want to avoid creation of any new non-constant SCEV. Since we are + // going to compare the operands to RHS, we should be certain that we don't + // need any size extensions for this. So let's decline all cases when the + // sizes of types of LHS and RHS do not match. + // TODO: Maybe try to get RHS from sext to catch more cases? + if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType())) + return false; + + // Should not overflow. + if (!LHSAddExpr->hasNoSignedWrap()) + return false; + + auto *LL = LHSAddExpr->getOperand(0); + auto *LR = LHSAddExpr->getOperand(1); + auto *MinusOne = getNegativeSCEV(getOne(RHS->getType())); + + // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. + auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { + return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); + }; + // Try to prove the following rule: + // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). + // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). + if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL)) + return true; + } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) { + Value *LL, *LR; + // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { + // Rules for division. + // We are going to perform some comparisons with Denominator and its + // derivative expressions. In general case, creating a SCEV for it may + // lead to a complex analysis of the entire graph, and in particular it + // can request trip count recalculation for the same loop. This would + // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid + // this, we only want to create SCEVs that are constants in this section. + // So we bail if Denominator is not a constant. + if (!isa<ConstantInt>(LR)) + return false; + + auto *Denominator = cast<SCEVConstant>(getSCEV(LR)); + + // We want to make sure that LHS = FoundLHS / Denominator. If it is so, + // then a SCEV for the numerator already exists and matches with FoundLHS. + auto *Numerator = getExistingSCEV(LL); + if (!Numerator || Numerator->getType() != FoundLHS->getType()) + return false; + + // Make sure that the numerator matches with FoundLHS and the denominator + // is positive. + if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator)) + return false; + + auto *DTy = Denominator->getType(); + auto *FRHSTy = FoundRHS->getType(); + if (DTy->isPointerTy() != FRHSTy->isPointerTy()) + // One of types is a pointer and another one is not. We cannot extend + // them properly to a wider type, so let us just reject this case. + // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help + // to avoid this check. + return false; + + // Given that: + // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. + auto *WTy = getWiderType(DTy, FRHSTy); + auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); + auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); + + // Try to prove the following rule: + // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). + // For example, given that FoundLHS > 2. It means that FoundLHS is at + // least 3. If we divide it by Denominator < 4, we will have at least 1. + auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); + if (isKnownNonPositive(RHS) && + IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) + return true; + + // Try to prove the following rule: + // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS). + // For example, given that FoundLHS > -3. Then FoundLHS is at least -2. + // If we divide it by Denominator > 2, then: + // 1. If FoundLHS is negative, then the result is 0. + // 2. If FoundLHS is non-negative, then the result is non-negative. + // Anyways, the result is non-negative. + auto *MinusOne = getNegativeSCEV(getOne(WTy)); + auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); + if (isKnownNegative(RHS) && + IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) + return true; + } + } + + return false; +} + +bool +ScalarEvolution::isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || + isKnownPredicateViaNoOverflow(Pred, LHS, RHS); +} + bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { - auto IsKnownPredicateFull = - [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || - IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || - IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || - isKnownPredicateViaNoOverflow(Pred, LHS, RHS); - }; - switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -8510,30 +9016,34 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } + // Maybe it can be proved via operations? + if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return false; } @@ -8551,7 +9061,7 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, if (!Addend) return false; - APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); + const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the // antecedent "`FoundLHS` `Pred` `FoundRHS`". @@ -8563,7 +9073,7 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, // We can also compute the range of values for `LHS` that satisfy the // consequent, "`LHS` `Pred` `RHS`": - APInt ConstRHS = cast<SCEVConstant>(RHS)->getAPInt(); + const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt(); ConstantRange SatisfyingLHSRange = ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); @@ -8588,7 +9098,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, .getSignedMax(); // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow! - return (MaxValue - MaxStrideMinusOne).slt(MaxRHS); + return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS); } APInt MaxRHS = getUnsignedRange(RHS).getUnsignedMax(); @@ -8597,7 +9107,7 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, .getUnsignedMax(); // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow! - return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); + return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS); } bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, @@ -8614,7 +9124,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, .getSignedMax(); // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! - return (MinValue + MaxStrideMinusOne).sgt(MinRHS); + return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS); } APInt MinRHS = getUnsignedRange(RHS).getUnsignedMin(); @@ -8623,7 +9133,7 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, .getUnsignedMax(); // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! - return (MinValue + MaxStrideMinusOne).ugt(MinRHS); + return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); } const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, @@ -8790,8 +9300,9 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, getConstant(StrideForMaxBECount), false); } - if (isa<SCEVCouldNotCompute>(MaxBECount)) - MaxBECount = BECount; + if (isa<SCEVCouldNotCompute>(MaxBECount) && + !isa<SCEVCouldNotCompute>(BECount)) + MaxBECount = getConstant(getUnsignedRange(BECount).getUnsignedMax()); return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); } @@ -8914,9 +9425,8 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, // the upper value of the range must be the first possible exit value. // If A is negative then the lower of the range is the last possible loop // value. Also note that we already checked for a full range. - APInt One(BitWidth,1); APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt(); - APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); + APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower(); // The exit value should be (End+A)/A. APInt ExitVal = (End + A).udiv(A); @@ -8932,7 +9442,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, // Ensure that the previous value is in the range. This is a sanity check. assert(Range.contains( EvaluateConstantChrecAtConstant(this, - ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) && + ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && "Linear scev computation is off in a bad way!"); return SE.getConstant(ExitValue); } else if (isQuadratic()) { @@ -9083,8 +9593,11 @@ struct SCEVCollectAddRecMultiplies { bool HasAddRec = false; SmallVector<const SCEV *, 0> Operands; for (auto Op : Mul->operands()) { - if (isa<SCEVUnknown>(Op)) { + const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Op); + if (Unknown && !isa<CallInst>(Unknown->getValue())) { Operands.push_back(Op); + } else if (Unknown) { + HasAddRec = true; } else { bool ContainsAddRec; SCEVHasAddRec ContiansAddRec(ContainsAddRec); @@ -9238,7 +9751,7 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, SmallVectorImpl<const SCEV *> &Sizes, - const SCEV *ElementSize) const { + const SCEV *ElementSize) { if (Terms.size() < 1 || !ElementSize) return; @@ -9254,7 +9767,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, }); // Remove duplicates. - std::sort(Terms.begin(), Terms.end()); + array_pod_sort(Terms.begin(), Terms.end()); Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); // Put larger terms first. @@ -9262,13 +9775,11 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, return numberOfTerms(LHS) > numberOfTerms(RHS); }); - ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); - // Try to divide all terms by the element size. If term is not divisible by // element size, proceed with the original term. for (const SCEV *&Term : Terms) { const SCEV *Q, *R; - SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); + SCEVDivision::divide(*this, Term, ElementSize, &Q, &R); if (!Q->isZero()) Term = Q; } @@ -9277,7 +9788,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, // Remove constant factors. for (const SCEV *T : Terms) - if (const SCEV *NewT = removeConstantFactors(SE, T)) + if (const SCEV *NewT = removeConstantFactors(*this, T)) NewTerms.push_back(NewT); DEBUG({ @@ -9286,8 +9797,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, dbgs() << *T << "\n"; }); - if (NewTerms.empty() || - !findArrayDimensionsRec(SE, NewTerms, Sizes)) { + if (NewTerms.empty() || !findArrayDimensionsRec(*this, NewTerms, Sizes)) { Sizes.clear(); return; } @@ -9524,6 +10034,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), + MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( std::move(Arg.PredicatedBackedgeTakenCounts)), @@ -9621,6 +10132,13 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Unpredictable predicated backedge-taken count. "; } OS << "\n"; + + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + OS << "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n"; + } } static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { @@ -9929,6 +10447,7 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { SignedRanges.erase(S); ExprValueMap.erase(S); HasRecMap.erase(S); + MinTrailingZerosCache.erase(S); auto RemoveSCEVFromBackedgeMap = [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { @@ -9946,84 +10465,75 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } -typedef DenseMap<const Loop *, std::string> VerifyMap; +void ScalarEvolution::verify() const { + ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); + ScalarEvolution SE2(F, TLI, AC, DT, LI); -/// replaceSubString - Replaces all occurrences of From in Str with To. -static void replaceSubString(std::string &Str, StringRef From, StringRef To) { - size_t Pos = 0; - while ((Pos = Str.find(From, Pos)) != std::string::npos) { - Str.replace(Pos, From.size(), To.data(), To.size()); - Pos += To.size(); - } -} + SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end()); -/// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis. -static void -getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) { - std::string &S = Map[L]; - if (S.empty()) { - raw_string_ostream OS(S); - SE.getBackedgeTakenCount(L)->print(OS); + // Map's SCEV expressions from one ScalarEvolution "universe" to another. + struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> { + const SCEV *visitConstant(const SCEVConstant *Constant) { + return SE.getConstant(Constant->getAPInt()); + } + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + return SE.getUnknown(Expr->getValue()); + } - // false and 0 are semantically equivalent. This can happen in dead loops. - replaceSubString(OS.str(), "false", "0"); - // Remove wrap flags, their use in SCEV is highly fragile. - // FIXME: Remove this when SCEV gets smarter about them. - replaceSubString(OS.str(), "<nw>", ""); - replaceSubString(OS.str(), "<nsw>", ""); - replaceSubString(OS.str(), "<nuw>", ""); - } + const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return SE.getCouldNotCompute(); + } + SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {} + }; - for (auto *R : reverse(*L)) - getLoopBackedgeTakenCounts(R, Map, SE); // recurse. -} + SCEVMapper SCM(SE2); -void ScalarEvolution::verify() const { - ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); + while (!LoopStack.empty()) { + auto *L = LoopStack.pop_back_val(); + LoopStack.insert(LoopStack.end(), L->begin(), L->end()); - // Gather stringified backedge taken counts for all loops using SCEV's caches. - // FIXME: It would be much better to store actual values instead of strings, - // but SCEV pointers will change if we drop the caches. - VerifyMap BackedgeDumpsOld, BackedgeDumpsNew; - for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) - getLoopBackedgeTakenCounts(*I, BackedgeDumpsOld, SE); + auto *CurBECount = SCM.visit( + const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L)); + auto *NewBECount = SE2.getBackedgeTakenCount(L); - // Gather stringified backedge taken counts for all loops using a fresh - // ScalarEvolution object. - ScalarEvolution SE2(F, TLI, AC, DT, LI); - for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) - getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE2); - - // Now compare whether they're the same with and without caches. This allows - // verifying that no pass changed the cache. - assert(BackedgeDumpsOld.size() == BackedgeDumpsNew.size() && - "New loops suddenly appeared!"); - - for (VerifyMap::iterator OldI = BackedgeDumpsOld.begin(), - OldE = BackedgeDumpsOld.end(), - NewI = BackedgeDumpsNew.begin(); - OldI != OldE; ++OldI, ++NewI) { - assert(OldI->first == NewI->first && "Loop order changed!"); - - // Compare the stringified SCEVs. We don't care if undef backedgetaken count - // changes. - // FIXME: We currently ignore SCEV changes from/to CouldNotCompute. This - // means that a pass is buggy or SCEV has to learn a new pattern but is - // usually not harmful. - if (OldI->second != NewI->second && - OldI->second.find("undef") == std::string::npos && - NewI->second.find("undef") == std::string::npos && - OldI->second != "***COULDNOTCOMPUTE***" && - NewI->second != "***COULDNOTCOMPUTE***") { - dbgs() << "SCEVValidator: SCEV for loop '" - << OldI->first->getHeader()->getName() - << "' changed from '" << OldI->second - << "' to '" << NewI->second << "'!\n"; + if (CurBECount == SE2.getCouldNotCompute() || + NewBECount == SE2.getCouldNotCompute()) { + // NB! This situation is legal, but is very suspicious -- whatever pass + // change the loop to make a trip count go from could not compute to + // computable or vice-versa *should have* invalidated SCEV. However, we + // choose not to assert here (for now) since we don't want false + // positives. + continue; + } + + if (containsUndefs(CurBECount) || containsUndefs(NewBECount)) { + // SCEV treats "undef" as an unknown but consistent value (i.e. it does + // not propagate undef aggressively). This means we can (and do) fail + // verification in cases where a transform makes the trip count of a loop + // go from "undef" to "undef+1" (say). The transform is fine, since in + // both cases the loop iterates "undef" times, but SCEV thinks we + // increased the trip count of the loop by 1 incorrectly. + continue; + } + + if (SE.getTypeSizeInBits(CurBECount->getType()) > + SE.getTypeSizeInBits(NewBECount->getType())) + NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType()); + else if (SE.getTypeSizeInBits(CurBECount->getType()) < + SE.getTypeSizeInBits(NewBECount->getType())) + CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType()); + + auto *ConstantDelta = + dyn_cast<SCEVConstant>(SE2.getMinusSCEV(CurBECount, NewBECount)); + + if (ConstantDelta && ConstantDelta->getAPInt() != 0) { + dbgs() << "Trip Count Changed!\n"; + dbgs() << "Old: " << *CurBECount << "\n"; + dbgs() << "New: " << *NewBECount << "\n"; + dbgs() << "Delta: " << *ConstantDelta << "\n"; std::abort(); } } - - // TODO: Verify more things. } bool ScalarEvolution::invalidate( diff --git a/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp index d15a7dbd20e6..f9b9df2bc707 100644 --- a/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1268,8 +1268,7 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { if (PostIncLoops.count(L)) { PostIncLoopSet Loops; Loops.insert(L); - Normalized = cast<SCEVAddRecExpr>(TransformForPostIncUse( - Normalize, S, nullptr, nullptr, Loops, SE, SE.DT)); + Normalized = cast<SCEVAddRecExpr>(normalizeForPostIncUse(S, Loops, SE)); } // Strip off any non-loop-dominating component from the addrec start. @@ -1306,12 +1305,17 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { // Expand the core addrec. If we need post-loop scaling, force it to // expand to an integer type to avoid the need for additional casting. Type *ExpandTy = PostLoopScale ? IntTy : STy; + // We can't use a pointer type for the addrec if the pointer type is + // non-integral. + Type *AddRecPHIExpandTy = + DL.isNonIntegralPointerType(STy) ? Normalized->getType() : ExpandTy; + // In some cases, we decide to reuse an existing phi node but need to truncate // it and/or invert the step. Type *TruncTy = nullptr; bool InvertStep = false; - PHINode *PN = getAddRecExprPHILiterally(Normalized, L, ExpandTy, IntTy, - TruncTy, InvertStep); + PHINode *PN = getAddRecExprPHILiterally(Normalized, L, AddRecPHIExpandTy, + IntTy, TruncTy, InvertStep); // Accommodate post-inc mode, if necessary. Value *Result; @@ -1384,8 +1388,15 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { // Re-apply any non-loop-dominating offset. if (PostLoopOffset) { if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) { - const SCEV *const OffsetArray[1] = { PostLoopOffset }; - Result = expandAddToGEP(OffsetArray, OffsetArray+1, PTy, IntTy, Result); + if (Result->getType()->isIntegerTy()) { + Value *Base = expandCodeFor(PostLoopOffset, ExpandTy); + const SCEV *const OffsetArray[1] = {SE.getUnknown(Result)}; + Result = expandAddToGEP(OffsetArray, OffsetArray + 1, PTy, IntTy, Base); + } else { + const SCEV *const OffsetArray[1] = {PostLoopOffset}; + Result = + expandAddToGEP(OffsetArray, OffsetArray + 1, PTy, IntTy, Result); + } } else { Result = InsertNoopCastOfTo(Result, IntTy); Result = Builder.CreateAdd(Result, @@ -1773,9 +1784,10 @@ SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L, /// /// This does not depend on any SCEVExpander state but should be used in /// the same context that SCEVExpander is used. -unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, - SmallVectorImpl<WeakVH> &DeadInsts, - const TargetTransformInfo *TTI) { +unsigned +SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, + SmallVectorImpl<WeakTrackingVH> &DeadInsts, + const TargetTransformInfo *TTI) { // Find integer phis in order of increasing width. SmallVector<PHINode*, 8> Phis; for (auto &I : *L->getHeader()) { @@ -1800,7 +1812,7 @@ unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, // so narrow phis can reuse them. for (PHINode *Phi : Phis) { auto SimplifyPHINode = [&](PHINode *PN) -> Value * { - if (Value *V = SimplifyInstruction(PN, DL, &SE.TLI, &SE.DT, &SE.AC)) + if (Value *V = SimplifyInstruction(PN, {DL, &SE.TLI, &SE.DT, &SE.AC})) return V; if (!SE.isSCEVable(PN->getType())) return nullptr; diff --git a/contrib/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp b/contrib/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp index c1f9503816ee..54c44c8e542d 100644 --- a/contrib/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp +++ b/contrib/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -12,243 +12,107 @@ // //===----------------------------------------------------------------------===// -#include "llvm/IR/Dominators.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolutionNormalization.h" using namespace llvm; -/// IVUseShouldUsePostIncValue - We have discovered a "User" of an IV expression -/// and now we need to decide whether the user should use the preinc or post-inc -/// value. If this user should use the post-inc version of the IV, return true. -/// -/// Choosing wrong here can break dominance properties (if we choose to use the -/// post-inc value when we cannot) or it can end up adding extra live-ranges to -/// the loop, resulting in reg-reg copies (if we use the pre-inc value when we -/// should use the post-inc value). -static bool IVUseShouldUsePostIncValue(Instruction *User, Value *Operand, - const Loop *L, DominatorTree *DT) { - // If the user is in the loop, use the preinc value. - if (L->contains(User)) return false; - - BasicBlock *LatchBlock = L->getLoopLatch(); - if (!LatchBlock) - return false; - - // Ok, the user is outside of the loop. If it is dominated by the latch - // block, use the post-inc value. - if (DT->dominates(LatchBlock, User->getParent())) - return true; - - // There is one case we have to be careful of: PHI nodes. These little guys - // can live in blocks that are not dominated by the latch block, but (since - // their uses occur in the predecessor block, not the block the PHI lives in) - // should still use the post-inc value. Check for this case now. - PHINode *PN = dyn_cast<PHINode>(User); - if (!PN || !Operand) return false; // not a phi, not dominated by latch block. - - // Look at all of the uses of Operand by the PHI node. If any use corresponds - // to a block that is not dominated by the latch block, give up and use the - // preincremented value. - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (PN->getIncomingValue(i) == Operand && - !DT->dominates(LatchBlock, PN->getIncomingBlock(i))) - return false; - - // Okay, all uses of Operand by PN are in predecessor blocks that really are - // dominated by the latch block. Use the post-incremented value. - return true; -} +/// TransformKind - Different types of transformations that +/// TransformForPostIncUse can do. +enum TransformKind { + /// Normalize - Normalize according to the given loops. + Normalize, + /// Denormalize - Perform the inverse transform on the expression with the + /// given loop set. + Denormalize +}; namespace { - -/// Hold the state used during post-inc expression transformation, including a -/// map of transformed expressions. -class PostIncTransform { - TransformKind Kind; - PostIncLoopSet &Loops; - ScalarEvolution &SE; - DominatorTree &DT; - - DenseMap<const SCEV*, const SCEV*> Transformed; - -public: - PostIncTransform(TransformKind kind, PostIncLoopSet &loops, - ScalarEvolution &se, DominatorTree &dt): - Kind(kind), Loops(loops), SE(se), DT(dt) {} - - const SCEV *TransformSubExpr(const SCEV *S, Instruction *User, - Value *OperandValToReplace); - -protected: - const SCEV *TransformImpl(const SCEV *S, Instruction *User, - Value *OperandValToReplace); +struct NormalizeDenormalizeRewriter + : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> { + const TransformKind Kind; + + // NB! Pred is a function_ref. Storing it here is okay only because + // we're careful about the lifetime of NormalizeDenormalizeRewriter. + const NormalizePredTy Pred; + + NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred, + ScalarEvolution &SE) + : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind), + Pred(Pred) {} + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); }; - } // namespace -/// Implement post-inc transformation for all valid expression types. -const SCEV *PostIncTransform:: -TransformImpl(const SCEV *S, Instruction *User, Value *OperandValToReplace) { - - if (const SCEVCastExpr *X = dyn_cast<SCEVCastExpr>(S)) { - const SCEV *O = X->getOperand(); - const SCEV *N = TransformSubExpr(O, User, OperandValToReplace); - if (O != N) - switch (S->getSCEVType()) { - case scZeroExtend: return SE.getZeroExtendExpr(N, S->getType()); - case scSignExtend: return SE.getSignExtendExpr(N, S->getType()); - case scTruncate: return SE.getTruncateExpr(N, S->getType()); - default: llvm_unreachable("Unexpected SCEVCastExpr kind!"); - } - return S; +const SCEV * +NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { + SmallVector<const SCEV *, 8> Operands; + + transform(AR->operands(), std::back_inserter(Operands), + [&](const SCEV *Op) { return visit(Op); }); + + if (!Pred(AR)) + return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); + + // Normalization and denormalization are fancy names for decrementing and + // incrementing a SCEV expression with respect to a set of loops. Since + // Pred(AR) has returned true, we know we need to normalize or denormalize AR + // with respect to its loop. + + if (Kind == Denormalize) { + // Denormalization / "partial increment" is essentially the same as \c + // SCEVAddRecExpr::getPostIncExpr. Here we use an explicit loop to make the + // symmetry with Normalization clear. + for (int i = 0, e = Operands.size() - 1; i < e; i++) + Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]); + } else { + assert(Kind == Normalize && "Only two possibilities!"); + + // Normalization / "partial decrement" is a bit more subtle. Since + // incrementing a SCEV expression (in general) changes the step of the SCEV + // expression as well, we cannot use the step of the current expression. + // Instead, we have to use the step of the very expression we're trying to + // compute! + // + // We solve the issue by recursively building up the result, starting from + // the "least significant" operand in the add recurrence: + // + // Base case: + // Single operand add recurrence. It's its own normalization. + // + // N-operand case: + // {S_{N-1},+,S_{N-2},+,...,+,S_0} = S + // + // Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its + // normalization by induction. We subtract the normalized step + // recurrence from S_{N-1} to get the normalization of S. + + for (int i = Operands.size() - 2; i >= 0; i--) + Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]); } - if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { - // An addrec. This is the interesting part. - SmallVector<const SCEV *, 8> Operands; - const Loop *L = AR->getLoop(); - // The addrec conceptually uses its operands at loop entry. - Instruction *LUser = &L->getHeader()->front(); - // Transform each operand. - for (SCEVNAryExpr::op_iterator I = AR->op_begin(), E = AR->op_end(); - I != E; ++I) { - Operands.push_back(TransformSubExpr(*I, LUser, nullptr)); - } - // Conservatively use AnyWrap until/unless we need FlagNW. - const SCEV *Result = SE.getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); - switch (Kind) { - case NormalizeAutodetect: - // Normalize this SCEV by subtracting the expression for the final step. - // We only allow affine AddRecs to be normalized, otherwise we would not - // be able to correctly denormalize. - // e.g. {1,+,3,+,2} == {-2,+,1,+,2} + {3,+,2} - // Normalized form: {-2,+,1,+,2} - // Denormalized form: {1,+,3,+,2} - // - // However, denormalization would use a different step expression than - // normalization (see getPostIncExpr), generating the wrong final - // expression: {-2,+,1,+,2} + {1,+,2} => {-1,+,3,+,2} - if (AR->isAffine() && - IVUseShouldUsePostIncValue(User, OperandValToReplace, L, &DT)) { - const SCEV *TransformedStep = - TransformSubExpr(AR->getStepRecurrence(SE), - User, OperandValToReplace); - Result = SE.getMinusSCEV(Result, TransformedStep); - Loops.insert(L); - } -#if 0 - // This assert is conceptually correct, but ScalarEvolution currently - // sometimes fails to canonicalize two equal SCEVs to exactly the same - // form. It's possibly a pessimization when this happens, but it isn't a - // correctness problem, so disable this assert for now. - assert(S == TransformSubExpr(Result, User, OperandValToReplace) && - "SCEV normalization is not invertible!"); -#endif - break; - case Normalize: - // We want to normalize step expression, because otherwise we might not be - // able to denormalize to the original expression. - // - // Here is an example what will happen if we don't normalize step: - // ORIGINAL ISE: - // {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25> - // NORMALIZED ISE: - // {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+, - // (100 /u {0,+,1}<%bb16>)}<%bb25> - // DENORMALIZED BACK ISE: - // {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+, - // (100 /u {1,+,1}<%bb16>)}<%bb25> - // Note that the initial value changes after normalization + - // denormalization, which isn't correct. - if (Loops.count(L)) { - const SCEV *TransformedStep = - TransformSubExpr(AR->getStepRecurrence(SE), - User, OperandValToReplace); - Result = SE.getMinusSCEV(Result, TransformedStep); - } -#if 0 - // See the comment on the assert above. - assert(S == TransformSubExpr(Result, User, OperandValToReplace) && - "SCEV normalization is not invertible!"); -#endif - break; - case Denormalize: - // Here we want to normalize step expressions for the same reasons, as - // stated above. - if (Loops.count(L)) { - const SCEV *TransformedStep = - TransformSubExpr(AR->getStepRecurrence(SE), - User, OperandValToReplace); - Result = SE.getAddExpr(Result, TransformedStep); - } - break; - } - return Result; - } - - if (const SCEVNAryExpr *X = dyn_cast<SCEVNAryExpr>(S)) { - SmallVector<const SCEV *, 8> Operands; - bool Changed = false; - // Transform each operand. - for (SCEVNAryExpr::op_iterator I = X->op_begin(), E = X->op_end(); - I != E; ++I) { - const SCEV *O = *I; - const SCEV *N = TransformSubExpr(O, User, OperandValToReplace); - Changed |= N != O; - Operands.push_back(N); - } - // If any operand actually changed, return a transformed result. - if (Changed) - switch (S->getSCEVType()) { - case scAddExpr: return SE.getAddExpr(Operands); - case scMulExpr: return SE.getMulExpr(Operands); - case scSMaxExpr: return SE.getSMaxExpr(Operands); - case scUMaxExpr: return SE.getUMaxExpr(Operands); - default: llvm_unreachable("Unexpected SCEVNAryExpr kind!"); - } - return S; - } - - if (const SCEVUDivExpr *X = dyn_cast<SCEVUDivExpr>(S)) { - const SCEV *LO = X->getLHS(); - const SCEV *RO = X->getRHS(); - const SCEV *LN = TransformSubExpr(LO, User, OperandValToReplace); - const SCEV *RN = TransformSubExpr(RO, User, OperandValToReplace); - if (LO != LN || RO != RN) - return SE.getUDivExpr(LN, RN); - return S; - } - - llvm_unreachable("Unexpected SCEV kind!"); + return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); } -/// Manage recursive transformation across an expression DAG. Revisiting -/// expressions would lead to exponential recursion. -const SCEV *PostIncTransform:: -TransformSubExpr(const SCEV *S, Instruction *User, Value *OperandValToReplace) { - - if (isa<SCEVConstant>(S) || isa<SCEVUnknown>(S)) - return S; - - const SCEV *Result = Transformed.lookup(S); - if (Result) - return Result; +const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, + const PostIncLoopSet &Loops, + ScalarEvolution &SE) { + auto Pred = [&](const SCEVAddRecExpr *AR) { + return Loops.count(AR->getLoop()); + }; + return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); +} - Result = TransformImpl(S, User, OperandValToReplace); - Transformed[S] = Result; - return Result; +const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, + ScalarEvolution &SE) { + return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); } -/// Top level driver for transforming an expression DAG into its requested -/// post-inc form (either "Normalized" or "Denormalized"). -const SCEV *llvm::TransformForPostIncUse(TransformKind Kind, - const SCEV *S, - Instruction *User, - Value *OperandValToReplace, - PostIncLoopSet &Loops, - ScalarEvolution &SE, - DominatorTree &DT) { - PostIncTransform Transform(Kind, Loops, SE, DT); - return Transform.TransformSubExpr(S, User, OperandValToReplace); +const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, + const PostIncLoopSet &Loops, + ScalarEvolution &SE) { + auto Pred = [&](const SCEVAddRecExpr *AR) { + return Loops.count(AR->getLoop()); + }; + return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S); } diff --git a/contrib/llvm/lib/Analysis/SparsePropagation.cpp b/contrib/llvm/lib/Analysis/SparsePropagation.cpp index 79dc84e25533..470f4bee1e0a 100644 --- a/contrib/llvm/lib/Analysis/SparsePropagation.cpp +++ b/contrib/llvm/lib/Analysis/SparsePropagation.cpp @@ -195,7 +195,7 @@ void SparseSolver::getFeasibleSuccessors(TerminatorInst &TI, Succs.assign(TI.getNumSuccessors(), true); return; } - SwitchInst::CaseIt Case = SI.findCaseValue(cast<ConstantInt>(C)); + SwitchInst::CaseHandle Case = *SI.findCaseValue(cast<ConstantInt>(C)); Succs[Case.getSuccessorIndex()] = true; } diff --git a/contrib/llvm/lib/Analysis/TargetLibraryInfo.cpp b/contrib/llvm/lib/Analysis/TargetLibraryInfo.cpp index 112118ab77eb..2be5d5caf7c2 100644 --- a/contrib/llvm/lib/Analysis/TargetLibraryInfo.cpp +++ b/contrib/llvm/lib/Analysis/TargetLibraryInfo.cpp @@ -13,6 +13,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/ADT/Triple.h" +#include "llvm/IR/Constants.h" #include "llvm/Support/CommandLine.h" using namespace llvm; @@ -82,24 +83,24 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, if (T.getArch() == Triple::r600 || T.getArch() == Triple::amdgcn) { - TLI.setUnavailable(LibFunc::ldexp); - TLI.setUnavailable(LibFunc::ldexpf); - TLI.setUnavailable(LibFunc::ldexpl); - TLI.setUnavailable(LibFunc::exp10); - TLI.setUnavailable(LibFunc::exp10f); - TLI.setUnavailable(LibFunc::exp10l); - TLI.setUnavailable(LibFunc::log10); - TLI.setUnavailable(LibFunc::log10f); - TLI.setUnavailable(LibFunc::log10l); + TLI.setUnavailable(LibFunc_ldexp); + TLI.setUnavailable(LibFunc_ldexpf); + TLI.setUnavailable(LibFunc_ldexpl); + TLI.setUnavailable(LibFunc_exp10); + TLI.setUnavailable(LibFunc_exp10f); + TLI.setUnavailable(LibFunc_exp10l); + TLI.setUnavailable(LibFunc_log10); + TLI.setUnavailable(LibFunc_log10f); + TLI.setUnavailable(LibFunc_log10l); } // There are no library implementations of mempcy and memset for AMD gpus and // these can be difficult to lower in the backend. if (T.getArch() == Triple::r600 || T.getArch() == Triple::amdgcn) { - TLI.setUnavailable(LibFunc::memcpy); - TLI.setUnavailable(LibFunc::memset); - TLI.setUnavailable(LibFunc::memset_pattern16); + TLI.setUnavailable(LibFunc_memcpy); + TLI.setUnavailable(LibFunc_memset); + TLI.setUnavailable(LibFunc_memset_pattern16); return; } @@ -107,21 +108,21 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, // All versions of watchOS support it. if (T.isMacOSX()) { if (T.isMacOSXVersionLT(10, 5)) - TLI.setUnavailable(LibFunc::memset_pattern16); + TLI.setUnavailable(LibFunc_memset_pattern16); } else if (T.isiOS()) { if (T.isOSVersionLT(3, 0)) - TLI.setUnavailable(LibFunc::memset_pattern16); + TLI.setUnavailable(LibFunc_memset_pattern16); } else if (!T.isWatchOS()) { - TLI.setUnavailable(LibFunc::memset_pattern16); + TLI.setUnavailable(LibFunc_memset_pattern16); } if (!hasSinCosPiStret(T)) { - TLI.setUnavailable(LibFunc::sinpi); - TLI.setUnavailable(LibFunc::sinpif); - TLI.setUnavailable(LibFunc::cospi); - TLI.setUnavailable(LibFunc::cospif); - TLI.setUnavailable(LibFunc::sincospi_stret); - TLI.setUnavailable(LibFunc::sincospif_stret); + TLI.setUnavailable(LibFunc_sinpi); + TLI.setUnavailable(LibFunc_sinpif); + TLI.setUnavailable(LibFunc_cospi); + TLI.setUnavailable(LibFunc_cospif); + TLI.setUnavailable(LibFunc_sincospi_stret); + TLI.setUnavailable(LibFunc_sincospif_stret); } if (T.isMacOSX() && T.getArch() == Triple::x86 && @@ -131,179 +132,223 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, // has a $UNIX2003 suffix. The two implementations are identical except // for the return value in some edge cases. However, we don't want to // generate code that depends on the old symbols. - TLI.setAvailableWithName(LibFunc::fwrite, "fwrite$UNIX2003"); - TLI.setAvailableWithName(LibFunc::fputs, "fputs$UNIX2003"); + TLI.setAvailableWithName(LibFunc_fwrite, "fwrite$UNIX2003"); + TLI.setAvailableWithName(LibFunc_fputs, "fputs$UNIX2003"); } // iprintf and friends are only available on XCore and TCE. if (T.getArch() != Triple::xcore && T.getArch() != Triple::tce) { - TLI.setUnavailable(LibFunc::iprintf); - TLI.setUnavailable(LibFunc::siprintf); - TLI.setUnavailable(LibFunc::fiprintf); + TLI.setUnavailable(LibFunc_iprintf); + TLI.setUnavailable(LibFunc_siprintf); + TLI.setUnavailable(LibFunc_fiprintf); } if (T.isOSWindows() && !T.isOSCygMing()) { // Win32 does not support long double - TLI.setUnavailable(LibFunc::acosl); - TLI.setUnavailable(LibFunc::asinl); - TLI.setUnavailable(LibFunc::atanl); - TLI.setUnavailable(LibFunc::atan2l); - TLI.setUnavailable(LibFunc::ceill); - TLI.setUnavailable(LibFunc::copysignl); - TLI.setUnavailable(LibFunc::cosl); - TLI.setUnavailable(LibFunc::coshl); - TLI.setUnavailable(LibFunc::expl); - TLI.setUnavailable(LibFunc::fabsf); // Win32 and Win64 both lack fabsf - TLI.setUnavailable(LibFunc::fabsl); - TLI.setUnavailable(LibFunc::floorl); - TLI.setUnavailable(LibFunc::fmaxl); - TLI.setUnavailable(LibFunc::fminl); - TLI.setUnavailable(LibFunc::fmodl); - TLI.setUnavailable(LibFunc::frexpl); - TLI.setUnavailable(LibFunc::ldexpf); - TLI.setUnavailable(LibFunc::ldexpl); - TLI.setUnavailable(LibFunc::logl); - TLI.setUnavailable(LibFunc::modfl); - TLI.setUnavailable(LibFunc::powl); - TLI.setUnavailable(LibFunc::sinl); - TLI.setUnavailable(LibFunc::sinhl); - TLI.setUnavailable(LibFunc::sqrtl); - TLI.setUnavailable(LibFunc::tanl); - TLI.setUnavailable(LibFunc::tanhl); + TLI.setUnavailable(LibFunc_acosl); + TLI.setUnavailable(LibFunc_asinl); + TLI.setUnavailable(LibFunc_atanl); + TLI.setUnavailable(LibFunc_atan2l); + TLI.setUnavailable(LibFunc_ceill); + TLI.setUnavailable(LibFunc_copysignl); + TLI.setUnavailable(LibFunc_cosl); + TLI.setUnavailable(LibFunc_coshl); + TLI.setUnavailable(LibFunc_expl); + TLI.setUnavailable(LibFunc_fabsf); // Win32 and Win64 both lack fabsf + TLI.setUnavailable(LibFunc_fabsl); + TLI.setUnavailable(LibFunc_floorl); + TLI.setUnavailable(LibFunc_fmaxl); + TLI.setUnavailable(LibFunc_fminl); + TLI.setUnavailable(LibFunc_fmodl); + TLI.setUnavailable(LibFunc_frexpl); + TLI.setUnavailable(LibFunc_ldexpf); + TLI.setUnavailable(LibFunc_ldexpl); + TLI.setUnavailable(LibFunc_logl); + TLI.setUnavailable(LibFunc_modfl); + TLI.setUnavailable(LibFunc_powl); + TLI.setUnavailable(LibFunc_sinl); + TLI.setUnavailable(LibFunc_sinhl); + TLI.setUnavailable(LibFunc_sqrtl); + TLI.setUnavailable(LibFunc_tanl); + TLI.setUnavailable(LibFunc_tanhl); // Win32 only has C89 math - TLI.setUnavailable(LibFunc::acosh); - TLI.setUnavailable(LibFunc::acoshf); - TLI.setUnavailable(LibFunc::acoshl); - TLI.setUnavailable(LibFunc::asinh); - TLI.setUnavailable(LibFunc::asinhf); - TLI.setUnavailable(LibFunc::asinhl); - TLI.setUnavailable(LibFunc::atanh); - TLI.setUnavailable(LibFunc::atanhf); - TLI.setUnavailable(LibFunc::atanhl); - TLI.setUnavailable(LibFunc::cbrt); - TLI.setUnavailable(LibFunc::cbrtf); - TLI.setUnavailable(LibFunc::cbrtl); - TLI.setUnavailable(LibFunc::exp2); - TLI.setUnavailable(LibFunc::exp2f); - TLI.setUnavailable(LibFunc::exp2l); - TLI.setUnavailable(LibFunc::expm1); - TLI.setUnavailable(LibFunc::expm1f); - TLI.setUnavailable(LibFunc::expm1l); - TLI.setUnavailable(LibFunc::log2); - TLI.setUnavailable(LibFunc::log2f); - TLI.setUnavailable(LibFunc::log2l); - TLI.setUnavailable(LibFunc::log1p); - TLI.setUnavailable(LibFunc::log1pf); - TLI.setUnavailable(LibFunc::log1pl); - TLI.setUnavailable(LibFunc::logb); - TLI.setUnavailable(LibFunc::logbf); - TLI.setUnavailable(LibFunc::logbl); - TLI.setUnavailable(LibFunc::nearbyint); - TLI.setUnavailable(LibFunc::nearbyintf); - TLI.setUnavailable(LibFunc::nearbyintl); - TLI.setUnavailable(LibFunc::rint); - TLI.setUnavailable(LibFunc::rintf); - TLI.setUnavailable(LibFunc::rintl); - TLI.setUnavailable(LibFunc::round); - TLI.setUnavailable(LibFunc::roundf); - TLI.setUnavailable(LibFunc::roundl); - TLI.setUnavailable(LibFunc::trunc); - TLI.setUnavailable(LibFunc::truncf); - TLI.setUnavailable(LibFunc::truncl); + TLI.setUnavailable(LibFunc_acosh); + TLI.setUnavailable(LibFunc_acoshf); + TLI.setUnavailable(LibFunc_acoshl); + TLI.setUnavailable(LibFunc_asinh); + TLI.setUnavailable(LibFunc_asinhf); + TLI.setUnavailable(LibFunc_asinhl); + TLI.setUnavailable(LibFunc_atanh); + TLI.setUnavailable(LibFunc_atanhf); + TLI.setUnavailable(LibFunc_atanhl); + TLI.setUnavailable(LibFunc_cbrt); + TLI.setUnavailable(LibFunc_cbrtf); + TLI.setUnavailable(LibFunc_cbrtl); + TLI.setUnavailable(LibFunc_exp2); + TLI.setUnavailable(LibFunc_exp2f); + TLI.setUnavailable(LibFunc_exp2l); + TLI.setUnavailable(LibFunc_expm1); + TLI.setUnavailable(LibFunc_expm1f); + TLI.setUnavailable(LibFunc_expm1l); + TLI.setUnavailable(LibFunc_log2); + TLI.setUnavailable(LibFunc_log2f); + TLI.setUnavailable(LibFunc_log2l); + TLI.setUnavailable(LibFunc_log1p); + TLI.setUnavailable(LibFunc_log1pf); + TLI.setUnavailable(LibFunc_log1pl); + TLI.setUnavailable(LibFunc_logb); + TLI.setUnavailable(LibFunc_logbf); + TLI.setUnavailable(LibFunc_logbl); + TLI.setUnavailable(LibFunc_nearbyint); + TLI.setUnavailable(LibFunc_nearbyintf); + TLI.setUnavailable(LibFunc_nearbyintl); + TLI.setUnavailable(LibFunc_rint); + TLI.setUnavailable(LibFunc_rintf); + TLI.setUnavailable(LibFunc_rintl); + TLI.setUnavailable(LibFunc_round); + TLI.setUnavailable(LibFunc_roundf); + TLI.setUnavailable(LibFunc_roundl); + TLI.setUnavailable(LibFunc_trunc); + TLI.setUnavailable(LibFunc_truncf); + TLI.setUnavailable(LibFunc_truncl); // Win32 provides some C99 math with mangled names - TLI.setAvailableWithName(LibFunc::copysign, "_copysign"); + TLI.setAvailableWithName(LibFunc_copysign, "_copysign"); if (T.getArch() == Triple::x86) { // Win32 on x86 implements single-precision math functions as macros - TLI.setUnavailable(LibFunc::acosf); - TLI.setUnavailable(LibFunc::asinf); - TLI.setUnavailable(LibFunc::atanf); - TLI.setUnavailable(LibFunc::atan2f); - TLI.setUnavailable(LibFunc::ceilf); - TLI.setUnavailable(LibFunc::copysignf); - TLI.setUnavailable(LibFunc::cosf); - TLI.setUnavailable(LibFunc::coshf); - TLI.setUnavailable(LibFunc::expf); - TLI.setUnavailable(LibFunc::floorf); - TLI.setUnavailable(LibFunc::fminf); - TLI.setUnavailable(LibFunc::fmaxf); - TLI.setUnavailable(LibFunc::fmodf); - TLI.setUnavailable(LibFunc::logf); - TLI.setUnavailable(LibFunc::log10f); - TLI.setUnavailable(LibFunc::modff); - TLI.setUnavailable(LibFunc::powf); - TLI.setUnavailable(LibFunc::sinf); - TLI.setUnavailable(LibFunc::sinhf); - TLI.setUnavailable(LibFunc::sqrtf); - TLI.setUnavailable(LibFunc::tanf); - TLI.setUnavailable(LibFunc::tanhf); + TLI.setUnavailable(LibFunc_acosf); + TLI.setUnavailable(LibFunc_asinf); + TLI.setUnavailable(LibFunc_atanf); + TLI.setUnavailable(LibFunc_atan2f); + TLI.setUnavailable(LibFunc_ceilf); + TLI.setUnavailable(LibFunc_copysignf); + TLI.setUnavailable(LibFunc_cosf); + TLI.setUnavailable(LibFunc_coshf); + TLI.setUnavailable(LibFunc_expf); + TLI.setUnavailable(LibFunc_floorf); + TLI.setUnavailable(LibFunc_fminf); + TLI.setUnavailable(LibFunc_fmaxf); + TLI.setUnavailable(LibFunc_fmodf); + TLI.setUnavailable(LibFunc_logf); + TLI.setUnavailable(LibFunc_log10f); + TLI.setUnavailable(LibFunc_modff); + TLI.setUnavailable(LibFunc_powf); + TLI.setUnavailable(LibFunc_sinf); + TLI.setUnavailable(LibFunc_sinhf); + TLI.setUnavailable(LibFunc_sqrtf); + TLI.setUnavailable(LibFunc_tanf); + TLI.setUnavailable(LibFunc_tanhf); } + // These definitions are due to math-finite.h header on Linux + TLI.setUnavailable(LibFunc_acos_finite); + TLI.setUnavailable(LibFunc_acosf_finite); + TLI.setUnavailable(LibFunc_acosl_finite); + TLI.setUnavailable(LibFunc_acosh_finite); + TLI.setUnavailable(LibFunc_acoshf_finite); + TLI.setUnavailable(LibFunc_acoshl_finite); + TLI.setUnavailable(LibFunc_asin_finite); + TLI.setUnavailable(LibFunc_asinf_finite); + TLI.setUnavailable(LibFunc_asinl_finite); + TLI.setUnavailable(LibFunc_atan2_finite); + TLI.setUnavailable(LibFunc_atan2f_finite); + TLI.setUnavailable(LibFunc_atan2l_finite); + TLI.setUnavailable(LibFunc_atanh_finite); + TLI.setUnavailable(LibFunc_atanhf_finite); + TLI.setUnavailable(LibFunc_atanhl_finite); + TLI.setUnavailable(LibFunc_cosh_finite); + TLI.setUnavailable(LibFunc_coshf_finite); + TLI.setUnavailable(LibFunc_coshl_finite); + TLI.setUnavailable(LibFunc_exp10_finite); + TLI.setUnavailable(LibFunc_exp10f_finite); + TLI.setUnavailable(LibFunc_exp10l_finite); + TLI.setUnavailable(LibFunc_exp2_finite); + TLI.setUnavailable(LibFunc_exp2f_finite); + TLI.setUnavailable(LibFunc_exp2l_finite); + TLI.setUnavailable(LibFunc_exp_finite); + TLI.setUnavailable(LibFunc_expf_finite); + TLI.setUnavailable(LibFunc_expl_finite); + TLI.setUnavailable(LibFunc_log10_finite); + TLI.setUnavailable(LibFunc_log10f_finite); + TLI.setUnavailable(LibFunc_log10l_finite); + TLI.setUnavailable(LibFunc_log2_finite); + TLI.setUnavailable(LibFunc_log2f_finite); + TLI.setUnavailable(LibFunc_log2l_finite); + TLI.setUnavailable(LibFunc_log_finite); + TLI.setUnavailable(LibFunc_logf_finite); + TLI.setUnavailable(LibFunc_logl_finite); + TLI.setUnavailable(LibFunc_pow_finite); + TLI.setUnavailable(LibFunc_powf_finite); + TLI.setUnavailable(LibFunc_powl_finite); + TLI.setUnavailable(LibFunc_sinh_finite); + TLI.setUnavailable(LibFunc_sinhf_finite); + TLI.setUnavailable(LibFunc_sinhl_finite); + // Win32 does *not* provide provide these functions, but they are // generally available on POSIX-compliant systems: - TLI.setUnavailable(LibFunc::access); - TLI.setUnavailable(LibFunc::bcmp); - TLI.setUnavailable(LibFunc::bcopy); - TLI.setUnavailable(LibFunc::bzero); - TLI.setUnavailable(LibFunc::chmod); - TLI.setUnavailable(LibFunc::chown); - TLI.setUnavailable(LibFunc::closedir); - TLI.setUnavailable(LibFunc::ctermid); - TLI.setUnavailable(LibFunc::fdopen); - TLI.setUnavailable(LibFunc::ffs); - TLI.setUnavailable(LibFunc::fileno); - TLI.setUnavailable(LibFunc::flockfile); - TLI.setUnavailable(LibFunc::fseeko); - TLI.setUnavailable(LibFunc::fstat); - TLI.setUnavailable(LibFunc::fstatvfs); - TLI.setUnavailable(LibFunc::ftello); - TLI.setUnavailable(LibFunc::ftrylockfile); - TLI.setUnavailable(LibFunc::funlockfile); - TLI.setUnavailable(LibFunc::getc_unlocked); - TLI.setUnavailable(LibFunc::getitimer); - TLI.setUnavailable(LibFunc::getlogin_r); - TLI.setUnavailable(LibFunc::getpwnam); - TLI.setUnavailable(LibFunc::gettimeofday); - TLI.setUnavailable(LibFunc::htonl); - TLI.setUnavailable(LibFunc::htons); - TLI.setUnavailable(LibFunc::lchown); - TLI.setUnavailable(LibFunc::lstat); - TLI.setUnavailable(LibFunc::memccpy); - TLI.setUnavailable(LibFunc::mkdir); - TLI.setUnavailable(LibFunc::ntohl); - TLI.setUnavailable(LibFunc::ntohs); - TLI.setUnavailable(LibFunc::open); - TLI.setUnavailable(LibFunc::opendir); - TLI.setUnavailable(LibFunc::pclose); - TLI.setUnavailable(LibFunc::popen); - TLI.setUnavailable(LibFunc::pread); - TLI.setUnavailable(LibFunc::pwrite); - TLI.setUnavailable(LibFunc::read); - TLI.setUnavailable(LibFunc::readlink); - TLI.setUnavailable(LibFunc::realpath); - TLI.setUnavailable(LibFunc::rmdir); - TLI.setUnavailable(LibFunc::setitimer); - TLI.setUnavailable(LibFunc::stat); - TLI.setUnavailable(LibFunc::statvfs); - TLI.setUnavailable(LibFunc::stpcpy); - TLI.setUnavailable(LibFunc::stpncpy); - TLI.setUnavailable(LibFunc::strcasecmp); - TLI.setUnavailable(LibFunc::strncasecmp); - TLI.setUnavailable(LibFunc::times); - TLI.setUnavailable(LibFunc::uname); - TLI.setUnavailable(LibFunc::unlink); - TLI.setUnavailable(LibFunc::unsetenv); - TLI.setUnavailable(LibFunc::utime); - TLI.setUnavailable(LibFunc::utimes); - TLI.setUnavailable(LibFunc::write); + TLI.setUnavailable(LibFunc_access); + TLI.setUnavailable(LibFunc_bcmp); + TLI.setUnavailable(LibFunc_bcopy); + TLI.setUnavailable(LibFunc_bzero); + TLI.setUnavailable(LibFunc_chmod); + TLI.setUnavailable(LibFunc_chown); + TLI.setUnavailable(LibFunc_closedir); + TLI.setUnavailable(LibFunc_ctermid); + TLI.setUnavailable(LibFunc_fdopen); + TLI.setUnavailable(LibFunc_ffs); + TLI.setUnavailable(LibFunc_fileno); + TLI.setUnavailable(LibFunc_flockfile); + TLI.setUnavailable(LibFunc_fseeko); + TLI.setUnavailable(LibFunc_fstat); + TLI.setUnavailable(LibFunc_fstatvfs); + TLI.setUnavailable(LibFunc_ftello); + TLI.setUnavailable(LibFunc_ftrylockfile); + TLI.setUnavailable(LibFunc_funlockfile); + TLI.setUnavailable(LibFunc_getc_unlocked); + TLI.setUnavailable(LibFunc_getitimer); + TLI.setUnavailable(LibFunc_getlogin_r); + TLI.setUnavailable(LibFunc_getpwnam); + TLI.setUnavailable(LibFunc_gettimeofday); + TLI.setUnavailable(LibFunc_htonl); + TLI.setUnavailable(LibFunc_htons); + TLI.setUnavailable(LibFunc_lchown); + TLI.setUnavailable(LibFunc_lstat); + TLI.setUnavailable(LibFunc_memccpy); + TLI.setUnavailable(LibFunc_mkdir); + TLI.setUnavailable(LibFunc_ntohl); + TLI.setUnavailable(LibFunc_ntohs); + TLI.setUnavailable(LibFunc_open); + TLI.setUnavailable(LibFunc_opendir); + TLI.setUnavailable(LibFunc_pclose); + TLI.setUnavailable(LibFunc_popen); + TLI.setUnavailable(LibFunc_pread); + TLI.setUnavailable(LibFunc_pwrite); + TLI.setUnavailable(LibFunc_read); + TLI.setUnavailable(LibFunc_readlink); + TLI.setUnavailable(LibFunc_realpath); + TLI.setUnavailable(LibFunc_rmdir); + TLI.setUnavailable(LibFunc_setitimer); + TLI.setUnavailable(LibFunc_stat); + TLI.setUnavailable(LibFunc_statvfs); + TLI.setUnavailable(LibFunc_stpcpy); + TLI.setUnavailable(LibFunc_stpncpy); + TLI.setUnavailable(LibFunc_strcasecmp); + TLI.setUnavailable(LibFunc_strncasecmp); + TLI.setUnavailable(LibFunc_times); + TLI.setUnavailable(LibFunc_uname); + TLI.setUnavailable(LibFunc_unlink); + TLI.setUnavailable(LibFunc_unsetenv); + TLI.setUnavailable(LibFunc_utime); + TLI.setUnavailable(LibFunc_utimes); + TLI.setUnavailable(LibFunc_write); // Win32 does *not* provide provide these functions, but they are // specified by C99: - TLI.setUnavailable(LibFunc::atoll); - TLI.setUnavailable(LibFunc::frexpf); - TLI.setUnavailable(LibFunc::llabs); + TLI.setUnavailable(LibFunc_atoll); + TLI.setUnavailable(LibFunc_frexpf); + TLI.setUnavailable(LibFunc_llabs); } switch (T.getOS()) { @@ -311,28 +356,28 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, // exp10 and exp10f are not available on OS X until 10.9 and iOS until 7.0 // and their names are __exp10 and __exp10f. exp10l is not available on // OS X or iOS. - TLI.setUnavailable(LibFunc::exp10l); + TLI.setUnavailable(LibFunc_exp10l); if (T.isMacOSXVersionLT(10, 9)) { - TLI.setUnavailable(LibFunc::exp10); - TLI.setUnavailable(LibFunc::exp10f); + TLI.setUnavailable(LibFunc_exp10); + TLI.setUnavailable(LibFunc_exp10f); } else { - TLI.setAvailableWithName(LibFunc::exp10, "__exp10"); - TLI.setAvailableWithName(LibFunc::exp10f, "__exp10f"); + TLI.setAvailableWithName(LibFunc_exp10, "__exp10"); + TLI.setAvailableWithName(LibFunc_exp10f, "__exp10f"); } break; case Triple::IOS: case Triple::TvOS: case Triple::WatchOS: - TLI.setUnavailable(LibFunc::exp10l); + TLI.setUnavailable(LibFunc_exp10l); if (!T.isWatchOS() && (T.isOSVersionLT(7, 0) || (T.isOSVersionLT(9, 0) && (T.getArch() == Triple::x86 || T.getArch() == Triple::x86_64)))) { - TLI.setUnavailable(LibFunc::exp10); - TLI.setUnavailable(LibFunc::exp10f); + TLI.setUnavailable(LibFunc_exp10); + TLI.setUnavailable(LibFunc_exp10f); } else { - TLI.setAvailableWithName(LibFunc::exp10, "__exp10"); - TLI.setAvailableWithName(LibFunc::exp10f, "__exp10f"); + TLI.setAvailableWithName(LibFunc_exp10, "__exp10"); + TLI.setAvailableWithName(LibFunc_exp10f, "__exp10f"); } break; case Triple::Linux: @@ -344,9 +389,9 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, // Fall through to disable all of them. LLVM_FALLTHROUGH; default: - TLI.setUnavailable(LibFunc::exp10); - TLI.setUnavailable(LibFunc::exp10f); - TLI.setUnavailable(LibFunc::exp10l); + TLI.setUnavailable(LibFunc_exp10); + TLI.setUnavailable(LibFunc_exp10f); + TLI.setUnavailable(LibFunc_exp10l); } // ffsl is available on at least Darwin, Mac OS X, iOS, FreeBSD, and @@ -364,7 +409,7 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, case Triple::Linux: break; default: - TLI.setUnavailable(LibFunc::ffsl); + TLI.setUnavailable(LibFunc_ffsl); } // ffsll is available on at least FreeBSD and Linux (GLIBC): @@ -380,7 +425,7 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, case Triple::Linux: break; default: - TLI.setUnavailable(LibFunc::ffsll); + TLI.setUnavailable(LibFunc_ffsll); } // The following functions are available on at least FreeBSD: @@ -388,30 +433,30 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, // http://svn.freebsd.org/base/head/lib/libc/string/flsl.c // http://svn.freebsd.org/base/head/lib/libc/string/flsll.c if (!T.isOSFreeBSD()) { - TLI.setUnavailable(LibFunc::fls); - TLI.setUnavailable(LibFunc::flsl); - TLI.setUnavailable(LibFunc::flsll); + TLI.setUnavailable(LibFunc_fls); + TLI.setUnavailable(LibFunc_flsl); + TLI.setUnavailable(LibFunc_flsll); } // The following functions are available on at least Linux: if (!T.isOSLinux()) { - TLI.setUnavailable(LibFunc::dunder_strdup); - TLI.setUnavailable(LibFunc::dunder_strtok_r); - TLI.setUnavailable(LibFunc::dunder_isoc99_scanf); - TLI.setUnavailable(LibFunc::dunder_isoc99_sscanf); - TLI.setUnavailable(LibFunc::under_IO_getc); - TLI.setUnavailable(LibFunc::under_IO_putc); - TLI.setUnavailable(LibFunc::memalign); - TLI.setUnavailable(LibFunc::fopen64); - TLI.setUnavailable(LibFunc::fseeko64); - TLI.setUnavailable(LibFunc::fstat64); - TLI.setUnavailable(LibFunc::fstatvfs64); - TLI.setUnavailable(LibFunc::ftello64); - TLI.setUnavailable(LibFunc::lstat64); - TLI.setUnavailable(LibFunc::open64); - TLI.setUnavailable(LibFunc::stat64); - TLI.setUnavailable(LibFunc::statvfs64); - TLI.setUnavailable(LibFunc::tmpfile64); + TLI.setUnavailable(LibFunc_dunder_strdup); + TLI.setUnavailable(LibFunc_dunder_strtok_r); + TLI.setUnavailable(LibFunc_dunder_isoc99_scanf); + TLI.setUnavailable(LibFunc_dunder_isoc99_sscanf); + TLI.setUnavailable(LibFunc_under_IO_getc); + TLI.setUnavailable(LibFunc_under_IO_putc); + TLI.setUnavailable(LibFunc_memalign); + TLI.setUnavailable(LibFunc_fopen64); + TLI.setUnavailable(LibFunc_fseeko64); + TLI.setUnavailable(LibFunc_fstat64); + TLI.setUnavailable(LibFunc_fstatvfs64); + TLI.setUnavailable(LibFunc_ftello64); + TLI.setUnavailable(LibFunc_lstat64); + TLI.setUnavailable(LibFunc_open64); + TLI.setUnavailable(LibFunc_stat64); + TLI.setUnavailable(LibFunc_statvfs64); + TLI.setUnavailable(LibFunc_tmpfile64); } // As currently implemented in clang, NVPTX code has no standard library to @@ -427,9 +472,9 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, // optimizations, so this situation should be fixed. if (T.isNVPTX()) { TLI.disableAllFunctions(); - TLI.setAvailable(LibFunc::nvvm_reflect); + TLI.setAvailable(LibFunc_nvvm_reflect); } else { - TLI.setUnavailable(LibFunc::nvvm_reflect); + TLI.setUnavailable(LibFunc_nvvm_reflect); } TLI.addVectorizableFunctionsFromVecLib(ClVectorLibrary); @@ -496,13 +541,13 @@ static StringRef sanitizeFunctionName(StringRef funcName) { // Check for \01 prefix that is used to mangle __asm declarations and // strip it if present. - return GlobalValue::getRealLinkageName(funcName); + return GlobalValue::dropLLVMManglingEscape(funcName); } bool TargetLibraryInfoImpl::getLibFunc(StringRef funcName, - LibFunc::Func &F) const { + LibFunc &F) const { StringRef const *Start = &StandardNames[0]; - StringRef const *End = &StandardNames[LibFunc::NumLibFuncs]; + StringRef const *End = &StandardNames[NumLibFuncs]; funcName = sanitizeFunctionName(funcName); if (funcName.empty()) @@ -513,14 +558,14 @@ bool TargetLibraryInfoImpl::getLibFunc(StringRef funcName, return LHS < RHS; }); if (I != End && *I == funcName) { - F = (LibFunc::Func)(I - Start); + F = (LibFunc)(I - Start); return true; } return false; } bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy, - LibFunc::Func F, + LibFunc F, const DataLayout *DL) const { LLVMContext &Ctx = FTy.getContext(); Type *PCharTy = Type::getInt8PtrTy(Ctx); @@ -531,504 +576,706 @@ bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy, unsigned NumParams = FTy.getNumParams(); switch (F) { - case LibFunc::strlen: + case LibFunc_strlen: return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() && FTy.getReturnType()->isIntegerTy()); - case LibFunc::strchr: - case LibFunc::strrchr: + case LibFunc_strchr: + case LibFunc_strrchr: return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0) == FTy.getReturnType() && FTy.getParamType(1)->isIntegerTy()); - case LibFunc::strtol: - case LibFunc::strtod: - case LibFunc::strtof: - case LibFunc::strtoul: - case LibFunc::strtoll: - case LibFunc::strtold: - case LibFunc::strtoull: + case LibFunc_strtol: + case LibFunc_strtod: + case LibFunc_strtof: + case LibFunc_strtoul: + case LibFunc_strtoll: + case LibFunc_strtold: + case LibFunc_strtoull: return ((NumParams == 2 || NumParams == 3) && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::strcat: + case LibFunc_strcat: return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0) == FTy.getReturnType() && FTy.getParamType(1) == FTy.getReturnType()); - case LibFunc::strncat: + case LibFunc_strncat: return (NumParams == 3 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0) == FTy.getReturnType() && FTy.getParamType(1) == FTy.getReturnType() && FTy.getParamType(2)->isIntegerTy()); - case LibFunc::strcpy_chk: - case LibFunc::stpcpy_chk: + case LibFunc_strcpy_chk: + case LibFunc_stpcpy_chk: --NumParams; if (!IsSizeTTy(FTy.getParamType(NumParams))) return false; LLVM_FALLTHROUGH; - case LibFunc::strcpy: - case LibFunc::stpcpy: + case LibFunc_strcpy: + case LibFunc_stpcpy: return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(0) && FTy.getParamType(0) == FTy.getParamType(1) && FTy.getParamType(0) == PCharTy); - case LibFunc::strncpy_chk: - case LibFunc::stpncpy_chk: + case LibFunc_strncpy_chk: + case LibFunc_stpncpy_chk: --NumParams; if (!IsSizeTTy(FTy.getParamType(NumParams))) return false; LLVM_FALLTHROUGH; - case LibFunc::strncpy: - case LibFunc::stpncpy: + case LibFunc_strncpy: + case LibFunc_stpncpy: return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && FTy.getParamType(0) == FTy.getParamType(1) && FTy.getParamType(0) == PCharTy && FTy.getParamType(2)->isIntegerTy()); - case LibFunc::strxfrm: + case LibFunc_strxfrm: return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::strcmp: + case LibFunc_strcmp: return (NumParams == 2 && FTy.getReturnType()->isIntegerTy(32) && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(0) == FTy.getParamType(1)); - case LibFunc::strncmp: + case LibFunc_strncmp: return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(0) == FTy.getParamType(1) && FTy.getParamType(2)->isIntegerTy()); - case LibFunc::strspn: - case LibFunc::strcspn: + case LibFunc_strspn: + case LibFunc_strcspn: return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(0) == FTy.getParamType(1) && FTy.getReturnType()->isIntegerTy()); - case LibFunc::strcoll: - case LibFunc::strcasecmp: - case LibFunc::strncasecmp: + case LibFunc_strcoll: + case LibFunc_strcasecmp: + case LibFunc_strncasecmp: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::strstr: + case LibFunc_strstr: return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::strpbrk: + case LibFunc_strpbrk: return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && FTy.getReturnType() == FTy.getParamType(0) && FTy.getParamType(0) == FTy.getParamType(1)); - case LibFunc::strtok: - case LibFunc::strtok_r: + case LibFunc_strtok: + case LibFunc_strtok_r: return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::scanf: - case LibFunc::setbuf: - case LibFunc::setvbuf: + case LibFunc_scanf: + case LibFunc_setbuf: + case LibFunc_setvbuf: return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::strdup: - case LibFunc::strndup: + case LibFunc_strdup: + case LibFunc_strndup: return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0)->isPointerTy()); - case LibFunc::sscanf: - case LibFunc::stat: - case LibFunc::statvfs: - case LibFunc::sprintf: + case LibFunc_sscanf: + case LibFunc_stat: + case LibFunc_statvfs: + case LibFunc_siprintf: + case LibFunc_sprintf: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::snprintf: + case LibFunc_snprintf: return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(2)->isPointerTy()); - case LibFunc::setitimer: + case LibFunc_setitimer: return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && FTy.getParamType(2)->isPointerTy()); - case LibFunc::system: + case LibFunc_system: return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::malloc: + case LibFunc_malloc: return (NumParams == 1 && FTy.getReturnType()->isPointerTy()); - case LibFunc::memcmp: - return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && - FTy.getParamType(1)->isPointerTy() && - FTy.getReturnType()->isIntegerTy(32)); + case LibFunc_memcmp: + return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); - case LibFunc::memchr: - case LibFunc::memrchr: - return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + case LibFunc_memchr: + case LibFunc_memrchr: + return (NumParams == 3 && FTy.getReturnType()->isPointerTy() && + FTy.getReturnType() == FTy.getParamType(0) && FTy.getParamType(1)->isIntegerTy(32) && - FTy.getParamType(2)->isIntegerTy() && - FTy.getReturnType()->isPointerTy()); - case LibFunc::modf: - case LibFunc::modff: - case LibFunc::modfl: + IsSizeTTy(FTy.getParamType(2))); + case LibFunc_modf: + case LibFunc_modff: + case LibFunc_modfl: return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::memcpy_chk: - case LibFunc::memmove_chk: + case LibFunc_memcpy_chk: + case LibFunc_memmove_chk: --NumParams; if (!IsSizeTTy(FTy.getParamType(NumParams))) return false; LLVM_FALLTHROUGH; - case LibFunc::memcpy: - case LibFunc::mempcpy: - case LibFunc::memmove: + case LibFunc_memcpy: + case LibFunc_mempcpy: + case LibFunc_memmove: return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy() && IsSizeTTy(FTy.getParamType(2))); - case LibFunc::memset_chk: + case LibFunc_memset_chk: --NumParams; if (!IsSizeTTy(FTy.getParamType(NumParams))) return false; LLVM_FALLTHROUGH; - case LibFunc::memset: + case LibFunc_memset: return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isIntegerTy() && IsSizeTTy(FTy.getParamType(2))); - case LibFunc::memccpy: + case LibFunc_memccpy: return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::memalign: + case LibFunc_memalign: return (FTy.getReturnType()->isPointerTy()); - case LibFunc::realloc: - return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && - FTy.getReturnType()->isPointerTy()); - case LibFunc::read: + case LibFunc_realloc: + case LibFunc_reallocf: + return (NumParams == 2 && FTy.getReturnType() == PCharTy && + FTy.getParamType(0) == FTy.getReturnType() && + IsSizeTTy(FTy.getParamType(1))); + case LibFunc_read: return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::rewind: - case LibFunc::rmdir: - case LibFunc::remove: - case LibFunc::realpath: + case LibFunc_rewind: + case LibFunc_rmdir: + case LibFunc_remove: + case LibFunc_realpath: return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::rename: + case LibFunc_rename: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::readlink: + case LibFunc_readlink: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::write: + case LibFunc_write: return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::bcopy: - case LibFunc::bcmp: + case LibFunc_bcopy: + case LibFunc_bcmp: return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::bzero: + case LibFunc_bzero: return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::calloc: + case LibFunc_calloc: return (NumParams == 2 && FTy.getReturnType()->isPointerTy()); - case LibFunc::atof: - case LibFunc::atoi: - case LibFunc::atol: - case LibFunc::atoll: - case LibFunc::ferror: - case LibFunc::getenv: - case LibFunc::getpwnam: - case LibFunc::pclose: - case LibFunc::perror: - case LibFunc::printf: - case LibFunc::puts: - case LibFunc::uname: - case LibFunc::under_IO_getc: - case LibFunc::unlink: - case LibFunc::unsetenv: + case LibFunc_atof: + case LibFunc_atoi: + case LibFunc_atol: + case LibFunc_atoll: + case LibFunc_ferror: + case LibFunc_getenv: + case LibFunc_getpwnam: + case LibFunc_iprintf: + case LibFunc_pclose: + case LibFunc_perror: + case LibFunc_printf: + case LibFunc_puts: + case LibFunc_uname: + case LibFunc_under_IO_getc: + case LibFunc_unlink: + case LibFunc_unsetenv: return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::chmod: - case LibFunc::chown: - case LibFunc::clearerr: - case LibFunc::closedir: - case LibFunc::ctermid: - case LibFunc::fclose: - case LibFunc::feof: - case LibFunc::fflush: - case LibFunc::fgetc: - case LibFunc::fileno: - case LibFunc::flockfile: - case LibFunc::free: - case LibFunc::fseek: - case LibFunc::fseeko64: - case LibFunc::fseeko: - case LibFunc::fsetpos: - case LibFunc::ftell: - case LibFunc::ftello64: - case LibFunc::ftello: - case LibFunc::ftrylockfile: - case LibFunc::funlockfile: - case LibFunc::getc: - case LibFunc::getc_unlocked: - case LibFunc::getlogin_r: - case LibFunc::mkdir: - case LibFunc::mktime: - case LibFunc::times: + case LibFunc_access: + case LibFunc_chmod: + case LibFunc_chown: + case LibFunc_clearerr: + case LibFunc_closedir: + case LibFunc_ctermid: + case LibFunc_fclose: + case LibFunc_feof: + case LibFunc_fflush: + case LibFunc_fgetc: + case LibFunc_fileno: + case LibFunc_flockfile: + case LibFunc_free: + case LibFunc_fseek: + case LibFunc_fseeko64: + case LibFunc_fseeko: + case LibFunc_fsetpos: + case LibFunc_ftell: + case LibFunc_ftello64: + case LibFunc_ftello: + case LibFunc_ftrylockfile: + case LibFunc_funlockfile: + case LibFunc_getc: + case LibFunc_getc_unlocked: + case LibFunc_getlogin_r: + case LibFunc_mkdir: + case LibFunc_mktime: + case LibFunc_times: return (NumParams != 0 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::access: - return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::fopen: + case LibFunc_fopen: return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::fdopen: + case LibFunc_fdopen: return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::fputc: - case LibFunc::fstat: - case LibFunc::frexp: - case LibFunc::frexpf: - case LibFunc::frexpl: - case LibFunc::fstatvfs: + case LibFunc_fputc: + case LibFunc_fstat: + case LibFunc_frexp: + case LibFunc_frexpf: + case LibFunc_frexpl: + case LibFunc_fstatvfs: return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::fgets: + case LibFunc_fgets: return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(2)->isPointerTy()); - case LibFunc::fread: + case LibFunc_fread: return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(3)->isPointerTy()); - case LibFunc::fwrite: + case LibFunc_fwrite: return (NumParams == 4 && FTy.getReturnType()->isIntegerTy() && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isIntegerTy() && FTy.getParamType(2)->isIntegerTy() && FTy.getParamType(3)->isPointerTy()); - case LibFunc::fputs: + case LibFunc_fputs: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::fscanf: - case LibFunc::fprintf: - return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + case LibFunc_fscanf: + case LibFunc_fiprintf: + case LibFunc_fprintf: + return (NumParams >= 2 && FTy.getReturnType()->isIntegerTy() && + FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::fgetpos: + case LibFunc_fgetpos: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::gets: - case LibFunc::getchar: - case LibFunc::getitimer: + case LibFunc_getchar: + return (NumParams == 0 && FTy.getReturnType()->isIntegerTy()); + case LibFunc_gets: + return (NumParams == 1 && FTy.getParamType(0) == PCharTy); + case LibFunc_getitimer: return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::ungetc: + case LibFunc_ungetc: return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::utime: - case LibFunc::utimes: + case LibFunc_utime: + case LibFunc_utimes: return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::putc: + case LibFunc_putc: return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::pread: - case LibFunc::pwrite: + case LibFunc_pread: + case LibFunc_pwrite: return (NumParams == 4 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::popen: + case LibFunc_popen: return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::vscanf: + case LibFunc_vscanf: return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::vsscanf: + case LibFunc_vsscanf: return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && FTy.getParamType(2)->isPointerTy()); - case LibFunc::vfscanf: + case LibFunc_vfscanf: return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && FTy.getParamType(2)->isPointerTy()); - case LibFunc::valloc: + case LibFunc_valloc: return (FTy.getReturnType()->isPointerTy()); - case LibFunc::vprintf: + case LibFunc_vprintf: return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::vfprintf: - case LibFunc::vsprintf: + case LibFunc_vfprintf: + case LibFunc_vsprintf: return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::vsnprintf: + case LibFunc_vsnprintf: return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(2)->isPointerTy()); - case LibFunc::open: + case LibFunc_open: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::opendir: + case LibFunc_opendir: return (NumParams == 1 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0)->isPointerTy()); - case LibFunc::tmpfile: + case LibFunc_tmpfile: return (FTy.getReturnType()->isPointerTy()); - case LibFunc::htonl: - case LibFunc::htons: - case LibFunc::ntohl: - case LibFunc::ntohs: - case LibFunc::lstat: + case LibFunc_htonl: + case LibFunc_ntohl: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getReturnType() == FTy.getParamType(0)); + case LibFunc_htons: + case LibFunc_ntohs: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(16) && + FTy.getReturnType() == FTy.getParamType(0)); + case LibFunc_lstat: return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::lchown: + case LibFunc_lchown: return (NumParams == 3 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::qsort: + case LibFunc_qsort: return (NumParams == 4 && FTy.getParamType(3)->isPointerTy()); - case LibFunc::dunder_strdup: - case LibFunc::dunder_strndup: + case LibFunc_dunder_strdup: + case LibFunc_dunder_strndup: return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0)->isPointerTy()); - case LibFunc::dunder_strtok_r: + case LibFunc_dunder_strtok_r: return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::under_IO_putc: + case LibFunc_under_IO_putc: return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::dunder_isoc99_scanf: + case LibFunc_dunder_isoc99_scanf: return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::stat64: - case LibFunc::lstat64: - case LibFunc::statvfs64: + case LibFunc_stat64: + case LibFunc_lstat64: + case LibFunc_statvfs64: return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::dunder_isoc99_sscanf: + case LibFunc_dunder_isoc99_sscanf: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::fopen64: + case LibFunc_fopen64: return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::tmpfile64: + case LibFunc_tmpfile64: return (FTy.getReturnType()->isPointerTy()); - case LibFunc::fstat64: - case LibFunc::fstatvfs64: + case LibFunc_fstat64: + case LibFunc_fstatvfs64: return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); - case LibFunc::open64: + case LibFunc_open64: return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy()); - case LibFunc::gettimeofday: + case LibFunc_gettimeofday: return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); - case LibFunc::Znwj: // new(unsigned int); - case LibFunc::Znwm: // new(unsigned long); - case LibFunc::Znaj: // new[](unsigned int); - case LibFunc::Znam: // new[](unsigned long); - case LibFunc::msvc_new_int: // new(unsigned int); - case LibFunc::msvc_new_longlong: // new(unsigned long long); - case LibFunc::msvc_new_array_int: // new[](unsigned int); - case LibFunc::msvc_new_array_longlong: // new[](unsigned long long); - return (NumParams == 1); - - case LibFunc::memset_pattern16: + // new(unsigned int); + case LibFunc_Znwj: + // new(unsigned long); + case LibFunc_Znwm: + // new[](unsigned int); + case LibFunc_Znaj: + // new[](unsigned long); + case LibFunc_Znam: + // new(unsigned int); + case LibFunc_msvc_new_int: + // new(unsigned long long); + case LibFunc_msvc_new_longlong: + // new[](unsigned int); + case LibFunc_msvc_new_array_int: + // new[](unsigned long long); + case LibFunc_msvc_new_array_longlong: + return (NumParams == 1 && FTy.getReturnType()->isPointerTy()); + + // new(unsigned int, nothrow); + case LibFunc_ZnwjRKSt9nothrow_t: + // new(unsigned long, nothrow); + case LibFunc_ZnwmRKSt9nothrow_t: + // new[](unsigned int, nothrow); + case LibFunc_ZnajRKSt9nothrow_t: + // new[](unsigned long, nothrow); + case LibFunc_ZnamRKSt9nothrow_t: + // new(unsigned int, nothrow); + case LibFunc_msvc_new_int_nothrow: + // new(unsigned long long, nothrow); + case LibFunc_msvc_new_longlong_nothrow: + // new[](unsigned int, nothrow); + case LibFunc_msvc_new_array_int_nothrow: + // new[](unsigned long long, nothrow); + case LibFunc_msvc_new_array_longlong_nothrow: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy()); + + // void operator delete[](void*); + case LibFunc_ZdaPv: + // void operator delete(void*); + case LibFunc_ZdlPv: + // void operator delete[](void*); + case LibFunc_msvc_delete_array_ptr32: + // void operator delete[](void*); + case LibFunc_msvc_delete_array_ptr64: + // void operator delete(void*); + case LibFunc_msvc_delete_ptr32: + // void operator delete(void*); + case LibFunc_msvc_delete_ptr64: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + + // void operator delete[](void*, nothrow); + case LibFunc_ZdaPvRKSt9nothrow_t: + // void operator delete[](void*, unsigned int); + case LibFunc_ZdaPvj: + // void operator delete[](void*, unsigned long); + case LibFunc_ZdaPvm: + // void operator delete(void*, nothrow); + case LibFunc_ZdlPvRKSt9nothrow_t: + // void operator delete(void*, unsigned int); + case LibFunc_ZdlPvj: + // void operator delete(void*, unsigned long); + case LibFunc_ZdlPvm: + // void operator delete[](void*, unsigned int); + case LibFunc_msvc_delete_array_ptr32_int: + // void operator delete[](void*, nothrow); + case LibFunc_msvc_delete_array_ptr32_nothrow: + // void operator delete[](void*, unsigned long long); + case LibFunc_msvc_delete_array_ptr64_longlong: + // void operator delete[](void*, nothrow); + case LibFunc_msvc_delete_array_ptr64_nothrow: + // void operator delete(void*, unsigned int); + case LibFunc_msvc_delete_ptr32_int: + // void operator delete(void*, nothrow); + case LibFunc_msvc_delete_ptr32_nothrow: + // void operator delete(void*, unsigned long long); + case LibFunc_msvc_delete_ptr64_longlong: + // void operator delete(void*, nothrow); + case LibFunc_msvc_delete_ptr64_nothrow: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); + + case LibFunc_memset_pattern16: return (!FTy.isVarArg() && NumParams == 3 && - isa<PointerType>(FTy.getParamType(0)) && - isa<PointerType>(FTy.getParamType(1)) && - isa<IntegerType>(FTy.getParamType(2))); - - // int __nvvm_reflect(const char *); - case LibFunc::nvvm_reflect: - return (NumParams == 1 && isa<PointerType>(FTy.getParamType(0))); - - case LibFunc::sin: - case LibFunc::sinf: - case LibFunc::sinl: - case LibFunc::cos: - case LibFunc::cosf: - case LibFunc::cosl: - case LibFunc::tan: - case LibFunc::tanf: - case LibFunc::tanl: - case LibFunc::exp: - case LibFunc::expf: - case LibFunc::expl: - case LibFunc::exp2: - case LibFunc::exp2f: - case LibFunc::exp2l: - case LibFunc::log: - case LibFunc::logf: - case LibFunc::logl: - case LibFunc::log10: - case LibFunc::log10f: - case LibFunc::log10l: - case LibFunc::log2: - case LibFunc::log2f: - case LibFunc::log2l: - case LibFunc::fabs: - case LibFunc::fabsf: - case LibFunc::fabsl: - case LibFunc::floor: - case LibFunc::floorf: - case LibFunc::floorl: - case LibFunc::ceil: - case LibFunc::ceilf: - case LibFunc::ceill: - case LibFunc::trunc: - case LibFunc::truncf: - case LibFunc::truncl: - case LibFunc::rint: - case LibFunc::rintf: - case LibFunc::rintl: - case LibFunc::nearbyint: - case LibFunc::nearbyintf: - case LibFunc::nearbyintl: - case LibFunc::round: - case LibFunc::roundf: - case LibFunc::roundl: - case LibFunc::sqrt: - case LibFunc::sqrtf: - case LibFunc::sqrtl: + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isIntegerTy()); + + case LibFunc_cxa_guard_abort: + case LibFunc_cxa_guard_acquire: + case LibFunc_cxa_guard_release: + case LibFunc_nvvm_reflect: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + + case LibFunc_sincospi_stret: + case LibFunc_sincospif_stret: + return (NumParams == 1 && FTy.getParamType(0)->isFloatingPointTy()); + + case LibFunc_acos: + case LibFunc_acos_finite: + case LibFunc_acosf: + case LibFunc_acosf_finite: + case LibFunc_acosh: + case LibFunc_acosh_finite: + case LibFunc_acoshf: + case LibFunc_acoshf_finite: + case LibFunc_acoshl: + case LibFunc_acoshl_finite: + case LibFunc_acosl: + case LibFunc_acosl_finite: + case LibFunc_asin: + case LibFunc_asin_finite: + case LibFunc_asinf: + case LibFunc_asinf_finite: + case LibFunc_asinh: + case LibFunc_asinhf: + case LibFunc_asinhl: + case LibFunc_asinl: + case LibFunc_asinl_finite: + case LibFunc_atan: + case LibFunc_atanf: + case LibFunc_atanh: + case LibFunc_atanh_finite: + case LibFunc_atanhf: + case LibFunc_atanhf_finite: + case LibFunc_atanhl: + case LibFunc_atanhl_finite: + case LibFunc_atanl: + case LibFunc_cbrt: + case LibFunc_cbrtf: + case LibFunc_cbrtl: + case LibFunc_ceil: + case LibFunc_ceilf: + case LibFunc_ceill: + case LibFunc_cos: + case LibFunc_cosf: + case LibFunc_cosh: + case LibFunc_cosh_finite: + case LibFunc_coshf: + case LibFunc_coshf_finite: + case LibFunc_coshl: + case LibFunc_coshl_finite: + case LibFunc_cosl: + case LibFunc_exp10: + case LibFunc_exp10_finite: + case LibFunc_exp10f: + case LibFunc_exp10f_finite: + case LibFunc_exp10l: + case LibFunc_exp10l_finite: + case LibFunc_exp2: + case LibFunc_exp2_finite: + case LibFunc_exp2f: + case LibFunc_exp2f_finite: + case LibFunc_exp2l: + case LibFunc_exp2l_finite: + case LibFunc_exp: + case LibFunc_exp_finite: + case LibFunc_expf: + case LibFunc_expf_finite: + case LibFunc_expl: + case LibFunc_expl_finite: + case LibFunc_expm1: + case LibFunc_expm1f: + case LibFunc_expm1l: + case LibFunc_fabs: + case LibFunc_fabsf: + case LibFunc_fabsl: + case LibFunc_floor: + case LibFunc_floorf: + case LibFunc_floorl: + case LibFunc_log10: + case LibFunc_log10_finite: + case LibFunc_log10f: + case LibFunc_log10f_finite: + case LibFunc_log10l: + case LibFunc_log10l_finite: + case LibFunc_log1p: + case LibFunc_log1pf: + case LibFunc_log1pl: + case LibFunc_log2: + case LibFunc_log2_finite: + case LibFunc_log2f: + case LibFunc_log2f_finite: + case LibFunc_log2l: + case LibFunc_log2l_finite: + case LibFunc_log: + case LibFunc_log_finite: + case LibFunc_logb: + case LibFunc_logbf: + case LibFunc_logbl: + case LibFunc_logf: + case LibFunc_logf_finite: + case LibFunc_logl: + case LibFunc_logl_finite: + case LibFunc_nearbyint: + case LibFunc_nearbyintf: + case LibFunc_nearbyintl: + case LibFunc_rint: + case LibFunc_rintf: + case LibFunc_rintl: + case LibFunc_round: + case LibFunc_roundf: + case LibFunc_roundl: + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_sinh: + case LibFunc_sinh_finite: + case LibFunc_sinhf: + case LibFunc_sinhf_finite: + case LibFunc_sinhl: + case LibFunc_sinhl_finite: + case LibFunc_sinl: + case LibFunc_sqrt: + case LibFunc_sqrt_finite: + case LibFunc_sqrtf: + case LibFunc_sqrtf_finite: + case LibFunc_sqrtl: + case LibFunc_sqrtl_finite: + case LibFunc_tan: + case LibFunc_tanf: + case LibFunc_tanh: + case LibFunc_tanhf: + case LibFunc_tanhl: + case LibFunc_tanl: + case LibFunc_trunc: + case LibFunc_truncf: + case LibFunc_truncl: return (NumParams == 1 && FTy.getReturnType()->isFloatingPointTy() && FTy.getReturnType() == FTy.getParamType(0)); - case LibFunc::fmin: - case LibFunc::fminf: - case LibFunc::fminl: - case LibFunc::fmax: - case LibFunc::fmaxf: - case LibFunc::fmaxl: - case LibFunc::copysign: - case LibFunc::copysignf: - case LibFunc::copysignl: - case LibFunc::pow: - case LibFunc::powf: - case LibFunc::powl: + case LibFunc_atan2: + case LibFunc_atan2_finite: + case LibFunc_atan2f: + case LibFunc_atan2f_finite: + case LibFunc_atan2l: + case LibFunc_atan2l_finite: + case LibFunc_fmin: + case LibFunc_fminf: + case LibFunc_fminl: + case LibFunc_fmax: + case LibFunc_fmaxf: + case LibFunc_fmaxl: + case LibFunc_fmod: + case LibFunc_fmodf: + case LibFunc_fmodl: + case LibFunc_copysign: + case LibFunc_copysignf: + case LibFunc_copysignl: + case LibFunc_pow: + case LibFunc_pow_finite: + case LibFunc_powf: + case LibFunc_powf_finite: + case LibFunc_powl: + case LibFunc_powl_finite: return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() && FTy.getReturnType() == FTy.getParamType(0) && FTy.getReturnType() == FTy.getParamType(1)); - case LibFunc::ffs: - case LibFunc::ffsl: - case LibFunc::ffsll: - case LibFunc::fls: - case LibFunc::flsl: - case LibFunc::flsll: + case LibFunc_ldexp: + case LibFunc_ldexpf: + case LibFunc_ldexpl: + return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() && + FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(1)->isIntegerTy(32)); + + case LibFunc_ffs: + case LibFunc_ffsl: + case LibFunc_ffsll: + case LibFunc_fls: + case LibFunc_flsl: + case LibFunc_flsll: return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && FTy.getParamType(0)->isIntegerTy()); - case LibFunc::isdigit: - case LibFunc::isascii: - case LibFunc::toascii: + case LibFunc_isdigit: + case LibFunc_isascii: + case LibFunc_toascii: + case LibFunc_putchar: return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && FTy.getReturnType() == FTy.getParamType(0)); - case LibFunc::abs: - case LibFunc::labs: - case LibFunc::llabs: + case LibFunc_abs: + case LibFunc_labs: + case LibFunc_llabs: return (NumParams == 1 && FTy.getReturnType()->isIntegerTy() && FTy.getReturnType() == FTy.getParamType(0)); - case LibFunc::cxa_atexit: + case LibFunc_cxa_atexit: return (NumParams == 3 && FTy.getReturnType()->isIntegerTy() && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy() && FTy.getParamType(2)->isPointerTy()); - case LibFunc::sinpi: - case LibFunc::cospi: + case LibFunc_sinpi: + case LibFunc_cospi: return (NumParams == 1 && FTy.getReturnType()->isDoubleTy() && FTy.getReturnType() == FTy.getParamType(0)); - case LibFunc::sinpif: - case LibFunc::cospif: + case LibFunc_sinpif: + case LibFunc_cospif: return (NumParams == 1 && FTy.getReturnType()->isFloatTy() && FTy.getReturnType() == FTy.getParamType(0)); - default: - // Assume the other functions are correct. - // FIXME: It'd be really nice to cover them all. - return true; + case LibFunc_strnlen: + return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(1) && + FTy.getParamType(0) == PCharTy && + FTy.getParamType(1) == SizeTTy); + + case LibFunc_posix_memalign: + return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1) == SizeTTy && FTy.getParamType(2) == SizeTTy); + + case LibFunc_wcslen: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() && + FTy.getReturnType()->isIntegerTy()); + + case LibFunc::NumLibFuncs: + break; } + + llvm_unreachable("Invalid libfunc"); } bool TargetLibraryInfoImpl::getLibFunc(const Function &FDecl, - LibFunc::Func &F) const { + LibFunc &F) const { const DataLayout *DL = FDecl.getParent() ? &FDecl.getParent()->getDataLayout() : nullptr; return getLibFunc(FDecl.getName(), F) && @@ -1134,6 +1381,14 @@ void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( {"powf", "__svml_powf8", 8}, {"powf", "__svml_powf16", 16}, + { "__pow_finite", "__svml_pow2", 2 }, + { "__pow_finite", "__svml_pow4", 4 }, + { "__pow_finite", "__svml_pow8", 8 }, + + { "__powf_finite", "__svml_powf4", 4 }, + { "__powf_finite", "__svml_powf8", 8 }, + { "__powf_finite", "__svml_powf16", 16 }, + {"llvm.pow.f64", "__svml_pow2", 2}, {"llvm.pow.f64", "__svml_pow4", 4}, {"llvm.pow.f64", "__svml_pow8", 8}, @@ -1150,6 +1405,14 @@ void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( {"expf", "__svml_expf8", 8}, {"expf", "__svml_expf16", 16}, + { "__exp_finite", "__svml_exp2", 2 }, + { "__exp_finite", "__svml_exp4", 4 }, + { "__exp_finite", "__svml_exp8", 8 }, + + { "__expf_finite", "__svml_expf4", 4 }, + { "__expf_finite", "__svml_expf8", 8 }, + { "__expf_finite", "__svml_expf16", 16 }, + {"llvm.exp.f64", "__svml_exp2", 2}, {"llvm.exp.f64", "__svml_exp4", 4}, {"llvm.exp.f64", "__svml_exp8", 8}, @@ -1166,6 +1429,14 @@ void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( {"logf", "__svml_logf8", 8}, {"logf", "__svml_logf16", 16}, + { "__log_finite", "__svml_log2", 2 }, + { "__log_finite", "__svml_log4", 4 }, + { "__log_finite", "__svml_log8", 8 }, + + { "__logf_finite", "__svml_logf4", 4 }, + { "__logf_finite", "__svml_logf8", 8 }, + { "__logf_finite", "__svml_logf16", 16 }, + {"llvm.log.f64", "__svml_log2", 2}, {"llvm.log.f64", "__svml_log4", 4}, {"llvm.log.f64", "__svml_log8", 8}, @@ -1248,6 +1519,21 @@ TargetLibraryInfoImpl &TargetLibraryAnalysis::lookupInfoImpl(const Triple &T) { return *Impl; } +unsigned TargetLibraryInfoImpl::getTargetWCharSize(const Triple &T) { + // See also clang/lib/Basic/Targets.cpp. + if (T.isPS4() || T.isOSWindows() || T.isArch16Bit()) + return 2; + if (T.getArch() == Triple::xcore) + return 1; + return 4; +} + +unsigned TargetLibraryInfoImpl::getWCharSize(const Module &M) const { + if (auto *ShortWChar = cast_or_null<ConstantAsMetadata>( + M.getModuleFlag("wchar_size"))) + return cast<ConstantInt>(ShortWChar->getValue())->getZExtValue(); + return getTargetWCharSize(Triple(M.getTargetTriple())); +} TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass() : ImmutablePass(ID), TLIImpl(), TLI(TLIImpl) { diff --git a/contrib/llvm/lib/Analysis/TargetTransformInfo.cpp b/contrib/llvm/lib/Analysis/TargetTransformInfo.cpp index 5c0d1aac1b98..ac646716476b 100644 --- a/contrib/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/contrib/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -83,6 +83,12 @@ int TargetTransformInfo::getIntrinsicCost( return Cost; } +unsigned +TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI, + unsigned &JTSize) const { + return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); +} + int TargetTransformInfo::getUserCost(const User *U) const { int Cost = TTIImpl->getUserCost(U); assert(Cost >= 0 && "TTI should not produce negative costs!"); @@ -97,6 +103,10 @@ bool TargetTransformInfo::isSourceOfDivergence(const Value *V) const { return TTIImpl->isSourceOfDivergence(V); } +unsigned TargetTransformInfo::getFlatAddressSpace() const { + return TTIImpl->getFlatAddressSpace(); +} + bool TargetTransformInfo::isLoweredToCall(const Function *F) const { return TTIImpl->isLoweredToCall(F); } @@ -139,6 +149,10 @@ bool TargetTransformInfo::isLegalMaskedScatter(Type *DataType) const { return TTIImpl->isLegalMaskedGather(DataType); } +bool TargetTransformInfo::prefersVectorizedAddressing() const { + return TTIImpl->prefersVectorizedAddressing(); +} + int TargetTransformInfo::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, @@ -182,10 +196,29 @@ bool TargetTransformInfo::shouldBuildLookupTablesForConstant(Constant *C) const return TTIImpl->shouldBuildLookupTablesForConstant(C); } +unsigned TargetTransformInfo:: +getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) const { + return TTIImpl->getScalarizationOverhead(Ty, Insert, Extract); +} + +unsigned TargetTransformInfo:: +getOperandsScalarizationOverhead(ArrayRef<const Value *> Args, + unsigned VF) const { + return TTIImpl->getOperandsScalarizationOverhead(Args, VF); +} + +bool TargetTransformInfo::supportsEfficientVectorElementLoadStore() const { + return TTIImpl->supportsEfficientVectorElementLoadStore(); +} + bool TargetTransformInfo::enableAggressiveInterleaving(bool LoopHasReductions) const { return TTIImpl->enableAggressiveInterleaving(LoopHasReductions); } +bool TargetTransformInfo::expandMemCmp(Instruction *I, unsigned &MaxLoadSize) const { + return TTIImpl->expandMemCmp(I, MaxLoadSize); +} + bool TargetTransformInfo::enableInterleavedAccessVectorization() const { return TTIImpl->enableInterleavedAccessVectorization(); } @@ -254,6 +287,16 @@ unsigned TargetTransformInfo::getRegisterBitWidth(bool Vector) const { return TTIImpl->getRegisterBitWidth(Vector); } +unsigned TargetTransformInfo::getMinVectorRegisterBitWidth() const { + return TTIImpl->getMinVectorRegisterBitWidth(); +} + +bool TargetTransformInfo::shouldConsiderAddressTypePromotion( + const Instruction &I, bool &AllowPromotionWithoutCommonHeader) const { + return TTIImpl->shouldConsiderAddressTypePromotion( + I, AllowPromotionWithoutCommonHeader); +} + unsigned TargetTransformInfo::getCacheLineSize() const { return TTIImpl->getCacheLineSize(); } @@ -293,8 +336,10 @@ int TargetTransformInfo::getShuffleCost(ShuffleKind Kind, Type *Ty, int Index, } int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, - Type *Src) const { - int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src); + Type *Src, const Instruction *I) const { + assert ((I == nullptr || I->getOpcode() == Opcode) && + "Opcode should reflect passed instruction."); + int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, I); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } @@ -314,8 +359,10 @@ int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const { } int TargetTransformInfo::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, - Type *CondTy) const { - int Cost = TTIImpl->getCmpSelInstrCost(Opcode, ValTy, CondTy); + Type *CondTy, const Instruction *I) const { + assert ((I == nullptr || I->getOpcode() == Opcode) && + "Opcode should reflect passed instruction."); + int Cost = TTIImpl->getCmpSelInstrCost(Opcode, ValTy, CondTy, I); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } @@ -329,8 +376,11 @@ int TargetTransformInfo::getVectorInstrCost(unsigned Opcode, Type *Val, int TargetTransformInfo::getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, - unsigned AddressSpace) const { - int Cost = TTIImpl->getMemoryOpCost(Opcode, Src, Alignment, AddressSpace); + unsigned AddressSpace, + const Instruction *I) const { + assert ((I == nullptr || I->getOpcode() == Opcode) && + "Opcode should reflect passed instruction."); + int Cost = TTIImpl->getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, I); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } @@ -363,17 +413,17 @@ int TargetTransformInfo::getInterleavedMemoryOpCost( } int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef<Type *> Tys, - FastMathFlags FMF) const { - int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Tys, FMF); + ArrayRef<Type *> Tys, FastMathFlags FMF, + unsigned ScalarizationCostPassed) const { + int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Tys, FMF, + ScalarizationCostPassed); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef<Value *> Args, - FastMathFlags FMF) const { - int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Args, FMF); + ArrayRef<Value *> Args, FastMathFlags FMF, unsigned VF) const { + int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } @@ -462,6 +512,15 @@ unsigned TargetTransformInfo::getStoreVectorFactor(unsigned VF, return TTIImpl->getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); } +bool TargetTransformInfo::useReductionIntrinsic(unsigned Opcode, + Type *Ty, ReductionFlags Flags) const { + return TTIImpl->useReductionIntrinsic(Opcode, Ty, Flags); +} + +bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const { + return TTIImpl->shouldExpandReduction(II); +} + TargetTransformInfo::Concept::~Concept() {} TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} diff --git a/contrib/llvm/lib/Analysis/TypeMetadataUtils.cpp b/contrib/llvm/lib/Analysis/TypeMetadataUtils.cpp index f56754167360..6871e4887c9e 100644 --- a/contrib/llvm/lib/Analysis/TypeMetadataUtils.cpp +++ b/contrib/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -39,7 +39,7 @@ findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls, // Search for virtual calls that load from VPtr and add them to DevirtCalls. static void -findLoadCallsAtConstantOffset(Module *M, +findLoadCallsAtConstantOffset(const Module *M, SmallVectorImpl<DevirtCallSite> &DevirtCalls, Value *VPtr, int64_t Offset) { for (const Use &U : VPtr->uses()) { @@ -62,10 +62,10 @@ findLoadCallsAtConstantOffset(Module *M, void llvm::findDevirtualizableCallsForTypeTest( SmallVectorImpl<DevirtCallSite> &DevirtCalls, - SmallVectorImpl<CallInst *> &Assumes, CallInst *CI) { + SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI) { assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test); - Module *M = CI->getParent()->getParent()->getParent(); + const Module *M = CI->getParent()->getParent()->getParent(); // Find llvm.assume intrinsics for this llvm.type.test call. for (const Use &CIU : CI->uses()) { @@ -86,7 +86,8 @@ void llvm::findDevirtualizableCallsForTypeTest( void llvm::findDevirtualizableCallsForTypeCheckedLoad( SmallVectorImpl<DevirtCallSite> &DevirtCalls, SmallVectorImpl<Instruction *> &LoadedPtrs, - SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, CallInst *CI) { + SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, + const CallInst *CI) { assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_checked_load); @@ -96,7 +97,7 @@ void llvm::findDevirtualizableCallsForTypeCheckedLoad( return; } - for (Use &U : CI->uses()) { + for (const Use &U : CI->uses()) { auto CIU = U.getUser(); if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) { if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) { diff --git a/contrib/llvm/lib/Analysis/ValueTracking.cpp b/contrib/llvm/lib/Analysis/ValueTracking.cpp index be6285803e2b..a5dceb6c2271 100644 --- a/contrib/llvm/lib/Analysis/ValueTracking.cpp +++ b/contrib/llvm/lib/Analysis/ValueTracking.cpp @@ -20,11 +20,13 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" @@ -37,6 +39,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Statepoint.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include <algorithm> #include <array> @@ -57,8 +60,8 @@ static cl::opt<bool> DontImproveNonNegativePhiBits("dont-improve-non-negative-phi-bits", cl::Hidden, cl::init(true)); -/// Returns the bitwidth of the given scalar or pointer type (if unknown returns -/// 0). For vector types, returns the element type's bitwidth. +/// Returns the bitwidth of the given scalar or pointer type. For vector types, +/// returns the element type's bitwidth. static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { if (unsigned BitWidth = Ty->getScalarSizeInBits()) return BitWidth; @@ -76,6 +79,9 @@ struct Query { AssumptionCache *AC; const Instruction *CxtI; const DominatorTree *DT; + // Unlike the other analyses, this may be a nullptr because not all clients + // provide it currently. + OptimizationRemarkEmitter *ORE; /// Set of assumptions that should be excluded from further queries. /// This is because of the potential for mutual recursion to cause @@ -83,18 +89,18 @@ struct Query { /// classic case of this is assume(x = y), which will attempt to determine /// bits in x from bits in y, which will attempt to determine bits in y from /// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call - /// isKnownNonZero, which calls computeKnownBits and ComputeSignBit and - /// isKnownToBeAPowerOfTwo (all of which can call computeKnownBits), and so - /// on. + /// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo + /// (all of which can call computeKnownBits), and so on. std::array<const Value *, MaxDepth> Excluded; unsigned NumExcluded; Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) - : DL(DL), AC(AC), CxtI(CxtI), DT(DT), NumExcluded(0) {} + const DominatorTree *DT, OptimizationRemarkEmitter *ORE = nullptr) + : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), NumExcluded(0) {} Query(const Query &Q, const Value *NewExcl) - : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), NumExcluded(Q.NumExcluded) { + : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), + NumExcluded(Q.NumExcluded) { Excluded = Q.Excluded; Excluded[NumExcluded++] = NewExcl; assert(NumExcluded <= Excluded.size()); @@ -125,15 +131,28 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { return nullptr; } -static void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, +static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, const Query &Q); -void llvm::computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, +void llvm::computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - ::computeKnownBits(V, KnownZero, KnownOne, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT)); + const DominatorTree *DT, + OptimizationRemarkEmitter *ORE) { + ::computeKnownBits(V, Known, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); +} + +static KnownBits computeKnownBits(const Value *V, unsigned Depth, + const Query &Q); + +KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT, + OptimizationRemarkEmitter *ORE) { + return ::computeKnownBits(V, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); } bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, @@ -145,22 +164,24 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, assert(LHS->getType()->isIntOrIntVectorTy() && "LHS and RHS should be integers"); IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType()); - APInt LHSKnownZero(IT->getBitWidth(), 0), LHSKnownOne(IT->getBitWidth(), 0); - APInt RHSKnownZero(IT->getBitWidth(), 0), RHSKnownOne(IT->getBitWidth(), 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, DL, 0, AC, CxtI, DT); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, DL, 0, AC, CxtI, DT); - return (LHSKnownZero | RHSKnownZero).isAllOnesValue(); + KnownBits LHSKnown(IT->getBitWidth()); + KnownBits RHSKnown(IT->getBitWidth()); + computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT); + computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT); + return (LHSKnown.Zero | RHSKnown.Zero).isAllOnesValue(); } -static void ComputeSignBit(const Value *V, bool &KnownZero, bool &KnownOne, - unsigned Depth, const Query &Q); -void llvm::ComputeSignBit(const Value *V, bool &KnownZero, bool &KnownOne, - const DataLayout &DL, unsigned Depth, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - ::ComputeSignBit(V, KnownZero, KnownOne, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT)); +bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *CxtI) { + for (const User *U : CxtI->users()) { + if (const ICmpInst *IC = dyn_cast<ICmpInst>(U)) + if (IC->isEquality()) + if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) + if (C->isNullValue()) + continue; + return false; + } + return true; } static bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, @@ -187,9 +208,8 @@ bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - bool NonNegative, Negative; - ComputeSignBit(V, NonNegative, Negative, DL, Depth, AC, CxtI, DT); - return NonNegative; + KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT); + return Known.isNonNegative(); } bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth, @@ -207,9 +227,8 @@ bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth, bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - bool NonNegative, Negative; - ComputeSignBit(V, NonNegative, Negative, DL, Depth, AC, CxtI, DT); - return Negative; + KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT); + return Known.isNegative(); } static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q); @@ -246,91 +265,65 @@ unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL, static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, bool NSW, - APInt &KnownZero, APInt &KnownOne, - APInt &KnownZero2, APInt &KnownOne2, + KnownBits &KnownOut, KnownBits &Known2, unsigned Depth, const Query &Q) { - if (!Add) { - if (const ConstantInt *CLHS = dyn_cast<ConstantInt>(Op0)) { - // We know that the top bits of C-X are clear if X contains less bits - // than C (i.e. no wrap-around can happen). For example, 20-X is - // positive if we can prove that X is >= 0 and < 16. - if (!CLHS->getValue().isNegative()) { - unsigned BitWidth = KnownZero.getBitWidth(); - unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros(); - // NLZ can't be BitWidth with no sign bit - APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1); - computeKnownBits(Op1, KnownZero2, KnownOne2, Depth + 1, Q); - - // If all of the MaskV bits are known to be zero, then we know the - // output top bits are zero, because we now know that the output is - // from [0-C]. - if ((KnownZero2 & MaskV) == MaskV) { - unsigned NLZ2 = CLHS->getValue().countLeadingZeros(); - // Top bits known zero. - KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2); - } - } - } - } - - unsigned BitWidth = KnownZero.getBitWidth(); + unsigned BitWidth = KnownOut.getBitWidth(); // If an initial sequence of bits in the result is not needed, the // corresponding bits in the operands are not needed. - APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, Depth + 1, Q); - computeKnownBits(Op1, KnownZero2, KnownOne2, Depth + 1, Q); + KnownBits LHSKnown(BitWidth); + computeKnownBits(Op0, LHSKnown, Depth + 1, Q); + computeKnownBits(Op1, Known2, Depth + 1, Q); // Carry in a 1 for a subtract, rather than a 0. - APInt CarryIn(BitWidth, 0); + uint64_t CarryIn = 0; if (!Add) { // Sum = LHS + ~RHS + 1 - std::swap(KnownZero2, KnownOne2); - CarryIn.setBit(0); + std::swap(Known2.Zero, Known2.One); + CarryIn = 1; } - APInt PossibleSumZero = ~LHSKnownZero + ~KnownZero2 + CarryIn; - APInt PossibleSumOne = LHSKnownOne + KnownOne2 + CarryIn; + APInt PossibleSumZero = ~LHSKnown.Zero + ~Known2.Zero + CarryIn; + APInt PossibleSumOne = LHSKnown.One + Known2.One + CarryIn; // Compute known bits of the carry. - APInt CarryKnownZero = ~(PossibleSumZero ^ LHSKnownZero ^ KnownZero2); - APInt CarryKnownOne = PossibleSumOne ^ LHSKnownOne ^ KnownOne2; + APInt CarryKnownZero = ~(PossibleSumZero ^ LHSKnown.Zero ^ Known2.Zero); + APInt CarryKnownOne = PossibleSumOne ^ LHSKnown.One ^ Known2.One; // Compute set of known bits (where all three relevant bits are known). - APInt LHSKnown = LHSKnownZero | LHSKnownOne; - APInt RHSKnown = KnownZero2 | KnownOne2; - APInt CarryKnown = CarryKnownZero | CarryKnownOne; - APInt Known = LHSKnown & RHSKnown & CarryKnown; + APInt LHSKnownUnion = LHSKnown.Zero | LHSKnown.One; + APInt RHSKnownUnion = Known2.Zero | Known2.One; + APInt CarryKnownUnion = CarryKnownZero | CarryKnownOne; + APInt Known = LHSKnownUnion & RHSKnownUnion & CarryKnownUnion; assert((PossibleSumZero & Known) == (PossibleSumOne & Known) && "known bits of sum differ"); // Compute known bits of the result. - KnownZero = ~PossibleSumOne & Known; - KnownOne = PossibleSumOne & Known; + KnownOut.Zero = ~PossibleSumOne & Known; + KnownOut.One = PossibleSumOne & Known; // Are we still trying to solve for the sign bit? - if (!Known.isNegative()) { + if (!Known.isSignBitSet()) { if (NSW) { // Adding two non-negative numbers, or subtracting a negative number from // a non-negative one, can't wrap into negative. - if (LHSKnownZero.isNegative() && KnownZero2.isNegative()) - KnownZero |= APInt::getSignBit(BitWidth); + if (LHSKnown.isNonNegative() && Known2.isNonNegative()) + KnownOut.makeNonNegative(); // Adding two negative numbers, or subtracting a non-negative number from // a negative one, can't wrap into non-negative. - else if (LHSKnownOne.isNegative() && KnownOne2.isNegative()) - KnownOne |= APInt::getSignBit(BitWidth); + else if (LHSKnown.isNegative() && Known2.isNegative()) + KnownOut.makeNegative(); } } } static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, - APInt &KnownZero, APInt &KnownOne, - APInt &KnownZero2, APInt &KnownOne2, + KnownBits &Known, KnownBits &Known2, unsigned Depth, const Query &Q) { - unsigned BitWidth = KnownZero.getBitWidth(); - computeKnownBits(Op1, KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(Op0, KnownZero2, KnownOne2, Depth + 1, Q); + unsigned BitWidth = Known.getBitWidth(); + computeKnownBits(Op1, Known, Depth + 1, Q); + computeKnownBits(Op0, Known2, Depth + 1, Q); bool isKnownNegative = false; bool isKnownNonNegative = false; @@ -340,10 +333,10 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, // The product of a number with itself is non-negative. isKnownNonNegative = true; } else { - bool isKnownNonNegativeOp1 = KnownZero.isNegative(); - bool isKnownNonNegativeOp0 = KnownZero2.isNegative(); - bool isKnownNegativeOp1 = KnownOne.isNegative(); - bool isKnownNegativeOp0 = KnownOne2.isNegative(); + bool isKnownNonNegativeOp1 = Known.isNonNegative(); + bool isKnownNonNegativeOp0 = Known2.isNonNegative(); + bool isKnownNegativeOp1 = Known.isNegative(); + bool isKnownNegativeOp0 = Known2.isNegative(); // The product of two numbers with the same sign is non-negative. isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); @@ -361,38 +354,37 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, // Also compute a conservative estimate for high known-0 bits. // More trickiness is possible, but this is sufficient for the // interesting case of alignment computation. - KnownOne.clearAllBits(); - unsigned TrailZ = KnownZero.countTrailingOnes() + - KnownZero2.countTrailingOnes(); - unsigned LeadZ = std::max(KnownZero.countLeadingOnes() + - KnownZero2.countLeadingOnes(), + unsigned TrailZ = Known.countMinTrailingZeros() + + Known2.countMinTrailingZeros(); + unsigned LeadZ = std::max(Known.countMinLeadingZeros() + + Known2.countMinLeadingZeros(), BitWidth) - BitWidth; TrailZ = std::min(TrailZ, BitWidth); LeadZ = std::min(LeadZ, BitWidth); - KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ) | - APInt::getHighBitsSet(BitWidth, LeadZ); + Known.resetAll(); + Known.Zero.setLowBits(TrailZ); + Known.Zero.setHighBits(LeadZ); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in // which case we prefer to follow the result of the direct computation, // though as the program is invoking undefined behaviour we can choose // whatever we like here. - if (isKnownNonNegative && !KnownOne.isNegative()) - KnownZero.setBit(BitWidth - 1); - else if (isKnownNegative && !KnownZero.isNegative()) - KnownOne.setBit(BitWidth - 1); + if (isKnownNonNegative && !Known.isNegative()) + Known.makeNonNegative(); + else if (isKnownNegative && !Known.isNonNegative()) + Known.makeNegative(); } void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, - APInt &KnownZero, - APInt &KnownOne) { - unsigned BitWidth = KnownZero.getBitWidth(); + KnownBits &Known) { + unsigned BitWidth = Known.getBitWidth(); unsigned NumRanges = Ranges.getNumOperands() / 2; assert(NumRanges >= 1); - KnownZero.setAllBits(); - KnownOne.setAllBits(); + Known.Zero.setAllBits(); + Known.One.setAllBits(); for (unsigned i = 0; i < NumRanges; ++i) { ConstantInt *Lower = @@ -406,8 +398,8 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros(); APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits); - KnownOne &= Range.getUnsignedMax() & Mask; - KnownZero &= ~Range.getUnsignedMax() & Mask; + Known.One &= Range.getUnsignedMax() & Mask; + Known.Zero &= ~Range.getUnsignedMax() & Mask; } } @@ -516,15 +508,14 @@ bool llvm::isValidAssumeForContext(const Instruction *Inv, return !isEphemeralValueOf(Inv, CxtI); } -static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, - APInt &KnownOne, unsigned Depth, - const Query &Q) { +static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, + unsigned Depth, const Query &Q) { // Use of assumptions is context-sensitive. If we don't have a context, we // cannot use them! if (!Q.AC || !Q.CxtI) return; - unsigned BitWidth = KnownZero.getBitWidth(); + unsigned BitWidth = Known.getBitWidth(); // Note that the patterns below need to be kept in sync with the code // in AssumptionCache::updateAffectedValues. @@ -549,8 +540,13 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { assert(BitWidth == 1 && "assume operand is not i1?"); - KnownZero.clearAllBits(); - KnownOne.setAllBits(); + Known.setAllOnes(); + return; + } + if (match(Arg, m_Not(m_Specific(V))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + assert(BitWidth == 1 && "assume operand is not i1?"); + Known.setAllZero(); return; } @@ -568,122 +564,126 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, // assume(v = a) if (match(Arg, m_c_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); - KnownZero |= RHSKnownZero; - KnownOne |= RHSKnownOne; + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + Known.Zero |= RHSKnown.Zero; + Known.One |= RHSKnown.One; // assume(v & b = a) } else if (match(Arg, m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); - APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); - computeKnownBits(B, MaskKnownZero, MaskKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits MaskKnown(BitWidth); + computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); // For those bits in the mask that are known to be one, we can propagate // known bits from the RHS to V. - KnownZero |= RHSKnownZero & MaskKnownOne; - KnownOne |= RHSKnownOne & MaskKnownOne; + Known.Zero |= RHSKnown.Zero & MaskKnown.One; + Known.One |= RHSKnown.One & MaskKnown.One; // assume(~(v & b) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); - APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); - computeKnownBits(B, MaskKnownZero, MaskKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits MaskKnown(BitWidth); + computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); // For those bits in the mask that are known to be one, we can propagate // inverted known bits from the RHS to V. - KnownZero |= RHSKnownOne & MaskKnownOne; - KnownOne |= RHSKnownZero & MaskKnownOne; + Known.Zero |= RHSKnown.One & MaskKnown.One; + Known.One |= RHSKnown.Zero & MaskKnown.One; // assume(v | b = a) } else if (match(Arg, m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); - APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); - computeKnownBits(B, BKnownZero, BKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. - KnownZero |= RHSKnownZero & BKnownZero; - KnownOne |= RHSKnownOne & BKnownZero; + Known.Zero |= RHSKnown.Zero & BKnown.Zero; + Known.One |= RHSKnown.One & BKnown.Zero; // assume(~(v | b) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); - APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); - computeKnownBits(B, BKnownZero, BKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. - KnownZero |= RHSKnownOne & BKnownZero; - KnownOne |= RHSKnownZero & BKnownZero; + Known.Zero |= RHSKnown.One & BKnown.Zero; + Known.One |= RHSKnown.Zero & BKnown.Zero; // assume(v ^ b = a) } else if (match(Arg, m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); - APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); - computeKnownBits(B, BKnownZero, BKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. For those bits in B that are known to be one, // we can propagate inverted known bits from the RHS to V. - KnownZero |= RHSKnownZero & BKnownZero; - KnownOne |= RHSKnownOne & BKnownZero; - KnownZero |= RHSKnownOne & BKnownOne; - KnownOne |= RHSKnownZero & BKnownOne; + Known.Zero |= RHSKnown.Zero & BKnown.Zero; + Known.One |= RHSKnown.One & BKnown.Zero; + Known.Zero |= RHSKnown.One & BKnown.One; + Known.One |= RHSKnown.Zero & BKnown.One; // assume(~(v ^ b) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); - APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); - computeKnownBits(B, BKnownZero, BKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. For those bits in B that are // known to be one, we can propagate known bits from the RHS to V. - KnownZero |= RHSKnownOne & BKnownZero; - KnownOne |= RHSKnownZero & BKnownZero; - KnownZero |= RHSKnownZero & BKnownOne; - KnownOne |= RHSKnownOne & BKnownOne; + Known.Zero |= RHSKnown.One & BKnown.Zero; + Known.One |= RHSKnown.Zero & BKnown.Zero; + Known.Zero |= RHSKnown.Zero & BKnown.One; + Known.One |= RHSKnown.One & BKnown.One; // assume(v << c = a) } else if (match(Arg, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. - KnownZero |= RHSKnownZero.lshr(C->getZExtValue()); - KnownOne |= RHSKnownOne.lshr(C->getZExtValue()); + RHSKnown.Zero.lshrInPlace(C->getZExtValue()); + Known.Zero |= RHSKnown.Zero; + RHSKnown.One.lshrInPlace(C->getZExtValue()); + Known.One |= RHSKnown.One; // assume(~(v << c) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. - KnownZero |= RHSKnownOne.lshr(C->getZExtValue()); - KnownOne |= RHSKnownZero.lshr(C->getZExtValue()); + RHSKnown.One.lshrInPlace(C->getZExtValue()); + Known.Zero |= RHSKnown.One; + RHSKnown.Zero.lshrInPlace(C->getZExtValue()); + Known.One |= RHSKnown.Zero; // assume(v >> c = a) } else if (match(Arg, m_c_ICmp(Pred, m_CombineOr(m_LShr(m_V, m_ConstantInt(C)), @@ -691,12 +691,12 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. - KnownZero |= RHSKnownZero << C->getZExtValue(); - KnownOne |= RHSKnownOne << C->getZExtValue(); + Known.Zero |= RHSKnown.Zero << C->getZExtValue(); + Known.One |= RHSKnown.One << C->getZExtValue(); // assume(~(v >> c) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_CombineOr( m_LShr(m_V, m_ConstantInt(C)), @@ -704,146 +704,147 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. - KnownZero |= RHSKnownOne << C->getZExtValue(); - KnownOne |= RHSKnownZero << C->getZExtValue(); + Known.Zero |= RHSKnown.One << C->getZExtValue(); + Known.One |= RHSKnown.Zero << C->getZExtValue(); // assume(v >=_s c) where c is non-negative } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_SGE && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - if (RHSKnownZero.isNegative()) { + if (RHSKnown.isNonNegative()) { // We know that the sign bit is zero. - KnownZero |= APInt::getSignBit(BitWidth); + Known.makeNonNegative(); } // assume(v >_s c) where c is at least -1. } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_SGT && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - if (RHSKnownOne.isAllOnesValue() || RHSKnownZero.isNegative()) { + if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { // We know that the sign bit is zero. - KnownZero |= APInt::getSignBit(BitWidth); + Known.makeNonNegative(); } // assume(v <=_s c) where c is negative } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_SLE && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - if (RHSKnownOne.isNegative()) { + if (RHSKnown.isNegative()) { // We know that the sign bit is one. - KnownOne |= APInt::getSignBit(BitWidth); + Known.makeNegative(); } // assume(v <_s c) where c is non-positive } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_SLT && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - if (RHSKnownZero.isAllOnesValue() || RHSKnownOne.isNegative()) { + if (RHSKnown.isZero() || RHSKnown.isNegative()) { // We know that the sign bit is one. - KnownOne |= APInt::getSignBit(BitWidth); + Known.makeNegative(); } // assume(v <=_u c) } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_ULE && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // Whatever high bits in c are zero are known to be zero. - KnownZero |= - APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); - // assume(v <_u c) + Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); + // assume(v <_u c) } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_ULT && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // Whatever high bits in c are zero are known to be zero (if c is a power // of 2, then one more). if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) - KnownZero |= - APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()+1); + Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1); else - KnownZero |= - APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); + Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); } } // If assumptions conflict with each other or previous known bits, then we - // have a logical fallacy. This should only happen when a program has - // undefined behavior. We can't assert/crash, so clear out the known bits and - // hope for the best. - - // FIXME: Publish a warning/remark that we have encountered UB or the compiler - // is broken. - - // FIXME: Implement a stronger version of "I give up" by invalidating/clearing - // the assumption cache. This should indicate that the cache is corrupted so - // future callers will not waste time repopulating it with faulty assumptions. - - if ((KnownZero & KnownOne) != 0) { - KnownZero.clearAllBits(); - KnownOne.clearAllBits(); + // have a logical fallacy. It's possible that the assumption is not reachable, + // so this isn't a real bug. On the other hand, the program may have undefined + // behavior, or we might have a bug in the compiler. We can't assert/crash, so + // clear out the known bits, try to warn the user, and hope for the best. + if (Known.Zero.intersects(Known.One)) { + Known.resetAll(); + + if (Q.ORE) { + auto *CxtI = const_cast<Instruction *>(Q.CxtI); + OptimizationRemarkAnalysis ORA("value-tracking", "BadAssumption", CxtI); + Q.ORE->emit(ORA << "Detected conflicting code assumptions. Program may " + "have undefined behavior, or compiler may have " + "internal error."); + } } } // Compute known bits from a shift operator, including those with a -// non-constant shift amount. KnownZero and KnownOne are the outputs of this -// function. KnownZero2 and KnownOne2 are pre-allocated temporaries with the -// same bit width as KnownZero and KnownOne. KZF and KOF are operator-specific -// functors that, given the known-zero or known-one bits respectively, and a -// shift amount, compute the implied known-zero or known-one bits of the shift -// operator's result respectively for that shift amount. The results from calling -// KZF and KOF are conservatively combined for all permitted shift amounts. +// non-constant shift amount. Known is the outputs of this function. Known2 is a +// pre-allocated temporary with the/ same bit width as Known. KZF and KOF are +// operator-specific functors that, given the known-zero or known-one bits +// respectively, and a shift amount, compute the implied known-zero or known-one +// bits of the shift operator's result respectively for that shift amount. The +// results from calling KZF and KOF are conservatively combined for all +// permitted shift amounts. static void computeKnownBitsFromShiftOperator( - const Operator *I, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, - APInt &KnownOne2, unsigned Depth, const Query &Q, + const Operator *I, KnownBits &Known, KnownBits &Known2, + unsigned Depth, const Query &Q, function_ref<APInt(const APInt &, unsigned)> KZF, function_ref<APInt(const APInt &, unsigned)> KOF) { - unsigned BitWidth = KnownZero.getBitWidth(); + unsigned BitWidth = Known.getBitWidth(); if (auto *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { unsigned ShiftAmt = SA->getLimitedValue(BitWidth-1); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); - KnownZero = KZF(KnownZero, ShiftAmt); - KnownOne = KOF(KnownOne, ShiftAmt); - // If there is conflict between KnownZero and KnownOne, this must be an - // overflowing left shift, so the shift result is undefined. Clear KnownZero - // and KnownOne bits so that other code could propagate this undef. - if ((KnownZero & KnownOne) != 0) { - KnownZero.clearAllBits(); - KnownOne.clearAllBits(); - } + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + Known.Zero = KZF(Known.Zero, ShiftAmt); + Known.One = KOF(Known.One, ShiftAmt); + // If there is conflict between Known.Zero and Known.One, this must be an + // overflowing left shift, so the shift result is undefined. Clear Known + // bits so that other code could propagate this undef. + if ((Known.Zero & Known.One) != 0) + Known.resetAll(); return; } - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + + // If the shift amount could be greater than or equal to the bit-width of the LHS, the + // value could be undef, so we don't know anything about it. + if ((~Known.Zero).uge(BitWidth)) { + Known.resetAll(); + return; + } - // Note: We cannot use KnownZero.getLimitedValue() here, because if + // Note: We cannot use Known.Zero.getLimitedValue() here, because if // BitWidth > 64 and any upper bits are known, we'll end up returning the // limit value (which implies all bits are known). - uint64_t ShiftAmtKZ = KnownZero.zextOrTrunc(64).getZExtValue(); - uint64_t ShiftAmtKO = KnownOne.zextOrTrunc(64).getZExtValue(); + uint64_t ShiftAmtKZ = Known.Zero.zextOrTrunc(64).getZExtValue(); + uint64_t ShiftAmtKO = Known.One.zextOrTrunc(64).getZExtValue(); // It would be more-clearly correct to use the two temporaries for this // calculation. Reusing the APInts here to prevent unnecessary allocations. - KnownZero.clearAllBits(); - KnownOne.clearAllBits(); + Known.resetAll(); // If we know the shifter operand is nonzero, we can sometimes infer more // known bits. However this is expensive to compute, so be lazy about it and @@ -858,9 +859,10 @@ static void computeKnownBitsFromShiftOperator( return; } - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); - KnownZero = KnownOne = APInt::getAllOnesValue(BitWidth); + Known.Zero.setAllBits(); + Known.One.setAllBits(); for (unsigned ShiftAmt = 0; ShiftAmt < BitWidth; ++ShiftAmt) { // Combine the shifted known input bits only for those shift amounts // compatible with its known constraints. @@ -879,8 +881,8 @@ static void computeKnownBitsFromShiftOperator( continue; } - KnownZero &= KZF(KnownZero2, ShiftAmt); - KnownOne &= KOF(KnownOne2, ShiftAmt); + Known.Zero &= KZF(Known2.Zero, ShiftAmt); + Known.One &= KOF(Known2.One, ShiftAmt); } // If there are no compatible shift amounts, then we've proven that the shift @@ -888,33 +890,30 @@ static void computeKnownBitsFromShiftOperator( // return anything we'd like, but we need to make sure the sets of known bits // stay disjoint (it should be better for some other code to actually // propagate the undef than to pick a value here using known bits). - if ((KnownZero & KnownOne) != 0) { - KnownZero.clearAllBits(); - KnownOne.clearAllBits(); - } + if (Known.Zero.intersects(Known.One)) + Known.resetAll(); } -static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, - APInt &KnownOne, unsigned Depth, - const Query &Q) { - unsigned BitWidth = KnownZero.getBitWidth(); +static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, + unsigned Depth, const Query &Q) { + unsigned BitWidth = Known.getBitWidth(); - APInt KnownZero2(KnownZero), KnownOne2(KnownOne); + KnownBits Known2(Known); switch (I->getOpcode()) { default: break; case Instruction::Load: if (MDNode *MD = cast<LoadInst>(I)->getMetadata(LLVMContext::MD_range)) - computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne); + computeKnownBitsFromRangeMetadata(*MD, Known); break; case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); // Output known-1 bits are only known if set in both the LHS & RHS. - KnownOne &= KnownOne2; + Known.One &= Known2.One; // Output known-0 are known to be clear if zero in either the LHS | RHS. - KnownZero |= KnownZero2; + Known.Zero |= Known2.Zero; // and(x, add (x, -1)) is a common idiom that always clears the low bit; // here we handle the more general case of adding any odd number by @@ -922,118 +921,113 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // TODO: This could be generalized to clearing any bit set in y where the // following bit is known to be unset in y. Value *Y = nullptr; - if (match(I->getOperand(0), m_Add(m_Specific(I->getOperand(1)), - m_Value(Y))) || - match(I->getOperand(1), m_Add(m_Specific(I->getOperand(0)), - m_Value(Y)))) { - APInt KnownZero3(BitWidth, 0), KnownOne3(BitWidth, 0); - computeKnownBits(Y, KnownZero3, KnownOne3, Depth + 1, Q); - if (KnownOne3.countTrailingOnes() > 0) - KnownZero |= APInt::getLowBitsSet(BitWidth, 1); + if (!Known.Zero[0] && !Known.One[0] && + (match(I->getOperand(0), m_Add(m_Specific(I->getOperand(1)), + m_Value(Y))) || + match(I->getOperand(1), m_Add(m_Specific(I->getOperand(0)), + m_Value(Y))))) { + Known2.resetAll(); + computeKnownBits(Y, Known2, Depth + 1, Q); + if (Known2.countMinTrailingOnes() > 0) + Known.Zero.setBit(0); } break; } case Instruction::Or: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); // Output known-0 bits are only known if clear in both the LHS & RHS. - KnownZero &= KnownZero2; + Known.Zero &= Known2.Zero; // Output known-1 are known to be set if set in either the LHS | RHS. - KnownOne |= KnownOne2; + Known.One |= Known2.One; break; } case Instruction::Xor: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); // Output known-0 bits are known if clear or set in both the LHS & RHS. - APInt KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2); + APInt KnownZeroOut = (Known.Zero & Known2.Zero) | (Known.One & Known2.One); // Output known-1 are known to be set if set in only one of the LHS, RHS. - KnownOne = (KnownZero & KnownOne2) | (KnownOne & KnownZero2); - KnownZero = KnownZeroOut; + Known.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero); + Known.Zero = std::move(KnownZeroOut); break; } case Instruction::Mul: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, KnownZero, - KnownOne, KnownZero2, KnownOne2, Depth, Q); + computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, Known, + Known2, Depth, Q); break; } case Instruction::UDiv: { // For the purposes of computing leading zeros we can conservatively // treat a udiv as a logical right shift by the power of 2 known to // be less than the denominator. - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); - unsigned LeadZ = KnownZero2.countLeadingOnes(); - - KnownOne2.clearAllBits(); - KnownZero2.clearAllBits(); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, Depth + 1, Q); - unsigned RHSUnknownLeadingOnes = KnownOne2.countLeadingZeros(); - if (RHSUnknownLeadingOnes != BitWidth) - LeadZ = std::min(BitWidth, - LeadZ + BitWidth - RHSUnknownLeadingOnes - 1); - - KnownZero = APInt::getHighBitsSet(BitWidth, LeadZ); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + unsigned LeadZ = Known2.countMinLeadingZeros(); + + Known2.resetAll(); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + unsigned RHSMaxLeadingZeros = Known2.countMaxLeadingZeros(); + if (RHSMaxLeadingZeros != BitWidth) + LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1); + + Known.Zero.setHighBits(LeadZ); break; } case Instruction::Select: { - computeKnownBits(I->getOperand(2), KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, Depth + 1, Q); - - const Value *LHS; - const Value *RHS; + const Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor; if (SelectPatternResult::isMinOrMax(SPF)) { - computeKnownBits(RHS, KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(LHS, KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(RHS, Known, Depth + 1, Q); + computeKnownBits(LHS, Known2, Depth + 1, Q); } else { - computeKnownBits(I->getOperand(2), KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(I->getOperand(2), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); } unsigned MaxHighOnes = 0; unsigned MaxHighZeros = 0; if (SPF == SPF_SMAX) { // If both sides are negative, the result is negative. - if (KnownOne[BitWidth - 1] && KnownOne2[BitWidth - 1]) + if (Known.isNegative() && Known2.isNegative()) // We can derive a lower bound on the result by taking the max of the // leading one bits. MaxHighOnes = - std::max(KnownOne.countLeadingOnes(), KnownOne2.countLeadingOnes()); + std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes()); // If either side is non-negative, the result is non-negative. - else if (KnownZero[BitWidth - 1] || KnownZero2[BitWidth - 1]) + else if (Known.isNonNegative() || Known2.isNonNegative()) MaxHighZeros = 1; } else if (SPF == SPF_SMIN) { // If both sides are non-negative, the result is non-negative. - if (KnownZero[BitWidth - 1] && KnownZero2[BitWidth - 1]) + if (Known.isNonNegative() && Known2.isNonNegative()) // We can derive an upper bound on the result by taking the max of the // leading zero bits. - MaxHighZeros = std::max(KnownZero.countLeadingOnes(), - KnownZero2.countLeadingOnes()); + MaxHighZeros = std::max(Known.countMinLeadingZeros(), + Known2.countMinLeadingZeros()); // If either side is negative, the result is negative. - else if (KnownOne[BitWidth - 1] || KnownOne2[BitWidth - 1]) + else if (Known.isNegative() || Known2.isNegative()) MaxHighOnes = 1; } else if (SPF == SPF_UMAX) { // We can derive a lower bound on the result by taking the max of the // leading one bits. MaxHighOnes = - std::max(KnownOne.countLeadingOnes(), KnownOne2.countLeadingOnes()); + std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes()); } else if (SPF == SPF_UMIN) { // We can derive an upper bound on the result by taking the max of the // leading zero bits. MaxHighZeros = - std::max(KnownZero.countLeadingOnes(), KnownZero2.countLeadingOnes()); + std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros()); } // Only known if known in both the LHS and RHS. - KnownOne &= KnownOne2; - KnownZero &= KnownZero2; + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; if (MaxHighOnes > 0) - KnownOne |= APInt::getHighBitsSet(BitWidth, MaxHighOnes); + Known.One.setHighBits(MaxHighOnes); if (MaxHighZeros > 0) - KnownZero |= APInt::getHighBitsSet(BitWidth, MaxHighZeros); + Known.Zero.setHighBits(MaxHighZeros); break; } case Instruction::FPTrunc: @@ -1057,14 +1051,12 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, SrcBitWidth = Q.DL.getTypeSizeInBits(SrcTy->getScalarType()); assert(SrcBitWidth && "SrcBitWidth can't be zero"); - KnownZero = KnownZero.zextOrTrunc(SrcBitWidth); - KnownOne = KnownOne.zextOrTrunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); - KnownZero = KnownZero.zextOrTrunc(BitWidth); - KnownOne = KnownOne.zextOrTrunc(BitWidth); + Known = Known.zextOrTrunc(SrcBitWidth); + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + Known = Known.zextOrTrunc(BitWidth); // Any top bits are known to be zero. if (BitWidth > SrcBitWidth) - KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + Known.Zero.setBitsFrom(SrcBitWidth); break; } case Instruction::BitCast: { @@ -1073,7 +1065,7 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // TODO: For now, not handling conversions like: // (bitcast i64 %x to <2 x i32>) !I->getType()->isVectorTy()) { - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); break; } break; @@ -1082,90 +1074,75 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // Compute the bits in the result that are not present in the input. unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); - KnownZero = KnownZero.trunc(SrcBitWidth); - KnownOne = KnownOne.trunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); - KnownZero = KnownZero.zext(BitWidth); - KnownOne = KnownOne.zext(BitWidth); - + Known = Known.trunc(SrcBitWidth); + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); // If the sign bit of the input is known set or clear, then we know the // top bits of the result. - if (KnownZero[SrcBitWidth-1]) // Input sign bit known zero - KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); - else if (KnownOne[SrcBitWidth-1]) // Input sign bit known set - KnownOne |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + Known = Known.sext(BitWidth); break; } case Instruction::Shl: { // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - auto KZF = [BitWidth, NSW](const APInt &KnownZero, unsigned ShiftAmt) { - APInt KZResult = - (KnownZero << ShiftAmt) | - APInt::getLowBitsSet(BitWidth, ShiftAmt); // Low bits known 0. + auto KZF = [NSW](const APInt &KnownZero, unsigned ShiftAmt) { + APInt KZResult = KnownZero << ShiftAmt; + KZResult.setLowBits(ShiftAmt); // Low bits known 0. // If this shift has "nsw" keyword, then the result is either a poison // value or has the same sign bit as the first operand. - if (NSW && KnownZero.isNegative()) - KZResult.setBit(BitWidth - 1); + if (NSW && KnownZero.isSignBitSet()) + KZResult.setSignBit(); return KZResult; }; - auto KOF = [BitWidth, NSW](const APInt &KnownOne, unsigned ShiftAmt) { + auto KOF = [NSW](const APInt &KnownOne, unsigned ShiftAmt) { APInt KOResult = KnownOne << ShiftAmt; - if (NSW && KnownOne.isNegative()) - KOResult.setBit(BitWidth - 1); + if (NSW && KnownOne.isSignBitSet()) + KOResult.setSignBit(); return KOResult; }; - computeKnownBitsFromShiftOperator(I, KnownZero, KnownOne, - KnownZero2, KnownOne2, Depth, Q, KZF, - KOF); + computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); break; } case Instruction::LShr: { // (ushr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 - auto KZF = [BitWidth](const APInt &KnownZero, unsigned ShiftAmt) { - return APIntOps::lshr(KnownZero, ShiftAmt) | - // High bits known zero. - APInt::getHighBitsSet(BitWidth, ShiftAmt); + auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) { + APInt KZResult = KnownZero.lshr(ShiftAmt); + // High bits known zero. + KZResult.setHighBits(ShiftAmt); + return KZResult; }; - auto KOF = [BitWidth](const APInt &KnownOne, unsigned ShiftAmt) { - return APIntOps::lshr(KnownOne, ShiftAmt); + auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) { + return KnownOne.lshr(ShiftAmt); }; - computeKnownBitsFromShiftOperator(I, KnownZero, KnownOne, - KnownZero2, KnownOne2, Depth, Q, KZF, - KOF); + computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); break; } case Instruction::AShr: { // (ashr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 - auto KZF = [BitWidth](const APInt &KnownZero, unsigned ShiftAmt) { - return APIntOps::ashr(KnownZero, ShiftAmt); + auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) { + return KnownZero.ashr(ShiftAmt); }; - auto KOF = [BitWidth](const APInt &KnownOne, unsigned ShiftAmt) { - return APIntOps::ashr(KnownOne, ShiftAmt); + auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) { + return KnownOne.ashr(ShiftAmt); }; - computeKnownBitsFromShiftOperator(I, KnownZero, KnownOne, - KnownZero2, KnownOne2, Depth, Q, KZF, - KOF); + computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); break; } case Instruction::Sub: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, - KnownZero, KnownOne, KnownZero2, KnownOne2, Depth, - Q); + Known, Known2, Depth, Q); break; } case Instruction::Add: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, - KnownZero, KnownOne, KnownZero2, KnownOne2, Depth, - Q); + Known, Known2, Depth, Q); break; } case Instruction::SRem: @@ -1173,37 +1150,33 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, APInt RA = Rem->getValue().abs(); if (RA.isPowerOf2()) { APInt LowBits = RA - 1; - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, - Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); // The low bits of the first operand are unchanged by the srem. - KnownZero = KnownZero2 & LowBits; - KnownOne = KnownOne2 & LowBits; + Known.Zero = Known2.Zero & LowBits; + Known.One = Known2.One & LowBits; // If the first operand is non-negative or has all low bits zero, then // the upper bits are all zero. - if (KnownZero2[BitWidth-1] || ((KnownZero2 & LowBits) == LowBits)) - KnownZero |= ~LowBits; + if (Known2.isNonNegative() || LowBits.isSubsetOf(Known2.Zero)) + Known.Zero |= ~LowBits; // If the first operand is negative and not all low bits are zero, then // the upper bits are all one. - if (KnownOne2[BitWidth-1] && ((KnownOne2 & LowBits) != 0)) - KnownOne |= ~LowBits; + if (Known2.isNegative() && LowBits.intersects(Known2.One)) + Known.One |= ~LowBits; - assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?"); + break; } } // The sign bit is the LHS's sign bit, except when the result of the // remainder is zero. - if (KnownZero.isNonNegative()) { - APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, - Q); - // If it's known zero, our sign bit is also zero. - if (LHSKnownZero.isNegative()) - KnownZero.setBit(BitWidth - 1); - } + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + // If it's known zero, our sign bit is also zero. + if (Known2.isNonNegative()) + Known.makeNonNegative(); break; case Instruction::URem: { @@ -1211,22 +1184,22 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, const APInt &RA = Rem->getValue(); if (RA.isPowerOf2()) { APInt LowBits = (RA - 1); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); - KnownZero |= ~LowBits; - KnownOne &= LowBits; + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + Known.Zero |= ~LowBits; + Known.One &= LowBits; break; } } // Since the result is less than or equal to either operand, any leading // zero bits in either operand must also exist in the result. - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); - unsigned Leaders = std::max(KnownZero.countLeadingOnes(), - KnownZero2.countLeadingOnes()); - KnownOne.clearAllBits(); - KnownZero = APInt::getHighBitsSet(BitWidth, Leaders); + unsigned Leaders = + std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros()); + Known.resetAll(); + Known.Zero.setHighBits(Leaders); break; } @@ -1237,16 +1210,15 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, Align = Q.DL.getABITypeAlignment(AI->getAllocatedType()); if (Align > 0) - KnownZero = APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); + Known.Zero.setLowBits(countTrailingZeros(Align)); break; } case Instruction::GetElementPtr: { // Analyze all of the subscripts of this getelementptr instruction // to determine if we can prove known low zero bits. - APInt LocalKnownZero(BitWidth, 0), LocalKnownOne(BitWidth, 0); - computeKnownBits(I->getOperand(0), LocalKnownZero, LocalKnownOne, Depth + 1, - Q); - unsigned TrailZ = LocalKnownZero.countTrailingOnes(); + KnownBits LocalKnown(BitWidth); + computeKnownBits(I->getOperand(0), LocalKnown, Depth + 1, Q); + unsigned TrailZ = LocalKnown.countMinTrailingZeros(); gep_type_iterator GTI = gep_type_begin(I); for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) { @@ -1276,15 +1248,15 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, } unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy); - LocalKnownZero = LocalKnownOne = APInt(GEPOpiBits, 0); - computeKnownBits(Index, LocalKnownZero, LocalKnownOne, Depth + 1, Q); + LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0); + computeKnownBits(Index, LocalKnown, Depth + 1, Q); TrailZ = std::min(TrailZ, unsigned(countTrailingZeros(TypeSize) + - LocalKnownZero.countTrailingOnes())); + LocalKnown.countMinTrailingZeros())); } } - KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ); + Known.Zero.setLowBits(TrailZ); break; } case Instruction::PHI: { @@ -1319,15 +1291,14 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, break; // Ok, we have a PHI of the form L op= R. Check for low // zero bits. - computeKnownBits(R, KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(R, Known2, Depth + 1, Q); // We need to take the minimum number of known bits - APInt KnownZero3(KnownZero), KnownOne3(KnownOne); - computeKnownBits(L, KnownZero3, KnownOne3, Depth + 1, Q); + KnownBits Known3(Known); + computeKnownBits(L, Known3, Depth + 1, Q); - KnownZero = APInt::getLowBitsSet( - BitWidth, std::min(KnownZero2.countTrailingOnes(), - KnownZero3.countTrailingOnes())); + Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(), + Known3.countMinTrailingZeros())); if (DontImproveNonNegativePhiBits) break; @@ -1344,25 +1315,25 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // (add non-negative, non-negative) --> non-negative // (add negative, negative) --> negative if (Opcode == Instruction::Add) { - if (KnownZero2.isNegative() && KnownZero3.isNegative()) - KnownZero.setBit(BitWidth - 1); - else if (KnownOne2.isNegative() && KnownOne3.isNegative()) - KnownOne.setBit(BitWidth - 1); + if (Known2.isNonNegative() && Known3.isNonNegative()) + Known.makeNonNegative(); + else if (Known2.isNegative() && Known3.isNegative()) + Known.makeNegative(); } // (sub nsw non-negative, negative) --> non-negative // (sub nsw negative, non-negative) --> negative else if (Opcode == Instruction::Sub && LL == I) { - if (KnownZero2.isNegative() && KnownOne3.isNegative()) - KnownZero.setBit(BitWidth - 1); - else if (KnownOne2.isNegative() && KnownZero3.isNegative()) - KnownOne.setBit(BitWidth - 1); + if (Known2.isNonNegative() && Known3.isNegative()) + Known.makeNonNegative(); + else if (Known2.isNegative() && Known3.isNonNegative()) + Known.makeNegative(); } // (mul nsw non-negative, non-negative) --> non-negative - else if (Opcode == Instruction::Mul && KnownZero2.isNegative() && - KnownZero3.isNegative()) - KnownZero.setBit(BitWidth - 1); + else if (Opcode == Instruction::Mul && Known2.isNonNegative() && + Known3.isNonNegative()) + Known.makeNonNegative(); } break; @@ -1376,27 +1347,26 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // Otherwise take the unions of the known bit sets of the operands, // taking conservative care to avoid excessive recursion. - if (Depth < MaxDepth - 1 && !KnownZero && !KnownOne) { + if (Depth < MaxDepth - 1 && !Known.Zero && !Known.One) { // Skip if every incoming value references to ourself. if (dyn_cast_or_null<UndefValue>(P->hasConstantValue())) break; - KnownZero = APInt::getAllOnesValue(BitWidth); - KnownOne = APInt::getAllOnesValue(BitWidth); + Known.Zero.setAllBits(); + Known.One.setAllBits(); for (Value *IncValue : P->incoming_values()) { // Skip direct self references. if (IncValue == P) continue; - KnownZero2 = APInt(BitWidth, 0); - KnownOne2 = APInt(BitWidth, 0); + Known2 = KnownBits(BitWidth); // Recurse, but cap the recursion to one level, because we don't // want to waste time spinning around in loops. - computeKnownBits(IncValue, KnownZero2, KnownOne2, MaxDepth - 1, Q); - KnownZero &= KnownZero2; - KnownOne &= KnownOne2; + computeKnownBits(IncValue, Known2, MaxDepth - 1, Q); + Known.Zero &= Known2.Zero; + Known.One &= Known2.One; // If all bits have been ruled out, there's no need to check // more operands. - if (!KnownZero && !KnownOne) + if (!Known.Zero && !Known.One) break; } } @@ -1408,45 +1378,60 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // and then intersect with known bits based on other properties of the // function. if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range)) - computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne); + computeKnownBitsFromRangeMetadata(*MD, Known); if (const Value *RV = ImmutableCallSite(I).getReturnedArgOperand()) { - computeKnownBits(RV, KnownZero2, KnownOne2, Depth + 1, Q); - KnownZero |= KnownZero2; - KnownOne |= KnownOne2; + computeKnownBits(RV, Known2, Depth + 1, Q); + Known.Zero |= Known2.Zero; + Known.One |= Known2.One; } if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { default: break; + case Intrinsic::bitreverse: + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + Known.Zero |= Known2.Zero.reverseBits(); + Known.One |= Known2.One.reverseBits(); + break; case Intrinsic::bswap: - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); - KnownZero |= KnownZero2.byteSwap(); - KnownOne |= KnownOne2.byteSwap(); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + Known.Zero |= Known2.Zero.byteSwap(); + Known.One |= Known2.One.byteSwap(); break; - case Intrinsic::ctlz: + case Intrinsic::ctlz: { + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + // If we have a known 1, its position is our upper bound. + unsigned PossibleLZ = Known2.One.countLeadingZeros(); + // If this call is undefined for 0, the result will be less than 2^n. + if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext())) + PossibleLZ = std::min(PossibleLZ, BitWidth - 1); + unsigned LowBits = Log2_32(PossibleLZ)+1; + Known.Zero.setBitsFrom(LowBits); + break; + } case Intrinsic::cttz: { - unsigned LowBits = Log2_32(BitWidth)+1; + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + // If we have a known 1, its position is our upper bound. + unsigned PossibleTZ = Known2.One.countTrailingZeros(); // If this call is undefined for 0, the result will be less than 2^n. if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext())) - LowBits -= 1; - KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - LowBits); + PossibleTZ = std::min(PossibleTZ, BitWidth - 1); + unsigned LowBits = Log2_32(PossibleTZ)+1; + Known.Zero.setBitsFrom(LowBits); break; } case Intrinsic::ctpop: { - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); // We can bound the space the count needs. Also, bits known to be zero // can't contribute to the population. - unsigned BitsPossiblySet = BitWidth - KnownZero2.countPopulation(); - unsigned LeadingZeros = - APInt(BitWidth, BitsPossiblySet).countLeadingZeros(); - assert(LeadingZeros <= BitWidth); - KnownZero |= APInt::getHighBitsSet(BitWidth, LeadingZeros); - KnownOne &= ~KnownZero; + unsigned BitsPossiblySet = Known2.countMaxPopulation(); + unsigned LowBits = Log2_32(BitsPossiblySet)+1; + Known.Zero.setBitsFrom(LowBits); // TODO: we could bound KnownOne using the lower bound on the number // of bits which might be set provided by popcnt KnownOne2. break; } case Intrinsic::x86_sse42_crc32_64_64: - KnownZero |= APInt::getHighBitsSet(64, 32); + Known.Zero.setBitsFrom(32); break; } } @@ -1456,7 +1441,7 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // tracking the specific element. But at least we might find information // valid for all elements of the vector (for example if vector is sign // extended, shifted, etc). - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); break; case Instruction::ExtractValue: if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I->getOperand(0))) { @@ -1468,20 +1453,19 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: computeKnownBitsAddSub(true, II->getArgOperand(0), - II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, Depth, Q); + II->getArgOperand(1), false, Known, Known2, + Depth, Q); break; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: computeKnownBitsAddSub(false, II->getArgOperand(0), - II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, Depth, Q); + II->getArgOperand(1), false, Known, Known2, + Depth, Q); break; case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false, - KnownZero, KnownOne, KnownZero2, KnownOne2, Depth, - Q); + Known, Known2, Depth, Q); break; } } @@ -1490,7 +1474,15 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, } /// Determine which bits of V are known to be either zero or one and return -/// them in the KnownZero/KnownOne bit sets. +/// them. +KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) { + KnownBits Known(getBitWidth(V->getType(), Q.DL)); + computeKnownBits(V, Known, Depth, Q); + return Known; +} + +/// Determine which bits of V are known to be either zero or one and return +/// them in the Known bit set. /// /// NOTE: we cannot consider 'undef' to be "IsZero" here. The problem is that /// we cannot optimize based on the assumption that it is zero without changing @@ -1504,11 +1496,11 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, /// where V is a vector, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. -void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, - unsigned Depth, const Query &Q) { +void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, + const Query &Q) { assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); - unsigned BitWidth = KnownZero.getBitWidth(); + unsigned BitWidth = Known.getBitWidth(); assert((V->getType()->isIntOrIntVectorTy() || V->getType()->getScalarType()->isPointerTy()) && @@ -1516,21 +1508,19 @@ void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, assert((Q.DL.getTypeSizeInBits(V->getType()->getScalarType()) == BitWidth) && (!V->getType()->isIntOrIntVectorTy() || V->getType()->getScalarSizeInBits() == BitWidth) && - KnownZero.getBitWidth() == BitWidth && - KnownOne.getBitWidth() == BitWidth && - "V, KnownOne and KnownZero should have same BitWidth"); + "V and Known should have same BitWidth"); + (void)BitWidth; const APInt *C; if (match(V, m_APInt(C))) { // We know all of the bits for a scalar constant or a splat vector constant! - KnownOne = *C; - KnownZero = ~KnownOne; + Known.One = *C; + Known.Zero = ~Known.One; return; } // Null and aggregate-zero are all-zeros. if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) { - KnownOne.clearAllBits(); - KnownZero = APInt::getAllOnesValue(BitWidth); + Known.setAllZero(); return; } // Handle a constant vector by taking the intersection of the known bits of @@ -1538,12 +1528,12 @@ void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, if (const ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) { // We know that CDS must be a vector of integers. Take the intersection of // each element. - KnownZero.setAllBits(); KnownOne.setAllBits(); - APInt Elt(KnownZero.getBitWidth(), 0); + Known.Zero.setAllBits(); Known.One.setAllBits(); + APInt Elt(BitWidth, 0); for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) { Elt = CDS->getElementAsInteger(i); - KnownZero &= ~Elt; - KnownOne &= Elt; + Known.Zero &= ~Elt; + Known.One &= Elt; } return; } @@ -1551,25 +1541,24 @@ void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, if (const auto *CV = dyn_cast<ConstantVector>(V)) { // We know that CV must be a vector of integers. Take the intersection of // each element. - KnownZero.setAllBits(); KnownOne.setAllBits(); - APInt Elt(KnownZero.getBitWidth(), 0); + Known.Zero.setAllBits(); Known.One.setAllBits(); + APInt Elt(BitWidth, 0); for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { Constant *Element = CV->getAggregateElement(i); auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); if (!ElementCI) { - KnownZero.clearAllBits(); - KnownOne.clearAllBits(); + Known.resetAll(); return; } Elt = ElementCI->getValue(); - KnownZero &= ~Elt; - KnownOne &= Elt; + Known.Zero &= ~Elt; + Known.One &= Elt; } return; } // Start out not knowing anything. - KnownZero.clearAllBits(); KnownOne.clearAllBits(); + Known.resetAll(); // We can't imply anything about undefs. if (isa<UndefValue>(V)) @@ -1588,44 +1577,27 @@ void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, // the bits of its aliasee. if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { if (!GA->isInterposable()) - computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(GA->getAliasee(), Known, Depth + 1, Q); return; } if (const Operator *I = dyn_cast<Operator>(V)) - computeKnownBitsFromOperator(I, KnownZero, KnownOne, Depth, Q); + computeKnownBitsFromOperator(I, Known, Depth, Q); - // Aligned pointers have trailing zeros - refine KnownZero set + // Aligned pointers have trailing zeros - refine Known.Zero set if (V->getType()->isPointerTy()) { unsigned Align = V->getPointerAlignment(Q.DL); if (Align) - KnownZero |= APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); + Known.Zero.setLowBits(countTrailingZeros(Align)); } - // computeKnownBitsFromAssume strictly refines KnownZero and - // KnownOne. Therefore, we run them after computeKnownBitsFromOperator. + // computeKnownBitsFromAssume strictly refines Known. + // Therefore, we run them after computeKnownBitsFromOperator. // Check whether a nearby assume intrinsic can determine some known bits. - computeKnownBitsFromAssume(V, KnownZero, KnownOne, Depth, Q); + computeKnownBitsFromAssume(V, Known, Depth, Q); - assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); -} - -/// Determine whether the sign bit is known to be zero or one. -/// Convenience wrapper around computeKnownBits. -void ComputeSignBit(const Value *V, bool &KnownZero, bool &KnownOne, - unsigned Depth, const Query &Q) { - unsigned BitWidth = getBitWidth(V->getType(), Q.DL); - if (!BitWidth) { - KnownZero = false; - KnownOne = false; - return; - } - APInt ZeroBits(BitWidth, 0); - APInt OneBits(BitWidth, 0); - computeKnownBits(V, ZeroBits, OneBits, Depth, Q); - KnownOne = OneBits[BitWidth - 1]; - KnownZero = ZeroBits[BitWidth - 1]; + assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?"); } /// Return true if the given value is known to have exactly one @@ -1648,9 +1620,9 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, if (match(V, m_Shl(m_One(), m_Value()))) return true; - // (signbit) >>l X is clearly a power of two if the one is not shifted off the - // bottom. If it is shifted off the bottom then the result is undefined. - if (match(V, m_LShr(m_SignBit(), m_Value()))) + // (signmask) >>l X is clearly a power of two if the one is not shifted off + // the bottom. If it is shifted off the bottom then the result is undefined. + if (match(V, m_LShr(m_SignMask(), m_Value()))) return true; // The remaining tests are all recursive, so bail out if we hit the limit. @@ -1697,18 +1669,18 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, return true; unsigned BitWidth = V->getType()->getScalarSizeInBits(); - APInt LHSZeroBits(BitWidth, 0), LHSOneBits(BitWidth, 0); - computeKnownBits(X, LHSZeroBits, LHSOneBits, Depth, Q); + KnownBits LHSBits(BitWidth); + computeKnownBits(X, LHSBits, Depth, Q); - APInt RHSZeroBits(BitWidth, 0), RHSOneBits(BitWidth, 0); - computeKnownBits(Y, RHSZeroBits, RHSOneBits, Depth, Q); + KnownBits RHSBits(BitWidth); + computeKnownBits(Y, RHSBits, Depth, Q); // If i8 V is a power of two or zero: // ZeroBits: 1 1 1 0 1 1 1 1 // ~ZeroBits: 0 0 0 1 0 0 0 0 - if ((~(LHSZeroBits & RHSZeroBits)).isPowerOf2()) + if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2()) // If OrZero isn't set, we cannot give back a zero result. // Make sure either the LHS or RHS has a bit set. - if (OrZero || RHSOneBits.getBoolValue() || LHSOneBits.getBoolValue()) + if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue()) return true; } } @@ -1805,10 +1777,12 @@ static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) return true; } -/// Return true if the given value is known to be non-zero when defined. -/// For vectors return true if every element is known to be non-zero when -/// defined. Supports values with integer or pointer type and vectors of -/// integers. +/// Return true if the given value is known to be non-zero when defined. For +/// vectors, return true if every element is known to be non-zero when +/// defined. For pointers, if the context instruction and dominator tree are +/// specified, perform context-sensitive analysis and return true if the +/// pointer couldn't possibly be null at the specified instruction. +/// Supports values with integer or pointer type and vectors of integers. bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (auto *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) @@ -1851,7 +1825,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // Check for pointer simplifications. if (V->getType()->isPointerTy()) { - if (isKnownNonNull(V)) + if (isKnownNonNullAt(V, Q.CxtI, Q.DT)) return true; if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) if (isGEPKnownNonNull(GEP, Depth, Q)) @@ -1871,16 +1845,15 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined // if the lowest bit is shifted off the end. - if (BitWidth && match(V, m_Shl(m_Value(X), m_Value(Y)))) { + if (match(V, m_Shl(m_Value(X), m_Value(Y)))) { // shl nuw can't remove any non-zero bits. const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); if (BO->hasNoUnsignedWrap()) return isKnownNonZero(X, Depth, Q); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(X, KnownZero, KnownOne, Depth, Q); - if (KnownOne[0]) + KnownBits Known(BitWidth); + computeKnownBits(X, Known, Depth, Q); + if (Known.One[0]) return true; } // shr X, Y != 0 if X is negative. Note that the value of the shift is not @@ -1891,25 +1864,20 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (BO->isExact()) return isKnownNonZero(X, Depth, Q); - bool XKnownNonNegative, XKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, Depth, Q); - if (XKnownNegative) + KnownBits Known = computeKnownBits(X, Depth, Q); + if (Known.isNegative()) return true; // If the shifter operand is a constant, and all of the bits shifted // out are known to be zero, and X is known non-zero then at least one // non-zero bit must remain. if (ConstantInt *Shift = dyn_cast<ConstantInt>(Y)) { - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(X, KnownZero, KnownOne, Depth, Q); - auto ShiftVal = Shift->getLimitedValue(BitWidth - 1); // Is there a known one in the portion not shifted out? - if (KnownOne.countLeadingZeros() < BitWidth - ShiftVal) + if (Known.countMaxLeadingZeros() < BitWidth - ShiftVal) return true; // Are all the bits to be shifted out known zero? - if (KnownZero.countTrailingOnes() >= ShiftVal) + if (Known.countMinTrailingZeros() >= ShiftVal) return isKnownNonZero(X, Depth, Q); } } @@ -1919,40 +1887,34 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { } // X + Y. else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { - bool XKnownNonNegative, XKnownNegative; - bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, Depth, Q); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Depth, Q); + KnownBits XKnown = computeKnownBits(X, Depth, Q); + KnownBits YKnown = computeKnownBits(Y, Depth, Q); // If X and Y are both non-negative (as signed values) then their sum is not // zero unless both X and Y are zero. - if (XKnownNonNegative && YKnownNonNegative) + if (XKnown.isNonNegative() && YKnown.isNonNegative()) if (isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q)) return true; // If X and Y are both negative (as signed values) then their sum is not // zero unless both X and Y equal INT_MIN. - if (BitWidth && XKnownNegative && YKnownNegative) { - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); + if (XKnown.isNegative() && YKnown.isNegative()) { APInt Mask = APInt::getSignedMaxValue(BitWidth); // The sign bit of X is set. If some other bit is set then X is not equal // to INT_MIN. - computeKnownBits(X, KnownZero, KnownOne, Depth, Q); - if ((KnownOne & Mask) != 0) + if (XKnown.One.intersects(Mask)) return true; // The sign bit of Y is set. If some other bit is set then Y is not equal // to INT_MIN. - computeKnownBits(Y, KnownZero, KnownOne, Depth, Q); - if ((KnownOne & Mask) != 0) + if (YKnown.One.intersects(Mask)) return true; } // The sum of a non-negative number and a power of two is not zero. - if (XKnownNonNegative && + if (XKnown.isNonNegative() && isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q)) return true; - if (YKnownNonNegative && + if (YKnown.isNonNegative() && isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q)) return true; } @@ -1998,11 +1960,9 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { return true; } - if (!BitWidth) return false; - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, Depth, Q); - return KnownOne != 0; + KnownBits Known(BitWidth); + computeKnownBits(V, Known, Depth, Q); + return Known.One != 0; } /// Return true if V2 == V1 + X, where X is known non-zero. @@ -2034,14 +1994,13 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q) { // Are any known bits in V1 contradictory to known bits in V2? If V1 // has a known zero where V2 has a known one, they must not be equal. auto BitWidth = Ty->getBitWidth(); - APInt KnownZero1(BitWidth, 0); - APInt KnownOne1(BitWidth, 0); - computeKnownBits(V1, KnownZero1, KnownOne1, 0, Q); - APInt KnownZero2(BitWidth, 0); - APInt KnownOne2(BitWidth, 0); - computeKnownBits(V2, KnownZero2, KnownOne2, 0, Q); - - auto OppositeBits = (KnownZero1 & KnownOne2) | (KnownZero2 & KnownOne1); + KnownBits Known1(BitWidth); + computeKnownBits(V1, Known1, 0, Q); + KnownBits Known2(BitWidth); + computeKnownBits(V2, Known2, 0, Q); + + APInt OppositeBits = (Known1.Zero & Known2.One) | + (Known2.Zero & Known1.One); if (OppositeBits.getBoolValue()) return true; } @@ -2059,9 +2018,9 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q) { /// for all of the elements in the vector. bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, const Query &Q) { - APInt KnownZero(Mask.getBitWidth(), 0), KnownOne(Mask.getBitWidth(), 0); - computeKnownBits(V, KnownZero, KnownOne, Depth, Q); - return (KnownZero & Mask) == Mask; + KnownBits Known(Mask.getBitWidth()); + computeKnownBits(V, Known, Depth, Q); + return Mask.isSubsetOf(Known.Zero); } /// For vector constants, loop over the elements and find the constant with the @@ -2092,13 +2051,29 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V, return MinSignBits; } +static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, + const Query &Q); + +static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, + const Query &Q) { + unsigned Result = ComputeNumSignBitsImpl(V, Depth, Q); + assert(Result > 0 && "At least one sign bit needs to be present!"); + return Result; +} + /// Return the number of times the sign bit of the register is replicated into /// the other bits. We know that at least 1 bit is always equal to the sign bit /// (itself), but other cases can give us information. For example, immediately /// after an "ashr X, 2", we know that the top 3 bits are all equal to each /// other, so we return 3. For vectors, return the number of sign bits for the /// vector element with the mininum number of known sign bits. -unsigned ComputeNumSignBits(const Value *V, unsigned Depth, const Query &Q) { +static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, + const Query &Q) { + + // We return the minimum number of sign bits that are guaranteed to be present + // in V, so for undef we have to conservatively return 1. We don't have the + // same behavior for poison though -- that's a FIXME today. + unsigned TyBits = Q.DL.getTypeSizeInBits(V->getType()->getScalarType()); unsigned Tmp, Tmp2; unsigned FirstAnswer = 1; @@ -2174,7 +2149,10 @@ unsigned ComputeNumSignBits(const Value *V, unsigned Depth, const Query &Q) { // ashr X, C -> adds C sign bits. Vectors too. const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { - Tmp += ShAmt->getZExtValue(); + unsigned ShAmtLimited = ShAmt->getZExtValue(); + if (ShAmtLimited >= TyBits) + break; // Bad shift. + Tmp += ShAmtLimited; if (Tmp > TyBits) Tmp = TyBits; } return Tmp; @@ -2220,17 +2198,17 @@ unsigned ComputeNumSignBits(const Value *V, unsigned Depth, const Query &Q) { // Special case decrementing a value (ADD X, -1): if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1))) if (CRHS->isAllOnesValue()) { - APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); + KnownBits Known(TyBits); + computeKnownBits(U->getOperand(0), Known, Depth + 1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. - if ((KnownZero | APInt(TyBits, 1)).isAllOnesValue()) + if ((Known.Zero | 1).isAllOnesValue()) return TyBits; // If we are subtracting one from a positive number, there is no carry // out of the result. - if (KnownZero.isNegative()) + if (Known.isNonNegative()) return Tmp; } @@ -2245,16 +2223,16 @@ unsigned ComputeNumSignBits(const Value *V, unsigned Depth, const Query &Q) { // Handle NEG. if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0))) if (CLHS->isNullValue()) { - APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); + KnownBits Known(TyBits); + computeKnownBits(U->getOperand(1), Known, Depth + 1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. - if ((KnownZero | APInt(TyBits, 1)).isAllOnesValue()) + if ((Known.Zero | 1).isAllOnesValue()) return TyBits; // If the input is known to be positive (the sign bit is known clear), // the output of the NEG has the same number of sign bits as the input. - if (KnownZero.isNegative()) + if (Known.isNonNegative()) return Tmp2; // Otherwise, we treat this like a SUB. @@ -2306,19 +2284,12 @@ unsigned ComputeNumSignBits(const Value *V, unsigned Depth, const Query &Q) { if (unsigned VecSignBits = computeNumSignBitsVectorConstant(V, TyBits)) return VecSignBits; - APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(V, KnownZero, KnownOne, Depth, Q); + KnownBits Known(TyBits); + computeKnownBits(V, Known, Depth, Q); // If we know that the sign bit is either zero or one, determine the number of // identical bits in the top of the input value. - if (KnownZero.isNegative()) - return std::max(FirstAnswer, KnownZero.countLeadingOnes()); - - if (KnownOne.isNegative()) - return std::max(FirstAnswer, KnownOne.countLeadingOnes()); - - // computeKnownBits gave us no extra information about the top bits. - return FirstAnswer; + return std::max(FirstAnswer, Known.countMinSignBits()); } /// This function computes the integer multiple of Base that equals V. @@ -2368,6 +2339,7 @@ bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, case Instruction::SExt: if (!LookThroughSExt) return false; // otherwise fall through to ZExt + LLVM_FALLTHROUGH; case Instruction::ZExt: return ComputeMultiple(I->getOperand(0), Base, Multiple, LookThroughSExt, Depth+1); @@ -2453,7 +2425,7 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, if (!TLI) return Intrinsic::not_intrinsic; - LibFunc::Func Func; + LibFunc Func; // We're going to make assumptions on the semantics of the functions, check // that the target knows that it's available in this environment and it does // not have local linkage. @@ -2468,81 +2440,81 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, switch (Func) { default: break; - case LibFunc::sin: - case LibFunc::sinf: - case LibFunc::sinl: + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_sinl: return Intrinsic::sin; - case LibFunc::cos: - case LibFunc::cosf: - case LibFunc::cosl: + case LibFunc_cos: + case LibFunc_cosf: + case LibFunc_cosl: return Intrinsic::cos; - case LibFunc::exp: - case LibFunc::expf: - case LibFunc::expl: + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: return Intrinsic::exp; - case LibFunc::exp2: - case LibFunc::exp2f: - case LibFunc::exp2l: + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: return Intrinsic::exp2; - case LibFunc::log: - case LibFunc::logf: - case LibFunc::logl: + case LibFunc_log: + case LibFunc_logf: + case LibFunc_logl: return Intrinsic::log; - case LibFunc::log10: - case LibFunc::log10f: - case LibFunc::log10l: + case LibFunc_log10: + case LibFunc_log10f: + case LibFunc_log10l: return Intrinsic::log10; - case LibFunc::log2: - case LibFunc::log2f: - case LibFunc::log2l: + case LibFunc_log2: + case LibFunc_log2f: + case LibFunc_log2l: return Intrinsic::log2; - case LibFunc::fabs: - case LibFunc::fabsf: - case LibFunc::fabsl: + case LibFunc_fabs: + case LibFunc_fabsf: + case LibFunc_fabsl: return Intrinsic::fabs; - case LibFunc::fmin: - case LibFunc::fminf: - case LibFunc::fminl: + case LibFunc_fmin: + case LibFunc_fminf: + case LibFunc_fminl: return Intrinsic::minnum; - case LibFunc::fmax: - case LibFunc::fmaxf: - case LibFunc::fmaxl: + case LibFunc_fmax: + case LibFunc_fmaxf: + case LibFunc_fmaxl: return Intrinsic::maxnum; - case LibFunc::copysign: - case LibFunc::copysignf: - case LibFunc::copysignl: + case LibFunc_copysign: + case LibFunc_copysignf: + case LibFunc_copysignl: return Intrinsic::copysign; - case LibFunc::floor: - case LibFunc::floorf: - case LibFunc::floorl: + case LibFunc_floor: + case LibFunc_floorf: + case LibFunc_floorl: return Intrinsic::floor; - case LibFunc::ceil: - case LibFunc::ceilf: - case LibFunc::ceill: + case LibFunc_ceil: + case LibFunc_ceilf: + case LibFunc_ceill: return Intrinsic::ceil; - case LibFunc::trunc: - case LibFunc::truncf: - case LibFunc::truncl: + case LibFunc_trunc: + case LibFunc_truncf: + case LibFunc_truncl: return Intrinsic::trunc; - case LibFunc::rint: - case LibFunc::rintf: - case LibFunc::rintl: + case LibFunc_rint: + case LibFunc_rintf: + case LibFunc_rintl: return Intrinsic::rint; - case LibFunc::nearbyint: - case LibFunc::nearbyintf: - case LibFunc::nearbyintl: + case LibFunc_nearbyint: + case LibFunc_nearbyintf: + case LibFunc_nearbyintl: return Intrinsic::nearbyint; - case LibFunc::round: - case LibFunc::roundf: - case LibFunc::roundl: + case LibFunc_round: + case LibFunc_roundf: + case LibFunc_roundl: return Intrinsic::round; - case LibFunc::pow: - case LibFunc::powf: - case LibFunc::powl: + case LibFunc_pow: + case LibFunc_powf: + case LibFunc_powl: return Intrinsic::pow; - case LibFunc::sqrt: - case LibFunc::sqrtf: - case LibFunc::sqrtl: + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: if (ICS->hasNoNaNs()) return Intrinsic::sqrt; return Intrinsic::not_intrinsic; @@ -2607,6 +2579,11 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, const TargetLibraryInfo *TLI, bool SignBitOnly, unsigned Depth) { + // TODO: This function does not do the right thing when SignBitOnly is true + // and we're lowering to a hypothetical IEEE 754-compliant-but-evil platform + // which flips the sign bits of NaNs. See + // https://llvm.org/bugs/show_bug.cgi?id=31702. + if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { return !CFP->getValueAPF().isNegative() || (!SignBitOnly && CFP->getValueAPF().isZero()); @@ -2650,7 +2627,8 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1); case Instruction::Call: - Intrinsic::ID IID = getIntrinsicForCallSite(cast<CallInst>(I), TLI); + const auto *CI = cast<CallInst>(I); + Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); switch (IID) { default: break; @@ -2667,16 +2645,37 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, case Intrinsic::exp: case Intrinsic::exp2: case Intrinsic::fabs: - case Intrinsic::sqrt: return true; + + case Intrinsic::sqrt: + // sqrt(x) is always >= -0 or NaN. Moreover, sqrt(x) == -0 iff x == -0. + if (!SignBitOnly) + return true; + return CI->hasNoNaNs() && (CI->hasNoSignedZeros() || + CannotBeNegativeZero(CI->getOperand(0), TLI)); + case Intrinsic::powi: - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + if (ConstantInt *Exponent = dyn_cast<ConstantInt>(I->getOperand(1))) { // powi(x,n) is non-negative if n is even. - if (CI->getBitWidth() <= 64 && CI->getSExtValue() % 2u == 0) + if (Exponent->getBitWidth() <= 64 && Exponent->getSExtValue() % 2u == 0) return true; } + // TODO: This is not correct. Given that exp is an integer, here are the + // ways that pow can return a negative value: + // + // pow(x, exp) --> negative if exp is odd and x is negative. + // pow(-0, exp) --> -inf if exp is negative odd. + // pow(-0, exp) --> -0 if exp is positive odd. + // pow(-inf, exp) --> -0 if exp is negative odd. + // pow(-inf, exp) --> -inf if exp is positive odd. + // + // Therefore, if !SignBitOnly, we can return true if x >= +0 or x is NaN, + // but we must return false if x == -0. Unfortunately we do not currently + // have a way of expressing this constraint. See details in + // https://llvm.org/bugs/show_bug.cgi?id=31702. return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1); + case Intrinsic::fma: case Intrinsic::fmuladd: // x*x+y is non-negative if y is non-negative. @@ -2970,14 +2969,16 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, return Ptr; } -bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP) { +bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP, + unsigned CharSize) { // Make sure the GEP has exactly three arguments. if (GEP->getNumOperands() != 3) return false; - // Make sure the index-ee is a pointer to array of i8. + // Make sure the index-ee is a pointer to array of \p CharSize integers. + // CharSize. ArrayType *AT = dyn_cast<ArrayType>(GEP->getSourceElementType()); - if (!AT || !AT->getElementType()->isIntegerTy(8)) + if (!AT || !AT->getElementType()->isIntegerTy(CharSize)) return false; // Check to make sure that the first operand of the GEP is an integer and @@ -2989,11 +2990,9 @@ bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP) { return true; } -/// This function computes the length of a null-terminated C string pointed to -/// by V. If successful, it returns true and returns the string in Str. -/// If unsuccessful, it returns false. -bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, - uint64_t Offset, bool TrimAtNul) { +bool llvm::getConstantDataArrayInfo(const Value *V, + ConstantDataArraySlice &Slice, + unsigned ElementSize, uint64_t Offset) { assert(V); // Look through bitcast instructions and geps. @@ -3004,7 +3003,7 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { // The GEP operator should be based on a pointer to string constant, and is // indexing into the string constant. - if (!isGEPBasedOnPointerToString(GEP)) + if (!isGEPBasedOnPointerToString(GEP, ElementSize)) return false; // If the second index isn't a ConstantInt, then this is a variable index @@ -3015,8 +3014,8 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, StartIdx = CI->getZExtValue(); else return false; - return getConstantStringInfo(GEP->getOperand(0), Str, StartIdx + Offset, - TrimAtNul); + return getConstantDataArrayInfo(GEP->getOperand(0), Slice, ElementSize, + StartIdx + Offset); } // The GEP instruction, constant or instruction, must reference a global @@ -3026,30 +3025,72 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) return false; - // Handle the all-zeros case. + const ConstantDataArray *Array; + ArrayType *ArrayTy; if (GV->getInitializer()->isNullValue()) { - // This is a degenerate case. The initializer is constant zero so the - // length of the string must be zero. - Str = ""; - return true; + Type *GVTy = GV->getValueType(); + if ( (ArrayTy = dyn_cast<ArrayType>(GVTy)) ) { + // A zeroinitializer for the array; There is no ConstantDataArray. + Array = nullptr; + } else { + const DataLayout &DL = GV->getParent()->getDataLayout(); + uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy); + uint64_t Length = SizeInBytes / (ElementSize / 8); + if (Length <= Offset) + return false; + + Slice.Array = nullptr; + Slice.Offset = 0; + Slice.Length = Length - Offset; + return true; + } + } else { + // This must be a ConstantDataArray. + Array = dyn_cast<ConstantDataArray>(GV->getInitializer()); + if (!Array) + return false; + ArrayTy = Array->getType(); } + if (!ArrayTy->getElementType()->isIntegerTy(ElementSize)) + return false; - // This must be a ConstantDataArray. - const auto *Array = dyn_cast<ConstantDataArray>(GV->getInitializer()); - if (!Array || !Array->isString()) + uint64_t NumElts = ArrayTy->getArrayNumElements(); + if (Offset > NumElts) return false; - // Get the number of elements in the array. - uint64_t NumElts = Array->getType()->getArrayNumElements(); + Slice.Array = Array; + Slice.Offset = Offset; + Slice.Length = NumElts - Offset; + return true; +} - // Start out with the entire array in the StringRef. - Str = Array->getAsString(); +/// This function computes the length of a null-terminated C string pointed to +/// by V. If successful, it returns true and returns the string in Str. +/// If unsuccessful, it returns false. +bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, + uint64_t Offset, bool TrimAtNul) { + ConstantDataArraySlice Slice; + if (!getConstantDataArrayInfo(V, Slice, 8, Offset)) + return false; - if (Offset > NumElts) + if (Slice.Array == nullptr) { + if (TrimAtNul) { + Str = StringRef(); + return true; + } + if (Slice.Length == 1) { + Str = StringRef("", 1); + return true; + } + // We cannot instantiate a StringRef as we do not have an apropriate string + // of 0s at hand. return false; + } + // Start out with the entire array in the StringRef. + Str = Slice.Array->getAsString(); // Skip over 'offset' bytes. - Str = Str.substr(Offset); + Str = Str.substr(Slice.Offset); if (TrimAtNul) { // Trim off the \0 and anything after it. If the array is not nul @@ -3067,7 +3108,8 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, /// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. static uint64_t GetStringLengthH(const Value *V, - SmallPtrSetImpl<const PHINode*> &PHIs) { + SmallPtrSetImpl<const PHINode*> &PHIs, + unsigned CharSize) { // Look through noop bitcast instructions. V = V->stripPointerCasts(); @@ -3080,7 +3122,7 @@ static uint64_t GetStringLengthH(const Value *V, // If it was new, see if all the input strings are the same length. uint64_t LenSoFar = ~0ULL; for (Value *IncValue : PN->incoming_values()) { - uint64_t Len = GetStringLengthH(IncValue, PHIs); + uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize); if (Len == 0) return 0; // Unknown length -> unknown. if (Len == ~0ULL) continue; @@ -3096,9 +3138,9 @@ static uint64_t GetStringLengthH(const Value *V, // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y) if (const SelectInst *SI = dyn_cast<SelectInst>(V)) { - uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs); + uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize); if (Len1 == 0) return 0; - uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs); + uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize); if (Len2 == 0) return 0; if (Len1 == ~0ULL) return Len2; if (Len2 == ~0ULL) return Len1; @@ -3107,20 +3149,30 @@ static uint64_t GetStringLengthH(const Value *V, } // Otherwise, see if we can read the string. - StringRef StrData; - if (!getConstantStringInfo(V, StrData)) + ConstantDataArraySlice Slice; + if (!getConstantDataArrayInfo(V, Slice, CharSize)) return 0; - return StrData.size()+1; + if (Slice.Array == nullptr) + return 1; + + // Search for nul characters + unsigned NullIndex = 0; + for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) { + if (Slice.Array->getElementAsInteger(Slice.Offset + NullIndex) == 0) + break; + } + + return NullIndex + 1; } /// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. -uint64_t llvm::GetStringLength(const Value *V) { +uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) { if (!V->getType()->isPointerTy()) return 0; SmallPtrSet<const PHINode*, 32> PHIs; - uint64_t Len = GetStringLengthH(V, PHIs); + uint64_t Len = GetStringLengthH(V, PHIs, CharSize); // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return // an empty string as a length. return Len == ~0ULL ? 1 : Len; @@ -3167,6 +3219,9 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, if (GA->isInterposable()) return V; V = GA->getAliasee(); + } else if (isa<AllocaInst>(V)) { + // An alloca can't be further simplified. + return V; } else { if (auto CS = CallSite(V)) if (Value *RV = CS.getReturnedArgOperand()) { @@ -3177,7 +3232,7 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, // See if InstructionSimplify knows any relevant tricks. if (Instruction *I = dyn_cast<Instruction>(V)) // TODO: Acquire a DominatorTree and AssumptionCache and use them. - if (Value *Simplified = SimplifyInstruction(I, DL, nullptr)) { + if (Value *Simplified = SimplifyInstruction(I, {DL, I})) { V = Simplified; continue; } @@ -3298,59 +3353,12 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, LI->getAlignment(), DL, CtxI, DT); } case Instruction::Call: { - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { - switch (II->getIntrinsicID()) { - // These synthetic intrinsics have no side-effects and just mark - // information about their operands. - // FIXME: There are other no-op synthetic instructions that potentially - // should be considered at least *safe* to speculate... - case Intrinsic::dbg_declare: - case Intrinsic::dbg_value: - return true; + auto *CI = cast<const CallInst>(Inst); + const Function *Callee = CI->getCalledFunction(); - case Intrinsic::bitreverse: - case Intrinsic::bswap: - case Intrinsic::ctlz: - case Intrinsic::ctpop: - case Intrinsic::cttz: - case Intrinsic::objectsize: - case Intrinsic::sadd_with_overflow: - case Intrinsic::smul_with_overflow: - case Intrinsic::ssub_with_overflow: - case Intrinsic::uadd_with_overflow: - case Intrinsic::umul_with_overflow: - case Intrinsic::usub_with_overflow: - return true; - // These intrinsics are defined to have the same behavior as libm - // functions except for setting errno. - case Intrinsic::sqrt: - case Intrinsic::fma: - case Intrinsic::fmuladd: - return true; - // These intrinsics are defined to have the same behavior as libm - // functions, and the corresponding libm functions never set errno. - case Intrinsic::trunc: - case Intrinsic::copysign: - case Intrinsic::fabs: - case Intrinsic::minnum: - case Intrinsic::maxnum: - return true; - // These intrinsics are defined to have the same behavior as libm - // functions, which never overflow when operating on the IEEE754 types - // that we support, and never set errno otherwise. - case Intrinsic::ceil: - case Intrinsic::floor: - case Intrinsic::nearbyint: - case Intrinsic::rint: - case Intrinsic::round: - return true; - // TODO: are convert_{from,to}_fp16 safe? - // TODO: can we list target-specific intrinsics here? - default: break; - } - } - return false; // The called function could have undefined behavior or - // side-effects, even if marked readnone nounwind. + // The called function could have undefined behavior or side-effects, even + // if marked readnone nounwind. + return Callee && Callee->isSpeculatable(); } case Instruction::VAArg: case Instruction::Alloca: @@ -3423,6 +3431,16 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V, if (NumUsesExplored >= DomConditionsMaxUses) break; NumUsesExplored++; + + // If the value is used as an argument to a call or invoke, then argument + // attributes may provide an answer about null-ness. + if (auto CS = ImmutableCallSite(U)) + if (auto *CalledFunc = CS.getCalledFunction()) + for (const Argument &Arg : CalledFunc->args()) + if (CS.getArgOperand(Arg.getArgNo()) == V && + Arg.hasNonNullAttr() && DT->dominates(CS.getInstruction(), CtxI)) + return true; + // Consider only compare instructions uniquely controlling a branch CmpInst::Predicate Pred; if (!match(const_cast<User *>(U), @@ -3477,38 +3495,34 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, // we can guarantee that the result does not overflow. // Ref: "Hacker's Delight" by Henry Warren unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, DL, /*Depth=*/0, AC, CxtI, - DT); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, DL, /*Depth=*/0, AC, CxtI, - DT); + KnownBits LHSKnown(BitWidth); + KnownBits RHSKnown(BitWidth); + computeKnownBits(LHS, LHSKnown, DL, /*Depth=*/0, AC, CxtI, DT); + computeKnownBits(RHS, RHSKnown, DL, /*Depth=*/0, AC, CxtI, DT); // Note that underestimating the number of zero bits gives a more // conservative answer. - unsigned ZeroBits = LHSKnownZero.countLeadingOnes() + - RHSKnownZero.countLeadingOnes(); + unsigned ZeroBits = LHSKnown.countMinLeadingZeros() + + RHSKnown.countMinLeadingZeros(); // First handle the easy case: if we have enough zero bits there's // definitely no overflow. if (ZeroBits >= BitWidth) return OverflowResult::NeverOverflows; // Get the largest possible values for each operand. - APInt LHSMax = ~LHSKnownZero; - APInt RHSMax = ~RHSKnownZero; + APInt LHSMax = ~LHSKnown.Zero; + APInt RHSMax = ~RHSKnown.Zero; // We know the multiply operation doesn't overflow if the maximum values for // each operand will not overflow after we multiply them together. bool MaxOverflow; - LHSMax.umul_ov(RHSMax, MaxOverflow); + (void)LHSMax.umul_ov(RHSMax, MaxOverflow); if (!MaxOverflow) return OverflowResult::NeverOverflows; // We know it always overflows if multiplying the smallest possible values for // the operands also results in overflow. bool MinOverflow; - LHSKnownOne.umul_ov(RHSKnownOne, MinOverflow); + (void)LHSKnown.One.umul_ov(RHSKnown.One, MinOverflow); if (MinOverflow) return OverflowResult::AlwaysOverflows; @@ -3521,21 +3535,17 @@ OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - bool LHSKnownNonNegative, LHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, /*Depth=*/0, - AC, CxtI, DT); - if (LHSKnownNonNegative || LHSKnownNegative) { - bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, /*Depth=*/0, - AC, CxtI, DT); - - if (LHSKnownNegative && RHSKnownNegative) { + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); + if (LHSKnown.isNonNegative() || LHSKnown.isNegative()) { + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + + if (LHSKnown.isNegative() && RHSKnown.isNegative()) { // The sign bit is set in both cases: this MUST overflow. // Create a simple add instruction, and insert it into the struct. return OverflowResult::AlwaysOverflows; } - if (LHSKnownNonNegative && RHSKnownNonNegative) { + if (LHSKnown.isNonNegative() && RHSKnown.isNonNegative()) { // The sign bit is clear in both cases: this CANNOT overflow. // Create a simple add instruction, and insert it into the struct. return OverflowResult::NeverOverflows; @@ -3545,6 +3555,51 @@ OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, return OverflowResult::MayOverflow; } +/// \brief Return true if we can prove that adding the two values of the +/// knownbits will not overflow. +/// Otherwise return false. +static bool checkRippleForSignedAdd(const KnownBits &LHSKnown, + const KnownBits &RHSKnown) { + // Addition of two 2's complement numbers having opposite signs will never + // overflow. + if ((LHSKnown.isNegative() && RHSKnown.isNonNegative()) || + (LHSKnown.isNonNegative() && RHSKnown.isNegative())) + return true; + + // If either of the values is known to be non-negative, adding them can only + // overflow if the second is also non-negative, so we can assume that. + // Two non-negative numbers will only overflow if there is a carry to the + // sign bit, so we can check if even when the values are as big as possible + // there is no overflow to the sign bit. + if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) { + APInt MaxLHS = ~LHSKnown.Zero; + MaxLHS.clearSignBit(); + APInt MaxRHS = ~RHSKnown.Zero; + MaxRHS.clearSignBit(); + APInt Result = std::move(MaxLHS) + std::move(MaxRHS); + return Result.isSignBitClear(); + } + + // If either of the values is known to be negative, adding them can only + // overflow if the second is also negative, so we can assume that. + // Two negative number will only overflow if there is no carry to the sign + // bit, so we can check if even when the values are as small as possible + // there is overflow to the sign bit. + if (LHSKnown.isNegative() || RHSKnown.isNegative()) { + APInt MinLHS = LHSKnown.One; + MinLHS.clearSignBit(); + APInt MinRHS = RHSKnown.One; + MinRHS.clearSignBit(); + APInt Result = std::move(MinLHS) + std::move(MinRHS); + return Result.isSignBitSet(); + } + + // If we reached here it means that we know nothing about the sign bits. + // In this case we can't know if there will be an overflow, since by + // changing the sign bits any two values can be made to overflow. + return false; +} + static OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS, const AddOperator *Add, @@ -3556,18 +3611,29 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS, return OverflowResult::NeverOverflows; } - bool LHSKnownNonNegative, LHSKnownNegative; - bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, /*Depth=*/0, - AC, CxtI, DT); - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, /*Depth=*/0, - AC, CxtI, DT); + // If LHS and RHS each have at least two sign bits, the addition will look + // like + // + // XX..... + + // YY..... + // + // If the carry into the most significant position is 0, X and Y can't both + // be 1 and therefore the carry out of the addition is also 0. + // + // If the carry into the most significant position is 1, X and Y can't both + // be 0 and therefore the carry out of the addition is also 1. + // + // Since the carry into the most significant position is always equal to + // the carry out of the addition, there is no signed overflow. + if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 && + ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1) + return OverflowResult::NeverOverflows; + + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); - if ((LHSKnownNonNegative && RHSKnownNegative) || - (LHSKnownNegative && RHSKnownNonNegative)) { - // The sign bits are opposite: this CANNOT overflow. + if (checkRippleForSignedAdd(LHSKnown, RHSKnown)) return OverflowResult::NeverOverflows; - } // The remaining code needs Add to be available. Early returns if not so. if (!Add) @@ -3578,14 +3644,13 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS, // @llvm.assume'ed non-negative rather than proved so from analyzing its // operands. bool LHSOrRHSKnownNonNegative = - (LHSKnownNonNegative || RHSKnownNonNegative); - bool LHSOrRHSKnownNegative = (LHSKnownNegative || RHSKnownNegative); + (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()); + bool LHSOrRHSKnownNegative = + (LHSKnown.isNegative() || RHSKnown.isNegative()); if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) { - bool AddKnownNonNegative, AddKnownNegative; - ComputeSignBit(Add, AddKnownNonNegative, AddKnownNegative, DL, - /*Depth=*/0, AC, CxtI, DT); - if ((AddKnownNonNegative && LHSOrRHSKnownNonNegative) || - (AddKnownNegative && LHSOrRHSKnownNegative)) { + KnownBits AddKnown = computeKnownBits(Add, DL, /*Depth=*/0, AC, CxtI, DT); + if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) || + (AddKnown.isNegative() && LHSOrRHSKnownNegative)) { return OverflowResult::NeverOverflows; } } @@ -3700,6 +3765,8 @@ bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { return false; if (isa<ReturnInst>(I)) return false; + if (isa<UnreachableInst>(I)) + return false; // Calls can throw, or contain an infinite loop, or kill the process. if (auto CS = ImmutableCallSite(I)) { @@ -3748,79 +3815,33 @@ bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I, bool llvm::propagatesFullPoison(const Instruction *I) { switch (I->getOpcode()) { - case Instruction::Add: - case Instruction::Sub: - case Instruction::Xor: - case Instruction::Trunc: - case Instruction::BitCast: - case Instruction::AddrSpaceCast: - // These operations all propagate poison unconditionally. Note that poison - // is not any particular value, so xor or subtraction of poison with - // itself still yields poison, not zero. - return true; - - case Instruction::AShr: - case Instruction::SExt: - // For these operations, one bit of the input is replicated across - // multiple output bits. A replicated poison bit is still poison. - return true; - - case Instruction::Shl: { - // Left shift *by* a poison value is poison. The number of - // positions to shift is unsigned, so no negative values are - // possible there. Left shift by zero places preserves poison. So - // it only remains to consider left shift of poison by a positive - // number of places. - // - // A left shift by a positive number of places leaves the lowest order bit - // non-poisoned. However, if such a shift has a no-wrap flag, then we can - // make the poison operand violate that flag, yielding a fresh full-poison - // value. - auto *OBO = cast<OverflowingBinaryOperator>(I); - return OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap(); - } - - case Instruction::Mul: { - // A multiplication by zero yields a non-poison zero result, so we need to - // rule out zero as an operand. Conservatively, multiplication by a - // non-zero constant is not multiplication by zero. - // - // Multiplication by a non-zero constant can leave some bits - // non-poisoned. For example, a multiplication by 2 leaves the lowest - // order bit unpoisoned. So we need to consider that. - // - // Multiplication by 1 preserves poison. If the multiplication has a - // no-wrap flag, then we can make the poison operand violate that flag - // when multiplied by any integer other than 0 and 1. - auto *OBO = cast<OverflowingBinaryOperator>(I); - if (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) { - for (Value *V : OBO->operands()) { - if (auto *CI = dyn_cast<ConstantInt>(V)) { - // A ConstantInt cannot yield poison, so we can assume that it is - // the other operand that is poison. - return !CI->isZero(); - } - } - } - return false; - } + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + case Instruction::Trunc: + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::Mul: + case Instruction::Shl: + case Instruction::GetElementPtr: + // These operations all propagate poison unconditionally. Note that poison + // is not any particular value, so xor or subtraction of poison with + // itself still yields poison, not zero. + return true; - case Instruction::ICmp: - // Comparing poison with any value yields poison. This is why, for - // instance, x s< (x +nsw 1) can be folded to true. - return true; + case Instruction::AShr: + case Instruction::SExt: + // For these operations, one bit of the input is replicated across + // multiple output bits. A replicated poison bit is still poison. + return true; - case Instruction::GetElementPtr: - // A GEP implicitly represents a sequence of additions, subtractions, - // truncations, sign extensions and multiplications. The multiplications - // are by the non-zero sizes of some set of types, so we do not have to be - // concerned with multiplication by zero. If the GEP is in-bounds, then - // these operations are implicitly no-signed-wrap so poison is propagated - // by the arguments above for Add, Sub, Trunc, SExt and Mul. - return cast<GEPOperator>(I)->isInBounds(); + case Instruction::ICmp: + // Comparing poison with any value yields poison. This is why, for + // instance, x s< (x +nsw 1) can be folded to true. + return true; - default: - return false; + default: + return false; } } @@ -3849,7 +3870,7 @@ const Value *llvm::getGuaranteedNonFullPoisonOp(const Instruction *I) { } } -bool llvm::isKnownNotFullPoison(const Instruction *PoisonI) { +bool llvm::programUndefinedIfFullPoison(const Instruction *PoisonI) { // We currently only look for uses of poison values within the same basic // block, as that makes it easier to guarantee that the uses will be // executed given that PoisonI is executed. @@ -3923,6 +3944,37 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, Value *CmpLHS, Value *CmpRHS, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS) { + // Assume success. If there's no match, callers should not use these anyway. + LHS = TrueVal; + RHS = FalseVal; + + // Recognize variations of: + // CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v))) + const APInt *C1; + if (CmpRHS == TrueVal && match(CmpRHS, m_APInt(C1))) { + const APInt *C2; + + // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1) + if (match(FalseVal, m_SMin(m_Specific(CmpLHS), m_APInt(C2))) && + C1->slt(*C2) && Pred == CmpInst::ICMP_SLT) + return {SPF_SMAX, SPNB_NA, false}; + + // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1) + if (match(FalseVal, m_SMax(m_Specific(CmpLHS), m_APInt(C2))) && + C1->sgt(*C2) && Pred == CmpInst::ICMP_SGT) + return {SPF_SMIN, SPNB_NA, false}; + + // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1) + if (match(FalseVal, m_UMin(m_Specific(CmpLHS), m_APInt(C2))) && + C1->ult(*C2) && Pred == CmpInst::ICMP_ULT) + return {SPF_UMAX, SPNB_NA, false}; + + // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1) + if (match(FalseVal, m_UMax(m_Specific(CmpLHS), m_APInt(C2))) && + C1->ugt(*C2) && Pred == CmpInst::ICMP_UGT) + return {SPF_UMIN, SPNB_NA, false}; + } + if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -3930,23 +3982,16 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, // (X >s Y) ? 0 : Z ==> (Z >s 0) ? 0 : Z ==> SMIN(Z, 0) // (X <s Y) ? 0 : Z ==> (Z <s 0) ? 0 : Z ==> SMAX(Z, 0) if (match(TrueVal, m_Zero()) && - match(FalseVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) { - LHS = TrueVal; - RHS = FalseVal; + match(FalseVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; - } // Z = X -nsw Y // (X >s Y) ? Z : 0 ==> (Z >s 0) ? Z : 0 ==> SMAX(Z, 0) // (X <s Y) ? Z : 0 ==> (Z <s 0) ? Z : 0 ==> SMIN(Z, 0) if (match(FalseVal, m_Zero()) && - match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) { - LHS = TrueVal; - RHS = FalseVal; + match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; - } - const APInt *C1; if (!match(CmpRHS, m_APInt(C1))) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -3957,41 +4002,29 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, // Is the sign bit set? // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN - if (Pred == CmpInst::ICMP_SLT && *C1 == 0 && C2->isMaxSignedValue()) { - LHS = TrueVal; - RHS = FalseVal; + if (Pred == CmpInst::ICMP_SLT && *C1 == 0 && C2->isMaxSignedValue()) return {CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; - } // Is the sign bit clear? // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN if (Pred == CmpInst::ICMP_SGT && C1->isAllOnesValue() && - C2->isMinSignedValue()) { - LHS = TrueVal; - RHS = FalseVal; + C2->isMinSignedValue()) return {CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; - } } // Look through 'not' ops to find disguised signed min/max. // (X >s C) ? ~X : ~C ==> (~X <s ~C) ? ~X : ~C ==> SMIN(~X, ~C) // (X <s C) ? ~X : ~C ==> (~X >s ~C) ? ~X : ~C ==> SMAX(~X, ~C) if (match(TrueVal, m_Not(m_Specific(CmpLHS))) && - match(FalseVal, m_APInt(C2)) && ~(*C1) == *C2) { - LHS = TrueVal; - RHS = FalseVal; + match(FalseVal, m_APInt(C2)) && ~(*C1) == *C2) return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; - } // (X >s C) ? ~C : ~X ==> (~X <s ~C) ? ~C : ~X ==> SMAX(~C, ~X) // (X <s C) ? ~C : ~X ==> (~X >s ~C) ? ~C : ~X ==> SMIN(~C, ~X) if (match(FalseVal, m_Not(m_Specific(CmpLHS))) && - match(TrueVal, m_APInt(C2)) && ~(*C1) == *C2) { - LHS = TrueVal; - RHS = FalseVal; + match(TrueVal, m_APInt(C2)) && ~(*C1) == *C2) return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; - } return {SPF_UNKNOWN, SPNB_NA, false}; } @@ -4118,58 +4151,64 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, Instruction::CastOps *CastOp) { - CastInst *CI = dyn_cast<CastInst>(V1); - Constant *C = dyn_cast<Constant>(V2); - if (!CI) - return nullptr; - *CastOp = CI->getOpcode(); - - if (auto *CI2 = dyn_cast<CastInst>(V2)) { - // If V1 and V2 are both the same cast from the same type, we can look - // through V1. - if (CI2->getOpcode() == CI->getOpcode() && - CI2->getSrcTy() == CI->getSrcTy()) - return CI2->getOperand(0); + auto *Cast1 = dyn_cast<CastInst>(V1); + if (!Cast1) return nullptr; - } else if (!C) { + + *CastOp = Cast1->getOpcode(); + Type *SrcTy = Cast1->getSrcTy(); + if (auto *Cast2 = dyn_cast<CastInst>(V2)) { + // If V1 and V2 are both the same cast from the same type, look through V1. + if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy()) + return Cast2->getOperand(0); return nullptr; } - Constant *CastedTo = nullptr; - - if (isa<ZExtInst>(CI) && CmpI->isUnsigned()) - CastedTo = ConstantExpr::getTrunc(C, CI->getSrcTy()); - - if (isa<SExtInst>(CI) && CmpI->isSigned()) - CastedTo = ConstantExpr::getTrunc(C, CI->getSrcTy(), true); - - if (isa<TruncInst>(CI)) - CastedTo = ConstantExpr::getIntegerCast(C, CI->getSrcTy(), CmpI->isSigned()); - - if (isa<FPTruncInst>(CI)) - CastedTo = ConstantExpr::getFPExtend(C, CI->getSrcTy(), true); - - if (isa<FPExtInst>(CI)) - CastedTo = ConstantExpr::getFPTrunc(C, CI->getSrcTy(), true); - - if (isa<FPToUIInst>(CI)) - CastedTo = ConstantExpr::getUIToFP(C, CI->getSrcTy(), true); - - if (isa<FPToSIInst>(CI)) - CastedTo = ConstantExpr::getSIToFP(C, CI->getSrcTy(), true); - - if (isa<UIToFPInst>(CI)) - CastedTo = ConstantExpr::getFPToUI(C, CI->getSrcTy(), true); + auto *C = dyn_cast<Constant>(V2); + if (!C) + return nullptr; - if (isa<SIToFPInst>(CI)) - CastedTo = ConstantExpr::getFPToSI(C, CI->getSrcTy(), true); + Constant *CastedTo = nullptr; + switch (*CastOp) { + case Instruction::ZExt: + if (CmpI->isUnsigned()) + CastedTo = ConstantExpr::getTrunc(C, SrcTy); + break; + case Instruction::SExt: + if (CmpI->isSigned()) + CastedTo = ConstantExpr::getTrunc(C, SrcTy, true); + break; + case Instruction::Trunc: + CastedTo = ConstantExpr::getIntegerCast(C, SrcTy, CmpI->isSigned()); + break; + case Instruction::FPTrunc: + CastedTo = ConstantExpr::getFPExtend(C, SrcTy, true); + break; + case Instruction::FPExt: + CastedTo = ConstantExpr::getFPTrunc(C, SrcTy, true); + break; + case Instruction::FPToUI: + CastedTo = ConstantExpr::getUIToFP(C, SrcTy, true); + break; + case Instruction::FPToSI: + CastedTo = ConstantExpr::getSIToFP(C, SrcTy, true); + break; + case Instruction::UIToFP: + CastedTo = ConstantExpr::getFPToUI(C, SrcTy, true); + break; + case Instruction::SIToFP: + CastedTo = ConstantExpr::getFPToSI(C, SrcTy, true); + break; + default: + break; + } if (!CastedTo) return nullptr; - Constant *CastedBack = - ConstantExpr::getCast(CI->getOpcode(), CastedTo, C->getType(), true); // Make sure the cast doesn't lose any information. + Constant *CastedBack = + ConstantExpr::getCast(*CastOp, CastedTo, C->getType(), true); if (CastedBack != C) return nullptr; @@ -4253,11 +4292,10 @@ static bool isTruePredicate(CmpInst::Predicate Pred, // If X & C == 0 then (X | C) == X +_{nuw} C if (match(A, m_Or(m_Value(X), m_APInt(CA))) && match(B, m_Or(m_Specific(X), m_APInt(CB)))) { - unsigned BitWidth = CA->getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(X, KnownZero, KnownOne, DL, Depth + 1, AC, CxtI, DT); + KnownBits Known(CA->getBitWidth()); + computeKnownBits(X, Known, DL, Depth + 1, AC, CxtI, DT); - if ((KnownZero & *CA) == *CA && (KnownZero & *CB) == *CB) + if (CA->isSubsetOf(Known.Zero) && CB->isSubsetOf(Known.Zero)) return true; } diff --git a/contrib/llvm/lib/Analysis/VectorUtils.cpp b/contrib/llvm/lib/Analysis/VectorUtils.cpp index 7e598f435ff5..2d2249da4e13 100644 --- a/contrib/llvm/lib/Analysis/VectorUtils.cpp +++ b/contrib/llvm/lib/Analysis/VectorUtils.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" using namespace llvm; using namespace llvm::PatternMatch; @@ -488,3 +489,88 @@ Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { return Inst; } + +Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF, + unsigned NumVecs) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < VF; i++) + for (unsigned j = 0; j < NumVecs; j++) + Mask.push_back(Builder.getInt32(j * VF + i)); + + return ConstantVector::get(Mask); +} + +Constant *llvm::createStrideMask(IRBuilder<> &Builder, unsigned Start, + unsigned Stride, unsigned VF) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < VF; i++) + Mask.push_back(Builder.getInt32(Start + i * Stride)); + + return ConstantVector::get(Mask); +} + +Constant *llvm::createSequentialMask(IRBuilder<> &Builder, unsigned Start, + unsigned NumInts, unsigned NumUndefs) { + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < NumInts; i++) + Mask.push_back(Builder.getInt32(Start + i)); + + Constant *Undef = UndefValue::get(Builder.getInt32Ty()); + for (unsigned i = 0; i < NumUndefs; i++) + Mask.push_back(Undef); + + return ConstantVector::get(Mask); +} + +/// A helper function for concatenating vectors. This function concatenates two +/// vectors having the same element type. If the second vector has fewer +/// elements than the first, it is padded with undefs. +static Value *concatenateTwoVectors(IRBuilder<> &Builder, Value *V1, + Value *V2) { + VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); + VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); + assert(VecTy1 && VecTy2 && + VecTy1->getScalarType() == VecTy2->getScalarType() && + "Expect two vectors with the same element type"); + + unsigned NumElts1 = VecTy1->getNumElements(); + unsigned NumElts2 = VecTy2->getNumElements(); + assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); + + if (NumElts1 > NumElts2) { + // Extend with UNDEFs. + Constant *ExtMask = + createSequentialMask(Builder, 0, NumElts2, NumElts1 - NumElts2); + V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); + } + + Constant *Mask = createSequentialMask(Builder, 0, NumElts1 + NumElts2, 0); + return Builder.CreateShuffleVector(V1, V2, Mask); +} + +Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) { + unsigned NumVecs = Vecs.size(); + assert(NumVecs > 1 && "Should be at least two vectors"); + + SmallVector<Value *, 8> ResList; + ResList.append(Vecs.begin(), Vecs.end()); + do { + SmallVector<Value *, 8> TmpList; + for (unsigned i = 0; i < NumVecs - 1; i += 2) { + Value *V0 = ResList[i], *V1 = ResList[i + 1]; + assert((V0->getType() == V1->getType() || i == NumVecs - 2) && + "Only the last vector may have a different type"); + + TmpList.push_back(concatenateTwoVectors(Builder, V0, V1)); + } + + // Push the last vector if the total number of vectors is odd. + if (NumVecs % 2 != 0) + TmpList.push_back(ResList[NumVecs - 1]); + + ResList = TmpList; + NumVecs = ResList.size(); + } while (NumVecs > 1); + + return ResList[0]; +} |