diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2016-07-23 20:41:05 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2016-07-23 20:41:05 +0000 |
commit | 01095a5d43bbfde13731688ddcf6048ebb8b7721 (patch) | |
tree | 4def12e759965de927d963ac65840d663ef9d1ea /lib/Transforms | |
parent | f0f4822ed4b66e3579e92a89f368f8fb860e218e (diff) | |
download | src-01095a5d43bbfde13731688ddcf6048ebb8b7721.tar.gz src-01095a5d43bbfde13731688ddcf6048ebb8b7721.zip |
Vendor import of llvm release_39 branch r276489:vendor/llvm/llvm-release_39-r276489
Notes
Notes:
svn path=/vendor/llvm/dist/; revision=303231
svn path=/vendor/llvm/llvm-release_39-r276489/; revision=303232; tag=vendor/llvm/llvm-release_39-r276489
Diffstat (limited to 'lib/Transforms')
178 files changed, 31236 insertions, 20371 deletions
diff --git a/lib/Transforms/Hello/CMakeLists.txt b/lib/Transforms/Hello/CMakeLists.txt index e0b81907c7fb..4a55dd9c04b8 100644 --- a/lib/Transforms/Hello/CMakeLists.txt +++ b/lib/Transforms/Hello/CMakeLists.txt @@ -15,4 +15,6 @@ add_llvm_loadable_module( LLVMHello DEPENDS intrinsics_gen + PLUGIN_TOOL + opt ) diff --git a/lib/Transforms/Hello/Makefile b/lib/Transforms/Hello/Makefile deleted file mode 100644 index f1e31489d3c9..000000000000 --- a/lib/Transforms/Hello/Makefile +++ /dev/null @@ -1,24 +0,0 @@ -##===- lib/Transforms/Hello/Makefile -----------------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMHello -LOADABLE_MODULE = 1 -USEDLIBS = - -# If we don't need RTTI or EH, there's no reason to export anything -# from the hello plugin. -ifneq ($(REQUIRES_RTTI), 1) -ifneq ($(REQUIRES_EH), 1) -EXPORTED_SYMBOL_FILE = $(PROJ_SRC_DIR)/Hello.exports -endif -endif - -include $(LEVEL)/Makefile.common - diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp index 0e05129b5261..0716a3a9cbe9 100644 --- a/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -38,6 +38,7 @@ #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" @@ -68,6 +69,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + getAAResultsAnalysisUsage(AU); CallGraphSCCPass::getAnalysisUsage(AU); } @@ -78,19 +80,8 @@ namespace { initializeArgPromotionPass(*PassRegistry::getPassRegistry()); } - /// A vector used to hold the indices of a single GEP instruction - typedef std::vector<uint64_t> IndicesVector; - private: - bool isDenselyPacked(Type *type, const DataLayout &DL); - bool canPaddingBeAccessed(Argument *Arg); - CallGraphNode *PromoteArguments(CallGraphNode *CGN); - bool isSafeToPromoteArgument(Argument *Arg, bool isByVal, - AAResults &AAR) const; - CallGraphNode *DoPromotion(Function *F, - SmallPtrSetImpl<Argument*> &ArgsToPromote, - SmallPtrSetImpl<Argument*> &ByValArgsToTransform); - + using llvm::Pass::doInitialization; bool doInitialization(CallGraph &CG) override; /// The maximum number of elements to expand, or 0 for unlimited. @@ -98,6 +89,21 @@ namespace { }; } +/// A vector used to hold the indices of a single GEP instruction +typedef std::vector<uint64_t> IndicesVector; + +static CallGraphNode * +PromoteArguments(CallGraphNode *CGN, CallGraph &CG, + function_ref<AAResults &(Function &F)> AARGetter, + unsigned MaxElements); +static bool isDenselyPacked(Type *type, const DataLayout &DL); +static bool canPaddingBeAccessed(Argument *Arg); +static bool isSafeToPromoteArgument(Argument *Arg, bool isByVal, AAResults &AAR, + unsigned MaxElements); +static CallGraphNode * +DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, + SmallPtrSetImpl<Argument *> &ByValArgsToTransform, CallGraph &CG); + char ArgPromotion::ID = 0; INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", "Promote 'by reference' arguments to scalars", false, false) @@ -111,16 +117,19 @@ Pass *llvm::createArgumentPromotionPass(unsigned maxElements) { return new ArgPromotion(maxElements); } -bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { +static bool runImpl(CallGraphSCC &SCC, CallGraph &CG, + function_ref<AAResults &(Function &F)> AARGetter, + unsigned MaxElements) { bool Changed = false, LocalChange; do { // Iterate until we stop promoting from this SCC. LocalChange = false; // Attempt to promote arguments from all functions in this SCC. - for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) { - if (CallGraphNode *CGN = PromoteArguments(*I)) { + for (CallGraphNode *OldNode : SCC) { + if (CallGraphNode *NewNode = + PromoteArguments(OldNode, CG, AARGetter, MaxElements)) { LocalChange = true; - SCC.ReplaceNode(*I, CGN); + SCC.ReplaceNode(OldNode, NewNode); } } Changed |= LocalChange; // Remember that we changed something. @@ -129,8 +138,30 @@ bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { return Changed; } +bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { + if (skipSCC(SCC)) + return false; + + // Get the callgraph information that we need to update to reflect our + // changes. + CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + + // We compute dedicated AA results for each function in the SCC as needed. We + // use a lambda referencing external objects so that they live long enough to + // be queried, but we re-use them each time. + Optional<BasicAAResult> BAR; + Optional<AAResults> AAR; + auto AARGetter = [&](Function &F) -> AAResults & { + BAR.emplace(createLegacyPMBasicAAResult(*this, F)); + AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); + return *AAR; + }; + + return runImpl(SCC, CG, AARGetter, maxElements); +} + /// \brief Checks if a type could have padding bytes. -bool ArgPromotion::isDenselyPacked(Type *type, const DataLayout &DL) { +static bool isDenselyPacked(Type *type, const DataLayout &DL) { // There is no size information, so be conservative. if (!type->isSized()) @@ -166,7 +197,7 @@ bool ArgPromotion::isDenselyPacked(Type *type, const DataLayout &DL) { } /// \brief Checks if the padding bytes of an argument could be accessed. -bool ArgPromotion::canPaddingBeAccessed(Argument *arg) { +static bool canPaddingBeAccessed(Argument *arg) { assert(arg->hasByValAttr()); @@ -207,7 +238,10 @@ bool ArgPromotion::canPaddingBeAccessed(Argument *arg) { /// example, all callers are direct). If safe to promote some arguments, it /// calls the DoPromotion method. /// -CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) { +static CallGraphNode * +PromoteArguments(CallGraphNode *CGN, CallGraph &CG, + function_ref<AAResults &(Function &F)> AARGetter, + unsigned MaxElements) { Function *F = CGN->getFunction(); // Make sure that it is local to this module. @@ -242,20 +276,13 @@ CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) { const DataLayout &DL = F->getParent()->getDataLayout(); - // We need to manually construct BasicAA directly in order to disable its use - // of other function analyses. - BasicAAResult BAR(createLegacyPMBasicAAResult(*this, *F)); - - // Construct our own AA results for this function. We do this manually to - // work around the limitations of the legacy pass manager. - AAResults AAR(createLegacyPMAAResults(*this, *F, BAR)); + AAResults &AAR = AARGetter(*F); // Check to see which arguments are promotable. If an argument is promotable, // add it to ArgsToPromote. SmallPtrSet<Argument*, 8> ArgsToPromote; SmallPtrSet<Argument*, 8> ByValArgsToTransform; - for (unsigned i = 0, e = PointerArgs.size(); i != e; ++i) { - Argument *PtrArg = PointerArgs[i]; + for (Argument *PtrArg : PointerArgs) { Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType(); // Replace sret attribute with noalias. This reduces register pressure by @@ -285,10 +312,10 @@ CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) { (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg)); if (isSafeToPromote) { if (StructType *STy = dyn_cast<StructType>(AgTy)) { - if (maxElements > 0 && STy->getNumElements() > maxElements) { + if (MaxElements > 0 && STy->getNumElements() > MaxElements) { DEBUG(dbgs() << "argpromotion disable promoting argument '" << PtrArg->getName() << "' because it would require adding more" - << " than " << maxElements << " arguments to the function.\n"); + << " than " << MaxElements << " arguments to the function.\n"); continue; } @@ -302,7 +329,7 @@ CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) { } // Safe to transform, don't even bother trying to "promote" it. - // Passing the elements as a scalar will allow scalarrepl to hack on + // Passing the elements as a scalar will allow sroa to hack on // the new alloca we introduce. if (AllSimple) { ByValArgsToTransform.insert(PtrArg); @@ -328,7 +355,8 @@ CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) { } // Otherwise, see if we can promote the pointer to its value. - if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR)) + if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR, + MaxElements)) ArgsToPromote.insert(PtrArg); } @@ -336,7 +364,7 @@ CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) { if (ArgsToPromote.empty() && ByValArgsToTransform.empty()) return nullptr; - return DoPromotion(F, ArgsToPromote, ByValArgsToTransform); + return DoPromotion(F, ArgsToPromote, ByValArgsToTransform, CG); } /// AllCallersPassInValidPointerForArgument - Return true if we can prove that @@ -364,8 +392,7 @@ static bool AllCallersPassInValidPointerForArgument(Argument *Arg) { /// elements in Prefix is the same as the corresponding elements in Longer. /// /// This means it also returns true when Prefix and Longer are equal! -static bool IsPrefix(const ArgPromotion::IndicesVector &Prefix, - const ArgPromotion::IndicesVector &Longer) { +static bool IsPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) { if (Prefix.size() > Longer.size()) return false; return std::equal(Prefix.begin(), Prefix.end(), Longer.begin()); @@ -373,9 +400,9 @@ static bool IsPrefix(const ArgPromotion::IndicesVector &Prefix, /// Checks if Indices, or a prefix of Indices, is in Set. -static bool PrefixIn(const ArgPromotion::IndicesVector &Indices, - std::set<ArgPromotion::IndicesVector> &Set) { - std::set<ArgPromotion::IndicesVector>::iterator Low; +static bool PrefixIn(const IndicesVector &Indices, + std::set<IndicesVector> &Set) { + std::set<IndicesVector>::iterator Low; Low = Set.upper_bound(Indices); if (Low != Set.begin()) Low--; @@ -392,9 +419,9 @@ static bool PrefixIn(const ArgPromotion::IndicesVector &Indices, /// is already a prefix of Indices in Safe, Indices are implicitely marked safe /// already. Furthermore, any indices that Indices is itself a prefix of, are /// removed from Safe (since they are implicitely safe because of Indices now). -static void MarkIndicesSafe(const ArgPromotion::IndicesVector &ToMark, - std::set<ArgPromotion::IndicesVector> &Safe) { - std::set<ArgPromotion::IndicesVector>::iterator Low; +static void MarkIndicesSafe(const IndicesVector &ToMark, + std::set<IndicesVector> &Safe) { + std::set<IndicesVector>::iterator Low; Low = Safe.upper_bound(ToMark); // Guard against the case where Safe is empty if (Low != Safe.begin()) @@ -415,9 +442,9 @@ static void MarkIndicesSafe(const ArgPromotion::IndicesVector &ToMark, Low = Safe.insert(Low, ToMark); ++Low; // If there we're a prefix of longer index list(s), remove those - std::set<ArgPromotion::IndicesVector>::iterator End = Safe.end(); + std::set<IndicesVector>::iterator End = Safe.end(); while (Low != End && IsPrefix(ToMark, *Low)) { - std::set<ArgPromotion::IndicesVector>::iterator Remove = Low; + std::set<IndicesVector>::iterator Remove = Low; ++Low; Safe.erase(Remove); } @@ -428,9 +455,8 @@ static void MarkIndicesSafe(const ArgPromotion::IndicesVector &ToMark, /// This method limits promotion of aggregates to only promote up to three /// elements of the aggregate in order to avoid exploding the number of /// arguments passed in. -bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, - bool isByValOrInAlloca, - AAResults &AAR) const { +static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, + AAResults &AAR, unsigned MaxElements) { typedef std::set<IndicesVector> GEPIndicesSet; // Quick exit for unused arguments @@ -518,7 +544,8 @@ bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, // TODO: This runs the above loop over and over again for dead GEPs // Couldn't we just do increment the UI iterator earlier and erase the // use? - return isSafeToPromoteArgument(Arg, isByValOrInAlloca, AAR); + return isSafeToPromoteArgument(Arg, isByValOrInAlloca, AAR, + MaxElements); } // Ensure that all of the indices are constants. @@ -552,10 +579,10 @@ bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, // to make sure that we aren't promoting too many elements. If so, nothing // to do. if (ToPromote.find(Operands) == ToPromote.end()) { - if (maxElements > 0 && ToPromote.size() == maxElements) { + if (MaxElements > 0 && ToPromote.size() == MaxElements) { DEBUG(dbgs() << "argpromotion not promoting argument '" << Arg->getName() << "' because it would require adding more " - << "than " << maxElements << " arguments to the function.\n"); + << "than " << MaxElements << " arguments to the function.\n"); // We limit aggregate promotion to only promoting up to a fixed number // of elements of the aggregate. return false; @@ -575,10 +602,9 @@ bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, // blocks we know to be transparent to the load. SmallPtrSet<BasicBlock*, 16> TranspBlocks; - for (unsigned i = 0, e = Loads.size(); i != e; ++i) { + for (LoadInst *Load : Loads) { // Check to see if the load is invalidated from the start of the block to // the load itself. - LoadInst *Load = Loads[i]; BasicBlock *BB = Load->getParent(); MemoryLocation Loc = MemoryLocation::get(Load); @@ -604,9 +630,9 @@ bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, /// DoPromotion - This method actually performs the promotion of the specified /// arguments, and returns the new function. At this point, we know that it's /// safe to do so. -CallGraphNode *ArgPromotion::DoPromotion(Function *F, - SmallPtrSetImpl<Argument*> &ArgsToPromote, - SmallPtrSetImpl<Argument*> &ByValArgsToTransform) { +static CallGraphNode * +DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, + SmallPtrSetImpl<Argument *> &ByValArgsToTransform, CallGraph &CG) { // Start by computing a new prototype for the function, which is the same as // the old function, but has modified arguments. @@ -700,12 +726,11 @@ CallGraphNode *ArgPromotion::DoPromotion(Function *F, } // Add a parameter to the function for each element passed in. - for (ScalarizeTable::iterator SI = ArgIndices.begin(), - E = ArgIndices.end(); SI != E; ++SI) { + for (const auto &ArgIndex : ArgIndices) { // not allowed to dereference ->begin() if size() is 0 Params.push_back(GetElementPtrInst::getIndexedType( cast<PointerType>(I->getType()->getScalarType())->getElementType(), - SI->second)); + ArgIndex.second)); assert(Params.back()); } @@ -745,10 +770,6 @@ CallGraphNode *ArgPromotion::DoPromotion(Function *F, F->getParent()->getFunctionList().insert(F->getIterator(), NF); NF->takeName(F); - // Get the callgraph information that we need to update to reflect our - // changes. - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - // Get a new callgraph node for NF. CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF); @@ -800,27 +821,25 @@ CallGraphNode *ArgPromotion::DoPromotion(Function *F, // Store the Value* version of the indices in here, but declare it now // for reuse. std::vector<Value*> Ops; - for (ScalarizeTable::iterator SI = ArgIndices.begin(), - E = ArgIndices.end(); SI != E; ++SI) { + for (const auto &ArgIndex : ArgIndices) { Value *V = *AI; - LoadInst *OrigLoad = OriginalLoads[std::make_pair(&*I, SI->second)]; - if (!SI->second.empty()) { - Ops.reserve(SI->second.size()); + LoadInst *OrigLoad = + OriginalLoads[std::make_pair(&*I, ArgIndex.second)]; + if (!ArgIndex.second.empty()) { + Ops.reserve(ArgIndex.second.size()); Type *ElTy = V->getType(); - for (IndicesVector::const_iterator II = SI->second.begin(), - IE = SI->second.end(); - II != IE; ++II) { + for (unsigned long II : ArgIndex.second) { // Use i32 to index structs, and i64 for others (pointers/arrays). // This satisfies GEP constraints. Type *IdxTy = (ElTy->isStructTy() ? Type::getInt32Ty(F->getContext()) : Type::getInt64Ty(F->getContext())); - Ops.push_back(ConstantInt::get(IdxTy, *II)); + Ops.push_back(ConstantInt::get(IdxTy, II)); // Keep track of the type we're currently indexing. - ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(*II); + ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II); } // And create a GEP to extract those indices. - V = GetElementPtrInst::Create(SI->first, V, Ops, + V = GetElementPtrInst::Create(ArgIndex.first, V, Ops, V->getName() + ".idx", Call); Ops.clear(); } @@ -852,15 +871,18 @@ CallGraphNode *ArgPromotion::DoPromotion(Function *F, AttributesVec.push_back(AttributeSet::get(Call->getContext(), CallPAL.getFnAttributes())); + SmallVector<OperandBundleDef, 1> OpBundles; + CS.getOperandBundlesAsDefs(OpBundles); + Instruction *New; if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), - Args, "", Call); + Args, OpBundles, "", Call); cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); cast<InvokeInst>(New)->setAttributes(AttributeSet::get(II->getContext(), AttributesVec)); } else { - New = CallInst::Create(NF, Args, "", Call); + New = CallInst::Create(NF, Args, OpBundles, "", Call); cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); cast<CallInst>(New)->setAttributes(AttributeSet::get(New->getContext(), AttributesVec)); diff --git a/lib/Transforms/IPO/CMakeLists.txt b/lib/Transforms/IPO/CMakeLists.txt index 351b88fe2aa0..d6782c738cbe 100644 --- a/lib/Transforms/IPO/CMakeLists.txt +++ b/lib/Transforms/IPO/CMakeLists.txt @@ -19,7 +19,7 @@ add_llvm_library(LLVMipo Inliner.cpp Internalize.cpp LoopExtractor.cpp - LowerBitSets.cpp + LowerTypeTests.cpp MergeFunctions.cpp PartialInlining.cpp PassManagerBuilder.cpp @@ -27,6 +27,7 @@ add_llvm_library(LLVMipo SampleProfile.cpp StripDeadPrototypes.cpp StripSymbols.cpp + WholeProgramDevirt.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp index 0aa49d6fde01..d75ed206ad23 100644 --- a/lib/Transforms/IPO/ConstantMerge.cpp +++ b/lib/Transforms/IPO/ConstantMerge.cpp @@ -17,7 +17,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/ConstantMerge.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/SmallPtrSet.h" @@ -28,41 +28,13 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" using namespace llvm; #define DEBUG_TYPE "constmerge" STATISTIC(NumMerged, "Number of global constants merged"); -namespace { - struct ConstantMerge : public ModulePass { - static char ID; // Pass identification, replacement for typeid - ConstantMerge() : ModulePass(ID) { - initializeConstantMergePass(*PassRegistry::getPassRegistry()); - } - - // For this pass, process all of the globals in the module, eliminating - // duplicate constants. - bool runOnModule(Module &M) override; - - // Return true iff we can determine the alignment of this global variable. - bool hasKnownAlignment(GlobalVariable *GV) const; - - // Return the alignment of the global, including converting the default - // alignment to a concrete value. - unsigned getAlignment(GlobalVariable *GV) const; - - }; -} - -char ConstantMerge::ID = 0; -INITIALIZE_PASS(ConstantMerge, "constmerge", - "Merge Duplicate Global Constants", false, false) - -ModulePass *llvm::createConstantMergePass() { return new ConstantMerge(); } - - - /// Find values that are marked as llvm.used. static void FindUsedValues(GlobalVariable *LLVMUsed, SmallPtrSetImpl<const GlobalValue*> &UsedValues) { @@ -85,18 +57,17 @@ static bool IsBetterCanonical(const GlobalVariable &A, if (A.hasLocalLinkage() && !B.hasLocalLinkage()) return false; - return A.hasUnnamedAddr(); + return A.hasGlobalUnnamedAddr(); } -unsigned ConstantMerge::getAlignment(GlobalVariable *GV) const { +static unsigned getAlignment(GlobalVariable *GV) { unsigned Align = GV->getAlignment(); if (Align) return Align; return GV->getParent()->getDataLayout().getPreferredAlignment(GV); } -bool ConstantMerge::runOnModule(Module &M) { - +static bool mergeConstants(Module &M) { // Find all the globals that are marked "used". These cannot be merged. SmallPtrSet<const GlobalValue*, 8> UsedGlobals; FindUsedValues(M.getGlobalVariable("llvm.used"), UsedGlobals); @@ -181,11 +152,11 @@ bool ConstantMerge::runOnModule(Module &M) { if (!Slot || Slot == GV) continue; - if (!Slot->hasUnnamedAddr() && !GV->hasUnnamedAddr()) + if (!Slot->hasGlobalUnnamedAddr() && !GV->hasGlobalUnnamedAddr()) continue; - if (!GV->hasUnnamedAddr()) - Slot->setUnnamedAddr(false); + if (!GV->hasGlobalUnnamedAddr()) + Slot->setUnnamedAddr(GlobalValue::UnnamedAddr::None); // Make all uses of the duplicate constant use the canonical version. Replacements.push_back(std::make_pair(GV, Slot)); @@ -220,3 +191,34 @@ bool ConstantMerge::runOnModule(Module &M) { Replacements.clear(); } } + +PreservedAnalyses ConstantMergePass::run(Module &M, ModuleAnalysisManager &) { + if (!mergeConstants(M)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +namespace { +struct ConstantMergeLegacyPass : public ModulePass { + static char ID; // Pass identification, replacement for typeid + ConstantMergeLegacyPass() : ModulePass(ID) { + initializeConstantMergeLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + // For this pass, process all of the globals in the module, eliminating + // duplicate constants. + bool runOnModule(Module &M) { + if (skipModule(M)) + return false; + return mergeConstants(M); + } +}; +} + +char ConstantMergeLegacyPass::ID = 0; +INITIALIZE_PASS(ConstantMergeLegacyPass, "constmerge", + "Merge Duplicate Global Constants", false, false) + +ModulePass *llvm::createConstantMergePass() { + return new ConstantMergeLegacyPass(); +} diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp index 5bbb7513005c..58731eaf6e30 100644 --- a/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/CrossDSOCFI.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Statistic.h" @@ -30,13 +30,14 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; #define DEBUG_TYPE "cross-dso-cfi" -STATISTIC(TypeIds, "Number of unique type identifiers"); +STATISTIC(NumTypeIds, "Number of unique type identifiers"); namespace { @@ -46,13 +47,10 @@ struct CrossDSOCFI : public ModulePass { initializeCrossDSOCFIPass(*PassRegistry::getPassRegistry()); } - Module *M; MDNode *VeryLikelyWeights; - ConstantInt *extractBitSetTypeId(MDNode *MD); - void buildCFICheck(); - - bool doInitialization(Module &M) override; + ConstantInt *extractNumericTypeId(MDNode *MD); + void buildCFICheck(Module &M); bool runOnModule(Module &M) override; }; @@ -65,18 +63,10 @@ char CrossDSOCFI::ID = 0; ModulePass *llvm::createCrossDSOCFIPass() { return new CrossDSOCFI; } -bool CrossDSOCFI::doInitialization(Module &Mod) { - M = &Mod; - VeryLikelyWeights = - MDBuilder(M->getContext()).createBranchWeights((1U << 20) - 1, 1); - - return false; -} - -/// extractBitSetTypeId - Extracts TypeId from a hash-based bitset MDNode. -ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) { +/// Extracts a numeric type identifier from an MDNode containing type metadata. +ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) { // This check excludes vtables for classes inside anonymous namespaces. - auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(0)); + auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(1)); if (!TM) return nullptr; auto C = dyn_cast_or_null<ConstantInt>(TM->getValue()); @@ -84,68 +74,63 @@ ConstantInt *CrossDSOCFI::extractBitSetTypeId(MDNode *MD) { // We are looking for i64 constants. if (C->getBitWidth() != 64) return nullptr; - // Sanity check. - auto FM = dyn_cast_or_null<ValueAsMetadata>(MD->getOperand(1)); - // Can be null if a function was removed by an optimization. - if (FM) { - auto F = dyn_cast<Function>(FM->getValue()); - // But can never be a function declaration. - assert(!F || !F->isDeclaration()); - (void)F; // Suppress unused variable warning in the no-asserts build. - } return C; } /// buildCFICheck - emits __cfi_check for the current module. -void CrossDSOCFI::buildCFICheck() { +void CrossDSOCFI::buildCFICheck(Module &M) { // FIXME: verify that __cfi_check ends up near the end of the code section, - // but before the jump slots created in LowerBitSets. - llvm::DenseSet<uint64_t> BitSetIds; - NamedMDNode *BitSetNM = M->getNamedMetadata("llvm.bitsets"); - - if (BitSetNM) - for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I) - if (ConstantInt *TypeId = extractBitSetTypeId(BitSetNM->getOperand(I))) - BitSetIds.insert(TypeId->getZExtValue()); - - LLVMContext &Ctx = M->getContext(); - Constant *C = M->getOrInsertFunction( - "__cfi_check", - FunctionType::get( - Type::getVoidTy(Ctx), - {Type::getInt64Ty(Ctx), PointerType::getUnqual(Type::getInt8Ty(Ctx))}, - false)); + // but before the jump slots created in LowerTypeTests. + llvm::DenseSet<uint64_t> TypeIds; + SmallVector<MDNode *, 2> Types; + for (GlobalObject &GO : M.global_objects()) { + Types.clear(); + GO.getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + // Sanity check. GO must not be a function declaration. + assert(!isa<Function>(&GO) || !cast<Function>(&GO)->isDeclaration()); + + if (ConstantInt *TypeId = extractNumericTypeId(Type)) + TypeIds.insert(TypeId->getZExtValue()); + } + } + + LLVMContext &Ctx = M.getContext(); + Constant *C = M.getOrInsertFunction( + "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), + Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx), nullptr); Function *F = dyn_cast<Function>(C); F->setAlignment(4096); auto args = F->arg_begin(); - Argument &CallSiteTypeId = *(args++); + Value &CallSiteTypeId = *(args++); CallSiteTypeId.setName("CallSiteTypeId"); - Argument &Addr = *(args++); + Value &Addr = *(args++); Addr.setName("Addr"); + Value &CFICheckFailData = *(args++); + CFICheckFailData.setName("CFICheckFailData"); assert(args == F->arg_end()); BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); + BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F); - BasicBlock *TrapBB = BasicBlock::Create(Ctx, "trap", F); - IRBuilder<> IRBTrap(TrapBB); - Function *TrapFn = Intrinsic::getDeclaration(M, Intrinsic::trap); - llvm::CallInst *TrapCall = IRBTrap.CreateCall(TrapFn); - TrapCall->setDoesNotReturn(); - TrapCall->setDoesNotThrow(); - IRBTrap.CreateUnreachable(); + BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F); + IRBuilder<> IRBFail(TrapBB); + Constant *CFICheckFailFn = M.getOrInsertFunction( + "__cfi_check_fail", Type::getVoidTy(Ctx), Type::getInt8PtrTy(Ctx), + Type::getInt8PtrTy(Ctx), nullptr); + IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr}); + IRBFail.CreateBr(ExitBB); - BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F); IRBuilder<> IRBExit(ExitBB); IRBExit.CreateRetVoid(); IRBuilder<> IRB(BB); - SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, BitSetIds.size()); - for (uint64_t TypeId : BitSetIds) { + SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, TypeIds.size()); + for (uint64_t TypeId : TypeIds) { ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId); BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F); IRBuilder<> IRBTest(TestBB); - Function *BitsetTestFn = - Intrinsic::getDeclaration(M, Intrinsic::bitset_test); + Function *BitsetTestFn = Intrinsic::getDeclaration(&M, Intrinsic::type_test); Value *Test = IRBTest.CreateCall( BitsetTestFn, {&Addr, MetadataAsValue::get( @@ -154,13 +139,26 @@ void CrossDSOCFI::buildCFICheck() { BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights); SI->addCase(CaseTypeId, TestBB); - ++TypeIds; + ++NumTypeIds; } } bool CrossDSOCFI::runOnModule(Module &M) { + if (skipModule(M)) + return false; + + VeryLikelyWeights = + MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); if (M.getModuleFlag("Cross-DSO CFI") == nullptr) return false; - buildCFICheck(); + buildCFICheck(M); return true; } + +PreservedAnalyses CrossDSOCFIPass::run(Module &M, AnalysisManager<Module> &AM) { + CrossDSOCFI Impl; + bool Changed = Impl.runOnModule(M); + if (!Changed) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp index 4de3d95ab11d..c8c895b18796 100644 --- a/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -17,8 +17,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" -#include "llvm/ADT/DenseMap.h" +#include "llvm/Transforms/IPO/DeadArgumentElimination.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" @@ -35,8 +34,8 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include <map> #include <set> #include <tuple> using namespace llvm; @@ -51,77 +50,6 @@ namespace { /// DAE - The dead argument elimination pass. /// class DAE : public ModulePass { - public: - - /// Struct that represents (part of) either a return value or a function - /// argument. Used so that arguments and return values can be used - /// interchangeably. - struct RetOrArg { - RetOrArg(const Function *F, unsigned Idx, bool IsArg) : F(F), Idx(Idx), - IsArg(IsArg) {} - const Function *F; - unsigned Idx; - bool IsArg; - - /// Make RetOrArg comparable, so we can put it into a map. - bool operator<(const RetOrArg &O) const { - return std::tie(F, Idx, IsArg) < std::tie(O.F, O.Idx, O.IsArg); - } - - /// Make RetOrArg comparable, so we can easily iterate the multimap. - bool operator==(const RetOrArg &O) const { - return F == O.F && Idx == O.Idx && IsArg == O.IsArg; - } - - std::string getDescription() const { - return (Twine(IsArg ? "Argument #" : "Return value #") + utostr(Idx) + - " of function " + F->getName()).str(); - } - }; - - /// Liveness enum - During our initial pass over the program, we determine - /// that things are either alive or maybe alive. We don't mark anything - /// explicitly dead (even if we know they are), since anything not alive - /// with no registered uses (in Uses) will never be marked alive and will - /// thus become dead in the end. - enum Liveness { Live, MaybeLive }; - - /// Convenience wrapper - RetOrArg CreateRet(const Function *F, unsigned Idx) { - return RetOrArg(F, Idx, false); - } - /// Convenience wrapper - RetOrArg CreateArg(const Function *F, unsigned Idx) { - return RetOrArg(F, Idx, true); - } - - typedef std::multimap<RetOrArg, RetOrArg> UseMap; - /// This maps a return value or argument to any MaybeLive return values or - /// arguments it uses. This allows the MaybeLive values to be marked live - /// when any of its users is marked live. - /// For example (indices are left out for clarity): - /// - Uses[ret F] = ret G - /// This means that F calls G, and F returns the value returned by G. - /// - Uses[arg F] = ret G - /// This means that some function calls G and passes its result as an - /// argument to F. - /// - Uses[ret F] = arg F - /// This means that F returns one of its own arguments. - /// - Uses[arg F] = arg G - /// This means that G calls F and passes one of its own (G's) arguments - /// directly to F. - UseMap Uses; - - typedef std::set<RetOrArg> LiveSet; - typedef std::set<const Function*> LiveFuncSet; - - /// This set contains all values that have been determined to be live. - LiveSet LiveValues; - /// This set contains all values that are cannot be changed in any way. - LiveFuncSet LiveFunctions; - - typedef SmallVector<RetOrArg, 5> UseVector; - protected: // DAH uses this to specify a different ID. explicit DAE(char &ID) : ModulePass(ID) {} @@ -132,25 +60,16 @@ namespace { initializeDAEPass(*PassRegistry::getPassRegistry()); } - bool runOnModule(Module &M) override; + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + DeadArgumentEliminationPass DAEP(ShouldHackArguments()); + ModuleAnalysisManager DummyMAM; + PreservedAnalyses PA = DAEP.run(M, DummyMAM); + return !PA.areAllPreserved(); + } virtual bool ShouldHackArguments() const { return false; } - - private: - Liveness MarkIfNotLive(RetOrArg Use, UseVector &MaybeLiveUses); - Liveness SurveyUse(const Use *U, UseVector &MaybeLiveUses, - unsigned RetValNum = -1U); - Liveness SurveyUses(const Value *V, UseVector &MaybeLiveUses); - - void SurveyFunction(const Function &F); - void MarkValue(const RetOrArg &RA, Liveness L, - const UseVector &MaybeLiveUses); - void MarkLive(const RetOrArg &RA); - void MarkLive(const Function &F); - void PropagateLiveness(const RetOrArg &RA); - bool RemoveDeadStuffFromFunction(Function *F); - bool DeleteDeadVarargs(Function &Fn); - bool RemoveDeadArgumentsFromCallers(Function &Fn); }; } @@ -183,7 +102,7 @@ ModulePass *llvm::createDeadArgHackingPass() { return new DAH(); } /// DeleteDeadVarargs - If this is an function that takes a ... list, and if /// llvm.vastart is never called, the varargs list is dead for the function. -bool DAE::DeleteDeadVarargs(Function &Fn) { +bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { assert(Fn.getFunctionType()->isVarArg() && "Function isn't varargs!"); if (Fn.isDeclaration() || !Fn.hasLocalLinkage()) return false; @@ -200,9 +119,9 @@ bool DAE::DeleteDeadVarargs(Function &Fn) { // Okay, we know we can transform this function if safe. Scan its body // looking for calls marked musttail or calls to llvm.vastart. - for (Function::iterator BB = Fn.begin(), E = Fn.end(); BB != E; ++BB) { - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { - CallInst *CI = dyn_cast<CallInst>(I); + for (BasicBlock &BB : Fn) { + for (Instruction &I : BB) { + CallInst *CI = dyn_cast<CallInst>(&I); if (!CI) continue; if (CI->isMustTailCall()) @@ -229,6 +148,7 @@ bool DAE::DeleteDeadVarargs(Function &Fn) { // Create the new function body and insert it into the module... Function *NF = Function::Create(NFTy, Fn.getLinkage()); NF->copyAttributesFrom(&Fn); + NF->setComdat(Fn.getComdat()); Fn.getParent()->getFunctionList().insert(Fn.getIterator(), NF); NF->takeName(&Fn); @@ -257,14 +177,17 @@ bool DAE::DeleteDeadVarargs(Function &Fn) { PAL = AttributeSet::get(Fn.getContext(), AttributesVec); } + SmallVector<OperandBundleDef, 1> OpBundles; + CS.getOperandBundlesAsDefs(OpBundles); + Instruction *New; if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), - Args, "", Call); + Args, OpBundles, "", Call); cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); cast<InvokeInst>(New)->setAttributes(PAL); } else { - New = CallInst::Create(NF, Args, "", Call); + New = CallInst::Create(NF, Args, OpBundles, "", Call); cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); cast<CallInst>(New)->setAttributes(PAL); if (cast<CallInst>(Call)->isTailCall()) @@ -316,8 +239,7 @@ bool DAE::DeleteDeadVarargs(Function &Fn) { /// RemoveDeadArgumentsFromCallers - Checks if the given function has any /// arguments that are unused, and changes the caller parameters to be undefined /// instead. -bool DAE::RemoveDeadArgumentsFromCallers(Function &Fn) -{ +bool DeadArgumentEliminationPass::RemoveDeadArgumentsFromCallers(Function &Fn) { // We cannot change the arguments if this TU does not define the function or // if the linker may choose a function body from another TU, even if the // nominal linkage indicates that other copies of the function have the same @@ -329,7 +251,7 @@ bool DAE::RemoveDeadArgumentsFromCallers(Function &Fn) // %v = load i32 %p // ret void // } - if (!Fn.isStrongDefinitionForLinker()) + if (!Fn.hasExactDefinition()) return false; // Functions with local linkage should already have been handled, except the @@ -409,7 +331,9 @@ static Type *getRetComponentType(const Function *F, unsigned Idx) { /// MarkIfNotLive - This checks Use for liveness in LiveValues. If Use is not /// live, it adds Use to the MaybeLiveUses argument. Returns the determined /// liveness of Use. -DAE::Liveness DAE::MarkIfNotLive(RetOrArg Use, UseVector &MaybeLiveUses) { +DeadArgumentEliminationPass::Liveness +DeadArgumentEliminationPass::MarkIfNotLive(RetOrArg Use, + UseVector &MaybeLiveUses) { // We're live if our use or its Function is already marked as live. if (LiveFunctions.count(Use.F) || LiveValues.count(Use)) return Live; @@ -428,8 +352,9 @@ DAE::Liveness DAE::MarkIfNotLive(RetOrArg Use, UseVector &MaybeLiveUses) { /// RetValNum is the return value number to use when this use is used in a /// return instruction. This is used in the recursion, you should always leave /// it at 0. -DAE::Liveness DAE::SurveyUse(const Use *U, - UseVector &MaybeLiveUses, unsigned RetValNum) { +DeadArgumentEliminationPass::Liveness +DeadArgumentEliminationPass::SurveyUse(const Use *U, UseVector &MaybeLiveUses, + unsigned RetValNum) { const User *V = U->getUser(); if (const ReturnInst *RI = dyn_cast<ReturnInst>(V)) { // The value is returned from a function. It's only live when the @@ -442,13 +367,14 @@ DAE::Liveness DAE::SurveyUse(const Use *U, // We might be live, depending on the liveness of Use. return MarkIfNotLive(Use, MaybeLiveUses); } else { - DAE::Liveness Result = MaybeLive; + DeadArgumentEliminationPass::Liveness Result = MaybeLive; for (unsigned i = 0; i < NumRetVals(F); ++i) { RetOrArg Use = CreateRet(F, i); // We might be live, depending on the liveness of Use. If any // sub-value is live, then the entire value is considered live. This // is a conservative choice, and better tracking is possible. - DAE::Liveness SubResult = MarkIfNotLive(Use, MaybeLiveUses); + DeadArgumentEliminationPass::Liveness SubResult = + MarkIfNotLive(Use, MaybeLiveUses); if (Result != Live) Result = SubResult; } @@ -514,7 +440,9 @@ DAE::Liveness DAE::SurveyUse(const Use *U, /// Adds all uses that cause the result to be MaybeLive to MaybeLiveRetUses. If /// the result is Live, MaybeLiveUses might be modified but its content should /// be ignored (since it might not be complete). -DAE::Liveness DAE::SurveyUses(const Value *V, UseVector &MaybeLiveUses) { +DeadArgumentEliminationPass::Liveness +DeadArgumentEliminationPass::SurveyUses(const Value *V, + UseVector &MaybeLiveUses) { // Assume it's dead (which will only hold if there are no uses at all..). Liveness Result = MaybeLive; // Check each use. @@ -534,7 +462,7 @@ DAE::Liveness DAE::SurveyUses(const Value *V, UseVector &MaybeLiveUses) { // We consider arguments of non-internal functions to be intrinsically alive as // well as arguments to functions which have their "address taken". // -void DAE::SurveyFunction(const Function &F) { +void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { // Functions with inalloca parameters are expecting args in a particular // register and memory layout. if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca)) { @@ -570,12 +498,13 @@ void DAE::SurveyFunction(const Function &F) { return; } - if (!F.hasLocalLinkage() && (!ShouldHackArguments() || F.isIntrinsic())) { + if (!F.hasLocalLinkage() && (!ShouldHackArguments || F.isIntrinsic())) { MarkLive(F); return; } - DEBUG(dbgs() << "DAE - Inspecting callers for fn: " << F.getName() << "\n"); + DEBUG(dbgs() << "DeadArgumentEliminationPass - Inspecting callers for fn: " + << F.getName() << "\n"); // Keep track of the number of live retvals, so we can skip checks once all // of them turn out to be live. unsigned NumLiveRetVals = 0; @@ -637,7 +566,8 @@ void DAE::SurveyFunction(const Function &F) { for (unsigned i = 0; i != RetCount; ++i) MarkValue(CreateRet(&F, i), RetValLiveness[i], MaybeLiveRetUses[i]); - DEBUG(dbgs() << "DAE - Inspecting args for fn: " << F.getName() << "\n"); + DEBUG(dbgs() << "DeadArgumentEliminationPass - Inspecting args for fn: " + << F.getName() << "\n"); // Now, check all of our arguments. unsigned i = 0; @@ -669,17 +599,16 @@ void DAE::SurveyFunction(const Function &F) { /// MaybeLive, it also takes all uses in MaybeLiveUses and records them in Uses, /// such that RA will be marked live if any use in MaybeLiveUses gets marked /// live later on. -void DAE::MarkValue(const RetOrArg &RA, Liveness L, - const UseVector &MaybeLiveUses) { +void DeadArgumentEliminationPass::MarkValue(const RetOrArg &RA, Liveness L, + const UseVector &MaybeLiveUses) { switch (L) { case Live: MarkLive(RA); break; case MaybeLive: { // Note any uses of this value, so this return value can be // marked live whenever one of the uses becomes live. - for (UseVector::const_iterator UI = MaybeLiveUses.begin(), - UE = MaybeLiveUses.end(); UI != UE; ++UI) - Uses.insert(std::make_pair(*UI, RA)); + for (const auto &MaybeLiveUse : MaybeLiveUses) + Uses.insert(std::make_pair(MaybeLiveUse, RA)); break; } } @@ -689,8 +618,9 @@ void DAE::MarkValue(const RetOrArg &RA, Liveness L, /// changed in any way. Additionally, /// mark any values that are used as this function's parameters or by its return /// values (according to Uses) live as well. -void DAE::MarkLive(const Function &F) { - DEBUG(dbgs() << "DAE - Intrinsically live fn: " << F.getName() << "\n"); +void DeadArgumentEliminationPass::MarkLive(const Function &F) { + DEBUG(dbgs() << "DeadArgumentEliminationPass - Intrinsically live fn: " + << F.getName() << "\n"); // Mark the function as live. LiveFunctions.insert(&F); // Mark all arguments as live. @@ -704,20 +634,21 @@ void DAE::MarkLive(const Function &F) { /// MarkLive - Mark the given return value or argument as live. Additionally, /// mark any values that are used by this value (according to Uses) live as /// well. -void DAE::MarkLive(const RetOrArg &RA) { +void DeadArgumentEliminationPass::MarkLive(const RetOrArg &RA) { if (LiveFunctions.count(RA.F)) return; // Function was already marked Live. if (!LiveValues.insert(RA).second) return; // We were already marked Live. - DEBUG(dbgs() << "DAE - Marking " << RA.getDescription() << " live\n"); + DEBUG(dbgs() << "DeadArgumentEliminationPass - Marking " + << RA.getDescription() << " live\n"); PropagateLiveness(RA); } /// PropagateLiveness - Given that RA is a live value, propagate it's liveness /// to any other values it uses (according to Uses). -void DAE::PropagateLiveness(const RetOrArg &RA) { +void DeadArgumentEliminationPass::PropagateLiveness(const RetOrArg &RA) { // We don't use upper_bound (or equal_range) here, because our recursive call // to ourselves is likely to cause the upper_bound (which is the first value // not belonging to RA) to become erased and the iterator invalidated. @@ -736,7 +667,7 @@ void DAE::PropagateLiveness(const RetOrArg &RA) { // that are not in LiveValues. Transform the function and all of the callees of // the function to not have these arguments and return values. // -bool DAE::RemoveDeadStuffFromFunction(Function *F) { +bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Don't modify fully live functions if (LiveFunctions.count(F)) return false; @@ -777,8 +708,9 @@ bool DAE::RemoveDeadStuffFromFunction(Function *F) { } } else { ++NumArgumentsEliminated; - DEBUG(dbgs() << "DAE - Removing argument " << i << " (" << I->getName() - << ") from " << F->getName() << "\n"); + DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing argument " << i + << " (" << I->getName() << ") from " << F->getName() + << "\n"); } } @@ -821,8 +753,8 @@ bool DAE::RemoveDeadStuffFromFunction(Function *F) { NewRetIdxs[i] = RetTypes.size() - 1; } else { ++NumRetValsEliminated; - DEBUG(dbgs() << "DAE - Removing return value " << i << " from " - << F->getName() << "\n"); + DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing return value " + << i << " from " << F->getName() << "\n"); } } if (RetTypes.size() > 1) { @@ -882,6 +814,7 @@ bool DAE::RemoveDeadStuffFromFunction(Function *F) { // Create the new function body and insert it into the module... Function *NF = Function::Create(NFTy, F->getLinkage()); NF->copyAttributesFrom(F); + NF->setComdat(F->getComdat()); NF->setAttributes(NewPAL); // Insert the new function before the old function, so we won't be processing // it again. @@ -950,14 +883,17 @@ bool DAE::RemoveDeadStuffFromFunction(Function *F) { // Reconstruct the AttributesList based on the vector we constructed. AttributeSet NewCallPAL = AttributeSet::get(F->getContext(), AttributesVec); + SmallVector<OperandBundleDef, 1> OpBundles; + CS.getOperandBundlesAsDefs(OpBundles); + Instruction *New; if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), - Args, "", Call->getParent()); + Args, OpBundles, "", Call->getParent()); cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv()); cast<InvokeInst>(New)->setAttributes(NewCallPAL); } else { - New = CallInst::Create(NF, Args, "", Call); + New = CallInst::Create(NF, Args, OpBundles, "", Call); cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); cast<CallInst>(New)->setAttributes(NewCallPAL); if (cast<CallInst>(Call)->isTailCall()) @@ -1045,8 +981,8 @@ bool DAE::RemoveDeadStuffFromFunction(Function *F) { // If we change the return value of the function we must rewrite any return // instructions. Check this now. if (F->getReturnType() != NF->getReturnType()) - for (Function::iterator BB = NF->begin(), E = NF->end(); BB != E; ++BB) - if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) { + for (BasicBlock &BB : *NF) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) { Value *RetVal; if (NFTy->getReturnType()->isVoidTy()) { @@ -1081,7 +1017,7 @@ bool DAE::RemoveDeadStuffFromFunction(Function *F) { // Replace the return instruction with one returning the new return // value (possibly 0 if we became void). ReturnInst::Create(F->getContext(), RetVal, RI); - BB->getInstList().erase(RI); + BB.getInstList().erase(RI); } // Patch the pointer to LLVM function in debug info descriptor. @@ -1093,14 +1029,15 @@ bool DAE::RemoveDeadStuffFromFunction(Function *F) { return true; } -bool DAE::runOnModule(Module &M) { +PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, + ModuleAnalysisManager &) { bool Changed = false; // First pass: Do a simple check to see if any functions can have their "..." // removed. We can do this if they never call va_start. This loop cannot be // fused with the next loop, because deleting a function invalidates // information computed while surveying other functions. - DEBUG(dbgs() << "DAE - Deleting dead varargs\n"); + DEBUG(dbgs() << "DeadArgumentEliminationPass - Deleting dead varargs\n"); for (Module::iterator I = M.begin(), E = M.end(); I != E; ) { Function &F = *I++; if (F.getFunctionType()->isVarArg()) @@ -1111,7 +1048,7 @@ bool DAE::runOnModule(Module &M) { // We assume all arguments are dead unless proven otherwise (allowing us to // determine that dead arguments passed into recursive functions are dead). // - DEBUG(dbgs() << "DAE - Determining liveness\n"); + DEBUG(dbgs() << "DeadArgumentEliminationPass - Determining liveness\n"); for (auto &F : M) SurveyFunction(F); @@ -1129,5 +1066,7 @@ bool DAE::runOnModule(Module &M) { for (auto &F : M) Changed |= RemoveDeadArgumentsFromCallers(F); - return Changed; + if (!Changed) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); } diff --git a/lib/Transforms/IPO/ElimAvailExtern.cpp b/lib/Transforms/IPO/ElimAvailExtern.cpp index af313a6b001d..98c4b1740306 100644 --- a/lib/Transforms/IPO/ElimAvailExtern.cpp +++ b/lib/Transforms/IPO/ElimAvailExtern.cpp @@ -13,10 +13,11 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/ElimAvailExtern.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Module.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/GlobalStatus.h" #include "llvm/Pass.h" using namespace llvm; @@ -26,30 +27,7 @@ using namespace llvm; STATISTIC(NumFunctions, "Number of functions removed"); STATISTIC(NumVariables, "Number of global variables removed"); -namespace { -struct EliminateAvailableExternally : public ModulePass { - static char ID; // Pass identification, replacement for typeid - EliminateAvailableExternally() : ModulePass(ID) { - initializeEliminateAvailableExternallyPass( - *PassRegistry::getPassRegistry()); - } - - // run - Do the EliminateAvailableExternally pass on the specified module, - // optionally updating the specified callgraph to reflect the changes. - // - bool runOnModule(Module &M) override; -}; -} - -char EliminateAvailableExternally::ID = 0; -INITIALIZE_PASS(EliminateAvailableExternally, "elim-avail-extern", - "Eliminate Available Externally Globals", false, false) - -ModulePass *llvm::createEliminateAvailableExternallyPass() { - return new EliminateAvailableExternally(); -} - -bool EliminateAvailableExternally::runOnModule(Module &M) { +static bool eliminateAvailableExternally(Module &M) { bool Changed = false; // Drop initializers of available externally global variables. @@ -82,3 +60,37 @@ bool EliminateAvailableExternally::runOnModule(Module &M) { return Changed; } + +PreservedAnalyses +EliminateAvailableExternallyPass::run(Module &M, ModuleAnalysisManager &) { + if (!eliminateAvailableExternally(M)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +namespace { +struct EliminateAvailableExternallyLegacyPass : public ModulePass { + static char ID; // Pass identification, replacement for typeid + EliminateAvailableExternallyLegacyPass() : ModulePass(ID) { + initializeEliminateAvailableExternallyLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + // run - Do the EliminateAvailableExternally pass on the specified module, + // optionally updating the specified callgraph to reflect the changes. + // + bool runOnModule(Module &M) { + if (skipModule(M)) + return false; + return eliminateAvailableExternally(M); + } +}; +} + +char EliminateAvailableExternallyLegacyPass::ID = 0; +INITIALIZE_PASS(EliminateAvailableExternallyLegacyPass, "elim-avail-extern", + "Eliminate Available Externally Globals", false, false) + +ModulePass *llvm::createEliminateAvailableExternallyPass() { + return new EliminateAvailableExternallyLegacyPass(); +} diff --git a/lib/Transforms/IPO/ExtractGV.cpp b/lib/Transforms/IPO/ExtractGV.cpp index 1a3b9253d72f..479fd182598a 100644 --- a/lib/Transforms/IPO/ExtractGV.cpp +++ b/lib/Transforms/IPO/ExtractGV.cpp @@ -68,6 +68,9 @@ namespace { : ModulePass(ID), Named(GVs.begin(), GVs.end()), deleteStuff(deleteS) {} bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + // Visit the global inline asm. if (!deleteStuff) M.setModuleInlineAsm(""); @@ -101,20 +104,20 @@ namespace { } // Visit the Functions. - for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { + for (Function &F : M) { bool Delete = - deleteStuff == (bool)Named.count(&*I) && !I->isDeclaration(); + deleteStuff == (bool)Named.count(&F) && !F.isDeclaration(); if (!Delete) { - if (I->hasAvailableExternallyLinkage()) + if (F.hasAvailableExternallyLinkage()) continue; } - makeVisible(*I, Delete); + makeVisible(F, Delete); if (Delete) { // Make this a declaration and drop it's comdat. - I->deleteBody(); - I->setComdat(nullptr); + F.deleteBody(); + F.setComdat(nullptr); } } @@ -128,7 +131,7 @@ namespace { makeVisible(*CurI, Delete); if (Delete) { - Type *Ty = CurI->getType()->getElementType(); + Type *Ty = CurI->getValueType(); CurI->removeFromParent(); llvm::Value *Declaration; diff --git a/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 6df044762cf4..968712138208 100644 --- a/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -80,7 +80,8 @@ static void addForcedAttributes(Function &F) { } } -PreservedAnalyses ForceFunctionAttrsPass::run(Module &M) { +PreservedAnalyses ForceFunctionAttrsPass::run(Module &M, + ModuleAnalysisManager &) { if (ForceAttributes.empty()) return PreservedAnalyses::all(); diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp index 527fdd1885a4..fff544085414 100644 --- a/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/lib/Transforms/IPO/FunctionAttrs.cpp @@ -13,6 +13,7 @@ /// //===----------------------------------------------------------------------===// +#include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/IPO.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/SetVector.h" @@ -52,38 +53,6 @@ typedef SmallSetVector<Function *, 8> SCCNodeSet; } namespace { -struct PostOrderFunctionAttrs : public CallGraphSCCPass { - static char ID; // Pass identification, replacement for typeid - PostOrderFunctionAttrs() : CallGraphSCCPass(ID) { - initializePostOrderFunctionAttrsPass(*PassRegistry::getPassRegistry()); - } - - bool runOnSCC(CallGraphSCC &SCC) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - CallGraphSCCPass::getAnalysisUsage(AU); - } - -private: - TargetLibraryInfo *TLI; -}; -} - -char PostOrderFunctionAttrs::ID = 0; -INITIALIZE_PASS_BEGIN(PostOrderFunctionAttrs, "functionattrs", - "Deduce function attributes", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(PostOrderFunctionAttrs, "functionattrs", - "Deduce function attributes", false, false) - -Pass *llvm::createPostOrderFunctionAttrsPass() { return new PostOrderFunctionAttrs(); } - -namespace { /// The three kinds of memory access relevant to 'readonly' and /// 'readnone' attributes. enum MemoryAccessKind { @@ -100,9 +69,10 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, AAResults &AAR, // Already perfect! return MAK_ReadNone; - // Definitions with weak linkage may be overridden at linktime with - // something that writes memory, so treat them like declarations. - if (F.isDeclaration() || F.mayBeOverridden()) { + // Non-exact function definitions may not be selected at link time, and an + // alternative version that writes to memory may be selected. See the comment + // on GlobalValue::isDefinitionExact for more details. + if (!F.hasExactDefinition()) { if (AliasAnalysis::onlyReadsMemory(MRB)) return MAK_ReadOnly; @@ -119,8 +89,12 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, AAResults &AAR, // Detect these now, skipping to the next instruction if one is found. CallSite CS(cast<Value>(I)); if (CS) { - // Ignore calls to functions in the same SCC. - if (CS.getCalledFunction() && SCCNodes.count(CS.getCalledFunction())) + // Ignore calls to functions in the same SCC, as long as the call sites + // don't have operand bundles. Calls with operand bundles are allowed to + // have memory effects not described by the memory effects of the call + // target. + if (!CS.hasOperandBundles() && CS.getCalledFunction() && + SCCNodes.count(CS.getCalledFunction())) continue; FunctionModRefBehavior MRB = AAR.getModRefBehavior(CS); @@ -311,8 +285,7 @@ struct ArgumentUsesTracker : public CaptureTracker { } Function *F = CS.getCalledFunction(); - if (!F || F->isDeclaration() || F->mayBeOverridden() || - !SCCNodes.count(F)) { + if (!F || !F->hasExactDefinition() || !SCCNodes.count(F)) { Captured = true; return true; } @@ -490,6 +463,11 @@ determinePointerReadAttrs(Argument *A, } case Instruction::Load: + // A volatile load has side effects beyond what readonly can be relied + // upon. + if (cast<LoadInst>(I)->isVolatile()) + return Attribute::None; + IsRead = true; break; @@ -517,9 +495,10 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { // Check each function in turn, determining which pointer arguments are not // captured. for (Function *F : SCCNodes) { - // Definitions with weak linkage may be overridden at linktime with - // something that captures pointers, so treat them like declarations. - if (F->isDeclaration() || F->mayBeOverridden()) + // We can infer and propagate function attributes only when we know that the + // definition we'll get at link time is *exactly* the definition we see now. + // For more details, see GlobalValue::mayBeDerefined. + if (!F->hasExactDefinition()) continue; // Functions that are readonly (or readnone) and nounwind and don't return @@ -557,12 +536,9 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { // then it must be calling into another function in our SCC. Save // its particulars for Argument-SCC analysis later. ArgumentGraphNode *Node = AG[&*A]; - for (SmallVectorImpl<Argument *>::iterator - UI = Tracker.Uses.begin(), - UE = Tracker.Uses.end(); - UI != UE; ++UI) { - Node->Uses.push_back(AG[*UI]); - if (*UI != A) + for (Argument *Use : Tracker.Uses) { + Node->Uses.push_back(AG[Use]); + if (Use != &*A) HasNonLocalUses = true; } } @@ -627,17 +603,15 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { SmallPtrSet<Argument *, 8> ArgumentSCCNodes; // Fill ArgumentSCCNodes with the elements of the ArgumentSCC. Used for // quickly looking up whether a given Argument is in this ArgumentSCC. - for (auto I = ArgumentSCC.begin(), E = ArgumentSCC.end(); I != E; ++I) { - ArgumentSCCNodes.insert((*I)->Definition); + for (ArgumentGraphNode *I : ArgumentSCC) { + ArgumentSCCNodes.insert(I->Definition); } for (auto I = ArgumentSCC.begin(), E = ArgumentSCC.end(); I != E && !SCCCaptured; ++I) { ArgumentGraphNode *N = *I; - for (SmallVectorImpl<ArgumentGraphNode *>::iterator UI = N->Uses.begin(), - UE = N->Uses.end(); - UI != UE; ++UI) { - Argument *A = (*UI)->Definition; + for (ArgumentGraphNode *Use : N->Uses) { + Argument *A = Use->Definition; if (A->hasNoCaptureAttr() || ArgumentSCCNodes.count(A)) continue; SCCCaptured = true; @@ -703,8 +677,8 @@ static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { /// doesn't alias any other pointer visible to the caller. static bool isFunctionMallocLike(Function *F, const SCCNodeSet &SCCNodes) { SmallSetVector<Value *, 8> FlowsToReturn; - for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) - if (ReturnInst *Ret = dyn_cast<ReturnInst>(I->getTerminator())) + for (BasicBlock &BB : *F) + if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB.getTerminator())) FlowsToReturn.insert(Ret->getReturnValue()); for (unsigned i = 0; i != FlowsToReturn.size(); ++i) { @@ -772,9 +746,10 @@ static bool addNoAliasAttrs(const SCCNodeSet &SCCNodes) { if (F->doesNotAlias(0)) continue; - // Definitions with weak linkage may be overridden at linktime, so - // treat them like declarations. - if (F->isDeclaration() || F->mayBeOverridden()) + // We can infer and propagate function attributes only when we know that the + // definition we'll get at link time is *exactly* the definition we see now. + // For more details, see GlobalValue::mayBeDerefined. + if (!F->hasExactDefinition()) return false; // We annotate noalias return values, which are only applicable to @@ -807,7 +782,7 @@ static bool addNoAliasAttrs(const SCCNodeSet &SCCNodes) { /// \p Speculative based on whether the returned conclusion is a speculative /// conclusion due to SCC calls. static bool isReturnNonNull(Function *F, const SCCNodeSet &SCCNodes, - const TargetLibraryInfo &TLI, bool &Speculative) { + bool &Speculative) { assert(F->getReturnType()->isPointerTy() && "nonnull only meaningful on pointer types"); Speculative = false; @@ -821,7 +796,7 @@ static bool isReturnNonNull(Function *F, const SCCNodeSet &SCCNodes, Value *RetVal = FlowsToReturn[i]; // If this value is locally known to be non-null, we're good - if (isKnownNonNull(RetVal, &TLI)) + if (isKnownNonNull(RetVal)) continue; // Otherwise, we need to look upwards since we can't make any local @@ -870,8 +845,7 @@ static bool isReturnNonNull(Function *F, const SCCNodeSet &SCCNodes, } /// Deduce nonnull attributes for the SCC. -static bool addNonNullAttrs(const SCCNodeSet &SCCNodes, - const TargetLibraryInfo &TLI) { +static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { // Speculative that all functions in the SCC return only nonnull // pointers. We may refute this as we analyze functions. bool SCCReturnsNonNull = true; @@ -886,9 +860,10 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes, Attribute::NonNull)) continue; - // Definitions with weak linkage may be overridden at linktime, so - // treat them like declarations. - if (F->isDeclaration() || F->mayBeOverridden()) + // We can infer and propagate function attributes only when we know that the + // definition we'll get at link time is *exactly* the definition we see now. + // For more details, see GlobalValue::mayBeDerefined. + if (!F->hasExactDefinition()) return false; // We annotate nonnull return values, which are only applicable to @@ -897,7 +872,7 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes, continue; bool Speculative = false; - if (isReturnNonNull(F, SCCNodes, TLI, Speculative)) { + if (isReturnNonNull(F, SCCNodes, Speculative)) { if (!Speculative) { // Mark the function eagerly since we may discover a function // which prevents us from speculating about the entire SCC @@ -930,6 +905,49 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes, return MadeChange; } +/// Remove the convergent attribute from all functions in the SCC if every +/// callsite within the SCC is not convergent (except for calls to functions +/// within the SCC). Returns true if changes were made. +static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) { + // For every function in SCC, ensure that either + // * it is not convergent, or + // * we can remove its convergent attribute. + bool HasConvergentFn = false; + for (Function *F : SCCNodes) { + if (!F->isConvergent()) continue; + HasConvergentFn = true; + + // Can't remove convergent from function declarations. + if (F->isDeclaration()) return false; + + // Can't remove convergent if any of our functions has a convergent call to a + // function not in the SCC. + for (Instruction &I : instructions(*F)) { + CallSite CS(&I); + // Bail if CS is a convergent call to a function not in the SCC. + if (CS && CS.isConvergent() && + SCCNodes.count(CS.getCalledFunction()) == 0) + return false; + } + } + + // If the SCC doesn't have any convergent functions, we have nothing to do. + if (!HasConvergentFn) return false; + + // If we got here, all of the calls the SCC makes to functions not in the SCC + // are non-convergent. Therefore all of the SCC's functions can also be made + // non-convergent. We'll remove the attr from the callsites in + // InstCombineCalls. + for (Function *F : SCCNodes) { + if (!F->isConvergent()) continue; + + DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName() + << "\n"); + F->setNotConvergent(); + } + return true; +} + static bool setDoesNotRecurse(Function &F) { if (F.doesNotRecurse()) return false; @@ -938,56 +956,129 @@ static bool setDoesNotRecurse(Function &F) { return true; } -static bool addNoRecurseAttrs(const CallGraphSCC &SCC) { +static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { // Try and identify functions that do not recurse. // If the SCC contains multiple nodes we know for sure there is recursion. - if (!SCC.isSingular()) + if (SCCNodes.size() != 1) return false; - const CallGraphNode *CGN = *SCC.begin(); - Function *F = CGN->getFunction(); + Function *F = *SCCNodes.begin(); if (!F || F->isDeclaration() || F->doesNotRecurse()) return false; // If all of the calls in F are identifiable and are to norecurse functions, F // is norecurse. This check also detects self-recursion as F is not currently // marked norecurse, so any called from F to F will not be marked norecurse. - if (std::all_of(CGN->begin(), CGN->end(), - [](const CallGraphNode::CallRecord &CR) { - Function *F = CR.second->getFunction(); - return F && F->doesNotRecurse(); - })) - // Function calls a potentially recursive function. - return setDoesNotRecurse(*F); - - // Nothing else we can deduce usefully during the postorder traversal. - return false; + for (Instruction &I : instructions(*F)) + if (auto CS = CallSite(&I)) { + Function *Callee = CS.getCalledFunction(); + if (!Callee || Callee == F || !Callee->doesNotRecurse()) + // Function calls a potentially recursive function. + return false; + } + + // Every call was to a non-recursive function other than this function, and + // we have no indirect recursion as the SCC size is one. This function cannot + // recurse. + return setDoesNotRecurse(*F); } -bool PostOrderFunctionAttrs::runOnSCC(CallGraphSCC &SCC) { - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - bool Changed = false; +PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C).getManager(); - // We compute dedicated AA results for each function in the SCC as needed. We - // use a lambda referencing external objects so that they live long enough to - // be queried, but we re-use them each time. - Optional<BasicAAResult> BAR; - Optional<AAResults> AAR; + // We pass a lambda into functions to wire them up to the analysis manager + // for getting function analyses. auto AARGetter = [&](Function &F) -> AAResults & { - BAR.emplace(createLegacyPMBasicAAResult(*this, F)); - AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); - return *AAR; + return FAM.getResult<AAManager>(F); }; + // Fill SCCNodes with the elements of the SCC. Also track whether there are + // any external or opt-none nodes that will prevent us from optimizing any + // part of the SCC. + SCCNodeSet SCCNodes; + bool HasUnknownCall = false; + for (LazyCallGraph::Node &N : C) { + Function &F = N.getFunction(); + if (F.hasFnAttribute(Attribute::OptimizeNone)) { + // Treat any function we're trying not to optimize as if it were an + // indirect call and omit it from the node set used below. + HasUnknownCall = true; + continue; + } + // Track whether any functions in this SCC have an unknown call edge. + // Note: if this is ever a performance hit, we can common it with + // subsequent routines which also do scans over the instructions of the + // function. + if (!HasUnknownCall) + for (Instruction &I : instructions(F)) + if (auto CS = CallSite(&I)) + if (!CS.getCalledFunction()) { + HasUnknownCall = true; + break; + } + + SCCNodes.insert(&F); + } + + bool Changed = false; + Changed |= addReadAttrs(SCCNodes, AARGetter); + Changed |= addArgumentAttrs(SCCNodes); + + // If we have no external nodes participating in the SCC, we can deduce some + // more precise attributes as well. + if (!HasUnknownCall) { + Changed |= addNoAliasAttrs(SCCNodes); + Changed |= addNonNullAttrs(SCCNodes); + Changed |= removeConvergentAttrs(SCCNodes); + Changed |= addNoRecurseAttrs(SCCNodes); + } + + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} + +namespace { +struct PostOrderFunctionAttrsLegacyPass : public CallGraphSCCPass { + static char ID; // Pass identification, replacement for typeid + PostOrderFunctionAttrsLegacyPass() : CallGraphSCCPass(ID) { + initializePostOrderFunctionAttrsLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnSCC(CallGraphSCC &SCC) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); + getAAResultsAnalysisUsage(AU); + CallGraphSCCPass::getAnalysisUsage(AU); + } +}; +} + +char PostOrderFunctionAttrsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PostOrderFunctionAttrsLegacyPass, "functionattrs", + "Deduce function attributes", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_END(PostOrderFunctionAttrsLegacyPass, "functionattrs", + "Deduce function attributes", false, false) + +Pass *llvm::createPostOrderFunctionAttrsLegacyPass() { return new PostOrderFunctionAttrsLegacyPass(); } + +template <typename AARGetterT> +static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { + bool Changed = false; + // Fill SCCNodes with the elements of the SCC. Used for quickly looking up // whether a given CallGraphNode is in this SCC. Also track whether there are // any external or opt-none nodes that will prevent us from optimizing any // part of the SCC. SCCNodeSet SCCNodes; bool ExternalNode = false; - for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) { - Function *F = (*I)->getFunction(); + for (CallGraphNode *I : SCC) { + Function *F = I->getFunction(); if (!F || F->hasFnAttribute(Attribute::OptimizeNone)) { // External node or function we're trying not to optimize - we both avoid // transform them and avoid leveraging information they provide. @@ -1005,28 +1096,37 @@ bool PostOrderFunctionAttrs::runOnSCC(CallGraphSCC &SCC) { // more precise attributes as well. if (!ExternalNode) { Changed |= addNoAliasAttrs(SCCNodes); - Changed |= addNonNullAttrs(SCCNodes, *TLI); + Changed |= addNonNullAttrs(SCCNodes); + Changed |= removeConvergentAttrs(SCCNodes); + Changed |= addNoRecurseAttrs(SCCNodes); } - Changed |= addNoRecurseAttrs(SCC); return Changed; } +bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) { + if (skipSCC(SCC)) + return false; + + // We compute dedicated AA results for each function in the SCC as needed. We + // use a lambda referencing external objects so that they live long enough to + // be queried, but we re-use them each time. + Optional<BasicAAResult> BAR; + Optional<AAResults> AAR; + auto AARGetter = [&](Function &F) -> AAResults & { + BAR.emplace(createLegacyPMBasicAAResult(*this, F)); + AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); + return *AAR; + }; + + return runImpl(SCC, AARGetter); +} + namespace { -/// A pass to do RPO deduction and propagation of function attributes. -/// -/// This pass provides a general RPO or "top down" propagation of -/// function attributes. For a few (rare) cases, we can deduce significantly -/// more about function attributes by working in RPO, so this pass -/// provides the compliment to the post-order pass above where the majority of -/// deduction is performed. -// FIXME: Currently there is no RPO CGSCC pass structure to slide into and so -// this is a boring module pass, but eventually it should be an RPO CGSCC pass -// when such infrastructure is available. -struct ReversePostOrderFunctionAttrs : public ModulePass { +struct ReversePostOrderFunctionAttrsLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid - ReversePostOrderFunctionAttrs() : ModulePass(ID) { - initializeReversePostOrderFunctionAttrsPass(*PassRegistry::getPassRegistry()); + ReversePostOrderFunctionAttrsLegacyPass() : ModulePass(ID) { + initializeReversePostOrderFunctionAttrsLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnModule(Module &M) override; @@ -1034,19 +1134,20 @@ struct ReversePostOrderFunctionAttrs : public ModulePass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired<CallGraphWrapperPass>(); + AU.addPreserved<CallGraphWrapperPass>(); } }; } -char ReversePostOrderFunctionAttrs::ID = 0; -INITIALIZE_PASS_BEGIN(ReversePostOrderFunctionAttrs, "rpo-functionattrs", +char ReversePostOrderFunctionAttrsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(ReversePostOrderFunctionAttrsLegacyPass, "rpo-functionattrs", "Deduce function attributes in RPO", false, false) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END(ReversePostOrderFunctionAttrs, "rpo-functionattrs", +INITIALIZE_PASS_END(ReversePostOrderFunctionAttrsLegacyPass, "rpo-functionattrs", "Deduce function attributes in RPO", false, false) Pass *llvm::createReversePostOrderFunctionAttrsPass() { - return new ReversePostOrderFunctionAttrs(); + return new ReversePostOrderFunctionAttrsLegacyPass(); } static bool addNoRecurseAttrsTopDown(Function &F) { @@ -1078,7 +1179,7 @@ static bool addNoRecurseAttrsTopDown(Function &F) { return setDoesNotRecurse(F); } -bool ReversePostOrderFunctionAttrs::runOnModule(Module &M) { +static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) { // We only have a post-order SCC traversal (because SCCs are inherently // discovered in post-order), so we accumulate them in a vector and then walk // it in reverse. This is simpler than using the RPO iterator infrastructure @@ -1086,7 +1187,6 @@ bool ReversePostOrderFunctionAttrs::runOnModule(Module &M) { // graph. We can also cheat egregiously because we're primarily interested in // synthesizing norecurse and so we can only save the singular SCCs as SCCs // with multiple functions in them will clearly be recursive. - auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); SmallVector<Function *, 16> Worklist; for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) { if (I->size() != 1) @@ -1104,3 +1204,24 @@ bool ReversePostOrderFunctionAttrs::runOnModule(Module &M) { return Changed; } + +bool ReversePostOrderFunctionAttrsLegacyPass::runOnModule(Module &M) { + if (skipModule(M)) + return false; + + auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + + return deduceFunctionAttributeInRPO(M, CG); +} + +PreservedAnalyses +ReversePostOrderFunctionAttrsPass::run(Module &M, AnalysisManager<Module> &AM) { + auto &CG = AM.getResult<CallGraphAnalysis>(M); + + bool Changed = deduceFunctionAttributeInRPO(M, CG); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<CallGraphAnalysis>(); + return PA; +} diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp index 5e0df9505119..c9d075e76325 100644 --- a/lib/Transforms/IPO/FunctionImport.cpp +++ b/lib/Transforms/IPO/FunctionImport.cpp @@ -13,329 +13,670 @@ #include "llvm/Transforms/IPO/FunctionImport.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/Triple.h" #include "llvm/IR/AutoUpgrade.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" -#include "llvm/Object/FunctionIndexObjectFile.h" +#include "llvm/Object/IRObjectFile.h" +#include "llvm/Object/ModuleSummaryIndexObjectFile.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/IPO/Internalize.h" +#include "llvm/Transforms/Utils/FunctionImportUtils.h" -#include <map> +#define DEBUG_TYPE "function-import" using namespace llvm; -#define DEBUG_TYPE "function-import" +STATISTIC(NumImported, "Number of functions imported"); /// Limit on instruction count of imported functions. static cl::opt<unsigned> ImportInstrLimit( "import-instr-limit", cl::init(100), cl::Hidden, cl::value_desc("N"), cl::desc("Only import functions with less than N instructions")); +static cl::opt<float> + ImportInstrFactor("import-instr-evolution-factor", cl::init(0.7), + cl::Hidden, cl::value_desc("x"), + cl::desc("As we import functions, multiply the " + "`import-instr-limit` threshold by this factor " + "before processing newly imported functions")); + +static cl::opt<bool> PrintImports("print-imports", cl::init(false), cl::Hidden, + cl::desc("Print imported functions")); + +// Temporary allows the function import pass to disable always linking +// referenced discardable symbols. +static cl::opt<bool> + DontForceImportReferencedDiscardableSymbols("disable-force-link-odr", + cl::init(false), cl::Hidden); + +static cl::opt<bool> EnableImportMetadata( + "enable-import-metadata", cl::init( +#if !defined(NDEBUG) + true /*Enabled with asserts.*/ +#else + false +#endif + ), + cl::Hidden, cl::desc("Enable import metadata like 'thinlto_src_module'")); + // Load lazily a module from \p FileName in \p Context. static std::unique_ptr<Module> loadFile(const std::string &FileName, LLVMContext &Context) { SMDiagnostic Err; DEBUG(dbgs() << "Loading '" << FileName << "'\n"); - // Metadata isn't loaded or linked until after all functions are - // imported, after which it will be materialized and linked. + // Metadata isn't loaded until functions are imported, to minimize + // the memory overhead. std::unique_ptr<Module> Result = getLazyIRFileModule(FileName, Err, Context, /* ShouldLazyLoadMetadata = */ true); if (!Result) { Err.print("function-import", errs()); - return nullptr; + report_fatal_error("Abort"); } return Result; } namespace { -/// Helper to load on demand a Module from file and cache it for subsequent -/// queries. It can be used with the FunctionImporter. -class ModuleLazyLoaderCache { - /// Cache of lazily loaded module for import. - StringMap<std::unique_ptr<Module>> ModuleMap; - /// Retrieve a Module from the cache or lazily load it on demand. - std::function<std::unique_ptr<Module>(StringRef FileName)> createLazyModule; +// Return true if the Summary describes a GlobalValue that can be externally +// referenced, i.e. it does not need renaming (linkage is not local) or renaming +// is possible (does not have a section for instance). +static bool canBeExternallyReferenced(const GlobalValueSummary &Summary) { + if (!Summary.needsRenaming()) + return true; -public: - /// Create the loader, Module will be initialized in \p Context. - ModuleLazyLoaderCache(std::function< - std::unique_ptr<Module>(StringRef FileName)> createLazyModule) - : createLazyModule(createLazyModule) {} - - /// Retrieve a Module from the cache or lazily load it on demand. - Module &operator()(StringRef FileName); - - std::unique_ptr<Module> takeModule(StringRef FileName) { - auto I = ModuleMap.find(FileName); - assert(I != ModuleMap.end()); - std::unique_ptr<Module> Ret = std::move(I->second); - ModuleMap.erase(I); - return Ret; - } -}; + if (Summary.hasSection()) + // Can't rename a global that needs renaming if has a section. + return false; -// Get a Module for \p FileName from the cache, or load it lazily. -Module &ModuleLazyLoaderCache::operator()(StringRef Identifier) { - auto &Module = ModuleMap[Identifier]; - if (!Module) - Module = createLazyModule(Identifier); - return *Module; + return true; } -} // anonymous namespace -/// Walk through the instructions in \p F looking for external -/// calls not already in the \p CalledFunctions set. If any are -/// found they are added to the \p Worklist for importing. -static void findExternalCalls(const Module &DestModule, Function &F, - const FunctionInfoIndex &Index, - StringSet<> &CalledFunctions, - SmallVector<StringRef, 64> &Worklist) { - // We need to suffix internal function calls imported from other modules, - // prepare the suffix ahead of time. - std::string Suffix; - if (F.getParent() != &DestModule) - Suffix = - (Twine(".llvm.") + - Twine(Index.getModuleId(F.getParent()->getModuleIdentifier()))).str(); - - for (auto &BB : F) { - for (auto &I : BB) { - if (isa<CallInst>(I)) { - auto CalledFunction = cast<CallInst>(I).getCalledFunction(); - // Insert any new external calls that have not already been - // added to set/worklist. - if (!CalledFunction || !CalledFunction->hasName()) - continue; - // Ignore intrinsics early - if (CalledFunction->isIntrinsic()) { - assert(CalledFunction->getIntrinsicID() != 0); - continue; - } - auto ImportedName = CalledFunction->getName(); - auto Renamed = (ImportedName + Suffix).str(); - // Rename internal functions - if (CalledFunction->hasInternalLinkage()) { - ImportedName = Renamed; - } - auto It = CalledFunctions.insert(ImportedName); - if (!It.second) { - // This is a call to a function we already considered, skip. - continue; - } - // Ignore functions already present in the destination module - auto *SrcGV = DestModule.getNamedValue(ImportedName); - if (SrcGV) { - if (GlobalAlias *SGA = dyn_cast<GlobalAlias>(SrcGV)) - SrcGV = SGA->getBaseObject(); - assert(isa<Function>(SrcGV) && "Name collision during import"); - if (!cast<Function>(SrcGV)->isDeclaration()) { - DEBUG(dbgs() << DestModule.getModuleIdentifier() << ": Ignoring " - << ImportedName << " already in DestinationModule\n"); - continue; - } +// Return true if \p GUID describes a GlobalValue that can be externally +// referenced, i.e. it does not need renaming (linkage is not local) or +// renaming is possible (does not have a section for instance). +static bool canBeExternallyReferenced(const ModuleSummaryIndex &Index, + GlobalValue::GUID GUID) { + auto Summaries = Index.findGlobalValueSummaryList(GUID); + if (Summaries == Index.end()) + return true; + if (Summaries->second.size() != 1) + // If there are multiple globals with this GUID, then we know it is + // not a local symbol, and it is necessarily externally referenced. + return true; + + // We don't need to check for the module path, because if it can't be + // externally referenced and we call it, it is necessarilly in the same + // module + return canBeExternallyReferenced(**Summaries->second.begin()); +} + +// Return true if the global described by \p Summary can be imported in another +// module. +static bool eligibleForImport(const ModuleSummaryIndex &Index, + const GlobalValueSummary &Summary) { + if (!canBeExternallyReferenced(Summary)) + // Can't import a global that needs renaming if has a section for instance. + // FIXME: we may be able to import it by copying it without promotion. + return false; + + // Check references (and potential calls) in the same module. If the current + // value references a global that can't be externally referenced it is not + // eligible for import. + bool AllRefsCanBeExternallyReferenced = + llvm::all_of(Summary.refs(), [&](const ValueInfo &VI) { + return canBeExternallyReferenced(Index, VI.getGUID()); + }); + if (!AllRefsCanBeExternallyReferenced) + return false; + + if (auto *FuncSummary = dyn_cast<FunctionSummary>(&Summary)) { + bool AllCallsCanBeExternallyReferenced = llvm::all_of( + FuncSummary->calls(), [&](const FunctionSummary::EdgeTy &Edge) { + return canBeExternallyReferenced(Index, Edge.first.getGUID()); + }); + if (!AllCallsCanBeExternallyReferenced) + return false; + } + return true; +} + +/// Given a list of possible callee implementation for a call site, select one +/// that fits the \p Threshold. +/// +/// FIXME: select "best" instead of first that fits. But what is "best"? +/// - The smallest: more likely to be inlined. +/// - The one with the least outgoing edges (already well optimized). +/// - One from a module already being imported from in order to reduce the +/// number of source modules parsed/linked. +/// - One that has PGO data attached. +/// - [insert you fancy metric here] +static const GlobalValueSummary * +selectCallee(const ModuleSummaryIndex &Index, + const GlobalValueSummaryList &CalleeSummaryList, + unsigned Threshold) { + auto It = llvm::find_if( + CalleeSummaryList, + [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) { + auto *GVSummary = SummaryPtr.get(); + if (GlobalValue::isInterposableLinkage(GVSummary->linkage())) + // There is no point in importing these, we can't inline them + return false; + if (auto *AS = dyn_cast<AliasSummary>(GVSummary)) { + GVSummary = &AS->getAliasee(); + // Alias can't point to "available_externally". However when we import + // linkOnceODR the linkage does not change. So we import the alias + // and aliasee only in this case. + // FIXME: we should import alias as available_externally *function*, + // the destination module does need to know it is an alias. + if (!GlobalValue::isLinkOnceODRLinkage(GVSummary->linkage())) + return false; } - Worklist.push_back(It.first->getKey()); - DEBUG(dbgs() << DestModule.getModuleIdentifier() - << ": Adding callee for : " << ImportedName << " : " - << F.getName() << "\n"); - } - } + auto *Summary = cast<FunctionSummary>(GVSummary); + + if (Summary->instCount() > Threshold) + return false; + + if (!eligibleForImport(Index, *Summary)) + return false; + + return true; + }); + if (It == CalleeSummaryList.end()) + return nullptr; + + return cast<GlobalValueSummary>(It->get()); +} + +/// Return the summary for the function \p GUID that fits the \p Threshold, or +/// null if there's no match. +static const GlobalValueSummary *selectCallee(GlobalValue::GUID GUID, + unsigned Threshold, + const ModuleSummaryIndex &Index) { + auto CalleeSummaryList = Index.findGlobalValueSummaryList(GUID); + if (CalleeSummaryList == Index.end()) + return nullptr; // This function does not have a summary + return selectCallee(Index, CalleeSummaryList->second, Threshold); +} + +/// Mark the global \p GUID as export by module \p ExportModulePath if found in +/// this module. If it is a GlobalVariable, we also mark any referenced global +/// in the current module as exported. +static void exportGlobalInModule(const ModuleSummaryIndex &Index, + StringRef ExportModulePath, + GlobalValue::GUID GUID, + FunctionImporter::ExportSetTy &ExportList) { + auto FindGlobalSummaryInModule = + [&](GlobalValue::GUID GUID) -> GlobalValueSummary *{ + auto SummaryList = Index.findGlobalValueSummaryList(GUID); + if (SummaryList == Index.end()) + // This global does not have a summary, it is not part of the ThinLTO + // process + return nullptr; + auto SummaryIter = llvm::find_if( + SummaryList->second, + [&](const std::unique_ptr<GlobalValueSummary> &Summary) { + return Summary->modulePath() == ExportModulePath; + }); + if (SummaryIter == SummaryList->second.end()) + return nullptr; + return SummaryIter->get(); + }; + + auto *Summary = FindGlobalSummaryInModule(GUID); + if (!Summary) + return; + // We found it in the current module, mark as exported + ExportList.insert(GUID); + + auto GVS = dyn_cast<GlobalVarSummary>(Summary); + if (!GVS) + return; + // FunctionImportGlobalProcessing::doPromoteLocalToGlobal() will always + // trigger importing the initializer for `constant unnamed addr` globals that + // are referenced. We conservatively export all the referenced symbols for + // every global to workaround this, so that the ExportList is accurate. + // FIXME: with a "isConstant" flag in the summary we could be more targetted. + for (auto &Ref : GVS->refs()) { + auto GUID = Ref.getGUID(); + auto *RefSummary = FindGlobalSummaryInModule(GUID); + if (RefSummary) + // Found a ref in the current module, mark it as exported + ExportList.insert(GUID); } } -// Helper function: given a worklist and an index, will process all the worklist -// and decide what to import based on the summary information. -// -// Nothing is actually imported, functions are materialized in their source -// module and analyzed there. -// -// \p ModuleToFunctionsToImportMap is filled with the set of Function to import -// per Module. -static void GetImportList(Module &DestModule, - SmallVector<StringRef, 64> &Worklist, - StringSet<> &CalledFunctions, - std::map<StringRef, DenseSet<const GlobalValue *>> - &ModuleToFunctionsToImportMap, - const FunctionInfoIndex &Index, - ModuleLazyLoaderCache &ModuleLoaderCache) { - while (!Worklist.empty()) { - auto CalledFunctionName = Worklist.pop_back_val(); - DEBUG(dbgs() << DestModule.getModuleIdentifier() << ": Process import for " - << CalledFunctionName << "\n"); - - // Try to get a summary for this function call. - auto InfoList = Index.findFunctionInfoList(CalledFunctionName); - if (InfoList == Index.end()) { - DEBUG(dbgs() << DestModule.getModuleIdentifier() << ": No summary for " - << CalledFunctionName << " Ignoring.\n"); +using EdgeInfo = std::pair<const FunctionSummary *, unsigned /* Threshold */>; + +/// Compute the list of functions to import for a given caller. Mark these +/// imported functions and the symbols they reference in their source module as +/// exported from their source module. +static void computeImportForFunction( + const FunctionSummary &Summary, const ModuleSummaryIndex &Index, + unsigned Threshold, const GVSummaryMapTy &DefinedGVSummaries, + SmallVectorImpl<EdgeInfo> &Worklist, + FunctionImporter::ImportMapTy &ImportsForModule, + StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { + for (auto &Edge : Summary.calls()) { + auto GUID = Edge.first.getGUID(); + DEBUG(dbgs() << " edge -> " << GUID << " Threshold:" << Threshold << "\n"); + + if (DefinedGVSummaries.count(GUID)) { + DEBUG(dbgs() << "ignored! Target already in destination module.\n"); continue; } - assert(!InfoList->second.empty() && "No summary, error at import?"); - - // Comdat can have multiple entries, FIXME: what do we do with them? - auto &Info = InfoList->second[0]; - assert(Info && "Nullptr in list, error importing summaries?\n"); - - auto *Summary = Info->functionSummary(); - if (!Summary) { - // FIXME: in case we are lazyloading summaries, we can do it now. - DEBUG(dbgs() << DestModule.getModuleIdentifier() - << ": Missing summary for " << CalledFunctionName - << ", error at import?\n"); - llvm_unreachable("Missing summary"); - } - if (Summary->instCount() > ImportInstrLimit) { - DEBUG(dbgs() << DestModule.getModuleIdentifier() << ": Skip import of " - << CalledFunctionName << " with " << Summary->instCount() - << " instructions (limit " << ImportInstrLimit << ")\n"); + auto *CalleeSummary = selectCallee(GUID, Threshold, Index); + if (!CalleeSummary) { + DEBUG(dbgs() << "ignored! No qualifying callee with summary found.\n"); continue; } - - // Get the module path from the summary. - auto ModuleIdentifier = Summary->modulePath(); - DEBUG(dbgs() << DestModule.getModuleIdentifier() << ": Importing " - << CalledFunctionName << " from " << ModuleIdentifier << "\n"); - - auto &SrcModule = ModuleLoaderCache(ModuleIdentifier); - - // The function that we will import! - GlobalValue *SGV = SrcModule.getNamedValue(CalledFunctionName); - - if (!SGV) { - // The destination module is referencing function using their renamed name - // when importing a function that was originally local in the source - // module. The source module we have might not have been renamed so we try - // to remove the suffix added during the renaming to recover the original - // name in the source module. - std::pair<StringRef, StringRef> Split = - CalledFunctionName.split(".llvm."); - SGV = SrcModule.getNamedValue(Split.first); - assert(SGV && "Can't find function to import in source module"); + // "Resolve" the summary, traversing alias, + const FunctionSummary *ResolvedCalleeSummary; + if (isa<AliasSummary>(CalleeSummary)) { + ResolvedCalleeSummary = cast<FunctionSummary>( + &cast<AliasSummary>(CalleeSummary)->getAliasee()); + assert( + GlobalValue::isLinkOnceODRLinkage(ResolvedCalleeSummary->linkage()) && + "Unexpected alias to a non-linkonceODR in import list"); + } else + ResolvedCalleeSummary = cast<FunctionSummary>(CalleeSummary); + + assert(ResolvedCalleeSummary->instCount() <= Threshold && + "selectCallee() didn't honor the threshold"); + + auto ExportModulePath = ResolvedCalleeSummary->modulePath(); + auto &ProcessedThreshold = ImportsForModule[ExportModulePath][GUID]; + /// Since the traversal of the call graph is DFS, we can revisit a function + /// a second time with a higher threshold. In this case, it is added back to + /// the worklist with the new threshold. + if (ProcessedThreshold && ProcessedThreshold >= Threshold) { + DEBUG(dbgs() << "ignored! Target was already seen with Threshold " + << ProcessedThreshold << "\n"); + continue; } - if (!SGV) { - report_fatal_error(Twine("Can't load function '") + CalledFunctionName + - "' in Module '" + SrcModule.getModuleIdentifier() + - "', error in the summary?\n"); + // Mark this function as imported in this module, with the current Threshold + ProcessedThreshold = Threshold; + + // Make exports in the source module. + if (ExportLists) { + auto &ExportList = (*ExportLists)[ExportModulePath]; + ExportList.insert(GUID); + // Mark all functions and globals referenced by this function as exported + // to the outside if they are defined in the same source module. + for (auto &Edge : ResolvedCalleeSummary->calls()) { + auto CalleeGUID = Edge.first.getGUID(); + exportGlobalInModule(Index, ExportModulePath, CalleeGUID, ExportList); + } + for (auto &Ref : ResolvedCalleeSummary->refs()) { + auto GUID = Ref.getGUID(); + exportGlobalInModule(Index, ExportModulePath, GUID, ExportList); + } } - Function *F = dyn_cast<Function>(SGV); - if (!F && isa<GlobalAlias>(SGV)) { - auto *SGA = dyn_cast<GlobalAlias>(SGV); - F = dyn_cast<Function>(SGA->getBaseObject()); - CalledFunctionName = F->getName(); - } - assert(F && "Imported Function is ... not a Function"); - - // We cannot import weak_any functions/aliases without possibly affecting - // the order they are seen and selected by the linker, changing program - // semantics. - if (SGV->hasWeakAnyLinkage()) { - DEBUG(dbgs() << DestModule.getModuleIdentifier() - << ": Ignoring import request for weak-any " - << (isa<Function>(SGV) ? "function " : "alias ") - << CalledFunctionName << " from " - << SrcModule.getModuleIdentifier() << "\n"); + // Insert the newly imported function to the worklist. + Worklist.push_back(std::make_pair(ResolvedCalleeSummary, Threshold)); + } +} + +/// Given the list of globals defined in a module, compute the list of imports +/// as well as the list of "exports", i.e. the list of symbols referenced from +/// another module (that may require promotion). +static void ComputeImportForModule( + const GVSummaryMapTy &DefinedGVSummaries, const ModuleSummaryIndex &Index, + FunctionImporter::ImportMapTy &ImportsForModule, + StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { + // Worklist contains the list of function imported in this module, for which + // we will analyse the callees and may import further down the callgraph. + SmallVector<EdgeInfo, 128> Worklist; + + // Populate the worklist with the import for the functions in the current + // module + for (auto &GVSummary : DefinedGVSummaries) { + auto *Summary = GVSummary.second; + if (auto *AS = dyn_cast<AliasSummary>(Summary)) + Summary = &AS->getAliasee(); + auto *FuncSummary = dyn_cast<FunctionSummary>(Summary); + if (!FuncSummary) + // Skip import for global variables continue; - } + DEBUG(dbgs() << "Initalize import for " << GVSummary.first << "\n"); + computeImportForFunction(*FuncSummary, Index, ImportInstrLimit, + DefinedGVSummaries, Worklist, ImportsForModule, + ExportLists); + } - // Add the function to the import list - auto &Entry = ModuleToFunctionsToImportMap[SrcModule.getModuleIdentifier()]; - Entry.insert(F); + while (!Worklist.empty()) { + auto FuncInfo = Worklist.pop_back_val(); + auto *Summary = FuncInfo.first; + auto Threshold = FuncInfo.second; // Process the newly imported functions and add callees to the worklist. - F->materialize(); - findExternalCalls(DestModule, *F, Index, CalledFunctions, Worklist); + // Adjust the threshold + Threshold = Threshold * ImportInstrFactor; + + computeImportForFunction(*Summary, Index, Threshold, DefinedGVSummaries, + Worklist, ImportsForModule, ExportLists); } } +} // anonymous namespace + +/// Compute all the import and export for every module using the Index. +void llvm::ComputeCrossModuleImport( + const ModuleSummaryIndex &Index, + const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, + StringMap<FunctionImporter::ImportMapTy> &ImportLists, + StringMap<FunctionImporter::ExportSetTy> &ExportLists) { + // For each module that has function defined, compute the import/export lists. + for (auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) { + auto &ImportsForModule = ImportLists[DefinedGVSummaries.first()]; + DEBUG(dbgs() << "Computing import for Module '" + << DefinedGVSummaries.first() << "'\n"); + ComputeImportForModule(DefinedGVSummaries.second, Index, ImportsForModule, + &ExportLists); + } + +#ifndef NDEBUG + DEBUG(dbgs() << "Import/Export lists for " << ImportLists.size() + << " modules:\n"); + for (auto &ModuleImports : ImportLists) { + auto ModName = ModuleImports.first(); + auto &Exports = ExportLists[ModName]; + DEBUG(dbgs() << "* Module " << ModName << " exports " << Exports.size() + << " functions. Imports from " << ModuleImports.second.size() + << " modules.\n"); + for (auto &Src : ModuleImports.second) { + auto SrcModName = Src.first(); + DEBUG(dbgs() << " - " << Src.second.size() << " functions imported from " + << SrcModName << "\n"); + } + } +#endif +} + +/// Compute all the imports for the given module in the Index. +void llvm::ComputeCrossModuleImportForModule( + StringRef ModulePath, const ModuleSummaryIndex &Index, + FunctionImporter::ImportMapTy &ImportList) { + + // Collect the list of functions this module defines. + // GUID -> Summary + GVSummaryMapTy FunctionSummaryMap; + Index.collectDefinedFunctionsForModule(ModulePath, FunctionSummaryMap); + + // Compute the import list for this module. + DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n"); + ComputeImportForModule(FunctionSummaryMap, Index, ImportList); + +#ifndef NDEBUG + DEBUG(dbgs() << "* Module " << ModulePath << " imports from " + << ImportList.size() << " modules.\n"); + for (auto &Src : ImportList) { + auto SrcModName = Src.first(); + DEBUG(dbgs() << " - " << Src.second.size() << " functions imported from " + << SrcModName << "\n"); + } +#endif +} + +/// Compute the set of summaries needed for a ThinLTO backend compilation of +/// \p ModulePath. +void llvm::gatherImportedSummariesForModule( + StringRef ModulePath, + const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, + const StringMap<FunctionImporter::ImportMapTy> &ImportLists, + std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) { + // Include all summaries from the importing module. + ModuleToSummariesForIndex[ModulePath] = + ModuleToDefinedGVSummaries.lookup(ModulePath); + auto ModuleImports = ImportLists.find(ModulePath); + if (ModuleImports != ImportLists.end()) { + // Include summaries for imports. + for (auto &ILI : ModuleImports->second) { + auto &SummariesForIndex = ModuleToSummariesForIndex[ILI.first()]; + const auto &DefinedGVSummaries = + ModuleToDefinedGVSummaries.lookup(ILI.first()); + for (auto &GI : ILI.second) { + const auto &DS = DefinedGVSummaries.find(GI.first); + assert(DS != DefinedGVSummaries.end() && + "Expected a defined summary for imported global value"); + SummariesForIndex[GI.first] = DS->second; + } + } + } +} + +/// Emit the files \p ModulePath will import from into \p OutputFilename. +std::error_code llvm::EmitImportsFiles( + StringRef ModulePath, StringRef OutputFilename, + const StringMap<FunctionImporter::ImportMapTy> &ImportLists) { + auto ModuleImports = ImportLists.find(ModulePath); + std::error_code EC; + raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::F_None); + if (EC) + return EC; + if (ModuleImports != ImportLists.end()) + for (auto &ILI : ModuleImports->second) + ImportsOS << ILI.first() << "\n"; + return std::error_code(); +} + +/// Fixup WeakForLinker linkages in \p TheModule based on summary analysis. +void llvm::thinLTOResolveWeakForLinkerModule( + Module &TheModule, const GVSummaryMapTy &DefinedGlobals) { + auto updateLinkage = [&](GlobalValue &GV) { + if (!GlobalValue::isWeakForLinker(GV.getLinkage())) + return; + // See if the global summary analysis computed a new resolved linkage. + const auto &GS = DefinedGlobals.find(GV.getGUID()); + if (GS == DefinedGlobals.end()) + return; + auto NewLinkage = GS->second->linkage(); + if (NewLinkage == GV.getLinkage()) + return; + DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from " + << GV.getLinkage() << " to " << NewLinkage << "\n"); + GV.setLinkage(NewLinkage); + }; + + // Process functions and global now + for (auto &GV : TheModule) + updateLinkage(GV); + for (auto &GV : TheModule.globals()) + updateLinkage(GV); + for (auto &GV : TheModule.aliases()) + updateLinkage(GV); +} + +/// Run internalization on \p TheModule based on symmary analysis. +void llvm::thinLTOInternalizeModule(Module &TheModule, + const GVSummaryMapTy &DefinedGlobals) { + // Parse inline ASM and collect the list of symbols that are not defined in + // the current module. + StringSet<> AsmUndefinedRefs; + object::IRObjectFile::CollectAsmUndefinedRefs( + Triple(TheModule.getTargetTriple()), TheModule.getModuleInlineAsm(), + [&AsmUndefinedRefs](StringRef Name, object::BasicSymbolRef::Flags Flags) { + if (Flags & object::BasicSymbolRef::SF_Undefined) + AsmUndefinedRefs.insert(Name); + }); + + // Declare a callback for the internalize pass that will ask for every + // candidate GlobalValue if it can be internalized or not. + auto MustPreserveGV = [&](const GlobalValue &GV) -> bool { + // Can't be internalized if referenced in inline asm. + if (AsmUndefinedRefs.count(GV.getName())) + return true; + + // Lookup the linkage recorded in the summaries during global analysis. + const auto &GS = DefinedGlobals.find(GV.getGUID()); + GlobalValue::LinkageTypes Linkage; + if (GS == DefinedGlobals.end()) { + // Must have been promoted (possibly conservatively). Find original + // name so that we can access the correct summary and see if it can + // be internalized again. + // FIXME: Eventually we should control promotion instead of promoting + // and internalizing again. + StringRef OrigName = + ModuleSummaryIndex::getOriginalNameBeforePromote(GV.getName()); + std::string OrigId = GlobalValue::getGlobalIdentifier( + OrigName, GlobalValue::InternalLinkage, + TheModule.getSourceFileName()); + const auto &GS = DefinedGlobals.find(GlobalValue::getGUID(OrigId)); + if (GS == DefinedGlobals.end()) { + // Also check the original non-promoted non-globalized name. In some + // cases a preempted weak value is linked in as a local copy because + // it is referenced by an alias (IRLinker::linkGlobalValueProto). + // In that case, since it was originally not a local value, it was + // recorded in the index using the original name. + // FIXME: This may not be needed once PR27866 is fixed. + const auto &GS = DefinedGlobals.find(GlobalValue::getGUID(OrigName)); + assert(GS != DefinedGlobals.end()); + Linkage = GS->second->linkage(); + } else { + Linkage = GS->second->linkage(); + } + } else + Linkage = GS->second->linkage(); + return !GlobalValue::isLocalLinkage(Linkage); + }; + + // FIXME: See if we can just internalize directly here via linkage changes + // based on the index, rather than invoking internalizeModule. + llvm::internalizeModule(TheModule, MustPreserveGV); +} + // Automatically import functions in Module \p DestModule based on the summaries // index. // -// The current implementation imports every called functions that exists in the -// summaries index. -bool FunctionImporter::importFunctions(Module &DestModule) { +bool FunctionImporter::importFunctions( + Module &DestModule, const FunctionImporter::ImportMapTy &ImportList, + bool ForceImportReferencedDiscardableSymbols) { DEBUG(dbgs() << "Starting import for Module " << DestModule.getModuleIdentifier() << "\n"); unsigned ImportedCount = 0; - /// First step is collecting the called external functions. - StringSet<> CalledFunctions; - SmallVector<StringRef, 64> Worklist; - for (auto &F : DestModule) { - if (F.isDeclaration() || F.hasFnAttribute(Attribute::OptimizeNone)) - continue; - findExternalCalls(DestModule, F, Index, CalledFunctions, Worklist); - } - if (Worklist.empty()) - return false; - - /// Second step: for every call to an external function, try to import it. - // Linker that will be used for importing function Linker TheLinker(DestModule); - - // Map of Module -> List of Function to import from the Module - std::map<StringRef, DenseSet<const GlobalValue *>> - ModuleToFunctionsToImportMap; - - // Analyze the summaries and get the list of functions to import by - // populating ModuleToFunctionsToImportMap - ModuleLazyLoaderCache ModuleLoaderCache(ModuleLoader); - GetImportList(DestModule, Worklist, CalledFunctions, - ModuleToFunctionsToImportMap, Index, ModuleLoaderCache); - assert(Worklist.empty() && "Worklist hasn't been flushed in GetImportList"); - - StringMap<std::unique_ptr<DenseMap<unsigned, MDNode *>>> - ModuleToTempMDValsMap; - // Do the actual import of functions now, one Module at a time - for (auto &FunctionsToImportPerModule : ModuleToFunctionsToImportMap) { + std::set<StringRef> ModuleNameOrderedList; + for (auto &FunctionsToImportPerModule : ImportList) { + ModuleNameOrderedList.insert(FunctionsToImportPerModule.first()); + } + for (auto &Name : ModuleNameOrderedList) { // Get the module for the import - auto &FunctionsToImport = FunctionsToImportPerModule.second; - std::unique_ptr<Module> SrcModule = - ModuleLoaderCache.takeModule(FunctionsToImportPerModule.first); + const auto &FunctionsToImportPerModule = ImportList.find(Name); + assert(FunctionsToImportPerModule != ImportList.end()); + std::unique_ptr<Module> SrcModule = ModuleLoader(Name); assert(&DestModule.getContext() == &SrcModule->getContext() && "Context mismatch"); - // Save the mapping of value ids to temporary metadata created when - // importing this function. If we have already imported from this module, - // add new temporary metadata to the existing mapping. - auto &TempMDVals = ModuleToTempMDValsMap[SrcModule->getModuleIdentifier()]; - if (!TempMDVals) - TempMDVals = llvm::make_unique<DenseMap<unsigned, MDNode *>>(); + // If modules were created with lazy metadata loading, materialize it + // now, before linking it (otherwise this will be a noop). + SrcModule->materializeMetadata(); + UpgradeDebugInfo(*SrcModule); + + auto &ImportGUIDs = FunctionsToImportPerModule->second; + // Find the globals to import + DenseSet<const GlobalValue *> GlobalsToImport; + for (Function &F : *SrcModule) { + if (!F.hasName()) + continue; + auto GUID = F.getGUID(); + auto Import = ImportGUIDs.count(GUID); + DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing function " << GUID + << " " << F.getName() << " from " + << SrcModule->getSourceFileName() << "\n"); + if (Import) { + F.materialize(); + if (EnableImportMetadata) { + // Add 'thinlto_src_module' metadata for statistics and debugging. + F.setMetadata( + "thinlto_src_module", + llvm::MDNode::get( + DestModule.getContext(), + {llvm::MDString::get(DestModule.getContext(), + SrcModule->getSourceFileName())})); + } + GlobalsToImport.insert(&F); + } + } + for (GlobalVariable &GV : SrcModule->globals()) { + if (!GV.hasName()) + continue; + auto GUID = GV.getGUID(); + auto Import = ImportGUIDs.count(GUID); + DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing global " << GUID + << " " << GV.getName() << " from " + << SrcModule->getSourceFileName() << "\n"); + if (Import) { + GV.materialize(); + GlobalsToImport.insert(&GV); + } + } + for (GlobalAlias &GA : SrcModule->aliases()) { + if (!GA.hasName()) + continue; + auto GUID = GA.getGUID(); + auto Import = ImportGUIDs.count(GUID); + DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing alias " << GUID + << " " << GA.getName() << " from " + << SrcModule->getSourceFileName() << "\n"); + if (Import) { + // Alias can't point to "available_externally". However when we import + // linkOnceODR the linkage does not change. So we import the alias + // and aliasee only in this case. This has been handled by + // computeImportForFunction() + GlobalObject *GO = GA.getBaseObject(); + assert(GO->hasLinkOnceODRLinkage() && + "Unexpected alias to a non-linkonceODR in import list"); +#ifndef NDEBUG + if (!GlobalsToImport.count(GO)) + DEBUG(dbgs() << " alias triggers importing aliasee " << GO->getGUID() + << " " << GO->getName() << " from " + << SrcModule->getSourceFileName() << "\n"); +#endif + GO->materialize(); + GlobalsToImport.insert(GO); + GA.materialize(); + GlobalsToImport.insert(&GA); + } + } // Link in the specified functions. - if (TheLinker.linkInModule(std::move(SrcModule), Linker::Flags::None, - &Index, &FunctionsToImport, TempMDVals.get())) + if (renameModuleForThinLTO(*SrcModule, Index, &GlobalsToImport)) + return true; + + if (PrintImports) { + for (const auto *GV : GlobalsToImport) + dbgs() << DestModule.getSourceFileName() << ": Import " << GV->getName() + << " from " << SrcModule->getSourceFileName() << "\n"; + } + + // Instruct the linker that the client will take care of linkonce resolution + unsigned Flags = Linker::Flags::None; + if (!ForceImportReferencedDiscardableSymbols) + Flags |= Linker::Flags::DontForceLinkLinkonceODR; + + if (TheLinker.linkInModule(std::move(SrcModule), Flags, &GlobalsToImport)) report_fatal_error("Function Import: link error"); - ImportedCount += FunctionsToImport.size(); + ImportedCount += GlobalsToImport.size(); } - // Now link in metadata for all modules from which we imported functions. - for (StringMapEntry<std::unique_ptr<DenseMap<unsigned, MDNode *>>> &SME : - ModuleToTempMDValsMap) { - // Load the specified source module. - auto &SrcModule = ModuleLoaderCache(SME.getKey()); - // The modules were created with lazy metadata loading. Materialize it - // now, before linking it. - SrcModule.materializeMetadata(); - UpgradeDebugInfo(SrcModule); - - // Link in all necessary metadata from this module. - if (TheLinker.linkInMetadata(SrcModule, SME.getValue().get())) - return false; - } + NumImported += ImportedCount; DEBUG(dbgs() << "Imported " << ImportedCount << " functions for Module " << DestModule.getModuleIdentifier() << "\n"); @@ -355,11 +696,11 @@ static void diagnosticHandler(const DiagnosticInfo &DI) { OS << '\n'; } -/// Parse the function index out of an IR file and return the function +/// Parse the summary index out of an IR file and return the summary /// index object if found, or nullptr if not. -static std::unique_ptr<FunctionInfoIndex> -getFunctionIndexForFile(StringRef Path, std::string &Error, - DiagnosticHandlerFunction DiagnosticHandler) { +static std::unique_ptr<ModuleSummaryIndex> getModuleSummaryIndexForFile( + StringRef Path, std::string &Error, + const DiagnosticHandlerFunction &DiagnosticHandler) { std::unique_ptr<MemoryBuffer> Buffer; ErrorOr<std::unique_ptr<MemoryBuffer>> BufferOrErr = MemoryBuffer::getFile(Path); @@ -368,9 +709,9 @@ getFunctionIndexForFile(StringRef Path, std::string &Error, return nullptr; } Buffer = std::move(BufferOrErr.get()); - ErrorOr<std::unique_ptr<object::FunctionIndexObjectFile>> ObjOrErr = - object::FunctionIndexObjectFile::create(Buffer->getMemBufferRef(), - DiagnosticHandler); + ErrorOr<std::unique_ptr<object::ModuleSummaryIndexObjectFile>> ObjOrErr = + object::ModuleSummaryIndexObjectFile::create(Buffer->getMemBufferRef(), + DiagnosticHandler); if (std::error_code EC = ObjOrErr.getError()) { Error = EC.message(); return nullptr; @@ -381,32 +722,34 @@ getFunctionIndexForFile(StringRef Path, std::string &Error, namespace { /// Pass that performs cross-module function import provided a summary file. class FunctionImportPass : public ModulePass { - /// Optional function summary index to use for importing, otherwise + /// Optional module summary index to use for importing, otherwise /// the summary-file option must be specified. - const FunctionInfoIndex *Index; + const ModuleSummaryIndex *Index; public: /// Pass identification, replacement for typeid static char ID; /// Specify pass name for debug output - const char *getPassName() const override { - return "Function Importing"; - } + const char *getPassName() const override { return "Function Importing"; } - explicit FunctionImportPass(const FunctionInfoIndex *Index = nullptr) + explicit FunctionImportPass(const ModuleSummaryIndex *Index = nullptr) : ModulePass(ID), Index(Index) {} bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + if (SummaryFile.empty() && !Index) report_fatal_error("error: -function-import requires -summary-file or " "file from frontend\n"); - std::unique_ptr<FunctionInfoIndex> IndexPtr; + std::unique_ptr<ModuleSummaryIndex> IndexPtr; if (!SummaryFile.empty()) { if (Index) report_fatal_error("error: -summary-file and index from frontend\n"); std::string Error; - IndexPtr = getFunctionIndexForFile(SummaryFile, Error, diagnosticHandler); + IndexPtr = + getModuleSummaryIndexForFile(SummaryFile, Error, diagnosticHandler); if (!IndexPtr) { errs() << "Error loading file '" << SummaryFile << "': " << Error << "\n"; @@ -415,9 +758,14 @@ public: Index = IndexPtr.get(); } - // First we need to promote to global scope and rename any local values that + // First step is collecting the import list. + FunctionImporter::ImportMapTy ImportList; + ComputeCrossModuleImportForModule(M.getModuleIdentifier(), *Index, + ImportList); + + // Next we need to promote to global scope and rename any local values that // are potentially exported to other modules. - if (renameModuleForThinLTO(M, Index)) { + if (renameModuleForThinLTO(M, *Index, nullptr)) { errs() << "Error renaming module\n"; return false; } @@ -427,7 +775,8 @@ public: return loadFile(Identifier, M.getContext()); }; FunctionImporter Importer(*Index, ModuleLoader); - return Importer.importFunctions(M); + return Importer.importFunctions( + M, ImportList, !DontForceImportReferencedDiscardableSymbols); } }; } // anonymous namespace @@ -439,7 +788,7 @@ INITIALIZE_PASS_END(FunctionImportPass, "function-import", "Summary Based Function Import", false, false) namespace llvm { -Pass *createFunctionImportPass(const FunctionInfoIndex *Index = nullptr) { +Pass *createFunctionImportPass(const ModuleSummaryIndex *Index = nullptr) { return new FunctionImportPass(Index); } } diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp index 9b276ed28e2e..4c74698a1b61 100644 --- a/lib/Transforms/IPO/GlobalDCE.cpp +++ b/lib/Transforms/IPO/GlobalDCE.cpp @@ -15,15 +15,16 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/GlobalDCE.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/CtorUtils.h" #include "llvm/Transforms/Utils/GlobalStatus.h" -#include "llvm/Pass.h" #include <unordered_map> using namespace llvm; @@ -31,32 +32,41 @@ using namespace llvm; STATISTIC(NumAliases , "Number of global aliases removed"); STATISTIC(NumFunctions, "Number of functions removed"); +STATISTIC(NumIFuncs, "Number of indirect functions removed"); STATISTIC(NumVariables, "Number of global variables removed"); namespace { - struct GlobalDCE : public ModulePass { + class GlobalDCELegacyPass : public ModulePass { + public: static char ID; // Pass identification, replacement for typeid - GlobalDCE() : ModulePass(ID) { - initializeGlobalDCEPass(*PassRegistry::getPassRegistry()); + GlobalDCELegacyPass() : ModulePass(ID) { + initializeGlobalDCELegacyPassPass(*PassRegistry::getPassRegistry()); } // run - Do the GlobalDCE pass on the specified module, optionally updating // the specified callgraph to reflect the changes. // - bool runOnModule(Module &M) override; + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + + ModuleAnalysisManager DummyMAM; + auto PA = Impl.run(M, DummyMAM); + return !PA.areAllPreserved(); + } private: - SmallPtrSet<GlobalValue*, 32> AliveGlobals; - SmallPtrSet<Constant *, 8> SeenConstants; - std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; + GlobalDCEPass Impl; + }; +} - /// GlobalIsNeeded - mark the specific global value as needed, and - /// recursively mark anything that it uses as also needed. - void GlobalIsNeeded(GlobalValue *GV); - void MarkUsedGlobalsAsNeeded(Constant *C); +char GlobalDCELegacyPass::ID = 0; +INITIALIZE_PASS(GlobalDCELegacyPass, "globaldce", + "Dead Global Elimination", false, false) - bool RemoveUnusedGlobalValue(GlobalValue &GV); - }; +// Public interface to the GlobalDCEPass. +ModulePass *llvm::createGlobalDCEPass() { + return new GlobalDCELegacyPass(); } /// Returns true if F contains only a single "ret" instruction. @@ -68,13 +78,7 @@ static bool isEmptyFunction(Function *F) { return RI.getReturnValue() == nullptr; } -char GlobalDCE::ID = 0; -INITIALIZE_PASS(GlobalDCE, "globaldce", - "Dead Global Elimination", false, false) - -ModulePass *llvm::createGlobalDCEPass() { return new GlobalDCE(); } - -bool GlobalDCE::runOnModule(Module &M) { +PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) { bool Changed = false; // Remove empty functions from the global ctors list. @@ -92,21 +96,14 @@ bool GlobalDCE::runOnModule(Module &M) { ComdatMembers.insert(std::make_pair(C, &GA)); // Loop over the module, adding globals which are obviously necessary. - for (Function &F : M) { - Changed |= RemoveUnusedGlobalValue(F); - // Functions with external linkage are needed if they have a body - if (!F.isDeclaration() && !F.hasAvailableExternallyLinkage()) - if (!F.isDiscardableIfUnused()) - GlobalIsNeeded(&F); - } - - for (GlobalVariable &GV : M.globals()) { - Changed |= RemoveUnusedGlobalValue(GV); + for (GlobalObject &GO : M.global_objects()) { + Changed |= RemoveUnusedGlobalValue(GO); + // Functions with external linkage are needed if they have a body. // Externally visible & appending globals are needed, if they have an // initializer. - if (!GV.isDeclaration() && !GV.hasAvailableExternallyLinkage()) - if (!GV.isDiscardableIfUnused()) - GlobalIsNeeded(&GV); + if (!GO.isDeclaration() && !GO.hasAvailableExternallyLinkage()) + if (!GO.isDiscardableIfUnused()) + GlobalIsNeeded(&GO); } for (GlobalAlias &GA : M.aliases()) { @@ -116,6 +113,13 @@ bool GlobalDCE::runOnModule(Module &M) { GlobalIsNeeded(&GA); } + for (GlobalIFunc &GIF : M.ifuncs()) { + Changed |= RemoveUnusedGlobalValue(GIF); + // Externally visible ifuncs are needed. + if (!GIF.isDiscardableIfUnused()) + GlobalIsNeeded(&GIF); + } + // Now that all globals which are needed are in the AliveGlobals set, we loop // through the program, deleting those which are not alive. // @@ -150,6 +154,14 @@ bool GlobalDCE::runOnModule(Module &M) { GA.setAliasee(nullptr); } + // The third pass drops targets of ifuncs which are dead... + std::vector<GlobalIFunc*> DeadIFuncs; + for (GlobalIFunc &GIF : M.ifuncs()) + if (!AliveGlobals.count(&GIF)) { + DeadIFuncs.push_back(&GIF); + GIF.setResolver(nullptr); + } + if (!DeadFunctions.empty()) { // Now that all interferences have been dropped, delete the actual objects // themselves. @@ -180,17 +192,29 @@ bool GlobalDCE::runOnModule(Module &M) { Changed = true; } + // Now delete any dead aliases. + if (!DeadIFuncs.empty()) { + for (GlobalIFunc *GIF : DeadIFuncs) { + RemoveUnusedGlobalValue(*GIF); + M.getIFuncList().erase(GIF); + } + NumIFuncs += DeadIFuncs.size(); + Changed = true; + } + // Make sure that all memory is released AliveGlobals.clear(); SeenConstants.clear(); ComdatMembers.clear(); - return Changed; + if (Changed) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); } /// GlobalIsNeeded - the specific global value as needed, and /// recursively mark anything that it uses as also needed. -void GlobalDCE::GlobalIsNeeded(GlobalValue *G) { +void GlobalDCEPass::GlobalIsNeeded(GlobalValue *G) { // If the global is already in the set, no need to reprocess it. if (!AliveGlobals.insert(G).second) return; @@ -205,9 +229,9 @@ void GlobalDCE::GlobalIsNeeded(GlobalValue *G) { // referenced by the initializer to the alive set. if (GV->hasInitializer()) MarkUsedGlobalsAsNeeded(GV->getInitializer()); - } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(G)) { - // The target of a global alias is needed. - MarkUsedGlobalsAsNeeded(GA->getAliasee()); + } else if (GlobalIndirectSymbol *GIS = dyn_cast<GlobalIndirectSymbol>(G)) { + // The target of a global alias or ifunc is needed. + MarkUsedGlobalsAsNeeded(GIS->getIndirectSymbol()); } else { // Otherwise this must be a function object. We have to scan the body of // the function looking for constants and global values which are used as @@ -228,7 +252,7 @@ void GlobalDCE::GlobalIsNeeded(GlobalValue *G) { } } -void GlobalDCE::MarkUsedGlobalsAsNeeded(Constant *C) { +void GlobalDCEPass::MarkUsedGlobalsAsNeeded(Constant *C) { if (GlobalValue *GV = dyn_cast<GlobalValue>(C)) return GlobalIsNeeded(GV); @@ -248,7 +272,7 @@ void GlobalDCE::MarkUsedGlobalsAsNeeded(Constant *C) { // so, nuke it. This will reduce the reference count on the global value, which // might make it deader. // -bool GlobalDCE::RemoveUnusedGlobalValue(GlobalValue &GV) { +bool GlobalDCEPass::RemoveUnusedGlobalValue(GlobalValue &GV) { if (GV.use_empty()) return false; GV.removeDeadConstantUsers(); diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index fd7736905fe8..310c29275faf 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -13,7 +13,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/GlobalOpt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -40,11 +40,11 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/CtorUtils.h" +#include "llvm/Transforms/Utils/Evaluator.h" #include "llvm/Transforms/Utils/GlobalStatus.h" -#include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> -#include <deque> using namespace llvm; #define DEBUG_TYPE "globalopt" @@ -65,46 +65,6 @@ STATISTIC(NumAliasesResolved, "Number of global aliases resolved"); STATISTIC(NumAliasesRemoved, "Number of global aliases eliminated"); STATISTIC(NumCXXDtorsRemoved, "Number of global C++ destructors removed"); -namespace { - struct GlobalOpt : public ModulePass { - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - } - static char ID; // Pass identification, replacement for typeid - GlobalOpt() : ModulePass(ID) { - initializeGlobalOptPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - - private: - bool OptimizeFunctions(Module &M); - bool OptimizeGlobalVars(Module &M); - bool OptimizeGlobalAliases(Module &M); - bool deleteIfDead(GlobalValue &GV); - bool processGlobal(GlobalValue &GV); - bool processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS); - bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn); - - bool isPointerValueDeadOnEntryToFunction(const Function *F, - GlobalValue *GV); - - TargetLibraryInfo *TLI; - SmallSet<const Comdat *, 8> NotDiscardableComdats; - }; -} - -char GlobalOpt::ID = 0; -INITIALIZE_PASS_BEGIN(GlobalOpt, "globalopt", - "Global Variable Optimizer", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(GlobalOpt, "globalopt", - "Global Variable Optimizer", false, false) - -ModulePass *llvm::createGlobalOptimizerPass() { return new GlobalOpt(); } - /// Is this global variable possibly used by a leak checker as a root? If so, /// we might not really want to eliminate the stores to it. static bool isLeakCheckerRoot(GlobalVariable *GV) { @@ -120,7 +80,7 @@ static bool isLeakCheckerRoot(GlobalVariable *GV) { return false; SmallVector<Type *, 4> Types; - Types.push_back(cast<PointerType>(GV->getType())->getElementType()); + Types.push_back(GV->getValueType()); unsigned Limit = 20; do { @@ -329,7 +289,7 @@ static bool CleanupConstantGlobalUsers(Value *V, Constant *Init, // we already know what the result of any load from that GEP is. // TODO: Handle splats. if (Init && isa<ConstantAggregateZero>(Init) && GEP->isInBounds()) - SubInit = Constant::getNullValue(GEP->getType()->getElementType()); + SubInit = Constant::getNullValue(GEP->getResultElementType()); } Changed |= CleanupConstantGlobalUsers(GEP, SubInit, DL, TLI); @@ -475,7 +435,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { if (!GlobalUsersSafeToSRA(GV)) return nullptr; - assert(GV->hasLocalLinkage() && !GV->isConstant()); + assert(GV->hasLocalLinkage()); Constant *Init = GV->getInitializer(); Type *Ty = Init->getType(); @@ -499,6 +459,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { GV->getThreadLocalMode(), GV->getType()->getAddressSpace()); NGV->setExternallyInitialized(GV->isExternallyInitialized()); + NGV->copyAttributesFrom(GV); Globals.push_back(NGV); NewGlobals.push_back(NGV); @@ -533,6 +494,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { GV->getThreadLocalMode(), GV->getType()->getAddressSpace()); NGV->setExternallyInitialized(GV->isExternallyInitialized()); + NGV->copyAttributesFrom(GV); Globals.push_back(NGV); NewGlobals.push_back(NGV); @@ -867,9 +829,8 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, } Constant *RepValue = NewGV; - if (NewGV->getType() != GV->getType()->getElementType()) - RepValue = ConstantExpr::getBitCast(RepValue, - GV->getType()->getElementType()); + if (NewGV->getType() != GV->getValueType()) + RepValue = ConstantExpr::getBitCast(RepValue, GV->getValueType()); // If there is a comparison against null, we will insert a global bool to // keep track of whether the global was initialized yet or not. @@ -1283,6 +1244,9 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, std::vector<Value*> FieldGlobals; std::vector<Value*> FieldMallocs; + SmallVector<OperandBundleDef, 1> OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + unsigned AS = GV->getType()->getPointerAddressSpace(); for (unsigned FieldNo = 0, e = STy->getNumElements(); FieldNo != e;++FieldNo){ Type *FieldTy = STy->getElementType(FieldNo); @@ -1292,6 +1256,7 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, *GV->getParent(), PFieldTy, false, GlobalValue::InternalLinkage, Constant::getNullValue(PFieldTy), GV->getName() + ".f" + Twine(FieldNo), nullptr, GV->getThreadLocalMode()); + NGV->copyAttributesFrom(GV); FieldGlobals.push_back(NGV); unsigned TypeSize = DL.getTypeAllocSize(FieldTy); @@ -1300,7 +1265,7 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, Type *IntPtrTy = DL.getIntPtrType(CI->getType()); Value *NMI = CallInst::CreateMalloc(CI, IntPtrTy, FieldTy, ConstantInt::get(IntPtrTy, TypeSize), - NElems, nullptr, + NElems, OpBundles, nullptr, CI->getName() + ".f" + Twine(FieldNo)); FieldMallocs.push_back(NMI); new StoreInst(NMI, NGV, CI); @@ -1359,7 +1324,7 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, Cmp, NullPtrBlock); // Fill in FreeBlock. - CallInst::CreateFree(GVVal, BI); + CallInst::CreateFree(GVVal, OpBundles, BI); new StoreInst(Constant::getNullValue(GVVal->getType()), FieldGlobals[i], FreeBlock); BranchInst::Create(NextBlock, FreeBlock); @@ -1397,8 +1362,8 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, // Insert a store of null into each global. for (unsigned i = 0, e = FieldGlobals.size(); i != e; ++i) { - PointerType *PT = cast<PointerType>(FieldGlobals[i]->getType()); - Constant *Null = Constant::getNullValue(PT->getElementType()); + Type *ValTy = cast<GlobalValue>(FieldGlobals[i])->getValueType(); + Constant *Null = Constant::getNullValue(ValTy); new StoreInst(Null, FieldGlobals[i], SI); } // Erase the original store. @@ -1500,7 +1465,7 @@ static bool tryToOptimizeStoreOfMallocToGlobal(GlobalVariable *GV, CallInst *CI, // into multiple malloc'd arrays, one for each field. This is basically // SRoA for malloc'd memory. - if (Ordering != NotAtomic) + if (Ordering != AtomicOrdering::NotAtomic) return false; // If this is an allocation of a fixed size array of structs, analyze as a @@ -1525,9 +1490,11 @@ static bool tryToOptimizeStoreOfMallocToGlobal(GlobalVariable *GV, CallInst *CI, unsigned TypeSize = DL.getStructLayout(AllocSTy)->getSizeInBytes(); Value *AllocSize = ConstantInt::get(IntPtrTy, TypeSize); Value *NumElements = ConstantInt::get(IntPtrTy, AT->getNumElements()); - Instruction *Malloc = CallInst::CreateMalloc(CI, IntPtrTy, AllocSTy, - AllocSize, NumElements, - nullptr, CI->getName()); + SmallVector<OperandBundleDef, 1> OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + Instruction *Malloc = + CallInst::CreateMalloc(CI, IntPtrTy, AllocSTy, AllocSize, NumElements, + OpBundles, nullptr, CI->getName()); Instruction *Cast = new BitCastInst(Malloc, CI->getType(), "tmp", CI); CI->replaceAllUsesWith(Cast); CI->eraseFromParent(); @@ -1583,7 +1550,7 @@ static bool optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, /// boolean and select between the two values whenever it is used. This exposes /// the values to other scalar optimizations. static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { - Type *GVElType = GV->getType()->getElementType(); + Type *GVElType = GV->getValueType(); // If GVElType is already i1, it is already shrunk. If the type of the GV is // an FP value, pointer or vector, don't do this optimization because a select @@ -1611,6 +1578,7 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { GV->getName()+".b", GV->getThreadLocalMode(), GV->getType()->getAddressSpace()); + NewGV->copyAttributesFrom(GV); GV->getParent()->getGlobalList().insert(GV->getIterator(), NewGV); Constant *InitVal = GV->getInitializer(); @@ -1679,7 +1647,8 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { return true; } -bool GlobalOpt::deleteIfDead(GlobalValue &GV) { +static bool deleteIfDead(GlobalValue &GV, + SmallSet<const Comdat *, 8> &NotDiscardableComdats) { GV.removeDeadConstantUsers(); if (!GV.isDiscardableIfUnused()) @@ -1703,36 +1672,9 @@ bool GlobalOpt::deleteIfDead(GlobalValue &GV) { return true; } -/// Analyze the specified global variable and optimize it if possible. If we -/// make a change, return true. -bool GlobalOpt::processGlobal(GlobalValue &GV) { - // Do more involved optimizations if the global is internal. - if (!GV.hasLocalLinkage()) - return false; - - GlobalStatus GS; - - if (GlobalStatus::analyzeGlobal(&GV, GS)) - return false; - - bool Changed = false; - if (!GS.IsCompared && !GV.hasUnnamedAddr()) { - GV.setUnnamedAddr(true); - NumUnnamed++; - Changed = true; - } - - auto *GVar = dyn_cast<GlobalVariable>(&GV); - if (!GVar) - return Changed; - - if (GVar->isConstant() || !GVar->hasInitializer()) - return Changed; - - return processInternalGlobal(GVar, GS) || Changed; -} - -bool GlobalOpt::isPointerValueDeadOnEntryToFunction(const Function *F, GlobalValue *GV) { +static bool isPointerValueDeadOnEntryToFunction( + const Function *F, GlobalValue *GV, + function_ref<DominatorTree &(Function &)> LookupDomTree) { // Find all uses of GV. We expect them all to be in F, and if we can't // identify any of the uses we bail out. // @@ -1776,8 +1718,7 @@ bool GlobalOpt::isPointerValueDeadOnEntryToFunction(const Function *F, GlobalVal // of them are known not to depend on the value of the global at the function // entry point. We do this by ensuring that every load is dominated by at // least one store. - auto &DT = getAnalysis<DominatorTreeWrapperPass>(*const_cast<Function *>(F)) - .getDomTree(); + auto &DT = LookupDomTree(*const_cast<Function *>(F)); // The below check is quadratic. Check we're not going to do too many tests. // FIXME: Even though this will always have worst-case quadratic time, we @@ -1866,8 +1807,9 @@ static void makeAllConstantUsesInstructions(Constant *C) { /// Analyze the specified global variable and optimize /// it if possible. If we make a change, return true. -bool GlobalOpt::processInternalGlobal(GlobalVariable *GV, - const GlobalStatus &GS) { +static bool processInternalGlobal( + GlobalVariable *GV, const GlobalStatus &GS, TargetLibraryInfo *TLI, + function_ref<DominatorTree &(Function &)> LookupDomTree) { auto &DL = GV->getParent()->getDataLayout(); // If this is a first class global and has only one accessing function and // this function is non-recursive, we replace the global with a local alloca @@ -1879,16 +1821,17 @@ bool GlobalOpt::processInternalGlobal(GlobalVariable *GV, // If the global is in different address space, don't bring it to stack. if (!GS.HasMultipleAccessingFunctions && GS.AccessingFunction && - GV->getType()->getElementType()->isSingleValueType() && + GV->getValueType()->isSingleValueType() && GV->getType()->getAddressSpace() == 0 && !GV->isExternallyInitialized() && allNonInstructionUsersCanBeMadeInstructions(GV) && GS.AccessingFunction->doesNotRecurse() && - isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV) ) { + isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV, + LookupDomTree)) { DEBUG(dbgs() << "LOCALIZING GLOBAL: " << *GV << "\n"); Instruction &FirstI = const_cast<Instruction&>(*GS.AccessingFunction ->getEntryBlock().begin()); - Type *ElemTy = GV->getType()->getElementType(); + Type *ElemTy = GV->getValueType(); // FIXME: Pass Global's alignment when globals have alignment AllocaInst *Alloca = new AllocaInst(ElemTy, nullptr, GV->getName(), &FirstI); @@ -1896,7 +1839,7 @@ bool GlobalOpt::processInternalGlobal(GlobalVariable *GV, new StoreInst(GV->getInitializer(), Alloca, &FirstI); makeAllConstantUsesInstructions(GV); - + GV->replaceAllUsesWith(Alloca); GV->eraseFromParent(); ++NumLocalized; @@ -1926,7 +1869,8 @@ bool GlobalOpt::processInternalGlobal(GlobalVariable *GV, } return Changed; - } else if (GS.StoredType <= GlobalStatus::InitializerStored) { + } + if (GS.StoredType <= GlobalStatus::InitializerStored) { DEBUG(dbgs() << "MARKING CONSTANT: " << *GV << "\n"); GV->setConstant(true); @@ -1939,15 +1883,18 @@ bool GlobalOpt::processInternalGlobal(GlobalVariable *GV, << "all users and delete global!\n"); GV->eraseFromParent(); ++NumDeleted; + return true; } + // Fall through to the next check; see if we can optimize further. ++NumMarked; - return true; - } else if (!GV->getInitializer()->getType()->isSingleValueType()) { + } + if (!GV->getInitializer()->getType()->isSingleValueType()) { const DataLayout &DL = GV->getParent()->getDataLayout(); if (SRAGlobal(GV, DL)) return true; - } else if (GS.StoredType == GlobalStatus::StoredOnce && GS.StoredOnceValue) { + } + if (GS.StoredType == GlobalStatus::StoredOnce && GS.StoredOnceValue) { // If the initial value for the global was an undef value, and if only // one other value was stored into it, we can just change the // initializer to be the stored value, then delete all stores to the @@ -1978,7 +1925,7 @@ bool GlobalOpt::processInternalGlobal(GlobalVariable *GV, // Otherwise, if the global was not a boolean, we can shrink it to be a // boolean. if (Constant *SOVConstant = dyn_cast<Constant>(GS.StoredOnceValue)) { - if (GS.Ordering == NotAtomic) { + if (GS.Ordering == AtomicOrdering::NotAtomic) { if (TryToShrinkGlobalToBoolean(GV, SOVConstant)) { ++NumShrunkToBool; return true; @@ -1990,6 +1937,44 @@ bool GlobalOpt::processInternalGlobal(GlobalVariable *GV, return false; } +/// Analyze the specified global variable and optimize it if possible. If we +/// make a change, return true. +static bool +processGlobal(GlobalValue &GV, TargetLibraryInfo *TLI, + function_ref<DominatorTree &(Function &)> LookupDomTree) { + if (GV.getName().startswith("llvm.")) + return false; + + GlobalStatus GS; + + if (GlobalStatus::analyzeGlobal(&GV, GS)) + return false; + + bool Changed = false; + if (!GS.IsCompared && !GV.hasGlobalUnnamedAddr()) { + auto NewUnnamedAddr = GV.hasLocalLinkage() ? GlobalValue::UnnamedAddr::Global + : GlobalValue::UnnamedAddr::Local; + if (NewUnnamedAddr != GV.getUnnamedAddr()) { + GV.setUnnamedAddr(NewUnnamedAddr); + NumUnnamed++; + Changed = true; + } + } + + // Do more involved optimizations if the global is internal. + if (!GV.hasLocalLinkage()) + return Changed; + + auto *GVar = dyn_cast<GlobalVariable>(&GV); + if (!GVar) + return Changed; + + if (GVar->isConstant() || !GVar->hasInitializer()) + return Changed; + + return processInternalGlobal(GVar, GS, TLI, LookupDomTree) || Changed; +} + /// Walk all of the direct calls of the specified function, changing them to /// FastCC. static void ChangeCalleesToFastCall(Function *F) { @@ -2034,7 +2019,10 @@ static bool isProfitableToMakeFastCC(Function *F) { return CC == CallingConv::C || CC == CallingConv::X86_ThisCall; } -bool GlobalOpt::OptimizeFunctions(Module &M) { +static bool +OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, + function_ref<DominatorTree &(Function &)> LookupDomTree, + SmallSet<const Comdat *, 8> &NotDiscardableComdats) { bool Changed = false; // Optimize functions. for (Module::iterator FI = M.begin(), E = M.end(); FI != E; ) { @@ -2043,12 +2031,12 @@ bool GlobalOpt::OptimizeFunctions(Module &M) { if (!F->hasName() && !F->isDeclaration() && !F->hasLocalLinkage()) F->setLinkage(GlobalValue::InternalLinkage); - if (deleteIfDead(*F)) { + if (deleteIfDead(*F, NotDiscardableComdats)) { Changed = true; continue; } - Changed |= processGlobal(*F); + Changed |= processGlobal(*F, TLI, LookupDomTree); if (!F->hasLocalLinkage()) continue; @@ -2075,7 +2063,10 @@ bool GlobalOpt::OptimizeFunctions(Module &M) { return Changed; } -bool GlobalOpt::OptimizeGlobalVars(Module &M) { +static bool +OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI, + function_ref<DominatorTree &(Function &)> LookupDomTree, + SmallSet<const Comdat *, 8> &NotDiscardableComdats) { bool Changed = false; for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); @@ -2093,148 +2084,16 @@ bool GlobalOpt::OptimizeGlobalVars(Module &M) { GV->setInitializer(New); } - if (deleteIfDead(*GV)) { + if (deleteIfDead(*GV, NotDiscardableComdats)) { Changed = true; continue; } - Changed |= processGlobal(*GV); + Changed |= processGlobal(*GV, TLI, LookupDomTree); } return Changed; } -static inline bool -isSimpleEnoughValueToCommit(Constant *C, - SmallPtrSetImpl<Constant *> &SimpleConstants, - const DataLayout &DL); - -/// Return true if the specified constant can be handled by the code generator. -/// We don't want to generate something like: -/// void *X = &X/42; -/// because the code generator doesn't have a relocation that can handle that. -/// -/// This function should be called if C was not found (but just got inserted) -/// in SimpleConstants to avoid having to rescan the same constants all the -/// time. -static bool -isSimpleEnoughValueToCommitHelper(Constant *C, - SmallPtrSetImpl<Constant *> &SimpleConstants, - const DataLayout &DL) { - // Simple global addresses are supported, do not allow dllimport or - // thread-local globals. - if (auto *GV = dyn_cast<GlobalValue>(C)) - return !GV->hasDLLImportStorageClass() && !GV->isThreadLocal(); - - // Simple integer, undef, constant aggregate zero, etc are all supported. - if (C->getNumOperands() == 0 || isa<BlockAddress>(C)) - return true; - - // Aggregate values are safe if all their elements are. - if (isa<ConstantArray>(C) || isa<ConstantStruct>(C) || - isa<ConstantVector>(C)) { - for (Value *Op : C->operands()) - if (!isSimpleEnoughValueToCommit(cast<Constant>(Op), SimpleConstants, DL)) - return false; - return true; - } - - // We don't know exactly what relocations are allowed in constant expressions, - // so we allow &global+constantoffset, which is safe and uniformly supported - // across targets. - ConstantExpr *CE = cast<ConstantExpr>(C); - switch (CE->getOpcode()) { - case Instruction::BitCast: - // Bitcast is fine if the casted value is fine. - return isSimpleEnoughValueToCommit(CE->getOperand(0), SimpleConstants, DL); - - case Instruction::IntToPtr: - case Instruction::PtrToInt: - // int <=> ptr is fine if the int type is the same size as the - // pointer type. - if (DL.getTypeSizeInBits(CE->getType()) != - DL.getTypeSizeInBits(CE->getOperand(0)->getType())) - return false; - return isSimpleEnoughValueToCommit(CE->getOperand(0), SimpleConstants, DL); - - // GEP is fine if it is simple + constant offset. - case Instruction::GetElementPtr: - for (unsigned i = 1, e = CE->getNumOperands(); i != e; ++i) - if (!isa<ConstantInt>(CE->getOperand(i))) - return false; - return isSimpleEnoughValueToCommit(CE->getOperand(0), SimpleConstants, DL); - - case Instruction::Add: - // We allow simple+cst. - if (!isa<ConstantInt>(CE->getOperand(1))) - return false; - return isSimpleEnoughValueToCommit(CE->getOperand(0), SimpleConstants, DL); - } - return false; -} - -static inline bool -isSimpleEnoughValueToCommit(Constant *C, - SmallPtrSetImpl<Constant *> &SimpleConstants, - const DataLayout &DL) { - // If we already checked this constant, we win. - if (!SimpleConstants.insert(C).second) - return true; - // Check the constant. - return isSimpleEnoughValueToCommitHelper(C, SimpleConstants, DL); -} - - -/// Return true if this constant is simple enough for us to understand. In -/// particular, if it is a cast to anything other than from one pointer type to -/// another pointer type, we punt. We basically just support direct accesses to -/// globals and GEP's of globals. This should be kept up to date with -/// CommitValueTo. -static bool isSimpleEnoughPointerToCommit(Constant *C) { - // Conservatively, avoid aggregate types. This is because we don't - // want to worry about them partially overlapping other stores. - if (!cast<PointerType>(C->getType())->getElementType()->isSingleValueType()) - return false; - - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) - // Do not allow weak/*_odr/linkonce linkage or external globals. - return GV->hasUniqueInitializer(); - - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { - // Handle a constantexpr gep. - if (CE->getOpcode() == Instruction::GetElementPtr && - isa<GlobalVariable>(CE->getOperand(0)) && - cast<GEPOperator>(CE)->isInBounds()) { - GlobalVariable *GV = cast<GlobalVariable>(CE->getOperand(0)); - // Do not allow weak/*_odr/linkonce/dllimport/dllexport linkage or - // external globals. - if (!GV->hasUniqueInitializer()) - return false; - - // The first index must be zero. - ConstantInt *CI = dyn_cast<ConstantInt>(*std::next(CE->op_begin())); - if (!CI || !CI->isZero()) return false; - - // The remaining indices must be compile-time known integers within the - // notional bounds of the corresponding static array types. - if (!CE->isGEPWithNoNotionalOverIndexing()) - return false; - - return ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE); - - // A constantexpr bitcast from a pointer to another pointer is a no-op, - // and we know how to evaluate it by moving the bitcast from the pointer - // operand to the value operand. - } else if (CE->getOpcode() == Instruction::BitCast && - isa<GlobalVariable>(CE->getOperand(0))) { - // Do not allow weak/*_odr/linkonce/dllimport/dllexport linkage or - // external globals. - return cast<GlobalVariable>(CE->getOperand(0))->hasUniqueInitializer(); - } - } - - return false; -} - /// Evaluate a piece of a constantexpr store into a global initializer. This /// returns 'Init' modified to reflect 'Val' stored into it. At this point, the /// GEP operands of Addr [0, OpNo) have been stepped into. @@ -2298,533 +2157,10 @@ static void CommitValueTo(Constant *Val, Constant *Addr) { GV->setInitializer(EvaluateStoreInto(GV->getInitializer(), Val, CE, 2)); } -namespace { - -/// This class evaluates LLVM IR, producing the Constant representing each SSA -/// instruction. Changes to global variables are stored in a mapping that can -/// be iterated over after the evaluation is complete. Once an evaluation call -/// fails, the evaluation object should not be reused. -class Evaluator { -public: - Evaluator(const DataLayout &DL, const TargetLibraryInfo *TLI) - : DL(DL), TLI(TLI) { - ValueStack.emplace_back(); - } - - ~Evaluator() { - for (auto &Tmp : AllocaTmps) - // If there are still users of the alloca, the program is doing something - // silly, e.g. storing the address of the alloca somewhere and using it - // later. Since this is undefined, we'll just make it be null. - if (!Tmp->use_empty()) - Tmp->replaceAllUsesWith(Constant::getNullValue(Tmp->getType())); - } - - /// Evaluate a call to function F, returning true if successful, false if we - /// can't evaluate it. ActualArgs contains the formal arguments for the - /// function. - bool EvaluateFunction(Function *F, Constant *&RetVal, - const SmallVectorImpl<Constant*> &ActualArgs); - - /// Evaluate all instructions in block BB, returning true if successful, false - /// if we can't evaluate it. NewBB returns the next BB that control flows - /// into, or null upon return. - bool EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB); - - Constant *getVal(Value *V) { - if (Constant *CV = dyn_cast<Constant>(V)) return CV; - Constant *R = ValueStack.back().lookup(V); - assert(R && "Reference to an uncomputed value!"); - return R; - } - - void setVal(Value *V, Constant *C) { - ValueStack.back()[V] = C; - } - - const DenseMap<Constant*, Constant*> &getMutatedMemory() const { - return MutatedMemory; - } - - const SmallPtrSetImpl<GlobalVariable*> &getInvariants() const { - return Invariants; - } - -private: - Constant *ComputeLoadResult(Constant *P); - - /// As we compute SSA register values, we store their contents here. The back - /// of the deque contains the current function and the stack contains the - /// values in the calling frames. - std::deque<DenseMap<Value*, Constant*>> ValueStack; - - /// This is used to detect recursion. In pathological situations we could hit - /// exponential behavior, but at least there is nothing unbounded. - SmallVector<Function*, 4> CallStack; - - /// For each store we execute, we update this map. Loads check this to get - /// the most up-to-date value. If evaluation is successful, this state is - /// committed to the process. - DenseMap<Constant*, Constant*> MutatedMemory; - - /// To 'execute' an alloca, we create a temporary global variable to represent - /// its body. This vector is needed so we can delete the temporary globals - /// when we are done. - SmallVector<std::unique_ptr<GlobalVariable>, 32> AllocaTmps; - - /// These global variables have been marked invariant by the static - /// constructor. - SmallPtrSet<GlobalVariable*, 8> Invariants; - - /// These are constants we have checked and know to be simple enough to live - /// in a static initializer of a global. - SmallPtrSet<Constant*, 8> SimpleConstants; - - const DataLayout &DL; - const TargetLibraryInfo *TLI; -}; - -} // anonymous namespace - -/// Return the value that would be computed by a load from P after the stores -/// reflected by 'memory' have been performed. If we can't decide, return null. -Constant *Evaluator::ComputeLoadResult(Constant *P) { - // If this memory location has been recently stored, use the stored value: it - // is the most up-to-date. - DenseMap<Constant*, Constant*>::const_iterator I = MutatedMemory.find(P); - if (I != MutatedMemory.end()) return I->second; - - // Access it. - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(P)) { - if (GV->hasDefinitiveInitializer()) - return GV->getInitializer(); - return nullptr; - } - - // Handle a constantexpr getelementptr. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(P)) - if (CE->getOpcode() == Instruction::GetElementPtr && - isa<GlobalVariable>(CE->getOperand(0))) { - GlobalVariable *GV = cast<GlobalVariable>(CE->getOperand(0)); - if (GV->hasDefinitiveInitializer()) - return ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE); - } - - return nullptr; // don't know how to evaluate. -} - -/// Evaluate all instructions in block BB, returning true if successful, false -/// if we can't evaluate it. NewBB returns the next BB that control flows into, -/// or null upon return. -bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, - BasicBlock *&NextBB) { - // This is the main evaluation loop. - while (1) { - Constant *InstResult = nullptr; - - DEBUG(dbgs() << "Evaluating Instruction: " << *CurInst << "\n"); - - if (StoreInst *SI = dyn_cast<StoreInst>(CurInst)) { - if (!SI->isSimple()) { - DEBUG(dbgs() << "Store is not simple! Can not evaluate.\n"); - return false; // no volatile/atomic accesses. - } - Constant *Ptr = getVal(SI->getOperand(1)); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { - DEBUG(dbgs() << "Folding constant ptr expression: " << *Ptr); - Ptr = ConstantFoldConstantExpression(CE, DL, TLI); - DEBUG(dbgs() << "; To: " << *Ptr << "\n"); - } - if (!isSimpleEnoughPointerToCommit(Ptr)) { - // If this is too complex for us to commit, reject it. - DEBUG(dbgs() << "Pointer is too complex for us to evaluate store."); - return false; - } - - Constant *Val = getVal(SI->getOperand(0)); - - // If this might be too difficult for the backend to handle (e.g. the addr - // of one global variable divided by another) then we can't commit it. - if (!isSimpleEnoughValueToCommit(Val, SimpleConstants, DL)) { - DEBUG(dbgs() << "Store value is too complex to evaluate store. " << *Val - << "\n"); - return false; - } - - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { - if (CE->getOpcode() == Instruction::BitCast) { - DEBUG(dbgs() << "Attempting to resolve bitcast on constant ptr.\n"); - // If we're evaluating a store through a bitcast, then we need - // to pull the bitcast off the pointer type and push it onto the - // stored value. - Ptr = CE->getOperand(0); - - Type *NewTy = cast<PointerType>(Ptr->getType())->getElementType(); - - // In order to push the bitcast onto the stored value, a bitcast - // from NewTy to Val's type must be legal. If it's not, we can try - // introspecting NewTy to find a legal conversion. - while (!Val->getType()->canLosslesslyBitCastTo(NewTy)) { - // If NewTy is a struct, we can convert the pointer to the struct - // into a pointer to its first member. - // FIXME: This could be extended to support arrays as well. - if (StructType *STy = dyn_cast<StructType>(NewTy)) { - NewTy = STy->getTypeAtIndex(0U); - - IntegerType *IdxTy = IntegerType::get(NewTy->getContext(), 32); - Constant *IdxZero = ConstantInt::get(IdxTy, 0, false); - Constant * const IdxList[] = {IdxZero, IdxZero}; - - Ptr = ConstantExpr::getGetElementPtr(nullptr, Ptr, IdxList); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) - Ptr = ConstantFoldConstantExpression(CE, DL, TLI); - - // If we can't improve the situation by introspecting NewTy, - // we have to give up. - } else { - DEBUG(dbgs() << "Failed to bitcast constant ptr, can not " - "evaluate.\n"); - return false; - } - } - - // If we found compatible types, go ahead and push the bitcast - // onto the stored value. - Val = ConstantExpr::getBitCast(Val, NewTy); - - DEBUG(dbgs() << "Evaluated bitcast: " << *Val << "\n"); - } - } - - MutatedMemory[Ptr] = Val; - } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CurInst)) { - InstResult = ConstantExpr::get(BO->getOpcode(), - getVal(BO->getOperand(0)), - getVal(BO->getOperand(1))); - DEBUG(dbgs() << "Found a BinaryOperator! Simplifying: " << *InstResult - << "\n"); - } else if (CmpInst *CI = dyn_cast<CmpInst>(CurInst)) { - InstResult = ConstantExpr::getCompare(CI->getPredicate(), - getVal(CI->getOperand(0)), - getVal(CI->getOperand(1))); - DEBUG(dbgs() << "Found a CmpInst! Simplifying: " << *InstResult - << "\n"); - } else if (CastInst *CI = dyn_cast<CastInst>(CurInst)) { - InstResult = ConstantExpr::getCast(CI->getOpcode(), - getVal(CI->getOperand(0)), - CI->getType()); - DEBUG(dbgs() << "Found a Cast! Simplifying: " << *InstResult - << "\n"); - } else if (SelectInst *SI = dyn_cast<SelectInst>(CurInst)) { - InstResult = ConstantExpr::getSelect(getVal(SI->getOperand(0)), - getVal(SI->getOperand(1)), - getVal(SI->getOperand(2))); - DEBUG(dbgs() << "Found a Select! Simplifying: " << *InstResult - << "\n"); - } else if (auto *EVI = dyn_cast<ExtractValueInst>(CurInst)) { - InstResult = ConstantExpr::getExtractValue( - getVal(EVI->getAggregateOperand()), EVI->getIndices()); - DEBUG(dbgs() << "Found an ExtractValueInst! Simplifying: " << *InstResult - << "\n"); - } else if (auto *IVI = dyn_cast<InsertValueInst>(CurInst)) { - InstResult = ConstantExpr::getInsertValue( - getVal(IVI->getAggregateOperand()), - getVal(IVI->getInsertedValueOperand()), IVI->getIndices()); - DEBUG(dbgs() << "Found an InsertValueInst! Simplifying: " << *InstResult - << "\n"); - } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(CurInst)) { - Constant *P = getVal(GEP->getOperand(0)); - SmallVector<Constant*, 8> GEPOps; - for (User::op_iterator i = GEP->op_begin() + 1, e = GEP->op_end(); - i != e; ++i) - GEPOps.push_back(getVal(*i)); - InstResult = - ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), P, GEPOps, - cast<GEPOperator>(GEP)->isInBounds()); - DEBUG(dbgs() << "Found a GEP! Simplifying: " << *InstResult - << "\n"); - } else if (LoadInst *LI = dyn_cast<LoadInst>(CurInst)) { - - if (!LI->isSimple()) { - DEBUG(dbgs() << "Found a Load! Not a simple load, can not evaluate.\n"); - return false; // no volatile/atomic accesses. - } - - Constant *Ptr = getVal(LI->getOperand(0)); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { - Ptr = ConstantFoldConstantExpression(CE, DL, TLI); - DEBUG(dbgs() << "Found a constant pointer expression, constant " - "folding: " << *Ptr << "\n"); - } - InstResult = ComputeLoadResult(Ptr); - if (!InstResult) { - DEBUG(dbgs() << "Failed to compute load result. Can not evaluate load." - "\n"); - return false; // Could not evaluate load. - } - - DEBUG(dbgs() << "Evaluated load: " << *InstResult << "\n"); - } else if (AllocaInst *AI = dyn_cast<AllocaInst>(CurInst)) { - if (AI->isArrayAllocation()) { - DEBUG(dbgs() << "Found an array alloca. Can not evaluate.\n"); - return false; // Cannot handle array allocs. - } - Type *Ty = AI->getType()->getElementType(); - AllocaTmps.push_back( - make_unique<GlobalVariable>(Ty, false, GlobalValue::InternalLinkage, - UndefValue::get(Ty), AI->getName())); - InstResult = AllocaTmps.back().get(); - DEBUG(dbgs() << "Found an alloca. Result: " << *InstResult << "\n"); - } else if (isa<CallInst>(CurInst) || isa<InvokeInst>(CurInst)) { - CallSite CS(&*CurInst); - - // Debug info can safely be ignored here. - if (isa<DbgInfoIntrinsic>(CS.getInstruction())) { - DEBUG(dbgs() << "Ignoring debug info.\n"); - ++CurInst; - continue; - } - - // Cannot handle inline asm. - if (isa<InlineAsm>(CS.getCalledValue())) { - DEBUG(dbgs() << "Found inline asm, can not evaluate.\n"); - return false; - } - - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { - if (MemSetInst *MSI = dyn_cast<MemSetInst>(II)) { - if (MSI->isVolatile()) { - DEBUG(dbgs() << "Can not optimize a volatile memset " << - "intrinsic.\n"); - return false; - } - Constant *Ptr = getVal(MSI->getDest()); - Constant *Val = getVal(MSI->getValue()); - Constant *DestVal = ComputeLoadResult(getVal(Ptr)); - if (Val->isNullValue() && DestVal && DestVal->isNullValue()) { - // This memset is a no-op. - DEBUG(dbgs() << "Ignoring no-op memset.\n"); - ++CurInst; - continue; - } - } - - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) { - DEBUG(dbgs() << "Ignoring lifetime intrinsic.\n"); - ++CurInst; - continue; - } - - if (II->getIntrinsicID() == Intrinsic::invariant_start) { - // We don't insert an entry into Values, as it doesn't have a - // meaningful return value. - if (!II->use_empty()) { - DEBUG(dbgs() << "Found unused invariant_start. Can't evaluate.\n"); - return false; - } - ConstantInt *Size = cast<ConstantInt>(II->getArgOperand(0)); - Value *PtrArg = getVal(II->getArgOperand(1)); - Value *Ptr = PtrArg->stripPointerCasts(); - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) { - Type *ElemTy = cast<PointerType>(GV->getType())->getElementType(); - if (!Size->isAllOnesValue() && - Size->getValue().getLimitedValue() >= - DL.getTypeStoreSize(ElemTy)) { - Invariants.insert(GV); - DEBUG(dbgs() << "Found a global var that is an invariant: " << *GV - << "\n"); - } else { - DEBUG(dbgs() << "Found a global var, but can not treat it as an " - "invariant.\n"); - } - } - // Continue even if we do nothing. - ++CurInst; - continue; - } else if (II->getIntrinsicID() == Intrinsic::assume) { - DEBUG(dbgs() << "Skipping assume intrinsic.\n"); - ++CurInst; - continue; - } - - DEBUG(dbgs() << "Unknown intrinsic. Can not evaluate.\n"); - return false; - } - - // Resolve function pointers. - Function *Callee = dyn_cast<Function>(getVal(CS.getCalledValue())); - if (!Callee || Callee->mayBeOverridden()) { - DEBUG(dbgs() << "Can not resolve function pointer.\n"); - return false; // Cannot resolve. - } - - SmallVector<Constant*, 8> Formals; - for (User::op_iterator i = CS.arg_begin(), e = CS.arg_end(); i != e; ++i) - Formals.push_back(getVal(*i)); - - if (Callee->isDeclaration()) { - // If this is a function we can constant fold, do it. - if (Constant *C = ConstantFoldCall(Callee, Formals, TLI)) { - InstResult = C; - DEBUG(dbgs() << "Constant folded function call. Result: " << - *InstResult << "\n"); - } else { - DEBUG(dbgs() << "Can not constant fold function call.\n"); - return false; - } - } else { - if (Callee->getFunctionType()->isVarArg()) { - DEBUG(dbgs() << "Can not constant fold vararg function call.\n"); - return false; - } - - Constant *RetVal = nullptr; - // Execute the call, if successful, use the return value. - ValueStack.emplace_back(); - if (!EvaluateFunction(Callee, RetVal, Formals)) { - DEBUG(dbgs() << "Failed to evaluate function.\n"); - return false; - } - ValueStack.pop_back(); - InstResult = RetVal; - - if (InstResult) { - DEBUG(dbgs() << "Successfully evaluated function. Result: " << - InstResult << "\n\n"); - } else { - DEBUG(dbgs() << "Successfully evaluated function. Result: 0\n\n"); - } - } - } else if (isa<TerminatorInst>(CurInst)) { - DEBUG(dbgs() << "Found a terminator instruction.\n"); - - if (BranchInst *BI = dyn_cast<BranchInst>(CurInst)) { - if (BI->isUnconditional()) { - NextBB = BI->getSuccessor(0); - } else { - ConstantInt *Cond = - dyn_cast<ConstantInt>(getVal(BI->getCondition())); - if (!Cond) return false; // Cannot determine. - - NextBB = BI->getSuccessor(!Cond->getZExtValue()); - } - } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurInst)) { - ConstantInt *Val = - dyn_cast<ConstantInt>(getVal(SI->getCondition())); - if (!Val) return false; // Cannot determine. - NextBB = SI->findCaseValue(Val).getCaseSuccessor(); - } else if (IndirectBrInst *IBI = dyn_cast<IndirectBrInst>(CurInst)) { - Value *Val = getVal(IBI->getAddress())->stripPointerCasts(); - if (BlockAddress *BA = dyn_cast<BlockAddress>(Val)) - NextBB = BA->getBasicBlock(); - else - return false; // Cannot determine. - } else if (isa<ReturnInst>(CurInst)) { - NextBB = nullptr; - } else { - // invoke, unwind, resume, unreachable. - DEBUG(dbgs() << "Can not handle terminator."); - return false; // Cannot handle this terminator. - } - - // We succeeded at evaluating this block! - DEBUG(dbgs() << "Successfully evaluated block.\n"); - return true; - } else { - // Did not know how to evaluate this! - DEBUG(dbgs() << "Failed to evaluate block due to unhandled instruction." - "\n"); - return false; - } - - if (!CurInst->use_empty()) { - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(InstResult)) - InstResult = ConstantFoldConstantExpression(CE, DL, TLI); - - setVal(&*CurInst, InstResult); - } - - // If we just processed an invoke, we finished evaluating the block. - if (InvokeInst *II = dyn_cast<InvokeInst>(CurInst)) { - NextBB = II->getNormalDest(); - DEBUG(dbgs() << "Found an invoke instruction. Finished Block.\n\n"); - return true; - } - - // Advance program counter. - ++CurInst; - } -} - -/// Evaluate a call to function F, returning true if successful, false if we -/// can't evaluate it. ActualArgs contains the formal arguments for the -/// function. -bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, - const SmallVectorImpl<Constant*> &ActualArgs) { - // Check to see if this function is already executing (recursion). If so, - // bail out. TODO: we might want to accept limited recursion. - if (std::find(CallStack.begin(), CallStack.end(), F) != CallStack.end()) - return false; - - CallStack.push_back(F); - - // Initialize arguments to the incoming values specified. - unsigned ArgNo = 0; - for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); AI != E; - ++AI, ++ArgNo) - setVal(&*AI, ActualArgs[ArgNo]); - - // ExecutedBlocks - We only handle non-looping, non-recursive code. As such, - // we can only evaluate any one basic block at most once. This set keeps - // track of what we have executed so we can detect recursive cases etc. - SmallPtrSet<BasicBlock*, 32> ExecutedBlocks; - - // CurBB - The current basic block we're evaluating. - BasicBlock *CurBB = &F->front(); - - BasicBlock::iterator CurInst = CurBB->begin(); - - while (1) { - BasicBlock *NextBB = nullptr; // Initialized to avoid compiler warnings. - DEBUG(dbgs() << "Trying to evaluate BB: " << *CurBB << "\n"); - - if (!EvaluateBlock(CurInst, NextBB)) - return false; - - if (!NextBB) { - // Successfully running until there's no next block means that we found - // the return. Fill it the return value and pop the call stack. - ReturnInst *RI = cast<ReturnInst>(CurBB->getTerminator()); - if (RI->getNumOperands()) - RetVal = getVal(RI->getOperand(0)); - CallStack.pop_back(); - return true; - } - - // Okay, we succeeded in evaluating this control flow. See if we have - // executed the new block before. If so, we have a looping function, - // which we cannot evaluate in reasonable time. - if (!ExecutedBlocks.insert(NextBB).second) - return false; // looped! - - // Okay, we have never been in this block before. Check to see if there - // are any PHI nodes. If so, evaluate them with information about where - // we came from. - PHINode *PN = nullptr; - for (CurInst = NextBB->begin(); - (PN = dyn_cast<PHINode>(CurInst)); ++CurInst) - setVal(PN, getVal(PN->getIncomingValueForBlock(CurBB))); - - // Advance to the next block. - CurBB = NextBB; - } -} - /// Evaluate static constructors in the function, if we can. Return true if we /// can, false otherwise. static bool EvaluateStaticConstructor(Function *F, const DataLayout &DL, - const TargetLibraryInfo *TLI) { + TargetLibraryInfo *TLI) { // Call the function. Evaluator Eval(DL, TLI); Constant *RetValDummy; @@ -2838,10 +2174,8 @@ static bool EvaluateStaticConstructor(Function *F, const DataLayout &DL, DEBUG(dbgs() << "FULLY EVALUATED GLOBAL CTOR FUNCTION '" << F->getName() << "' to " << Eval.getMutatedMemory().size() << " stores.\n"); - for (DenseMap<Constant*, Constant*>::const_iterator I = - Eval.getMutatedMemory().begin(), E = Eval.getMutatedMemory().end(); - I != E; ++I) - CommitValueTo(I->second, I->first); + for (const auto &I : Eval.getMutatedMemory()) + CommitValueTo(I.second, I.first); for (GlobalVariable *GV : Eval.getInvariants()) GV->setConstant(true); } @@ -2850,8 +2184,9 @@ static bool EvaluateStaticConstructor(Function *F, const DataLayout &DL, } static int compareNames(Constant *const *A, Constant *const *B) { - return (*A)->stripPointerCasts()->getName().compare( - (*B)->stripPointerCasts()->getName()); + Value *AStripped = (*A)->stripPointerCastsNoFollowAliases(); + Value *BStripped = (*B)->stripPointerCastsNoFollowAliases(); + return AStripped->getName().compare(BStripped->getName()); } static void setUsedInitializer(GlobalVariable &V, @@ -2995,7 +2330,9 @@ static bool hasUsesToReplace(GlobalAlias &GA, const LLVMUsed &U, return true; } -bool GlobalOpt::OptimizeGlobalAliases(Module &M) { +static bool +OptimizeGlobalAliases(Module &M, + SmallSet<const Comdat *, 8> &NotDiscardableComdats) { bool Changed = false; LLVMUsed Used(M); @@ -3010,13 +2347,13 @@ bool GlobalOpt::OptimizeGlobalAliases(Module &M) { if (!J->hasName() && !J->isDeclaration() && !J->hasLocalLinkage()) J->setLinkage(GlobalValue::InternalLinkage); - if (deleteIfDead(*J)) { + if (deleteIfDead(*J, NotDiscardableComdats)) { Changed = true; continue; } // If the aliasee may change at link time, nothing can be done - bail out. - if (J->mayBeOverridden()) + if (J->isInterposable()) continue; Constant *Aliasee = J->getAliasee(); @@ -3064,23 +2401,16 @@ bool GlobalOpt::OptimizeGlobalAliases(Module &M) { } static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc::cxa_atexit)) + LibFunc::Func F = LibFunc::cxa_atexit; + if (!TLI->has(F)) return nullptr; - Function *Fn = M.getFunction(TLI->getName(LibFunc::cxa_atexit)); - + Function *Fn = M.getFunction(TLI->getName(F)); if (!Fn) return nullptr; - FunctionType *FTy = Fn->getFunctionType(); - - // Checking that the function has the right return type, the right number of - // parameters and that they all have pointer types should be enough. - if (!FTy->getReturnType()->isIntegerTy() || - FTy->getNumParams() != 3 || - !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy() || - !FTy->getParamType(2)->isPointerTy()) + // Make sure that the function has the correct prototype. + if (!TLI->getLibFunc(*Fn, F) || F != LibFunc::cxa_atexit) return nullptr; return Fn; @@ -3132,7 +2462,7 @@ static bool cxxDtorIsEmpty(const Function &Fn, return false; } -bool GlobalOpt::OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { +static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { /// Itanium C++ ABI p3.3.5: /// /// After constructing a global (or local static) object, that will require @@ -3179,12 +2509,11 @@ bool GlobalOpt::OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { return Changed; } -bool GlobalOpt::runOnModule(Module &M) { +static bool optimizeGlobalsInModule( + Module &M, const DataLayout &DL, TargetLibraryInfo *TLI, + function_ref<DominatorTree &(Function &)> LookupDomTree) { + SmallSet<const Comdat *, 8> NotDiscardableComdats; bool Changed = false; - - auto &DL = M.getDataLayout(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - bool LocalChange = true; while (LocalChange) { LocalChange = false; @@ -3204,7 +2533,8 @@ bool GlobalOpt::runOnModule(Module &M) { NotDiscardableComdats.insert(C); // Delete functions that are trivially dead, ccc -> fastcc - LocalChange |= OptimizeFunctions(M); + LocalChange |= + OptimizeFunctions(M, TLI, LookupDomTree, NotDiscardableComdats); // Optimize global_ctors list. LocalChange |= optimizeGlobalCtorsList(M, [&](Function *F) { @@ -3212,10 +2542,11 @@ bool GlobalOpt::runOnModule(Module &M) { }); // Optimize non-address-taken globals. - LocalChange |= OptimizeGlobalVars(M); + LocalChange |= OptimizeGlobalVars(M, TLI, LookupDomTree, + NotDiscardableComdats); // Resolve aliases, when possible. - LocalChange |= OptimizeGlobalAliases(M); + LocalChange |= OptimizeGlobalAliases(M, NotDiscardableComdats); // Try to remove trivial global destructors if they are not removed // already. @@ -3232,3 +2563,53 @@ bool GlobalOpt::runOnModule(Module &M) { return Changed; } +PreservedAnalyses GlobalOptPass::run(Module &M, AnalysisManager<Module> &AM) { + auto &DL = M.getDataLayout(); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); + auto &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto LookupDomTree = [&FAM](Function &F) -> DominatorTree &{ + return FAM.getResult<DominatorTreeAnalysis>(F); + }; + if (!optimizeGlobalsInModule(M, DL, &TLI, LookupDomTree)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +namespace { +struct GlobalOptLegacyPass : public ModulePass { + static char ID; // Pass identification, replacement for typeid + GlobalOptLegacyPass() : ModulePass(ID) { + initializeGlobalOptLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + + auto &DL = M.getDataLayout(); + auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto LookupDomTree = [this](Function &F) -> DominatorTree & { + return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + }; + return optimizeGlobalsInModule(M, DL, TLI, LookupDomTree); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + } +}; +} + +char GlobalOptLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(GlobalOptLegacyPass, "globalopt", + "Global Variable Optimizer", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(GlobalOptLegacyPass, "globalopt", + "Global Variable Optimizer", false, false) + +ModulePass *llvm::createGlobalOptimizerPass() { + return new GlobalOptLegacyPass(); +} diff --git a/lib/Transforms/IPO/IPConstantPropagation.cpp b/lib/Transforms/IPO/IPConstantPropagation.cpp index af541d155254..916135e33cd5 100644 --- a/lib/Transforms/IPO/IPConstantPropagation.cpp +++ b/lib/Transforms/IPO/IPConstantPropagation.cpp @@ -41,44 +41,14 @@ namespace { } bool runOnModule(Module &M) override; - private: - bool PropagateConstantsIntoArguments(Function &F); - bool PropagateConstantReturn(Function &F); }; } -char IPCP::ID = 0; -INITIALIZE_PASS(IPCP, "ipconstprop", - "Interprocedural constant propagation", false, false) - -ModulePass *llvm::createIPConstantPropagationPass() { return new IPCP(); } - -bool IPCP::runOnModule(Module &M) { - bool Changed = false; - bool LocalChange = true; - - // FIXME: instead of using smart algorithms, we just iterate until we stop - // making changes. - while (LocalChange) { - LocalChange = false; - for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) - if (!I->isDeclaration()) { - // Delete any klingons. - I->removeDeadConstantUsers(); - if (I->hasLocalLinkage()) - LocalChange |= PropagateConstantsIntoArguments(*I); - Changed |= PropagateConstantReturn(*I); - } - Changed |= LocalChange; - } - return Changed; -} - /// PropagateConstantsIntoArguments - Look at all uses of the specified /// function. If all uses are direct call sites, and all pass a particular /// constant in for an argument, propagate that constant in as the argument. /// -bool IPCP::PropagateConstantsIntoArguments(Function &F) { +static bool PropagateConstantsIntoArguments(Function &F) { if (F.arg_empty() || F.use_empty()) return false; // No arguments? Early exit. // For each argument, keep track of its constant value and whether it is a @@ -157,13 +127,14 @@ bool IPCP::PropagateConstantsIntoArguments(Function &F) { // Additionally if a function always returns one of its arguments directly, // callers will be updated to use the value they pass in directly instead of // using the return value. -bool IPCP::PropagateConstantReturn(Function &F) { +static bool PropagateConstantReturn(Function &F) { if (F.getReturnType()->isVoidTy()) return false; // No return value. - // If this function could be overridden later in the link stage, we can't - // propagate information about its results into callers. - if (F.mayBeOverridden()) + // We can infer and propagate the return value only when we know that the + // definition we'll get at link time is *exactly* the definition we see now. + // For more details, see GlobalValue::mayBeDerefined. + if (!F.isDefinitionExact()) return false; // Check to see if this function returns a constant. @@ -176,8 +147,8 @@ bool IPCP::PropagateConstantReturn(Function &F) { RetVals.push_back(UndefValue::get(F.getReturnType())); unsigned NumNonConstant = 0; - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) - if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) { + for (BasicBlock &BB : F) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) { for (unsigned i = 0, e = RetVals.size(); i != e; ++i) { // Already found conflicting return values? Value *RV = RetVals[i]; @@ -277,3 +248,33 @@ bool IPCP::PropagateConstantReturn(Function &F) { if (MadeChange) ++NumReturnValProped; return MadeChange; } + +char IPCP::ID = 0; +INITIALIZE_PASS(IPCP, "ipconstprop", + "Interprocedural constant propagation", false, false) + +ModulePass *llvm::createIPConstantPropagationPass() { return new IPCP(); } + +bool IPCP::runOnModule(Module &M) { + if (skipModule(M)) + return false; + + bool Changed = false; + bool LocalChange = true; + + // FIXME: instead of using smart algorithms, we just iterate until we stop + // making changes. + while (LocalChange) { + LocalChange = false; + for (Function &F : M) + if (!F.isDeclaration()) { + // Delete any klingons. + F.removeDeadConstantUsers(); + if (F.hasLocalLinkage()) + LocalChange |= PropagateConstantsIntoArguments(F); + Changed |= PropagateConstantReturn(F); + } + Changed |= LocalChange; + } + return Changed; +} diff --git a/lib/Transforms/IPO/IPO.cpp b/lib/Transforms/IPO/IPO.cpp index 89629cf06e08..3507eba81b2f 100644 --- a/lib/Transforms/IPO/IPO.cpp +++ b/lib/Transforms/IPO/IPO.cpp @@ -18,31 +18,32 @@ #include "llvm/InitializePasses.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/FunctionAttrs.h" using namespace llvm; void llvm::initializeIPO(PassRegistry &Registry) { initializeArgPromotionPass(Registry); - initializeConstantMergePass(Registry); + initializeConstantMergeLegacyPassPass(Registry); initializeCrossDSOCFIPass(Registry); initializeDAEPass(Registry); initializeDAHPass(Registry); initializeForceFunctionAttrsLegacyPassPass(Registry); - initializeGlobalDCEPass(Registry); - initializeGlobalOptPass(Registry); + initializeGlobalDCELegacyPassPass(Registry); + initializeGlobalOptLegacyPassPass(Registry); initializeIPCPPass(Registry); initializeAlwaysInlinerPass(Registry); initializeSimpleInlinerPass(Registry); initializeInferFunctionAttrsLegacyPassPass(Registry); - initializeInternalizePassPass(Registry); + initializeInternalizeLegacyPassPass(Registry); initializeLoopExtractorPass(Registry); initializeBlockExtractorPassPass(Registry); initializeSingleLoopExtractorPass(Registry); - initializeLowerBitSetsPass(Registry); + initializeLowerTypeTestsPass(Registry); initializeMergeFunctionsPass(Registry); - initializePartialInlinerPass(Registry); - initializePostOrderFunctionAttrsPass(Registry); - initializeReversePostOrderFunctionAttrsPass(Registry); + initializePartialInlinerLegacyPassPass(Registry); + initializePostOrderFunctionAttrsLegacyPassPass(Registry); + initializeReversePostOrderFunctionAttrsLegacyPassPass(Registry); initializePruneEHPass(Registry); initializeStripDeadPrototypesLegacyPassPass(Registry); initializeStripSymbolsPass(Registry); @@ -50,9 +51,10 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeStripDeadDebugInfoPass(Registry); initializeStripNonDebugSymbolsPass(Registry); initializeBarrierNoopPass(Registry); - initializeEliminateAvailableExternallyPass(Registry); - initializeSampleProfileLoaderPass(Registry); + initializeEliminateAvailableExternallyLegacyPassPass(Registry); + initializeSampleProfileLoaderLegacyPassPass(Registry); initializeFunctionImportPassPass(Registry); + initializeWholeProgramDevirtPass(Registry); } void LLVMInitializeIPO(LLVMPassRegistryRef R) { @@ -72,7 +74,7 @@ void LLVMAddDeadArgEliminationPass(LLVMPassManagerRef PM) { } void LLVMAddFunctionAttrsPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createPostOrderFunctionAttrsPass()); + unwrap(PM)->add(createPostOrderFunctionAttrsLegacyPass()); } void LLVMAddFunctionInliningPass(LLVMPassManagerRef PM) { @@ -104,10 +106,10 @@ void LLVMAddIPSCCPPass(LLVMPassManagerRef PM) { } void LLVMAddInternalizePass(LLVMPassManagerRef PM, unsigned AllButMain) { - std::vector<const char *> Export; - if (AllButMain) - Export.push_back("main"); - unwrap(PM)->add(createInternalizePass(Export)); + auto PreserveMain = [=](const GlobalValue &GV) { + return AllButMain && GV.getName() == "main"; + }; + unwrap(PM)->add(createInternalizePass(PreserveMain)); } void LLVMAddStripDeadPrototypesPass(LLVMPassManagerRef PM) { diff --git a/lib/Transforms/IPO/InferFunctionAttrs.cpp b/lib/Transforms/IPO/InferFunctionAttrs.cpp index 4295a7595c29..ab2d2bd8b02a 100644 --- a/lib/Transforms/IPO/InferFunctionAttrs.cpp +++ b/lib/Transforms/IPO/InferFunctionAttrs.cpp @@ -8,7 +8,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/InferFunctionAttrs.h" -#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/IR/Function.h" @@ -16,937 +15,27 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" using namespace llvm; #define DEBUG_TYPE "inferattrs" -STATISTIC(NumReadNone, "Number of functions inferred as readnone"); -STATISTIC(NumReadOnly, "Number of functions inferred as readonly"); -STATISTIC(NumArgMemOnly, "Number of functions inferred as argmemonly"); -STATISTIC(NumNoUnwind, "Number of functions inferred as nounwind"); -STATISTIC(NumNoCapture, "Number of arguments inferred as nocapture"); -STATISTIC(NumReadOnlyArg, "Number of arguments inferred as readonly"); -STATISTIC(NumNoAlias, "Number of function returns inferred as noalias"); -STATISTIC(NumNonNull, "Number of function returns inferred as nonnull returns"); - -static bool setDoesNotAccessMemory(Function &F) { - if (F.doesNotAccessMemory()) - return false; - F.setDoesNotAccessMemory(); - ++NumReadNone; - return true; -} - -static bool setOnlyReadsMemory(Function &F) { - if (F.onlyReadsMemory()) - return false; - F.setOnlyReadsMemory(); - ++NumReadOnly; - return true; -} - -static bool setOnlyAccessesArgMemory(Function &F) { - if (F.onlyAccessesArgMemory()) - return false; - F.setOnlyAccessesArgMemory (); - ++NumArgMemOnly; - return true; -} - - -static bool setDoesNotThrow(Function &F) { - if (F.doesNotThrow()) - return false; - F.setDoesNotThrow(); - ++NumNoUnwind; - return true; -} - -static bool setDoesNotCapture(Function &F, unsigned n) { - if (F.doesNotCapture(n)) - return false; - F.setDoesNotCapture(n); - ++NumNoCapture; - return true; -} - -static bool setOnlyReadsMemory(Function &F, unsigned n) { - if (F.onlyReadsMemory(n)) - return false; - F.setOnlyReadsMemory(n); - ++NumReadOnlyArg; - return true; -} - -static bool setDoesNotAlias(Function &F, unsigned n) { - if (F.doesNotAlias(n)) - return false; - F.setDoesNotAlias(n); - ++NumNoAlias; - return true; -} - -static bool setNonNull(Function &F, unsigned n) { - assert((n != AttributeSet::ReturnIndex || - F.getReturnType()->isPointerTy()) && - "nonnull applies only to pointers"); - if (F.getAttributes().hasAttribute(n, Attribute::NonNull)) - return false; - F.addAttribute(n, Attribute::NonNull); - ++NumNonNull; - return true; -} - -/// Analyze the name and prototype of the given function and set any applicable -/// attributes. -/// -/// Returns true if any attributes were set and false otherwise. -static bool inferPrototypeAttributes(Function &F, - const TargetLibraryInfo &TLI) { - if (F.hasFnAttribute(Attribute::OptimizeNone)) - return false; - - FunctionType *FTy = F.getFunctionType(); - LibFunc::Func TheLibFunc; - if (!(TLI.getLibFunc(F.getName(), TheLibFunc) && TLI.has(TheLibFunc))) - return false; - - bool Changed = false; - switch (TheLibFunc) { - case LibFunc::strlen: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::strchr: - case LibFunc::strrchr: - if (FTy->getNumParams() != 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isIntegerTy()) - return false; - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotThrow(F); - return Changed; - case LibFunc::strtol: - case LibFunc::strtod: - case LibFunc::strtof: - case LibFunc::strtoul: - case LibFunc::strtoll: - case LibFunc::strtold: - case LibFunc::strtoull: - if (FTy->getNumParams() < 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::strcpy: - case LibFunc::stpcpy: - case LibFunc::strcat: - case LibFunc::strncat: - case LibFunc::strncpy: - case LibFunc::stpncpy: - if (FTy->getNumParams() < 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::strxfrm: - if (FTy->getNumParams() != 3 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::strcmp: // 0,1 - case LibFunc::strspn: // 0,1 - case LibFunc::strncmp: // 0,1 - case LibFunc::strcspn: // 0,1 - case LibFunc::strcoll: // 0,1 - case LibFunc::strcasecmp: // 0,1 - case LibFunc::strncasecmp: // - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::strstr: - case LibFunc::strpbrk: - if (FTy->getNumParams() != 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::strtok: - case LibFunc::strtok_r: - if (FTy->getNumParams() < 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::scanf: - if (FTy->getNumParams() < 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::setbuf: - case LibFunc::setvbuf: - if (FTy->getNumParams() < 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::strdup: - case LibFunc::strndup: - if (FTy->getNumParams() < 1 || !FTy->getReturnType()->isPointerTy() || - !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::stat: - case LibFunc::statvfs: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::sscanf: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::sprintf: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::snprintf: - if (FTy->getNumParams() != 3 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(2)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 3); - Changed |= setOnlyReadsMemory(F, 3); - return Changed; - case LibFunc::setitimer: - if (FTy->getNumParams() != 3 || !FTy->getParamType(1)->isPointerTy() || - !FTy->getParamType(2)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setDoesNotCapture(F, 3); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::system: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - // May throw; "system" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::malloc: - if (FTy->getNumParams() != 1 || !FTy->getReturnType()->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - return Changed; - case LibFunc::memcmp: - if (FTy->getNumParams() != 3 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::memchr: - case LibFunc::memrchr: - if (FTy->getNumParams() != 3) - return false; - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotThrow(F); - return Changed; - case LibFunc::modf: - case LibFunc::modff: - case LibFunc::modfl: - if (FTy->getNumParams() < 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::memcpy: - case LibFunc::memccpy: - case LibFunc::memmove: - if (FTy->getNumParams() < 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::memalign: - if (!FTy->getReturnType()->isPointerTy()) - return false; - Changed |= setDoesNotAlias(F, 0); - return Changed; - case LibFunc::mkdir: - if (FTy->getNumParams() == 0 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::mktime: - if (FTy->getNumParams() == 0 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::realloc: - if (FTy->getNumParams() != 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getReturnType()->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::read: - if (FTy->getNumParams() != 3 || !FTy->getParamType(1)->isPointerTy()) - return false; - // May throw; "read" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::rewind: - if (FTy->getNumParams() < 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::rmdir: - case LibFunc::remove: - case LibFunc::realpath: - if (FTy->getNumParams() < 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::rename: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::readlink: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::write: - if (FTy->getNumParams() != 3 || !FTy->getParamType(1)->isPointerTy()) - return false; - // May throw; "write" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::bcopy: - if (FTy->getNumParams() != 3 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::bcmp: - if (FTy->getNumParams() != 3 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::bzero: - if (FTy->getNumParams() != 2 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::calloc: - if (FTy->getNumParams() != 2 || !FTy->getReturnType()->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - return Changed; - case LibFunc::chmod: - case LibFunc::chown: - if (FTy->getNumParams() == 0 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::ctermid: - case LibFunc::clearerr: - case LibFunc::closedir: - if (FTy->getNumParams() == 0 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::atoi: - case LibFunc::atol: - case LibFunc::atof: - case LibFunc::atoll: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::access: - if (FTy->getNumParams() != 2 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::fopen: - if (FTy->getNumParams() != 2 || !FTy->getReturnType()->isPointerTy() || - !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::fdopen: - if (FTy->getNumParams() != 2 || !FTy->getReturnType()->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::feof: - case LibFunc::free: - case LibFunc::fseek: - case LibFunc::ftell: - case LibFunc::fgetc: - case LibFunc::fseeko: - case LibFunc::ftello: - case LibFunc::fileno: - case LibFunc::fflush: - case LibFunc::fclose: - case LibFunc::fsetpos: - case LibFunc::flockfile: - case LibFunc::funlockfile: - case LibFunc::ftrylockfile: - if (FTy->getNumParams() == 0 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::ferror: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F); - return Changed; - case LibFunc::fputc: - case LibFunc::fstat: - case LibFunc::frexp: - case LibFunc::frexpf: - case LibFunc::frexpl: - case LibFunc::fstatvfs: - if (FTy->getNumParams() != 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::fgets: - if (FTy->getNumParams() != 3 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(2)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 3); - return Changed; - case LibFunc::fread: - if (FTy->getNumParams() != 4 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(3)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 4); - return Changed; - case LibFunc::fwrite: - if (FTy->getNumParams() != 4 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(3)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 4); - return Changed; - case LibFunc::fputs: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::fscanf: - case LibFunc::fprintf: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::fgetpos: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::getc: - case LibFunc::getlogin_r: - case LibFunc::getc_unlocked: - if (FTy->getNumParams() == 0 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::getenv: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setOnlyReadsMemory(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::gets: - case LibFunc::getchar: - Changed |= setDoesNotThrow(F); - return Changed; - case LibFunc::getitimer: - if (FTy->getNumParams() != 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::getpwnam: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::ungetc: - if (FTy->getNumParams() != 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::uname: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::unlink: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::unsetenv: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::utime: - case LibFunc::utimes: - if (FTy->getNumParams() != 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::putc: - if (FTy->getNumParams() != 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::puts: - case LibFunc::printf: - case LibFunc::perror: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::pread: - if (FTy->getNumParams() != 4 || !FTy->getParamType(1)->isPointerTy()) - return false; - // May throw; "pread" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::pwrite: - if (FTy->getNumParams() != 4 || !FTy->getParamType(1)->isPointerTy()) - return false; - // May throw; "pwrite" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::putchar: - Changed |= setDoesNotThrow(F); - return Changed; - case LibFunc::popen: - if (FTy->getNumParams() != 2 || !FTy->getReturnType()->isPointerTy() || - !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::pclose: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::vscanf: - if (FTy->getNumParams() != 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::vsscanf: - if (FTy->getNumParams() != 3 || !FTy->getParamType(1)->isPointerTy() || - !FTy->getParamType(2)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::vfscanf: - if (FTy->getNumParams() != 3 || !FTy->getParamType(1)->isPointerTy() || - !FTy->getParamType(2)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::valloc: - if (!FTy->getReturnType()->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - return Changed; - case LibFunc::vprintf: - if (FTy->getNumParams() != 2 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::vfprintf: - case LibFunc::vsprintf: - if (FTy->getNumParams() != 3 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::vsnprintf: - if (FTy->getNumParams() != 4 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(2)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 3); - Changed |= setOnlyReadsMemory(F, 3); - return Changed; - case LibFunc::open: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy()) - return false; - // May throw; "open" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::opendir: - if (FTy->getNumParams() != 1 || !FTy->getReturnType()->isPointerTy() || - !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::tmpfile: - if (!FTy->getReturnType()->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - return Changed; - case LibFunc::times: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::htonl: - case LibFunc::htons: - case LibFunc::ntohl: - case LibFunc::ntohs: - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAccessMemory(F); - return Changed; - case LibFunc::lstat: - if (FTy->getNumParams() != 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::lchown: - if (FTy->getNumParams() != 3 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::qsort: - if (FTy->getNumParams() != 4 || !FTy->getParamType(3)->isPointerTy()) - return false; - // May throw; places call through function pointer. - Changed |= setDoesNotCapture(F, 4); - return Changed; - case LibFunc::dunder_strdup: - case LibFunc::dunder_strndup: - if (FTy->getNumParams() < 1 || !FTy->getReturnType()->isPointerTy() || - !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::dunder_strtok_r: - if (FTy->getNumParams() != 3 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::under_IO_getc: - if (FTy->getNumParams() != 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::under_IO_putc: - if (FTy->getNumParams() != 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::dunder_isoc99_scanf: - if (FTy->getNumParams() < 1 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::stat64: - case LibFunc::lstat64: - case LibFunc::statvfs64: - if (FTy->getNumParams() < 1 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::dunder_isoc99_sscanf: - if (FTy->getNumParams() < 1 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::fopen64: - if (FTy->getNumParams() != 2 || !FTy->getReturnType()->isPointerTy() || - !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - Changed |= setOnlyReadsMemory(F, 1); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - case LibFunc::fseeko64: - case LibFunc::ftello64: - if (FTy->getNumParams() == 0 || !FTy->getParamType(0)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - return Changed; - case LibFunc::tmpfile64: - if (!FTy->getReturnType()->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotAlias(F, 0); - return Changed; - case LibFunc::fstat64: - case LibFunc::fstatvfs64: - if (FTy->getNumParams() != 2 || !FTy->getParamType(1)->isPointerTy()) - return false; - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 2); - return Changed; - case LibFunc::open64: - if (FTy->getNumParams() < 2 || !FTy->getParamType(0)->isPointerTy()) - return false; - // May throw; "open" is a valid pthread cancellation point. - Changed |= setDoesNotCapture(F, 1); - Changed |= setOnlyReadsMemory(F, 1); - return Changed; - case LibFunc::gettimeofday: - if (FTy->getNumParams() != 2 || !FTy->getParamType(0)->isPointerTy() || - !FTy->getParamType(1)->isPointerTy()) - return false; - // Currently some platforms have the restrict keyword on the arguments to - // gettimeofday. To be conservative, do not add noalias to gettimeofday's - // arguments. - Changed |= setDoesNotThrow(F); - Changed |= setDoesNotCapture(F, 1); - Changed |= setDoesNotCapture(F, 2); - return Changed; - - case LibFunc::Znwj: // new(unsigned int) - case LibFunc::Znwm: // new(unsigned long) - case LibFunc::Znaj: // new[](unsigned int) - case LibFunc::Znam: // new[](unsigned long) - case LibFunc::msvc_new_int: // new(unsigned int) - case LibFunc::msvc_new_longlong: // new(unsigned long long) - case LibFunc::msvc_new_array_int: // new[](unsigned int) - case LibFunc::msvc_new_array_longlong: // new[](unsigned long long) - if (FTy->getNumParams() != 1) - return false; - // Operator new always returns a nonnull noalias pointer - Changed |= setNonNull(F, AttributeSet::ReturnIndex); - Changed |= setDoesNotAlias(F, AttributeSet::ReturnIndex); - return Changed; - - //TODO: add LibFunc entries for: - //case LibFunc::memset_pattern4: - //case LibFunc::memset_pattern8: - case LibFunc::memset_pattern16: - if (FTy->isVarArg() || FTy->getNumParams() != 3 || - !isa<PointerType>(FTy->getParamType(0)) || - !isa<PointerType>(FTy->getParamType(1)) || - !isa<IntegerType>(FTy->getParamType(2))) - return false; - - Changed |= setOnlyAccessesArgMemory(F); - Changed |= setOnlyReadsMemory(F, 2); - return Changed; - - default: - // FIXME: It'd be really nice to cover all the library functions we're - // aware of here. - return false; - } -} - static bool inferAllPrototypeAttributes(Module &M, const TargetLibraryInfo &TLI) { bool Changed = false; for (Function &F : M.functions()) - // We only infer things using the prototype if the definition isn't around - // to analyze directly. - if (F.isDeclaration()) - Changed |= inferPrototypeAttributes(F, TLI); + // We only infer things using the prototype and the name; we don't need + // definitions. + if (F.isDeclaration() && !F.hasFnAttribute((Attribute::OptimizeNone))) + Changed |= inferLibFuncAttributes(F, TLI); return Changed; } PreservedAnalyses InferFunctionAttrsPass::run(Module &M, - AnalysisManager<Module> *AM) { - auto &TLI = AM->getResult<TargetLibraryAnalysis>(M); + AnalysisManager<Module> &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); if (!inferAllPrototypeAttributes(M, TLI)) // If we didn't infer anything, preserve all analyses. @@ -970,6 +59,9 @@ struct InferFunctionAttrsLegacyPass : public ModulePass { } bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); return inferAllPrototypeAttributes(M, TLI); } diff --git a/lib/Transforms/IPO/InlineAlways.cpp b/lib/Transforms/IPO/InlineAlways.cpp index 1704bfea0b86..cb1ab95ec2af 100644 --- a/lib/Transforms/IPO/InlineAlways.cpp +++ b/lib/Transforms/IPO/InlineAlways.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/CallingConv.h" @@ -37,16 +38,17 @@ namespace { class AlwaysInliner : public Inliner { public: - // Use extremely low threshold. - AlwaysInliner() : Inliner(ID, -2000000000, /*InsertLifetime*/ true) { + AlwaysInliner() : Inliner(ID, /*InsertLifetime*/ true) { initializeAlwaysInlinerPass(*PassRegistry::getPassRegistry()); } - AlwaysInliner(bool InsertLifetime) - : Inliner(ID, -2000000000, InsertLifetime) { + AlwaysInliner(bool InsertLifetime) : Inliner(ID, InsertLifetime) { initializeAlwaysInlinerPass(*PassRegistry::getPassRegistry()); } + /// Main run interface method. We override here to avoid calling skipSCC(). + bool runOnSCC(CallGraphSCC &SCC) override { return inlineCalls(SCC); } + static char ID; // Pass identification, replacement for typeid InlineCost getInlineCost(CallSite CS) override; @@ -64,6 +66,7 @@ INITIALIZE_PASS_BEGIN(AlwaysInliner, "always-inline", "Inliner for always_inline functions", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(AlwaysInliner, "always-inline", "Inliner for always_inline functions", false, false) diff --git a/lib/Transforms/IPO/InlineSimple.cpp b/lib/Transforms/IPO/InlineSimple.cpp index 45609f891ed8..2aa650bd219d 100644 --- a/lib/Transforms/IPO/InlineSimple.cpp +++ b/lib/Transforms/IPO/InlineSimple.cpp @@ -14,6 +14,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CallSite.h" @@ -38,14 +39,20 @@ namespace { /// inliner pass and the always inliner pass. The two passes use different cost /// analyses to determine when to inline. class SimpleInliner : public Inliner { + // This field is populated based on one of the following: + // * optimization or size-optimization levels, + // * the --inline-threshold flag, or + // * a user specified value. + int DefaultThreshold; public: - SimpleInliner() : Inliner(ID) { + SimpleInliner() + : Inliner(ID), DefaultThreshold(llvm::getDefaultInlineThreshold()) { initializeSimpleInlinerPass(*PassRegistry::getPassRegistry()); } - SimpleInliner(int Threshold) - : Inliner(ID, Threshold, /*InsertLifetime*/ true) { + explicit SimpleInliner(int Threshold) + : Inliner(ID), DefaultThreshold(Threshold) { initializeSimpleInlinerPass(*PassRegistry::getPassRegistry()); } @@ -54,7 +61,7 @@ public: InlineCost getInlineCost(CallSite CS) override { Function *Callee = CS.getCalledFunction(); TargetTransformInfo &TTI = TTIWP->getTTI(*Callee); - return llvm::getInlineCost(CS, getInlineThreshold(CS), TTI, ACT); + return llvm::getInlineCost(CS, DefaultThreshold, TTI, ACT, PSI); } bool runOnSCC(CallGraphSCC &SCC) override; @@ -64,17 +71,6 @@ private: TargetTransformInfoWrapperPass *TTIWP; }; -static int computeThresholdFromOptLevels(unsigned OptLevel, - unsigned SizeOptLevel) { - if (OptLevel > 2) - return 275; - if (SizeOptLevel == 1) // -Os - return 75; - if (SizeOptLevel == 2) // -Oz - return 25; - return 225; -} - } // end anonymous namespace char SimpleInliner::ID = 0; @@ -82,6 +78,7 @@ INITIALIZE_PASS_BEGIN(SimpleInliner, "inline", "Function Integration/Inlining", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(SimpleInliner, "inline", @@ -96,7 +93,7 @@ Pass *llvm::createFunctionInliningPass(int Threshold) { Pass *llvm::createFunctionInliningPass(unsigned OptLevel, unsigned SizeOptLevel) { return new SimpleInliner( - computeThresholdFromOptLevels(OptLevel, SizeOptLevel)); + llvm::computeThresholdFromOptLevels(OptLevel, SizeOptLevel)); } bool SimpleInliner::runOnSCC(CallGraphSCC &SCC) { diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp index bbe5f8761d5f..79535ca49780 100644 --- a/lib/Transforms/IPO/Inliner.cpp +++ b/lib/Transforms/IPO/Inliner.cpp @@ -13,7 +13,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO/InlinerPass.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -21,6 +20,7 @@ #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" @@ -28,9 +28,9 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO/InlinerPass.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -47,40 +47,19 @@ STATISTIC(NumMergedAllocas, "Number of allocas merged together"); // if those would be more profitable and blocked inline steps. STATISTIC(NumCallerCallersAnalyzed, "Number of caller-callers analyzed"); -static cl::opt<int> -InlineLimit("inline-threshold", cl::Hidden, cl::init(225), cl::ZeroOrMore, - cl::desc("Control the amount of inlining to perform (default = 225)")); - -static cl::opt<int> -HintThreshold("inlinehint-threshold", cl::Hidden, cl::init(325), - cl::desc("Threshold for inlining functions with inline hint")); - -// We instroduce this threshold to help performance of instrumentation based -// PGO before we actually hook up inliner with analysis passes such as BPI and -// BFI. -static cl::opt<int> -ColdThreshold("inlinecold-threshold", cl::Hidden, cl::init(225), - cl::desc("Threshold for inlining functions with cold attribute")); - -// Threshold to use when optsize is specified (and there is no -inline-limit). -const int OptSizeThreshold = 75; +Inliner::Inliner(char &ID) : CallGraphSCCPass(ID), InsertLifetime(true) {} -Inliner::Inliner(char &ID) - : CallGraphSCCPass(ID), InlineThreshold(InlineLimit), InsertLifetime(true) { -} - -Inliner::Inliner(char &ID, int Threshold, bool InsertLifetime) - : CallGraphSCCPass(ID), - InlineThreshold(InlineLimit.getNumOccurrences() > 0 ? InlineLimit - : Threshold), - InsertLifetime(InsertLifetime) {} +Inliner::Inliner(char &ID, bool InsertLifetime) + : CallGraphSCCPass(ID), InsertLifetime(InsertLifetime) {} /// For this class, we declare that we require and preserve the call graph. /// If the derived class implements this method, it should /// always explicitly call the implementation here. void Inliner::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + getAAResultsAnalysisUsage(AU); CallGraphSCCPass::getAnalysisUsage(AU); } @@ -243,67 +222,6 @@ static bool InlineCallIfPossible(Pass &P, CallSite CS, InlineFunctionInfo &IFI, return true; } -unsigned Inliner::getInlineThreshold(CallSite CS) const { - int Threshold = InlineThreshold; // -inline-threshold or else selected by - // overall opt level - - // If -inline-threshold is not given, listen to the optsize attribute when it - // would decrease the threshold. - Function *Caller = CS.getCaller(); - bool OptSize = Caller && !Caller->isDeclaration() && - // FIXME: Use Function::optForSize(). - Caller->hasFnAttribute(Attribute::OptimizeForSize); - if (!(InlineLimit.getNumOccurrences() > 0) && OptSize && - OptSizeThreshold < Threshold) - Threshold = OptSizeThreshold; - - Function *Callee = CS.getCalledFunction(); - if (!Callee || Callee->isDeclaration()) - return Threshold; - - // If profile information is available, use that to adjust threshold of hot - // and cold functions. - // FIXME: The heuristic used below for determining hotness and coldness are - // based on preliminary SPEC tuning and may not be optimal. Replace this with - // a well-tuned heuristic based on *callsite* hotness and not callee hotness. - uint64_t FunctionCount = 0, MaxFunctionCount = 0; - bool HasPGOCounts = false; - if (Callee->getEntryCount() && - Callee->getParent()->getMaximumFunctionCount()) { - HasPGOCounts = true; - FunctionCount = Callee->getEntryCount().getValue(); - MaxFunctionCount = - Callee->getParent()->getMaximumFunctionCount().getValue(); - } - - // Listen to the inlinehint attribute or profile based hotness information - // when it would increase the threshold and the caller does not need to - // minimize its size. - bool InlineHint = - Callee->hasFnAttribute(Attribute::InlineHint) || - (HasPGOCounts && - FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount)); - if (InlineHint && HintThreshold > Threshold && - !Caller->hasFnAttribute(Attribute::MinSize)) - Threshold = HintThreshold; - - // Listen to the cold attribute or profile based coldness information - // when it would decrease the threshold. - bool ColdCallee = - Callee->hasFnAttribute(Attribute::Cold) || - (HasPGOCounts && - FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount)); - // Command line argument for InlineLimit will override the default - // ColdThreshold. If we have -inline-threshold but no -inlinecold-threshold, - // do not use the default cold threshold even if it is smaller. - if ((InlineLimit.getNumOccurrences() == 0 || - ColdThreshold.getNumOccurrences() > 0) && ColdCallee && - ColdThreshold < Threshold) - Threshold = ColdThreshold; - - return Threshold; -} - static void emitAnalysis(CallSite CS, const Twine &Msg) { Function *Caller = CS.getCaller(); LLVMContext &Ctx = Caller->getContext(); @@ -311,6 +229,76 @@ static void emitAnalysis(CallSite CS, const Twine &Msg) { emitOptimizationRemarkAnalysis(Ctx, DEBUG_TYPE, *Caller, DLoc, Msg); } +bool Inliner::shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, + int &TotalSecondaryCost) { + + // For now we only handle local or inline functions. + if (!Caller->hasLocalLinkage() && !Caller->hasLinkOnceODRLinkage()) + return false; + // Try to detect the case where the current inlining candidate caller (call + // it B) is a static or linkonce-ODR function and is an inlining candidate + // elsewhere, and the current candidate callee (call it C) is large enough + // that inlining it into B would make B too big to inline later. In these + // circumstances it may be best not to inline C into B, but to inline B into + // its callers. + // + // This only applies to static and linkonce-ODR functions because those are + // expected to be available for inlining in the translation units where they + // are used. Thus we will always have the opportunity to make local inlining + // decisions. Importantly the linkonce-ODR linkage covers inline functions + // and templates in C++. + // + // FIXME: All of this logic should be sunk into getInlineCost. It relies on + // the internal implementation of the inline cost metrics rather than + // treating them as truly abstract units etc. + TotalSecondaryCost = 0; + // The candidate cost to be imposed upon the current function. + int CandidateCost = IC.getCost() - (InlineConstants::CallPenalty + 1); + // This bool tracks what happens if we do NOT inline C into B. + bool callerWillBeRemoved = Caller->hasLocalLinkage(); + // This bool tracks what happens if we DO inline C into B. + bool inliningPreventsSomeOuterInline = false; + for (User *U : Caller->users()) { + CallSite CS2(U); + + // If this isn't a call to Caller (it could be some other sort + // of reference) skip it. Such references will prevent the caller + // from being removed. + if (!CS2 || CS2.getCalledFunction() != Caller) { + callerWillBeRemoved = false; + continue; + } + + InlineCost IC2 = getInlineCost(CS2); + ++NumCallerCallersAnalyzed; + if (!IC2) { + callerWillBeRemoved = false; + continue; + } + if (IC2.isAlways()) + continue; + + // See if inlining or original callsite would erase the cost delta of + // this callsite. We subtract off the penalty for the call instruction, + // which we would be deleting. + if (IC2.getCostDelta() <= CandidateCost) { + inliningPreventsSomeOuterInline = true; + TotalSecondaryCost += IC2.getCost(); + } + } + // If all outer calls to Caller would get inlined, the cost for the last + // one is set very low by getInlineCost, in anticipation that Caller will + // be removed entirely. We did not account for this above unless there + // is only one caller of Caller. + if (callerWillBeRemoved && !Caller->use_empty()) + TotalSecondaryCost += InlineConstants::LastCallToStaticBonus; + + if (inliningPreventsSomeOuterInline && TotalSecondaryCost < IC.getCost()) + return true; + + return false; +} + /// Return true if the inliner should attempt to inline at the given CallSite. bool Inliner::shouldInline(CallSite CS) { InlineCost IC = getInlineCost(CS); @@ -342,77 +330,17 @@ bool Inliner::shouldInline(CallSite CS) { Twine(IC.getCostDelta() + IC.getCost()) + ")"); return false; } - - // Try to detect the case where the current inlining candidate caller (call - // it B) is a static or linkonce-ODR function and is an inlining candidate - // elsewhere, and the current candidate callee (call it C) is large enough - // that inlining it into B would make B too big to inline later. In these - // circumstances it may be best not to inline C into B, but to inline B into - // its callers. - // - // This only applies to static and linkonce-ODR functions because those are - // expected to be available for inlining in the translation units where they - // are used. Thus we will always have the opportunity to make local inlining - // decisions. Importantly the linkonce-ODR linkage covers inline functions - // and templates in C++. - // - // FIXME: All of this logic should be sunk into getInlineCost. It relies on - // the internal implementation of the inline cost metrics rather than - // treating them as truly abstract units etc. - if (Caller->hasLocalLinkage() || Caller->hasLinkOnceODRLinkage()) { - int TotalSecondaryCost = 0; - // The candidate cost to be imposed upon the current function. - int CandidateCost = IC.getCost() - (InlineConstants::CallPenalty + 1); - // This bool tracks what happens if we do NOT inline C into B. - bool callerWillBeRemoved = Caller->hasLocalLinkage(); - // This bool tracks what happens if we DO inline C into B. - bool inliningPreventsSomeOuterInline = false; - for (User *U : Caller->users()) { - CallSite CS2(U); - - // If this isn't a call to Caller (it could be some other sort - // of reference) skip it. Such references will prevent the caller - // from being removed. - if (!CS2 || CS2.getCalledFunction() != Caller) { - callerWillBeRemoved = false; - continue; - } - InlineCost IC2 = getInlineCost(CS2); - ++NumCallerCallersAnalyzed; - if (!IC2) { - callerWillBeRemoved = false; - continue; - } - if (IC2.isAlways()) - continue; - - // See if inlining or original callsite would erase the cost delta of - // this callsite. We subtract off the penalty for the call instruction, - // which we would be deleting. - if (IC2.getCostDelta() <= CandidateCost) { - inliningPreventsSomeOuterInline = true; - TotalSecondaryCost += IC2.getCost(); - } - } - // If all outer calls to Caller would get inlined, the cost for the last - // one is set very low by getInlineCost, in anticipation that Caller will - // be removed entirely. We did not account for this above unless there - // is only one caller of Caller. - if (callerWillBeRemoved && !Caller->use_empty()) - TotalSecondaryCost += InlineConstants::LastCallToStaticBonus; - - if (inliningPreventsSomeOuterInline && TotalSecondaryCost < IC.getCost()) { - DEBUG(dbgs() << " NOT Inlining: " << *CS.getInstruction() << - " Cost = " << IC.getCost() << - ", outer Cost = " << TotalSecondaryCost << '\n'); - emitAnalysis( - CS, Twine("Not inlining. Cost of inlining " + - CS.getCalledFunction()->getName() + - " increases the cost of inlining " + - CS.getCaller()->getName() + " in other contexts")); - return false; - } + int TotalSecondaryCost = 0; + if (shouldBeDeferred(Caller, CS, IC, TotalSecondaryCost)) { + DEBUG(dbgs() << " NOT Inlining: " << *CS.getInstruction() + << " Cost = " << IC.getCost() + << ", outer Cost = " << TotalSecondaryCost << '\n'); + emitAnalysis(CS, Twine("Not inlining. Cost of inlining " + + CS.getCalledFunction()->getName() + + " increases the cost of inlining " + + CS.getCaller()->getName() + " in other contexts")); + return false; } DEBUG(dbgs() << " Inlining: cost=" << IC.getCost() @@ -440,8 +368,15 @@ static bool InlineHistoryIncludes(Function *F, int InlineHistoryID, } bool Inliner::runOnSCC(CallGraphSCC &SCC) { + if (skipSCC(SCC)) + return false; + return inlineCalls(SCC); +} + +bool Inliner::inlineCalls(CallGraphSCC &SCC) { CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); ACT = &getAnalysis<AssumptionCacheTracker>(); + PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(CG.getModule()); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); SmallPtrSet<Function*, 8> SCCFunctions; diff --git a/lib/Transforms/IPO/Internalize.cpp b/lib/Transforms/IPO/Internalize.cpp index 21bb5d000bc7..8c5c6f77077c 100644 --- a/lib/Transforms/IPO/Internalize.cpp +++ b/lib/Transforms/IPO/Internalize.cpp @@ -8,8 +8,8 @@ //===----------------------------------------------------------------------===// // // This pass loops over all of the functions and variables in the input module. -// If the function or variable is not in the list of external names given to -// the pass it is marked as internal. +// If the function or variable does not need to be preserved according to the +// client supplied callback, it is marked as internal. // // This transformation would not be legal in a regular compilation, but it gets // extra information from the linker about what is safe. @@ -19,98 +19,77 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/Internalize.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/GlobalStatus.h" -#include "llvm/Transforms/Utils/ModuleUtils.h" #include <fstream> #include <set> using namespace llvm; #define DEBUG_TYPE "internalize" -STATISTIC(NumAliases , "Number of aliases internalized"); +STATISTIC(NumAliases, "Number of aliases internalized"); STATISTIC(NumFunctions, "Number of functions internalized"); -STATISTIC(NumGlobals , "Number of global vars internalized"); +STATISTIC(NumGlobals, "Number of global vars internalized"); // APIFile - A file which contains a list of symbols that should not be marked // external. static cl::opt<std::string> -APIFile("internalize-public-api-file", cl::value_desc("filename"), - cl::desc("A file containing list of symbol names to preserve")); + APIFile("internalize-public-api-file", cl::value_desc("filename"), + cl::desc("A file containing list of symbol names to preserve")); // APIList - A list of symbols that should not be marked internal. static cl::list<std::string> -APIList("internalize-public-api-list", cl::value_desc("list"), - cl::desc("A list of symbol names to preserve"), - cl::CommaSeparated); + APIList("internalize-public-api-list", cl::value_desc("list"), + cl::desc("A list of symbol names to preserve"), cl::CommaSeparated); namespace { - class InternalizePass : public ModulePass { - std::set<std::string> ExternalNames; - public: - static char ID; // Pass identification, replacement for typeid - explicit InternalizePass(); - explicit InternalizePass(ArrayRef<const char *> ExportList); - void LoadFile(const char *Filename); - bool maybeInternalize(GlobalValue &GV, - const std::set<const Comdat *> &ExternalComdats); - void checkComdatVisibility(GlobalValue &GV, - std::set<const Comdat *> &ExternalComdats); - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addPreserved<CallGraphWrapperPass>(); - } - }; -} // end anonymous namespace - -char InternalizePass::ID = 0; -INITIALIZE_PASS(InternalizePass, "internalize", - "Internalize Global Symbols", false, false) - -InternalizePass::InternalizePass() : ModulePass(ID) { - initializeInternalizePassPass(*PassRegistry::getPassRegistry()); - if (!APIFile.empty()) // If a filename is specified, use it. - LoadFile(APIFile.c_str()); - ExternalNames.insert(APIList.begin(), APIList.end()); -} - -InternalizePass::InternalizePass(ArrayRef<const char *> ExportList) - : ModulePass(ID) { - initializeInternalizePassPass(*PassRegistry::getPassRegistry()); - for(ArrayRef<const char *>::const_iterator itr = ExportList.begin(); - itr != ExportList.end(); itr++) { - ExternalNames.insert(*itr); +// Helper to load an API list to preserve from file and expose it as a functor +// for internalization. +class PreserveAPIList { +public: + PreserveAPIList() { + if (!APIFile.empty()) + LoadFile(APIFile); + ExternalNames.insert(APIList.begin(), APIList.end()); } -} -void InternalizePass::LoadFile(const char *Filename) { - // Load the APIFile... - std::ifstream In(Filename); - if (!In.good()) { - errs() << "WARNING: Internalize couldn't load file '" << Filename - << "'! Continuing as if it's empty.\n"; - return; // Just continue as if the file were empty + bool operator()(const GlobalValue &GV) { + return ExternalNames.count(GV.getName()); } - while (In) { - std::string Symbol; - In >> Symbol; - if (!Symbol.empty()) - ExternalNames.insert(Symbol); + +private: + // Contains the set of symbols loaded from file + StringSet<> ExternalNames; + + void LoadFile(StringRef Filename) { + // Load the APIFile... + std::ifstream In(Filename.data()); + if (!In.good()) { + errs() << "WARNING: Internalize couldn't load file '" << Filename + << "'! Continuing as if it's empty.\n"; + return; // Just continue as if the file were empty + } + while (In) { + std::string Symbol; + In >> Symbol; + if (!Symbol.empty()) + ExternalNames.insert(Symbol); + } } -} +}; +} // end anonymous namespace -static bool isExternallyVisible(const GlobalValue &GV, - const std::set<std::string> &ExternalNames) { +bool InternalizePass::shouldPreserveGV(const GlobalValue &GV) { // Function must be defined here if (GV.isDeclaration()) return true; @@ -123,15 +102,17 @@ static bool isExternallyVisible(const GlobalValue &GV, if (GV.hasDLLExportStorageClass()) return true; - // Marked to keep external? - if (!GV.hasLocalLinkage() && ExternalNames.count(GV.getName())) + // Already local, has nothing to do. + if (GV.hasLocalLinkage()) + return false; + + // Check some special cases + if (AlwaysPreserved.count(GV.getName())) return true; - return false; + return MustPreserveGV(GV); } -// Internalize GV if it is possible to do so, i.e. it is not externally visible -// and is not a member of an externally visible comdat. bool InternalizePass::maybeInternalize( GlobalValue &GV, const std::set<const Comdat *> &ExternalComdats) { if (Comdat *C = GV.getComdat()) { @@ -148,7 +129,7 @@ bool InternalizePass::maybeInternalize( if (GV.hasLocalLinkage()) return false; - if (isExternallyVisible(GV, ExternalNames)) + if (shouldPreserveGV(GV)) return false; } @@ -165,13 +146,12 @@ void InternalizePass::checkComdatVisibility( if (!C) return; - if (isExternallyVisible(GV, ExternalNames)) + if (shouldPreserveGV(GV)) ExternalComdats.insert(C); } -bool InternalizePass::runOnModule(Module &M) { - CallGraphWrapperPass *CGPass = getAnalysisIfAvailable<CallGraphWrapperPass>(); - CallGraph *CG = CGPass ? &CGPass->getCallGraph() : nullptr; +bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { + bool Changed = false; CallGraphNode *ExternalNode = CG ? CG->getExternalCallingNode() : nullptr; SmallPtrSet<GlobalValue *, 8> Used; @@ -198,13 +178,14 @@ bool InternalizePass::runOnModule(Module &M) { // conservative, we internalize symbols in llvm.compiler.used, but we // keep llvm.compiler.used so that the symbol is not deleted by llvm. for (GlobalValue *V : Used) { - ExternalNames.insert(V->getName()); + AlwaysPreserved.insert(V->getName()); } // Mark all functions not in the api as internal. for (Function &I : M) { if (!maybeInternalize(I, ExternalComdats)) continue; + Changed = true; if (ExternalNode) // Remove a callgraph edge from the external node to this function. @@ -217,53 +198,97 @@ bool InternalizePass::runOnModule(Module &M) { // Never internalize the llvm.used symbol. It is used to implement // attribute((used)). // FIXME: Shouldn't this just filter on llvm.metadata section?? - ExternalNames.insert("llvm.used"); - ExternalNames.insert("llvm.compiler.used"); + AlwaysPreserved.insert("llvm.used"); + AlwaysPreserved.insert("llvm.compiler.used"); // Never internalize anchors used by the machine module info, else the info // won't find them. (see MachineModuleInfo.) - ExternalNames.insert("llvm.global_ctors"); - ExternalNames.insert("llvm.global_dtors"); - ExternalNames.insert("llvm.global.annotations"); + AlwaysPreserved.insert("llvm.global_ctors"); + AlwaysPreserved.insert("llvm.global_dtors"); + AlwaysPreserved.insert("llvm.global.annotations"); // Never internalize symbols code-gen inserts. // FIXME: We should probably add this (and the __stack_chk_guard) via some // type of call-back in CodeGen. - ExternalNames.insert("__stack_chk_fail"); - ExternalNames.insert("__stack_chk_guard"); + AlwaysPreserved.insert("__stack_chk_fail"); + AlwaysPreserved.insert("__stack_chk_guard"); // Mark all global variables with initializers that are not in the api as // internal as well. - for (Module::global_iterator I = M.global_begin(), E = M.global_end(); - I != E; ++I) { - if (!maybeInternalize(*I, ExternalComdats)) + for (auto &GV : M.globals()) { + if (!maybeInternalize(GV, ExternalComdats)) continue; + Changed = true; ++NumGlobals; - DEBUG(dbgs() << "Internalized gvar " << I->getName() << "\n"); + DEBUG(dbgs() << "Internalized gvar " << GV.getName() << "\n"); } // Mark all aliases that are not in the api as internal as well. - for (Module::alias_iterator I = M.alias_begin(), E = M.alias_end(); - I != E; ++I) { - if (!maybeInternalize(*I, ExternalComdats)) + for (auto &GA : M.aliases()) { + if (!maybeInternalize(GA, ExternalComdats)) continue; + Changed = true; ++NumAliases; - DEBUG(dbgs() << "Internalized alias " << I->getName() << "\n"); + DEBUG(dbgs() << "Internalized alias " << GA.getName() << "\n"); } - // We do not keep track of whether this pass changed the module because - // it adds unnecessary complexity: - // 1) This pass will generally be near the start of the pass pipeline, so - // there will be no analyses to invalidate. - // 2) This pass will most likely end up changing the module and it isn't worth - // worrying about optimizing the case where the module is unchanged. - return true; + return Changed; } -ModulePass *llvm::createInternalizePass() { return new InternalizePass(); } +InternalizePass::InternalizePass() : MustPreserveGV(PreserveAPIList()) {} + +PreservedAnalyses InternalizePass::run(Module &M, AnalysisManager<Module> &AM) { + if (!internalizeModule(M, AM.getCachedResult<CallGraphAnalysis>(M))) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<CallGraphAnalysis>(); + return PA; +} + +namespace { +class InternalizeLegacyPass : public ModulePass { + // Client supplied callback to control wheter a symbol must be preserved. + std::function<bool(const GlobalValue &)> MustPreserveGV; + +public: + static char ID; // Pass identification, replacement for typeid + + InternalizeLegacyPass() : ModulePass(ID), MustPreserveGV(PreserveAPIList()) {} + + InternalizeLegacyPass(std::function<bool(const GlobalValue &)> MustPreserveGV) + : ModulePass(ID), MustPreserveGV(std::move(MustPreserveGV)) { + initializeInternalizeLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + + CallGraphWrapperPass *CGPass = + getAnalysisIfAvailable<CallGraphWrapperPass>(); + CallGraph *CG = CGPass ? &CGPass->getCallGraph() : nullptr; + return internalizeModule(M, MustPreserveGV, CG); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addPreserved<CallGraphWrapperPass>(); + } +}; +} + +char InternalizeLegacyPass::ID = 0; +INITIALIZE_PASS(InternalizeLegacyPass, "internalize", + "Internalize Global Symbols", false, false) + +ModulePass *llvm::createInternalizePass() { + return new InternalizeLegacyPass(); +} -ModulePass *llvm::createInternalizePass(ArrayRef<const char *> ExportList) { - return new InternalizePass(ExportList); +ModulePass *llvm::createInternalizePass( + std::function<bool(const GlobalValue &)> MustPreserveGV) { + return new InternalizeLegacyPass(std::move(MustPreserveGV)); } diff --git a/lib/Transforms/IPO/LLVMBuild.txt b/lib/Transforms/IPO/LLVMBuild.txt index b5410f5f7757..bc3df98d504c 100644 --- a/lib/Transforms/IPO/LLVMBuild.txt +++ b/lib/Transforms/IPO/LLVMBuild.txt @@ -20,4 +20,4 @@ type = Library name = IPO parent = Transforms library_name = ipo -required_libraries = Analysis Core InstCombine IRReader Linker Object ProfileData Scalar Support TransformUtils Vectorize +required_libraries = Analysis Core InstCombine IRReader Linker Object ProfileData Scalar Support TransformUtils Vectorize Instrumentation diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp index 3c6a7bb7a17a..f898c3b5a935 100644 --- a/lib/Transforms/IPO/LoopExtractor.cpp +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -81,7 +81,7 @@ INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single", Pass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &) { - if (skipOptnoneFunction(L)) + if (skipLoop(L)) return false; // Only visit top-level loops. @@ -249,6 +249,9 @@ void BlockExtractorPass::SplitLandingPadPreds(Function *F) { } bool BlockExtractorPass::runOnModule(Module &M) { + if (skipModule(M)) + return false; + std::set<BasicBlock*> TranslatedBlocksToNotExtract; for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) { BasicBlock *BB = BlocksToNotExtract[i]; @@ -272,15 +275,13 @@ bool BlockExtractorPass::runOnModule(Module &M) { std::string &FuncName = BlocksToNotExtractByName.back().first; std::string &BlockName = BlocksToNotExtractByName.back().second; - for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) { - Function &F = *FI; + for (Function &F : M) { if (F.getName() != FuncName) continue; - for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { - BasicBlock &BB = *BI; + for (BasicBlock &BB : F) { if (BB.getName() != BlockName) continue; - TranslatedBlocksToNotExtract.insert(&*BI); + TranslatedBlocksToNotExtract.insert(&BB); } } @@ -290,18 +291,18 @@ bool BlockExtractorPass::runOnModule(Module &M) { // Now that we know which blocks to not extract, figure out which ones we WANT // to extract. std::vector<BasicBlock*> BlocksToExtract; - for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { - SplitLandingPadPreds(&*F); - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) - if (!TranslatedBlocksToNotExtract.count(&*BB)) - BlocksToExtract.push_back(&*BB); + for (Function &F : M) { + SplitLandingPadPreds(&F); + for (BasicBlock &BB : F) + if (!TranslatedBlocksToNotExtract.count(&BB)) + BlocksToExtract.push_back(&BB); } - for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i) { + for (BasicBlock *BlockToExtract : BlocksToExtract) { SmallVector<BasicBlock*, 2> BlocksToExtractVec; - BlocksToExtractVec.push_back(BlocksToExtract[i]); + BlocksToExtractVec.push_back(BlockToExtract); if (const InvokeInst *II = - dyn_cast<InvokeInst>(BlocksToExtract[i]->getTerminator())) + dyn_cast<InvokeInst>(BlockToExtract->getTerminator())) BlocksToExtractVec.push_back(II->getUnwindDest()); CodeExtractor(BlocksToExtractVec).extractCodeRegion(); } diff --git a/lib/Transforms/IPO/LowerBitSets.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp index 7b515745c312..36089f0a8801 100644 --- a/lib/Transforms/IPO/LowerBitSets.cpp +++ b/lib/Transforms/IPO/LowerTypeTests.cpp @@ -1,4 +1,4 @@ -//===-- LowerBitSets.cpp - Bitset lowering pass ---------------------------===// +//===-- LowerTypeTests.cpp - type metadata lowering pass ------------------===// // // The LLVM Compiler Infrastructure // @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// // -// This pass lowers bitset metadata and calls to the llvm.bitset.test intrinsic. -// See http://llvm.org/docs/LangRef.html#bitsets for more information. +// This pass lowers type metadata and calls to the llvm.type.test intrinsic. +// See http://llvm.org/docs/TypeMetadata.html for more information. // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO/LowerBitSets.h" +#include "llvm/Transforms/IPO/LowerTypeTests.h" #include "llvm/Transforms/IPO.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Statistic.h" @@ -33,17 +33,18 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; +using namespace lowertypetests; -#define DEBUG_TYPE "lowerbitsets" +#define DEBUG_TYPE "lowertypetests" STATISTIC(ByteArraySizeBits, "Byte array size in bits"); STATISTIC(ByteArraySizeBytes, "Byte array size in bytes"); STATISTIC(NumByteArraysCreated, "Number of byte arrays created"); -STATISTIC(NumBitSetCallsLowered, "Number of bitset calls lowered"); -STATISTIC(NumBitSetDisjointSets, "Number of disjoint sets of bitsets"); +STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered"); +STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type identifiers"); static cl::opt<bool> AvoidReuse( - "lowerbitsets-avoid-reuse", + "lowertypetests-avoid-reuse", cl::desc("Try to avoid reuse of byte array addresses using aliases"), cl::Hidden, cl::init(true)); @@ -203,10 +204,10 @@ struct ByteArrayInfo { Constant *Mask; }; -struct LowerBitSets : public ModulePass { +struct LowerTypeTests : public ModulePass { static char ID; - LowerBitSets() : ModulePass(ID) { - initializeLowerBitSetsPass(*PassRegistry::getPassRegistry()); + LowerTypeTests() : ModulePass(ID) { + initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry()); } Module *M; @@ -221,105 +222,68 @@ struct LowerBitSets : public ModulePass { IntegerType *Int64Ty; IntegerType *IntPtrTy; - // The llvm.bitsets named metadata. - NamedMDNode *BitSetNM; - - // Mapping from bitset identifiers to the call sites that test them. - DenseMap<Metadata *, std::vector<CallInst *>> BitSetTestCallSites; + // Mapping from type identifiers to the call sites that test them. + DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites; std::vector<ByteArrayInfo> ByteArrayInfos; BitSetInfo - buildBitSet(Metadata *BitSet, + buildBitSet(Metadata *TypeId, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); ByteArrayInfo *createByteArray(BitSetInfo &BSI); void allocateByteArrays(); Value *createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, ByteArrayInfo *&BAI, Value *BitOffset); - void lowerBitSetCalls(ArrayRef<Metadata *> BitSets, - Constant *CombinedGlobalAddr, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); + void + lowerTypeTestCalls(ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, + const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); Value * lowerBitSetCall(CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI, Constant *CombinedGlobal, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); - void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> BitSets, + void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalVariable *> Globals); unsigned getJumpTableEntrySize(); Type *getJumpTableEntryType(); Constant *createJumpTableEntry(GlobalObject *Src, Function *Dest, unsigned Distance); - void verifyBitSetMDNode(MDNode *Op); - void buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, + void verifyTypeMDNode(GlobalObject *GO, MDNode *Type); + void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, ArrayRef<Function *> Functions); - void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> BitSets, + void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalObject *> Globals); - bool buildBitSets(); - bool eraseBitSetMetadata(); - - bool doInitialization(Module &M) override; + bool lower(); bool runOnModule(Module &M) override; }; } // anonymous namespace -INITIALIZE_PASS_BEGIN(LowerBitSets, "lowerbitsets", - "Lower bitset metadata", false, false) -INITIALIZE_PASS_END(LowerBitSets, "lowerbitsets", - "Lower bitset metadata", false, false) -char LowerBitSets::ID = 0; - -ModulePass *llvm::createLowerBitSetsPass() { return new LowerBitSets; } - -bool LowerBitSets::doInitialization(Module &Mod) { - M = &Mod; - const DataLayout &DL = Mod.getDataLayout(); - - Triple TargetTriple(M->getTargetTriple()); - LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX(); - Arch = TargetTriple.getArch(); - ObjectFormat = TargetTriple.getObjectFormat(); +INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false, + false) +char LowerTypeTests::ID = 0; - Int1Ty = Type::getInt1Ty(M->getContext()); - Int8Ty = Type::getInt8Ty(M->getContext()); - Int32Ty = Type::getInt32Ty(M->getContext()); - Int32PtrTy = PointerType::getUnqual(Int32Ty); - Int64Ty = Type::getInt64Ty(M->getContext()); - IntPtrTy = DL.getIntPtrType(M->getContext(), 0); +ModulePass *llvm::createLowerTypeTestsPass() { return new LowerTypeTests; } - BitSetNM = M->getNamedMetadata("llvm.bitsets"); - - BitSetTestCallSites.clear(); - - return false; -} - -/// Build a bit set for BitSet using the object layouts in +/// Build a bit set for TypeId using the object layouts in /// GlobalLayout. -BitSetInfo LowerBitSets::buildBitSet( - Metadata *BitSet, +BitSetInfo LowerTypeTests::buildBitSet( + Metadata *TypeId, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { BitSetBuilder BSB; - // Compute the byte offset of each element of this bitset. - if (BitSetNM) { - for (MDNode *Op : BitSetNM->operands()) { - if (Op->getOperand(0) != BitSet || !Op->getOperand(1)) - continue; - Constant *OpConst = - cast<ConstantAsMetadata>(Op->getOperand(1))->getValue(); - if (auto GA = dyn_cast<GlobalAlias>(OpConst)) - OpConst = GA->getAliasee(); - auto OpGlobal = dyn_cast<GlobalObject>(OpConst); - if (!OpGlobal) + // Compute the byte offset of each address associated with this type + // identifier. + SmallVector<MDNode *, 2> Types; + for (auto &GlobalAndOffset : GlobalLayout) { + Types.clear(); + GlobalAndOffset.first->getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + if (Type->getOperand(1) != TypeId) continue; uint64_t Offset = - cast<ConstantInt>(cast<ConstantAsMetadata>(Op->getOperand(2)) + cast<ConstantInt>(cast<ConstantAsMetadata>(Type->getOperand(0)) ->getValue())->getZExtValue(); - - Offset += GlobalLayout.find(OpGlobal)->second; - - BSB.addOffset(Offset); + BSB.addOffset(GlobalAndOffset.second + Offset); } } @@ -341,7 +305,7 @@ static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits, return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0)); } -ByteArrayInfo *LowerBitSets::createByteArray(BitSetInfo &BSI) { +ByteArrayInfo *LowerTypeTests::createByteArray(BitSetInfo &BSI) { // Create globals to stand in for byte arrays and masks. These never actually // get initialized, we RAUW and erase them later in allocateByteArrays() once // we know the offset and mask to use. @@ -360,7 +324,7 @@ ByteArrayInfo *LowerBitSets::createByteArray(BitSetInfo &BSI) { return BAI; } -void LowerBitSets::allocateByteArrays() { +void LowerTypeTests::allocateByteArrays() { std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(), [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) { return BAI1.BitSize > BAI2.BitSize; @@ -413,8 +377,8 @@ void LowerBitSets::allocateByteArrays() { /// Build a test that bit BitOffset is set in BSI, where /// BitSetGlobal is a global containing the bits in BSI. -Value *LowerBitSets::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, - ByteArrayInfo *&BAI, Value *BitOffset) { +Value *LowerTypeTests::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, + ByteArrayInfo *&BAI, Value *BitOffset) { if (BSI.BitSize <= 64) { // If the bit set is sufficiently small, we can avoid a load by bit testing // a constant. @@ -454,9 +418,9 @@ Value *LowerBitSets::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, } } -/// Lower a llvm.bitset.test call to its implementation. Returns the value to +/// Lower a llvm.type.test call to its implementation. Returns the value to /// replace the call with. -Value *LowerBitSets::lowerBitSetCall( +Value *LowerTypeTests::lowerBitSetCall( CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI, Constant *CombinedGlobalIntAddr, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { @@ -524,10 +488,10 @@ Value *LowerBitSets::lowerBitSetCall( return P; } -/// Given a disjoint set of bitsets and globals, layout the globals, build the -/// bit sets and lower the llvm.bitset.test calls. -void LowerBitSets::buildBitSetsFromGlobalVariables( - ArrayRef<Metadata *> BitSets, ArrayRef<GlobalVariable *> Globals) { +/// Given a disjoint set of type identifiers and globals, lay out the globals, +/// build the bit sets and lower the llvm.type.test calls. +void LowerTypeTests::buildBitSetsFromGlobalVariables( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalVariable *> Globals) { // Build a new global with the combined contents of the referenced globals. // This global is a struct whose even-indexed elements contain the original // contents of the referenced globals and whose odd-indexed elements contain @@ -544,7 +508,7 @@ void LowerBitSets::buildBitSetsFromGlobalVariables( // Cap at 128 was found experimentally to have a good data/instruction // overhead tradeoff. if (Padding > 128) - Padding = RoundUpToAlignment(InitSize, 128) - InitSize; + Padding = alignTo(InitSize, 128) - InitSize; GlobalInits.push_back( ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding))); @@ -565,7 +529,7 @@ void LowerBitSets::buildBitSetsFromGlobalVariables( // Multiply by 2 to account for padding elements. GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2); - lowerBitSetCalls(BitSets, CombinedGlobal, GlobalLayout); + lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout); // Build aliases pointing to offsets into the combined global for each // global from which we built the combined global, and replace references @@ -591,19 +555,19 @@ void LowerBitSets::buildBitSetsFromGlobalVariables( } } -void LowerBitSets::lowerBitSetCalls( - ArrayRef<Metadata *> BitSets, Constant *CombinedGlobalAddr, +void LowerTypeTests::lowerTypeTestCalls( + ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { Constant *CombinedGlobalIntAddr = ConstantExpr::getPtrToInt(CombinedGlobalAddr, IntPtrTy); - // For each bitset in this disjoint set... - for (Metadata *BS : BitSets) { + // For each type identifier in this disjoint set... + for (Metadata *TypeId : TypeIds) { // Build the bitset. - BitSetInfo BSI = buildBitSet(BS, GlobalLayout); + BitSetInfo BSI = buildBitSet(TypeId, GlobalLayout); DEBUG({ - if (auto BSS = dyn_cast<MDString>(BS)) - dbgs() << BSS->getString() << ": "; + if (auto MDS = dyn_cast<MDString>(TypeId)) + dbgs() << MDS->getString() << ": "; else dbgs() << "<unnamed>: "; BSI.print(dbgs()); @@ -611,9 +575,9 @@ void LowerBitSets::lowerBitSetCalls( ByteArrayInfo *BAI = nullptr; - // Lower each call to llvm.bitset.test for this bitset. - for (CallInst *CI : BitSetTestCallSites[BS]) { - ++NumBitSetCallsLowered; + // Lower each call to llvm.type.test for this type identifier. + for (CallInst *CI : TypeTestCallSites[TypeId]) { + ++NumTypeTestCallsLowered; Value *Lowered = lowerBitSetCall(CI, BSI, BAI, CombinedGlobalIntAddr, GlobalLayout); CI->replaceAllUsesWith(Lowered); @@ -622,39 +586,32 @@ void LowerBitSets::lowerBitSetCalls( } } -void LowerBitSets::verifyBitSetMDNode(MDNode *Op) { - if (Op->getNumOperands() != 3) +void LowerTypeTests::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { + if (Type->getNumOperands() != 2) report_fatal_error( - "All operands of llvm.bitsets metadata must have 3 elements"); - if (!Op->getOperand(1)) - return; - - auto OpConstMD = dyn_cast<ConstantAsMetadata>(Op->getOperand(1)); - if (!OpConstMD) - report_fatal_error("Bit set element must be a constant"); - auto OpGlobal = dyn_cast<GlobalObject>(OpConstMD->getValue()); - if (!OpGlobal) - return; + "All operands of type metadata must have 2 elements"); - if (OpGlobal->isThreadLocal()) + if (GO->isThreadLocal()) report_fatal_error("Bit set element may not be thread-local"); - if (OpGlobal->hasSection()) - report_fatal_error("Bit set element may not have an explicit section"); + if (isa<GlobalVariable>(GO) && GO->hasSection()) + report_fatal_error( + "A member of a type identifier may not have an explicit section"); - if (isa<GlobalVariable>(OpGlobal) && OpGlobal->isDeclarationForLinker()) - report_fatal_error("Bit set global var element must be a definition"); + if (isa<GlobalVariable>(GO) && GO->isDeclarationForLinker()) + report_fatal_error( + "A global var member of a type identifier must be a definition"); - auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Op->getOperand(2)); + auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0)); if (!OffsetConstMD) - report_fatal_error("Bit set element offset must be a constant"); + report_fatal_error("Type offset must be a constant"); auto OffsetInt = dyn_cast<ConstantInt>(OffsetConstMD->getValue()); if (!OffsetInt) - report_fatal_error("Bit set element offset must be an integer constant"); + report_fatal_error("Type offset must be an integer constant"); } static const unsigned kX86JumpTableEntrySize = 8; -unsigned LowerBitSets::getJumpTableEntrySize() { +unsigned LowerTypeTests::getJumpTableEntrySize() { if (Arch != Triple::x86 && Arch != Triple::x86_64) report_fatal_error("Unsupported architecture for jump tables"); @@ -665,8 +622,9 @@ unsigned LowerBitSets::getJumpTableEntrySize() { // consists of an instruction sequence containing a relative branch to Dest. The // constant will be laid out at address Src+(Len*Distance) where Len is the // target-specific jump table entry size. -Constant *LowerBitSets::createJumpTableEntry(GlobalObject *Src, Function *Dest, - unsigned Distance) { +Constant *LowerTypeTests::createJumpTableEntry(GlobalObject *Src, + Function *Dest, + unsigned Distance) { if (Arch != Triple::x86 && Arch != Triple::x86_64) report_fatal_error("Unsupported architecture for jump tables"); @@ -693,7 +651,7 @@ Constant *LowerBitSets::createJumpTableEntry(GlobalObject *Src, Function *Dest, return ConstantStruct::getAnon(Fields, /*Packed=*/true); } -Type *LowerBitSets::getJumpTableEntryType() { +Type *LowerTypeTests::getJumpTableEntryType() { if (Arch != Triple::x86 && Arch != Triple::x86_64) report_fatal_error("Unsupported architecture for jump tables"); @@ -702,10 +660,10 @@ Type *LowerBitSets::getJumpTableEntryType() { /*Packed=*/true); } -/// Given a disjoint set of bitsets and functions, build a jump table for the -/// functions, build the bit sets and lower the llvm.bitset.test calls. -void LowerBitSets::buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, - ArrayRef<Function *> Functions) { +/// Given a disjoint set of type identifiers and functions, build a jump table +/// for the functions, build the bit sets and lower the llvm.type.test calls. +void LowerTypeTests::buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, + ArrayRef<Function *> Functions) { // Unlike the global bitset builder, the function bitset builder cannot // re-arrange functions in a particular order and base its calculations on the // layout of the functions' entry points, as we have no idea how large a @@ -719,8 +677,7 @@ void LowerBitSets::buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, // verification done inside the module. // // In more concrete terms, suppose we have three functions f, g, h which are - // members of a single bitset, and a function foo that returns their - // addresses: + // of the same type, and a function foo that returns their addresses: // // f: // mov 0, %eax @@ -803,7 +760,7 @@ void LowerBitSets::buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, JumpTable->setSection(ObjectFormat == Triple::MachO ? "__TEXT,__text,regular,pure_instructions" : ".text"); - lowerBitSetCalls(BitSets, JumpTable, GlobalLayout); + lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout); // Build aliases pointing to offsets into the jump table, and replace // references to the original functions with references to the aliases. @@ -838,39 +795,32 @@ void LowerBitSets::buildBitSetsFromFunctions(ArrayRef<Metadata *> BitSets, ConstantArray::get(JumpTableType, JumpTableEntries)); } -void LowerBitSets::buildBitSetsFromDisjointSet( - ArrayRef<Metadata *> BitSets, ArrayRef<GlobalObject *> Globals) { - llvm::DenseMap<Metadata *, uint64_t> BitSetIndices; - llvm::DenseMap<GlobalObject *, uint64_t> GlobalIndices; - for (unsigned I = 0; I != BitSets.size(); ++I) - BitSetIndices[BitSets[I]] = I; - for (unsigned I = 0; I != Globals.size(); ++I) - GlobalIndices[Globals[I]] = I; - - // For each bitset, build a set of indices that refer to globals referenced by - // the bitset. - std::vector<std::set<uint64_t>> BitSetMembers(BitSets.size()); - if (BitSetNM) { - for (MDNode *Op : BitSetNM->operands()) { - // Op = { bitset name, global, offset } - if (!Op->getOperand(1)) - continue; - auto I = BitSetIndices.find(Op->getOperand(0)); - if (I == BitSetIndices.end()) - continue; - - auto OpGlobal = dyn_cast<GlobalObject>( - cast<ConstantAsMetadata>(Op->getOperand(1))->getValue()); - if (!OpGlobal) - continue; - BitSetMembers[I->second].insert(GlobalIndices[OpGlobal]); +void LowerTypeTests::buildBitSetsFromDisjointSet( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalObject *> Globals) { + llvm::DenseMap<Metadata *, uint64_t> TypeIdIndices; + for (unsigned I = 0; I != TypeIds.size(); ++I) + TypeIdIndices[TypeIds[I]] = I; + + // For each type identifier, build a set of indices that refer to members of + // the type identifier. + std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size()); + SmallVector<MDNode *, 2> Types; + unsigned GlobalIndex = 0; + for (GlobalObject *GO : Globals) { + Types.clear(); + GO->getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + // Type = { offset, type identifier } + unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)]; + TypeMembers[TypeIdIndex].insert(GlobalIndex); } + GlobalIndex++; } // Order the sets of indices by size. The GlobalLayoutBuilder works best // when given small index sets first. std::stable_sort( - BitSetMembers.begin(), BitSetMembers.end(), + TypeMembers.begin(), TypeMembers.end(), [](const std::set<uint64_t> &O1, const std::set<uint64_t> &O2) { return O1.size() < O2.size(); }); @@ -879,7 +829,7 @@ void LowerBitSets::buildBitSetsFromDisjointSet( // fragments. The GlobalLayoutBuilder tries to lay out members of fragments as // close together as possible. GlobalLayoutBuilder GLB(Globals.size()); - for (auto &&MemSet : BitSetMembers) + for (auto &&MemSet : TypeMembers) GLB.addFragment(MemSet); // Build the bitsets from this disjoint set. @@ -891,13 +841,13 @@ void LowerBitSets::buildBitSetsFromDisjointSet( for (auto &&Offset : F) { auto GV = dyn_cast<GlobalVariable>(Globals[Offset]); if (!GV) - report_fatal_error( - "Bit set may not contain both global variables and functions"); + report_fatal_error("Type identifier may not contain both global " + "variables and functions"); *OGI++ = GV; } } - buildBitSetsFromGlobalVariables(BitSets, OrderedGVs); + buildBitSetsFromGlobalVariables(TypeIds, OrderedGVs); } else { // Build a vector of functions with the computed layout. std::vector<Function *> OrderedFns(Globals.size()); @@ -906,102 +856,97 @@ void LowerBitSets::buildBitSetsFromDisjointSet( for (auto &&Offset : F) { auto Fn = dyn_cast<Function>(Globals[Offset]); if (!Fn) - report_fatal_error( - "Bit set may not contain both global variables and functions"); + report_fatal_error("Type identifier may not contain both global " + "variables and functions"); *OFI++ = Fn; } } - buildBitSetsFromFunctions(BitSets, OrderedFns); + buildBitSetsFromFunctions(TypeIds, OrderedFns); } } -/// Lower all bit sets in this module. -bool LowerBitSets::buildBitSets() { - Function *BitSetTestFunc = - M->getFunction(Intrinsic::getName(Intrinsic::bitset_test)); - if (!BitSetTestFunc) +/// Lower all type tests in this module. +bool LowerTypeTests::lower() { + Function *TypeTestFunc = + M->getFunction(Intrinsic::getName(Intrinsic::type_test)); + if (!TypeTestFunc || TypeTestFunc->use_empty()) return false; - // Equivalence class set containing bitsets and the globals they reference. - // This is used to partition the set of bitsets in the module into disjoint - // sets. + // Equivalence class set containing type identifiers and the globals that + // reference them. This is used to partition the set of type identifiers in + // the module into disjoint sets. typedef EquivalenceClasses<PointerUnion<GlobalObject *, Metadata *>> GlobalClassesTy; GlobalClassesTy GlobalClasses; - // Verify the bitset metadata and build a mapping from bitset identifiers to - // their last observed index in BitSetNM. This will used later to - // deterministically order the list of bitset identifiers. - llvm::DenseMap<Metadata *, unsigned> BitSetIdIndices; - if (BitSetNM) { - for (unsigned I = 0, E = BitSetNM->getNumOperands(); I != E; ++I) { - MDNode *Op = BitSetNM->getOperand(I); - verifyBitSetMDNode(Op); - BitSetIdIndices[Op->getOperand(0)] = I; + // Verify the type metadata and build a mapping from type identifiers to their + // last observed index in the list of globals. This will be used later to + // deterministically order the list of type identifiers. + llvm::DenseMap<Metadata *, unsigned> TypeIdIndices; + unsigned I = 0; + SmallVector<MDNode *, 2> Types; + for (GlobalObject &GO : M->global_objects()) { + Types.clear(); + GO.getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + verifyTypeMDNode(&GO, Type); + TypeIdIndices[cast<MDNode>(Type)->getOperand(1)] = ++I; } } - for (const Use &U : BitSetTestFunc->uses()) { + for (const Use &U : TypeTestFunc->uses()) { auto CI = cast<CallInst>(U.getUser()); auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); if (!BitSetMDVal) report_fatal_error( - "Second argument of llvm.bitset.test must be metadata"); + "Second argument of llvm.type.test must be metadata"); auto BitSet = BitSetMDVal->getMetadata(); - // Add the call site to the list of call sites for this bit set. We also use - // BitSetTestCallSites to keep track of whether we have seen this bit set - // before. If we have, we don't need to re-add the referenced globals to the - // equivalence class. - std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, - bool> Ins = - BitSetTestCallSites.insert( + // Add the call site to the list of call sites for this type identifier. We + // also use TypeTestCallSites to keep track of whether we have seen this + // type identifier before. If we have, we don't need to re-add the + // referenced globals to the equivalence class. + std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, bool> + Ins = TypeTestCallSites.insert( std::make_pair(BitSet, std::vector<CallInst *>())); Ins.first->second.push_back(CI); if (!Ins.second) continue; - // Add the bitset to the equivalence class. + // Add the type identifier to the equivalence class. GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet); GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI); - if (!BitSetNM) - continue; - - // Add the referenced globals to the bitset's equivalence class. - for (MDNode *Op : BitSetNM->operands()) { - if (Op->getOperand(0) != BitSet || !Op->getOperand(1)) - continue; - - auto OpGlobal = dyn_cast<GlobalObject>( - cast<ConstantAsMetadata>(Op->getOperand(1))->getValue()); - if (!OpGlobal) - continue; - - CurSet = GlobalClasses.unionSets( - CurSet, GlobalClasses.findLeader(GlobalClasses.insert(OpGlobal))); + // Add the referenced globals to the type identifier's equivalence class. + for (GlobalObject &GO : M->global_objects()) { + Types.clear(); + GO.getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) + if (Type->getOperand(1) == BitSet) + CurSet = GlobalClasses.unionSets( + CurSet, GlobalClasses.findLeader(GlobalClasses.insert(&GO))); } } if (GlobalClasses.empty()) return false; - // Build a list of disjoint sets ordered by their maximum BitSetNM index - // for determinism. + // Build a list of disjoint sets ordered by their maximum global index for + // determinism. std::vector<std::pair<GlobalClassesTy::iterator, unsigned>> Sets; for (GlobalClassesTy::iterator I = GlobalClasses.begin(), E = GlobalClasses.end(); I != E; ++I) { if (!I->isLeader()) continue; - ++NumBitSetDisjointSets; + ++NumTypeIdDisjointSets; unsigned MaxIndex = 0; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I); MI != GlobalClasses.member_end(); ++MI) { if ((*MI).is<Metadata *>()) - MaxIndex = std::max(MaxIndex, BitSetIdIndices[MI->get<Metadata *>()]); + MaxIndex = std::max(MaxIndex, TypeIdIndices[MI->get<Metadata *>()]); } Sets.emplace_back(I, MaxIndex); } @@ -1013,26 +958,26 @@ bool LowerBitSets::buildBitSets() { // For each disjoint set we found... for (const auto &S : Sets) { - // Build the list of bitsets in this disjoint set. - std::vector<Metadata *> BitSets; + // Build the list of type identifiers in this disjoint set. + std::vector<Metadata *> TypeIds; std::vector<GlobalObject *> Globals; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(S.first); MI != GlobalClasses.member_end(); ++MI) { if ((*MI).is<Metadata *>()) - BitSets.push_back(MI->get<Metadata *>()); + TypeIds.push_back(MI->get<Metadata *>()); else Globals.push_back(MI->get<GlobalObject *>()); } - // Order bitsets by BitSetNM index for determinism. This ordering is stable - // as there is a one-to-one mapping between metadata and indices. - std::sort(BitSets.begin(), BitSets.end(), [&](Metadata *M1, Metadata *M2) { - return BitSetIdIndices[M1] < BitSetIdIndices[M2]; + // Order type identifiers by global index for determinism. This ordering is + // stable as there is a one-to-one mapping between metadata and indices. + std::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) { + return TypeIdIndices[M1] < TypeIdIndices[M2]; }); - // Lower the bitsets in this disjoint set. - buildBitSetsFromDisjointSet(BitSets, Globals); + // Build bitsets for this disjoint set. + buildBitSetsFromDisjointSet(TypeIds, Globals); } allocateByteArrays(); @@ -1040,16 +985,36 @@ bool LowerBitSets::buildBitSets() { return true; } -bool LowerBitSets::eraseBitSetMetadata() { - if (!BitSetNM) - return false; +// Initialization helper shared by the old and the new PM. +static void init(LowerTypeTests *LTT, Module &M) { + LTT->M = &M; + const DataLayout &DL = M.getDataLayout(); + Triple TargetTriple(M.getTargetTriple()); + LTT->LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX(); + LTT->Arch = TargetTriple.getArch(); + LTT->ObjectFormat = TargetTriple.getObjectFormat(); + LTT->Int1Ty = Type::getInt1Ty(M.getContext()); + LTT->Int8Ty = Type::getInt8Ty(M.getContext()); + LTT->Int32Ty = Type::getInt32Ty(M.getContext()); + LTT->Int32PtrTy = PointerType::getUnqual(LTT->Int32Ty); + LTT->Int64Ty = Type::getInt64Ty(M.getContext()); + LTT->IntPtrTy = DL.getIntPtrType(M.getContext(), 0); + LTT->TypeTestCallSites.clear(); +} - M->eraseNamedMetadata(BitSetNM); - return true; +bool LowerTypeTests::runOnModule(Module &M) { + if (skipModule(M)) + return false; + init(this, M); + return lower(); } -bool LowerBitSets::runOnModule(Module &M) { - bool Changed = buildBitSets(); - Changed |= eraseBitSetMetadata(); - return Changed; +PreservedAnalyses LowerTypeTestsPass::run(Module &M, + AnalysisManager<Module> &AM) { + LowerTypeTests Impl; + init(&Impl, M); + bool Changed = Impl.lower(); + if (!Changed) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); } diff --git a/lib/Transforms/IPO/Makefile b/lib/Transforms/IPO/Makefile deleted file mode 100644 index 5c42374139aa..000000000000 --- a/lib/Transforms/IPO/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Transforms/IPO/Makefile -------------------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMipo -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 8a209a18c540..fe653a75ddb5 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -89,13 +89,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/FoldingSet.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Hashing.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -112,6 +109,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include <vector> using namespace llvm; @@ -189,7 +187,7 @@ public: private: /// Test whether two basic blocks have equivalent behaviour. - int cmpBasicBlocks(const BasicBlock *BBL, const BasicBlock *BBR); + int cmpBasicBlocks(const BasicBlock *BBL, const BasicBlock *BBR) const; /// Constants comparison. /// Its analog to lexicographical comparison between hypothetical numbers @@ -293,11 +291,11 @@ private: /// look at their particular properties (bit-width for vectors, and /// address space for pointers). /// If these properties are equal - compare their contents. - int cmpConstants(const Constant *L, const Constant *R); + int cmpConstants(const Constant *L, const Constant *R) const; /// Compares two global values by number. Uses the GlobalNumbersState to /// identify the same gobals across function calls. - int cmpGlobalValues(GlobalValue *L, GlobalValue *R); + int cmpGlobalValues(GlobalValue *L, GlobalValue *R) const; /// Assign or look up previously assigned numbers for the two values, and /// return whether the numbers are equal. Numbers are assigned in the order @@ -317,11 +315,11 @@ private: /// then left value is greater. /// In another words, we compare serial numbers, for more details /// see comments for sn_mapL and sn_mapR. - int cmpValues(const Value *L, const Value *R); + int cmpValues(const Value *L, const Value *R) const; /// Compare two Instructions for equivalence, similar to - /// Instruction::isSameOperationAs but with modifications to the type - /// comparison. + /// Instruction::isSameOperationAs. + /// /// Stages are listed in "most significant stage first" order: /// On each stage below, we do comparison between some left and right /// operation parts. If parts are non-equal, we assign parts comparison @@ -339,8 +337,9 @@ private: /// For example, for Load it would be: /// 6.1.Load: volatile (as boolean flag) /// 6.2.Load: alignment (as integer numbers) - /// 6.3.Load: synch-scope (as integer numbers) - /// 6.4.Load: range metadata (as integer numbers) + /// 6.3.Load: ordering (as underlying enum class value) + /// 6.4.Load: synch-scope (as integer numbers) + /// 6.5.Load: range metadata (as integer ranges) /// On this stage its better to see the code, since its not more than 10-15 /// strings for particular instruction, and could change sometimes. int cmpOperations(const Instruction *L, const Instruction *R) const; @@ -353,8 +352,9 @@ private: /// 3. Pointer operand type (using cmpType method). /// 4. Number of operands. /// 5. Compare operands, using cmpValues method. - int cmpGEPs(const GEPOperator *GEPL, const GEPOperator *GEPR); - int cmpGEPs(const GetElementPtrInst *GEPL, const GetElementPtrInst *GEPR) { + int cmpGEPs(const GEPOperator *GEPL, const GEPOperator *GEPR) const; + int cmpGEPs(const GetElementPtrInst *GEPL, + const GetElementPtrInst *GEPR) const { return cmpGEPs(cast<GEPOperator>(GEPL), cast<GEPOperator>(GEPR)); } @@ -401,12 +401,13 @@ private: int cmpTypes(Type *TyL, Type *TyR) const; int cmpNumbers(uint64_t L, uint64_t R) const; + int cmpOrderings(AtomicOrdering L, AtomicOrdering R) const; int cmpAPInts(const APInt &L, const APInt &R) const; int cmpAPFloats(const APFloat &L, const APFloat &R) const; int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const; int cmpMem(StringRef L, StringRef R) const; int cmpAttrs(const AttributeSet L, const AttributeSet R) const; - int cmpRangeMetadata(const MDNode* L, const MDNode* R) const; + int cmpRangeMetadata(const MDNode *L, const MDNode *R) const; int cmpOperandBundlesSchema(const Instruction *L, const Instruction *R) const; // The two functions undergoing comparison. @@ -445,7 +446,7 @@ private: /// But, we are still not able to compare operands of PHI nodes, since those /// could be operands from further BBs we didn't scan yet. /// So it's impossible to use dominance properties in general. - DenseMap<const Value*, int> sn_mapL, sn_mapR; + mutable DenseMap<const Value*, int> sn_mapL, sn_mapR; // The global state we will use GlobalNumberState* GlobalNumbers; @@ -477,6 +478,12 @@ int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const { return 0; } +int FunctionComparator::cmpOrderings(AtomicOrdering L, AtomicOrdering R) const { + if ((int)L < (int)R) return -1; + if ((int)L > (int)R) return 1; + return 0; +} + int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const { if (int Res = cmpNumbers(L.getBitWidth(), R.getBitWidth())) return Res; @@ -538,8 +545,8 @@ int FunctionComparator::cmpAttrs(const AttributeSet L, return 0; } -int FunctionComparator::cmpRangeMetadata(const MDNode* L, - const MDNode* R) const { +int FunctionComparator::cmpRangeMetadata(const MDNode *L, + const MDNode *R) const { if (L == R) return 0; if (!L) @@ -547,7 +554,7 @@ int FunctionComparator::cmpRangeMetadata(const MDNode* L, if (!R) return 1; // Range metadata is a sequence of numbers. Make sure they are the same - // sequence. + // sequence. // TODO: Note that as this is metadata, it is possible to drop and/or merge // this data when considering functions to merge. Thus this comparison would // return 0 (i.e. equivalent), but merging would become more complicated @@ -557,8 +564,8 @@ int FunctionComparator::cmpRangeMetadata(const MDNode* L, if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) return Res; for (size_t I = 0; I < L->getNumOperands(); ++I) { - ConstantInt* LLow = mdconst::extract<ConstantInt>(L->getOperand(I)); - ConstantInt* RLow = mdconst::extract<ConstantInt>(R->getOperand(I)); + ConstantInt *LLow = mdconst::extract<ConstantInt>(L->getOperand(I)); + ConstantInt *RLow = mdconst::extract<ConstantInt>(R->getOperand(I)); if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue())) return Res; } @@ -596,7 +603,8 @@ int FunctionComparator::cmpOperandBundlesSchema(const Instruction *L, /// type. /// 2. Compare constant contents. /// For more details see declaration comments. -int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) { +int FunctionComparator::cmpConstants(const Constant *L, + const Constant *R) const { Type *TyL = L->getType(); Type *TyR = R->getType(); @@ -793,7 +801,7 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) { } } -int FunctionComparator::cmpGlobalValues(GlobalValue *L, GlobalValue* R) { +int FunctionComparator::cmpGlobalValues(GlobalValue *L, GlobalValue *R) const { return cmpNumbers(GlobalNumbers->getNumber(L), GlobalNumbers->getNumber(R)); } @@ -898,9 +906,9 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { int FunctionComparator::cmpOperations(const Instruction *L, const Instruction *R) const { // Differences from Instruction::isSameOperationAs: - // * replace type comparison with calls to isEquivalentType. - // * we test for I->hasSameSubclassOptionalData (nuw/nsw/tail) at the top - // * because of the above, we don't test for the tail bit on calls later on + // * replace type comparison with calls to cmpTypes. + // * we test for I->getRawSubclassOptionalData (nuw/nsw/tail) at the top. + // * because of the above, we don't test for the tail bit on calls later on. if (int Res = cmpNumbers(L->getOpcode(), R->getOpcode())) return Res; @@ -914,15 +922,6 @@ int FunctionComparator::cmpOperations(const Instruction *L, R->getRawSubclassOptionalData())) return Res; - if (const AllocaInst *AI = dyn_cast<AllocaInst>(L)) { - if (int Res = cmpTypes(AI->getAllocatedType(), - cast<AllocaInst>(R)->getAllocatedType())) - return Res; - if (int Res = - cmpNumbers(AI->getAlignment(), cast<AllocaInst>(R)->getAlignment())) - return Res; - } - // We have two instructions of identical opcode and #operands. Check to see // if all operands are the same type for (unsigned i = 0, e = L->getNumOperands(); i != e; ++i) { @@ -932,6 +931,12 @@ int FunctionComparator::cmpOperations(const Instruction *L, } // Check special state that is a part of some instructions. + if (const AllocaInst *AI = dyn_cast<AllocaInst>(L)) { + if (int Res = cmpTypes(AI->getAllocatedType(), + cast<AllocaInst>(R)->getAllocatedType())) + return Res; + return cmpNumbers(AI->getAlignment(), cast<AllocaInst>(R)->getAlignment()); + } if (const LoadInst *LI = dyn_cast<LoadInst>(L)) { if (int Res = cmpNumbers(LI->isVolatile(), cast<LoadInst>(R)->isVolatile())) return Res; @@ -939,7 +944,7 @@ int FunctionComparator::cmpOperations(const Instruction *L, cmpNumbers(LI->getAlignment(), cast<LoadInst>(R)->getAlignment())) return Res; if (int Res = - cmpNumbers(LI->getOrdering(), cast<LoadInst>(R)->getOrdering())) + cmpOrderings(LI->getOrdering(), cast<LoadInst>(R)->getOrdering())) return Res; if (int Res = cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope())) @@ -955,7 +960,7 @@ int FunctionComparator::cmpOperations(const Instruction *L, cmpNumbers(SI->getAlignment(), cast<StoreInst>(R)->getAlignment())) return Res; if (int Res = - cmpNumbers(SI->getOrdering(), cast<StoreInst>(R)->getOrdering())) + cmpOrderings(SI->getOrdering(), cast<StoreInst>(R)->getOrdering())) return Res; return cmpNumbers(SI->getSynchScope(), cast<StoreInst>(R)->getSynchScope()); } @@ -996,6 +1001,7 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpNumbers(LIndices[i], RIndices[i])) return Res; } + return 0; } if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(L)) { ArrayRef<unsigned> LIndices = EVI->getIndices(); @@ -1009,11 +1015,10 @@ int FunctionComparator::cmpOperations(const Instruction *L, } if (const FenceInst *FI = dyn_cast<FenceInst>(L)) { if (int Res = - cmpNumbers(FI->getOrdering(), cast<FenceInst>(R)->getOrdering())) + cmpOrderings(FI->getOrdering(), cast<FenceInst>(R)->getOrdering())) return Res; return cmpNumbers(FI->getSynchScope(), cast<FenceInst>(R)->getSynchScope()); } - if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(L)) { if (int Res = cmpNumbers(CXI->isVolatile(), cast<AtomicCmpXchgInst>(R)->isVolatile())) @@ -1021,11 +1026,13 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpNumbers(CXI->isWeak(), cast<AtomicCmpXchgInst>(R)->isWeak())) return Res; - if (int Res = cmpNumbers(CXI->getSuccessOrdering(), - cast<AtomicCmpXchgInst>(R)->getSuccessOrdering())) + if (int Res = + cmpOrderings(CXI->getSuccessOrdering(), + cast<AtomicCmpXchgInst>(R)->getSuccessOrdering())) return Res; - if (int Res = cmpNumbers(CXI->getFailureOrdering(), - cast<AtomicCmpXchgInst>(R)->getFailureOrdering())) + if (int Res = + cmpOrderings(CXI->getFailureOrdering(), + cast<AtomicCmpXchgInst>(R)->getFailureOrdering())) return Res; return cmpNumbers(CXI->getSynchScope(), cast<AtomicCmpXchgInst>(R)->getSynchScope()); @@ -1037,19 +1044,30 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpNumbers(RMWI->isVolatile(), cast<AtomicRMWInst>(R)->isVolatile())) return Res; - if (int Res = cmpNumbers(RMWI->getOrdering(), + if (int Res = cmpOrderings(RMWI->getOrdering(), cast<AtomicRMWInst>(R)->getOrdering())) return Res; return cmpNumbers(RMWI->getSynchScope(), cast<AtomicRMWInst>(R)->getSynchScope()); } + if (const PHINode *PNL = dyn_cast<PHINode>(L)) { + const PHINode *PNR = cast<PHINode>(R); + // Ensure that in addition to the incoming values being identical + // (checked by the caller of this function), the incoming blocks + // are also identical. + for (unsigned i = 0, e = PNL->getNumIncomingValues(); i != e; ++i) { + if (int Res = + cmpValues(PNL->getIncomingBlock(i), PNR->getIncomingBlock(i))) + return Res; + } + } return 0; } // Determine whether two GEP operations perform the same underlying arithmetic. // Read method declaration comments for more details. int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, - const GEPOperator *GEPR) { + const GEPOperator *GEPR) const { unsigned int ASL = GEPL->getPointerAddressSpace(); unsigned int ASR = GEPR->getPointerAddressSpace(); @@ -1106,7 +1124,7 @@ int FunctionComparator::cmpInlineAsm(const InlineAsm *L, /// this is the first time the values are seen, they're added to the mapping so /// that we will detect mismatches on next use. /// See comments in declaration for more details. -int FunctionComparator::cmpValues(const Value *L, const Value *R) { +int FunctionComparator::cmpValues(const Value *L, const Value *R) const { // Catch self-reference case. if (L == FnL) { if (R == FnR) @@ -1149,7 +1167,7 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) { } // Test whether two basic blocks have equivalent behaviour. int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL, - const BasicBlock *BBR) { + const BasicBlock *BBR) const { BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end(); BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end(); @@ -1186,7 +1204,8 @@ int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL, } } - ++InstL, ++InstR; + ++InstL; + ++InstR; } while (InstL != InstLE && InstR != InstRE); if (InstL != InstLE && InstR == InstRE) @@ -1249,7 +1268,7 @@ int FunctionComparator::compare() { // functions, then takes each block from each terminator in order. As an // artifact, this also means that unreachable blocks are ignored. SmallVector<const BasicBlock *, 8> FnLBBs, FnRBBs; - SmallSet<const BasicBlock *, 128> VisitedBBs; // in terms of F1. + SmallPtrSet<const BasicBlock *, 32> VisitedBBs; // in terms of F1. FnLBBs.push_back(&FnL->getEntryBlock()); FnRBBs.push_back(&FnR->getEntryBlock()); @@ -1517,6 +1536,9 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { } bool MergeFunctions::runOnModule(Module &M) { + if (skipModule(M)) + return false; + bool Changed = false; // All functions in the module, ordered by hash. Functions with a unique @@ -1555,28 +1577,12 @@ bool MergeFunctions::runOnModule(Module &M) { DEBUG(dbgs() << "size of module: " << M.size() << '\n'); DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n'); - // Insert only strong functions and merge them. Strong function merging - // always deletes one of them. - for (std::vector<WeakVH>::iterator I = Worklist.begin(), - E = Worklist.end(); I != E; ++I) { - if (!*I) continue; - Function *F = cast<Function>(*I); - if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() && - !F->mayBeOverridden()) { - Changed |= insert(F); - } - } - - // Insert only weak functions and merge them. By doing these second we - // create thunks to the strong function when possible. When two weak - // functions are identical, we create a new strong function with two weak - // weak thunks to it which are identical but not mergable. - for (std::vector<WeakVH>::iterator I = Worklist.begin(), - E = Worklist.end(); I != E; ++I) { - if (!*I) continue; - Function *F = cast<Function>(*I); - if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() && - F->mayBeOverridden()) { + // Insert functions and merge them. + for (WeakVH &I : Worklist) { + if (!I) + continue; + Function *F = cast<Function>(I); + if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) { Changed |= insert(F); } } @@ -1631,7 +1637,7 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { // Replace G with an alias to F if possible, or else a thunk to F. Deletes G. void MergeFunctions::writeThunkOrAlias(Function *F, Function *G) { - if (HasGlobalAliases && G->hasUnnamedAddr()) { + if (HasGlobalAliases && G->hasGlobalUnnamedAddr()) { if (G->hasExternalLinkage() || G->hasLocalLinkage() || G->hasWeakLinkage()) { writeAlias(F, G); @@ -1645,7 +1651,7 @@ void MergeFunctions::writeThunkOrAlias(Function *F, Function *G) { // Helper for writeThunk, // Selects proper bitcast operation, // but a bit simpler then CastInst::getCastOpcode. -static Value *createCast(IRBuilder<false> &Builder, Value *V, Type *DestTy) { +static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { Type *SrcTy = V->getType(); if (SrcTy->isStructTy()) { assert(DestTy->isStructTy()); @@ -1673,7 +1679,7 @@ static Value *createCast(IRBuilder<false> &Builder, Value *V, Type *DestTy) { // Replace G with a simple tail call to bitcast(F). Also replace direct uses // of G with bitcast(F). Deletes G. void MergeFunctions::writeThunk(Function *F, Function *G) { - if (!G->mayBeOverridden()) { + if (!G->isInterposable()) { // Redirect direct callers of G to F. replaceDirectCallers(G, F); } @@ -1688,7 +1694,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { Function *NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "", G->getParent()); BasicBlock *BB = BasicBlock::Create(F->getContext(), "", NewG); - IRBuilder<false> Builder(BB); + IRBuilder<> Builder(BB); SmallVector<Value *, 16> Args; unsigned i = 0; @@ -1734,8 +1740,8 @@ void MergeFunctions::writeAlias(Function *F, Function *G) { // Merge two equivalent functions. Upon completion, Function G is deleted. void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { - if (F->mayBeOverridden()) { - assert(G->mayBeOverridden()); + if (F->isInterposable()) { + assert(G->isInterposable()); // Make them both thunks to the same internal function. Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "", @@ -1816,20 +1822,16 @@ bool MergeFunctions::insert(Function *NewFunction) { // important when operating on more than one module independently to prevent // cycles of thunks calling each other when the modules are linked together. // - // When one function is weak and the other is strong there is an order imposed - // already. We process strong functions before weak functions. - if ((OldF.getFunc()->mayBeOverridden() && NewFunction->mayBeOverridden()) || - (!OldF.getFunc()->mayBeOverridden() && !NewFunction->mayBeOverridden())) - if (OldF.getFunc()->getName() > NewFunction->getName()) { - // Swap the two functions. - Function *F = OldF.getFunc(); - replaceFunctionInTree(*Result.first, NewFunction); - NewFunction = F; - assert(OldF.getFunc() != F && "Must have swapped the functions."); - } - - // Never thunk a strong function to a weak function. - assert(!OldF.getFunc()->mayBeOverridden() || NewFunction->mayBeOverridden()); + // First of all, we process strong functions before weak functions. + if ((OldF.getFunc()->isInterposable() && !NewFunction->isInterposable()) || + (OldF.getFunc()->isInterposable() == NewFunction->isInterposable() && + OldF.getFunc()->getName() > NewFunction->getName())) { + // Swap the two functions. + Function *F = OldF.getFunc(); + replaceFunctionInTree(*Result.first, NewFunction); + NewFunction = F; + assert(OldF.getFunc() != F && "Must have swapped the functions."); + } DEBUG(dbgs() << " " << OldF.getFunc()->getName() << " == " << NewFunction->getName() << '\n'); diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp index 0c5c84bbccab..49c44173491e 100644 --- a/lib/Transforms/IPO/PartialInlining.cpp +++ b/lib/Transforms/IPO/PartialInlining.cpp @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/PartialInlining.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/CodeExtractor.h" using namespace llvm; @@ -28,27 +29,34 @@ using namespace llvm; STATISTIC(NumPartialInlined, "Number of functions partially inlined"); namespace { - struct PartialInliner : public ModulePass { - void getAnalysisUsage(AnalysisUsage &AU) const override { } - static char ID; // Pass identification, replacement for typeid - PartialInliner() : ModulePass(ID) { - initializePartialInlinerPass(*PassRegistry::getPassRegistry()); - } +struct PartialInlinerLegacyPass : public ModulePass { + static char ID; // Pass identification, replacement for typeid + PartialInlinerLegacyPass() : ModulePass(ID) { + initializePartialInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); + } - bool runOnModule(Module& M) override; + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + ModuleAnalysisManager DummyMAM; + auto PA = Impl.run(M, DummyMAM); + return !PA.areAllPreserved(); + } - private: - Function* unswitchFunction(Function* F); +private: + PartialInlinerPass Impl; }; } -char PartialInliner::ID = 0; -INITIALIZE_PASS(PartialInliner, "partial-inliner", - "Partial Inliner", false, false) +char PartialInlinerLegacyPass::ID = 0; +INITIALIZE_PASS(PartialInlinerLegacyPass, "partial-inliner", "Partial Inliner", + false, false) -ModulePass* llvm::createPartialInliningPass() { return new PartialInliner(); } +ModulePass *llvm::createPartialInliningPass() { + return new PartialInlinerLegacyPass(); +} -Function* PartialInliner::unswitchFunction(Function* F) { +Function *PartialInlinerPass::unswitchFunction(Function *F) { // First, verify that this function is an unswitching candidate... BasicBlock *entryBlock = &F->front(); BranchInst *BR = dyn_cast<BranchInst>(entryBlock->getTerminator()); @@ -71,10 +79,8 @@ Function* PartialInliner::unswitchFunction(Function* F) { // Clone the function, so that we can hack away on it. ValueToValueMapTy VMap; - Function* duplicateFunction = CloneFunction(F, VMap, - /*ModuleLevelChanges=*/false); + Function* duplicateFunction = CloneFunction(F, VMap); duplicateFunction->setLinkage(GlobalValue::InternalLinkage); - F->getParent()->getFunctionList().push_back(duplicateFunction); BasicBlock* newEntryBlock = cast<BasicBlock>(VMap[entryBlock]); BasicBlock* newReturnBlock = cast<BasicBlock>(VMap[returnBlock]); BasicBlock* newNonReturnBlock = cast<BasicBlock>(VMap[nonReturnBlock]); @@ -112,11 +118,10 @@ Function* PartialInliner::unswitchFunction(Function* F) { // Gather up the blocks that we're going to extract. std::vector<BasicBlock*> toExtract; toExtract.push_back(newNonReturnBlock); - for (Function::iterator FI = duplicateFunction->begin(), - FE = duplicateFunction->end(); FI != FE; ++FI) - if (&*FI != newEntryBlock && &*FI != newReturnBlock && - &*FI != newNonReturnBlock) - toExtract.push_back(&*FI); + for (BasicBlock &BB : *duplicateFunction) + if (&BB != newEntryBlock && &BB != newReturnBlock && + &BB != newNonReturnBlock) + toExtract.push_back(&BB); // The CodeExtractor needs a dominator tree. DominatorTree DT; @@ -131,11 +136,10 @@ Function* PartialInliner::unswitchFunction(Function* F) { // Inline the top-level if test into all callers. std::vector<User *> Users(duplicateFunction->user_begin(), duplicateFunction->user_end()); - for (std::vector<User*>::iterator UI = Users.begin(), UE = Users.end(); - UI != UE; ++UI) - if (CallInst *CI = dyn_cast<CallInst>(*UI)) + for (User *User : Users) + if (CallInst *CI = dyn_cast<CallInst>(User)) InlineFunction(CI, IFI); - else if (InvokeInst *II = dyn_cast<InvokeInst>(*UI)) + else if (InvokeInst *II = dyn_cast<InvokeInst>(User)) InlineFunction(II, IFI); // Ditch the duplicate, since we're done with it, and rewrite all remaining @@ -148,13 +152,13 @@ Function* PartialInliner::unswitchFunction(Function* F) { return extractedFunction; } -bool PartialInliner::runOnModule(Module& M) { +PreservedAnalyses PartialInlinerPass::run(Module &M, ModuleAnalysisManager &) { std::vector<Function*> worklist; worklist.reserve(M.size()); - for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) - if (!FI->use_empty() && !FI->isDeclaration()) - worklist.push_back(&*FI); - + for (Function &F : M) + if (!F.use_empty() && !F.isDeclaration()) + worklist.push_back(&F); + bool changed = false; while (!worklist.empty()) { Function* currFunc = worklist.back(); @@ -178,6 +182,8 @@ bool PartialInliner::runOnModule(Module& M) { } } - - return changed; + + if (changed) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); } diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index faada9c2a7db..cf5b76dc365b 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -16,23 +16,27 @@ #include "llvm-c/Transforms/PassManagerBuilder.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/CFLAliasAnalysis.h" +#include "llvm/Analysis/CFLAndersAliasAnalysis.h" +#include "llvm/Analysis/CFLSteensAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TypeBasedAliasAnalysis.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/FunctionInfo.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/ModuleSummaryIndex.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" +#include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/IPO/InferFunctionAttrs.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Vectorize.h" using namespace llvm; @@ -58,10 +62,6 @@ static cl::opt<bool> ExtraVectorizerPasses( "extra-vectorizer-passes", cl::init(false), cl::Hidden, cl::desc("Run cleanup optimization passes after vectorization.")); -static cl::opt<bool> UseNewSROA("use-new-sroa", - cl::init(true), cl::Hidden, - cl::desc("Enable the new, experimental SROA pass")); - static cl::opt<bool> RunLoopRerolling("reroll-loops", cl::Hidden, cl::desc("Run the loop rerolling pass")); @@ -80,9 +80,19 @@ RunSLPAfterLoopVectorization("run-slp-after-loop-vectorization", cl::desc("Run the SLP vectorizer (and BB vectorizer) after the Loop " "vectorizer instead of before")); -static cl::opt<bool> UseCFLAA("use-cfl-aa", - cl::init(false), cl::Hidden, - cl::desc("Enable the new, experimental CFL alias analysis")); +// Experimental option to use CFL-AA +enum class CFLAAType { None, Steensgaard, Andersen, Both }; +static cl::opt<CFLAAType> + UseCFLAA("use-cfl-aa", cl::init(CFLAAType::None), cl::Hidden, + cl::desc("Enable the new, experimental CFL alias analysis"), + cl::values(clEnumValN(CFLAAType::None, "none", "Disable CFL-AA"), + clEnumValN(CFLAAType::Steensgaard, "steens", + "Enable unification-based CFL-AA"), + clEnumValN(CFLAAType::Andersen, "anders", + "Enable inclusion-based CFL-AA"), + clEnumValN(CFLAAType::Both, "both", + "Enable both variants of CFL-aa"), + clEnumValEnd)); static cl::opt<bool> EnableMLSM("mlsm", cl::init(true), cl::Hidden, @@ -92,25 +102,44 @@ static cl::opt<bool> EnableLoopInterchange( "enable-loopinterchange", cl::init(false), cl::Hidden, cl::desc("Enable the new, experimental LoopInterchange Pass")); -static cl::opt<bool> EnableLoopDistribute( - "enable-loop-distribute", cl::init(false), cl::Hidden, - cl::desc("Enable the new, experimental LoopDistribution Pass")); - static cl::opt<bool> EnableNonLTOGlobalsModRef( "enable-non-lto-gmr", cl::init(true), cl::Hidden, cl::desc( "Enable the GlobalsModRef AliasAnalysis outside of the LTO pipeline.")); static cl::opt<bool> EnableLoopLoadElim( - "enable-loop-load-elim", cl::init(false), cl::Hidden, - cl::desc("Enable the new, experimental LoopLoadElimination Pass")); + "enable-loop-load-elim", cl::init(true), cl::Hidden, + cl::desc("Enable the LoopLoadElimination Pass")); + +static cl::opt<std::string> RunPGOInstrGen( + "profile-generate", cl::init(""), cl::Hidden, + cl::desc("Enable generation phase of PGO instrumentation and specify the " + "path of profile data file")); + +static cl::opt<std::string> RunPGOInstrUse( + "profile-use", cl::init(""), cl::Hidden, cl::value_desc("filename"), + cl::desc("Enable use phase of PGO instrumentation and specify the path " + "of profile data file")); + +static cl::opt<bool> UseLoopVersioningLICM( + "enable-loop-versioning-licm", cl::init(false), cl::Hidden, + cl::desc("Enable the experimental Loop Versioning LICM pass")); + +static cl::opt<bool> + DisablePreInliner("disable-preinline", cl::init(false), cl::Hidden, + cl::desc("Disable pre-instrumentation inliner")); + +static cl::opt<int> PreInlineThreshold( + "preinline-threshold", cl::Hidden, cl::init(75), cl::ZeroOrMore, + cl::desc("Control the amount of inlining in pre-instrumentation inliner " + "(default = 75)")); PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; LibraryInfo = nullptr; Inliner = nullptr; - FunctionIndex = nullptr; + ModuleSummary = nullptr; DisableUnitAtATime = false; DisableUnrollLoops = false; BBVectorize = RunBBVectorization; @@ -123,6 +152,10 @@ PassManagerBuilder::PassManagerBuilder() { VerifyOutput = false; MergeFunctions = false; PrepareForLTO = false; + PGOInstrGen = RunPGOInstrGen; + PGOInstrUse = RunPGOInstrUse; + PrepareForThinLTO = false; + PerformThinLTO = false; } PassManagerBuilder::~PassManagerBuilder() { @@ -137,11 +170,11 @@ static ManagedStatic<SmallVector<std::pair<PassManagerBuilder::ExtensionPointTy, void PassManagerBuilder::addGlobalExtension( PassManagerBuilder::ExtensionPointTy Ty, PassManagerBuilder::ExtensionFn Fn) { - GlobalExtensions->push_back(std::make_pair(Ty, Fn)); + GlobalExtensions->push_back(std::make_pair(Ty, std::move(Fn))); } void PassManagerBuilder::addExtension(ExtensionPointTy Ty, ExtensionFn Fn) { - Extensions.push_back(std::make_pair(Ty, Fn)); + Extensions.push_back(std::make_pair(Ty, std::move(Fn))); } void PassManagerBuilder::addExtensionsToPM(ExtensionPointTy ETy, @@ -156,15 +189,34 @@ void PassManagerBuilder::addExtensionsToPM(ExtensionPointTy ETy, void PassManagerBuilder::addInitialAliasAnalysisPasses( legacy::PassManagerBase &PM) const { + switch (UseCFLAA) { + case CFLAAType::Steensgaard: + PM.add(createCFLSteensAAWrapperPass()); + break; + case CFLAAType::Andersen: + PM.add(createCFLAndersAAWrapperPass()); + break; + case CFLAAType::Both: + PM.add(createCFLSteensAAWrapperPass()); + PM.add(createCFLAndersAAWrapperPass()); + break; + default: + break; + } + // Add TypeBasedAliasAnalysis before BasicAliasAnalysis so that // BasicAliasAnalysis wins if they disagree. This is intended to help // support "obvious" type-punning idioms. - if (UseCFLAA) - PM.add(createCFLAAWrapperPass()); PM.add(createTypeBasedAAWrapperPass()); PM.add(createScopedNoAliasAAWrapperPass()); } +void PassManagerBuilder::addInstructionCombiningPass( + legacy::PassManagerBase &PM) const { + bool ExpensiveCombines = OptLevel > 2; + PM.add(createInstructionCombiningPass(ExpensiveCombines)); +} + void PassManagerBuilder::populateFunctionPassManager( legacy::FunctionPassManager &FPM) { addExtensionsToPM(EP_EarlyAsPossible, FPM); @@ -178,94 +230,50 @@ void PassManagerBuilder::populateFunctionPassManager( addInitialAliasAnalysisPasses(FPM); FPM.add(createCFGSimplificationPass()); - if (UseNewSROA) - FPM.add(createSROAPass()); - else - FPM.add(createScalarReplAggregatesPass()); + FPM.add(createSROAPass()); FPM.add(createEarlyCSEPass()); + FPM.add(createGVNHoistPass()); FPM.add(createLowerExpectIntrinsicPass()); } -void PassManagerBuilder::populateModulePassManager( - legacy::PassManagerBase &MPM) { - // Allow forcing function attributes as a debugging and tuning aid. - MPM.add(createForceFunctionAttrsLegacyPass()); - - // If all optimizations are disabled, just run the always-inline pass and, - // if enabled, the function merging pass. - if (OptLevel == 0) { - if (Inliner) { - MPM.add(Inliner); - Inliner = nullptr; - } - - // FIXME: The BarrierNoopPass is a HACK! The inliner pass above implicitly - // creates a CGSCC pass manager, but we don't want to add extensions into - // that pass manager. To prevent this we insert a no-op module pass to reset - // the pass manager to get the same behavior as EP_OptimizerLast in non-O0 - // builds. The function merging pass is - if (MergeFunctions) - MPM.add(createMergeFunctionsPass()); - else if (!GlobalExtensions->empty() || !Extensions.empty()) - MPM.add(createBarrierNoopPass()); - - addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); +// Do PGO instrumentation generation or use pass as the option specified. +void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM) { + if (PGOInstrGen.empty() && PGOInstrUse.empty()) return; - } - - // Add LibraryInfo if we have some. - if (LibraryInfo) - MPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); - - addInitialAliasAnalysisPasses(MPM); - - if (!DisableUnitAtATime) { - // Infer attributes about declarations if possible. - MPM.add(createInferFunctionAttrsLegacyPass()); - - addExtensionsToPM(EP_ModuleOptimizerEarly, MPM); - - MPM.add(createIPSCCPPass()); // IP SCCP - MPM.add(createGlobalOptimizerPass()); // Optimize out global vars - // Promote any localized global vars - MPM.add(createPromoteMemoryToRegisterPass()); - - MPM.add(createDeadArgEliminationPass()); // Dead argument elimination - - MPM.add(createInstructionCombiningPass());// Clean up after IPCP & DAE + // Perform the preinline and cleanup passes for O1 and above. + // And avoid doing them if optimizing for size. + if (OptLevel > 0 && SizeLevel == 0 && !DisablePreInliner) { + // Create preinline pass. + MPM.add(createFunctionInliningPass(PreInlineThreshold)); + MPM.add(createSROAPass()); + MPM.add(createEarlyCSEPass()); // Catch trivial redundancies + MPM.add(createCFGSimplificationPass()); // Merge & remove BBs + MPM.add(createInstructionCombiningPass()); // Combine silly seq's addExtensionsToPM(EP_Peephole, MPM); - MPM.add(createCFGSimplificationPass()); // Clean up after IPCP & DAE } - - if (EnableNonLTOGlobalsModRef) - // We add a module alias analysis pass here. In part due to bugs in the - // analysis infrastructure this "works" in that the analysis stays alive - // for the entire SCC pass run below. - MPM.add(createGlobalsAAWrapperPass()); - - // Start of CallGraph SCC passes. - if (!DisableUnitAtATime) - MPM.add(createPruneEHPass()); // Remove dead EH info - if (Inliner) { - MPM.add(Inliner); - Inliner = nullptr; + if (!PGOInstrGen.empty()) { + MPM.add(createPGOInstrumentationGenLegacyPass()); + // Add the profile lowering pass. + InstrProfOptions Options; + Options.InstrProfileOutput = PGOInstrGen; + MPM.add(createInstrProfilingLegacyPass(Options)); } - if (!DisableUnitAtATime) - MPM.add(createPostOrderFunctionAttrsPass()); - if (OptLevel > 2) - MPM.add(createArgumentPromotionPass()); // Scalarize uninlined fn args - + if (!PGOInstrUse.empty()) + MPM.add(createPGOInstrumentationUseLegacyPass(PGOInstrUse)); +} +void PassManagerBuilder::addFunctionSimplificationPasses( + legacy::PassManagerBase &MPM) { // Start of function pass. // Break up aggregate allocas, using SSAUpdater. - if (UseNewSROA) - MPM.add(createSROAPass()); - else - MPM.add(createScalarReplAggregatesPass(-1, false)); + MPM.add(createSROAPass()); MPM.add(createEarlyCSEPass()); // Catch trivial redundancies + // Speculative execution if the target has divergent branches; otherwise nop. + MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass()); MPM.add(createJumpThreadingPass()); // Thread jumps. MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals MPM.add(createCFGSimplificationPass()); // Merge & remove BBs - MPM.add(createInstructionCombiningPass()); // Combine silly seq's + // Combine silly seq's + addInstructionCombiningPass(MPM); addExtensionsToPM(EP_Peephole, MPM); MPM.add(createTailCallEliminationPass()); // Eliminate tail calls @@ -276,7 +284,7 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createLICMPass()); // Hoist loop invariants MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3)); MPM.add(createCFGSimplificationPass()); - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars MPM.add(createLoopIdiomPass()); // Recognize idioms like memset. MPM.add(createLoopDeletionPass()); // Delete dead loops @@ -303,7 +311,7 @@ void PassManagerBuilder::populateModulePassManager( // Run instcombine after redundancy elimination to exploit opportunities // opened up by them. - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); addExtensionsToPM(EP_Peephole, MPM); MPM.add(createJumpThreadingPass()); // Thread jumps MPM.add(createCorrelatedValuePropagationPass()); @@ -320,7 +328,7 @@ void PassManagerBuilder::populateModulePassManager( if (BBVectorize) { MPM.add(createBBVectorizePass()); - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); addExtensionsToPM(EP_Peephole, MPM); if (OptLevel > 1 && UseGVNAfterVectorization) MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies @@ -338,18 +346,99 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createAggressiveDCEPass()); // Delete dead instructions MPM.add(createCFGSimplificationPass()); // Merge & remove BBs - MPM.add(createInstructionCombiningPass()); // Clean up after everything. + // Clean up after everything. + addInstructionCombiningPass(MPM); addExtensionsToPM(EP_Peephole, MPM); +} + +void PassManagerBuilder::populateModulePassManager( + legacy::PassManagerBase &MPM) { + // Allow forcing function attributes as a debugging and tuning aid. + MPM.add(createForceFunctionAttrsLegacyPass()); + + // If all optimizations are disabled, just run the always-inline pass and, + // if enabled, the function merging pass. + if (OptLevel == 0) { + addPGOInstrPasses(MPM); + if (Inliner) { + MPM.add(Inliner); + Inliner = nullptr; + } + + // FIXME: The BarrierNoopPass is a HACK! The inliner pass above implicitly + // creates a CGSCC pass manager, but we don't want to add extensions into + // that pass manager. To prevent this we insert a no-op module pass to reset + // the pass manager to get the same behavior as EP_OptimizerLast in non-O0 + // builds. The function merging pass is + if (MergeFunctions) + MPM.add(createMergeFunctionsPass()); + else if (!GlobalExtensions->empty() || !Extensions.empty()) + MPM.add(createBarrierNoopPass()); + + addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); + return; + } + + // Add LibraryInfo if we have some. + if (LibraryInfo) + MPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); + + addInitialAliasAnalysisPasses(MPM); + + if (!DisableUnitAtATime) { + // Infer attributes about declarations if possible. + MPM.add(createInferFunctionAttrsLegacyPass()); + + addExtensionsToPM(EP_ModuleOptimizerEarly, MPM); + + MPM.add(createIPSCCPPass()); // IP SCCP + MPM.add(createGlobalOptimizerPass()); // Optimize out global vars + // Promote any localized global vars. + MPM.add(createPromoteMemoryToRegisterPass()); + + MPM.add(createDeadArgEliminationPass()); // Dead argument elimination + + addInstructionCombiningPass(MPM); // Clean up after IPCP & DAE + addExtensionsToPM(EP_Peephole, MPM); + MPM.add(createCFGSimplificationPass()); // Clean up after IPCP & DAE + } + + if (!PerformThinLTO) { + /// PGO instrumentation is added during the compile phase for ThinLTO, do + /// not run it a second time + addPGOInstrPasses(MPM); + } + + // Indirect call promotion that promotes intra-module targets only. + MPM.add(createPGOIndirectCallPromotionLegacyPass()); + + if (EnableNonLTOGlobalsModRef) + // We add a module alias analysis pass here. In part due to bugs in the + // analysis infrastructure this "works" in that the analysis stays alive + // for the entire SCC pass run below. + MPM.add(createGlobalsAAWrapperPass()); + + // Start of CallGraph SCC passes. + if (!DisableUnitAtATime) + MPM.add(createPruneEHPass()); // Remove dead EH info + if (Inliner) { + MPM.add(Inliner); + Inliner = nullptr; + } + if (!DisableUnitAtATime) + MPM.add(createPostOrderFunctionAttrsLegacyPass()); + if (OptLevel > 2) + MPM.add(createArgumentPromotionPass()); // Scalarize uninlined fn args + + addFunctionSimplificationPasses(MPM); // FIXME: This is a HACK! The inliner pass above implicitly creates a CGSCC // pass manager that we are specifically trying to avoid. To prevent this // we must insert a no-op module pass to reset the pass manager. MPM.add(createBarrierNoopPass()); - if (!DisableUnitAtATime) - MPM.add(createReversePostOrderFunctionAttrsPass()); - - if (!DisableUnitAtATime && OptLevel > 1 && !PrepareForLTO) { + if (!DisableUnitAtATime && OptLevel > 1 && !PrepareForLTO && + !PrepareForThinLTO) // Remove avail extern fns and globals definitions if we aren't // compiling an object file for later LTO. For LTO we want to preserve // these so they are eligible for inlining at link-time. Note if they @@ -360,6 +449,34 @@ void PassManagerBuilder::populateModulePassManager( // globals referenced by available external functions dead // and saves running remaining passes on the eliminated functions. MPM.add(createEliminateAvailableExternallyPass()); + + if (!DisableUnitAtATime) + MPM.add(createReversePostOrderFunctionAttrsPass()); + + // If we are planning to perform ThinLTO later, let's not bloat the code with + // unrolling/vectorization/... now. We'll first run the inliner + CGSCC passes + // during ThinLTO and perform the rest of the optimizations afterward. + if (PrepareForThinLTO) { + // Reduce the size of the IR as much as possible. + MPM.add(createGlobalOptimizerPass()); + // Rename anon function to be able to export them in the summary. + MPM.add(createNameAnonFunctionPass()); + return; + } + + if (PerformThinLTO) + // Optimize globals now when performing ThinLTO, this enables more + // optimizations later. + MPM.add(createGlobalOptimizerPass()); + + // Scheduling LoopVersioningLICM when inlining is over, because after that + // we may see more accurate aliasing. Reason to run this late is that too + // early versioning may prevent further inlining due to increase of code + // size. By placing it just after inlining other optimizations which runs + // later might get benefit of no-alias assumption in clone loop. + if (UseLoopVersioningLICM) { + MPM.add(createLoopVersioningLICMPass()); // Do LoopVersioningLICM + MPM.add(createLICMPass()); // Hoist loop invariants } if (EnableNonLTOGlobalsModRef) @@ -391,9 +508,10 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1)); // Distribute loops to allow partial vectorization. I.e. isolate dependences - // into separate loop that would otherwise inhibit vectorization. - if (EnableLoopDistribute) - MPM.add(createLoopDistributePass()); + // into separate loop that would otherwise inhibit vectorization. This is + // currently only performed for loops marked with the metadata + // llvm.loop.distribute=true or when -enable-loop-distribute is specified. + MPM.add(createLoopDistributePass(/*ProcessAllLoopsByDefault=*/false)); MPM.add(createLoopVectorizePass(DisableUnrollLoops, LoopVectorize)); @@ -407,7 +525,7 @@ void PassManagerBuilder::populateModulePassManager( // on -O1 and no #pragma is found). Would be good to have these two passes // as function calls, so that we can only pass them when the vectorizer // changed the code. - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); if (OptLevel > 1 && ExtraVectorizerPasses) { // At higher optimization levels, try to clean up any runtime overlap and // alignment checks inserted by the vectorizer. We want to track correllated @@ -417,11 +535,11 @@ void PassManagerBuilder::populateModulePassManager( // dead (or speculatable) control flows or more combining opportunities. MPM.add(createEarlyCSEPass()); MPM.add(createCorrelatedValuePropagationPass()); - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); MPM.add(createLICMPass()); MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3)); MPM.add(createCFGSimplificationPass()); - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); } if (RunSLPAfterLoopVectorization) { @@ -434,7 +552,7 @@ void PassManagerBuilder::populateModulePassManager( if (BBVectorize) { MPM.add(createBBVectorizePass()); - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); addExtensionsToPM(EP_Peephole, MPM); if (OptLevel > 1 && UseGVNAfterVectorization) MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies @@ -449,19 +567,22 @@ void PassManagerBuilder::populateModulePassManager( addExtensionsToPM(EP_Peephole, MPM); MPM.add(createCFGSimplificationPass()); - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); if (!DisableUnrollLoops) { MPM.add(createLoopUnrollPass()); // Unroll small loops // LoopUnroll may generate some redundency to cleanup. - MPM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(MPM); // Runtime unrolling will introduce runtime check in loop prologue. If the // unrolled loop is a inner loop, then the prologue will be inside the // outer loop. LICM pass can help to promote the runtime check out if the // checked value is loop invariant. MPM.add(createLICMPass()); + + // Get rid of LCSSA nodes. + MPM.add(createInstructionSimplifierPass()); } // After vectorization and unrolling, assume intrinsics may tell us more @@ -487,11 +608,15 @@ void PassManagerBuilder::populateModulePassManager( } void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { + // Remove unused virtual tables to improve the quality of code generated by + // whole-program devirtualization and bitset lowering. + PM.add(createGlobalDCEPass()); + // Provide AliasAnalysis services for optimizations. addInitialAliasAnalysisPasses(PM); - if (FunctionIndex) - PM.add(createFunctionImportPass(FunctionIndex)); + if (ModuleSummary) + PM.add(createFunctionImportPass(ModuleSummary)); // Allow forcing function attributes as a debugging and tuning aid. PM.add(createForceFunctionAttrsLegacyPass()); @@ -499,14 +624,32 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // Infer attributes about declarations if possible. PM.add(createInferFunctionAttrsLegacyPass()); - // Propagate constants at call sites into the functions they call. This - // opens opportunities for globalopt (and inlining) by substituting function - // pointers passed as arguments to direct uses of functions. - PM.add(createIPSCCPPass()); + if (OptLevel > 1) { + // Indirect call promotion. This should promote all the targets that are + // left by the earlier promotion pass that promotes intra-module targets. + // This two-step promotion is to save the compile time. For LTO, it should + // produce the same result as if we only do promotion here. + PM.add(createPGOIndirectCallPromotionLegacyPass(true)); + + // Propagate constants at call sites into the functions they call. This + // opens opportunities for globalopt (and inlining) by substituting function + // pointers passed as arguments to direct uses of functions. + PM.add(createIPSCCPPass()); + } - // Now that we internalized some globals, see if we can hack on them! - PM.add(createPostOrderFunctionAttrsPass()); + // Infer attributes about definitions. The readnone attribute in particular is + // required for virtual constant propagation. + PM.add(createPostOrderFunctionAttrsLegacyPass()); PM.add(createReversePostOrderFunctionAttrsPass()); + + // Apply whole-program devirtualization and virtual constant propagation. + PM.add(createWholeProgramDevirtPass()); + + // That's all we need at opt level 1. + if (OptLevel == 1) + return; + + // Now that we internalized some globals, see if we can hack on them! PM.add(createGlobalOptimizerPass()); // Promote any localized global vars. PM.add(createPromoteMemoryToRegisterPass()); @@ -522,7 +665,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // simplification opportunities, and both can propagate functions through // function pointers. When this happens, we often have to resolve varargs // calls, etc, so let instcombine do this. - PM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(PM); addExtensionsToPM(EP_Peephole, PM); // Inline small functions @@ -544,18 +687,15 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createArgumentPromotionPass()); // The IPO passes may leave cruft around. Clean up after them. - PM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(PM); addExtensionsToPM(EP_Peephole, PM); PM.add(createJumpThreadingPass()); // Break up allocas - if (UseNewSROA) - PM.add(createSROAPass()); - else - PM.add(createScalarReplAggregatesPass()); + PM.add(createSROAPass()); // Run a few AA driven optimizations here and now, to cleanup the code. - PM.add(createPostOrderFunctionAttrsPass()); // Add nocapture. + PM.add(createPostOrderFunctionAttrsLegacyPass()); // Add nocapture. PM.add(createGlobalsAAWrapperPass()); // IP alias analysis. PM.add(createLICMPass()); // Hoist loop invariants. @@ -573,15 +713,20 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { if (EnableLoopInterchange) PM.add(createLoopInterchangePass()); + if (!DisableUnrollLoops) + PM.add(createSimpleLoopUnrollPass()); // Unroll small loops PM.add(createLoopVectorizePass(true, LoopVectorize)); + // The vectorizer may have significantly shortened a loop body; unroll again. + if (!DisableUnrollLoops) + PM.add(createLoopUnrollPass()); // Now that we've optimized loops (in particular loop induction variables), // we may have exposed more scalar opportunities. Run parts of the scalar // optimizer again at this point. - PM.add(createInstructionCombiningPass()); // Initial cleanup + addInstructionCombiningPass(PM); // Initial cleanup PM.add(createCFGSimplificationPass()); // if-convert PM.add(createSCCPPass()); // Propagate exposed constants - PM.add(createInstructionCombiningPass()); // Clean up again + addInstructionCombiningPass(PM); // Clean up again PM.add(createBitTrackingDCEPass()); // More scalar chains could be vectorized due to more alias information @@ -597,7 +742,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createLoadCombinePass()); // Cleanup and simplify the code after the scalar optimizations. - PM.add(createInstructionCombiningPass()); + addInstructionCombiningPass(PM); addExtensionsToPM(EP_Peephole, PM); PM.add(createJumpThreadingPass()); @@ -620,6 +765,23 @@ void PassManagerBuilder::addLateLTOOptimizationPasses( PM.add(createMergeFunctionsPass()); } +void PassManagerBuilder::populateThinLTOPassManager( + legacy::PassManagerBase &PM) { + PerformThinLTO = true; + + if (VerifyInput) + PM.add(createVerifierPass()); + + if (ModuleSummary) + PM.add(createFunctionImportPass(ModuleSummary)); + + populateModulePassManager(PM); + + if (VerifyOutput) + PM.add(createVerifierPass()); + PerformThinLTO = false; +} + void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { if (LibraryInfo) PM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); @@ -627,17 +789,17 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { if (VerifyInput) PM.add(createVerifierPass()); - if (OptLevel > 1) + if (OptLevel != 0) addLTOOptimizationPasses(PM); // Create a function that performs CFI checks for cross-DSO calls with targets // in the current module. PM.add(createCrossDSOCFIPass()); - // Lower bit sets to globals. This pass supports Clang's control flow - // integrity mechanisms (-fsanitize=cfi*) and needs to run at link time if CFI - // is enabled. The pass does nothing if CFI is disabled. - PM.add(createLowerBitSetsPass()); + // Lower type metadata and the type.test intrinsic. This pass supports Clang's + // control flow integrity mechanisms (-fsanitize=cfi*) and needs to run at + // link time if CFI is enabled. The pass does nothing if CFI is disabled. + PM.add(createLowerTypeTestsPass()); if (OptLevel != 0) addLateLTOOptimizationPasses(PM); diff --git a/lib/Transforms/IPO/PruneEH.cpp b/lib/Transforms/IPO/PruneEH.cpp index 22a95fa03f7c..2aa3fa55cefd 100644 --- a/lib/Transforms/IPO/PruneEH.cpp +++ b/lib/Transforms/IPO/PruneEH.cpp @@ -16,7 +16,6 @@ #include "llvm/Transforms/IPO.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Analysis/CallGraph.h" @@ -48,10 +47,10 @@ namespace { // runOnSCC - Analyze the SCC, performing the transformation if possible. bool runOnSCC(CallGraphSCC &SCC) override; - bool SimplifyFunction(Function *F); - void DeleteBasicBlock(BasicBlock *BB); }; } +static bool SimplifyFunction(Function *F, CallGraph &CG); +static void DeleteBasicBlock(BasicBlock *BB, CallGraph &CG); char PruneEH::ID = 0; INITIALIZE_PASS_BEGIN(PruneEH, "prune-eh", @@ -62,22 +61,20 @@ INITIALIZE_PASS_END(PruneEH, "prune-eh", Pass *llvm::createPruneEHPass() { return new PruneEH(); } - -bool PruneEH::runOnSCC(CallGraphSCC &SCC) { +static bool runImpl(CallGraphSCC &SCC, CallGraph &CG) { SmallPtrSet<CallGraphNode *, 8> SCCNodes; - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); bool MadeChange = false; // Fill SCCNodes with the elements of the SCC. Used for quickly // looking up whether a given CallGraphNode is in this SCC. - for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) - SCCNodes.insert(*I); + for (CallGraphNode *I : SCC) + SCCNodes.insert(I); // First pass, scan all of the functions in the SCC, simplifying them // according to what we know. - for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) - if (Function *F = (*I)->getFunction()) - MadeChange |= SimplifyFunction(F); + for (CallGraphNode *I : SCC) + if (Function *F = I->getFunction()) + MadeChange |= SimplifyFunction(F, CG); // Next, check to see if any callees might throw or if there are any external // functions in this SCC: if so, we cannot prune any functions in this SCC. @@ -93,7 +90,10 @@ bool PruneEH::runOnSCC(CallGraphSCC &SCC) { if (!F) { SCCMightUnwind = true; SCCMightReturn = true; - } else if (F->isDeclaration() || F->mayBeOverridden()) { + } else if (F->isDeclaration() || F->isInterposable()) { + // Note: isInterposable (as opposed to hasExactDefinition) is fine above, + // since we're not inferring new attributes here, but only using existing, + // assumed to be correct, function attributes. SCCMightUnwind |= !F->doesNotThrow(); SCCMightReturn |= !F->doesNotReturn(); } else { @@ -153,8 +153,8 @@ bool PruneEH::runOnSCC(CallGraphSCC &SCC) { // If the SCC doesn't unwind or doesn't throw, note this fact. if (!SCCMightUnwind || !SCCMightReturn) - for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) { - Function *F = (*I)->getFunction(); + for (CallGraphNode *I : SCC) { + Function *F = I->getFunction(); if (!SCCMightUnwind && !F->hasFnAttribute(Attribute::NoUnwind)) { F->addFnAttr(Attribute::NoUnwind); @@ -167,22 +167,30 @@ bool PruneEH::runOnSCC(CallGraphSCC &SCC) { } } - for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) { + for (CallGraphNode *I : SCC) { // Convert any invoke instructions to non-throwing functions in this node // into call instructions with a branch. This makes the exception blocks // dead. - if (Function *F = (*I)->getFunction()) - MadeChange |= SimplifyFunction(F); + if (Function *F = I->getFunction()) + MadeChange |= SimplifyFunction(F, CG); } return MadeChange; } +bool PruneEH::runOnSCC(CallGraphSCC &SCC) { + if (skipSCC(SCC)) + return false; + CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + return runImpl(SCC, CG); +} + + // SimplifyFunction - Given information about callees, simplify the specified // function if we have invokes to non-unwinding functions or code after calls to // no-return functions. -bool PruneEH::SimplifyFunction(Function *F) { +static bool SimplifyFunction(Function *F, CallGraph &CG) { bool MadeChange = false; for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { if (InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) @@ -192,7 +200,7 @@ bool PruneEH::SimplifyFunction(Function *F) { // If the unwind block is now dead, nuke it. if (pred_empty(UnwindBlock)) - DeleteBasicBlock(UnwindBlock); // Delete the new BB. + DeleteBasicBlock(UnwindBlock, CG); // Delete the new BB. ++NumRemoved; MadeChange = true; @@ -211,7 +219,7 @@ bool PruneEH::SimplifyFunction(Function *F) { BB->getInstList().pop_back(); new UnreachableInst(BB->getContext(), &*BB); - DeleteBasicBlock(New); // Delete the new BB. + DeleteBasicBlock(New, CG); // Delete the new BB. MadeChange = true; ++NumUnreach; break; @@ -224,9 +232,8 @@ bool PruneEH::SimplifyFunction(Function *F) { /// DeleteBasicBlock - remove the specified basic block from the program, /// updating the callgraph to reflect any now-obsolete edges due to calls that /// exist in the BB. -void PruneEH::DeleteBasicBlock(BasicBlock *BB) { +static void DeleteBasicBlock(BasicBlock *BB, CallGraph &CG) { assert(pred_empty(BB) && "BB is not dead!"); - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); Instruction *TokenInst = nullptr; diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp index 928d92ef9d12..39de108edc06 100644 --- a/lib/Transforms/IPO/SampleProfile.cpp +++ b/lib/Transforms/IPO/SampleProfile.cpp @@ -22,10 +22,12 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/SampleProfile.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/Constants.h" @@ -35,6 +37,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" @@ -76,16 +79,6 @@ static cl::opt<double> SampleProfileHotThreshold( "sample-profile-inline-hot-threshold", cl::init(0.1), cl::value_desc("N"), cl::desc("Inlined functions that account for more than N% of all samples " "collected in the parent function, will be inlined again.")); -static cl::opt<double> SampleProfileGlobalHotThreshold( - "sample-profile-global-hot-threshold", cl::init(30), cl::value_desc("N"), - cl::desc("Top-level functions that account for more than N% of all samples " - "collected in the profile, will be marked as hot for the inliner " - "to consider.")); -static cl::opt<double> SampleProfileGlobalColdThreshold( - "sample-profile-global-cold-threshold", cl::init(0.5), cl::value_desc("N"), - cl::desc("Top-level functions that account for less than N% of all samples " - "collected in the profile, will be marked as cold for the inliner " - "to consider.")); namespace { typedef DenseMap<const BasicBlock *, uint64_t> BlockWeightMap; @@ -100,30 +93,19 @@ typedef DenseMap<const BasicBlock *, SmallVector<const BasicBlock *, 8>> /// This pass reads profile data from the file specified by /// -sample-profile-file and annotates every affected function with the /// profile information found in that file. -class SampleProfileLoader : public ModulePass { +class SampleProfileLoader { public: - // Class identification, replacement for typeinfo - static char ID; - SampleProfileLoader(StringRef Name = SampleProfileFile) - : ModulePass(ID), DT(nullptr), PDT(nullptr), LI(nullptr), Reader(), + : DT(nullptr), PDT(nullptr), LI(nullptr), ACT(nullptr), Reader(), Samples(nullptr), Filename(Name), ProfileIsValid(false), - TotalCollectedSamples(0) { - initializeSampleProfileLoaderPass(*PassRegistry::getPassRegistry()); - } + TotalCollectedSamples(0) {} - bool doInitialization(Module &M) override; + bool doInitialization(Module &M); + bool runOnModule(Module &M); + void setACT(AssumptionCacheTracker *A) { ACT = A; } void dump() { Reader->dump(); } - const char *getPassName() const override { return "Sample profile pass"; } - - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - } - protected: bool runOnFunction(Function &F); unsigned getFunctionLoc(Function &F); @@ -133,14 +115,12 @@ protected: const FunctionSamples *findCalleeFunctionSamples(const CallInst &I) const; const FunctionSamples *findFunctionSamples(const Instruction &I) const; bool inlineHotFunctions(Function &F); - bool emitInlineHints(Function &F); void printEdgeWeight(raw_ostream &OS, Edge E); void printBlockWeight(raw_ostream &OS, const BasicBlock *BB) const; void printBlockEquivalence(raw_ostream &OS, const BasicBlock *BB); bool computeBlockWeights(Function &F); void findEquivalenceClasses(Function &F); - void findEquivalencesFor(BasicBlock *BB1, - SmallVector<BasicBlock *, 8> Descendants, + void findEquivalencesFor(BasicBlock *BB1, ArrayRef<BasicBlock *> Descendants, DominatorTreeBase<BasicBlock> *DomTree); void propagateWeights(Function &F); uint64_t visitEdge(Edge E, unsigned *NumUnknownEdges, Edge *UnknownEdge); @@ -163,10 +143,10 @@ protected: EdgeWeightMap EdgeWeights; /// \brief Set of visited blocks during propagation. - SmallPtrSet<const BasicBlock *, 128> VisitedBlocks; + SmallPtrSet<const BasicBlock *, 32> VisitedBlocks; /// \brief Set of visited edges during propagation. - SmallSet<Edge, 128> VisitedEdges; + SmallSet<Edge, 32> VisitedEdges; /// \brief Equivalence classes for block weights. /// @@ -181,6 +161,8 @@ protected: std::unique_ptr<DominatorTreeBase<BasicBlock>> PDT; std::unique_ptr<LoopInfo> LI; + AssumptionCacheTracker *ACT; + /// \brief Predecessors for each basic block in the CFG. BlockEdgeMap Predecessors; @@ -206,6 +188,32 @@ protected: uint64_t TotalCollectedSamples; }; +class SampleProfileLoaderLegacyPass : public ModulePass { +public: + // Class identification, replacement for typeinfo + static char ID; + + SampleProfileLoaderLegacyPass(StringRef Name = SampleProfileFile) + : ModulePass(ID), SampleLoader(Name) { + initializeSampleProfileLoaderLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + void dump() { SampleLoader.dump(); } + + bool doInitialization(Module &M) override { + return SampleLoader.doInitialization(M); + } + const char *getPassName() const override { return "Sample profile pass"; } + bool runOnModule(Module &M) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + } +private: + SampleProfileLoader SampleLoader; +}; + class SampleCoverageTracker { public: SampleCoverageTracker() : SampleCoverage(), TotalUsedSamples(0) {} @@ -285,7 +293,6 @@ bool callsiteIsHot(const FunctionSamples *CallerFS, (double)CallsiteTotalSamples / (double)ParentTotalSamples * 100.0; return PercentSamples >= SampleProfileHotThreshold; } - } /// Mark as used the sample record for the given function samples at @@ -445,7 +452,7 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS, /// \returns the weight of \p Inst. ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) const { - DebugLoc DLoc = Inst.getDebugLoc(); + const DebugLoc &DLoc = Inst.getDebugLoc(); if (!DLoc) return std::error_code(); @@ -453,6 +460,11 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) const { if (!FS) return std::error_code(); + // Ignore all dbg_value intrinsics. + const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); + if (II && II->getIntrinsicID() == Intrinsic::dbg_value) + return std::error_code(); + const DILocation *DIL = DLoc; unsigned Lineno = DLoc.getLine(); unsigned HeaderLineno = DIL->getScope()->getSubprogram()->getLine(); @@ -476,6 +488,13 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) const { << Inst << " (line offset: " << Lineno - HeaderLineno << "." << DIL->getDiscriminator() << " - weight: " << R.get() << ")\n"); + } else { + // If a call instruction is inlined in profile, but not inlined here, + // it means that the inlined callsite has no sample, thus the call + // instruction should have 0 count. + const CallInst *CI = dyn_cast<CallInst>(&Inst); + if (CI && findCalleeFunctionSamples(*CI)) + R = 0; } return R; } @@ -490,19 +509,22 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) const { /// \returns the weight for \p BB. ErrorOr<uint64_t> SampleProfileLoader::getBlockWeight(const BasicBlock *BB) const { - bool Found = false; - uint64_t Weight = 0; + DenseMap<uint64_t, uint64_t> CM; for (auto &I : BB->getInstList()) { const ErrorOr<uint64_t> &R = getInstWeight(I); - if (R && R.get() >= Weight) { - Weight = R.get(); - Found = true; + if (R) CM[R.get()]++; + } + if (CM.size() == 0) return std::error_code(); + uint64_t W = 0, C = 0; + for (const auto &C_W : CM) { + if (C_W.second == W) { + C = std::max(C, C_W.first); + } else if (C_W.second > W) { + C = C_W.first; + W = C_W.second; } } - if (Found) - return Weight; - else - return std::error_code(); + return C; } /// \brief Compute and store the weights of every basic block. @@ -549,19 +571,12 @@ SampleProfileLoader::findCalleeFunctionSamples(const CallInst &Inst) const { if (!SP) return nullptr; - Function *CalleeFunc = Inst.getCalledFunction(); - if (!CalleeFunc) { - return nullptr; - } - - StringRef CalleeName = CalleeFunc->getName(); const FunctionSamples *FS = findFunctionSamples(Inst); if (FS == nullptr) return nullptr; - return FS->findFunctionSamplesAt( - CallsiteLocation(getOffset(DIL->getLine(), SP->getLine()), - DIL->getDiscriminator(), CalleeName)); + return FS->findFunctionSamplesAt(LineLocation( + getOffset(DIL->getLine(), SP->getLine()), DIL->getDiscriminator())); } /// \brief Get the FunctionSamples for an instruction. @@ -575,22 +590,17 @@ SampleProfileLoader::findCalleeFunctionSamples(const CallInst &Inst) const { /// \returns the FunctionSamples pointer to the inlined instance. const FunctionSamples * SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { - SmallVector<CallsiteLocation, 10> S; + SmallVector<LineLocation, 10> S; const DILocation *DIL = Inst.getDebugLoc(); if (!DIL) { return Samples; } - StringRef CalleeName; - for (const DILocation *DIL = Inst.getDebugLoc(); DIL; - DIL = DIL->getInlinedAt()) { + for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) { DISubprogram *SP = DIL->getScope()->getSubprogram(); if (!SP) return nullptr; - if (!CalleeName.empty()) { - S.push_back(CallsiteLocation(getOffset(DIL->getLine(), SP->getLine()), - DIL->getDiscriminator(), CalleeName)); - } - CalleeName = SP->getLinkageName(); + S.push_back(LineLocation(getOffset(DIL->getLine(), SP->getLine()), + DIL->getDiscriminator())); } if (S.size() == 0) return Samples; @@ -601,63 +611,6 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { return FS; } -/// \brief Emit an inline hint if \p F is globally hot or cold. -/// -/// If \p F consumes a significant fraction of samples (indicated by -/// SampleProfileGlobalHotThreshold), apply the InlineHint attribute for the -/// inliner to consider the function hot. -/// -/// If \p F consumes a small fraction of samples (indicated by -/// SampleProfileGlobalColdThreshold), apply the Cold attribute for the inliner -/// to consider the function cold. -/// -/// FIXME - This setting of inline hints is sub-optimal. Instead of marking a -/// function globally hot or cold, we should be annotating individual callsites. -/// This is not currently possible, but work on the inliner will eventually -/// provide this ability. See http://reviews.llvm.org/D15003 for details and -/// discussion. -/// -/// \returns True if either attribute was applied to \p F. -bool SampleProfileLoader::emitInlineHints(Function &F) { - if (TotalCollectedSamples == 0) - return false; - - uint64_t FunctionSamples = Samples->getTotalSamples(); - double SamplesPercent = - (double)FunctionSamples / (double)TotalCollectedSamples * 100.0; - - // If the function collected more samples than the hot threshold, mark - // it globally hot. - if (SamplesPercent >= SampleProfileGlobalHotThreshold) { - F.addFnAttr(llvm::Attribute::InlineHint); - std::string Msg; - raw_string_ostream S(Msg); - S << "Applied inline hint to globally hot function '" << F.getName() - << "' with " << format("%.2f", SamplesPercent) - << "% of samples (threshold: " - << format("%.2f", SampleProfileGlobalHotThreshold.getValue()) << "%)"; - S.flush(); - emitOptimizationRemark(F.getContext(), DEBUG_TYPE, F, DebugLoc(), Msg); - return true; - } - - // If the function collected fewer samples than the cold threshold, mark - // it globally cold. - if (SamplesPercent <= SampleProfileGlobalColdThreshold) { - F.addFnAttr(llvm::Attribute::Cold); - std::string Msg; - raw_string_ostream S(Msg); - S << "Applied cold hint to globally cold function '" << F.getName() - << "' with " << format("%.2f", SamplesPercent) - << "% of samples (threshold: " - << format("%.2f", SampleProfileGlobalColdThreshold.getValue()) << "%)"; - S.flush(); - emitOptimizationRemark(F.getContext(), DEBUG_TYPE, F, DebugLoc(), Msg); - return true; - } - - return false; -} /// \brief Iteratively inline hot callsites of a function. /// @@ -685,7 +638,7 @@ bool SampleProfileLoader::inlineHotFunctions(Function &F) { } } for (auto CI : CIS) { - InlineFunctionInfo IFI; + InlineFunctionInfo IFI(nullptr, ACT); Function *CalledFunction = CI->getCalledFunction(); DebugLoc DLoc = CI->getDebugLoc(); uint64_t NumSamples = findCalleeFunctionSamples(*CI)->getTotalSamples(); @@ -731,7 +684,7 @@ bool SampleProfileLoader::inlineHotFunctions(Function &F) { /// with blocks from \p BB1's dominator tree, then /// this is the post-dominator tree, and vice versa. void SampleProfileLoader::findEquivalencesFor( - BasicBlock *BB1, SmallVector<BasicBlock *, 8> Descendants, + BasicBlock *BB1, ArrayRef<BasicBlock *> Descendants, DominatorTreeBase<BasicBlock> *DomTree) { const BasicBlock *EC = EquivalenceClass[BB1]; uint64_t Weight = BlockWeights[EC]; @@ -859,23 +812,31 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F) { // edge is unknown (see setEdgeOrBlockWeight). for (unsigned i = 0; i < 2; i++) { uint64_t TotalWeight = 0; - unsigned NumUnknownEdges = 0; - Edge UnknownEdge, SelfReferentialEdge; + unsigned NumUnknownEdges = 0, NumTotalEdges = 0; + Edge UnknownEdge, SelfReferentialEdge, SingleEdge; if (i == 0) { // First, visit all predecessor edges. + NumTotalEdges = Predecessors[BB].size(); for (auto *Pred : Predecessors[BB]) { Edge E = std::make_pair(Pred, BB); TotalWeight += visitEdge(E, &NumUnknownEdges, &UnknownEdge); if (E.first == E.second) SelfReferentialEdge = E; } + if (NumTotalEdges == 1) { + SingleEdge = std::make_pair(Predecessors[BB][0], BB); + } } else { // On the second round, visit all successor edges. + NumTotalEdges = Successors[BB].size(); for (auto *Succ : Successors[BB]) { Edge E = std::make_pair(BB, Succ); TotalWeight += visitEdge(E, &NumUnknownEdges, &UnknownEdge); } + if (NumTotalEdges == 1) { + SingleEdge = std::make_pair(BB, Successors[BB][0]); + } } // After visiting all the edges, there are three cases that we @@ -904,18 +865,24 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F) { if (NumUnknownEdges <= 1) { uint64_t &BBWeight = BlockWeights[EC]; if (NumUnknownEdges == 0) { - // If we already know the weight of all edges, the weight of the - // basic block can be computed. It should be no larger than the sum - // of all edge weights. - if (TotalWeight > BBWeight) { - BBWeight = TotalWeight; + if (!VisitedBlocks.count(EC)) { + // If we already know the weight of all edges, the weight of the + // basic block can be computed. It should be no larger than the sum + // of all edge weights. + if (TotalWeight > BBWeight) { + BBWeight = TotalWeight; + Changed = true; + DEBUG(dbgs() << "All edge weights for " << BB->getName() + << " known. Set weight for block: "; + printBlockWeight(dbgs(), BB);); + } + } else if (NumTotalEdges == 1 && + EdgeWeights[SingleEdge] < BlockWeights[EC]) { + // If there is only one edge for the visited basic block, use the + // block weight to adjust edge weight if edge weight is smaller. + EdgeWeights[SingleEdge] = BlockWeights[EC]; Changed = true; - DEBUG(dbgs() << "All edge weights for " << BB->getName() - << " known. Set weight for block: "; - printBlockWeight(dbgs(), BB);); } - if (VisitedBlocks.insert(EC).second) - Changed = true; } else if (NumUnknownEdges == 1 && VisitedBlocks.count(EC)) { // If there is a single unknown edge and the block has been // visited, then we can compute E's weight. @@ -1020,6 +987,19 @@ void SampleProfileLoader::propagateWeights(Function &F) { MDBuilder MDB(Ctx); for (auto &BI : F) { BasicBlock *BB = &BI; + + if (BlockWeights[BB]) { + for (auto &I : BB->getInstList()) { + if (CallInst *CI = dyn_cast<CallInst>(&I)) { + if (!dyn_cast<IntrinsicInst>(&I)) { + SmallVector<uint32_t, 1> Weights; + Weights.push_back(BlockWeights[BB]); + CI->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(Weights)); + } + } + } + } TerminatorInst *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 1) continue; @@ -1084,7 +1064,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { /// \returns the line number where \p F is defined. If it returns 0, /// it means that there is no debug information available for \p F. unsigned SampleProfileLoader::getFunctionLoc(Function &F) { - if (DISubprogram *S = getDISubprogram(&F)) + if (DISubprogram *S = F.getSubprogram()) return S->getLine(); // If the start of \p F is missing, emit a diagnostic to inform the user @@ -1165,8 +1145,6 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { DEBUG(dbgs() << "Line number for the first instruction in " << F.getName() << ": " << getFunctionLoc(F) << "\n"); - Changed |= emitInlineHints(F); - Changed |= inlineHotFunctions(F); // Compute basic block weights. @@ -1190,7 +1168,7 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { unsigned Coverage = CoverageTracker.computeCoverage(Used, Total); if (Coverage < SampleProfileRecordCoverage) { F.getContext().diagnose(DiagnosticInfoSampleProfile( - getDISubprogram(&F)->getFilename(), getFunctionLoc(F), + F.getSubprogram()->getFilename(), getFunctionLoc(F), Twine(Used) + " of " + Twine(Total) + " available profile records (" + Twine(Coverage) + "%) were applied", DS_Warning)); @@ -1203,7 +1181,7 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { unsigned Coverage = CoverageTracker.computeCoverage(Used, Total); if (Coverage < SampleProfileSampleCoverage) { F.getContext().diagnose(DiagnosticInfoSampleProfile( - getDISubprogram(&F)->getFilename(), getFunctionLoc(F), + F.getSubprogram()->getFilename(), getFunctionLoc(F), Twine(Used) + " of " + Twine(Total) + " available profile samples (" + Twine(Coverage) + "%) were applied", DS_Warning)); @@ -1212,12 +1190,12 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { return Changed; } -char SampleProfileLoader::ID = 0; -INITIALIZE_PASS_BEGIN(SampleProfileLoader, "sample-profile", - "Sample Profile loader", false, false) -INITIALIZE_PASS_DEPENDENCY(AddDiscriminators) -INITIALIZE_PASS_END(SampleProfileLoader, "sample-profile", - "Sample Profile loader", false, false) +char SampleProfileLoaderLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(SampleProfileLoaderLegacyPass, "sample-profile", + "Sample Profile loader", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_END(SampleProfileLoaderLegacyPass, "sample-profile", + "Sample Profile loader", false, false) bool SampleProfileLoader::doInitialization(Module &M) { auto &Ctx = M.getContext(); @@ -1233,11 +1211,11 @@ bool SampleProfileLoader::doInitialization(Module &M) { } ModulePass *llvm::createSampleProfileLoaderPass() { - return new SampleProfileLoader(SampleProfileFile); + return new SampleProfileLoaderLegacyPass(SampleProfileFile); } ModulePass *llvm::createSampleProfileLoaderPass(StringRef Name) { - return new SampleProfileLoader(Name); + return new SampleProfileLoaderLegacyPass(Name); } bool SampleProfileLoader::runOnModule(Module &M) { @@ -1254,12 +1232,33 @@ bool SampleProfileLoader::runOnModule(Module &M) { clearFunctionData(); retval |= runOnFunction(F); } + M.setProfileSummary(Reader->getSummary().getMD(M.getContext())); return retval; } +bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) { + // FIXME: pass in AssumptionCache correctly for the new pass manager. + SampleLoader.setACT(&getAnalysis<AssumptionCacheTracker>()); + return SampleLoader.runOnModule(M); +} + bool SampleProfileLoader::runOnFunction(Function &F) { + F.setEntryCount(0); Samples = Reader->getSamplesFor(F); if (!Samples->empty()) return emitAnnotations(F); return false; } + +PreservedAnalyses SampleProfileLoaderPass::run(Module &M, + AnalysisManager<Module> &AM) { + + SampleProfileLoader SampleLoader(SampleProfileFile); + + SampleLoader.doInitialization(M); + + if (!SampleLoader.runOnModule(M)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} diff --git a/lib/Transforms/IPO/StripDeadPrototypes.cpp b/lib/Transforms/IPO/StripDeadPrototypes.cpp index c94cc7c74a89..3c3c5dd19d1f 100644 --- a/lib/Transforms/IPO/StripDeadPrototypes.cpp +++ b/lib/Transforms/IPO/StripDeadPrototypes.cpp @@ -53,7 +53,8 @@ static bool stripDeadPrototypes(Module &M) { return MadeChange; } -PreservedAnalyses StripDeadPrototypesPass::run(Module &M) { +PreservedAnalyses StripDeadPrototypesPass::run(Module &M, + ModuleAnalysisManager &) { if (stripDeadPrototypes(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); @@ -69,6 +70,9 @@ public: *PassRegistry::getPassRegistry()); } bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + return stripDeadPrototypes(M); } }; diff --git a/lib/Transforms/IPO/StripSymbols.cpp b/lib/Transforms/IPO/StripSymbols.cpp index 46f352f7f9f1..fd250366cef2 100644 --- a/lib/Transforms/IPO/StripSymbols.cpp +++ b/lib/Transforms/IPO/StripSymbols.cpp @@ -21,7 +21,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" @@ -216,11 +215,11 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) { I->setName(""); // Internal symbols can't participate in linkage } - for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { - if (I->hasLocalLinkage() && llvmUsedValues.count(&*I) == 0) - if (!PreserveDbgInfo || !I->getName().startswith("llvm.dbg")) - I->setName(""); // Internal symbols can't participate in linkage - StripSymtab(I->getValueSymbolTable(), PreserveDbgInfo); + for (Function &I : M) { + if (I.hasLocalLinkage() && llvmUsedValues.count(&I) == 0) + if (!PreserveDbgInfo || !I.getName().startswith("llvm.dbg")) + I.setName(""); // Internal symbols can't participate in linkage + StripSymtab(I.getValueSymbolTable(), PreserveDbgInfo); } // Remove all names from types. @@ -230,6 +229,9 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) { } bool StripSymbols::runOnModule(Module &M) { + if (skipModule(M)) + return false; + bool Changed = false; Changed |= StripDebugInfo(M); if (!OnlyDebugInfo) @@ -238,10 +240,15 @@ bool StripSymbols::runOnModule(Module &M) { } bool StripNonDebugSymbols::runOnModule(Module &M) { + if (skipModule(M)) + return false; + return StripSymbolNames(M, true); } bool StripDebugDeclare::runOnModule(Module &M) { + if (skipModule(M)) + return false; Function *Declare = M.getFunction("llvm.dbg.declare"); std::vector<Constant*> DeadConstants; @@ -287,6 +294,9 @@ bool StripDebugDeclare::runOnModule(Module &M) { /// optimized away by the optimizer. This special pass removes debug info for /// such symbols. bool StripDeadDebugInfo::runOnModule(Module &M) { + if (skipModule(M)) + return false; + bool Changed = false; LLVMContext &C = M.getContext(); @@ -312,20 +322,6 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { } for (DICompileUnit *DIC : F.compile_units()) { - // Create our live subprogram list. - bool SubprogramChange = false; - for (DISubprogram *DISP : DIC->getSubprograms()) { - // Make sure we visit each subprogram only once. - if (!VisitedSet.insert(DISP).second) - continue; - - // If the function referenced by DISP is not null, the function is live. - if (LiveSPs.count(DISP)) - LiveSubprograms.push_back(DISP); - else - SubprogramChange = true; - } - // Create our live global variable list. bool GlobalVariableChange = false; for (DIGlobalVariable *DIG : DIC->getGlobalVariables()) { @@ -341,14 +337,8 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { GlobalVariableChange = true; } - // If we found dead subprograms or global variables, replace the current - // subprogram list/global variable list with our new live subprogram/global - // variable list. - if (SubprogramChange) { - DIC->replaceSubprograms(MDTuple::get(C, LiveSubprograms)); - Changed = true; - } - + // If we found dead global variables, replace the current global + // variable list with our new live global variable list. if (GlobalVariableChange) { DIC->replaceGlobalVariables(MDTuple::get(C, LiveGlobalVariables)); Changed = true; diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp new file mode 100644 index 000000000000..53eb4e2c9076 --- /dev/null +++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -0,0 +1,843 @@ +//===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements whole program optimization of virtual calls in cases +// where we know (via !type metadata) that the list of callees is fixed. This +// includes the following: +// - Single implementation devirtualization: if a virtual call has a single +// possible callee, replace all calls with a direct call to that callee. +// - Virtual constant propagation: if the virtual function's return type is an +// integer <=64 bits and all possible callees are readnone, for each class and +// each list of constant arguments: evaluate the function, store the return +// value alongside the virtual table, and rewrite each virtual call as a load +// from the virtual table. +// - Uniform return value optimization: if the conditions for virtual constant +// propagation hold and each function returns the same constant value, replace +// each virtual call with that constant. +// - Unique return value optimization for i1 return values: if the conditions +// for virtual constant propagation hold and a single vtable's function +// returns 0, or a single vtable's function returns 1, replace each virtual +// call with a comparison of the vptr against that vtable's address. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/WholeProgramDevirt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils/Evaluator.h" +#include "llvm/Transforms/Utils/Local.h" + +#include <set> + +using namespace llvm; +using namespace wholeprogramdevirt; + +#define DEBUG_TYPE "wholeprogramdevirt" + +// Find the minimum offset that we may store a value of size Size bits at. If +// IsAfter is set, look for an offset before the object, otherwise look for an +// offset after the object. +uint64_t +wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, + bool IsAfter, uint64_t Size) { + // Find a minimum offset taking into account only vtable sizes. + uint64_t MinByte = 0; + for (const VirtualCallTarget &Target : Targets) { + if (IsAfter) + MinByte = std::max(MinByte, Target.minAfterBytes()); + else + MinByte = std::max(MinByte, Target.minBeforeBytes()); + } + + // Build a vector of arrays of bytes covering, for each target, a slice of the + // used region (see AccumBitVector::BytesUsed in + // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, + // this aligns the used regions to start at MinByte. + // + // In this example, A, B and C are vtables, # is a byte already allocated for + // a virtual function pointer, AAAA... (etc.) are the used regions for the + // vtables and Offset(X) is the value computed for the Offset variable below + // for X. + // + // Offset(A) + // | | + // |MinByte + // A: ################AAAAAAAA|AAAAAAAA + // B: ########BBBBBBBBBBBBBBBB|BBBB + // C: ########################|CCCCCCCCCCCCCCCC + // | Offset(B) | + // + // This code produces the slices of A, B and C that appear after the divider + // at MinByte. + std::vector<ArrayRef<uint8_t>> Used; + for (const VirtualCallTarget &Target : Targets) { + ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed + : Target.TM->Bits->Before.BytesUsed; + uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() + : MinByte - Target.minBeforeBytes(); + + // Disregard used regions that are smaller than Offset. These are + // effectively all-free regions that do not need to be checked. + if (VTUsed.size() > Offset) + Used.push_back(VTUsed.slice(Offset)); + } + + if (Size == 1) { + // Find a free bit in each member of Used. + for (unsigned I = 0;; ++I) { + uint8_t BitsUsed = 0; + for (auto &&B : Used) + if (I < B.size()) + BitsUsed |= B[I]; + if (BitsUsed != 0xff) + return (MinByte + I) * 8 + + countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); + } + } else { + // Find a free (Size/8) byte region in each member of Used. + // FIXME: see if alignment helps. + for (unsigned I = 0;; ++I) { + for (auto &&B : Used) { + unsigned Byte = 0; + while ((I + Byte) < B.size() && Byte < (Size / 8)) { + if (B[I + Byte]) + goto NextI; + ++Byte; + } + } + return (MinByte + I) * 8; + NextI:; + } + } +} + +void wholeprogramdevirt::setBeforeReturnValues( + MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, + unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { + if (BitWidth == 1) + OffsetByte = -(AllocBefore / 8 + 1); + else + OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); + OffsetBit = AllocBefore % 8; + + for (VirtualCallTarget &Target : Targets) { + if (BitWidth == 1) + Target.setBeforeBit(AllocBefore); + else + Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); + } +} + +void wholeprogramdevirt::setAfterReturnValues( + MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, + unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { + if (BitWidth == 1) + OffsetByte = AllocAfter / 8; + else + OffsetByte = (AllocAfter + 7) / 8; + OffsetBit = AllocAfter % 8; + + for (VirtualCallTarget &Target : Targets) { + if (BitWidth == 1) + Target.setAfterBit(AllocAfter); + else + Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); + } +} + +VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) + : Fn(Fn), TM(TM), + IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()) {} + +namespace { + +// A slot in a set of virtual tables. The TypeID identifies the set of virtual +// tables, and the ByteOffset is the offset in bytes from the address point to +// the virtual function pointer. +struct VTableSlot { + Metadata *TypeID; + uint64_t ByteOffset; +}; + +} + +namespace llvm { + +template <> struct DenseMapInfo<VTableSlot> { + static VTableSlot getEmptyKey() { + return {DenseMapInfo<Metadata *>::getEmptyKey(), + DenseMapInfo<uint64_t>::getEmptyKey()}; + } + static VTableSlot getTombstoneKey() { + return {DenseMapInfo<Metadata *>::getTombstoneKey(), + DenseMapInfo<uint64_t>::getTombstoneKey()}; + } + static unsigned getHashValue(const VTableSlot &I) { + return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ + DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); + } + static bool isEqual(const VTableSlot &LHS, + const VTableSlot &RHS) { + return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; + } +}; + +} + +namespace { + +// A virtual call site. VTable is the loaded virtual table pointer, and CS is +// the indirect virtual call. +struct VirtualCallSite { + Value *VTable; + CallSite CS; + + // If non-null, this field points to the associated unsafe use count stored in + // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description + // of that field for details. + unsigned *NumUnsafeUses; + + void emitRemark() { + Function *F = CS.getCaller(); + emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, + CS.getInstruction()->getDebugLoc(), + "devirtualized call"); + } + + void replaceAndErase(Value *New) { + emitRemark(); + CS->replaceAllUsesWith(New); + if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { + BranchInst::Create(II->getNormalDest(), CS.getInstruction()); + II->getUnwindDest()->removePredecessor(II->getParent()); + } + CS->eraseFromParent(); + // This use is no longer unsafe. + if (NumUnsafeUses) + --*NumUnsafeUses; + } +}; + +struct DevirtModule { + Module &M; + IntegerType *Int8Ty; + PointerType *Int8PtrTy; + IntegerType *Int32Ty; + + MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; + + // This map keeps track of the number of "unsafe" uses of a loaded function + // pointer. The key is the associated llvm.type.test intrinsic call generated + // by this pass. An unsafe use is one that calls the loaded function pointer + // directly. Every time we eliminate an unsafe use (for example, by + // devirtualizing it or by applying virtual constant propagation), we + // decrement the value stored in this map. If a value reaches zero, we can + // eliminate the type check by RAUWing the associated llvm.type.test call with + // true. + std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; + + DevirtModule(Module &M) + : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), + Int8PtrTy(Type::getInt8PtrTy(M.getContext())), + Int32Ty(Type::getInt32Ty(M.getContext())) {} + + void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc); + void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); + + void buildTypeIdentifierMap( + std::vector<VTableBits> &Bits, + DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); + bool + tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, + const std::set<TypeMemberInfo> &TypeMemberInfos, + uint64_t ByteOffset); + bool trySingleImplDevirt(ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallSite> CallSites); + bool tryEvaluateFunctionsWithArgs( + MutableArrayRef<VirtualCallTarget> TargetsForSlot, + ArrayRef<ConstantInt *> Args); + bool tryUniformRetValOpt(IntegerType *RetType, + ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallSite> CallSites); + bool tryUniqueRetValOpt(unsigned BitWidth, + ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallSite> CallSites); + bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, + ArrayRef<VirtualCallSite> CallSites); + + void rebuildGlobal(VTableBits &B); + + bool run(); +}; + +struct WholeProgramDevirt : public ModulePass { + static char ID; + WholeProgramDevirt() : ModulePass(ID) { + initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); + } + bool runOnModule(Module &M) { + if (skipModule(M)) + return false; + + return DevirtModule(M).run(); + } +}; + +} // anonymous namespace + +INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt", + "Whole program devirtualization", false, false) +char WholeProgramDevirt::ID = 0; + +ModulePass *llvm::createWholeProgramDevirtPass() { + return new WholeProgramDevirt; +} + +PreservedAnalyses WholeProgramDevirtPass::run(Module &M, + ModuleAnalysisManager &) { + if (!DevirtModule(M).run()) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +void DevirtModule::buildTypeIdentifierMap( + std::vector<VTableBits> &Bits, + DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { + DenseMap<GlobalVariable *, VTableBits *> GVToBits; + Bits.reserve(M.getGlobalList().size()); + SmallVector<MDNode *, 2> Types; + for (GlobalVariable &GV : M.globals()) { + Types.clear(); + GV.getMetadata(LLVMContext::MD_type, Types); + if (Types.empty()) + continue; + + VTableBits *&BitsPtr = GVToBits[&GV]; + if (!BitsPtr) { + Bits.emplace_back(); + Bits.back().GV = &GV; + Bits.back().ObjectSize = + M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); + BitsPtr = &Bits.back(); + } + + for (MDNode *Type : Types) { + auto TypeID = Type->getOperand(1).get(); + + uint64_t Offset = + cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); + + TypeIdMap[TypeID].insert({BitsPtr, Offset}); + } + } +} + +bool DevirtModule::tryFindVirtualCallTargets( + std::vector<VirtualCallTarget> &TargetsForSlot, + const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) { + for (const TypeMemberInfo &TM : TypeMemberInfos) { + if (!TM.Bits->GV->isConstant()) + return false; + + auto Init = dyn_cast<ConstantArray>(TM.Bits->GV->getInitializer()); + if (!Init) + return false; + ArrayType *VTableTy = Init->getType(); + + uint64_t ElemSize = + M.getDataLayout().getTypeAllocSize(VTableTy->getElementType()); + uint64_t GlobalSlotOffset = TM.Offset + ByteOffset; + if (GlobalSlotOffset % ElemSize != 0) + return false; + + unsigned Op = GlobalSlotOffset / ElemSize; + if (Op >= Init->getNumOperands()) + return false; + + auto Fn = dyn_cast<Function>(Init->getOperand(Op)->stripPointerCasts()); + if (!Fn) + return false; + + // We can disregard __cxa_pure_virtual as a possible call target, as + // calls to pure virtuals are UB. + if (Fn->getName() == "__cxa_pure_virtual") + continue; + + TargetsForSlot.push_back({Fn, &TM}); + } + + // Give up if we couldn't find any targets. + return !TargetsForSlot.empty(); +} + +bool DevirtModule::trySingleImplDevirt( + ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallSite> CallSites) { + // See if the program contains a single implementation of this virtual + // function. + Function *TheFn = TargetsForSlot[0].Fn; + for (auto &&Target : TargetsForSlot) + if (TheFn != Target.Fn) + return false; + + // If so, update each call site to call that implementation directly. + for (auto &&VCallSite : CallSites) { + VCallSite.emitRemark(); + VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( + TheFn, VCallSite.CS.getCalledValue()->getType())); + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; + } + return true; +} + +bool DevirtModule::tryEvaluateFunctionsWithArgs( + MutableArrayRef<VirtualCallTarget> TargetsForSlot, + ArrayRef<ConstantInt *> Args) { + // Evaluate each function and store the result in each target's RetVal + // field. + for (VirtualCallTarget &Target : TargetsForSlot) { + if (Target.Fn->arg_size() != Args.size() + 1) + return false; + for (unsigned I = 0; I != Args.size(); ++I) + if (Target.Fn->getFunctionType()->getParamType(I + 1) != + Args[I]->getType()) + return false; + + Evaluator Eval(M.getDataLayout(), nullptr); + SmallVector<Constant *, 2> EvalArgs; + EvalArgs.push_back( + Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); + EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end()); + Constant *RetVal; + if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || + !isa<ConstantInt>(RetVal)) + return false; + Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); + } + return true; +} + +bool DevirtModule::tryUniformRetValOpt( + IntegerType *RetType, ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallSite> CallSites) { + // Uniform return value optimization. If all functions return the same + // constant, replace all calls with that constant. + uint64_t TheRetVal = TargetsForSlot[0].RetVal; + for (const VirtualCallTarget &Target : TargetsForSlot) + if (Target.RetVal != TheRetVal) + return false; + + auto TheRetValConst = ConstantInt::get(RetType, TheRetVal); + for (auto Call : CallSites) + Call.replaceAndErase(TheRetValConst); + return true; +} + +bool DevirtModule::tryUniqueRetValOpt( + unsigned BitWidth, ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallSite> CallSites) { + // IsOne controls whether we look for a 0 or a 1. + auto tryUniqueRetValOptFor = [&](bool IsOne) { + const TypeMemberInfo *UniqueMember = 0; + for (const VirtualCallTarget &Target : TargetsForSlot) { + if (Target.RetVal == (IsOne ? 1 : 0)) { + if (UniqueMember) + return false; + UniqueMember = Target.TM; + } + } + + // We should have found a unique member or bailed out by now. We already + // checked for a uniform return value in tryUniformRetValOpt. + assert(UniqueMember); + + // Replace each call with the comparison. + for (auto &&Call : CallSites) { + IRBuilder<> B(Call.CS.getInstruction()); + Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy); + OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset); + Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + Call.VTable, OneAddr); + Call.replaceAndErase(Cmp); + } + return true; + }; + + if (BitWidth == 1) { + if (tryUniqueRetValOptFor(true)) + return true; + if (tryUniqueRetValOptFor(false)) + return true; + } + return false; +} + +bool DevirtModule::tryVirtualConstProp( + MutableArrayRef<VirtualCallTarget> TargetsForSlot, + ArrayRef<VirtualCallSite> CallSites) { + // This only works if the function returns an integer. + auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); + if (!RetType) + return false; + unsigned BitWidth = RetType->getBitWidth(); + if (BitWidth > 64) + return false; + + // Make sure that each function does not access memory, takes at least one + // argument, does not use its first argument (which we assume is 'this'), + // and has the same return type. + for (VirtualCallTarget &Target : TargetsForSlot) { + if (!Target.Fn->doesNotAccessMemory() || Target.Fn->arg_empty() || + !Target.Fn->arg_begin()->use_empty() || + Target.Fn->getReturnType() != RetType) + return false; + } + + // Group call sites by the list of constant arguments they pass. + // The comparator ensures deterministic ordering. + struct ByAPIntValue { + bool operator()(const std::vector<ConstantInt *> &A, + const std::vector<ConstantInt *> &B) const { + return std::lexicographical_compare( + A.begin(), A.end(), B.begin(), B.end(), + [](ConstantInt *AI, ConstantInt *BI) { + return AI->getValue().ult(BI->getValue()); + }); + } + }; + std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>, + ByAPIntValue> + VCallSitesByConstantArg; + for (auto &&VCallSite : CallSites) { + std::vector<ConstantInt *> Args; + if (VCallSite.CS.getType() != RetType) + continue; + for (auto &&Arg : + make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) { + if (!isa<ConstantInt>(Arg)) + break; + Args.push_back(cast<ConstantInt>(&Arg)); + } + if (Args.size() + 1 != VCallSite.CS.arg_size()) + continue; + + VCallSitesByConstantArg[Args].push_back(VCallSite); + } + + for (auto &&CSByConstantArg : VCallSitesByConstantArg) { + if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) + continue; + + if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second)) + continue; + + if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) + continue; + + // Find an allocation offset in bits in all vtables associated with the + // type. + uint64_t AllocBefore = + findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); + uint64_t AllocAfter = + findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); + + // Calculate the total amount of padding needed to store a value at both + // ends of the object. + uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; + for (auto &&Target : TargetsForSlot) { + TotalPaddingBefore += std::max<int64_t>( + (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); + TotalPaddingAfter += std::max<int64_t>( + (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); + } + + // If the amount of padding is too large, give up. + // FIXME: do something smarter here. + if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) + continue; + + // Calculate the offset to the value as a (possibly negative) byte offset + // and (if applicable) a bit offset, and store the values in the targets. + int64_t OffsetByte; + uint64_t OffsetBit; + if (TotalPaddingBefore <= TotalPaddingAfter) + setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, + OffsetBit); + else + setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, + OffsetBit); + + // Rewrite each call to a load from OffsetByte/OffsetBit. + for (auto Call : CSByConstantArg.second) { + IRBuilder<> B(Call.CS.getInstruction()); + Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte); + if (BitWidth == 1) { + Value *Bits = B.CreateLoad(Addr); + Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); + Value *BitsAndBit = B.CreateAnd(Bits, Bit); + auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); + Call.replaceAndErase(IsBitSet); + } else { + Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); + Value *Val = B.CreateLoad(RetType, ValAddr); + Call.replaceAndErase(Val); + } + } + } + return true; +} + +void DevirtModule::rebuildGlobal(VTableBits &B) { + if (B.Before.Bytes.empty() && B.After.Bytes.empty()) + return; + + // Align each byte array to pointer width. + unsigned PointerSize = M.getDataLayout().getPointerSize(); + B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize)); + B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize)); + + // Before was stored in reverse order; flip it now. + for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) + std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); + + // Build an anonymous global containing the before bytes, followed by the + // original initializer, followed by the after bytes. + auto NewInit = ConstantStruct::getAnon( + {ConstantDataArray::get(M.getContext(), B.Before.Bytes), + B.GV->getInitializer(), + ConstantDataArray::get(M.getContext(), B.After.Bytes)}); + auto NewGV = + new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), + GlobalVariable::PrivateLinkage, NewInit, "", B.GV); + NewGV->setSection(B.GV->getSection()); + NewGV->setComdat(B.GV->getComdat()); + + // Copy the original vtable's metadata to the anonymous global, adjusting + // offsets as required. + NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); + + // Build an alias named after the original global, pointing at the second + // element (the original initializer). + auto Alias = GlobalAlias::create( + B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", + ConstantExpr::getGetElementPtr( + NewInit->getType(), NewGV, + ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), + ConstantInt::get(Int32Ty, 1)}), + &M); + Alias->setVisibility(B.GV->getVisibility()); + Alias->takeName(B.GV); + + B.GV->replaceAllUsesWith(Alias); + B.GV->eraseFromParent(); +} + +void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, + Function *AssumeFunc) { + // Find all virtual calls via a virtual table pointer %p under an assumption + // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p + // points to a member of the type identifier %md. Group calls by (type ID, + // offset) pair (effectively the identity of the virtual function) and store + // to CallSlots. + DenseSet<Value *> SeenPtrs; + for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); + I != E;) { + auto CI = dyn_cast<CallInst>(I->getUser()); + ++I; + if (!CI) + continue; + + // Search for virtual calls based on %p and add them to DevirtCalls. + SmallVector<DevirtCallSite, 1> DevirtCalls; + SmallVector<CallInst *, 1> Assumes; + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); + + // If we found any, add them to CallSlots. Only do this if we haven't seen + // the vtable pointer before, as it may have been CSE'd with pointers from + // other call sites, and we don't want to process call sites multiple times. + if (!Assumes.empty()) { + Metadata *TypeId = + cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); + Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); + if (SeenPtrs.insert(Ptr).second) { + for (DevirtCallSite Call : DevirtCalls) { + CallSlots[{TypeId, Call.Offset}].push_back( + {CI->getArgOperand(0), Call.CS, nullptr}); + } + } + } + + // We no longer need the assumes or the type test. + for (auto Assume : Assumes) + Assume->eraseFromParent(); + // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we + // may use the vtable argument later. + if (CI->use_empty()) + CI->eraseFromParent(); + } +} + +void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { + Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); + + for (auto I = TypeCheckedLoadFunc->use_begin(), + E = TypeCheckedLoadFunc->use_end(); + I != E;) { + auto CI = dyn_cast<CallInst>(I->getUser()); + ++I; + if (!CI) + continue; + + Value *Ptr = CI->getArgOperand(0); + Value *Offset = CI->getArgOperand(1); + Value *TypeIdValue = CI->getArgOperand(2); + Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); + + SmallVector<DevirtCallSite, 1> DevirtCalls; + SmallVector<Instruction *, 1> LoadedPtrs; + SmallVector<Instruction *, 1> Preds; + bool HasNonCallUses = false; + findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, + HasNonCallUses, CI); + + // Start by generating "pessimistic" code that explicitly loads the function + // pointer from the vtable and performs the type check. If possible, we will + // eliminate the load and the type check later. + + // If possible, only generate the load at the point where it is used. + // This helps avoid unnecessary spills. + IRBuilder<> LoadB( + (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); + Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); + Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); + Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + + for (Instruction *LoadedPtr : LoadedPtrs) { + LoadedPtr->replaceAllUsesWith(LoadedValue); + LoadedPtr->eraseFromParent(); + } + + // Likewise for the type test. + IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); + CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); + + for (Instruction *Pred : Preds) { + Pred->replaceAllUsesWith(TypeTestCall); + Pred->eraseFromParent(); + } + + // We have already erased any extractvalue instructions that refer to the + // intrinsic call, but the intrinsic may have other non-extractvalue uses + // (although this is unlikely). In that case, explicitly build a pair and + // RAUW it. + if (!CI->use_empty()) { + Value *Pair = UndefValue::get(CI->getType()); + IRBuilder<> B(CI); + Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); + Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); + CI->replaceAllUsesWith(Pair); + } + + // The number of unsafe uses is initially the number of uses. + auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; + NumUnsafeUses = DevirtCalls.size(); + + // If the function pointer has a non-call user, we cannot eliminate the type + // check, as one of those users may eventually call the pointer. Increment + // the unsafe use count to make sure it cannot reach zero. + if (HasNonCallUses) + ++NumUnsafeUses; + for (DevirtCallSite Call : DevirtCalls) { + CallSlots[{TypeId, Call.Offset}].push_back( + {Ptr, Call.CS, &NumUnsafeUses}); + } + + CI->eraseFromParent(); + } +} + +bool DevirtModule::run() { + Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); + Function *TypeCheckedLoadFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); + + if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || + AssumeFunc->use_empty()) && + (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) + return false; + + if (TypeTestFunc && AssumeFunc) + scanTypeTestUsers(TypeTestFunc, AssumeFunc); + + if (TypeCheckedLoadFunc) + scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); + + // Rebuild type metadata into a map for easy lookup. + std::vector<VTableBits> Bits; + DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; + buildTypeIdentifierMap(Bits, TypeIdMap); + if (TypeIdMap.empty()) + return true; + + // For each (type, offset) pair: + bool DidVirtualConstProp = false; + for (auto &S : CallSlots) { + // Search each of the members of the type identifier for the virtual + // function implementation at offset S.first.ByteOffset, and add to + // TargetsForSlot. + std::vector<VirtualCallTarget> TargetsForSlot; + if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], + S.first.ByteOffset)) + continue; + + if (trySingleImplDevirt(TargetsForSlot, S.second)) + continue; + + DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second); + } + + // If we were able to eliminate all unsafe uses for a type checked load, + // eliminate the type test by replacing it with true. + if (TypeCheckedLoadFunc) { + auto True = ConstantInt::getTrue(M.getContext()); + for (auto &&U : NumUnsafeUsesForTypeTest) { + if (U.second == 0) { + U.first->replaceAllUsesWith(True); + U.first->eraseFromParent(); + } + } + } + + // Rebuild each global we touched as part of virtual constant propagation to + // include the before and after bytes. + if (DidVirtualConstProp) + for (VTableBits &B : Bits) + rebuildGlobal(B); + + return true; +} diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 6f49399f57bf..221a22007173 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -58,7 +58,6 @@ namespace { // operators inevitably call FAddendCoef's constructor which is not cheap. void operator=(const FAddendCoef &A); void operator+=(const FAddendCoef &A); - void operator-=(const FAddendCoef &A); void operator*=(const FAddendCoef &S); bool isOne() const { return isInt() && IntVal == 1; } @@ -123,11 +122,18 @@ namespace { bool isConstant() const { return Val == nullptr; } bool isZero() const { return Coeff.isZero(); } - void set(short Coefficient, Value *V) { Coeff.set(Coefficient), Val = V; } - void set(const APFloat& Coefficient, Value *V) - { Coeff.set(Coefficient); Val = V; } - void set(const ConstantFP* Coefficient, Value *V) - { Coeff.set(Coefficient->getValueAPF()); Val = V; } + void set(short Coefficient, Value *V) { + Coeff.set(Coefficient); + Val = V; + } + void set(const APFloat &Coefficient, Value *V) { + Coeff.set(Coefficient); + Val = V; + } + void set(const ConstantFP *Coefficient, Value *V) { + Coeff.set(Coefficient->getValueAPF()); + Val = V; + } void negate() { Coeff.negate(); } @@ -272,27 +278,6 @@ void FAddendCoef::operator+=(const FAddendCoef &That) { T.add(createAPFloatFromInt(T.getSemantics(), That.IntVal), RndMode); } -void FAddendCoef::operator-=(const FAddendCoef &That) { - enum APFloat::roundingMode RndMode = APFloat::rmNearestTiesToEven; - if (isInt() == That.isInt()) { - if (isInt()) - IntVal -= That.IntVal; - else - getFpVal().subtract(That.getFpVal(), RndMode); - return; - } - - if (isInt()) { - const APFloat &T = That.getFpVal(); - convertToFpType(T.getSemantics()); - getFpVal().subtract(T, RndMode); - return; - } - - APFloat &T = getFpVal(); - T.subtract(createAPFloatFromInt(T.getSemantics(), IntVal), RndMode); -} - void FAddendCoef::operator*=(const FAddendCoef &That) { if (That.isOne()) return; @@ -321,8 +306,6 @@ void FAddendCoef::operator*=(const FAddendCoef &That) { APFloat::rmNearestTiesToEven); else F0.multiply(That.getFpVal(), APFloat::rmNearestTiesToEven); - - return; } void FAddendCoef::negate() { @@ -716,10 +699,9 @@ Value *FAddCombine::createNaryFAdd bool LastValNeedNeg = false; // Iterate the addends, creating fadd/fsub using adjacent two addends. - for (AddendVect::const_iterator I = Opnds.begin(), E = Opnds.end(); - I != E; I++) { + for (const FAddend *Opnd : Opnds) { bool NeedNeg; - Value *V = createAddendVal(**I, NeedNeg); + Value *V = createAddendVal(*Opnd, NeedNeg); if (!LastVal) { LastVal = V; LastValNeedNeg = NeedNeg; @@ -808,9 +790,7 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { unsigned NegOpndNum = 0; // Adjust the number of instructions needed to emit the N-ary add. - for (AddendVect::const_iterator I = Opnds.begin(), E = Opnds.end(); - I != E; I++) { - const FAddend *Opnd = *I; + for (const FAddend *Opnd : Opnds) { if (Opnd->isConstant()) continue; @@ -1052,22 +1032,26 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + const APInt *Val; + if (match(RHS, m_APInt(Val))) { // X + (signbit) --> X ^ signbit - const APInt &Val = CI->getValue(); - if (Val.isSignBit()) + if (Val->isSignBit()) return BinaryOperator::CreateXor(LHS, RHS); + } + // FIXME: Use the match above instead of dyn_cast to allow these transforms + // for splat vectors. + if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { // See if SimplifyDemandedBits can simplify this. This handles stuff like // (X & 254)+1 -> (X&254)|1 if (SimplifyDemandedInstructionBits(I)) @@ -1157,7 +1141,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateSub(LHS, V); if (Value *V = checkForNegativeOperand(I, Builder)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, AC, &I, DT)) @@ -1169,6 +1153,9 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateSub(SubOne(CRHS), X); } + // FIXME: We already did a check for ConstantInt RHS above this. + // FIXME: Is this pattern covered by another fold? No regression tests fail on + // removal. if (ConstantInt *CRHS = dyn_cast<ConstantInt>(RHS)) { // (X & FF00) + xx00 -> (X+xx00) & FF00 Value *X; @@ -1317,11 +1304,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (isa<Constant>(RHS)) { if (isa<PHINode>(LHS)) @@ -1415,7 +1402,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); } return Changed ? &I : nullptr; @@ -1493,15 +1480,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // If this is a 'B = x-(-A)', change to B = x+A. if (Value *V = dyn_castNegVal(Op1)) { @@ -1667,13 +1654,13 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && match(Op1, m_PtrToInt(m_Value(RHSOp)))) if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) - return ReplaceInstUsesWith(I, Res); + return replaceInstUsesWith(I, Res); // trunc(p)-trunc(q) -> trunc(p-q) if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) - return ReplaceInstUsesWith(I, Res); + return replaceInstUsesWith(I, Res); bool Changed = false; if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, I)) { @@ -1692,11 +1679,11 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // fsub nsz 0, X ==> fsub nsz -0.0, X if (I.getFastMathFlags().noSignedZeros() && match(Op0, m_Zero())) { @@ -1736,7 +1723,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (I.hasUnsafeAlgebra()) { if (Value *V = FAddCombine(Builder).simplify(&I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); } return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 76cefd97cd8f..1a6459b3d689 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -39,30 +39,29 @@ static inline Value *dyn_castNotVal(Value *V) { } /// Similar to getICmpCode but for FCmpInst. This encodes a fcmp predicate into -/// a three bit mask. It also returns whether it is an ordered predicate by -/// reference. -static unsigned getFCmpCode(FCmpInst::Predicate CC, bool &isOrdered) { - isOrdered = false; - switch (CC) { - case FCmpInst::FCMP_ORD: isOrdered = true; return 0; // 000 - case FCmpInst::FCMP_UNO: return 0; // 000 - case FCmpInst::FCMP_OGT: isOrdered = true; return 1; // 001 - case FCmpInst::FCMP_UGT: return 1; // 001 - case FCmpInst::FCMP_OEQ: isOrdered = true; return 2; // 010 - case FCmpInst::FCMP_UEQ: return 2; // 010 - case FCmpInst::FCMP_OGE: isOrdered = true; return 3; // 011 - case FCmpInst::FCMP_UGE: return 3; // 011 - case FCmpInst::FCMP_OLT: isOrdered = true; return 4; // 100 - case FCmpInst::FCMP_ULT: return 4; // 100 - case FCmpInst::FCMP_ONE: isOrdered = true; return 5; // 101 - case FCmpInst::FCMP_UNE: return 5; // 101 - case FCmpInst::FCMP_OLE: isOrdered = true; return 6; // 110 - case FCmpInst::FCMP_ULE: return 6; // 110 - // True -> 7 - default: - // Not expecting FCMP_FALSE and FCMP_TRUE; - llvm_unreachable("Unexpected FCmp predicate!"); - } +/// a four bit mask. +static unsigned getFCmpCode(FCmpInst::Predicate CC) { + assert(FCmpInst::FCMP_FALSE <= CC && CC <= FCmpInst::FCMP_TRUE && + "Unexpected FCmp predicate!"); + // Take advantage of the bit pattern of FCmpInst::Predicate here. + // U L G E + static_assert(FCmpInst::FCMP_FALSE == 0, ""); // 0 0 0 0 + static_assert(FCmpInst::FCMP_OEQ == 1, ""); // 0 0 0 1 + static_assert(FCmpInst::FCMP_OGT == 2, ""); // 0 0 1 0 + static_assert(FCmpInst::FCMP_OGE == 3, ""); // 0 0 1 1 + static_assert(FCmpInst::FCMP_OLT == 4, ""); // 0 1 0 0 + static_assert(FCmpInst::FCMP_OLE == 5, ""); // 0 1 0 1 + static_assert(FCmpInst::FCMP_ONE == 6, ""); // 0 1 1 0 + static_assert(FCmpInst::FCMP_ORD == 7, ""); // 0 1 1 1 + static_assert(FCmpInst::FCMP_UNO == 8, ""); // 1 0 0 0 + static_assert(FCmpInst::FCMP_UEQ == 9, ""); // 1 0 0 1 + static_assert(FCmpInst::FCMP_UGT == 10, ""); // 1 0 1 0 + static_assert(FCmpInst::FCMP_UGE == 11, ""); // 1 0 1 1 + static_assert(FCmpInst::FCMP_ULT == 12, ""); // 1 1 0 0 + static_assert(FCmpInst::FCMP_ULE == 13, ""); // 1 1 0 1 + static_assert(FCmpInst::FCMP_UNE == 14, ""); // 1 1 1 0 + static_assert(FCmpInst::FCMP_TRUE == 15, ""); // 1 1 1 1 + return CC; } /// This is the complement of getICmpCode, which turns an opcode and two @@ -78,26 +77,16 @@ static Value *getNewICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, } /// This is the complement of getFCmpCode, which turns an opcode and two -/// operands into either a FCmp instruction. isordered is passed in to determine -/// which kind of predicate to use in the new fcmp instruction. -static Value *getFCmpValue(bool isordered, unsigned code, - Value *LHS, Value *RHS, +/// operands into either a FCmp instruction, or a true/false constant. +static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, InstCombiner::BuilderTy *Builder) { - CmpInst::Predicate Pred; - switch (code) { - default: llvm_unreachable("Illegal FCmp code!"); - case 0: Pred = isordered ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO; break; - case 1: Pred = isordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; break; - case 2: Pred = isordered ? FCmpInst::FCMP_OEQ : FCmpInst::FCMP_UEQ; break; - case 3: Pred = isordered ? FCmpInst::FCMP_OGE : FCmpInst::FCMP_UGE; break; - case 4: Pred = isordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; break; - case 5: Pred = isordered ? FCmpInst::FCMP_ONE : FCmpInst::FCMP_UNE; break; - case 6: Pred = isordered ? FCmpInst::FCMP_OLE : FCmpInst::FCMP_ULE; break; - case 7: - if (!isordered) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); - Pred = FCmpInst::FCMP_ORD; break; - } + const auto Pred = static_cast<FCmpInst::Predicate>(Code); + assert(FCmpInst::FCMP_FALSE <= Pred && Pred <= FCmpInst::FCMP_TRUE && + "Unexpected FCmp predicate!"); + if (Pred == FCmpInst::FCMP_FALSE) + return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); + if (Pred == FCmpInst::FCMP_TRUE) + return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); return Builder->CreateFCmp(Pred, LHS, RHS); } @@ -243,7 +232,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, if (CI->getValue() == ShlMask) // Masking out bits that the shift already masks. - return ReplaceInstUsesWith(TheAnd, Op); // No need for the and. + return replaceInstUsesWith(TheAnd, Op); // No need for the and. if (CI != AndRHS) { // Reducing bits set in and. TheAnd.setOperand(1, CI); @@ -263,7 +252,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, if (CI->getValue() == ShrMask) // Masking out bits that the shift already masks. - return ReplaceInstUsesWith(TheAnd, Op); + return replaceInstUsesWith(TheAnd, Op); if (CI != AndRHS) { TheAnd.setOperand(1, CI); // Reduce bits set in and cst. @@ -465,11 +454,9 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, if (CCst && CCst->isZero()) { // if C is zero, then both A and B qualify as mask result |= (icmp_eq ? (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_Mask_AllZeroes | FoldMskICmp_AMask_Mixed | FoldMskICmp_BMask_Mixed) : (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_AMask_NotMixed | FoldMskICmp_BMask_NotMixed)); if (icmp_abit) @@ -666,7 +653,7 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, if (!ICmpInst::isEquality(RHSCC)) return 0; - // Look for ANDs in on the right side of the RHS icmp. + // Look for ANDs on the right side of the RHS icmp. if (!ok && R2->getType()->isIntegerTy()) { if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { R11 = R2; @@ -694,9 +681,9 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, B = L21; C = L1; } - unsigned left_type = getTypeOfMaskedICmp(A, B, C, LHSCC); - unsigned right_type = getTypeOfMaskedICmp(A, D, E, RHSCC); - return left_type & right_type; + unsigned LeftType = getTypeOfMaskedICmp(A, B, C, LHSCC); + unsigned RightType = getTypeOfMaskedICmp(A, D, E, RHSCC); + return LeftType & RightType; } /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) @@ -705,9 +692,9 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, llvm::InstCombiner::BuilderTy *Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); - unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, + unsigned Mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, LHSCC, RHSCC); - if (mask == 0) return nullptr; + if (Mask == 0) return nullptr; assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) && "foldLogOpOfMaskedICmpsHelper must return an equality predicate."); @@ -723,48 +710,48 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // input and output). // In most cases we're going to produce an EQ for the "&&" case. - ICmpInst::Predicate NEWCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; if (!IsAnd) { // Convert the masking analysis into its equivalent with negated // comparisons. - mask = conjugateICmpMask(mask); + Mask = conjugateICmpMask(Mask); } - if (mask & FoldMskICmp_Mask_AllZeroes) { + if (Mask & FoldMskICmp_Mask_AllZeroes) { // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) // -> (icmp eq (A & (B|D)), 0) - Value *newOr = Builder->CreateOr(B, D); - Value *newAnd = Builder->CreateAnd(A, newOr); - // we can't use C as zero, because we might actually handle + Value *NewOr = Builder->CreateOr(B, D); + Value *NewAnd = Builder->CreateAnd(A, NewOr); + // We can't use C as zero because we might actually handle // (icmp ne (A & B), B) & (icmp ne (A & D), D) - // with B and D, having a single bit set - Value *zero = Constant::getNullValue(A->getType()); - return Builder->CreateICmp(NEWCC, newAnd, zero); + // with B and D, having a single bit set. + Value *Zero = Constant::getNullValue(A->getType()); + return Builder->CreateICmp(NewCC, NewAnd, Zero); } - if (mask & FoldMskICmp_BMask_AllOnes) { + if (Mask & FoldMskICmp_BMask_AllOnes) { // (icmp eq (A & B), B) & (icmp eq (A & D), D) // -> (icmp eq (A & (B|D)), (B|D)) - Value *newOr = Builder->CreateOr(B, D); - Value *newAnd = Builder->CreateAnd(A, newOr); - return Builder->CreateICmp(NEWCC, newAnd, newOr); + Value *NewOr = Builder->CreateOr(B, D); + Value *NewAnd = Builder->CreateAnd(A, NewOr); + return Builder->CreateICmp(NewCC, NewAnd, NewOr); } - if (mask & FoldMskICmp_AMask_AllOnes) { + if (Mask & FoldMskICmp_AMask_AllOnes) { // (icmp eq (A & B), A) & (icmp eq (A & D), A) // -> (icmp eq (A & (B&D)), A) - Value *newAnd1 = Builder->CreateAnd(B, D); - Value *newAnd = Builder->CreateAnd(A, newAnd1); - return Builder->CreateICmp(NEWCC, newAnd, A); + Value *NewAnd1 = Builder->CreateAnd(B, D); + Value *NewAnd2 = Builder->CreateAnd(A, NewAnd1); + return Builder->CreateICmp(NewCC, NewAnd2, A); } // Remaining cases assume at least that B and D are constant, and depend on - // their actual values. This isn't strictly, necessary, just a "handle the + // their actual values. This isn't strictly necessary, just a "handle the // easy cases for now" decision. ConstantInt *BCst = dyn_cast<ConstantInt>(B); if (!BCst) return nullptr; ConstantInt *DCst = dyn_cast<ConstantInt>(D); if (!DCst) return nullptr; - if (mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) { + if (Mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) { // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) @@ -777,7 +764,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (mask & FoldMskICmp_AMask_NotAllOnes) { + if (Mask & FoldMskICmp_AMask_NotAllOnes) { // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) // Only valid if one of the masks is a superset of the other (check "B|D" is @@ -789,7 +776,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (mask & FoldMskICmp_BMask_Mixed) { + if (Mask & FoldMskICmp_BMask_Mixed) { // (icmp eq (A & B), C) & (icmp eq (A & D), E) // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -797,26 +784,26 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // contradict, then we can transform to // -> (icmp eq (A & (B|D)), (C|E)) // Currently, we only handle the case of B, C, D, and E being constant. - // we can't simply use C and E, because we might actually handle + // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) - // with B and D, having a single bit set + // with B and D, having a single bit set. ConstantInt *CCst = dyn_cast<ConstantInt>(C); if (!CCst) return nullptr; ConstantInt *ECst = dyn_cast<ConstantInt>(E); if (!ECst) return nullptr; - if (LHSCC != NEWCC) + if (LHSCC != NewCC) CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); - if (RHSCC != NEWCC) + if (RHSCC != NewCC) ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); - // if there is a conflict we should actually return a false for the - // whole construct + // If there is a conflict, we should actually return a false for the + // whole construct. if (((BCst->getValue() & DCst->getValue()) & (CCst->getValue() ^ ECst->getValue())) != 0) return ConstantInt::get(LHS->getType(), !IsAnd); - Value *newOr1 = Builder->CreateOr(B, D); - Value *newOr2 = ConstantExpr::getOr(CCst, ECst); - Value *newAnd = Builder->CreateAnd(A, newOr1); - return Builder->CreateICmp(NEWCC, newAnd, newOr2); + Value *NewOr1 = Builder->CreateOr(B, D); + Value *NewOr2 = ConstantExpr::getOr(CCst, ECst); + Value *NewAnd = Builder->CreateAnd(A, NewOr1); + return Builder->CreateICmp(NewCC, NewAnd, NewOr2); } return nullptr; } @@ -915,15 +902,10 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (LHSCst == RHSCst && LHSCC == RHSCC) { // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C) - // where C is a power of 2 - if (LHSCC == ICmpInst::ICMP_ULT && - LHSCst->getValue().isPowerOf2()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); - } - + // where C is a power of 2 or // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - if (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero()) { + if ((LHSCC == ICmpInst::ICMP_ULT && LHSCst->getValue().isPowerOf2()) || + (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero())) { Value *NewOr = Builder->CreateOr(Val, Val2); return Builder->CreateICmp(LHSCC, NewOr, LHSCst); } @@ -975,16 +957,6 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) return nullptr; - // Make a constant range that's the intersection of the two icmp ranges. - // If the intersection is empty, we know that the result is false. - ConstantRange LHSRange = - ConstantRange::makeAllowedICmpRegion(LHSCC, LHSCst->getValue()); - ConstantRange RHSRange = - ConstantRange::makeAllowedICmpRegion(RHSCC, RHSCst->getValue()); - - if (LHSRange.intersectWith(RHSRange).isEmptySet()) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - // We can't fold (ugt x, C) & (sgt x, C2). if (!PredicatesFoldable(LHSCC, RHSCC)) return nullptr; @@ -1124,6 +1096,29 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { /// Optimize (fcmp)&(fcmp). NOTE: Unlike the rest of instcombine, this returns /// a Value which should already be inserted into the function. Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { + Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); + Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); + FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); + + if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { + // Swap RHS operands to match LHS. + Op1CC = FCmpInst::getSwappedPredicate(Op1CC); + std::swap(Op1LHS, Op1RHS); + } + + // Simplify (fcmp cc0 x, y) & (fcmp cc1 x, y). + // Suppose the relation between x and y is R, where R is one of + // U(1000), L(0100), G(0010) or E(0001), and CC0 and CC1 are the bitmasks for + // testing the desired relations. + // + // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: + // bool(R & CC0) && bool(R & CC1) + // = bool((R & CC0) & (R & CC1)) + // = bool(R & (CC0 & CC1)) <= by re-association, commutation, and idempotency + if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) + return getFCmpValue(getFCmpCode(Op0CC) & getFCmpCode(Op1CC), Op0LHS, Op0RHS, + Builder); + if (LHS->getPredicate() == FCmpInst::FCMP_ORD && RHS->getPredicate() == FCmpInst::FCMP_ORD) { if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) @@ -1147,56 +1142,6 @@ Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { return nullptr; } - Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); - Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); - FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); - - - if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { - // Swap RHS operands to match LHS. - Op1CC = FCmpInst::getSwappedPredicate(Op1CC); - std::swap(Op1LHS, Op1RHS); - } - - if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) { - // Simplify (fcmp cc0 x, y) & (fcmp cc1 x, y). - if (Op0CC == Op1CC) - return Builder->CreateFCmp((FCmpInst::Predicate)Op0CC, Op0LHS, Op0RHS); - if (Op0CC == FCmpInst::FCMP_FALSE || Op1CC == FCmpInst::FCMP_FALSE) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - if (Op0CC == FCmpInst::FCMP_TRUE) - return RHS; - if (Op1CC == FCmpInst::FCMP_TRUE) - return LHS; - - bool Op0Ordered; - bool Op1Ordered; - unsigned Op0Pred = getFCmpCode(Op0CC, Op0Ordered); - unsigned Op1Pred = getFCmpCode(Op1CC, Op1Ordered); - // uno && ord -> false - if (Op0Pred == 0 && Op1Pred == 0 && Op0Ordered != Op1Ordered) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - if (Op1Pred == 0) { - std::swap(LHS, RHS); - std::swap(Op0Pred, Op1Pred); - std::swap(Op0Ordered, Op1Ordered); - } - if (Op0Pred == 0) { - // uno && ueq -> uno && (uno || eq) -> uno - // ord && olt -> ord && (ord && lt) -> olt - if (!Op0Ordered && (Op0Ordered == Op1Ordered)) - return LHS; - if (Op0Ordered && (Op0Ordered == Op1Ordered)) - return RHS; - - // uno && oeq -> uno && (ord && eq) -> false - if (!Op0Ordered) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - // ord && ueq -> ord && (uno || eq) -> oeq - return getFCmpValue(true, Op1Pred, Op0LHS, Op0RHS, Builder); - } - } - return nullptr; } @@ -1248,19 +1193,131 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, return nullptr; } +Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { + auto LogicOpc = I.getOpcode(); + assert((LogicOpc == Instruction::And || LogicOpc == Instruction::Or || + LogicOpc == Instruction::Xor) && + "Unexpected opcode for bitwise logic folding"); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + CastInst *Cast0 = dyn_cast<CastInst>(Op0); + if (!Cast0) + return nullptr; + + // This must be a cast from an integer or integer vector source type to allow + // transformation of the logic operation to the source type. + Type *DestTy = I.getType(); + Type *SrcTy = Cast0->getSrcTy(); + if (!SrcTy->isIntOrIntVectorTy()) + return nullptr; + + // If one operand is a bitcast and the other is a constant, move the logic + // operation ahead of the bitcast. That is, do the logic operation in the + // original type. This can eliminate useless bitcasts and allow normal + // combines that would otherwise be impeded by the bitcast. Canonicalization + // ensures that if there is a constant operand, it will be the second operand. + Value *BC = nullptr; + Constant *C = nullptr; + if ((match(Op0, m_BitCast(m_Value(BC))) && match(Op1, m_Constant(C)))) { + Value *NewConstant = ConstantExpr::getBitCast(C, SrcTy); + Value *NewOp = Builder->CreateBinOp(LogicOpc, BC, NewConstant, I.getName()); + return CastInst::CreateBitOrPointerCast(NewOp, DestTy); + } + + CastInst *Cast1 = dyn_cast<CastInst>(Op1); + if (!Cast1) + return nullptr; + + // Both operands of the logic operation are casts. The casts must be of the + // same type for reduction. + auto CastOpcode = Cast0->getOpcode(); + if (CastOpcode != Cast1->getOpcode() || SrcTy != Cast1->getSrcTy()) + return nullptr; + + Value *Cast0Src = Cast0->getOperand(0); + Value *Cast1Src = Cast1->getOperand(0); + + // fold (logic (cast A), (cast B)) -> (cast (logic A, B)) + + // Only do this if the casts both really cause code to be generated. + if ((!isa<ICmpInst>(Cast0Src) || !isa<ICmpInst>(Cast1Src)) && + ShouldOptimizeCast(CastOpcode, Cast0Src, DestTy) && + ShouldOptimizeCast(CastOpcode, Cast1Src, DestTy)) { + Value *NewOp = Builder->CreateBinOp(LogicOpc, Cast0Src, Cast1Src, + I.getName()); + return CastInst::Create(CastOpcode, NewOp, DestTy); + } + + // For now, only 'and'/'or' have optimizations after this. + if (LogicOpc == Instruction::Xor) + return nullptr; + + // If this is logic(cast(icmp), cast(icmp)), try to fold this even if the + // cast is otherwise not optimizable. This happens for vector sexts. + ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src); + ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src); + if (ICmp0 && ICmp1) { + Value *Res = LogicOpc == Instruction::And ? FoldAndOfICmps(ICmp0, ICmp1) + : FoldOrOfICmps(ICmp0, ICmp1, &I); + if (Res) + return CastInst::Create(CastOpcode, Res, DestTy); + return nullptr; + } + + // If this is logic(cast(fcmp), cast(fcmp)), try to fold this even if the + // cast is otherwise not optimizable. This happens for vector sexts. + FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src); + FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src); + if (FCmp0 && FCmp1) { + Value *Res = LogicOpc == Instruction::And ? FoldAndOfFCmps(FCmp0, FCmp1) + : FoldOrOfFCmps(FCmp0, FCmp1); + if (Res) + return CastInst::Create(CastOpcode, Res, DestTy); + return nullptr; + } + + return nullptr; +} + +static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Canonicalize SExt or Not to the LHS + if (match(Op1, m_SExt(m_Value())) || match(Op1, m_Not(m_Value()))) { + std::swap(Op0, Op1); + } + + // Fold (and (sext bool to A), B) --> (select bool, B, 0) + Value *X = nullptr; + if (match(Op0, m_SExt(m_Value(X))) && + X->getType()->getScalarType()->isIntegerTy(1)) { + Value *Zero = Constant::getNullValue(Op1->getType()); + return SelectInst::Create(X, Op1, Zero); + } + + // Fold (and ~(sext bool to A), B) --> (select bool, 0, B) + if (match(Op0, m_Not(m_SExt(m_Value(X)))) && + X->getType()->getScalarType()->isIntegerTy(1)) { + Value *Zero = Constant::getNullValue(Op0->getType()); + return SelectInst::Create(X, Zero, Op1); + } + + return nullptr; +} + Instruction *InstCombiner::visitAnd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A|B)&(A|C) -> A|(B&C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -1268,7 +1325,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return &I; if (Value *V = SimplifyBSwap(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { const APInt &AndRHSMask = AndRHS->getValue(); @@ -1399,8 +1456,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { { Value *tmpOp0 = Op0; Value *tmpOp1 = Op1; - if (Op0->hasOneUse() && - match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_Value(B))))) { if (A == Op1 || B == Op1 ) { tmpOp1 = Op0; tmpOp0 = Op1; @@ -1408,12 +1464,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - if (tmpOp1->hasOneUse() && - match(tmpOp1, m_Xor(m_Value(A), m_Value(B)))) { + if (match(tmpOp1, m_OneUse(m_Xor(m_Value(A), m_Value(B))))) { if (B == tmpOp0) { std::swap(A, B); } - // Notice that the patten (A&(~B)) is actually (A&(-1^B)), so if + // Notice that the pattern (A&(~B)) is actually (A&(-1^B)), so if // A is originally -1 (or a vector of -1 and undefs), then we enter // an endless loop. By checking that A is non-constant we ensure that // we will never get to the loop. @@ -1458,7 +1513,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); if (LHS && RHS) if (Value *Res = FoldAndOfICmps(LHS, RHS)) - return ReplaceInstUsesWith(I, Res); + return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary // number of 'and' instructions might have to be created. @@ -1466,18 +1521,18 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = FoldAndOfICmps(LHS, Cmp)) - return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) if (Value *Res = FoldAndOfICmps(LHS, Cmp)) - return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); } if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = FoldAndOfICmps(Cmp, RHS)) - return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) if (Value *Res = FoldAndOfICmps(Cmp, RHS)) - return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); } } @@ -1485,92 +1540,46 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) if (Value *Res = FoldAndOfFCmps(LHS, RHS)) - return ReplaceInstUsesWith(I, Res); - - - if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { - Value *Op0COp = Op0C->getOperand(0); - Type *SrcTy = Op0COp->getType(); - // fold (and (cast A), (cast B)) -> (cast (and A, B)) - if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) { - if (Op0C->getOpcode() == Op1C->getOpcode() && // same cast kind ? - SrcTy == Op1C->getOperand(0)->getType() && - SrcTy->isIntOrIntVectorTy()) { - Value *Op1COp = Op1C->getOperand(0); - - // Only do this if the casts both really cause code to be generated. - if (ShouldOptimizeCast(Op0C->getOpcode(), Op0COp, I.getType()) && - ShouldOptimizeCast(Op1C->getOpcode(), Op1COp, I.getType())) { - Value *NewOp = Builder->CreateAnd(Op0COp, Op1COp, I.getName()); - return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); - } + return replaceInstUsesWith(I, Res); - // If this is and(cast(icmp), cast(icmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1COp)) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0COp)) - if (Value *Res = FoldAndOfICmps(LHS, RHS)) - return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); - - // If this is and(cast(fcmp), cast(fcmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - if (FCmpInst *RHS = dyn_cast<FCmpInst>(Op1COp)) - if (FCmpInst *LHS = dyn_cast<FCmpInst>(Op0COp)) - if (Value *Res = FoldAndOfFCmps(LHS, RHS)) - return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); - } - } + if (Instruction *CastedAnd = foldCastedBitwiseLogic(I)) + return CastedAnd; - // If we are masking off the sign bit of a floating-point value, convert - // this to the canonical fabs intrinsic call and cast back to integer. - // The backend should know how to optimize fabs(). - // TODO: This transform should also apply to vectors. - ConstantInt *CI; - if (isa<BitCastInst>(Op0C) && SrcTy->isFloatingPointTy() && - match(Op1, m_ConstantInt(CI)) && CI->isMaxValue(true)) { - Module *M = I.getModule(); - Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, SrcTy); - Value *Call = Builder->CreateCall(Fabs, Op0COp, "fabs"); - return CastInst::CreateBitOrPointerCast(Call, I.getType()); - } - } + if (Instruction *Select = foldBoolSextMaskToSelect(I)) + return Select; - { - Value *X = nullptr; - bool OpsSwapped = false; - // Canonicalize SExt or Not to the LHS - if (match(Op1, m_SExt(m_Value())) || - match(Op1, m_Not(m_Value()))) { - std::swap(Op0, Op1); - OpsSwapped = true; - } + return Changed ? &I : nullptr; +} - // Fold (and (sext bool to A), B) --> (select bool, B, 0) - if (match(Op0, m_SExt(m_Value(X))) && - X->getType()->getScalarType()->isIntegerTy(1)) { - Value *Zero = Constant::getNullValue(Op1->getType()); - return SelectInst::Create(X, Op1, Zero); - } +/// Given an OR instruction, check to see if this is a bswap idiom. If so, +/// insert the new intrinsic and return it. +Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - // Fold (and ~(sext bool to A), B) --> (select bool, 0, B) - if (match(Op0, m_Not(m_SExt(m_Value(X)))) && - X->getType()->getScalarType()->isIntegerTy(1)) { - Value *Zero = Constant::getNullValue(Op0->getType()); - return SelectInst::Create(X, Zero, Op1); - } + // Look through zero extends. + if (Instruction *Ext = dyn_cast<ZExtInst>(Op0)) + Op0 = Ext->getOperand(0); - if (OpsSwapped) - std::swap(Op0, Op1); - } + if (Instruction *Ext = dyn_cast<ZExtInst>(Op1)) + Op1 = Ext->getOperand(0); - return Changed ? &I : nullptr; -} + // (A | B) | C and A | (B | C) -> bswap if possible. + bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) || + match(Op1, m_Or(m_Value(), m_Value())); + + // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. + bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) && + match(Op1, m_LogicalShift(m_Value(), m_Value())); + + // (A & B) | (C & D) -> bswap if possible. + bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && + match(Op1, m_And(m_Value(), m_Value())); + + if (!OrOfOrs && !OrOfShifts && !OrOfAnds) + return nullptr; -/// Given an OR instruction, check to see if this is a bswap or bitreverse -/// idiom. If so, insert the new intrinsic and return it. -Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) { SmallVector<Instruction*, 4> Insts; - if (!recognizeBitReverseOrBSwapIdiom(&I, true, false, Insts)) + if (!recognizeBSwapOrBitReverseIdiom(&I, true, false, Insts)) return nullptr; Instruction *LastInst = Insts.pop_back_val(); LastInst->removeFromParent(); @@ -1580,28 +1589,89 @@ Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) { return LastInst; } -/// We have an expression of the form (A&C)|(B&D). Check if A is (cond?-1:0) -/// and either B or D is ~(cond?-1,0) or (cond?0,-1), then we can simplify this -/// expression to "cond ? C : D or B". -static Instruction *MatchSelectFromAndOr(Value *A, Value *B, - Value *C, Value *D) { - // If A is not a select of -1/0, this cannot match. - Value *Cond = nullptr; - if (!match(A, m_SExt(m_Value(Cond))) || - !Cond->getType()->isIntegerTy(1)) +/// If all elements of two constant vectors are 0/-1 and inverses, return true. +static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { + unsigned NumElts = C1->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *EltC1 = C1->getAggregateElement(i); + Constant *EltC2 = C2->getAggregateElement(i); + if (!EltC1 || !EltC2) + return false; + + // One element must be all ones, and the other must be all zeros. + // FIXME: Allow undef elements. + if (!((match(EltC1, m_Zero()) && match(EltC2, m_AllOnes())) || + (match(EltC2, m_Zero()) && match(EltC1, m_AllOnes())))) + return false; + } + return true; +} + +/// We have an expression of the form (A & C) | (B & D). If A is a scalar or +/// vector composed of all-zeros or all-ones values and is the bitwise 'not' of +/// B, it can be used as the condition operand of a select instruction. +static Value *getSelectCondition(Value *A, Value *B, + InstCombiner::BuilderTy &Builder) { + // If these are scalars or vectors of i1, A can be used directly. + Type *Ty = A->getType(); + if (match(A, m_Not(m_Specific(B))) && Ty->getScalarType()->isIntegerTy(1)) + return A; + + // If A and B are sign-extended, look through the sexts to find the booleans. + Value *Cond; + if (match(A, m_SExt(m_Value(Cond))) && + Cond->getType()->getScalarType()->isIntegerTy(1) && + match(B, m_CombineOr(m_Not(m_SExt(m_Specific(Cond))), + m_SExt(m_Not(m_Specific(Cond)))))) + return Cond; + + // All scalar (and most vector) possibilities should be handled now. + // Try more matches that only apply to non-splat constant vectors. + if (!Ty->isVectorTy()) return nullptr; - // ((cond?-1:0)&C) | (B&(cond?0:-1)) -> cond ? C : B. - if (match(D, m_Not(m_SExt(m_Specific(Cond))))) - return SelectInst::Create(Cond, C, B); - if (match(D, m_SExt(m_Not(m_Specific(Cond))))) - return SelectInst::Create(Cond, C, B); - - // ((cond?-1:0)&C) | ((cond?0:-1)&D) -> cond ? C : D. - if (match(B, m_Not(m_SExt(m_Specific(Cond))))) - return SelectInst::Create(Cond, C, D); - if (match(B, m_SExt(m_Not(m_Specific(Cond))))) - return SelectInst::Create(Cond, C, D); + // If both operands are constants, see if the constants are inverse bitmasks. + Constant *AC, *BC; + if (match(A, m_Constant(AC)) && match(B, m_Constant(BC)) && + areInverseVectorBitmasks(AC, BC)) + return ConstantExpr::getTrunc(AC, CmpInst::makeCmpResultType(Ty)); + + // If both operands are xor'd with constants using the same sexted boolean + // operand, see if the constants are inverse bitmasks. + if (match(A, (m_Xor(m_SExt(m_Value(Cond)), m_Constant(AC)))) && + match(B, (m_Xor(m_SExt(m_Specific(Cond)), m_Constant(BC)))) && + Cond->getType()->getScalarType()->isIntegerTy(1) && + areInverseVectorBitmasks(AC, BC)) { + AC = ConstantExpr::getTrunc(AC, CmpInst::makeCmpResultType(Ty)); + return Builder.CreateXor(Cond, AC); + } + return nullptr; +} + +/// We have an expression of the form (A & C) | (B & D). Try to simplify this +/// to "A' ? C : D", where A' is a boolean or vector of booleans. +static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D, + InstCombiner::BuilderTy &Builder) { + // The potential condition of the select may be bitcasted. In that case, look + // through its bitcast and the corresponding bitcast of the 'not' condition. + Type *OrigType = A->getType(); + Value *SrcA, *SrcB; + if (match(A, m_OneUse(m_BitCast(m_Value(SrcA)))) && + match(B, m_OneUse(m_BitCast(m_Value(SrcB))))) { + A = SrcA; + B = SrcB; + } + + if (Value *Cond = getSelectCondition(A, B, Builder)) { + // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + // The bitcasts will either all exist or all not exist. The builder will + // not create unnecessary casts if the types already match. + Value *BitcastC = Builder.CreateBitCast(C, A->getType()); + Value *BitcastD = Builder.CreateBitCast(D, A->getType()); + Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); + return Builder.CreateBitCast(Select, OrigType); + } + return nullptr; } @@ -1940,6 +2010,27 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, /// Optimize (fcmp)|(fcmp). NOTE: Unlike the rest of instcombine, this returns /// a Value which should already be inserted into the function. Value *InstCombiner::FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { + Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); + Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); + FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); + + if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { + // Swap RHS operands to match LHS. + Op1CC = FCmpInst::getSwappedPredicate(Op1CC); + std::swap(Op1LHS, Op1RHS); + } + + // Simplify (fcmp cc0 x, y) | (fcmp cc1 x, y). + // This is a similar transformation to the one in FoldAndOfFCmps. + // + // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: + // bool(R & CC0) || bool(R & CC1) + // = bool((R & CC0) | (R & CC1)) + // = bool(R & (CC0 | CC1)) <= by reversed distribution (contribution? ;) + if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) + return getFCmpValue(getFCmpCode(Op0CC) | getFCmpCode(Op1CC), Op0LHS, Op0RHS, + Builder); + if (LHS->getPredicate() == FCmpInst::FCMP_UNO && RHS->getPredicate() == FCmpInst::FCMP_UNO && LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType()) { @@ -1964,35 +2055,6 @@ Value *InstCombiner::FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { return nullptr; } - Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); - Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); - FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); - - if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { - // Swap RHS operands to match LHS. - Op1CC = FCmpInst::getSwappedPredicate(Op1CC); - std::swap(Op1LHS, Op1RHS); - } - if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) { - // Simplify (fcmp cc0 x, y) | (fcmp cc1 x, y). - if (Op0CC == Op1CC) - return Builder->CreateFCmp((FCmpInst::Predicate)Op0CC, Op0LHS, Op0RHS); - if (Op0CC == FCmpInst::FCMP_TRUE || Op1CC == FCmpInst::FCMP_TRUE) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); - if (Op0CC == FCmpInst::FCMP_FALSE) - return RHS; - if (Op1CC == FCmpInst::FCMP_FALSE) - return LHS; - bool Op0Ordered; - bool Op1Ordered; - unsigned Op0Pred = getFCmpCode(Op0CC, Op0Ordered); - unsigned Op1Pred = getFCmpCode(Op1CC, Op1Ordered); - if (Op0Ordered == Op1Ordered) { - // If both are ordered or unordered, return a new fcmp with - // or'ed predicates. - return getFCmpValue(Op0Ordered, Op0Pred|Op1Pred, Op0LHS, Op0RHS, Builder); - } - } return nullptr; } @@ -2062,14 +2124,14 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A&B)|(A&C) -> A&(B|C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -2077,7 +2139,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return &I; if (Value *V = SimplifyBSwap(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { ConstantInt *C1 = nullptr; Value *X = nullptr; @@ -2111,23 +2173,13 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return NV; } + // Given an OR instruction, check to see if this is a bswap. + if (Instruction *BSwap = MatchBSwap(I)) + return BSwap; + Value *A = nullptr, *B = nullptr; ConstantInt *C1 = nullptr, *C2 = nullptr; - // (A | B) | C and A | (B | C) -> bswap if possible. - bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) || - match(Op1, m_Or(m_Value(), m_Value())); - // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. - bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) && - match(Op1, m_LogicalShift(m_Value(), m_Value())); - // (A & B) | (C & D) -> bswap if possible. - bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && - match(Op1, m_And(m_Value(), m_Value())); - - if (OrOfOrs || OrOfShifts || OrOfAnds) - if (Instruction *BSwap = MatchBSwapOrBitReverse(I)) - return BSwap; - // (X^C)|Y -> (X|Y)^C iff Y&C == 0 if (Op0->hasOneUse() && match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && @@ -2207,18 +2259,27 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - // (A & (C0?-1:0)) | (B & ~(C0?-1:0)) -> C0 ? A : B, and commuted variants. - // Don't do this for vector select idioms, the code generator doesn't handle - // them well yet. - if (!I.getType()->isVectorTy()) { - if (Instruction *Match = MatchSelectFromAndOr(A, B, C, D)) - return Match; - if (Instruction *Match = MatchSelectFromAndOr(B, A, D, C)) - return Match; - if (Instruction *Match = MatchSelectFromAndOr(C, B, A, D)) - return Match; - if (Instruction *Match = MatchSelectFromAndOr(D, A, B, C)) - return Match; + // Don't try to form a select if it's unlikely that we'll get rid of at + // least one of the operands. A select is generally more expensive than the + // 'or' that it is replacing. + if (Op0->hasOneUse() || Op1->hasOneUse()) { + // (Cond & C) | (~Cond & D) -> Cond ? C : D, and commuted variants. + if (Value *V = matchSelectFromAndOr(A, C, B, D, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(A, C, D, B, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, B, D, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, D, B, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(B, D, A, C, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(B, D, C, A, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(D, B, A, C, *Builder)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(D, B, C, A, *Builder)) + return replaceInstUsesWith(I, V); } // ((A&~B)|(~A&B)) -> A^B @@ -2342,7 +2403,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); if (LHS && RHS) if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) - return ReplaceInstUsesWith(I, Res); + return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary // number of 'or' instructions might have to be created. @@ -2350,18 +2411,18 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (LHS && match(Op1, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) - return ReplaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + return replaceInstUsesWith(I, Builder->CreateOr(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) - return ReplaceInstUsesWith(I, Builder->CreateOr(Res, X)); + return replaceInstUsesWith(I, Builder->CreateOr(Res, X)); } if (RHS && match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) - return ReplaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + return replaceInstUsesWith(I, Builder->CreateOr(Res, Y)); if (auto *Cmp = dyn_cast<ICmpInst>(Y)) if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) - return ReplaceInstUsesWith(I, Builder->CreateOr(Res, X)); + return replaceInstUsesWith(I, Builder->CreateOr(Res, X)); } } @@ -2369,48 +2430,17 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) if (Value *Res = FoldOrOfFCmps(LHS, RHS)) - return ReplaceInstUsesWith(I, Res); - - // fold (or (cast A), (cast B)) -> (cast (or A, B)) - if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { - CastInst *Op1C = dyn_cast<CastInst>(Op1); - if (Op1C && Op0C->getOpcode() == Op1C->getOpcode()) {// same cast kind ? - Type *SrcTy = Op0C->getOperand(0)->getType(); - if (SrcTy == Op1C->getOperand(0)->getType() && - SrcTy->isIntOrIntVectorTy()) { - Value *Op0COp = Op0C->getOperand(0), *Op1COp = Op1C->getOperand(0); - - if ((!isa<ICmpInst>(Op0COp) || !isa<ICmpInst>(Op1COp)) && - // Only do this if the casts both really cause code to be - // generated. - ShouldOptimizeCast(Op0C->getOpcode(), Op0COp, I.getType()) && - ShouldOptimizeCast(Op1C->getOpcode(), Op1COp, I.getType())) { - Value *NewOp = Builder->CreateOr(Op0COp, Op1COp, I.getName()); - return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); - } + return replaceInstUsesWith(I, Res); - // If this is or(cast(icmp), cast(icmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1COp)) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0COp)) - if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) - return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); - - // If this is or(cast(fcmp), cast(fcmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - if (FCmpInst *RHS = dyn_cast<FCmpInst>(Op1COp)) - if (FCmpInst *LHS = dyn_cast<FCmpInst>(Op0COp)) - if (Value *Res = FoldOrOfFCmps(LHS, RHS)) - return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); - } - } - } + if (Instruction *CastedOr = foldCastedBitwiseLogic(I)) + return CastedOr; - // or(sext(A), B) -> A ? -1 : B where A is an i1 - // or(A, sext(B)) -> B ? -1 : A where B is an i1 - if (match(Op0, m_SExt(m_Value(A))) && A->getType()->isIntegerTy(1)) + // or(sext(A), B) / or(B, sext(A)) --> A ? -1 : B, where A is i1 or <N x i1>. + if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->getScalarType()->isIntegerTy(1)) return SelectInst::Create(A, ConstantInt::getSigned(I.getType(), -1), Op1); - if (match(Op1, m_SExt(m_Value(A))) && A->getType()->isIntegerTy(1)) + if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->getScalarType()->isIntegerTy(1)) return SelectInst::Create(A, ConstantInt::getSigned(I.getType(), -1), Op0); // Note: If we've gotten to the point of visiting the outer OR, then the @@ -2447,14 +2477,14 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (A&B)^(A&C) -> A&(B^C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -2462,7 +2492,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return &I; if (Value *V = SimplifyBSwap(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Is this a ~ operation? if (Value *NotOp = dyn_castNotVal(&I)) { @@ -2731,29 +2761,14 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); unsigned Code = getICmpCode(LHS) ^ getICmpCode(RHS); bool isSigned = LHS->isSigned() || RHS->isSigned(); - return ReplaceInstUsesWith(I, + return replaceInstUsesWith(I, getNewICmpValue(isSigned, Code, Op0, Op1, Builder)); } } - // fold (xor (cast A), (cast B)) -> (cast (xor A, B)) - if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { - if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) - if (Op0C->getOpcode() == Op1C->getOpcode()) { // same cast kind? - Type *SrcTy = Op0C->getOperand(0)->getType(); - if (SrcTy == Op1C->getOperand(0)->getType() && SrcTy->isIntegerTy() && - // Only do this if the casts both really cause code to be generated. - ShouldOptimizeCast(Op0C->getOpcode(), Op0C->getOperand(0), - I.getType()) && - ShouldOptimizeCast(Op1C->getOpcode(), Op1C->getOperand(0), - I.getType())) { - Value *NewOp = Builder->CreateXor(Op0C->getOperand(0), - Op1C->getOperand(0), I.getName()); - return CastInst::Create(Op0C->getOpcode(), NewOp, I.getType()); - } - } - } + if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) + return CastedXor; return Changed ? &I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 090245d1b22c..8acff91345d6 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -14,6 +14,7 @@ #include "InstCombineInternal.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" @@ -29,8 +30,8 @@ using namespace PatternMatch; STATISTIC(NumSimplified, "Number of library calls simplified"); -/// getPromotedType - Return the specified type promoted as it would be to pass -/// though a va_arg area. +/// Return the specified type promoted as it would be to pass though a va_arg +/// area. static Type *getPromotedType(Type *Ty) { if (IntegerType* ITy = dyn_cast<IntegerType>(Ty)) { if (ITy->getBitWidth() < 32) @@ -39,8 +40,8 @@ static Type *getPromotedType(Type *Ty) { return Ty; } -/// reduceToSingleValueType - Given an aggregate type which ultimately holds a -/// single scalar element, like {{{type}}} or [1 x type], return type. +/// Given an aggregate type which ultimately holds a single scalar element, +/// like {{{type}}} or [1 x type], return type. static Type *reduceToSingleValueType(Type *T) { while (!T->isSingleValueType()) { if (StructType *STy = dyn_cast<StructType>(T)) { @@ -60,6 +61,23 @@ static Type *reduceToSingleValueType(Type *T) { return T; } +/// Return a constant boolean vector that has true elements in all positions +/// where the input constant data vector has an element with the sign bit set. +static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { + SmallVector<Constant *, 32> BoolVec; + IntegerType *BoolTy = Type::getInt1Ty(V->getContext()); + for (unsigned I = 0, E = V->getNumElements(); I != E; ++I) { + Constant *Elt = V->getElementAsConstant(I); + assert((isa<ConstantInt>(Elt) || isa<ConstantFP>(Elt)) && + "Unexpected constant data vector element type"); + bool Sign = V->getElementType()->isIntegerTy() + ? cast<ConstantInt>(Elt)->isNegative() + : cast<ConstantFP>(Elt)->isNegative(); + BoolVec.push_back(ConstantInt::get(BoolTy, Sign)); + } + return ConstantVector::get(BoolVec); +} + Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, AC, DT); unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, AC, DT); @@ -197,7 +215,7 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { return nullptr; } -static Value *SimplifyX86immshift(const IntrinsicInst &II, +static Value *simplifyX86immShift(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { bool LogicalShift = false; bool ShiftLeft = false; @@ -307,83 +325,216 @@ static Value *SimplifyX86immshift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } -static Value *SimplifyX86extend(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder, - bool SignExtend) { - VectorType *SrcTy = cast<VectorType>(II.getArgOperand(0)->getType()); - VectorType *DstTy = cast<VectorType>(II.getType()); - unsigned NumDstElts = DstTy->getNumElements(); - - // Extract a subvector of the first NumDstElts lanes and sign/zero extend. - SmallVector<int, 8> ShuffleMask; - for (int i = 0; i != (int)NumDstElts; ++i) - ShuffleMask.push_back(i); - - Value *SV = Builder.CreateShuffleVector(II.getArgOperand(0), - UndefValue::get(SrcTy), ShuffleMask); - return SignExtend ? Builder.CreateSExt(SV, DstTy) - : Builder.CreateZExt(SV, DstTy); -} - -static Value *SimplifyX86insertps(const IntrinsicInst &II, +// Attempt to simplify AVX2 per-element shift intrinsics to a generic IR shift. +// Unlike the generic IR shifts, the intrinsics have defined behaviour for out +// of range shift amounts (logical - set to zero, arithmetic - splat sign bit). +static Value *simplifyX86varShift(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { - if (auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2))) { - VectorType *VecTy = cast<VectorType>(II.getType()); - assert(VecTy->getNumElements() == 4 && "insertps with wrong vector type"); - - // The immediate permute control byte looks like this: - // [3:0] - zero mask for each 32-bit lane - // [5:4] - select one 32-bit destination lane - // [7:6] - select one 32-bit source lane - - uint8_t Imm = CInt->getZExtValue(); - uint8_t ZMask = Imm & 0xf; - uint8_t DestLane = (Imm >> 4) & 0x3; - uint8_t SourceLane = (Imm >> 6) & 0x3; - - ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); - - // If all zero mask bits are set, this was just a weird way to - // generate a zero vector. - if (ZMask == 0xf) - return ZeroVector; - - // Initialize by passing all of the first source bits through. - int ShuffleMask[4] = { 0, 1, 2, 3 }; - - // We may replace the second operand with the zero vector. - Value *V1 = II.getArgOperand(1); - - if (ZMask) { - // If the zero mask is being used with a single input or the zero mask - // overrides the destination lane, this is a shuffle with the zero vector. - if ((II.getArgOperand(0) == II.getArgOperand(1)) || - (ZMask & (1 << DestLane))) { - V1 = ZeroVector; - // We may still move 32-bits of the first source vector from one lane - // to another. - ShuffleMask[DestLane] = SourceLane; - // The zero mask may override the previous insert operation. - for (unsigned i = 0; i < 4; ++i) - if ((ZMask >> i) & 0x1) - ShuffleMask[i] = i + 4; + bool LogicalShift = false; + bool ShiftLeft = false; + + switch (II.getIntrinsicID()) { + default: + return nullptr; + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + LogicalShift = false; + ShiftLeft = false; + break; + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + LogicalShift = true; + ShiftLeft = false; + break; + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + LogicalShift = true; + ShiftLeft = true; + break; + } + assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); + + // Simplify if all shift amounts are constant/undef. + auto *CShift = dyn_cast<Constant>(II.getArgOperand(1)); + if (!CShift) + return nullptr; + + auto Vec = II.getArgOperand(0); + auto VT = cast<VectorType>(II.getType()); + auto SVT = VT->getVectorElementType(); + int NumElts = VT->getNumElements(); + int BitWidth = SVT->getIntegerBitWidth(); + + // Collect each element's shift amount. + // We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth. + bool AnyOutOfRange = false; + SmallVector<int, 8> ShiftAmts; + for (int I = 0; I < NumElts; ++I) { + auto *CElt = CShift->getAggregateElement(I); + if (CElt && isa<UndefValue>(CElt)) { + ShiftAmts.push_back(-1); + continue; + } + + auto *COp = dyn_cast_or_null<ConstantInt>(CElt); + if (!COp) + return nullptr; + + // Handle out of range shifts. + // If LogicalShift - set to BitWidth (special case). + // If ArithmeticShift - set to (BitWidth - 1) (sign splat). + APInt ShiftVal = COp->getValue(); + if (ShiftVal.uge(BitWidth)) { + AnyOutOfRange = LogicalShift; + ShiftAmts.push_back(LogicalShift ? BitWidth : BitWidth - 1); + continue; + } + + ShiftAmts.push_back((int)ShiftVal.getZExtValue()); + } + + // If all elements out of range or UNDEF, return vector of zeros/undefs. + // ArithmeticShift should only hit this if they are all UNDEF. + auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; + if (llvm::all_of(ShiftAmts, OutOfRange)) { + SmallVector<Constant *, 8> ConstantVec; + for (int Idx : ShiftAmts) { + if (Idx < 0) { + ConstantVec.push_back(UndefValue::get(SVT)); } else { - // TODO: Model this case as 2 shuffles or a 'logical and' plus shuffle? - return nullptr; + assert(LogicalShift && "Logical shift expected"); + ConstantVec.push_back(ConstantInt::getNullValue(SVT)); } - } else { - // Replace the selected destination lane with the selected source lane. - ShuffleMask[DestLane] = SourceLane + 4; } + return ConstantVector::get(ConstantVec); + } - return Builder.CreateShuffleVector(II.getArgOperand(0), V1, ShuffleMask); + // We can't handle only some out of range values with generic logical shifts. + if (AnyOutOfRange) + return nullptr; + + // Build the shift amount constant vector. + SmallVector<Constant *, 8> ShiftVecAmts; + for (int Idx : ShiftAmts) { + if (Idx < 0) + ShiftVecAmts.push_back(UndefValue::get(SVT)); + else + ShiftVecAmts.push_back(ConstantInt::get(SVT, Idx)); } - return nullptr; + auto ShiftVec = ConstantVector::get(ShiftVecAmts); + + if (ShiftLeft) + return Builder.CreateShl(Vec, ShiftVec); + + if (LogicalShift) + return Builder.CreateLShr(Vec, ShiftVec); + + return Builder.CreateAShr(Vec, ShiftVec); +} + +static Value *simplifyX86movmsk(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + Value *Arg = II.getArgOperand(0); + Type *ResTy = II.getType(); + Type *ArgTy = Arg->getType(); + + // movmsk(undef) -> zero as we must ensure the upper bits are zero. + if (isa<UndefValue>(Arg)) + return Constant::getNullValue(ResTy); + + // We can't easily peek through x86_mmx types. + if (!ArgTy->isVectorTy()) + return nullptr; + + auto *C = dyn_cast<Constant>(Arg); + if (!C) + return nullptr; + + // Extract signbits of the vector input and pack into integer result. + APInt Result(ResTy->getPrimitiveSizeInBits(), 0); + for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) { + auto *COp = C->getAggregateElement(I); + if (!COp) + return nullptr; + if (isa<UndefValue>(COp)) + continue; + + auto *CInt = dyn_cast<ConstantInt>(COp); + auto *CFp = dyn_cast<ConstantFP>(COp); + if (!CInt && !CFp) + return nullptr; + + if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative())) + Result.setBit(I); + } + + return Constant::getIntegerValue(ResTy, Result); +} + +static Value *simplifyX86insertps(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2)); + if (!CInt) + return nullptr; + + VectorType *VecTy = cast<VectorType>(II.getType()); + assert(VecTy->getNumElements() == 4 && "insertps with wrong vector type"); + + // The immediate permute control byte looks like this: + // [3:0] - zero mask for each 32-bit lane + // [5:4] - select one 32-bit destination lane + // [7:6] - select one 32-bit source lane + + uint8_t Imm = CInt->getZExtValue(); + uint8_t ZMask = Imm & 0xf; + uint8_t DestLane = (Imm >> 4) & 0x3; + uint8_t SourceLane = (Imm >> 6) & 0x3; + + ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); + + // If all zero mask bits are set, this was just a weird way to + // generate a zero vector. + if (ZMask == 0xf) + return ZeroVector; + + // Initialize by passing all of the first source bits through. + uint32_t ShuffleMask[4] = { 0, 1, 2, 3 }; + + // We may replace the second operand with the zero vector. + Value *V1 = II.getArgOperand(1); + + if (ZMask) { + // If the zero mask is being used with a single input or the zero mask + // overrides the destination lane, this is a shuffle with the zero vector. + if ((II.getArgOperand(0) == II.getArgOperand(1)) || + (ZMask & (1 << DestLane))) { + V1 = ZeroVector; + // We may still move 32-bits of the first source vector from one lane + // to another. + ShuffleMask[DestLane] = SourceLane; + // The zero mask may override the previous insert operation. + for (unsigned i = 0; i < 4; ++i) + if ((ZMask >> i) & 0x1) + ShuffleMask[i] = i + 4; + } else { + // TODO: Model this case as 2 shuffles or a 'logical and' plus shuffle? + return nullptr; + } + } else { + // Replace the selected destination lane with the selected source lane. + ShuffleMask[DestLane] = SourceLane + 4; + } + + return Builder.CreateShuffleVector(II.getArgOperand(0), V1, ShuffleMask); } /// Attempt to simplify SSE4A EXTRQ/EXTRQI instructions using constant folding /// or conversion to a shuffle vector. -static Value *SimplifyX86extrq(IntrinsicInst &II, Value *Op0, +static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, ConstantInt *CILength, ConstantInt *CIIndex, InstCombiner::BuilderTy &Builder) { auto LowConstantHighUndef = [&](uint64_t Val) { @@ -476,7 +627,7 @@ static Value *SimplifyX86extrq(IntrinsicInst &II, Value *Op0, /// Attempt to simplify SSE4A INSERTQ/INSERTQI instructions using constant /// folding or conversion to a shuffle vector. -static Value *SimplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, +static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, APInt APLength, APInt APIndex, InstCombiner::BuilderTy &Builder) { @@ -571,74 +722,211 @@ static Value *SimplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, return nullptr; } -/// The shuffle mask for a perm2*128 selects any two halves of two 256-bit -/// source vectors, unless a zero bit is set. If a zero bit is set, -/// then ignore that half of the mask and clear that half of the vector. -static Value *SimplifyX86vperm2(const IntrinsicInst &II, +/// Attempt to convert pshufb* to shufflevector if the mask is constant. +static Value *simplifyX86pshufb(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { - if (auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2))) { - VectorType *VecTy = cast<VectorType>(II.getType()); - ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); + Constant *V = dyn_cast<Constant>(II.getArgOperand(1)); + if (!V) + return nullptr; + + auto *VecTy = cast<VectorType>(II.getType()); + auto *MaskEltTy = Type::getInt32Ty(II.getContext()); + unsigned NumElts = VecTy->getNumElements(); + assert((NumElts == 16 || NumElts == 32) && + "Unexpected number of elements in shuffle mask!"); + + // Construct a shuffle mask from constant integers or UNDEFs. + Constant *Indexes[32] = {NULL}; + + // Each byte in the shuffle control mask forms an index to permute the + // corresponding byte in the destination operand. + for (unsigned I = 0; I < NumElts; ++I) { + Constant *COp = V->getAggregateElement(I); + if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) + return nullptr; + + if (isa<UndefValue>(COp)) { + Indexes[I] = UndefValue::get(MaskEltTy); + continue; + } - // The immediate permute control byte looks like this: - // [1:0] - select 128 bits from sources for low half of destination - // [2] - ignore - // [3] - zero low half of destination - // [5:4] - select 128 bits from sources for high half of destination - // [6] - ignore - // [7] - zero high half of destination + int8_t Index = cast<ConstantInt>(COp)->getValue().getZExtValue(); - uint8_t Imm = CInt->getZExtValue(); + // If the most significant bit (bit[7]) of each byte of the shuffle + // control mask is set, then zero is written in the result byte. + // The zero vector is in the right-hand side of the resulting + // shufflevector. - bool LowHalfZero = Imm & 0x08; - bool HighHalfZero = Imm & 0x80; + // The value of each index for the high 128-bit lane is the least + // significant 4 bits of the respective shuffle control byte. + Index = ((Index < 0) ? NumElts : Index & 0x0F) + (I & 0xF0); + Indexes[I] = ConstantInt::get(MaskEltTy, Index); + } - // If both zero mask bits are set, this was just a weird way to - // generate a zero vector. - if (LowHalfZero && HighHalfZero) - return ZeroVector; + auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, NumElts)); + auto V1 = II.getArgOperand(0); + auto V2 = Constant::getNullValue(VecTy); + return Builder.CreateShuffleVector(V1, V2, ShuffleMask); +} - // If 0 or 1 zero mask bits are set, this is a simple shuffle. - unsigned NumElts = VecTy->getNumElements(); - unsigned HalfSize = NumElts / 2; - SmallVector<int, 8> ShuffleMask(NumElts); +/// Attempt to convert vpermilvar* to shufflevector if the mask is constant. +static Value *simplifyX86vpermilvar(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + Constant *V = dyn_cast<Constant>(II.getArgOperand(1)); + if (!V) + return nullptr; + + auto *MaskEltTy = Type::getInt32Ty(II.getContext()); + unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + assert(NumElts == 8 || NumElts == 4 || NumElts == 2); - // The high bit of the selection field chooses the 1st or 2nd operand. - bool LowInputSelect = Imm & 0x02; - bool HighInputSelect = Imm & 0x20; + // Construct a shuffle mask from constant integers or UNDEFs. + Constant *Indexes[8] = {NULL}; - // The low bit of the selection field chooses the low or high half - // of the selected operand. - bool LowHalfSelect = Imm & 0x01; - bool HighHalfSelect = Imm & 0x10; + // The intrinsics only read one or two bits, clear the rest. + for (unsigned I = 0; I < NumElts; ++I) { + Constant *COp = V->getAggregateElement(I); + if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) + return nullptr; - // Determine which operand(s) are actually in use for this instruction. - Value *V0 = LowInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); - Value *V1 = HighInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); + if (isa<UndefValue>(COp)) { + Indexes[I] = UndefValue::get(MaskEltTy); + continue; + } - // If needed, replace operands based on zero mask. - V0 = LowHalfZero ? ZeroVector : V0; - V1 = HighHalfZero ? ZeroVector : V1; + APInt Index = cast<ConstantInt>(COp)->getValue(); + Index = Index.zextOrTrunc(32).getLoBits(2); - // Permute low half of result. - unsigned StartIndex = LowHalfSelect ? HalfSize : 0; - for (unsigned i = 0; i < HalfSize; ++i) - ShuffleMask[i] = StartIndex + i; + // The PD variants uses bit 1 to select per-lane element index, so + // shift down to convert to generic shuffle mask index. + if (II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd || + II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) + Index = Index.lshr(1); - // Permute high half of result. - StartIndex = HighHalfSelect ? HalfSize : 0; - StartIndex += NumElts; - for (unsigned i = 0; i < HalfSize; ++i) - ShuffleMask[i + HalfSize] = StartIndex + i; + // The _256 variants are a bit trickier since the mask bits always index + // into the corresponding 128 half. In order to convert to a generic + // shuffle, we have to make that explicit. + if ((II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_ps_256 || + II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) && + ((NumElts / 2) <= I)) { + Index += APInt(32, NumElts / 2); + } - return Builder.CreateShuffleVector(V0, V1, ShuffleMask); + Indexes[I] = ConstantInt::get(MaskEltTy, Index); } - return nullptr; + + auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, NumElts)); + auto V1 = II.getArgOperand(0); + auto V2 = UndefValue::get(V1->getType()); + return Builder.CreateShuffleVector(V1, V2, ShuffleMask); +} + +/// Attempt to convert vpermd/vpermps to shufflevector if the mask is constant. +static Value *simplifyX86vpermv(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + auto *V = dyn_cast<Constant>(II.getArgOperand(1)); + if (!V) + return nullptr; + + auto *VecTy = cast<VectorType>(II.getType()); + auto *MaskEltTy = Type::getInt32Ty(II.getContext()); + unsigned Size = VecTy->getNumElements(); + assert(Size == 8 && "Unexpected shuffle mask size"); + + // Construct a shuffle mask from constant integers or UNDEFs. + Constant *Indexes[8] = {NULL}; + + for (unsigned I = 0; I < Size; ++I) { + Constant *COp = V->getAggregateElement(I); + if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) + return nullptr; + + if (isa<UndefValue>(COp)) { + Indexes[I] = UndefValue::get(MaskEltTy); + continue; + } + + APInt Index = cast<ConstantInt>(COp)->getValue(); + Index = Index.zextOrTrunc(32).getLoBits(3); + Indexes[I] = ConstantInt::get(MaskEltTy, Index); + } + + auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, Size)); + auto V1 = II.getArgOperand(0); + auto V2 = UndefValue::get(VecTy); + return Builder.CreateShuffleVector(V1, V2, ShuffleMask); +} + +/// The shuffle mask for a perm2*128 selects any two halves of two 256-bit +/// source vectors, unless a zero bit is set. If a zero bit is set, +/// then ignore that half of the mask and clear that half of the vector. +static Value *simplifyX86vperm2(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2)); + if (!CInt) + return nullptr; + + VectorType *VecTy = cast<VectorType>(II.getType()); + ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); + + // The immediate permute control byte looks like this: + // [1:0] - select 128 bits from sources for low half of destination + // [2] - ignore + // [3] - zero low half of destination + // [5:4] - select 128 bits from sources for high half of destination + // [6] - ignore + // [7] - zero high half of destination + + uint8_t Imm = CInt->getZExtValue(); + + bool LowHalfZero = Imm & 0x08; + bool HighHalfZero = Imm & 0x80; + + // If both zero mask bits are set, this was just a weird way to + // generate a zero vector. + if (LowHalfZero && HighHalfZero) + return ZeroVector; + + // If 0 or 1 zero mask bits are set, this is a simple shuffle. + unsigned NumElts = VecTy->getNumElements(); + unsigned HalfSize = NumElts / 2; + SmallVector<uint32_t, 8> ShuffleMask(NumElts); + + // The high bit of the selection field chooses the 1st or 2nd operand. + bool LowInputSelect = Imm & 0x02; + bool HighInputSelect = Imm & 0x20; + + // The low bit of the selection field chooses the low or high half + // of the selected operand. + bool LowHalfSelect = Imm & 0x01; + bool HighHalfSelect = Imm & 0x10; + + // Determine which operand(s) are actually in use for this instruction. + Value *V0 = LowInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); + Value *V1 = HighInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); + + // If needed, replace operands based on zero mask. + V0 = LowHalfZero ? ZeroVector : V0; + V1 = HighHalfZero ? ZeroVector : V1; + + // Permute low half of result. + unsigned StartIndex = LowHalfSelect ? HalfSize : 0; + for (unsigned i = 0; i < HalfSize; ++i) + ShuffleMask[i] = StartIndex + i; + + // Permute high half of result. + StartIndex = HighHalfSelect ? HalfSize : 0; + StartIndex += NumElts; + for (unsigned i = 0; i < HalfSize; ++i) + ShuffleMask[i + HalfSize] = StartIndex + i; + + return Builder.CreateShuffleVector(V0, V1, ShuffleMask); } /// Decode XOP integer vector comparison intrinsics. -static Value *SimplifyX86vpcom(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder, bool IsSigned) { +static Value *simplifyX86vpcom(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder, + bool IsSigned) { if (auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2))) { uint64_t Imm = CInt->getZExtValue() & 0x7; VectorType *VecTy = cast<VectorType>(II.getType()); @@ -667,21 +955,296 @@ static Value *SimplifyX86vpcom(const IntrinsicInst &II, return ConstantInt::getSigned(VecTy, -1); // TRUE } - if (Value *Cmp = Builder.CreateICmp(Pred, II.getArgOperand(0), II.getArgOperand(1))) + if (Value *Cmp = Builder.CreateICmp(Pred, II.getArgOperand(0), + II.getArgOperand(1))) return Builder.CreateSExtOrTrunc(Cmp, VecTy); } return nullptr; } -/// visitCallInst - CallInst simplification. This mostly only handles folding -/// of intrinsic instructions. For normal calls, it allows visitCallSite to do -/// the heavy lifting. -/// +static Value *simplifyMinnumMaxnum(const IntrinsicInst &II) { + Value *Arg0 = II.getArgOperand(0); + Value *Arg1 = II.getArgOperand(1); + + // fmin(x, x) -> x + if (Arg0 == Arg1) + return Arg0; + + const auto *C1 = dyn_cast<ConstantFP>(Arg1); + + // fmin(x, nan) -> x + if (C1 && C1->isNaN()) + return Arg0; + + // This is the value because if undef were NaN, we would return the other + // value and cannot return a NaN unless both operands are. + // + // fmin(undef, x) -> x + if (isa<UndefValue>(Arg0)) + return Arg1; + + // fmin(x, undef) -> x + if (isa<UndefValue>(Arg1)) + return Arg0; + + Value *X = nullptr; + Value *Y = nullptr; + if (II.getIntrinsicID() == Intrinsic::minnum) { + // fmin(x, fmin(x, y)) -> fmin(x, y) + // fmin(y, fmin(x, y)) -> fmin(x, y) + if (match(Arg1, m_FMin(m_Value(X), m_Value(Y)))) { + if (Arg0 == X || Arg0 == Y) + return Arg1; + } + + // fmin(fmin(x, y), x) -> fmin(x, y) + // fmin(fmin(x, y), y) -> fmin(x, y) + if (match(Arg0, m_FMin(m_Value(X), m_Value(Y)))) { + if (Arg1 == X || Arg1 == Y) + return Arg0; + } + + // TODO: fmin(nnan x, inf) -> x + // TODO: fmin(nnan ninf x, flt_max) -> x + if (C1 && C1->isInfinity()) { + // fmin(x, -inf) -> -inf + if (C1->isNegative()) + return Arg1; + } + } else { + assert(II.getIntrinsicID() == Intrinsic::maxnum); + // fmax(x, fmax(x, y)) -> fmax(x, y) + // fmax(y, fmax(x, y)) -> fmax(x, y) + if (match(Arg1, m_FMax(m_Value(X), m_Value(Y)))) { + if (Arg0 == X || Arg0 == Y) + return Arg1; + } + + // fmax(fmax(x, y), x) -> fmax(x, y) + // fmax(fmax(x, y), y) -> fmax(x, y) + if (match(Arg0, m_FMax(m_Value(X), m_Value(Y)))) { + if (Arg1 == X || Arg1 == Y) + return Arg0; + } + + // TODO: fmax(nnan x, -inf) -> x + // TODO: fmax(nnan ninf x, -flt_max) -> x + if (C1 && C1->isInfinity()) { + // fmax(x, inf) -> inf + if (!C1->isNegative()) + return Arg1; + } + } + return nullptr; +} + +static bool maskIsAllOneOrUndef(Value *Mask) { + auto *ConstMask = dyn_cast<Constant>(Mask); + if (!ConstMask) + return false; + if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask)) + return true; + for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; + ++I) { + if (auto *MaskElt = ConstMask->getAggregateElement(I)) + if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt)) + continue; + return false; + } + return true; +} + +static Value *simplifyMaskedLoad(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + // If the mask is all ones or undefs, this is a plain vector load of the 1st + // argument. + if (maskIsAllOneOrUndef(II.getArgOperand(2))) { + Value *LoadPtr = II.getArgOperand(0); + unsigned Alignment = cast<ConstantInt>(II.getArgOperand(1))->getZExtValue(); + return Builder.CreateAlignedLoad(LoadPtr, Alignment, "unmaskedload"); + } + + return nullptr; +} + +static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) { + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + if (!ConstMask) + return nullptr; + + // If the mask is all zeros, this instruction does nothing. + if (ConstMask->isNullValue()) + return IC.eraseInstFromFunction(II); + + // If the mask is all ones, this is a plain vector store of the 1st argument. + if (ConstMask->isAllOnesValue()) { + Value *StorePtr = II.getArgOperand(1); + unsigned Alignment = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue(); + return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); + } + + return nullptr; +} + +static Instruction *simplifyMaskedGather(IntrinsicInst &II, InstCombiner &IC) { + // If the mask is all zeros, return the "passthru" argument of the gather. + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); + if (ConstMask && ConstMask->isNullValue()) + return IC.replaceInstUsesWith(II, II.getArgOperand(3)); + + return nullptr; +} + +static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { + // If the mask is all zeros, a scatter does nothing. + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + if (ConstMask && ConstMask->isNullValue()) + return IC.eraseInstFromFunction(II); + + return nullptr; +} + +// TODO: If the x86 backend knew how to convert a bool vector mask back to an +// XMM register mask efficiently, we could transform all x86 masked intrinsics +// to LLVM masked intrinsics and remove the x86 masked intrinsic defs. +static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) { + Value *Ptr = II.getOperand(0); + Value *Mask = II.getOperand(1); + Constant *ZeroVec = Constant::getNullValue(II.getType()); + + // Special case a zero mask since that's not a ConstantDataVector. + // This masked load instruction creates a zero vector. + if (isa<ConstantAggregateZero>(Mask)) + return IC.replaceInstUsesWith(II, ZeroVec); + + auto *ConstMask = dyn_cast<ConstantDataVector>(Mask); + if (!ConstMask) + return nullptr; + + // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic + // to allow target-independent optimizations. + + // First, cast the x86 intrinsic scalar pointer to a vector pointer to match + // the LLVM intrinsic definition for the pointer argument. + unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace(); + PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace); + Value *PtrCast = IC.Builder->CreateBitCast(Ptr, VecPtrTy, "castvec"); + + // Second, convert the x86 XMM integer vector mask to a vector of bools based + // on each element's most significant bit (the sign bit). + Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); + + // The pass-through vector for an x86 masked load is a zero vector. + CallInst *NewMaskedLoad = + IC.Builder->CreateMaskedLoad(PtrCast, 1, BoolMask, ZeroVec); + return IC.replaceInstUsesWith(II, NewMaskedLoad); +} + +// TODO: If the x86 backend knew how to convert a bool vector mask back to an +// XMM register mask efficiently, we could transform all x86 masked intrinsics +// to LLVM masked intrinsics and remove the x86 masked intrinsic defs. +static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) { + Value *Ptr = II.getOperand(0); + Value *Mask = II.getOperand(1); + Value *Vec = II.getOperand(2); + + // Special case a zero mask since that's not a ConstantDataVector: + // this masked store instruction does nothing. + if (isa<ConstantAggregateZero>(Mask)) { + IC.eraseInstFromFunction(II); + return true; + } + + // The SSE2 version is too weird (eg, unaligned but non-temporal) to do + // anything else at this level. + if (II.getIntrinsicID() == Intrinsic::x86_sse2_maskmov_dqu) + return false; + + auto *ConstMask = dyn_cast<ConstantDataVector>(Mask); + if (!ConstMask) + return false; + + // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic + // to allow target-independent optimizations. + + // First, cast the x86 intrinsic scalar pointer to a vector pointer to match + // the LLVM intrinsic definition for the pointer argument. + unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace(); + PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace); + Value *PtrCast = IC.Builder->CreateBitCast(Ptr, VecPtrTy, "castvec"); + + // Second, convert the x86 XMM integer vector mask to a vector of bools based + // on each element's most significant bit (the sign bit). + Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); + + IC.Builder->CreateMaskedStore(Vec, PtrCast, 1, BoolMask); + + // 'Replace uses' doesn't work for stores. Erase the original masked store. + IC.eraseInstFromFunction(II); + return true; +} + +// Returns true iff the 2 intrinsics have the same operands, limiting the +// comparison to the first NumOperands. +static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, + unsigned NumOperands) { + assert(I.getNumArgOperands() >= NumOperands && "Not enough operands"); + assert(E.getNumArgOperands() >= NumOperands && "Not enough operands"); + for (unsigned i = 0; i < NumOperands; i++) + if (I.getArgOperand(i) != E.getArgOperand(i)) + return false; + return true; +} + +// Remove trivially empty start/end intrinsic ranges, i.e. a start +// immediately followed by an end (ignoring debuginfo or other +// start/end intrinsics in between). As this handles only the most trivial +// cases, tracking the nesting level is not needed: +// +// call @llvm.foo.start(i1 0) ; &I +// call @llvm.foo.start(i1 0) +// call @llvm.foo.end(i1 0) ; This one will not be skipped: it will be removed +// call @llvm.foo.end(i1 0) +static bool removeTriviallyEmptyRange(IntrinsicInst &I, unsigned StartID, + unsigned EndID, InstCombiner &IC) { + assert(I.getIntrinsicID() == StartID && + "Start intrinsic does not have expected ID"); + BasicBlock::iterator BI(I), BE(I.getParent()->end()); + for (++BI; BI != BE; ++BI) { + if (auto *E = dyn_cast<IntrinsicInst>(BI)) { + if (isa<DbgInfoIntrinsic>(E) || E->getIntrinsicID() == StartID) + continue; + if (E->getIntrinsicID() == EndID && + haveSameOperands(I, *E, E->getNumArgOperands())) { + IC.eraseInstFromFunction(*E); + IC.eraseInstFromFunction(I); + return true; + } + } + break; + } + + return false; +} + +Instruction *InstCombiner::visitVAStartInst(VAStartInst &I) { + removeTriviallyEmptyRange(I, Intrinsic::vastart, Intrinsic::vaend, *this); + return nullptr; +} + +Instruction *InstCombiner::visitVACopyInst(VACopyInst &I) { + removeTriviallyEmptyRange(I, Intrinsic::vacopy, Intrinsic::vaend, *this); + return nullptr; +} + +/// CallInst simplification. This mostly only handles folding of intrinsic +/// instructions. For normal calls, it allows visitCallSite to do the heavy +/// lifting. Instruction *InstCombiner::visitCallInst(CallInst &CI) { auto Args = CI.arg_operands(); if (Value *V = SimplifyCall(CI.getCalledValue(), Args.begin(), Args.end(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(CI, V); + return replaceInstUsesWith(CI, V); if (isFreeCall(&CI, TLI)) return visitFree(CI); @@ -705,7 +1268,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // memmove/cpy/set of zero bytes is a noop. if (Constant *NumBytes = dyn_cast<Constant>(MI->getLength())) { if (NumBytes->isNullValue()) - return EraseInstFromFunction(CI); + return eraseInstFromFunction(CI); if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) if (CI->getZExtValue() == 1) { @@ -738,7 +1301,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { // memmove(x,x,size) -> noop. if (MTI->getSource() == MTI->getDest()) - return EraseInstFromFunction(CI); + return eraseInstFromFunction(CI); } // If we can determine a pointer alignment that is bigger than currently @@ -754,19 +1317,30 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Changed) return II; } - auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width, unsigned DemandedWidth) - { + auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width, + unsigned DemandedWidth) { APInt UndefElts(Width, 0); APInt DemandedElts = APInt::getLowBitsSet(Width, DemandedWidth); return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); }; + auto SimplifyDemandedVectorEltsHigh = [this](Value *Op, unsigned Width, + unsigned DemandedWidth) { + APInt UndefElts(Width, 0); + APInt DemandedElts = APInt::getHighBitsSet(Width, DemandedWidth); + return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); + }; switch (II->getIntrinsicID()) { default: break; case Intrinsic::objectsize: { uint64_t Size; - if (getObjectSize(II->getArgOperand(0), Size, DL, TLI)) - return ReplaceInstUsesWith(CI, ConstantInt::get(CI.getType(), Size)); + if (getObjectSize(II->getArgOperand(0), Size, DL, TLI)) { + APInt APSize(II->getType()->getIntegerBitWidth(), Size); + // Equality check to be sure that `Size` can fit in a value of type + // `II->getType()` + if (APSize == Size) + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), APSize)); + } return nullptr; } case Intrinsic::bswap: { @@ -775,7 +1349,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // bswap(bswap(x)) -> x if (match(IIOperand, m_BSwap(m_Value(X)))) - return ReplaceInstUsesWith(CI, X); + return replaceInstUsesWith(CI, X); // bswap(trunc(bswap(x))) -> trunc(lshr(x, c)) if (match(IIOperand, m_Trunc(m_BSwap(m_Value(X))))) { @@ -794,18 +1368,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // bitreverse(bitreverse(x)) -> x if (match(IIOperand, m_Intrinsic<Intrinsic::bitreverse>(m_Value(X)))) - return ReplaceInstUsesWith(CI, X); + return replaceInstUsesWith(CI, X); break; } + case Intrinsic::masked_load: + if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II, *Builder)) + return replaceInstUsesWith(CI, SimplifiedMaskedOp); + break; + case Intrinsic::masked_store: + return simplifyMaskedStore(*II, *this); + case Intrinsic::masked_gather: + return simplifyMaskedGather(*II, *this); + case Intrinsic::masked_scatter: + return simplifyMaskedScatter(*II, *this); + case Intrinsic::powi: if (ConstantInt *Power = dyn_cast<ConstantInt>(II->getArgOperand(1))) { // powi(x, 0) -> 1.0 if (Power->isZero()) - return ReplaceInstUsesWith(CI, ConstantFP::get(CI.getType(), 1.0)); + return replaceInstUsesWith(CI, ConstantFP::get(CI.getType(), 1.0)); // powi(x, 1) -> x if (Power->isOne()) - return ReplaceInstUsesWith(CI, II->getArgOperand(0)); + return replaceInstUsesWith(CI, II->getArgOperand(0)); // powi(x, -1) -> 1/x if (Power->isAllOnesValue()) return BinaryOperator::CreateFDiv(ConstantFP::get(CI.getType(), 1.0), @@ -825,7 +1410,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { unsigned TrailingZeros = KnownOne.countTrailingZeros(); APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros)); if ((Mask & KnownZero) == Mask) - return ReplaceInstUsesWith(CI, ConstantInt::get(IT, + return replaceInstUsesWith(CI, ConstantInt::get(IT, APInt(BitWidth, TrailingZeros))); } @@ -843,7 +1428,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { unsigned LeadingZeros = KnownOne.countLeadingZeros(); APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros)); if ((Mask & KnownZero) == Mask) - return ReplaceInstUsesWith(CI, ConstantInt::get(IT, + return replaceInstUsesWith(CI, ConstantInt::get(IT, APInt(BitWidth, LeadingZeros))); } @@ -882,84 +1467,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::maxnum: { Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - - // fmin(x, x) -> x - if (Arg0 == Arg1) - return ReplaceInstUsesWith(CI, Arg0); - - const ConstantFP *C0 = dyn_cast<ConstantFP>(Arg0); - const ConstantFP *C1 = dyn_cast<ConstantFP>(Arg1); - - // Canonicalize constants into the RHS. - if (C0 && !C1) { + // Canonicalize constants to the RHS. + if (isa<ConstantFP>(Arg0) && !isa<ConstantFP>(Arg1)) { II->setArgOperand(0, Arg1); II->setArgOperand(1, Arg0); return II; } - - // fmin(x, nan) -> x - if (C1 && C1->isNaN()) - return ReplaceInstUsesWith(CI, Arg0); - - // This is the value because if undef were NaN, we would return the other - // value and cannot return a NaN unless both operands are. - // - // fmin(undef, x) -> x - if (isa<UndefValue>(Arg0)) - return ReplaceInstUsesWith(CI, Arg1); - - // fmin(x, undef) -> x - if (isa<UndefValue>(Arg1)) - return ReplaceInstUsesWith(CI, Arg0); - - Value *X = nullptr; - Value *Y = nullptr; - if (II->getIntrinsicID() == Intrinsic::minnum) { - // fmin(x, fmin(x, y)) -> fmin(x, y) - // fmin(y, fmin(x, y)) -> fmin(x, y) - if (match(Arg1, m_FMin(m_Value(X), m_Value(Y)))) { - if (Arg0 == X || Arg0 == Y) - return ReplaceInstUsesWith(CI, Arg1); - } - - // fmin(fmin(x, y), x) -> fmin(x, y) - // fmin(fmin(x, y), y) -> fmin(x, y) - if (match(Arg0, m_FMin(m_Value(X), m_Value(Y)))) { - if (Arg1 == X || Arg1 == Y) - return ReplaceInstUsesWith(CI, Arg0); - } - - // TODO: fmin(nnan x, inf) -> x - // TODO: fmin(nnan ninf x, flt_max) -> x - if (C1 && C1->isInfinity()) { - // fmin(x, -inf) -> -inf - if (C1->isNegative()) - return ReplaceInstUsesWith(CI, Arg1); - } - } else { - assert(II->getIntrinsicID() == Intrinsic::maxnum); - // fmax(x, fmax(x, y)) -> fmax(x, y) - // fmax(y, fmax(x, y)) -> fmax(x, y) - if (match(Arg1, m_FMax(m_Value(X), m_Value(Y)))) { - if (Arg0 == X || Arg0 == Y) - return ReplaceInstUsesWith(CI, Arg1); - } - - // fmax(fmax(x, y), x) -> fmax(x, y) - // fmax(fmax(x, y), y) -> fmax(x, y) - if (match(Arg0, m_FMax(m_Value(X), m_Value(Y)))) { - if (Arg1 == X || Arg1 == Y) - return ReplaceInstUsesWith(CI, Arg0); - } - - // TODO: fmax(nnan x, -inf) -> x - // TODO: fmax(nnan ninf x, -flt_max) -> x - if (C1 && C1->isInfinity()) { - // fmax(x, inf) -> inf - if (!C1->isNegative()) - return ReplaceInstUsesWith(CI, Arg1); - } - } + if (Value *V = simplifyMinnumMaxnum(*II)) + return replaceInstUsesWith(*II, V); break; } case Intrinsic::ppc_altivec_lvx: @@ -1041,19 +1556,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; - case Intrinsic::x86_sse_storeu_ps: - case Intrinsic::x86_sse2_storeu_pd: - case Intrinsic::x86_sse2_storeu_dq: - // Turn X86 storeu -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, AC, DT) >= - 16) { - Type *OpPtrTy = - PointerType::getUnqual(II->getArgOperand(1)->getType()); - Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), OpPtrTy); - return new StoreInst(II->getArgOperand(1), Ptr); - } - break; - case Intrinsic::x86_vcvtph2ps_128: case Intrinsic::x86_vcvtph2ps_256: { auto Arg = II->getArgOperand(0); @@ -1070,12 +1572,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Constant folding: Convert to generic half to single conversion. if (isa<ConstantAggregateZero>(Arg)) - return ReplaceInstUsesWith(*II, ConstantAggregateZero::get(RetType)); + return replaceInstUsesWith(*II, ConstantAggregateZero::get(RetType)); if (isa<ConstantDataVector>(Arg)) { auto VectorHalfAsShorts = Arg; if (RetWidth < ArgWidth) { - SmallVector<int, 8> SubVecMask; + SmallVector<uint32_t, 8> SubVecMask; for (unsigned i = 0; i != RetWidth; ++i) SubVecMask.push_back((int)i); VectorHalfAsShorts = Builder->CreateShuffleVector( @@ -1087,7 +1589,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { auto VectorHalfs = Builder->CreateBitCast(VectorHalfAsShorts, VectorHalfType); auto VectorFloats = Builder->CreateFPExt(VectorHalfs, RetType); - return ReplaceInstUsesWith(*II, VectorFloats); + return replaceInstUsesWith(*II, VectorFloats); } // We only use the lowest lanes of the argument. @@ -1117,6 +1619,107 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_mmx_pmovmskb: + case Intrinsic::x86_sse_movmsk_ps: + case Intrinsic::x86_sse2_movmsk_pd: + case Intrinsic::x86_sse2_pmovmskb_128: + case Intrinsic::x86_avx_movmsk_pd_256: + case Intrinsic::x86_avx_movmsk_ps_256: + case Intrinsic::x86_avx2_pmovmskb: { + if (Value *V = simplifyX86movmsk(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; + } + + case Intrinsic::x86_sse_comieq_ss: + case Intrinsic::x86_sse_comige_ss: + case Intrinsic::x86_sse_comigt_ss: + case Intrinsic::x86_sse_comile_ss: + case Intrinsic::x86_sse_comilt_ss: + case Intrinsic::x86_sse_comineq_ss: + case Intrinsic::x86_sse_ucomieq_ss: + case Intrinsic::x86_sse_ucomige_ss: + case Intrinsic::x86_sse_ucomigt_ss: + case Intrinsic::x86_sse_ucomile_ss: + case Intrinsic::x86_sse_ucomilt_ss: + case Intrinsic::x86_sse_ucomineq_ss: + case Intrinsic::x86_sse2_comieq_sd: + case Intrinsic::x86_sse2_comige_sd: + case Intrinsic::x86_sse2_comigt_sd: + case Intrinsic::x86_sse2_comile_sd: + case Intrinsic::x86_sse2_comilt_sd: + case Intrinsic::x86_sse2_comineq_sd: + case Intrinsic::x86_sse2_ucomieq_sd: + case Intrinsic::x86_sse2_ucomige_sd: + case Intrinsic::x86_sse2_ucomigt_sd: + case Intrinsic::x86_sse2_ucomile_sd: + case Intrinsic::x86_sse2_ucomilt_sd: + case Intrinsic::x86_sse2_ucomineq_sd: { + // These intrinsics only demand the 0th element of their input vectors. If + // we can simplify the input based on that, do so now. + bool MadeChange = false; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + unsigned VWidth = Arg0->getType()->getVectorNumElements(); + if (Value *V = SimplifyDemandedVectorEltsLow(Arg0, VWidth, 1)) { + II->setArgOperand(0, V); + MadeChange = true; + } + if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { + II->setArgOperand(1, V); + MadeChange = true; + } + if (MadeChange) + return II; + break; + } + + case Intrinsic::x86_sse_add_ss: + case Intrinsic::x86_sse_sub_ss: + case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse_div_ss: + case Intrinsic::x86_sse_min_ss: + case Intrinsic::x86_sse_max_ss: + case Intrinsic::x86_sse_cmp_ss: + case Intrinsic::x86_sse2_add_sd: + case Intrinsic::x86_sse2_sub_sd: + case Intrinsic::x86_sse2_mul_sd: + case Intrinsic::x86_sse2_div_sd: + case Intrinsic::x86_sse2_min_sd: + case Intrinsic::x86_sse2_max_sd: + case Intrinsic::x86_sse2_cmp_sd: { + // These intrinsics only demand the lowest element of the second input + // vector. + Value *Arg1 = II->getArgOperand(1); + unsigned VWidth = Arg1->getType()->getVectorNumElements(); + if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { + II->setArgOperand(1, V); + return II; + } + break; + } + + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: { + // These intrinsics demand the upper elements of the first input vector and + // the lowest element of the second input vector. + bool MadeChange = false; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + unsigned VWidth = Arg0->getType()->getVectorNumElements(); + if (Value *V = SimplifyDemandedVectorEltsHigh(Arg0, VWidth, VWidth - 1)) { + II->setArgOperand(0, V); + MadeChange = true; + } + if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { + II->setArgOperand(1, V); + MadeChange = true; + } + if (MadeChange) + return II; + break; + } + // Constant fold ashr( <A x Bi>, Ci ). // Constant fold lshr( <A x Bi>, Ci ). // Constant fold shl( <A x Bi>, Ci ). @@ -1136,8 +1739,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: case Intrinsic::x86_avx2_pslli_w: - if (Value *V = SimplifyX86immshift(*II, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86immShift(*II, *Builder)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_sse2_psra_d: @@ -1156,8 +1759,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_psll_d: case Intrinsic::x86_avx2_psll_q: case Intrinsic::x86_avx2_psll_w: { - if (Value *V = SimplifyX86immshift(*II, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86immShift(*II, *Builder)) + return replaceInstUsesWith(*II, V); // SSE2/AVX2 uses only the first 64-bits of the 128-bit vector // operand to compute the shift amount. @@ -1173,35 +1776,23 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::x86_avx2_pmovsxbd: - case Intrinsic::x86_avx2_pmovsxbq: - case Intrinsic::x86_avx2_pmovsxbw: - case Intrinsic::x86_avx2_pmovsxdq: - case Intrinsic::x86_avx2_pmovsxwd: - case Intrinsic::x86_avx2_pmovsxwq: - if (Value *V = SimplifyX86extend(*II, *Builder, true)) - return ReplaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_sse41_pmovzxbd: - case Intrinsic::x86_sse41_pmovzxbq: - case Intrinsic::x86_sse41_pmovzxbw: - case Intrinsic::x86_sse41_pmovzxdq: - case Intrinsic::x86_sse41_pmovzxwd: - case Intrinsic::x86_sse41_pmovzxwq: - case Intrinsic::x86_avx2_pmovzxbd: - case Intrinsic::x86_avx2_pmovzxbq: - case Intrinsic::x86_avx2_pmovzxbw: - case Intrinsic::x86_avx2_pmovzxdq: - case Intrinsic::x86_avx2_pmovzxwd: - case Intrinsic::x86_avx2_pmovzxwq: - if (Value *V = SimplifyX86extend(*II, *Builder, false)) - return ReplaceInstUsesWith(*II, V); + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + if (Value *V = simplifyX86varShift(*II, *Builder)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_sse41_insertps: - if (Value *V = SimplifyX86insertps(*II, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86insertps(*II, *Builder)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_sse4a_extrq: { @@ -1223,19 +1814,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { : nullptr; // Attempt to simplify to a constant, shuffle vector or EXTRQI call. - if (Value *V = SimplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) + return replaceInstUsesWith(*II, V); // EXTRQ only uses the lowest 64-bits of the first 128-bit vector // operands and the lowest 16-bits of the second. + bool MadeChange = false; if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) { II->setArgOperand(0, V); - return II; + MadeChange = true; } if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 2)) { II->setArgOperand(1, V); - return II; + MadeChange = true; } + if (MadeChange) + return II; break; } @@ -1252,8 +1846,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { ConstantInt *CIIndex = dyn_cast<ConstantInt>(II->getArgOperand(2)); // Attempt to simplify to a constant or shuffle vector. - if (Value *V = SimplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, *Builder)) + return replaceInstUsesWith(*II, V); // EXTRQI only uses the lowest 64-bits of the first 128-bit vector // operand. @@ -1281,11 +1875,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Attempt to simplify to a constant, shuffle vector or INSERTQI call. if (CI11) { - APInt V11 = CI11->getValue(); + const APInt &V11 = CI11->getValue(); APInt Len = V11.zextOrTrunc(6); APInt Idx = V11.lshr(8).zextOrTrunc(6); - if (Value *V = SimplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) + return replaceInstUsesWith(*II, V); } // INSERTQ only uses the lowest 64-bits of the first 128-bit vector @@ -1317,21 +1911,23 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (CILength && CIIndex) { APInt Len = CILength->getValue().zextOrTrunc(6); APInt Idx = CIIndex->getValue().zextOrTrunc(6); - if (Value *V = SimplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, *Builder)) + return replaceInstUsesWith(*II, V); } // INSERTQI only uses the lowest 64-bits of the first two 128-bit vector // operands. + bool MadeChange = false; if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) { II->setArgOperand(0, V); - return II; + MadeChange = true; } - if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 1)) { II->setArgOperand(1, V); - return II; + MadeChange = true; } + if (MadeChange) + return II; break; } @@ -1352,143 +1948,87 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // fold (blend A, A, Mask) -> A if (Op0 == Op1) - return ReplaceInstUsesWith(CI, Op0); + return replaceInstUsesWith(CI, Op0); // Zero Mask - select 1st argument. if (isa<ConstantAggregateZero>(Mask)) - return ReplaceInstUsesWith(CI, Op0); + return replaceInstUsesWith(CI, Op0); // Constant Mask - select 1st/2nd argument lane based on top bit of mask. - if (auto C = dyn_cast<ConstantDataVector>(Mask)) { - auto Tyi1 = Builder->getInt1Ty(); - auto SelectorType = cast<VectorType>(Mask->getType()); - auto EltTy = SelectorType->getElementType(); - unsigned Size = SelectorType->getNumElements(); - unsigned BitWidth = - EltTy->isFloatTy() - ? 32 - : (EltTy->isDoubleTy() ? 64 : EltTy->getIntegerBitWidth()); - assert((BitWidth == 64 || BitWidth == 32 || BitWidth == 8) && - "Wrong arguments for variable blend intrinsic"); - SmallVector<Constant *, 32> Selectors; - for (unsigned I = 0; I < Size; ++I) { - // The intrinsics only read the top bit - uint64_t Selector; - if (BitWidth == 8) - Selector = C->getElementAsInteger(I); - else - Selector = C->getElementAsAPFloat(I).bitcastToAPInt().getZExtValue(); - Selectors.push_back(ConstantInt::get(Tyi1, Selector >> (BitWidth - 1))); - } - auto NewSelector = ConstantVector::get(Selectors); + if (auto *ConstantMask = dyn_cast<ConstantDataVector>(Mask)) { + Constant *NewSelector = getNegativeIsTrueBoolVec(ConstantMask); return SelectInst::Create(NewSelector, Op1, Op0, "blendv"); } break; } case Intrinsic::x86_ssse3_pshuf_b_128: - case Intrinsic::x86_avx2_pshuf_b: { - // Turn pshufb(V1,mask) -> shuffle(V1,Zero,mask) if mask is a constant. - auto *V = II->getArgOperand(1); - auto *VTy = cast<VectorType>(V->getType()); - unsigned NumElts = VTy->getNumElements(); - assert((NumElts == 16 || NumElts == 32) && - "Unexpected number of elements in shuffle mask!"); - // Initialize the resulting shuffle mask to all zeroes. - uint32_t Indexes[32] = {0}; - - if (auto *Mask = dyn_cast<ConstantDataVector>(V)) { - // Each byte in the shuffle control mask forms an index to permute the - // corresponding byte in the destination operand. - for (unsigned I = 0; I < NumElts; ++I) { - int8_t Index = Mask->getElementAsInteger(I); - // If the most significant bit (bit[7]) of each byte of the shuffle - // control mask is set, then zero is written in the result byte. - // The zero vector is in the right-hand side of the resulting - // shufflevector. - - // The value of each index is the least significant 4 bits of the - // shuffle control byte. - Indexes[I] = (Index < 0) ? NumElts : Index & 0xF; - } - } else if (!isa<ConstantAggregateZero>(V)) - break; - - // The value of each index for the high 128-bit lane is the least - // significant 4 bits of the respective shuffle control byte. - for (unsigned I = 16; I < NumElts; ++I) - Indexes[I] += I & 0xF0; - - auto NewC = ConstantDataVector::get(V->getContext(), - makeArrayRef(Indexes, NumElts)); - auto V1 = II->getArgOperand(0); - auto V2 = Constant::getNullValue(II->getType()); - auto Shuffle = Builder->CreateShuffleVector(V1, V2, NewC); - return ReplaceInstUsesWith(CI, Shuffle); - } + case Intrinsic::x86_avx2_pshuf_b: + if (Value *V = simplifyX86pshufb(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; case Intrinsic::x86_avx_vpermilvar_ps: case Intrinsic::x86_avx_vpermilvar_ps_256: case Intrinsic::x86_avx_vpermilvar_pd: - case Intrinsic::x86_avx_vpermilvar_pd_256: { - // Convert vpermil* to shufflevector if the mask is constant. - Value *V = II->getArgOperand(1); - unsigned Size = cast<VectorType>(V->getType())->getNumElements(); - assert(Size == 8 || Size == 4 || Size == 2); - uint32_t Indexes[8]; - if (auto C = dyn_cast<ConstantDataVector>(V)) { - // The intrinsics only read one or two bits, clear the rest. - for (unsigned I = 0; I < Size; ++I) { - uint32_t Index = C->getElementAsInteger(I) & 0x3; - if (II->getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd || - II->getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) - Index >>= 1; - Indexes[I] = Index; - } - } else if (isa<ConstantAggregateZero>(V)) { - for (unsigned I = 0; I < Size; ++I) - Indexes[I] = 0; - } else { - break; - } - // The _256 variants are a bit trickier since the mask bits always index - // into the corresponding 128 half. In order to convert to a generic - // shuffle, we have to make that explicit. - if (II->getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_ps_256 || - II->getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) { - for (unsigned I = Size / 2; I < Size; ++I) - Indexes[I] += Size / 2; - } - auto NewC = - ConstantDataVector::get(V->getContext(), makeArrayRef(Indexes, Size)); - auto V1 = II->getArgOperand(0); - auto V2 = UndefValue::get(V1->getType()); - auto Shuffle = Builder->CreateShuffleVector(V1, V2, NewC); - return ReplaceInstUsesWith(CI, Shuffle); - } + case Intrinsic::x86_avx_vpermilvar_pd_256: + if (Value *V = simplifyX86vpermilvar(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::x86_avx2_permd: + case Intrinsic::x86_avx2_permps: + if (Value *V = simplifyX86vpermv(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; case Intrinsic::x86_avx_vperm2f128_pd_256: case Intrinsic::x86_avx_vperm2f128_ps_256: case Intrinsic::x86_avx_vperm2f128_si_256: case Intrinsic::x86_avx2_vperm2i128: - if (Value *V = SimplifyX86vperm2(*II, *Builder)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86vperm2(*II, *Builder)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::x86_avx_maskload_ps: + case Intrinsic::x86_avx_maskload_pd: + case Intrinsic::x86_avx_maskload_ps_256: + case Intrinsic::x86_avx_maskload_pd_256: + case Intrinsic::x86_avx2_maskload_d: + case Intrinsic::x86_avx2_maskload_q: + case Intrinsic::x86_avx2_maskload_d_256: + case Intrinsic::x86_avx2_maskload_q_256: + if (Instruction *I = simplifyX86MaskedLoad(*II, *this)) + return I; + break; + + case Intrinsic::x86_sse2_maskmov_dqu: + case Intrinsic::x86_avx_maskstore_ps: + case Intrinsic::x86_avx_maskstore_pd: + case Intrinsic::x86_avx_maskstore_ps_256: + case Intrinsic::x86_avx_maskstore_pd_256: + case Intrinsic::x86_avx2_maskstore_d: + case Intrinsic::x86_avx2_maskstore_q: + case Intrinsic::x86_avx2_maskstore_d_256: + case Intrinsic::x86_avx2_maskstore_q_256: + if (simplifyX86MaskedStore(*II, *this)) + return nullptr; break; case Intrinsic::x86_xop_vpcomb: case Intrinsic::x86_xop_vpcomd: case Intrinsic::x86_xop_vpcomq: case Intrinsic::x86_xop_vpcomw: - if (Value *V = SimplifyX86vpcom(*II, *Builder, true)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86vpcom(*II, *Builder, true)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_xop_vpcomub: case Intrinsic::x86_xop_vpcomud: case Intrinsic::x86_xop_vpcomuq: case Intrinsic::x86_xop_vpcomuw: - if (Value *V = SimplifyX86vpcom(*II, *Builder, false)) - return ReplaceInstUsesWith(*II, V); + if (Value *V = simplifyX86vpcom(*II, *Builder, false)) + return replaceInstUsesWith(*II, V); break; case Intrinsic::ppc_altivec_vperm: @@ -1585,7 +2125,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Handle mul by zero first: if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1)) { - return ReplaceInstUsesWith(CI, ConstantAggregateZero::get(II->getType())); + return replaceInstUsesWith(CI, ConstantAggregateZero::get(II->getType())); } // Check for constant LHS & RHS - in this case we just simplify. @@ -1597,7 +2137,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { CV0 = ConstantExpr::getIntegerCast(CV0, NewVT, /*isSigned=*/!Zext); CV1 = ConstantExpr::getIntegerCast(CV1, NewVT, /*isSigned=*/!Zext); - return ReplaceInstUsesWith(CI, ConstantExpr::getMul(CV0, CV1)); + return replaceInstUsesWith(CI, ConstantExpr::getMul(CV0, CV1)); } // Couldn't simplify - canonicalize constant to the RHS. @@ -1615,7 +2155,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::AMDGPU_rcp: { + case Intrinsic::amdgcn_rcp: { if (const ConstantFP *C = dyn_cast<ConstantFP>(II->getArgOperand(0))) { const APFloat &ArgVal = C->getValueAPF(); APFloat Val(ArgVal.getSemantics(), 1.0); @@ -1624,18 +2164,43 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Only do this if it was exact and therefore not dependent on the // rounding mode. if (Status == APFloat::opOK) - return ReplaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Val)); + return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Val)); } break; } + case Intrinsic::amdgcn_frexp_mant: + case Intrinsic::amdgcn_frexp_exp: { + Value *Src = II->getArgOperand(0); + if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) { + int Exp; + APFloat Significand = frexp(C->getValueAPF(), Exp, + APFloat::rmNearestTiesToEven); + + if (II->getIntrinsicID() == Intrinsic::amdgcn_frexp_mant) { + return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), + Significand)); + } + + // Match instruction special case behavior. + if (Exp == APFloat::IEK_NaN || Exp == APFloat::IEK_Inf) + Exp = 0; + + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Exp)); + } + + if (isa<UndefValue>(Src)) + return replaceInstUsesWith(CI, UndefValue::get(II->getType())); + + break; + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. if (IntrinsicInst *SS = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { if (SS->getIntrinsicID() == Intrinsic::stacksave) { if (&*++SS->getIterator() == II) - return EraseInstFromFunction(CI); + return eraseInstFromFunction(CI); } } @@ -1653,8 +2218,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(BCI)) { // If there is a stackrestore below this one, remove this one. if (II->getIntrinsicID() == Intrinsic::stackrestore) - return EraseInstFromFunction(CI); - // Otherwise, ignore the intrinsic. + return eraseInstFromFunction(CI); + + // Bail if we cross over an intrinsic with side effects, such as + // llvm.stacksave, llvm.read_register, or llvm.setjmp. + if (II->mayHaveSideEffects()) { + CannotRemove = true; + break; + } } else { // If we found a non-intrinsic call, we can't remove the stack // restore. @@ -1668,42 +2239,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // are no allocas or calls between the restore and the return, nuke the // restore. if (!CannotRemove && (isa<ReturnInst>(TI) || isa<ResumeInst>(TI))) - return EraseInstFromFunction(CI); + return eraseInstFromFunction(CI); break; } - case Intrinsic::lifetime_start: { - // Remove trivially empty lifetime_start/end ranges, i.e. a start - // immediately followed by an end (ignoring debuginfo or other - // lifetime markers in between). - BasicBlock::iterator BI = II->getIterator(), BE = II->getParent()->end(); - for (++BI; BI != BE; ++BI) { - if (IntrinsicInst *LTE = dyn_cast<IntrinsicInst>(BI)) { - if (isa<DbgInfoIntrinsic>(LTE) || - LTE->getIntrinsicID() == Intrinsic::lifetime_start) - continue; - if (LTE->getIntrinsicID() == Intrinsic::lifetime_end) { - if (II->getOperand(0) == LTE->getOperand(0) && - II->getOperand(1) == LTE->getOperand(1)) { - EraseInstFromFunction(*LTE); - return EraseInstFromFunction(*II); - } - continue; - } - } - break; - } + case Intrinsic::lifetime_start: + if (removeTriviallyEmptyRange(*II, Intrinsic::lifetime_start, + Intrinsic::lifetime_end, *this)) + return nullptr; break; - } case Intrinsic::assume: { + Value *IIOperand = II->getArgOperand(0); + // Remove an assume if it is immediately followed by an identical assume. + if (match(II->getNextNode(), + m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand)))) + return eraseInstFromFunction(CI); + // Canonicalize assume(a && b) -> assume(a); assume(b); // Note: New assumption intrinsics created here are registered by // the InstCombineIRInserter object. - Value *IIOperand = II->getArgOperand(0), *A, *B, - *AssumeIntrinsic = II->getCalledValue(); + Value *AssumeIntrinsic = II->getCalledValue(), *A, *B; if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) { Builder->CreateCall(AssumeIntrinsic, A, II->getName()); Builder->CreateCall(AssumeIntrinsic, B, II->getName()); - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); } // assume(!(a || b)) -> assume(!a); assume(!b); if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) { @@ -1711,7 +2269,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->getName()); Builder->CreateCall(AssumeIntrinsic, Builder->CreateNot(B), II->getName()); - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); } // assume( (load addr) != null ) -> add 'nonnull' metadata to load @@ -1728,7 +2286,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (isValidAssumeForContext(II, LI, DT)) { MDNode *MD = MDNode::get(II->getContext(), None); LI->setMetadata(LLVMContext::MD_nonnull, MD); - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); } } // TODO: apply nonnull return attributes to calls and invokes @@ -1739,7 +2297,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APInt KnownZero(1, 0), KnownOne(1, 0); computeKnownBits(IIOperand, KnownZero, KnownOne, 0, II); if (KnownOne.isAllOnesValue()) - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); break; } @@ -1748,46 +2306,38 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // facts about the relocate value, while being careful to // preserve relocation semantics. Value *DerivedPtr = cast<GCRelocateInst>(II)->getDerivedPtr(); - auto *GCRelocateType = cast<PointerType>(II->getType()); // Remove the relocation if unused, note that this check is required // to prevent the cases below from looping forever. if (II->use_empty()) - return EraseInstFromFunction(*II); + return eraseInstFromFunction(*II); // Undef is undef, even after relocation. // TODO: provide a hook for this in GCStrategy. This is clearly legal for // most practical collectors, but there was discussion in the review thread // about whether it was legal for all possible collectors. - if (isa<UndefValue>(DerivedPtr)) { - // gc_relocate is uncasted. Use undef of gc_relocate's type to replace it. - return ReplaceInstUsesWith(*II, UndefValue::get(GCRelocateType)); - } + if (isa<UndefValue>(DerivedPtr)) + // Use undef of gc_relocate's type to replace it. + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); - // The relocation of null will be null for most any collector. - // TODO: provide a hook for this in GCStrategy. There might be some weird - // collector this property does not hold for. - if (isa<ConstantPointerNull>(DerivedPtr)) { - // gc_relocate is uncasted. Use null-pointer of gc_relocate's type to replace it. - return ReplaceInstUsesWith(*II, ConstantPointerNull::get(GCRelocateType)); - } + if (auto *PT = dyn_cast<PointerType>(II->getType())) { + // The relocation of null will be null for most any collector. + // TODO: provide a hook for this in GCStrategy. There might be some + // weird collector this property does not hold for. + if (isa<ConstantPointerNull>(DerivedPtr)) + // Use null-pointer of gc_relocate's type to replace it. + return replaceInstUsesWith(*II, ConstantPointerNull::get(PT)); - // isKnownNonNull -> nonnull attribute - if (isKnownNonNullAt(DerivedPtr, II, DT, TLI)) - II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); - - // isDereferenceablePointer -> deref attribute - if (isDereferenceablePointer(DerivedPtr, DL)) { - if (Argument *A = dyn_cast<Argument>(DerivedPtr)) { - uint64_t Bytes = A->getDereferenceableBytes(); - II->addDereferenceableAttr(AttributeSet::ReturnIndex, Bytes); - } + // isKnownNonNull -> nonnull attribute + if (isKnownNonNullAt(DerivedPtr, II, DT)) + II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); } // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) // Canonicalize on the type from the uses to the defs // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...) + break; } } @@ -1800,8 +2350,8 @@ Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { return visitCallSite(&II); } -/// isSafeToEliminateVarargsCast - If this cast does not affect the value -/// passed through the varargs area, we can eliminate the use of the cast. +/// If this cast does not affect the value passed through the varargs area, we +/// can eliminate the use of the cast. static bool isSafeToEliminateVarargsCast(const CallSite CS, const DataLayout &DL, const CastInst *const CI, @@ -1833,26 +2383,22 @@ static bool isSafeToEliminateVarargsCast(const CallSite CS, return true; } -// Try to fold some different type of calls here. -// Currently we're only working with the checking functions, memcpy_chk, -// mempcpy_chk, memmove_chk, memset_chk, strcpy_chk, stpcpy_chk, strncpy_chk, -// strcat_chk and strncat_chk. Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { if (!CI->getCalledFunction()) return nullptr; auto InstCombineRAUW = [this](Instruction *From, Value *With) { - ReplaceInstUsesWith(*From, With); + replaceInstUsesWith(*From, With); }; LibCallSimplifier Simplifier(DL, TLI, InstCombineRAUW); if (Value *With = Simplifier.optimizeCall(CI)) { ++NumSimplified; - return CI->use_empty() ? CI : ReplaceInstUsesWith(*CI, With); + return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); } return nullptr; } -static IntrinsicInst *FindInitTrampolineFromAlloca(Value *TrampMem) { +static IntrinsicInst *findInitTrampolineFromAlloca(Value *TrampMem) { // Strip off at most one level of pointer casts, looking for an alloca. This // is good enough in practice and simpler than handling any number of casts. Value *Underlying = TrampMem->stripPointerCasts(); @@ -1891,7 +2437,7 @@ static IntrinsicInst *FindInitTrampolineFromAlloca(Value *TrampMem) { return InitTrampoline; } -static IntrinsicInst *FindInitTrampolineFromBB(IntrinsicInst *AdjustTramp, +static IntrinsicInst *findInitTrampolineFromBB(IntrinsicInst *AdjustTramp, Value *TrampMem) { // Visit all the previous instructions in the basic block, and try to find a // init.trampoline which has a direct path to the adjust.trampoline. @@ -1913,7 +2459,7 @@ static IntrinsicInst *FindInitTrampolineFromBB(IntrinsicInst *AdjustTramp, // call to llvm.init.trampoline if the call to the trampoline can be optimized // to a direct call to a function. Otherwise return NULL. // -static IntrinsicInst *FindInitTrampoline(Value *Callee) { +static IntrinsicInst *findInitTrampoline(Value *Callee) { Callee = Callee->stripPointerCasts(); IntrinsicInst *AdjustTramp = dyn_cast<IntrinsicInst>(Callee); if (!AdjustTramp || @@ -1922,15 +2468,14 @@ static IntrinsicInst *FindInitTrampoline(Value *Callee) { Value *TrampMem = AdjustTramp->getOperand(0); - if (IntrinsicInst *IT = FindInitTrampolineFromAlloca(TrampMem)) + if (IntrinsicInst *IT = findInitTrampolineFromAlloca(TrampMem)) return IT; - if (IntrinsicInst *IT = FindInitTrampolineFromBB(AdjustTramp, TrampMem)) + if (IntrinsicInst *IT = findInitTrampolineFromBB(AdjustTramp, TrampMem)) return IT; return nullptr; } -// visitCallSite - Improvements for call and invoke instructions. -// +/// Improvements for call and invoke instructions. Instruction *InstCombiner::visitCallSite(CallSite CS) { if (isAllocLikeFn(CS.getInstruction(), TLI)) @@ -1945,8 +2490,9 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { unsigned ArgNo = 0; for (Value *V : CS.args()) { - if (V->getType()->isPointerTy() && !CS.paramHasAttr(ArgNo+1, Attribute::NonNull) && - isKnownNonNullAt(V, CS.getInstruction(), DT, TLI)) + if (V->getType()->isPointerTy() && + !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && + isKnownNonNullAt(V, CS.getInstruction(), DT)) Indices.push_back(ArgNo + 1); ArgNo++; } @@ -1968,7 +2514,16 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { if (!isa<Function>(Callee) && transformConstExprCastCall(CS)) return nullptr; - if (Function *CalleeF = dyn_cast<Function>(Callee)) + if (Function *CalleeF = dyn_cast<Function>(Callee)) { + // Remove the convergent attr on calls when the callee is not convergent. + if (CS.isConvergent() && !CalleeF->isConvergent() && + !CalleeF->isIntrinsic()) { + DEBUG(dbgs() << "Removing convergent attr from instr " + << CS.getInstruction() << "\n"); + CS.setNotConvergent(); + return CS.getInstruction(); + } + // If the call and callee calling conventions don't match, this call must // be unreachable, as the call is undefined. if (CalleeF->getCallingConv() != CS.getCallingConv() && @@ -1983,9 +2538,9 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { // If OldCall does not return void then replaceAllUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!OldCall->getType()->isVoidTy()) - ReplaceInstUsesWith(*OldCall, UndefValue::get(OldCall->getType())); + replaceInstUsesWith(*OldCall, UndefValue::get(OldCall->getType())); if (isa<CallInst>(OldCall)) - return EraseInstFromFunction(*OldCall); + return eraseInstFromFunction(*OldCall); // We cannot remove an invoke, because it would change the CFG, just // change the callee to a null pointer. @@ -1993,12 +2548,13 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { Constant::getNullValue(CalleeF->getType())); return nullptr; } + } if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { // If CS does not return void then replaceAllUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!CS.getInstruction()->getType()->isVoidTy()) - ReplaceInstUsesWith(*CS.getInstruction(), + replaceInstUsesWith(*CS.getInstruction(), UndefValue::get(CS.getInstruction()->getType())); if (isa<InvokeInst>(CS.getInstruction())) { @@ -2013,10 +2569,10 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { UndefValue::get(Type::getInt1PtrTy(Callee->getContext())), CS.getInstruction()); - return EraseInstFromFunction(*CS.getInstruction()); + return eraseInstFromFunction(*CS.getInstruction()); } - if (IntrinsicInst *II = FindInitTrampoline(Callee)) + if (IntrinsicInst *II = findInitTrampoline(Callee)) return transformCallThroughTrampoline(CS, II); PointerType *PTy = cast<PointerType>(Callee->getType()); @@ -2048,15 +2604,14 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { Instruction *I = tryOptimizeCall(CI); // If we changed something return the result, etc. Otherwise let // the fallthrough check. - if (I) return EraseInstFromFunction(*I); + if (I) return eraseInstFromFunction(*I); } return Changed ? CS.getInstruction() : nullptr; } -// transformConstExprCastCall - If the callee is a constexpr cast of a function, -// attempt to move the cast to the arguments of the call/invoke. -// +/// If the callee is a constexpr cast of a function, attempt to move the cast to +/// the arguments of the call/invoke. bool InstCombiner::transformConstExprCastCall(CallSite CS) { Function *Callee = dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts()); @@ -2316,7 +2871,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { } if (!Caller->use_empty()) - ReplaceInstUsesWith(*Caller, NV); + replaceInstUsesWith(*Caller, NV); else if (Caller->hasValueHandle()) { if (OldRetTy == NV->getType()) ValueHandleBase::ValueIsRAUWd(Caller, NV); @@ -2326,14 +2881,12 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { ValueHandleBase::ValueIsDeleted(Caller); } - EraseInstFromFunction(*Caller); + eraseInstFromFunction(*Caller); return true; } -// transformCallThroughTrampoline - Turn a call to a function created by -// init_trampoline / adjust_trampoline intrinsic pair into a direct call to the -// underlying function. -// +/// Turn a call to a function created by init_trampoline / adjust_trampoline +/// intrinsic pair into a direct call to the underlying function. Instruction * InstCombiner::transformCallThroughTrampoline(CallSite CS, IntrinsicInst *Tramp) { @@ -2351,8 +2904,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, "transformCallThroughTrampoline called with incorrect CallSite."); Function *NestF =cast<Function>(Tramp->getArgOperand(1)->stripPointerCasts()); - PointerType *NestFPTy = cast<PointerType>(NestF->getType()); - FunctionType *NestFTy = cast<FunctionType>(NestFPTy->getElementType()); + FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType()); const AttributeSet &NestAttrs = NestF->getAttributes(); if (!NestAttrs.isEmpty()) { @@ -2412,7 +2964,8 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, Idx + (Idx >= NestIdx), B)); } - ++Idx, ++I; + ++Idx; + ++I; } while (1); } @@ -2446,7 +2999,8 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Add the original type. NewTypes.push_back(*I); - ++Idx, ++I; + ++Idx; + ++I; } while (1); } @@ -2461,15 +3015,18 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, const AttributeSet &NewPAL = AttributeSet::get(FTy->getContext(), NewAttrs); + SmallVector<OperandBundleDef, 1> OpBundles; + CS.getOperandBundlesAsDefs(OpBundles); + Instruction *NewCaller; if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { NewCaller = InvokeInst::Create(NewCallee, II->getNormalDest(), II->getUnwindDest(), - NewArgs); + NewArgs, OpBundles); cast<InvokeInst>(NewCaller)->setCallingConv(II->getCallingConv()); cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); } else { - NewCaller = CallInst::Create(NewCallee, NewArgs); + NewCaller = CallInst::Create(NewCallee, NewArgs, OpBundles); if (cast<CallInst>(Caller)->isTailCall()) cast<CallInst>(NewCaller)->setTailCall(); cast<CallInst>(NewCaller)-> diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 0f01d183b1ad..20556157188f 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -149,9 +149,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, // New is the allocation instruction, pointer typed. AI is the original // allocation instruction, also pointer typed. Thus, cast to use is BitCast. Value *NewCast = AllocaBuilder.CreateBitCast(New, AI.getType(), "tmpcast"); - ReplaceInstUsesWith(AI, NewCast); + replaceInstUsesWith(AI, NewCast); } - return ReplaceInstUsesWith(CI, New); + return replaceInstUsesWith(CI, New); } /// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns @@ -508,7 +508,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { " to avoid cast: " << CI << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); - return ReplaceInstUsesWith(CI, Res); + return replaceInstUsesWith(CI, Res); } // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0), likewise for vector. @@ -532,7 +532,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // If the shift amount is larger than the size of A, then the result is // known to be zero because all the input bits got shifted out. if (Cst->getZExtValue() >= ASize) - return ReplaceInstUsesWith(CI, Constant::getNullValue(DestTy)); + return replaceInstUsesWith(CI, Constant::getNullValue(DestTy)); // Since we're doing an lshr and a zero extend, and know that the shift // amount is smaller than ASize, it is always safe to do the shift in A's @@ -606,7 +606,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, In = Builder->CreateXor(In, One, In->getName() + ".not"); } - return ReplaceInstUsesWith(CI, In); + return replaceInstUsesWith(CI, In); } // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. @@ -636,7 +636,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, Constant *Res = ConstantInt::get(Type::getInt1Ty(CI.getContext()), isNE); Res = ConstantExpr::getZExt(Res, CI.getType()); - return ReplaceInstUsesWith(CI, Res); + return replaceInstUsesWith(CI, Res); } uint32_t ShAmt = KnownZeroMask.logBase2(); @@ -654,7 +654,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, } if (CI.getType() == In->getType()) - return ReplaceInstUsesWith(CI, In); + return replaceInstUsesWith(CI, In); return CastInst::CreateIntegerCast(In, CI.getType(), false/*ZExt*/); } } @@ -694,7 +694,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, if (ICI->getPredicate() == ICmpInst::ICMP_EQ) Result = Builder->CreateXor(Result, ConstantInt::get(ITy, 1)); Result->takeName(ICI); - return ReplaceInstUsesWith(CI, Result); + return replaceInstUsesWith(CI, Result); } } } @@ -872,7 +872,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { APInt::getHighBitsSet(DestBitSize, DestBitSize-SrcBitsKept), 0, &CI)) - return ReplaceInstUsesWith(CI, Res); + return replaceInstUsesWith(CI, Res); // We need to emit an AND to clear the high bits. Constant *C = ConstantInt::get(Res->getType(), @@ -986,7 +986,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { if (Pred == ICmpInst::ICMP_SGT) In = Builder->CreateNot(In, In->getName()+".not"); - return ReplaceInstUsesWith(CI, In); + return replaceInstUsesWith(CI, In); } } @@ -1009,7 +1009,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { Value *V = Pred == ICmpInst::ICMP_NE ? ConstantInt::getAllOnesValue(CI.getType()) : ConstantInt::getNullValue(CI.getType()); - return ReplaceInstUsesWith(CI, V); + return replaceInstUsesWith(CI, V); } if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { @@ -1041,7 +1041,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { } if (CI.getType() == In->getType()) - return ReplaceInstUsesWith(CI, In); + return replaceInstUsesWith(CI, In); return CastInst::CreateIntegerCast(In, CI.getType(), true/*SExt*/); } } @@ -1137,7 +1137,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { ComputeSignBit(Src, KnownZero, KnownOne, 0, &CI); if (KnownZero) { Value *ZExt = Builder->CreateZExt(Src, DestTy); - return ReplaceInstUsesWith(CI, ZExt); + return replaceInstUsesWith(CI, ZExt); } // Attempt to extend the entire input expression tree to the destination @@ -1158,7 +1158,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // If the high bits are already filled with sign bit, just replace this // cast with the result. if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize) - return ReplaceInstUsesWith(CI, Res); + return replaceInstUsesWith(CI, Res); // We need to emit a shl + ashr to do the sign extend. Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); @@ -1400,8 +1400,11 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { Function *Overload = Intrinsic::getDeclaration( CI.getModule(), II->getIntrinsicID(), IntrinsicType); + SmallVector<OperandBundleDef, 1> OpBundles; + II->getOperandBundlesAsDefs(OpBundles); + Value *Args[] = { InnerTrunc }; - return CallInst::Create(Overload, Args, II->getName()); + return CallInst::Create(Overload, Args, OpBundles, II->getName()); } } } @@ -1451,7 +1454,7 @@ Instruction *InstCombiner::FoldItoFPtoI(Instruction &FI) { if (FITy->getScalarSizeInBits() < SrcTy->getScalarSizeInBits()) return new TruncInst(SrcI, FITy); if (SrcTy == FITy) - return ReplaceInstUsesWith(FI, SrcI); + return replaceInstUsesWith(FI, SrcI); return new BitCastInst(SrcI, FITy); } return nullptr; @@ -1796,7 +1799,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // Get rid of casts from one type to the same type. These are useless and can // be replaced by the operand. if (DestTy == Src->getType()) - return ReplaceInstUsesWith(CI, Src); + return replaceInstUsesWith(CI, Src); if (PointerType *DstPTy = dyn_cast<PointerType>(DestTy)) { PointerType *SrcPTy = cast<PointerType>(SrcTy); @@ -1811,6 +1814,13 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { if (Instruction *V = PromoteCastOfAllocation(CI, *AI)) return V; + // When the type pointed to is not sized the cast cannot be + // turned into a gep. + Type *PointeeType = + cast<PointerType>(Src->getType()->getScalarType())->getElementType(); + if (!PointeeType->isSized()) + return nullptr; + // If the source and destination are pointers, and this cast is equivalent // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. // This can enhance SROA and other transforms that want type-safe pointers. @@ -1854,7 +1864,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // assemble the elements of the vector manually. Try to rip the code out // and replace it with insertelements. if (Value *V = optimizeIntegerToVectorInsertions(CI, *this)) - return ReplaceInstUsesWith(CI, V); + return replaceInstUsesWith(CI, V); } } diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index d9311a343ead..bfd73f4bbac5 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -13,18 +13,19 @@ #include "InstCombineInternal.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Analysis/TargetLibraryInfo.h" using namespace llvm; using namespace PatternMatch; @@ -55,8 +56,8 @@ static bool HasAddOverflow(ConstantInt *Result, return Result->getValue().slt(In1->getValue()); } -/// AddWithOverflow - Compute Result = In1+In2, returning true if the result -/// overflowed for this type. +/// Compute Result = In1+In2, returning true if the result overflowed for this +/// type. static bool AddWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getAdd(In1, In2); @@ -90,8 +91,8 @@ static bool HasSubOverflow(ConstantInt *Result, return Result->getValue().sgt(In1->getValue()); } -/// SubWithOverflow - Compute Result = In1-In2, returning true if the result -/// overflowed for this type. +/// Compute Result = In1-In2, returning true if the result overflowed for this +/// type. static bool SubWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getSub(In1, In2); @@ -113,13 +114,21 @@ static bool SubWithOverflow(Constant *&Result, Constant *In1, IsSigned); } -/// isSignBitCheck - Given an exploded icmp instruction, return true if the -/// comparison only checks the sign bit. If it only checks the sign bit, set -/// TrueIfSigned if the result of the comparison is true when the input value is -/// signed. -static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS, +/// Given an icmp instruction, return true if any use of this comparison is a +/// branch on sign bit comparison. +static bool isBranchOnSignBitCheck(ICmpInst &I, bool isSignBit) { + for (auto *U : I.users()) + if (isa<BranchInst>(U)) + return isSignBit; + return false; +} + +/// Given an exploded icmp instruction, return true if the comparison only +/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the +/// result of the comparison is true when the input value is signed. +static bool isSignBitCheck(ICmpInst::Predicate Pred, ConstantInt *RHS, bool &TrueIfSigned) { - switch (pred) { + switch (Pred) { case ICmpInst::ICMP_SLT: // True if LHS s< 0 TrueIfSigned = true; return RHS->isZero(); @@ -145,21 +154,21 @@ static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS, /// Returns true if the exploded icmp can be expressed as a signed comparison /// to zero and updates the predicate accordingly. /// The signedness of the comparison is preserved. -static bool isSignTest(ICmpInst::Predicate &pred, const ConstantInt *RHS) { - if (!ICmpInst::isSigned(pred)) +static bool isSignTest(ICmpInst::Predicate &Pred, const ConstantInt *RHS) { + if (!ICmpInst::isSigned(Pred)) return false; if (RHS->isZero()) - return ICmpInst::isRelational(pred); + return ICmpInst::isRelational(Pred); if (RHS->isOne()) { - if (pred == ICmpInst::ICMP_SLT) { - pred = ICmpInst::ICMP_SLE; + if (Pred == ICmpInst::ICMP_SLT) { + Pred = ICmpInst::ICMP_SLE; return true; } } else if (RHS->isAllOnesValue()) { - if (pred == ICmpInst::ICMP_SGT) { - pred = ICmpInst::ICMP_SGE; + if (Pred == ICmpInst::ICMP_SGT) { + Pred = ICmpInst::ICMP_SGE; return true; } } @@ -167,19 +176,18 @@ static bool isSignTest(ICmpInst::Predicate &pred, const ConstantInt *RHS) { return false; } -// isHighOnes - Return true if the constant is of the form 1+0+. -// This is the same as lowones(~X). +/// Return true if the constant is of the form 1+0+. This is the same as +/// lowones(~X). static bool isHighOnes(const ConstantInt *CI) { return (~CI->getValue() + 1).isPowerOf2(); } -/// ComputeSignedMinMaxValuesFromKnownBits - Given a signed integer type and a -/// set of known zero and one bits, compute the maximum and minimum values that -/// could have the specified known zero and known one bits, returning them in -/// min/max. -static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero, - const APInt& KnownOne, - APInt& Min, APInt& Max) { +/// Given a signed integer type and a set of known zero and one bits, compute +/// the maximum and minimum values that could have the specified known zero and +/// known one bits, returning them in Min/Max. +static void ComputeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, + const APInt &KnownOne, + APInt &Min, APInt &Max) { assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && KnownZero.getBitWidth() == Min.getBitWidth() && KnownZero.getBitWidth() == Max.getBitWidth() && @@ -197,10 +205,9 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero, } } -// ComputeUnsignedMinMaxValuesFromKnownBits - Given an unsigned integer type and -// a set of known zero and one bits, compute the maximum and minimum values that -// could have the specified known zero and known one bits, returning them in -// min/max. +/// Given an unsigned integer type and a set of known zero and one bits, compute +/// the maximum and minimum values that could have the specified known zero and +/// known one bits, returning them in Min/Max. static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { @@ -216,14 +223,14 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, Max = KnownOne|UnknownBits; } -/// FoldCmpLoadFromIndexedGlobal - Called we see this pattern: +/// This is called when we see this pattern: /// cmp pred (load (gep GV, ...)), cmpcst -/// where GV is a global variable with a constant initializer. Try to simplify -/// this into some simple computation that does not need the load. For example +/// where GV is a global variable with a constant initializer. Try to simplify +/// this into some simple computation that does not need the load. For example /// we can optimize "icmp eq (load (gep "foo", 0, i)), 0" into "icmp eq i, 3". /// /// If AndCst is non-null, then the loaded value is masked with that constant -/// before doing the comparison. This handles cases like "A[i]&4 == 0". +/// before doing the comparison. This handles cases like "A[i]&4 == 0". Instruction *InstCombiner:: FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, ConstantInt *AndCst) { @@ -401,7 +408,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, if (SecondTrueElement != Overdefined) { // None true -> false. if (FirstTrueElement == Undefined) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); Value *FirstTrueIdx = ConstantInt::get(Idx->getType(), FirstTrueElement); @@ -421,7 +428,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, if (SecondFalseElement != Overdefined) { // None false -> true. if (FirstFalseElement == Undefined) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); Value *FirstFalseIdx = ConstantInt::get(Idx->getType(), FirstFalseElement); @@ -492,12 +499,12 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, return nullptr; } -/// EvaluateGEPOffsetExpression - Return a value that can be used to compare -/// the *offset* implied by a GEP to zero. For example, if we have &A[i], we -/// want to return 'i' for "icmp ne i, 0". Note that, in general, indices can -/// be complex, and scales are involved. The above expression would also be -/// legal to codegen as "icmp ne (i*4), 0" (assuming A is a pointer to i32). -/// This later form is less amenable to optimization though, and we are allowed +/// Return a value that can be used to compare the *offset* implied by a GEP to +/// zero. For example, if we have &A[i], we want to return 'i' for +/// "icmp ne i, 0". Note that, in general, indices can be complex, and scales +/// are involved. The above expression would also be legal to codegen as +/// "icmp ne (i*4), 0" (assuming A is a pointer to i32). +/// This latter form is less amenable to optimization though, and we are allowed /// to generate the first by knowing that pointer arithmetic doesn't overflow. /// /// If we can't emit an optimized form for this expression, this returns null. @@ -595,8 +602,323 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, return IC.Builder->CreateAdd(VariableIdx, OffsetVal, "offset"); } -/// FoldGEPICmp - Fold comparisons between a GEP instruction and something -/// else. At this point we know that the GEP is on the LHS of the comparison. +/// Returns true if we can rewrite Start as a GEP with pointer Base +/// and some integer offset. The nodes that need to be re-written +/// for this transformation will be added to Explored. +static bool canRewriteGEPAsOffset(Value *Start, Value *Base, + const DataLayout &DL, + SetVector<Value *> &Explored) { + SmallVector<Value *, 16> WorkList(1, Start); + Explored.insert(Base); + + // The following traversal gives us an order which can be used + // when doing the final transformation. Since in the final + // transformation we create the PHI replacement instructions first, + // we don't have to get them in any particular order. + // + // However, for other instructions we will have to traverse the + // operands of an instruction first, which means that we have to + // do a post-order traversal. + while (!WorkList.empty()) { + SetVector<PHINode *> PHIs; + + while (!WorkList.empty()) { + if (Explored.size() >= 100) + return false; + + Value *V = WorkList.back(); + + if (Explored.count(V) != 0) { + WorkList.pop_back(); + continue; + } + + if (!isa<IntToPtrInst>(V) && !isa<PtrToIntInst>(V) && + !isa<GEPOperator>(V) && !isa<PHINode>(V)) + // We've found some value that we can't explore which is different from + // the base. Therefore we can't do this transformation. + return false; + + if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) { + auto *CI = dyn_cast<CastInst>(V); + if (!CI->isNoopCast(DL)) + return false; + + if (Explored.count(CI->getOperand(0)) == 0) + WorkList.push_back(CI->getOperand(0)); + } + + if (auto *GEP = dyn_cast<GEPOperator>(V)) { + // We're limiting the GEP to having one index. This will preserve + // the original pointer type. We could handle more cases in the + // future. + if (GEP->getNumIndices() != 1 || !GEP->isInBounds() || + GEP->getType() != Start->getType()) + return false; + + if (Explored.count(GEP->getOperand(0)) == 0) + WorkList.push_back(GEP->getOperand(0)); + } + + if (WorkList.back() == V) { + WorkList.pop_back(); + // We've finished visiting this node, mark it as such. + Explored.insert(V); + } + + if (auto *PN = dyn_cast<PHINode>(V)) { + // We cannot transform PHIs on unsplittable basic blocks. + if (isa<CatchSwitchInst>(PN->getParent()->getTerminator())) + return false; + Explored.insert(PN); + PHIs.insert(PN); + } + } + + // Explore the PHI nodes further. + for (auto *PN : PHIs) + for (Value *Op : PN->incoming_values()) + if (Explored.count(Op) == 0) + WorkList.push_back(Op); + } + + // Make sure that we can do this. Since we can't insert GEPs in a basic + // block before a PHI node, we can't easily do this transformation if + // we have PHI node users of transformed instructions. + for (Value *Val : Explored) { + for (Value *Use : Val->uses()) { + + auto *PHI = dyn_cast<PHINode>(Use); + auto *Inst = dyn_cast<Instruction>(Val); + + if (Inst == Base || Inst == PHI || !Inst || !PHI || + Explored.count(PHI) == 0) + continue; + + if (PHI->getParent() == Inst->getParent()) + return false; + } + } + return true; +} + +// Sets the appropriate insert point on Builder where we can add +// a replacement Instruction for V (if that is possible). +static void setInsertionPoint(IRBuilder<> &Builder, Value *V, + bool Before = true) { + if (auto *PHI = dyn_cast<PHINode>(V)) { + Builder.SetInsertPoint(&*PHI->getParent()->getFirstInsertionPt()); + return; + } + if (auto *I = dyn_cast<Instruction>(V)) { + if (!Before) + I = &*std::next(I->getIterator()); + Builder.SetInsertPoint(I); + return; + } + if (auto *A = dyn_cast<Argument>(V)) { + // Set the insertion point in the entry block. + BasicBlock &Entry = A->getParent()->getEntryBlock(); + Builder.SetInsertPoint(&*Entry.getFirstInsertionPt()); + return; + } + // Otherwise, this is a constant and we don't need to set a new + // insertion point. + assert(isa<Constant>(V) && "Setting insertion point for unknown value!"); +} + +/// Returns a re-written value of Start as an indexed GEP using Base as a +/// pointer. +static Value *rewriteGEPAsOffset(Value *Start, Value *Base, + const DataLayout &DL, + SetVector<Value *> &Explored) { + // Perform all the substitutions. This is a bit tricky because we can + // have cycles in our use-def chains. + // 1. Create the PHI nodes without any incoming values. + // 2. Create all the other values. + // 3. Add the edges for the PHI nodes. + // 4. Emit GEPs to get the original pointers. + // 5. Remove the original instructions. + Type *IndexType = IntegerType::get( + Base->getContext(), DL.getPointerTypeSizeInBits(Start->getType())); + + DenseMap<Value *, Value *> NewInsts; + NewInsts[Base] = ConstantInt::getNullValue(IndexType); + + // Create the new PHI nodes, without adding any incoming values. + for (Value *Val : Explored) { + if (Val == Base) + continue; + // Create empty phi nodes. This avoids cyclic dependencies when creating + // the remaining instructions. + if (auto *PHI = dyn_cast<PHINode>(Val)) + NewInsts[PHI] = PHINode::Create(IndexType, PHI->getNumIncomingValues(), + PHI->getName() + ".idx", PHI); + } + IRBuilder<> Builder(Base->getContext()); + + // Create all the other instructions. + for (Value *Val : Explored) { + + if (NewInsts.find(Val) != NewInsts.end()) + continue; + + if (auto *CI = dyn_cast<CastInst>(Val)) { + NewInsts[CI] = NewInsts[CI->getOperand(0)]; + continue; + } + if (auto *GEP = dyn_cast<GEPOperator>(Val)) { + Value *Index = NewInsts[GEP->getOperand(1)] ? NewInsts[GEP->getOperand(1)] + : GEP->getOperand(1); + setInsertionPoint(Builder, GEP); + // Indices might need to be sign extended. GEPs will magically do + // this, but we need to do it ourselves here. + if (Index->getType()->getScalarSizeInBits() != + NewInsts[GEP->getOperand(0)]->getType()->getScalarSizeInBits()) { + Index = Builder.CreateSExtOrTrunc( + Index, NewInsts[GEP->getOperand(0)]->getType(), + GEP->getOperand(0)->getName() + ".sext"); + } + + auto *Op = NewInsts[GEP->getOperand(0)]; + if (isa<ConstantInt>(Op) && dyn_cast<ConstantInt>(Op)->isZero()) + NewInsts[GEP] = Index; + else + NewInsts[GEP] = Builder.CreateNSWAdd( + Op, Index, GEP->getOperand(0)->getName() + ".add"); + continue; + } + if (isa<PHINode>(Val)) + continue; + + llvm_unreachable("Unexpected instruction type"); + } + + // Add the incoming values to the PHI nodes. + for (Value *Val : Explored) { + if (Val == Base) + continue; + // All the instructions have been created, we can now add edges to the + // phi nodes. + if (auto *PHI = dyn_cast<PHINode>(Val)) { + PHINode *NewPhi = static_cast<PHINode *>(NewInsts[PHI]); + for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) { + Value *NewIncoming = PHI->getIncomingValue(I); + + if (NewInsts.find(NewIncoming) != NewInsts.end()) + NewIncoming = NewInsts[NewIncoming]; + + NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I)); + } + } + } + + for (Value *Val : Explored) { + if (Val == Base) + continue; + + // Depending on the type, for external users we have to emit + // a GEP or a GEP + ptrtoint. + setInsertionPoint(Builder, Val, false); + + // If required, create an inttoptr instruction for Base. + Value *NewBase = Base; + if (!Base->getType()->isPointerTy()) + NewBase = Builder.CreateBitOrPointerCast(Base, Start->getType(), + Start->getName() + "to.ptr"); + + Value *GEP = Builder.CreateInBoundsGEP( + Start->getType()->getPointerElementType(), NewBase, + makeArrayRef(NewInsts[Val]), Val->getName() + ".ptr"); + + if (!Val->getType()->isPointerTy()) { + Value *Cast = Builder.CreatePointerCast(GEP, Val->getType(), + Val->getName() + ".conv"); + GEP = Cast; + } + Val->replaceAllUsesWith(GEP); + } + + return NewInsts[Start]; +} + +/// Looks through GEPs, IntToPtrInsts and PtrToIntInsts in order to express +/// the input Value as a constant indexed GEP. Returns a pair containing +/// the GEPs Pointer and Index. +static std::pair<Value *, Value *> +getAsConstantIndexedAddress(Value *V, const DataLayout &DL) { + Type *IndexType = IntegerType::get(V->getContext(), + DL.getPointerTypeSizeInBits(V->getType())); + + Constant *Index = ConstantInt::getNullValue(IndexType); + while (true) { + if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + // We accept only inbouds GEPs here to exclude the possibility of + // overflow. + if (!GEP->isInBounds()) + break; + if (GEP->hasAllConstantIndices() && GEP->getNumIndices() == 1 && + GEP->getType() == V->getType()) { + V = GEP->getOperand(0); + Constant *GEPIndex = static_cast<Constant *>(GEP->getOperand(1)); + Index = ConstantExpr::getAdd( + Index, ConstantExpr::getSExtOrBitCast(GEPIndex, IndexType)); + continue; + } + break; + } + if (auto *CI = dyn_cast<IntToPtrInst>(V)) { + if (!CI->isNoopCast(DL)) + break; + V = CI->getOperand(0); + continue; + } + if (auto *CI = dyn_cast<PtrToIntInst>(V)) { + if (!CI->isNoopCast(DL)) + break; + V = CI->getOperand(0); + continue; + } + break; + } + return {V, Index}; +} + +/// Converts (CMP GEPLHS, RHS) if this change would make RHS a constant. +/// We can look through PHIs, GEPs and casts in order to determine a common base +/// between GEPLHS and RHS. +static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, + const DataLayout &DL) { + if (!GEPLHS->hasAllConstantIndices()) + return nullptr; + + Value *PtrBase, *Index; + std::tie(PtrBase, Index) = getAsConstantIndexedAddress(GEPLHS, DL); + + // The set of nodes that will take part in this transformation. + SetVector<Value *> Nodes; + + if (!canRewriteGEPAsOffset(RHS, PtrBase, DL, Nodes)) + return nullptr; + + // We know we can re-write this as + // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) + // Since we've only looked through inbouds GEPs we know that we + // can't have overflow on either side. We can therefore re-write + // this as: + // OFFSET1 cmp OFFSET2 + Value *NewRHS = rewriteGEPAsOffset(RHS, PtrBase, DL, Nodes); + + // RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written + // GEP having PtrBase as the pointer base, and has returned in NewRHS the + // offset. Since Index is the offset of LHS to the base pointer, we will now + // compare the offsets instead of comparing the pointers. + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Index, NewRHS); +} + +/// Fold comparisons between a GEP instruction and something else. At this point +/// we know that the GEP is on the LHS of the comparison. Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I) { @@ -670,12 +992,13 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Value *Cmp = Builder->CreateICmp(ICmpInst::getSignedPredicate(Cond), LOffset, ROffset); - return ReplaceInstUsesWith(I, Cmp); + return replaceInstUsesWith(I, Cmp); } // Otherwise, the base pointers are different and the indices are - // different, bail out. - return nullptr; + // different. Try convert this to an indexed compare by looking through + // PHIs/casts. + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } // If one of the GEPs has all zero indices, recurse. @@ -706,7 +1029,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } if (NumDifferences == 0) // SAME GEP? - return ReplaceInstUsesWith(I, // No comparison is needed here. + return replaceInstUsesWith(I, // No comparison is needed here. Builder->getInt1(ICmpInst::isTrueWhenEqual(Cond))); else if (NumDifferences == 1 && GEPsInBounds) { @@ -727,7 +1050,10 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); } } - return nullptr; + + // Try convert this to an indexed compare by looking through PHIs/casts as a + // last resort. + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, @@ -802,12 +1128,12 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } Type *CmpTy = CmpInst::makeCmpResultType(Other->getType()); - return ReplaceInstUsesWith( + return replaceInstUsesWith( ICI, ConstantInt::get(CmpTy, !CmpInst::isTrueWhenEqual(ICI.getPredicate()))); } -/// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X". +/// Fold "icmp pred (X+CI), X". Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred) { @@ -855,8 +1181,8 @@ Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); } -/// FoldICmpDivCst - Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS -/// and CmpRHS are both known to be integer constants. +/// Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS and CmpRHS are +/// both known to be integer constants. Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, ConstantInt *DivRHS) { ConstantInt *CmpRHS = cast<ConstantInt>(ICI.getOperand(1)); @@ -898,8 +1224,8 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // Get the ICmp opcode ICmpInst::Predicate Pred = ICI.getPredicate(); - /// If the division is known to be exact, then there is no remainder from the - /// divide, so the covered range size is unit, otherwise it is the divisor. + // If the division is known to be exact, then there is no remainder from the + // divide, so the covered range size is unit, otherwise it is the divisor. ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS; // Figure out the interval that is being checked. For example, a comparison @@ -973,46 +1299,46 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, default: llvm_unreachable("Unhandled icmp opcode!"); case ICmpInst::ICMP_EQ: if (LoOverflow && HiOverflow) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, X, LoBound); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, X, HiBound); - return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, + return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, X, LoBound); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, X, HiBound); - return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, + return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SLT: if (LoOverflow == +1) // Low bound is greater than input range. - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); if (LoOverflow == -1) // Low bound is less than input range. - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); return new ICmpInst(Pred, X, LoBound); case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); if (HiOverflow == -1) // High bound less than input range. - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); if (Pred == ICmpInst::ICMP_UGT) return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); } } -/// FoldICmpShrCst - Handle "icmp(([al]shr X, cst1), cst2)". +/// Handle "icmp(([al]shr X, cst1), cst2)". Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, ConstantInt *ShAmt) { const APInt &CmpRHSV = cast<ConstantInt>(ICI.getOperand(1))->getValue(); @@ -1077,7 +1403,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero. bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; Constant *Cst = Builder->getInt1(IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); + return replaceInstUsesWith(ICI, Cst); } // Otherwise, check to see if the bits shifted out are known to be zero. @@ -1098,7 +1424,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, return nullptr; } -/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> +/// Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> /// (icmp eq/ne A, Log2(const2/const1)) -> /// (icmp eq/ne A, Log2(const2) - Log2(const1)). Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, @@ -1109,7 +1435,7 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, auto getConstant = [&I, this](bool IsTrue) { if (I.getPredicate() == I.ICMP_NE) IsTrue = !IsTrue; - return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); }; auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { @@ -1118,8 +1444,8 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, return new ICmpInst(Pred, LHS, RHS); }; - APInt AP1 = CI1->getValue(); - APInt AP2 = CI2->getValue(); + const APInt &AP1 = CI1->getValue(); + const APInt &AP2 = CI2->getValue(); // Don't bother doing any work for cases which InstSimplify handles. if (AP2 == 0) @@ -1163,7 +1489,7 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, return getConstant(false); } -/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" -> +/// Handle "(icmp eq/ne (shl const2, A), const1)" -> /// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, ConstantInt *CI1, @@ -1173,7 +1499,7 @@ Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, auto getConstant = [&I, this](bool IsTrue) { if (I.getPredicate() == I.ICMP_NE) IsTrue = !IsTrue; - return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); }; auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { @@ -1182,8 +1508,8 @@ Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, return new ICmpInst(Pred, LHS, RHS); }; - APInt AP1 = CI1->getValue(); - APInt AP2 = CI2->getValue(); + const APInt &AP1 = CI1->getValue(); + const APInt &AP2 = CI2->getValue(); // Don't bother doing any work for cases which InstSimplify handles. if (AP2 == 0) @@ -1208,8 +1534,7 @@ Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, return getConstant(false); } -/// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". -/// +/// Handle "icmp (instr, intcst)". Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Instruction *LHSI, ConstantInt *RHS) { @@ -1412,9 +1737,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // As a special case, check to see if this means that the // result is always true or false now. if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); + return replaceInstUsesWith(ICI, Builder->getFalse()); if (ICI.getPredicate() == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); + return replaceInstUsesWith(ICI, Builder->getTrue()); } else { ICI.setOperand(1, NewCst); Constant *NewAndCst; @@ -1674,7 +1999,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (Comp != RHS) {// Comparing against a bit that we know is zero. bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; Constant *Cst = Builder->getInt1(IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); + return replaceInstUsesWith(ICI, Cst); } // If the shift is NUW, then it is just shifting out zeros, no need for an @@ -1764,8 +2089,28 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, break; } - case Instruction::SDiv: case Instruction::UDiv: + if (ConstantInt *DivLHS = dyn_cast<ConstantInt>(LHSI->getOperand(0))) { + Value *X = LHSI->getOperand(1); + const APInt &C1 = RHS->getValue(); + const APInt &C2 = DivLHS->getValue(); + assert(C2 != 0 && "udiv 0, X should have been simplified already."); + // (icmp ugt (udiv C2, X), C1) -> (icmp ule X, C2/(C1+1)) + if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { + assert(!C1.isMaxValue() && + "icmp ugt X, UINT_MAX should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_ULE, X, + ConstantInt::get(X->getType(), C2.udiv(C1 + 1))); + } + // (icmp ult (udiv C2, X), C1) -> (icmp ugt X, C2/C1) + if (ICI.getPredicate() == ICmpInst::ICMP_ULT) { + assert(C1 != 0 && "icmp ult X, 0 should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), C2.udiv(C1))); + } + } + // fall-through + case Instruction::SDiv: // Fold: icmp pred ([us]div X, C1), C2 -> range test // Fold this div into the comparison, producing a range check. // Determine, based on the divide type, what the range is being @@ -1895,27 +2240,30 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } break; case Instruction::Xor: - // For the xor case, we can xor two constants together, eliminating - // the explicit xor. - if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) { - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - ConstantExpr::getXor(RHS, BOC)); - } else if (RHSV == 0) { - // Replace ((xor A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); + if (BO->hasOneUse()) { + if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) { + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + ConstantExpr::getXor(RHS, BOC)); + } else if (RHSV == 0) { + // Replace ((xor A, B) != 0) with (A != B) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + BO->getOperand(1)); + } } break; case Instruction::Sub: - // Replace ((sub A, B) != C) with (B != A-C) if A & C are constants. - if (ConstantInt *BOp0C = dyn_cast<ConstantInt>(BO->getOperand(0))) { - if (BO->hasOneUse()) + if (BO->hasOneUse()) { + if (ConstantInt *BOp0C = dyn_cast<ConstantInt>(BO->getOperand(0))) { + // Replace ((sub A, B) != C) with (B != A-C) if A & C are constants. return new ICmpInst(ICI.getPredicate(), BO->getOperand(1), - ConstantExpr::getSub(BOp0C, RHS)); - } else if (RHSV == 0) { - // Replace ((sub A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); + ConstantExpr::getSub(BOp0C, RHS)); + } else if (RHSV == 0) { + // Replace ((sub A, B) != 0) with (A != B) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), + BO->getOperand(1)); + } } break; case Instruction::Or: @@ -1924,7 +2272,16 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { Constant *NotCI = ConstantExpr::getNot(RHS); if (!ConstantExpr::getAnd(BOC, NotCI)->isNullValue()) - return ReplaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); + return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); + + // Comparing if all bits outside of a constant mask are set? + // Replace (X | C) == -1 with (X & ~C) == ~C. + // This removes the -1 constant. + if (BO->hasOneUse() && RHS->isAllOnesValue()) { + Constant *NotBOC = ConstantExpr::getNot(BOC); + Value *And = Builder->CreateAnd(BO->getOperand(0), NotBOC); + return new ICmpInst(ICI.getPredicate(), And, NotBOC); + } } break; @@ -1933,7 +2290,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // If bits are being compared against that are and'd out, then the // comparison can never succeed! if ((RHSV & ~BOC->getValue()) != 0) - return ReplaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); + return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); // If we have ((X & C) == C), turn it into ((X & C) != 0). if (RHS == BOC && RHSV.isPowerOf2()) @@ -2013,11 +2370,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return nullptr; } -/// visitICmpInstWithCastAndCast - Handle icmp (cast x to y), (cast/cst). -/// We only handle extending casts so far. -/// -Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { - const CastInst *LHSCI = cast<CastInst>(ICI.getOperand(0)); +/// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so +/// far. +Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICmp) { + const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0)); Value *LHSCIOp = LHSCI->getOperand(0); Type *SrcTy = LHSCIOp->getType(); Type *DestTy = LHSCI->getType(); @@ -2028,7 +2384,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { if (LHSCI->getOpcode() == Instruction::PtrToInt && DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { Value *RHSOp = nullptr; - if (PtrToIntOperator *RHSC = dyn_cast<PtrToIntOperator>(ICI.getOperand(1))) { + if (auto *RHSC = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { Value *RHSCIOp = RHSC->getOperand(0); if (RHSCIOp->getType()->getPointerAddressSpace() == LHSCIOp->getType()->getPointerAddressSpace()) { @@ -2037,11 +2393,12 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { if (LHSCIOp->getType() != RHSOp->getType()) RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); } - } else if (Constant *RHSC = dyn_cast<Constant>(ICI.getOperand(1))) + } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } if (RHSOp) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSOp); + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSOp); } // The code below only handles extension cast instructions, so far. @@ -2051,9 +2408,9 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { return nullptr; bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; - bool isSignedCmp = ICI.isSigned(); + bool isSignedCmp = ICmp.isSigned(); - if (CastInst *CI = dyn_cast<CastInst>(ICI.getOperand(1))) { + if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) { // Not an extension from the same type? RHSCIOp = CI->getOperand(0); if (RHSCIOp->getType() != LHSCIOp->getType()) @@ -2065,50 +2422,51 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { return nullptr; // Deal with equality cases early. - if (ICI.isEquality()) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSCIOp); + if (ICmp.isEquality()) + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); // A signed comparison of sign extended values simplifies into a // signed comparison. if (isSignedCmp && isSignedExt) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICI.getUnsignedPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, RHSCIOp); } - // If we aren't dealing with a constant on the RHS, exit early - ConstantInt *CI = dyn_cast<ConstantInt>(ICI.getOperand(1)); - if (!CI) + // If we aren't dealing with a constant on the RHS, exit early. + auto *C = dyn_cast<Constant>(ICmp.getOperand(1)); + if (!C) return nullptr; // Compute the constant that would happen if we truncated to SrcTy then - // reextended to DestTy. - Constant *Res1 = ConstantExpr::getTrunc(CI, SrcTy); - Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), - Res1, DestTy); + // re-extended to DestTy. + Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); + Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), Res1, DestTy); // If the re-extended constant didn't change... - if (Res2 == CI) { + if (Res2 == C) { // Deal with equality cases early. - if (ICI.isEquality()) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, Res1); + if (ICmp.isEquality()) + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); // A signed comparison of sign extended values simplifies into a // signed comparison. if (isSignedExt && isSignedCmp) - return new ICmpInst(ICI.getPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICI.getUnsignedPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, Res1); } - // The re-extended constant changed so the constant cannot be represented - // in the shorter type. Consequently, we cannot emit a simple comparison. + // The re-extended constant changed, partly changed (in the case of a vector), + // or could not be determined to be equal (in the case of a constant + // expression), so the constant cannot be represented in the shorter type. + // Consequently, we cannot emit a simple comparison. // All the cases that fold to true or false will have already been handled // by SimplifyICmpInst, so only deal with the tricky case. - if (isSignedCmp || !isSignedExt) + if (isSignedCmp || !isSignedExt || !isa<ConstantInt>(C)) return nullptr; // Evaluate the comparison for LT (we invert for GT below). LE and GE cases @@ -2117,17 +2475,17 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // We're performing an unsigned comp with a sign extended value. // This is true if the input is >= 0. [aka >s -1] Constant *NegOne = Constant::getAllOnesValue(SrcTy); - Value *Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICI.getName()); + Value *Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICmp.getName()); // Finally, return the value computed. - if (ICI.getPredicate() == ICmpInst::ICMP_ULT) - return ReplaceInstUsesWith(ICI, Result); + if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) + return replaceInstUsesWith(ICmp, Result); - assert(ICI.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); + assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); return BinaryOperator::CreateNot(Result); } -/// ProcessUGT_ADDCST_ADD - The caller has matched a pattern of the form: +/// The caller has matched a pattern of the form: /// I = icmp ugt (add (add A, B), CI2), CI1 /// If this is of the form: /// sum = a + b @@ -2207,7 +2565,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // The inner add was the result of the narrow add, zero extended to the // wider type. Replace it with the result computed by the intrinsic. - IC.ReplaceInstUsesWith(*OrigAdd, ZExt); + IC.replaceInstUsesWith(*OrigAdd, ZExt); // The original icmp gets replaced with the overflow value. return ExtractValueInst::Create(Call, 1, "sadd.overflow"); @@ -2491,7 +2849,7 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, continue; if (TruncInst *TI = dyn_cast<TruncInst>(U)) { if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) - IC.ReplaceInstUsesWith(*TI, Mul); + IC.replaceInstUsesWith(*TI, Mul); else TI->setOperand(0, Mul); } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) { @@ -2503,7 +2861,7 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, Instruction *Zext = cast<Instruction>(Builder->CreateZExt(ShortAnd, BO->getType())); IC.Worklist.Add(Zext); - IC.ReplaceInstUsesWith(*BO, Zext); + IC.replaceInstUsesWith(*BO, Zext); } else { llvm_unreachable("Unexpected Binary operation"); } @@ -2545,9 +2903,9 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, return ExtractValueInst::Create(Call, 1); } -// DemandedBitsLHSMask - When performing a comparison against a constant, -// it is possible that not all the bits in the LHS are demanded. This helper -// method computes the mask that IS demanded. +/// When performing a comparison against a constant, it is possible that not all +/// the bits in the LHS are demanded. This helper method computes the mask that +/// IS demanded. static APInt DemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, bool isSignCheck) { if (isSignCheck) @@ -2656,9 +3014,7 @@ bool InstCombiner::dominatesAllUses(const Instruction *DI, return true; } -/// -/// true when the instruction sequence within a block is select-cmp-br. -/// +/// Return true when the instruction sequence within a block is select-cmp-br. static bool isChainSelectCmpBranch(const SelectInst *SI) { const BasicBlock *BB = SI->getParent(); if (!BB) @@ -2672,7 +3028,6 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) { return true; } -/// /// \brief True when a select result is replaced by one of its operands /// in select-icmp sequence. This will eventually result in the elimination /// of the select. @@ -2738,6 +3093,63 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, return false; } +/// If we have an icmp le or icmp ge instruction with a constant operand, turn +/// it into the appropriate icmp lt or icmp gt instruction. This transform +/// allows them to be folded in visitICmpInst. +static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { + ICmpInst::Predicate Pred = I.getPredicate(); + if (Pred != ICmpInst::ICMP_SLE && Pred != ICmpInst::ICMP_SGE && + Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_UGE) + return nullptr; + + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + auto *Op1C = dyn_cast<Constant>(Op1); + if (!Op1C) + return nullptr; + + // Check if the constant operand can be safely incremented/decremented without + // overflowing/underflowing. For scalars, SimplifyICmpInst has already handled + // the edge cases for us, so we just assert on them. For vectors, we must + // handle the edge cases. + Type *Op1Type = Op1->getType(); + bool IsSigned = I.isSigned(); + bool IsLE = (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE); + auto *CI = dyn_cast<ConstantInt>(Op1C); + if (CI) { + // A <= MAX -> TRUE ; A >= MIN -> TRUE + assert(IsLE ? !CI->isMaxValue(IsSigned) : !CI->isMinValue(IsSigned)); + } else if (Op1Type->isVectorTy()) { + // TODO? If the edge cases for vectors were guaranteed to be handled as they + // are for scalar, we could remove the min/max checks. However, to do that, + // we would have to use insertelement/shufflevector to replace edge values. + unsigned NumElts = Op1Type->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = Op1C->getAggregateElement(i); + if (!Elt) + return nullptr; + + if (isa<UndefValue>(Elt)) + continue; + // Bail out if we can't determine if this constant is min/max or if we + // know that this constant is min/max. + auto *CI = dyn_cast<ConstantInt>(Elt); + if (!CI || (IsLE ? CI->isMaxValue(IsSigned) : CI->isMinValue(IsSigned))) + return nullptr; + } + } else { + // ConstantExpr? + return nullptr; + } + + // Increment or decrement the constant and set the new comparison predicate: + // ULE -> ULT ; UGE -> UGT ; SLE -> SLT ; SGE -> SGT + Constant *OneOrNegOne = ConstantInt::get(Op1Type, IsLE ? 1 : -1, true); + CmpInst::Predicate NewPred = IsLE ? ICmpInst::ICMP_ULT: ICmpInst::ICMP_UGT; + NewPred = IsSigned ? ICmpInst::getSignedPredicate(NewPred) : NewPred; + return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne)); +} + Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2748,8 +3160,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { /// complex to least complex. This puts constants before unary operators, /// before binary operators. if (Op0Cplxity < Op1Cplxity || - (Op0Cplxity == Op1Cplxity && - swapMayExposeCSEOpportunities(Op0, Op1))) { + (Op0Cplxity == Op1Cplxity && swapMayExposeCSEOpportunities(Op0, Op1))) { I.swapOperands(); std::swap(Op0, Op1); Changed = true; @@ -2757,12 +3168,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC, &I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val // ie, abs(val) != 0 -> val != 0 - if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) - { + if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) { Value *Cond, *SelectTrue, *SelectFalse; if (match(Op0, m_Select(m_Value(Cond), m_Value(SelectTrue), m_Value(SelectFalse)))) { @@ -2780,47 +3190,50 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Type *Ty = Op0->getType(); // icmp's with boolean values can always be turned into bitwise operations - if (Ty->isIntegerTy(1)) { + if (Ty->getScalarType()->isIntegerTy(1)) { switch (I.getPredicate()) { default: llvm_unreachable("Invalid icmp instruction!"); - case ICmpInst::ICMP_EQ: { // icmp eq i1 A, B -> ~(A^B) - Value *Xor = Builder->CreateXor(Op0, Op1, I.getName()+"tmp"); + case ICmpInst::ICMP_EQ: { // icmp eq i1 A, B -> ~(A^B) + Value *Xor = Builder->CreateXor(Op0, Op1, I.getName() + "tmp"); return BinaryOperator::CreateNot(Xor); } - case ICmpInst::ICMP_NE: // icmp eq i1 A, B -> A^B + case ICmpInst::ICMP_NE: // icmp ne i1 A, B -> A^B return BinaryOperator::CreateXor(Op0, Op1); case ICmpInst::ICMP_UGT: std::swap(Op0, Op1); // Change icmp ugt -> icmp ult // FALL THROUGH - case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B - Value *Not = Builder->CreateNot(Op0, I.getName()+"tmp"); + case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B + Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op1); } case ICmpInst::ICMP_SGT: std::swap(Op0, Op1); // Change icmp sgt -> icmp slt // FALL THROUGH case ICmpInst::ICMP_SLT: { // icmp slt i1 A, B -> A & ~B - Value *Not = Builder->CreateNot(Op1, I.getName()+"tmp"); + Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op0); } case ICmpInst::ICMP_UGE: std::swap(Op0, Op1); // Change icmp uge -> icmp ule // FALL THROUGH - case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B - Value *Not = Builder->CreateNot(Op0, I.getName()+"tmp"); + case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B + Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op1); } case ICmpInst::ICMP_SGE: std::swap(Op0, Op1); // Change icmp sge -> icmp sle // FALL THROUGH - case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B - Value *Not = Builder->CreateNot(Op1, I.getName()+"tmp"); + case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B + Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op0); } } } + if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) + return NewICmp; + unsigned BitWidth = 0; if (Ty->isIntOrIntVectorTy()) BitWidth = Ty->getScalarSizeInBits(); @@ -2853,6 +3266,19 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return Res; } + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT) + if (auto *SI = dyn_cast<SelectInst>(Op0)) { + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL)) + return new ICmpInst(I.getPredicate(), B, CI); + if (isKnownPositive(B, DL)) + return new ICmpInst(I.getPredicate(), A, CI); + } + } + + // The following transforms are only 'worth it' if the only user of the // subtraction is the icmp. if (Op0->hasOneUse()) { @@ -2882,30 +3308,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return new ICmpInst(ICmpInst::ICMP_SLE, A, B); } - // If we have an icmp le or icmp ge instruction, turn it into the - // appropriate icmp lt or icmp gt instruction. This allows us to rely on - // them being folded in the code below. The SimplifyICmpInst code has - // already handled the edge cases for us, so we just assert on them. - switch (I.getPredicate()) { - default: break; - case ICmpInst::ICMP_ULE: - assert(!CI->isMaxValue(false)); // A <=u MAX -> TRUE - return new ICmpInst(ICmpInst::ICMP_ULT, Op0, - Builder->getInt(CI->getValue()+1)); - case ICmpInst::ICMP_SLE: - assert(!CI->isMaxValue(true)); // A <=s MAX -> TRUE - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, - Builder->getInt(CI->getValue()+1)); - case ICmpInst::ICMP_UGE: - assert(!CI->isMinValue(false)); // A >=u MIN -> TRUE - return new ICmpInst(ICmpInst::ICMP_UGT, Op0, - Builder->getInt(CI->getValue()-1)); - case ICmpInst::ICMP_SGE: - assert(!CI->isMinValue(true)); // A >=s MIN -> TRUE - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, - Builder->getInt(CI->getValue()-1)); - } - if (I.isEquality()) { ConstantInt *CI2; if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || @@ -2925,6 +3327,42 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // bits, if it is a sign bit comparison, it only demands the sign bit. bool UnusedBit; isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); + + // Canonicalize icmp instructions based on dominating conditions. + BasicBlock *Parent = I.getParent(); + BasicBlock *Dom = Parent->getSinglePredecessor(); + auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; + ICmpInst::Predicate Pred; + BasicBlock *TrueBB, *FalseBB; + ConstantInt *CI2; + if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)), + TrueBB, FalseBB)) && + TrueBB != FalseBB) { + ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(), + CI->getValue()); + ConstantRange DominatingCR = + (Parent == TrueBB) + ? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue()) + : ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred), CI2->getValue()); + ConstantRange Intersection = DominatingCR.intersectWith(CR); + ConstantRange Difference = DominatingCR.difference(CR); + if (Intersection.isEmptySet()) + return replaceInstUsesWith(I, Builder->getFalse()); + if (Difference.isEmptySet()) + return replaceInstUsesWith(I, Builder->getTrue()); + // Canonicalizing a sign bit comparison that gets used in a branch, + // pessimizes codegen by generating branch on zero instruction instead + // of a test and branch. So we avoid canonicalizing in such situations + // because test and branch instruction has better branch displacement + // than compare and branch instruction. + if (!isBranchOnSignBitCheck(I, isSignBit) && !I.isEquality()) { + if (auto *AI = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI)); + if (auto *AD = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD)); + } + } } // See if we can fold the comparison based on range information we can get @@ -2975,7 +3413,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { default: llvm_unreachable("Unknown icmp opcode!"); case ICmpInst::ICMP_EQ: { if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); // If all bits are known zero except for one, then we know at most one // bit is set. If the comparison is against zero, then this is a check @@ -3019,7 +3457,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } case ICmpInst::ICMP_NE: { if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); // If all bits are known zero except for one, then we know at most one // bit is set. If the comparison is against zero, then this is a check @@ -3063,9 +3501,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } case ICmpInst::ICMP_ULT: if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { @@ -3081,9 +3519,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_UGT: if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -3100,9 +3538,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_SLT: if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { @@ -3113,9 +3551,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_SGT: if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -3128,30 +3566,30 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { case ICmpInst::ICMP_SGE: assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_SLE: assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_UGE: assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_ULE: assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; } @@ -3179,12 +3617,22 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // See if we are doing a comparison between a constant and an instruction that // can be folded into the comparison. if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + Value *A = nullptr, *B = nullptr; // Since the RHS is a ConstantInt (CI), if the left hand side is an // instruction, see if that instruction also has constants so that the // instruction can be folded into the icmp if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) if (Instruction *Res = visitICmpInstWithInstAndIntCst(I, LHSI, CI)) return Res; + + // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) + if (I.isEquality() && CI->isZero() && + match(Op0, m_UDiv(m_Value(A), m_Value(B)))) { + ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_EQ + ? ICmpInst::ICMP_UGT + : ICmpInst::ICMP_ULE; + return new ICmpInst(Pred, B, A); + } } // Handle icmp with constant (but not simple integer constant) RHS @@ -3354,10 +3802,14 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Analyze the case when either Op0 or Op1 is an add instruction. // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Add) - A = BO0->getOperand(0), B = BO0->getOperand(1); - if (BO1 && BO1->getOpcode() == Instruction::Add) - C = BO1->getOperand(0), D = BO1->getOperand(1); + if (BO0 && BO0->getOpcode() == Instruction::Add) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Add) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } // icmp (X+cst) < 0 --> X < -cst if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) @@ -3474,11 +3926,18 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Analyze the case when either Op0 or Op1 is a sub instruction. // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). - A = nullptr; B = nullptr; C = nullptr; D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Sub) - A = BO0->getOperand(0), B = BO0->getOperand(1); - if (BO1 && BO1->getOpcode() == Instruction::Sub) - C = BO1->getOperand(0), D = BO1->getOperand(1); + A = nullptr; + B = nullptr; + C = nullptr; + D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Sub) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Sub) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. if (A == Op1 && NoOp0WrapProblem) @@ -3525,9 +3984,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { default: break; case ICmpInst::ICMP_EQ: - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); case ICmpInst::ICMP_NE: - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), @@ -3654,8 +4113,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Constant *Overflow; if (OptimizeOverflowCheck(OCF_UNSIGNED_ADD, A, B, *AddI, Result, Overflow)) { - ReplaceInstUsesWith(*AddI, Result); - return ReplaceInstUsesWith(I, Overflow); + replaceInstUsesWith(*AddI, Result); + return replaceInstUsesWith(I, Overflow); } } @@ -3834,7 +4293,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return Changed ? &I : nullptr; } -/// FoldFCmp_IntToFP_Cst - Fold fcmp ([us]itofp x, cst) if possible. +/// Fold fcmp ([us]itofp x, cst) if possible. Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { @@ -3864,10 +4323,10 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); if (RHS.compare(RHSRoundInt) != APFloat::cmpEqual) { if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE); - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); } } @@ -3933,9 +4392,9 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, Pred = ICmpInst::ICMP_NE; break; case FCmpInst::FCMP_ORD: - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); case FCmpInst::FCMP_UNO: - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); } // Now we know that the APFloat is a normal number, zero or inf. @@ -3953,8 +4412,8 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, if (SMax.compare(RHS) == APFloat::cmpLessThan) { // smax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return ReplaceInstUsesWith(I, Builder->getTrue()); - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getFalse()); } } else { // If the RHS value is > UnsignedMax, fold the comparison. This handles @@ -3965,8 +4424,8 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, if (UMax.compare(RHS) == APFloat::cmpLessThan) { // umax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) - return ReplaceInstUsesWith(I, Builder->getTrue()); - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getFalse()); } } @@ -3978,8 +4437,8 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // smin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return ReplaceInstUsesWith(I, Builder->getTrue()); - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getFalse()); } } else { // See if the RHS value is < UnsignedMin. @@ -3989,8 +4448,8 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // umin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) - return ReplaceInstUsesWith(I, Builder->getTrue()); - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getFalse()); } } @@ -4012,14 +4471,14 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, switch (Pred) { default: llvm_unreachable("Unexpected integer comparison!"); case ICmpInst::ICMP_NE: // (float)int != 4.4 --> true - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); case ICmpInst::ICMP_EQ: // (float)int == 4.4 --> false - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); case ICmpInst::ICMP_ULE: // (float)int <= 4.4 --> int <= 4 // (float)int <= -4.4 --> false if (RHS.isNegative()) - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); break; case ICmpInst::ICMP_SLE: // (float)int <= 4.4 --> int <= 4 @@ -4031,7 +4490,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // (float)int < -4.4 --> false // (float)int < 4.4 --> int <= 4 if (RHS.isNegative()) - return ReplaceInstUsesWith(I, Builder->getFalse()); + return replaceInstUsesWith(I, Builder->getFalse()); Pred = ICmpInst::ICMP_ULE; break; case ICmpInst::ICMP_SLT: @@ -4044,7 +4503,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // (float)int > 4.4 --> int > 4 // (float)int > -4.4 --> true if (RHS.isNegative()) - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); break; case ICmpInst::ICMP_SGT: // (float)int > 4.4 --> int > 4 @@ -4056,7 +4515,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // (float)int >= -4.4 --> true // (float)int >= 4.4 --> int > 4 if (RHS.isNegative()) - return ReplaceInstUsesWith(I, Builder->getTrue()); + return replaceInstUsesWith(I, Builder->getTrue()); Pred = ICmpInst::ICMP_UGT; break; case ICmpInst::ICMP_SGE: @@ -4089,7 +4548,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC, &I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' if (Op0 == Op1) { @@ -4208,39 +4667,33 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; CallInst *CI = cast<CallInst>(LHSI); - const Function *F = CI->getCalledFunction(); - if (!F) + Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + if (IID != Intrinsic::fabs) break; // Various optimization for fabs compared with zero. - LibFunc::Func Func; - if (F->getIntrinsicID() == Intrinsic::fabs || - (TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && - (Func == LibFunc::fabs || Func == LibFunc::fabsf || - Func == LibFunc::fabsl))) { - switch (I.getPredicate()) { - default: - break; - // fabs(x) < 0 --> false - case FCmpInst::FCMP_OLT: - return ReplaceInstUsesWith(I, Builder->getFalse()); - // fabs(x) > 0 --> x != 0 - case FCmpInst::FCMP_OGT: - return new FCmpInst(FCmpInst::FCMP_ONE, CI->getArgOperand(0), RHSC); - // fabs(x) <= 0 --> x == 0 - case FCmpInst::FCMP_OLE: - return new FCmpInst(FCmpInst::FCMP_OEQ, CI->getArgOperand(0), RHSC); - // fabs(x) >= 0 --> !isnan(x) - case FCmpInst::FCMP_OGE: - return new FCmpInst(FCmpInst::FCMP_ORD, CI->getArgOperand(0), RHSC); - // fabs(x) == 0 --> x == 0 - // fabs(x) != 0 --> x != 0 - case FCmpInst::FCMP_OEQ: - case FCmpInst::FCMP_UEQ: - case FCmpInst::FCMP_ONE: - case FCmpInst::FCMP_UNE: - return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC); - } + switch (I.getPredicate()) { + default: + break; + // fabs(x) < 0 --> false + case FCmpInst::FCMP_OLT: + llvm_unreachable("handled by SimplifyFCmpInst"); + // fabs(x) > 0 --> x != 0 + case FCmpInst::FCMP_OGT: + return new FCmpInst(FCmpInst::FCMP_ONE, CI->getArgOperand(0), RHSC); + // fabs(x) <= 0 --> x == 0 + case FCmpInst::FCMP_OLE: + return new FCmpInst(FCmpInst::FCMP_OEQ, CI->getArgOperand(0), RHSC); + // fabs(x) >= 0 --> !isnan(x) + case FCmpInst::FCMP_OGE: + return new FCmpInst(FCmpInst::FCMP_ORD, CI->getArgOperand(0), RHSC); + // fabs(x) == 0 --> x == 0 + // fabs(x) != 0 --> x != 0 + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UNE: + return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC); } } } diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index e4e506509d39..aa421ff594fb 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -138,7 +138,7 @@ IntrinsicIDToOverflowCheckFlavor(unsigned ID) { /// \brief An IRBuilder inserter that adds new instructions to the instcombine /// worklist. class LLVM_LIBRARY_VISIBILITY InstCombineIRInserter - : public IRBuilderDefaultInserter<true> { + : public IRBuilderDefaultInserter { InstCombineWorklist &Worklist; AssumptionCache *AC; @@ -148,7 +148,7 @@ public: void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, BasicBlock::iterator InsertPt) const { - IRBuilderDefaultInserter<true>::InsertHelper(I, Name, BB, InsertPt); + IRBuilderDefaultInserter::InsertHelper(I, Name, BB, InsertPt); Worklist.Add(I); using namespace llvm::PatternMatch; @@ -171,12 +171,14 @@ public: /// \brief An IRBuilder that automatically inserts new instructions into the /// worklist. - typedef IRBuilder<true, TargetFolder, InstCombineIRInserter> BuilderTy; + typedef IRBuilder<TargetFolder, InstCombineIRInserter> BuilderTy; BuilderTy *Builder; private: // Mode in which we are running the combiner. const bool MinimizeSize; + /// Enable combines that trigger rarely but are costly in compiletime. + const bool ExpensiveCombines; AliasAnalysis *AA; @@ -195,11 +197,12 @@ private: public: InstCombiner(InstCombineWorklist &Worklist, BuilderTy *Builder, - bool MinimizeSize, AliasAnalysis *AA, + bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, AssumptionCache *AC, TargetLibraryInfo *TLI, DominatorTree *DT, const DataLayout &DL, LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), - AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL), LI(LI), MadeIRChange(false) {} + ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), + DL(DL), LI(LI), MadeIRChange(false) {} /// \brief Run the combiner over the entire worklist until it is empty. /// @@ -327,6 +330,8 @@ public: Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI); Instruction *visitExtractValueInst(ExtractValueInst &EV); Instruction *visitLandingPadInst(LandingPadInst &LI); + Instruction *visitVAStartInst(VAStartInst &I); + Instruction *visitVACopyInst(VACopyInst &I); // visitInstruction - Specify what to return for unhandled instructions... Instruction *visitInstruction(Instruction &I) { return nullptr; } @@ -390,6 +395,7 @@ private: Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); + Instruction *foldCastedBitwiseLogic(BinaryOperator &I); public: /// \brief Inserts an instruction \p New before instruction \p Old @@ -417,7 +423,7 @@ public: /// replaceable with another preexisting expression. Here we add all uses of /// I to the worklist, replace all uses of I with the new value, then return /// I, so that the inst combiner will know that I was modified. - Instruction *ReplaceInstUsesWith(Instruction &I, Value *V) { + Instruction *replaceInstUsesWith(Instruction &I, Value *V) { // If there are no uses to replace, then we return nullptr to indicate that // no changes were made to the program. if (I.use_empty()) return nullptr; @@ -451,16 +457,16 @@ public: /// When dealing with an instruction that has side effects or produces a void /// value, we can't rely on DCE to delete the instruction. Instead, visit /// methods should return the value returned by this function. - Instruction *EraseInstFromFunction(Instruction &I) { + Instruction *eraseInstFromFunction(Instruction &I) { DEBUG(dbgs() << "IC: ERASE " << I << '\n'); assert(I.use_empty() && "Cannot erase instruction that is used!"); // Make sure that we reprocess all operands now that we reduced their // use counts. if (I.getNumOperands() < 8) { - for (User::op_iterator i = I.op_begin(), e = I.op_end(); i != e; ++i) - if (Instruction *Op = dyn_cast<Instruction>(*i)) - Worklist.Add(Op); + for (Use &Operand : I.operands()) + if (auto *Inst = dyn_cast<Instruction>(Operand)) + Worklist.Add(Inst); } Worklist.Remove(&I); I.eraseFromParent(); @@ -515,12 +521,12 @@ private: Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth, Instruction *CxtI); - bool SimplifyDemandedBits(Use &U, APInt DemandedMask, APInt &KnownZero, + bool SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth = 0); /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence. Value *SimplifyShrShlDemandedBits(Instruction *Lsr, Instruction *Sftl, - APInt DemandedMask, APInt &KnownZero, + const APInt &DemandedMask, APInt &KnownZero, APInt &KnownOne); /// \brief Tries to simplify operands to an integer instruction based on its @@ -556,7 +562,7 @@ private: Value *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); - Instruction *MatchBSwapOrBitReverse(BinaryOperator &I); + Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); Instruction *SimplifyMemTransfer(MemIntrinsic *MI); Instruction *SimplifyMemSet(MemSetInst *MI); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index dd2889de405e..d312983ed51b 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -205,11 +205,11 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { // Now make everything use the getelementptr instead of the original // allocation. - return IC.ReplaceInstUsesWith(AI, GEP); + return IC.replaceInstUsesWith(AI, GEP); } if (isa<UndefValue>(AI.getArraySize())) - return IC.ReplaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); + return IC.replaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); // Ensure that the alloca array size argument has type intptr_t, so that // any casting is exposed early. @@ -271,7 +271,7 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { EntryAI->setAlignment(MaxAlign); if (AI.getType() != EntryAI->getType()) return new BitCastInst(EntryAI, AI.getType()); - return ReplaceInstUsesWith(AI, EntryAI); + return replaceInstUsesWith(AI, EntryAI); } } } @@ -291,12 +291,12 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) - EraseInstFromFunction(*ToDelete[i]); + eraseInstFromFunction(*ToDelete[i]); Constant *TheSrc = cast<Constant>(Copy->getSource()); Constant *Cast = ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, AI.getType()); - Instruction *NewI = ReplaceInstUsesWith(AI, Cast); - EraseInstFromFunction(*Copy); + Instruction *NewI = replaceInstUsesWith(AI, Cast); + eraseInstFromFunction(*Copy); ++NumGlobalCopies; return NewI; } @@ -326,7 +326,8 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT LoadInst *NewLoad = IC.Builder->CreateAlignedLoad( IC.Builder->CreateBitCast(Ptr, NewTy->getPointerTo(AS)), - LI.getAlignment(), LI.getName() + Suffix); + LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); + NewLoad->setAtomic(LI.getOrdering(), LI.getSynchScope()); MDBuilder MDB(NewLoad->getContext()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; @@ -398,7 +399,8 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value StoreInst *NewStore = IC.Builder->CreateAlignedStore( V, IC.Builder->CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), - SI.getAlignment()); + SI.getAlignment(), SI.isVolatile()); + NewStore->setAtomic(SI.getOrdering(), SI.getSynchScope()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; MDNode *N = MDPair.second; @@ -438,7 +440,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value return NewStore; } -/// \brief Combine loads to match the type of value their uses after looking +/// \brief Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// /// The core idea here is that if the result of a load is used in an operation, @@ -456,9 +458,9 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value /// later. However, it is risky in case some backend or other part of LLVM is /// relying on the exact type loaded to select appropriate atomic operations. static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { - // FIXME: We could probably with some care handle both volatile and atomic - // loads here but it isn't clear that this is important. - if (!LI.isSimple()) + // FIXME: We could probably with some care handle both volatile and ordered + // atomic loads here but it isn't clear that this is important. + if (!LI.isUnordered()) return nullptr; if (LI.use_empty()) @@ -486,7 +488,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { auto *SI = cast<StoreInst>(*UI++); IC.Builder->SetInsertPoint(SI); combineStoreToNewValue(IC, *SI, NewLoad); - IC.EraseInstFromFunction(*SI); + IC.eraseInstFromFunction(*SI); } assert(LI.use_empty() && "Failed to remove all users of the load!"); // Return the old load so the combiner can delete it safely. @@ -503,7 +505,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { if (CI->isNoopCast(DL)) { LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy()); CI->replaceAllUsesWith(NewLoad); - IC.EraseInstFromFunction(*CI); + IC.eraseInstFromFunction(*CI); return &LI; } } @@ -523,16 +525,17 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { if (!T->isAggregateType()) return nullptr; + StringRef Name = LI.getName(); assert(LI.getAlignment() && "Alignment must be set at this point"); if (auto *ST = dyn_cast<StructType>(T)) { // If the struct only have one element, we unpack. - unsigned Count = ST->getNumElements(); - if (Count == 1) { + auto NumElements = ST->getNumElements(); + if (NumElements == 1) { LoadInst *NewLoad = combineLoadToNewType(IC, LI, ST->getTypeAtIndex(0U), ".unpack"); - return IC.ReplaceInstUsesWith(LI, IC.Builder->CreateInsertValue( - UndefValue::get(T), NewLoad, 0, LI.getName())); + return IC.replaceInstUsesWith(LI, IC.Builder->CreateInsertValue( + UndefValue::get(T), NewLoad, 0, Name)); } // We don't want to break loads with padding here as we'd loose @@ -542,38 +545,67 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { if (SL->hasPadding()) return nullptr; - auto Name = LI.getName(); - SmallString<16> LoadName = Name; - LoadName += ".unpack"; - SmallString<16> EltName = Name; - EltName += ".elt"; + auto Align = LI.getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(ST); + auto *Addr = LI.getPointerOperand(); - Value *V = UndefValue::get(T); - auto *IdxType = Type::getInt32Ty(ST->getContext()); + auto *IdxType = Type::getInt32Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - for (unsigned i = 0; i < Count; i++) { + + Value *V = UndefValue::get(T); + for (unsigned i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), EltName); - auto *L = IC.Builder->CreateAlignedLoad(Ptr, LI.getAlignment(), - LoadName); + auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + Name + ".elt"); + auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); + auto *L = IC.Builder->CreateAlignedLoad(Ptr, EltAlign, Name + ".unpack"); V = IC.Builder->CreateInsertValue(V, L, i); } V->setName(Name); - return IC.ReplaceInstUsesWith(LI, V); + return IC.replaceInstUsesWith(LI, V); } if (auto *AT = dyn_cast<ArrayType>(T)) { - // If the array only have one element, we unpack. - if (AT->getNumElements() == 1) { - LoadInst *NewLoad = combineLoadToNewType(IC, LI, AT->getElementType(), - ".unpack"); - return IC.ReplaceInstUsesWith(LI, IC.Builder->CreateInsertValue( - UndefValue::get(T), NewLoad, 0, LI.getName())); + auto *ET = AT->getElementType(); + auto NumElements = AT->getNumElements(); + if (NumElements == 1) { + LoadInst *NewLoad = combineLoadToNewType(IC, LI, ET, ".unpack"); + return IC.replaceInstUsesWith(LI, IC.Builder->CreateInsertValue( + UndefValue::get(T), NewLoad, 0, Name)); } + + const DataLayout &DL = IC.getDataLayout(); + auto EltSize = DL.getTypeAllocSize(ET); + auto Align = LI.getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(T); + + auto *Addr = LI.getPointerOperand(); + auto *IdxType = Type::getInt64Ty(T->getContext()); + auto *Zero = ConstantInt::get(IdxType, 0); + + Value *V = UndefValue::get(T); + uint64_t Offset = 0; + for (uint64_t i = 0; i < NumElements; i++) { + Value *Indices[2] = { + Zero, + ConstantInt::get(IdxType, i), + }; + auto *Ptr = IC.Builder->CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + Name + ".elt"); + auto *L = IC.Builder->CreateAlignedLoad(Ptr, MinAlign(Align, Offset), + Name + ".unpack"); + V = IC.Builder->CreateInsertValue(V, L, i); + Offset += EltSize; + } + + V->setName(Name); + return IC.replaceInstUsesWith(LI, V); } return nullptr; @@ -610,7 +642,7 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, } if (GlobalAlias *GA = dyn_cast<GlobalAlias>(P)) { - if (GA->mayBeOverridden()) + if (GA->isInterposable()) return false; Worklist.push_back(GA->getAliasee()); continue; @@ -638,7 +670,7 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, if (!GV->hasDefinitiveInitializer() || !GV->isConstant()) return false; - uint64_t InitSize = DL.getTypeAllocSize(GV->getType()->getElementType()); + uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType()); if (InitSize > MaxSize) return false; continue; @@ -695,10 +727,8 @@ static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI, return false; SmallVector<Value *, 4> Ops(GEPI->idx_begin(), GEPI->idx_begin() + Idx); - Type *AllocTy = GetElementPtrInst::getIndexedType( - cast<PointerType>(GEPI->getOperand(0)->getType()->getScalarType()) - ->getElementType(), - Ops); + Type *AllocTy = + GetElementPtrInst::getIndexedType(GEPI->getSourceElementType(), Ops); if (!AllocTy || !AllocTy->isSized()) return false; const DataLayout &DL = IC.getDataLayout(); @@ -781,10 +811,6 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return &LI; } - // None of the following transforms are legal for volatile/atomic loads. - // FIXME: Some of it is okay for atomic loads; needs refactoring. - if (!LI.isSimple()) return nullptr; - if (Instruction *Res = unpackLoadToAggregate(*this, LI)) return Res; @@ -793,10 +819,12 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // separated by a few arithmetic operations. BasicBlock::iterator BBI(LI); AAMDNodes AATags; + bool IsLoadCSE = false; if (Value *AvailableVal = - FindAvailableLoadedValue(Op, LI.getParent(), BBI, - DefMaxInstsToScan, AA, &AATags)) { - if (LoadInst *NLI = dyn_cast<LoadInst>(AvailableVal)) { + FindAvailableLoadedValue(&LI, LI.getParent(), BBI, + DefMaxInstsToScan, AA, &AATags, &IsLoadCSE)) { + if (IsLoadCSE) { + LoadInst *NLI = cast<LoadInst>(AvailableVal); unsigned KnownIDs[] = { LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, LLVMContext::MD_noalias, LLVMContext::MD_range, @@ -807,11 +835,15 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { combineMetadata(NLI, &LI, KnownIDs); }; - return ReplaceInstUsesWith( + return replaceInstUsesWith( LI, Builder->CreateBitOrPointerCast(AvailableVal, LI.getType(), LI.getName() + ".cast")); } + // None of the following transforms are legal for volatile/ordered atomic + // loads. Most of them do apply for unordered atomics. + if (!LI.isUnordered()) return nullptr; + // load(gep null, ...) -> unreachable if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { const Value *GEPI0 = GEPI->getOperand(0); @@ -823,7 +855,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // CFG. new StoreInst(UndefValue::get(LI.getType()), Constant::getNullValue(Op->getType()), &LI); - return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + return replaceInstUsesWith(LI, UndefValue::get(LI.getType())); } } @@ -836,7 +868,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // unreachable instruction directly because we cannot modify the CFG. new StoreInst(UndefValue::get(LI.getType()), Constant::getNullValue(Op->getType()), &LI); - return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); + return replaceInstUsesWith(LI, UndefValue::get(LI.getType())); } if (Op->hasOneUse()) { @@ -853,14 +885,17 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { if (SelectInst *SI = dyn_cast<SelectInst>(Op)) { // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). unsigned Align = LI.getAlignment(); - if (isSafeToLoadUnconditionally(SI->getOperand(1), SI, Align) && - isSafeToLoadUnconditionally(SI->getOperand(2), SI, Align)) { + if (isSafeToLoadUnconditionally(SI->getOperand(1), Align, DL, SI) && + isSafeToLoadUnconditionally(SI->getOperand(2), Align, DL, SI)) { LoadInst *V1 = Builder->CreateLoad(SI->getOperand(1), SI->getOperand(1)->getName()+".val"); LoadInst *V2 = Builder->CreateLoad(SI->getOperand(2), SI->getOperand(2)->getName()+".val"); + assert(LI.isUnordered() && "implied by above"); V1->setAlignment(Align); + V1->setAtomic(LI.getOrdering(), LI.getSynchScope()); V2->setAlignment(Align); + V2->setAtomic(LI.getOrdering(), LI.getSynchScope()); return SelectInst::Create(SI->getCondition(), V1, V2); } @@ -882,6 +917,61 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return nullptr; } +/// \brief Look for extractelement/insertvalue sequence that acts like a bitcast. +/// +/// \returns underlying value that was "cast", or nullptr otherwise. +/// +/// For example, if we have: +/// +/// %E0 = extractelement <2 x double> %U, i32 0 +/// %V0 = insertvalue [2 x double] undef, double %E0, 0 +/// %E1 = extractelement <2 x double> %U, i32 1 +/// %V1 = insertvalue [2 x double] %V0, double %E1, 1 +/// +/// and the layout of a <2 x double> is isomorphic to a [2 x double], +/// then %V1 can be safely approximated by a conceptual "bitcast" of %U. +/// Note that %U may contain non-undef values where %V1 has undef. +static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) { + Value *U = nullptr; + while (auto *IV = dyn_cast<InsertValueInst>(V)) { + auto *E = dyn_cast<ExtractElementInst>(IV->getInsertedValueOperand()); + if (!E) + return nullptr; + auto *W = E->getVectorOperand(); + if (!U) + U = W; + else if (U != W) + return nullptr; + auto *CI = dyn_cast<ConstantInt>(E->getIndexOperand()); + if (!CI || IV->getNumIndices() != 1 || CI->getZExtValue() != *IV->idx_begin()) + return nullptr; + V = IV->getAggregateOperand(); + } + if (!isa<UndefValue>(V) ||!U) + return nullptr; + + auto *UT = cast<VectorType>(U->getType()); + auto *VT = V->getType(); + // Check that types UT and VT are bitwise isomorphic. + const auto &DL = IC.getDataLayout(); + if (DL.getTypeStoreSizeInBits(UT) != DL.getTypeStoreSizeInBits(VT)) { + return nullptr; + } + if (auto *AT = dyn_cast<ArrayType>(VT)) { + if (AT->getNumElements() != UT->getNumElements()) + return nullptr; + } else { + auto *ST = cast<StructType>(VT); + if (ST->getNumElements() != UT->getNumElements()) + return nullptr; + for (const auto *EltT : ST->elements()) { + if (EltT != UT->getElementType()) + return nullptr; + } + } + return U; +} + /// \brief Combine stores to match the type of value being stored. /// /// The core idea here is that the memory does not have any intrinsic type and @@ -903,9 +993,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { /// the store instruction as otherwise there is no way to signal whether it was /// combined or not: IC.EraseInstFromFunction returns a null pointer. static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { - // FIXME: We could probably with some care handle both volatile and atomic - // stores here but it isn't clear that this is important. - if (!SI.isSimple()) + // FIXME: We could probably with some care handle both volatile and ordered + // atomic stores here but it isn't clear that this is important. + if (!SI.isUnordered()) return false; Value *V = SI.getValueOperand(); @@ -917,8 +1007,13 @@ static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { return true; } - // FIXME: We should also canonicalize loads of vectors when their elements are - // cast to other types. + if (Value *U = likeBitCastFromVector(IC, V)) { + combineStoreToNewValue(IC, SI, U); + return true; + } + + // FIXME: We should also canonicalize stores of vectors when their elements + // are cast to other types. return false; } @@ -950,11 +1045,16 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { if (SL->hasPadding()) return false; + auto Align = SI.getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(ST); + SmallString<16> EltName = V->getName(); EltName += ".elt"; auto *Addr = SI.getPointerOperand(); SmallString<16> AddrName = Addr->getName(); AddrName += ".repack"; + auto *IdxType = Type::getInt32Ty(ST->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); for (unsigned i = 0; i < Count; i++) { @@ -962,9 +1062,11 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), AddrName); + auto *Ptr = IC.Builder->CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + AddrName); auto *Val = IC.Builder->CreateExtractValue(V, i, EltName); - IC.Builder->CreateStore(Val, Ptr); + auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); + IC.Builder->CreateAlignedStore(Val, Ptr, EltAlign); } return true; @@ -972,11 +1074,43 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { if (auto *AT = dyn_cast<ArrayType>(T)) { // If the array only have one element, we unpack. - if (AT->getNumElements() == 1) { + auto NumElements = AT->getNumElements(); + if (NumElements == 1) { V = IC.Builder->CreateExtractValue(V, 0); combineStoreToNewValue(IC, SI, V); return true; } + + const DataLayout &DL = IC.getDataLayout(); + auto EltSize = DL.getTypeAllocSize(AT->getElementType()); + auto Align = SI.getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(T); + + SmallString<16> EltName = V->getName(); + EltName += ".elt"; + auto *Addr = SI.getPointerOperand(); + SmallString<16> AddrName = Addr->getName(); + AddrName += ".repack"; + + auto *IdxType = Type::getInt64Ty(T->getContext()); + auto *Zero = ConstantInt::get(IdxType, 0); + + uint64_t Offset = 0; + for (uint64_t i = 0; i < NumElements; i++) { + Value *Indices[2] = { + Zero, + ConstantInt::get(IdxType, i), + }; + auto *Ptr = IC.Builder->CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + AddrName); + auto *Val = IC.Builder->CreateExtractValue(V, i, EltName); + auto EltAlign = MinAlign(Align, Offset); + IC.Builder->CreateAlignedStore(Val, Ptr, EltAlign); + Offset += EltSize; + } + + return true; } return false; @@ -1017,7 +1151,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // Try to canonicalize the stored type. if (combineStoreToValueType(*this, SI)) - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); // Attempt to improve the alignment. unsigned KnownAlign = getOrEnforceKnownAlignment( @@ -1033,7 +1167,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // Try to canonicalize the stored type. if (unpackStoreToAggregate(*this, SI)) - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { @@ -1049,11 +1183,11 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // alloca dead. if (Ptr->hasOneUse()) { if (isa<AllocaInst>(Ptr)) - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { if (isa<AllocaInst>(GEP->getOperand(0))) { if (GEP->getOperand(0)->hasOneUse()) - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); } } } @@ -1079,7 +1213,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { SI.getOperand(1))) { ++NumDeadStore; ++BBI; - EraseInstFromFunction(*PrevSI); + eraseInstFromFunction(*PrevSI); continue; } break; @@ -1091,7 +1225,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { if (LI == Val && equivalentAddressValues(LI->getOperand(0), Ptr)) { assert(SI.isUnordered() && "can't eliminate ordering operation"); - return EraseInstFromFunction(SI); + return eraseInstFromFunction(SI); } // Otherwise, this is a load from some other location. Stores before it @@ -1116,11 +1250,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // store undef, Ptr -> noop if (isa<UndefValue>(Val)) - return EraseInstFromFunction(SI); - - // The code below needs to be audited and adjusted for unordered atomics - if (!SI.isSimple()) - return nullptr; + return eraseInstFromFunction(SI); // If this store is the last instruction in the basic block (possibly // excepting debug info instructions), and if the block ends with an @@ -1147,6 +1277,9 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { /// into a phi node with a store in the successor. /// bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { + assert(SI.isUnordered() && + "this code has not been auditted for volatile or ordered store case"); + BasicBlock *StoreBB = SI.getParent(); // Check to see if the successor block has exactly two incoming edges. If @@ -1268,7 +1401,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { } // Nuke the old stores. - EraseInstFromFunction(SI); - EraseInstFromFunction(*OtherStore); + eraseInstFromFunction(SI); + eraseInstFromFunction(*OtherStore); return true; } diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 160792b0a000..788097f33f12 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -45,28 +45,28 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. - if (BinaryOperator *I = dyn_cast<BinaryOperator>(V)) - if (I->isLogicalShift() && - isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, - IC.getAssumptionCache(), &CxtI, - IC.getDominatorTree())) { - // We know that this is an exact/nuw shift and that the input is a - // non-zero context as well. - if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { - I->setOperand(0, V2); - MadeChange = true; - } + BinaryOperator *I = dyn_cast<BinaryOperator>(V); + if (I && I->isLogicalShift() && + isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, + IC.getAssumptionCache(), &CxtI, + IC.getDominatorTree())) { + // We know that this is an exact/nuw shift and that the input is a + // non-zero context as well. + if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { + I->setOperand(0, V2); + MadeChange = true; + } - if (I->getOpcode() == Instruction::LShr && !I->isExact()) { - I->setIsExact(); - MadeChange = true; - } + if (I->getOpcode() == Instruction::LShr && !I->isExact()) { + I->setIsExact(); + MadeChange = true; + } - if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) { - I->setHasNoUnsignedWrap(); - MadeChange = true; - } + if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) { + I->setHasNoUnsignedWrap(); + MadeChange = true; } + } // TODO: Lots more we could do here: // If V is a phi node, we can call this on each of its operands. @@ -177,13 +177,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // X * -1 == 0 - X if (match(Op1, m_AllOnes())) { @@ -323,7 +323,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO)) if (SDiv->isExact()) { if (Op1BO == Op1C) - return ReplaceInstUsesWith(I, Op0BO); + return replaceInstUsesWith(I, Op0BO); return BinaryOperator::CreateNeg(Op0BO); } @@ -374,10 +374,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2, 0, &I)) - BoolCast = Op0, OtherOp = Op1; - else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) - BoolCast = Op1, OtherOp = Op0; + if (MaskedValueIsZero(Op0, Negative2, 0, &I)) { + BoolCast = Op0; + OtherOp = Op1; + } else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) { + BoolCast = Op1; + OtherOp = Op0; + } if (BoolCast) { Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()), @@ -536,14 +539,14 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) std::swap(Op0, Op1); if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -574,7 +577,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { // Try to simplify "MDC * Constant" if (isFMulOrFDivWithConstant(Op0)) if (Value *V = foldFMulConst(cast<Instruction>(Op0), C, &I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (MDC +/- C1) * C => (MDC * C) +/- (C1 * C) Instruction *FAddSub = dyn_cast<Instruction>(Op0); @@ -612,11 +615,22 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } - // sqrt(X) * sqrt(X) -> X - if (AllowReassociate && (Op0 == Op1)) - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) - if (II->getIntrinsicID() == Intrinsic::sqrt) - return ReplaceInstUsesWith(I, II->getOperand(0)); + if (Op0 == Op1) { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) { + // sqrt(X) * sqrt(X) -> X + if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt) + return replaceInstUsesWith(I, II->getOperand(0)); + + // fabs(X) * fabs(X) -> X * X + if (II->getIntrinsicID() == Intrinsic::fabs) { + Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0), + II->getOperand(0), + I.getName()); + FMulVal->copyFastMathFlags(&I); + return FMulVal; + } + } + } // Under unsafe algebra do: // X * log2(0.5*Y) = X*log2(Y) - X @@ -641,7 +655,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *FMulVal = Builder->CreateFMul(OpX, Log2); Value *FSub = Builder->CreateFSub(FMulVal, OpX); FSub->takeName(&I); - return ReplaceInstUsesWith(I, FSub); + return replaceInstUsesWith(I, FSub); } } @@ -661,7 +675,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (N1) { Value *FMul = Builder->CreateFMul(N0, N1); FMul->takeName(&I); - return ReplaceInstUsesWith(I, FMul); + return replaceInstUsesWith(I, FMul); } if (Opnd0->hasOneUse()) { @@ -669,7 +683,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *T = Builder->CreateFMul(N0, Opnd1); Value *Neg = Builder->CreateFNeg(T); Neg->takeName(&I); - return ReplaceInstUsesWith(I, Neg); + return replaceInstUsesWith(I, Neg); } } @@ -698,7 +712,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *R = Builder->CreateFMul(T, Y); R->takeName(&I); - return ReplaceInstUsesWith(I, R); + return replaceInstUsesWith(I, R); } } } @@ -1043,10 +1057,10 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) @@ -1116,27 +1130,43 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) return Common; - // sdiv X, -1 == -X - if (match(Op1, m_AllOnes())) - return BinaryOperator::CreateNeg(Op0); + const APInt *Op1C; + if (match(Op1, m_APInt(Op1C))) { + // sdiv X, -1 == -X + if (Op1C->isAllOnesValue()) + return BinaryOperator::CreateNeg(Op0); - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { - // sdiv X, C --> ashr exact X, log2(C) - if (I.isExact() && RHS->getValue().isNonNegative() && - RHS->getValue().isPowerOf2()) { - Value *ShAmt = llvm::ConstantInt::get(RHS->getType(), - RHS->getValue().exactLogBase2()); + // sdiv exact X, C --> ashr exact X, log2(C) + if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) { + Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2()); return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName()); } + + // If the dividend is sign-extended and the constant divisor is small enough + // to fit in the source type, shrink the division to the narrower type: + // (sext X) sdiv C --> sext (X sdiv C) + Value *Op0Src; + if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) && + Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) { + + // In the general case, we need to make sure that the dividend is not the + // minimum signed value because dividing that by -1 is UB. But here, we + // know that the -1 divisor case is already handled above. + + Constant *NarrowDivisor = + ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType()); + Value *NarrowOp = Builder->CreateSDiv(Op0Src, NarrowDivisor); + return new SExtInst(NarrowOp, Op0->getType()); + } } if (Constant *RHS = dyn_cast<Constant>(Op1)) { @@ -1214,11 +1244,11 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) @@ -1363,8 +1393,17 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; } else if (isa<PHINode>(Op0I)) { - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + using namespace llvm::PatternMatch; + const APInt *Op1Int; + if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() && + (I.getOpcode() == Instruction::URem || + !Op1Int->isMinSignedValue())) { + // FoldOpIntoPhi will speculate instructions to the end of the PHI's + // predecessor blocks, so do this only if we know the srem or urem + // will not fault. + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } } // See if we can fold away this rem instruction. @@ -1380,10 +1419,10 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) return common; @@ -1405,7 +1444,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (match(Op0, m_One())) { Value *Cmp = Builder->CreateICmpNE(Op1, Op0); Value *Ext = Builder->CreateZExt(Cmp, I.getType()); - return ReplaceInstUsesWith(I, Ext); + return replaceInstUsesWith(I, Ext); } return nullptr; @@ -1415,10 +1454,10 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer rem common cases if (Instruction *Common = commonIRemTransforms(I)) @@ -1490,11 +1529,11 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index f1aa98b5e359..79a4912332ff 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -15,8 +15,11 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" @@ -32,15 +35,6 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { Type *LHSType = LHSVal->getType(); Type *RHSType = RHSVal->getType(); - bool isNUW = false, isNSW = false, isExact = false; - if (OverflowingBinaryOperator *BO = - dyn_cast<OverflowingBinaryOperator>(FirstInst)) { - isNUW = BO->hasNoUnsignedWrap(); - isNSW = BO->hasNoSignedWrap(); - } else if (PossiblyExactOperator *PEO = - dyn_cast<PossiblyExactOperator>(FirstInst)) - isExact = PEO->isExact(); - // Scan to see if all operands are the same opcode, and all have one use. for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { Instruction *I = dyn_cast<Instruction>(PN.getIncomingValue(i)); @@ -56,13 +50,6 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { if (CI->getPredicate() != cast<CmpInst>(FirstInst)->getPredicate()) return nullptr; - if (isNUW) - isNUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); - if (isNSW) - isNSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - if (isExact) - isExact = cast<PossiblyExactOperator>(I)->isExact(); - // Keep track of which operand needs a phi node. if (I->getOperand(0) != LHSVal) LHSVal = nullptr; if (I->getOperand(1) != RHSVal) RHSVal = nullptr; @@ -121,9 +108,12 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { BinaryOperator *BinOp = cast<BinaryOperator>(FirstInst); BinaryOperator *NewBinOp = BinaryOperator::Create(BinOp->getOpcode(), LHSVal, RHSVal); - if (isNUW) NewBinOp->setHasNoUnsignedWrap(); - if (isNSW) NewBinOp->setHasNoSignedWrap(); - if (isExact) NewBinOp->setIsExact(); + + NewBinOp->copyIRFlags(PN.getIncomingValue(0)); + + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) + NewBinOp->andIRFlags(PN.getIncomingValue(i)); + NewBinOp->setDebugLoc(FirstInst->getDebugLoc()); return NewBinOp; } @@ -494,7 +484,6 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { // code size and simplifying code. Constant *ConstantOp = nullptr; Type *CastSrcTy = nullptr; - bool isNUW = false, isNSW = false, isExact = false; if (isa<CastInst>(FirstInst)) { CastSrcTy = FirstInst->getOperand(0)->getType(); @@ -511,14 +500,6 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { ConstantOp = dyn_cast<Constant>(FirstInst->getOperand(1)); if (!ConstantOp) return FoldPHIArgBinOpIntoPHI(PN); - - if (OverflowingBinaryOperator *BO = - dyn_cast<OverflowingBinaryOperator>(FirstInst)) { - isNUW = BO->hasNoUnsignedWrap(); - isNSW = BO->hasNoSignedWrap(); - } else if (PossiblyExactOperator *PEO = - dyn_cast<PossiblyExactOperator>(FirstInst)) - isExact = PEO->isExact(); } else { return nullptr; // Cannot fold this operation. } @@ -534,13 +515,6 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { } else if (I->getOperand(1) != ConstantOp) { return nullptr; } - - if (isNUW) - isNUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); - if (isNSW) - isNSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - if (isExact) - isExact = cast<PossiblyExactOperator>(I)->isExact(); } // Okay, they are all the same operation. Create a new PHI node of the @@ -581,9 +555,11 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(FirstInst)) { BinOp = BinaryOperator::Create(BinOp->getOpcode(), PhiVal, ConstantOp); - if (isNUW) BinOp->setHasNoUnsignedWrap(); - if (isNSW) BinOp->setHasNoSignedWrap(); - if (isExact) BinOp->setIsExact(); + BinOp->copyIRFlags(PN.getIncomingValue(0)); + + for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) + BinOp->andIRFlags(PN.getIncomingValue(i)); + BinOp->setDebugLoc(FirstInst->getDebugLoc()); return BinOp; } @@ -641,6 +617,16 @@ static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, return true; } +/// Return an existing non-zero constant if this phi node has one, otherwise +/// return constant 1. +static ConstantInt *GetAnyNonZeroConstInt(PHINode &PN) { + assert(isa<IntegerType>(PN.getType()) && "Expect only intger type phi"); + for (Value *V : PN.operands()) + if (auto *ConstVA = dyn_cast<ConstantInt>(V)) + if (!ConstVA->isZeroValue()) + return ConstVA; + return ConstantInt::get(cast<IntegerType>(PN.getType()), 1); +} namespace { struct PHIUsageRecord { @@ -768,7 +754,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // If we have no users, they must be all self uses, just nuke the PHI. if (PHIUsers.empty()) - return ReplaceInstUsesWith(FirstPhi, UndefValue::get(FirstPhi.getType())); + return replaceInstUsesWith(FirstPhi, UndefValue::get(FirstPhi.getType())); // If this phi node is transformable, create new PHIs for all the pieces // extracted out of it. First, sort the users by their offset and size. @@ -864,22 +850,22 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { } // Replace the use of this piece with the PHI node. - ReplaceInstUsesWith(*PHIUsers[UserI].Inst, EltPHI); + replaceInstUsesWith(*PHIUsers[UserI].Inst, EltPHI); } // Replace all the remaining uses of the PHI nodes (self uses and the lshrs) // with undefs. Value *Undef = UndefValue::get(FirstPhi.getType()); for (unsigned i = 1, e = PHIsToSlice.size(); i != e; ++i) - ReplaceInstUsesWith(*PHIsToSlice[i], Undef); - return ReplaceInstUsesWith(FirstPhi, Undef); + replaceInstUsesWith(*PHIsToSlice[i], Undef); + return replaceInstUsesWith(FirstPhi, Undef); } // PHINode simplification // Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (Value *V = SimplifyInstruction(&PN, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(PN, V); + return replaceInstUsesWith(PN, V); if (Instruction *Result = FoldPHIArgZextsIntoPHI(PN)) return Result; @@ -905,7 +891,7 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { SmallPtrSet<PHINode*, 16> PotentiallyDeadPHIs; PotentiallyDeadPHIs.insert(&PN); if (DeadPHICycle(PU, PotentiallyDeadPHIs)) - return ReplaceInstUsesWith(PN, UndefValue::get(PN.getType())); + return replaceInstUsesWith(PN, UndefValue::get(PN.getType())); } // If this phi has a single use, and if that use just computes a value for @@ -917,7 +903,30 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (PHIUser->hasOneUse() && (isa<BinaryOperator>(PHIUser) || isa<GetElementPtrInst>(PHIUser)) && PHIUser->user_back() == &PN) { - return ReplaceInstUsesWith(PN, UndefValue::get(PN.getType())); + return replaceInstUsesWith(PN, UndefValue::get(PN.getType())); + } + // When a PHI is used only to be compared with zero, it is safe to replace + // an incoming value proved as known nonzero with any non-zero constant. + // For example, in the code below, the incoming value %v can be replaced + // with any non-zero constant based on the fact that the PHI is only used to + // be compared with zero and %v is a known non-zero value: + // %v = select %cond, 1, 2 + // %p = phi [%v, BB] ... + // icmp eq, %p, 0 + auto *CmpInst = dyn_cast<ICmpInst>(PHIUser); + // FIXME: To be simple, handle only integer type for now. + if (CmpInst && isa<IntegerType>(PN.getType()) && CmpInst->isEquality() && + match(CmpInst->getOperand(1), m_Zero())) { + ConstantInt *NonZeroConst = nullptr; + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + Instruction *CtxI = PN.getIncomingBlock(i)->getTerminator(); + Value *VA = PN.getIncomingValue(i); + if (isKnownNonZero(VA, DL, 0, AC, CtxI, DT)) { + if (!NonZeroConst) + NonZeroConst = GetAnyNonZeroConstInt(PN); + PN.setIncomingValue(i, NonZeroConst); + } + } } } @@ -951,7 +960,7 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (InValNo == NumIncomingVals) { SmallPtrSet<PHINode*, 16> ValueEqualPHIs; if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs)) - return ReplaceInstUsesWith(PN, NonPhiInVal); + return replaceInstUsesWith(PN, NonPhiInVal); } } } diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 51219bcb0b7b..d7eed790e2ab 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -116,25 +116,41 @@ static Constant *GetSelectFoldableConstant(Instruction *I) { } } -/// Here we have (select c, TI, FI), and we know that TI and FI -/// have the same opcode and only one use each. Try to simplify this. +/// We have (select c, TI, FI), and we know that TI and FI have the same opcode. Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI) { - if (TI->getNumOperands() == 1) { - // If this is a non-volatile load or a cast from the same type, - // merge. - if (TI->isCast()) { - Type *FIOpndTy = FI->getOperand(0)->getType(); - if (TI->getOperand(0)->getType() != FIOpndTy) + // If this is a cast from the same type, merge. + if (TI->getNumOperands() == 1 && TI->isCast()) { + Type *FIOpndTy = FI->getOperand(0)->getType(); + if (TI->getOperand(0)->getType() != FIOpndTy) + return nullptr; + + // The select condition may be a vector. We may only change the operand + // type if the vector width remains the same (and matches the condition). + Type *CondTy = SI.getCondition()->getType(); + if (CondTy->isVectorTy()) { + if (!FIOpndTy->isVectorTy()) return nullptr; - // The select condition may be a vector. We may only change the operand - // type if the vector width remains the same (and matches the condition). - Type *CondTy = SI.getCondition()->getType(); - if (CondTy->isVectorTy() && (!FIOpndTy->isVectorTy() || - CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements())) + if (CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements()) return nullptr; - } else { - return nullptr; // unknown unary op. + + // TODO: If the backend knew how to deal with casts better, we could + // remove this limitation. For now, there's too much potential to create + // worse codegen by promoting the select ahead of size-altering casts + // (PR28160). + // + // Note that ValueTracking's matchSelectPattern() looks through casts + // without checking 'hasOneUse' when it matches min/max patterns, so this + // transform may end up happening anyway. + if (TI->getOpcode() != Instruction::BitCast && + (!TI->hasOneUse() || !FI->hasOneUse())) + return nullptr; + + } else if (!TI->hasOneUse() || !FI->hasOneUse()) { + // TODO: The one-use restrictions for a scalar select could be eased if + // the fold of a select in visitLoadInst() was enhanced to match a pattern + // that includes a cast. + return nullptr; } // Fold this by inserting a select from the input values. @@ -144,8 +160,13 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, TI->getType()); } - // Only handle binary operators here. - if (!isa<BinaryOperator>(TI)) + // TODO: This function ends awkwardly in unreachable - fix to be more normal. + + // Only handle binary operators with one-use here. As with the cast case + // above, it may be possible to relax the one-use constraint, but that needs + // be examined carefully since it may not reduce the total number of + // instructions. + if (!isa<BinaryOperator>(TI) || !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; // Figure out if the operations have any operands in common. @@ -231,12 +252,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI); BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(), FalseVal, NewSel); - if (isa<PossiblyExactOperator>(BO)) - BO->setIsExact(TVI_BO->isExact()); - if (isa<OverflowingBinaryOperator>(BO)) { - BO->setHasNoUnsignedWrap(TVI_BO->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap(TVI_BO->hasNoSignedWrap()); - } + BO->copyIRFlags(TVI_BO); return BO; } } @@ -266,12 +282,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI); BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(), TrueVal, NewSel); - if (isa<PossiblyExactOperator>(BO)) - BO->setIsExact(FVI_BO->isExact()); - if (isa<OverflowingBinaryOperator>(BO)) { - BO->setHasNoUnsignedWrap(FVI_BO->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap(FVI_BO->hasNoSignedWrap()); - } + BO->copyIRFlags(FVI_BO); return BO; } } @@ -353,7 +364,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, /// %1 = icmp ne i32 %x, 0 /// %2 = select i1 %1, i32 %0, i32 32 /// \code -/// +/// /// into: /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, @@ -519,10 +530,10 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // Check if we can express the operation with a single or. if (C2->isAllOnesValue()) - return ReplaceInstUsesWith(SI, Builder->CreateOr(AShr, C1)); + return replaceInstUsesWith(SI, Builder->CreateOr(AShr, C1)); Value *And = Builder->CreateAnd(AShr, C2->getValue()-C1->getValue()); - return ReplaceInstUsesWith(SI, Builder->CreateAdd(And, C1)); + return replaceInstUsesWith(SI, Builder->CreateAdd(And, C1)); } } } @@ -585,15 +596,15 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, V = Builder->CreateOr(X, *Y); if (V) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); } } if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); return Changed ? &SI : nullptr; } @@ -642,11 +653,14 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, Value *A, Value *B, Instruction &Outer, SelectPatternFlavor SPF2, Value *C) { + if (Outer.getType() != Inner->getType()) + return nullptr; + if (C == A || C == B) { // MAX(MAX(A, B), B) -> MAX(A, B) // MIN(MIN(a, b), a) -> MIN(a, b) if (SPF1 == SPF2) - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); // MAX(MIN(a, b), a) -> a // MIN(MAX(a, b), a) -> a @@ -654,14 +668,14 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, (SPF1 == SPF_SMAX && SPF2 == SPF_SMIN) || (SPF1 == SPF_UMIN && SPF2 == SPF_UMAX) || (SPF1 == SPF_UMAX && SPF2 == SPF_UMIN)) - return ReplaceInstUsesWith(Outer, C); + return replaceInstUsesWith(Outer, C); } if (SPF1 == SPF2) { if (ConstantInt *CB = dyn_cast<ConstantInt>(B)) { if (ConstantInt *CC = dyn_cast<ConstantInt>(C)) { - APInt ACB = CB->getValue(); - APInt ACC = CC->getValue(); + const APInt &ACB = CB->getValue(); + const APInt &ACC = CC->getValue(); // MIN(MIN(A, 23), 97) -> MIN(A, 23) // MAX(MAX(A, 97), 23) -> MAX(A, 97) @@ -669,7 +683,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, (SPF1 == SPF_SMIN && ACB.sle(ACC)) || (SPF1 == SPF_UMAX && ACB.uge(ACC)) || (SPF1 == SPF_SMAX && ACB.sge(ACC))) - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); // MIN(MIN(A, 97), 23) -> MIN(A, 23) // MAX(MAX(A, 23), 97) -> MAX(A, 97) @@ -687,7 +701,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, // ABS(ABS(X)) -> ABS(X) // NABS(NABS(X)) -> NABS(X) if (SPF1 == SPF2 && (SPF1 == SPF_ABS || SPF1 == SPF_NABS)) { - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); } // ABS(NABS(X)) -> ABS(X) @@ -697,7 +711,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, SelectInst *SI = cast<SelectInst>(Inner); Value *NewSI = Builder->CreateSelect( SI->getCondition(), SI->getFalseValue(), SI->getTrueValue()); - return ReplaceInstUsesWith(Outer, NewSI); + return replaceInstUsesWith(Outer, NewSI); } auto IsFreeOrProfitableToInvert = @@ -742,7 +756,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern( Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); - return ReplaceInstUsesWith(Outer, NewOuter); + return replaceInstUsesWith(Outer, NewOuter); } return nullptr; @@ -823,76 +837,156 @@ static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal, return V; } +/// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))). +/// This is even legal for FP. +static Instruction *foldAddSubSelect(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + auto *TI = dyn_cast<Instruction>(TrueVal); + auto *FI = dyn_cast<Instruction>(FalseVal); + if (!TI || !FI || !TI->hasOneUse() || !FI->hasOneUse()) + return nullptr; + + Instruction *AddOp = nullptr, *SubOp = nullptr; + if ((TI->getOpcode() == Instruction::Sub && + FI->getOpcode() == Instruction::Add) || + (TI->getOpcode() == Instruction::FSub && + FI->getOpcode() == Instruction::FAdd)) { + AddOp = FI; + SubOp = TI; + } else if ((FI->getOpcode() == Instruction::Sub && + TI->getOpcode() == Instruction::Add) || + (FI->getOpcode() == Instruction::FSub && + TI->getOpcode() == Instruction::FAdd)) { + AddOp = TI; + SubOp = FI; + } + + if (AddOp) { + Value *OtherAddOp = nullptr; + if (SubOp->getOperand(0) == AddOp->getOperand(0)) { + OtherAddOp = AddOp->getOperand(1); + } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { + OtherAddOp = AddOp->getOperand(0); + } + + if (OtherAddOp) { + // So at this point we know we have (Y -> OtherAddOp): + // select C, (add X, Y), (sub X, Z) + Value *NegVal; // Compute -Z + if (SI.getType()->isFPOrFPVectorTy()) { + NegVal = Builder.CreateFNeg(SubOp->getOperand(1)); + if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) { + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + NegInst->setFastMathFlags(Flags); + } + } else { + NegVal = Builder.CreateNeg(SubOp->getOperand(1)); + } + + Value *NewTrueOp = OtherAddOp; + Value *NewFalseOp = NegVal; + if (AddOp != TI) + std::swap(NewTrueOp, NewFalseOp); + Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp, + SI.getName() + ".p"); + + if (SI.getType()->isFPOrFPVectorTy()) { + Instruction *RI = + BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); + + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + RI->setFastMathFlags(Flags); + return RI; + } else + return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); + } + } + return nullptr; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); - if (SI.getType()->isIntegerTy(1)) { - if (ConstantInt *C = dyn_cast<ConstantInt>(TrueVal)) { - if (C->getZExtValue()) { - // Change: A = select B, true, C --> A = or B, C - return BinaryOperator::CreateOr(CondVal, FalseVal); - } + if (SelType->getScalarType()->isIntegerTy(1) && + TrueVal->getType() == CondVal->getType()) { + if (match(TrueVal, m_One())) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); + } + if (match(TrueVal, m_Zero())) { // Change: A = select B, false, C --> A = and !B, C - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); return BinaryOperator::CreateAnd(NotCond, FalseVal); } - if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) { - if (!C->getZExtValue()) { - // Change: A = select B, C, false --> A = and B, C - return BinaryOperator::CreateAnd(CondVal, TrueVal); - } + if (match(FalseVal, m_Zero())) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } + if (match(FalseVal, m_One())) { // Change: A = select B, C, true --> A = or !B, C - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); return BinaryOperator::CreateOr(NotCond, TrueVal); } - // select a, b, a -> a&b - // select a, a, b -> a|b + // select a, a, b -> a | b + // select a, b, a -> a & b if (CondVal == TrueVal) return BinaryOperator::CreateOr(CondVal, FalseVal); if (CondVal == FalseVal) return BinaryOperator::CreateAnd(CondVal, TrueVal); - // select a, ~a, b -> (~a)&b - // select a, b, ~a -> (~a)|b + // select a, ~a, b -> (~a) & b + // select a, b, ~a -> (~a) | b if (match(TrueVal, m_Not(m_Specific(CondVal)))) return BinaryOperator::CreateAnd(TrueVal, FalseVal); if (match(FalseVal, m_Not(m_Specific(CondVal)))) return BinaryOperator::CreateOr(TrueVal, FalseVal); } - // Selecting between two integer constants? - if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) - if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) { - // select C, 1, 0 -> zext C to int - if (FalseValC->isZero() && TrueValC->getValue() == 1) - return new ZExtInst(CondVal, SI.getType()); - - // select C, -1, 0 -> sext C to int - if (FalseValC->isZero() && TrueValC->isAllOnesValue()) - return new SExtInst(CondVal, SI.getType()); - - // select C, 0, 1 -> zext !C to int - if (TrueValC->isZero() && FalseValC->getValue() == 1) { - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); - return new ZExtInst(NotCond, SI.getType()); - } + // Selecting between two integer or vector splat integer constants? + // + // Note that we don't handle a scalar select of vectors: + // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0> + // because that may need 3 instructions to splat the condition value: + // extend, insertelement, shufflevector. + if (CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { + // select C, 1, 0 -> zext C to int + if (match(TrueVal, m_One()) && match(FalseVal, m_Zero())) + return new ZExtInst(CondVal, SelType); + + // select C, -1, 0 -> sext C to int + if (match(TrueVal, m_AllOnes()) && match(FalseVal, m_Zero())) + return new SExtInst(CondVal, SelType); + + // select C, 0, 1 -> zext !C to int + if (match(TrueVal, m_Zero()) && match(FalseVal, m_One())) { + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + return new ZExtInst(NotCond, SelType); + } - // select C, 0, -1 -> sext !C to int - if (TrueValC->isZero() && FalseValC->isAllOnesValue()) { - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); - return new SExtInst(NotCond, SI.getType()); - } + // select C, 0, -1 -> sext !C to int + if (match(TrueVal, m_Zero()) && match(FalseVal, m_AllOnes())) { + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + return new SExtInst(NotCond, SelType); + } + } + if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) + if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) if (Value *V = foldSelectICmpAnd(SI, TrueValC, FalseValC, Builder)) - return ReplaceInstUsesWith(SI, V); - } + return replaceInstUsesWith(SI, V); // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { @@ -907,7 +1001,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); } // Transform (X une Y) ? X : Y -> X if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { @@ -919,7 +1013,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, TrueVal); + return replaceInstUsesWith(SI, TrueVal); } // Canonicalize to use ordered comparisons by swapping the select @@ -950,7 +1044,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); } // Transform (X une Y) ? Y : X -> Y if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { @@ -962,7 +1056,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, TrueVal); + return replaceInstUsesWith(SI, TrueVal); } // Canonicalize to use ordered comparisons by swapping the select @@ -991,77 +1085,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Result = visitSelectInstWithICmp(SI, ICI)) return Result; - if (Instruction *TI = dyn_cast<Instruction>(TrueVal)) - if (Instruction *FI = dyn_cast<Instruction>(FalseVal)) - if (TI->hasOneUse() && FI->hasOneUse()) { - Instruction *AddOp = nullptr, *SubOp = nullptr; - - // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) - if (TI->getOpcode() == FI->getOpcode()) - if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) - return IV; - - // Turn select C, (X+Y), (X-Y) --> (X+(select C, Y, (-Y))). This is - // even legal for FP. - if ((TI->getOpcode() == Instruction::Sub && - FI->getOpcode() == Instruction::Add) || - (TI->getOpcode() == Instruction::FSub && - FI->getOpcode() == Instruction::FAdd)) { - AddOp = FI; SubOp = TI; - } else if ((FI->getOpcode() == Instruction::Sub && - TI->getOpcode() == Instruction::Add) || - (FI->getOpcode() == Instruction::FSub && - TI->getOpcode() == Instruction::FAdd)) { - AddOp = TI; SubOp = FI; - } + if (Instruction *Add = foldAddSubSelect(SI, *Builder)) + return Add; - if (AddOp) { - Value *OtherAddOp = nullptr; - if (SubOp->getOperand(0) == AddOp->getOperand(0)) { - OtherAddOp = AddOp->getOperand(1); - } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { - OtherAddOp = AddOp->getOperand(0); - } - - if (OtherAddOp) { - // So at this point we know we have (Y -> OtherAddOp): - // select C, (add X, Y), (sub X, Z) - Value *NegVal; // Compute -Z - if (SI.getType()->isFPOrFPVectorTy()) { - NegVal = Builder->CreateFNeg(SubOp->getOperand(1)); - if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) { - FastMathFlags Flags = AddOp->getFastMathFlags(); - Flags &= SubOp->getFastMathFlags(); - NegInst->setFastMathFlags(Flags); - } - } else { - NegVal = Builder->CreateNeg(SubOp->getOperand(1)); - } - - Value *NewTrueOp = OtherAddOp; - Value *NewFalseOp = NegVal; - if (AddOp != TI) - std::swap(NewTrueOp, NewFalseOp); - Value *NewSel = - Builder->CreateSelect(CondVal, NewTrueOp, - NewFalseOp, SI.getName() + ".p"); - - if (SI.getType()->isFPOrFPVectorTy()) { - Instruction *RI = - BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); - - FastMathFlags Flags = AddOp->getFastMathFlags(); - Flags &= SubOp->getFastMathFlags(); - RI->setFastMathFlags(Flags); - return RI; - } else - return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); - } - } - } + // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) + auto *TI = dyn_cast<Instruction>(TrueVal); + auto *FI = dyn_cast<Instruction>(FalseVal); + if (TI && FI && TI->getOpcode() == FI->getOpcode()) + if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) + return IV; // See if we can fold the select into one of our operands. - if (SI.getType()->isIntOrIntVectorTy() || SI.getType()->isFPOrFPVectorTy()) { + if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) { if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; @@ -1073,7 +1108,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (SelectPatternResult::isMinOrMax(SPF)) { // Canonicalize so that type casts are outside select patterns. if (LHS->getType()->getPrimitiveSizeInBits() != - SI.getType()->getPrimitiveSizeInBits()) { + SelType->getPrimitiveSizeInBits()) { CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); Value *Cmp; @@ -1088,8 +1123,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NewSI = Builder->CreateCast(CastOp, Builder->CreateSelect(Cmp, LHS, RHS), - SI.getType()); - return ReplaceInstUsesWith(SI, NewSI); + SelType); + return replaceInstUsesWith(SI, NewSI); } } @@ -1132,7 +1167,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { : Builder->CreateICmpULT(NewLHS, NewRHS); Value *NewSI = Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); - return ReplaceInstUsesWith(SI, NewSI); + return replaceInstUsesWith(SI, NewSI); } } } @@ -1195,18 +1230,36 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return &SI; } - if (VectorType* VecTy = dyn_cast<VectorType>(SI.getType())) { + if (VectorType* VecTy = dyn_cast<VectorType>(SelType)) { unsigned VWidth = VecTy->getNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) { if (V != &SI) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); return &SI; } if (isa<ConstantAggregateZero>(CondVal)) { - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); + } + } + + // See if we can determine the result of this select based on a dominating + // condition. + BasicBlock *Parent = SI.getParent(); + if (BasicBlock *Dom = Parent->getSinglePredecessor()) { + auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator()); + if (PBI && PBI->isConditional() && + PBI->getSuccessor(0) != PBI->getSuccessor(1) && + (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { + bool CondIsFalse = PBI->getSuccessor(1) == Parent; + Optional<bool> Implication = isImpliedCondition( + PBI->getCondition(), SI.getCondition(), DL, CondIsFalse); + if (Implication) { + Value *V = *Implication ? TrueVal : FalseVal; + return replaceInstUsesWith(SI, V); + } } } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 0c7defa5fff8..08e16a7ee1af 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -55,6 +55,51 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { return nullptr; } +/// Return true if we can simplify two logical (either left or right) shifts +/// that have constant shift amounts. +static bool canEvaluateShiftedShift(unsigned FirstShiftAmt, + bool IsFirstShiftLeft, + Instruction *SecondShift, InstCombiner &IC, + Instruction *CxtI) { + assert(SecondShift->isLogicalShift() && "Unexpected instruction type"); + + // We need constant shifts. + auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1)); + if (!SecondShiftConst) + return false; + + unsigned SecondShiftAmt = SecondShiftConst->getZExtValue(); + bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl; + + // We can always fold shl(c1) + shl(c2) -> shl(c1+c2). + // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2). + if (IsFirstShiftLeft == IsSecondShiftLeft) + return true; + + // We can always fold lshr(c) + shl(c) -> and(c2). + // We can always fold shl(c) + lshr(c) -> and(c2). + if (FirstShiftAmt == SecondShiftAmt) + return true; + + unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits(); + + // If the 2nd shift is bigger than the 1st, we can fold: + // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or + // shl(c1) + lshr(c2) -> lshr(c3) + and(c4), + // but it isn't profitable unless we know the and'd out bits are already zero. + // Also check that the 2nd shift is valid (less than the type width) or we'll + // crash trying to produce the bit mask for the 'and'. + if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) { + unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt + : SecondShiftAmt - FirstShiftAmt; + APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift; + if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI)) + return true; + } + + return false; +} + /// See if we can compute the specified value, but shifted /// logically to the left or right by some number of bits. This should return /// true if the expression can be computed for the same cost as the current @@ -67,7 +112,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { /// where the client will ask if E can be computed shifted right by 64-bits. If /// this succeeds, the GetShiftedValue function will be called to produce the /// value. -static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, +static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, InstCombiner &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) @@ -81,8 +126,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // the value which means that we don't care if the shift has multiple uses. // TODO: Handle opposite shift by exact value. ConstantInt *CI = nullptr; - if ((isLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || - (!isLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { + if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || + (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { if (CI->getZExtValue() == NumBits) { // TODO: Check that the input bits are already zero with MaskedValueIsZero #if 0 @@ -111,64 +156,19 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Or: case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. - return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC, I) && - CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC, I); - - case Instruction::Shl: { - // We can often fold the shift into shifts-by-a-constant. - CI = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!CI) return false; - - // We can always fold shl(c1)+shl(c2) -> shl(c1+c2). - if (isLeftShift) return true; + return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && + CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); - // We can always turn shl(c)+shr(c) -> and(c2). - if (CI->getValue() == NumBits) return true; + case Instruction::Shl: + case Instruction::LShr: + return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI); - unsigned TypeWidth = I->getType()->getScalarSizeInBits(); - - // We can turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but it isn't - // profitable unless we know the and'd out bits are already zero. - if (CI->getZExtValue() > NumBits) { - unsigned LowBits = TypeWidth - CI->getZExtValue(); - if (IC.MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, - 0, CxtI)) - return true; - } - - return false; - } - case Instruction::LShr: { - // We can often fold the shift into shifts-by-a-constant. - CI = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!CI) return false; - - // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2). - if (!isLeftShift) return true; - - // We can always turn lshr(c)+shl(c) -> and(c2). - if (CI->getValue() == NumBits) return true; - - unsigned TypeWidth = I->getType()->getScalarSizeInBits(); - - // We can always turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but it isn't - // profitable unless we know the and'd out bits are already zero. - if (CI->getValue().ult(TypeWidth) && CI->getZExtValue() > NumBits) { - unsigned LowBits = CI->getZExtValue() - NumBits; - if (IC.MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, - 0, CxtI)) - return true; - } - - return false; - } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, - IC, SI) && - CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC, SI); + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && + CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -176,8 +176,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (Value *IncValue : PN->incoming_values()) - if (!CanEvaluateShifted(IncValue, NumBits, isLeftShift, - IC, PN)) + if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) return false; return true; } @@ -257,6 +256,8 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, BO->setHasNoSignedWrap(false); return BO; } + // FIXME: This is almost identical to the SHL case. Refactor both cases into + // a helper function. case Instruction::LShr: { BinaryOperator *BO = cast<BinaryOperator>(I); unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); @@ -340,7 +341,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); - return ReplaceInstUsesWith( + return replaceInstUsesWith( I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL)); } @@ -356,7 +357,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, if (BO->getOpcode() == Instruction::Mul && isLeftShift) if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1))) return BinaryOperator::CreateMul(BO->getOperand(0), - ConstantExpr::getShl(BOOp, Op1)); + ConstantExpr::getShl(BOOp, Op1)); // Try to fold constant and into select arguments. if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) @@ -573,7 +574,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // saturates. if (AmtSum >= TypeBits) { if (I.getOpcode() != Instruction::AShr) - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); AmtSum = TypeBits-1; // Saturate to 31 for i32 ashr. } @@ -694,12 +695,12 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) return V; @@ -710,11 +711,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && MaskedValueIsZero(I.getOperand(0), - APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), - 0, &I)) { - I.setHasNoUnsignedWrap(); - return &I; - } + APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0, + &I)) { + I.setHasNoUnsignedWrap(); + return &I; + } // If the shifted out value is all signbits, this is a NSW shift. if (!I.hasNoSignedWrap() && @@ -736,11 +737,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { Instruction *InstCombiner::visitLShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; @@ -780,11 +781,11 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; @@ -813,8 +814,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt), - 0, &I)){ + MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)) { I.setIsExact(); return &I; } diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 743d51483ea1..f3268d2c3471 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -22,10 +22,9 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" -/// ShrinkDemandedConstant - Check to see if the specified operand of the -/// specified instruction is a constant integer. If so, check to see if there -/// are any bits set in the constant that are not demanded. If so, shrink the -/// constant and return true. +/// Check to see if the specified operand of the specified instruction is a +/// constant integer. If so, check to see if there are any bits set in the +/// constant that are not demanded. If so, shrink the constant and return true. static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, APInt Demanded) { assert(I && "No instruction?"); @@ -49,9 +48,8 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, -/// SimplifyDemandedInstructionBits - Inst is an integer instruction that -/// SimplifyDemandedBits knows about. See if the instruction has any -/// properties that allow us to simplify its operands. +/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if +/// the instruction has any properties that allow us to simplify its operands. bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); @@ -61,14 +59,14 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { 0, &Inst); if (!V) return false; if (V == &Inst) return true; - ReplaceInstUsesWith(Inst, V); + replaceInstUsesWith(Inst, V); return true; } -/// SimplifyDemandedBits - This form of SimplifyDemandedBits simplifies the -/// specified instruction operand if possible, updating it in place. It returns -/// true if it made any change and false otherwise. -bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, +/// This form of SimplifyDemandedBits simplifies the specified instruction +/// operand if possible, updating it in place. It returns true if it made any +/// change and false otherwise. +bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth) { auto *UserI = dyn_cast<Instruction>(U.getUser()); @@ -80,21 +78,22 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, } -/// SimplifyDemandedUseBits - This function attempts to replace V with a simpler -/// value based on the demanded bits. When this function is called, it is known -/// that only the bits set in DemandedMask of the result of V are ever used -/// downstream. Consequently, depending on the mask and V, it may be possible -/// to replace V with a constant or one of its operands. In such cases, this -/// function does the replacement and returns true. In all other cases, it -/// returns false after analyzing the expression and setting KnownOne and known -/// to be one in the expression. KnownZero contains all the bits that are known -/// to be zero in the expression. These are provided to potentially allow the -/// caller (which might recursively be SimplifyDemandedBits itself) to simplify -/// the expression. KnownOne and KnownZero always follow the invariant that -/// KnownOne & KnownZero == 0. That is, a bit can't be both 1 and 0. Note that -/// the bits in KnownOne and KnownZero may only be accurate for those bits set -/// in DemandedMask. Note also that the bitwidth of V, DemandedMask, KnownZero -/// and KnownOne must all be the same. +/// This function attempts to replace V with a simpler value based on the +/// demanded bits. When this function is called, it is known that only the bits +/// set in DemandedMask of the result of V are ever used downstream. +/// Consequently, depending on the mask and V, it may be possible to replace V +/// with a constant or one of its operands. In such cases, this function does +/// the replacement and returns true. In all other cases, it returns false after +/// analyzing the expression and setting KnownOne and known to be one in the +/// expression. KnownZero contains all the bits that are known to be zero in the +/// expression. These are provided to potentially allow the caller (which might +/// recursively be SimplifyDemandedBits itself) to simplify the expression. +/// KnownOne and KnownZero always follow the invariant that: +/// KnownOne & KnownZero == 0. +/// That is, a bit can't be both 1 and 0. Note that the bits in KnownOne and +/// KnownZero may only be accurate for those bits set in DemandedMask. Note also +/// that the bitwidth of V, DemandedMask, KnownZero and KnownOne must all be the +/// same. /// /// This returns null if it did not change anything and it permits no /// simplification. This returns V itself if it did some simplification of V's @@ -768,6 +767,34 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // TODO: Could compute known zero/one bits based on the input. break; } + case Intrinsic::x86_mmx_pmovmskb: + case Intrinsic::x86_sse_movmsk_ps: + case Intrinsic::x86_sse2_movmsk_pd: + case Intrinsic::x86_sse2_pmovmskb_128: + case Intrinsic::x86_avx_movmsk_ps_256: + case Intrinsic::x86_avx_movmsk_pd_256: + case Intrinsic::x86_avx2_pmovmskb: { + // MOVMSK copies the vector elements' sign bits to the low bits + // and zeros the high bits. + unsigned ArgWidth; + if (II->getIntrinsicID() == Intrinsic::x86_mmx_pmovmskb) { + ArgWidth = 8; // Arg is x86_mmx, but treated as <8 x i8>. + } else { + auto Arg = II->getArgOperand(0); + auto ArgType = cast<VectorType>(Arg->getType()); + ArgWidth = ArgType->getNumElements(); + } + + // If we don't need any of low bits then return zero, + // we know that DemandedMask is non-zero already. + APInt DemandedElts = DemandedMask.zextOrTrunc(ArgWidth); + if (DemandedElts == 0) + return ConstantInt::getNullValue(VTy); + + // We know that the upper bits are set to zero. + KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - ArgWidth); + return nullptr; + } case Intrinsic::x86_sse42_crc32_64_64: KnownZero = APInt::getHighBitsSet(64, 32); return nullptr; @@ -802,7 +829,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, /// As with SimplifyDemandedUseBits, it returns NULL if the simplification was /// not successful. Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, - Instruction *Shl, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne) { + Instruction *Shl, + const APInt &DemandedMask, + APInt &KnownZero, + APInt &KnownOne) { const APInt &ShlOp1 = cast<ConstantInt>(Shl->getOperand(1))->getValue(); const APInt &ShrOp1 = cast<ConstantInt>(Shr->getOperand(1))->getValue(); @@ -865,10 +895,10 @@ Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, return nullptr; } -/// SimplifyDemandedVectorElts - The specified value produces a vector with -/// any number of elements. DemandedElts contains the set of elements that are -/// actually used by the caller. This method analyzes which elements of the -/// operand are undef and returns that information in UndefElts. +/// The specified value produces a vector with any number of elements. +/// DemandedElts contains the set of elements that are actually used by the +/// caller. This method analyzes which elements of the operand are undef and +/// returns that information in UndefElts. /// /// If the information about demanded elements can be used to simplify the /// operation, the operation is simplified, then the resultant value is @@ -876,7 +906,7 @@ Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth) { - unsigned VWidth = cast<VectorType>(V->getType())->getNumElements(); + unsigned VWidth = V->getType()->getVectorNumElements(); APInt EltMask(APInt::getAllOnesValue(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); @@ -1179,16 +1209,42 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, switch (II->getIntrinsicID()) { default: break; - // Binary vector operations that work column-wise. A dest element is a - // function of the corresponding input elements from the two inputs. + // Unary scalar-as-vector operations that work column-wise. + case Intrinsic::x86_sse_rcp_ss: + case Intrinsic::x86_sse_rsqrt_ss: + case Intrinsic::x86_sse_sqrt_ss: + case Intrinsic::x86_sse2_sqrt_sd: + case Intrinsic::x86_xop_vfrcz_ss: + case Intrinsic::x86_xop_vfrcz_sd: + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + // If lowest element of a scalar op isn't used then use Arg0. + if (DemandedElts.getLoBits(1) != 1) + return II->getArgOperand(0); + // TODO: If only low elt lower SQRT to FSQRT (with rounding/exceptions + // checks). + break; + + // Binary scalar-as-vector operations that work column-wise. A dest element + // is a function of the corresponding input elements from the two inputs. + case Intrinsic::x86_sse_add_ss: case Intrinsic::x86_sse_sub_ss: case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse_div_ss: case Intrinsic::x86_sse_min_ss: case Intrinsic::x86_sse_max_ss: + case Intrinsic::x86_sse_cmp_ss: + case Intrinsic::x86_sse2_add_sd: case Intrinsic::x86_sse2_sub_sd: case Intrinsic::x86_sse2_mul_sd: + case Intrinsic::x86_sse2_div_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: + case Intrinsic::x86_sse2_cmp_sd: + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } @@ -1201,11 +1257,15 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, if (DemandedElts == 1) { switch (II->getIntrinsicID()) { default: break; + case Intrinsic::x86_sse_add_ss: case Intrinsic::x86_sse_sub_ss: case Intrinsic::x86_sse_mul_ss: + case Intrinsic::x86_sse_div_ss: + case Intrinsic::x86_sse2_add_sd: case Intrinsic::x86_sse2_sub_sd: case Intrinsic::x86_sse2_mul_sd: - // TODO: Lower MIN/MAX/ABS/etc + case Intrinsic::x86_sse2_div_sd: + // TODO: Lower MIN/MAX/etc. Value *LHS = II->getArgOperand(0); Value *RHS = II->getArgOperand(1); // Extract the element as scalars. @@ -1216,6 +1276,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, switch (II->getIntrinsicID()) { default: llvm_unreachable("Case stmts out of sync!"); + case Intrinsic::x86_sse_add_ss: + case Intrinsic::x86_sse2_add_sd: + TmpV = InsertNewInstWith(BinaryOperator::CreateFAdd(LHS, RHS, + II->getName()), *II); + break; case Intrinsic::x86_sse_sub_ss: case Intrinsic::x86_sse2_sub_sd: TmpV = InsertNewInstWith(BinaryOperator::CreateFSub(LHS, RHS, @@ -1226,6 +1291,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, TmpV = InsertNewInstWith(BinaryOperator::CreateFMul(LHS, RHS, II->getName()), *II); break; + case Intrinsic::x86_sse_div_ss: + case Intrinsic::x86_sse2_div_sd: + TmpV = InsertNewInstWith(BinaryOperator::CreateFDiv(LHS, RHS, + II->getName()), *II); + break; } Instruction *New = @@ -1238,6 +1308,10 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } } + // If lowest element of a scalar op isn't used then use Arg0. + if (DemandedElts.getLoBits(1) != 1) + return II->getArgOperand(0); + // Output elements are undefined if both are undefined. Consider things // like undef&0. The result is known zero, not undef. UndefElts &= UndefElts2; diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index bc4c0ebae790..a76138756148 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -62,21 +62,31 @@ static bool cheapToScalarize(Value *V, bool isConstant) { return false; } -// If we have a PHI node with a vector type that has only 2 uses: feed +// If we have a PHI node with a vector type that is only used to feed // itself and be an operand of extractelement at a constant location, // try to replace the PHI of the vector type with a PHI of a scalar type. Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { - // Verify that the PHI node has exactly 2 uses. Otherwise return NULL. - if (!PN->hasNUses(2)) - return nullptr; + SmallVector<Instruction *, 2> Extracts; + // The users we want the PHI to have are: + // 1) The EI ExtractElement (we already know this) + // 2) Possibly more ExtractElements with the same index. + // 3) Another operand, which will feed back into the PHI. + Instruction *PHIUser = nullptr; + for (auto U : PN->users()) { + if (ExtractElementInst *EU = dyn_cast<ExtractElementInst>(U)) { + if (EI.getIndexOperand() == EU->getIndexOperand()) + Extracts.push_back(EU); + else + return nullptr; + } else if (!PHIUser) { + PHIUser = cast<Instruction>(U); + } else { + return nullptr; + } + } - // If so, it's known at this point that one operand is PHI and the other is - // an extractelement node. Find the PHI user that is not the extractelement - // node. - auto iu = PN->user_begin(); - Instruction *PHIUser = dyn_cast<Instruction>(*iu); - if (PHIUser == cast<Instruction>(&EI)) - PHIUser = cast<Instruction>(*(++iu)); + if (!PHIUser) + return nullptr; // Verify that this PHI user has one use, which is the PHI itself, // and that it is a binary operation which is cheap to scalarize. @@ -106,7 +116,8 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { B0->getOperand(opId)->getName() + ".Elt"), *B0); Value *newPHIUser = InsertNewInstWith( - BinaryOperator::Create(B0->getOpcode(), scalarPHI, Op), *B0); + BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(), + scalarPHI, Op, B0), *B0); scalarPHI->addIncoming(newPHIUser, inBB); } else { // Scalarize PHI input: @@ -125,19 +136,23 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { scalarPHI->addIncoming(newEI, inBB); } } - return ReplaceInstUsesWith(EI, scalarPHI); + + for (auto E : Extracts) + replaceInstUsesWith(*E, scalarPHI); + + return &EI; } Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { if (Value *V = SimplifyExtractElementInst( EI.getVectorOperand(), EI.getIndexOperand(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(EI, V); + return replaceInstUsesWith(EI, V); // If vector val is constant with all elements the same, replace EI with // that element. We handle a known element # below. if (Constant *C = dyn_cast<Constant>(EI.getOperand(0))) if (cheapToScalarize(C, false)) - return ReplaceInstUsesWith(EI, C->getAggregateElement(0U)); + return replaceInstUsesWith(EI, C->getAggregateElement(0U)); // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. @@ -193,12 +208,13 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Value *newEI1 = Builder->CreateExtractElement(BO->getOperand(1), EI.getOperand(1), EI.getName()+".rhs"); - return BinaryOperator::Create(BO->getOpcode(), newEI0, newEI1); + return BinaryOperator::CreateWithCopiedFlags(BO->getOpcode(), + newEI0, newEI1, BO); } } else if (InsertElementInst *IE = dyn_cast<InsertElementInst>(I)) { // Extracting the inserted element? if (IE->getOperand(2) == EI.getOperand(1)) - return ReplaceInstUsesWith(EI, IE->getOperand(1)); + return replaceInstUsesWith(EI, IE->getOperand(1)); // If the inserted and extracted elements are constants, they must not // be the same value, extract from the pre-inserted value instead. if (isa<Constant>(IE->getOperand(2)) && isa<Constant>(EI.getOperand(1))) { @@ -216,7 +232,7 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { SVI->getOperand(0)->getType()->getVectorNumElements(); if (SrcIdx < 0) - return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType())); + return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); if (SrcIdx < (int)LHSWidth) Src = SVI->getOperand(0); else { @@ -417,7 +433,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, continue; auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); NewExt->insertAfter(WideVec); - IC.ReplaceInstUsesWith(*OldExt, NewExt); + IC.replaceInstUsesWith(*OldExt, NewExt); } } @@ -546,7 +562,7 @@ Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) { } if (IsRedundant) - return ReplaceInstUsesWith(I, I.getOperand(0)); + return replaceInstUsesWith(I, I.getOperand(0)); return nullptr; } @@ -557,7 +573,7 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { // Inserting an undef or into an undefined place, remove this. if (isa<UndefValue>(ScalarOp) || isa<UndefValue>(IdxOp)) - ReplaceInstUsesWith(IE, VecOp); + replaceInstUsesWith(IE, VecOp); // If the inserted element was extracted from some other vector, and if the // indexes are constant, try to turn this into a shufflevector operation. @@ -571,15 +587,15 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); if (ExtractedIdx >= NumExtractVectorElts) // Out of range extract. - return ReplaceInstUsesWith(IE, VecOp); + return replaceInstUsesWith(IE, VecOp); if (InsertedIdx >= NumInsertVectorElts) // Out of range insert. - return ReplaceInstUsesWith(IE, UndefValue::get(IE.getType())); + return replaceInstUsesWith(IE, UndefValue::get(IE.getType())); // If we are extracting a value from a vector, then inserting it right // back into the same place, just use the input vector. if (EI->getOperand(0) == VecOp && ExtractedIdx == InsertedIdx) - return ReplaceInstUsesWith(IE, VecOp); + return replaceInstUsesWith(IE, VecOp); // If this insertelement isn't used by some other insertelement, turn it // (and any insertelements it points to), into one big shuffle. @@ -605,7 +621,7 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { if (V != &IE) - return ReplaceInstUsesWith(IE, V); + return replaceInstUsesWith(IE, V); return &IE; } @@ -910,7 +926,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // Undefined shuffle mask -> undefined value. if (isa<UndefValue>(SVI.getOperand(2))) - return ReplaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); + return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); unsigned VWidth = cast<VectorType>(SVI.getType())->getNumElements(); @@ -918,7 +934,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { if (V != &SVI) - return ReplaceInstUsesWith(SVI, V); + return replaceInstUsesWith(SVI, V); LHS = SVI.getOperand(0); RHS = SVI.getOperand(1); MadeChange = true; @@ -933,7 +949,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // shuffle(undef,undef,mask) -> undef. Value *Result = (VWidth == LHSWidth) ? LHS : UndefValue::get(SVI.getType()); - return ReplaceInstUsesWith(SVI, Result); + return replaceInstUsesWith(SVI, Result); } // Remap any references to RHS to use LHS. @@ -967,13 +983,13 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { recognizeIdentityMask(Mask, isLHSID, isRHSID); // Eliminate identity shuffles. - if (isLHSID) return ReplaceInstUsesWith(SVI, LHS); - if (isRHSID) return ReplaceInstUsesWith(SVI, RHS); + if (isLHSID) return replaceInstUsesWith(SVI, LHS); + if (isRHSID) return replaceInstUsesWith(SVI, RHS); } if (isa<UndefValue>(RHS) && CanEvaluateShuffled(LHS, Mask)) { Value *V = EvaluateInDifferentElementOrder(LHS, Mask); - return ReplaceInstUsesWith(SVI, V); + return replaceInstUsesWith(SVI, V); } // SROA generates shuffle+bitcast when the extracted sub-vector is bitcast to @@ -1060,7 +1076,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { NewBC, ConstantInt::get(Int32Ty, BegIdx), SVI.getName() + ".extract"); // The shufflevector isn't being replaced: the bitcast that used it // is. InstCombine will visit the newly-created instructions. - ReplaceInstUsesWith(*BC, Ext); + replaceInstUsesWith(*BC, Ext); MadeChange = true; } } @@ -1251,8 +1267,8 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // corresponding argument. bool isLHSID, isRHSID; recognizeIdentityMask(newMask, isLHSID, isRHSID); - if (isLHSID && VWidth == LHSOp0Width) return ReplaceInstUsesWith(SVI, newLHS); - if (isRHSID && VWidth == RHSOp0Width) return ReplaceInstUsesWith(SVI, newRHS); + if (isLHSID && VWidth == LHSOp0Width) return replaceInstUsesWith(SVI, newLHS); + if (isRHSID && VWidth == RHSOp0Width) return replaceInstUsesWith(SVI, newRHS); return MadeChange ? &SVI : nullptr; } diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 903a0b5f5400..51c3262b5d14 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -39,7 +39,9 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" @@ -76,6 +78,10 @@ STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumFactor , "Number of factorizations"); STATISTIC(NumReassoc , "Number of reassociations"); +static cl::opt<bool> +EnableExpensiveCombines("expensive-combines", + cl::desc("Enable expensive instruction combines")); + Value *InstCombiner::EmitGEPOffset(User *GEP) { return llvm::EmitGEPOffset(Builder, DL, GEP); } @@ -120,33 +126,23 @@ bool InstCombiner::ShouldChangeType(Type *From, Type *To) const { // all other opcodes, the function conservatively returns false. static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I); - if (!OBO || !OBO->hasNoSignedWrap()) { + if (!OBO || !OBO->hasNoSignedWrap()) return false; - } // We reason about Add and Sub Only. Instruction::BinaryOps Opcode = I.getOpcode(); - if (Opcode != Instruction::Add && - Opcode != Instruction::Sub) { + if (Opcode != Instruction::Add && Opcode != Instruction::Sub) return false; - } - - ConstantInt *CB = dyn_cast<ConstantInt>(B); - ConstantInt *CC = dyn_cast<ConstantInt>(C); - if (!CB || !CC) { + const APInt *BVal, *CVal; + if (!match(B, m_APInt(BVal)) || !match(C, m_APInt(CVal))) return false; - } - const APInt &BVal = CB->getValue(); - const APInt &CVal = CC->getValue(); bool Overflow = false; - - if (Opcode == Instruction::Add) { - BVal.sadd_ov(CVal, Overflow); - } else { - BVal.ssub_ov(CVal, Overflow); - } + if (Opcode == Instruction::Add) + BVal->sadd_ov(*CVal, Overflow); + else + BVal->ssub_ov(*CVal, Overflow); return !Overflow; } @@ -166,6 +162,49 @@ static void ClearSubclassDataAfterReassociation(BinaryOperator &I) { I.setFastMathFlags(FMF); } +/// Combine constant operands of associative operations either before or after a +/// cast to eliminate one of the associative operations: +/// (op (cast (op X, C2)), C1) --> (cast (op X, op (C1, C2))) +/// (op (cast (op X, C2)), C1) --> (op (cast X), op (C1, C2)) +static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1) { + auto *Cast = dyn_cast<CastInst>(BinOp1->getOperand(0)); + if (!Cast || !Cast->hasOneUse()) + return false; + + // TODO: Enhance logic for other casts and remove this check. + auto CastOpcode = Cast->getOpcode(); + if (CastOpcode != Instruction::ZExt) + return false; + + // TODO: Enhance logic for other BinOps and remove this check. + auto AssocOpcode = BinOp1->getOpcode(); + if (AssocOpcode != Instruction::Xor && AssocOpcode != Instruction::And && + AssocOpcode != Instruction::Or) + return false; + + auto *BinOp2 = dyn_cast<BinaryOperator>(Cast->getOperand(0)); + if (!BinOp2 || !BinOp2->hasOneUse() || BinOp2->getOpcode() != AssocOpcode) + return false; + + Constant *C1, *C2; + if (!match(BinOp1->getOperand(1), m_Constant(C1)) || + !match(BinOp2->getOperand(1), m_Constant(C2))) + return false; + + // TODO: This assumes a zext cast. + // Eg, if it was a trunc, we'd cast C1 to the source type because casting C2 + // to the destination type might lose bits. + + // Fold the constants together in the destination type: + // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC) + Type *DestTy = C1->getType(); + Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); + Constant *FoldedC = ConstantExpr::get(AssocOpcode, C1, CastC2); + Cast->setOperand(0, BinOp2->getOperand(0)); + BinOp1->setOperand(1, FoldedC); + return true; +} + /// This performs a few simplifications for operators that are associative or /// commutative: /// @@ -253,6 +292,12 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { } if (I.isAssociative() && I.isCommutative()) { + if (simplifyAssocCastAssoc(&I)) { + Changed = true; + ++NumReassoc; + continue; + } + // Transform: "(A op B) op C" ==> "(C op A) op B" if "C op A" simplifies. if (Op0 && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); @@ -919,10 +964,10 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { for (auto UI = PN->user_begin(), E = PN->user_end(); UI != E;) { Instruction *User = cast<Instruction>(*UI++); if (User == &I) continue; - ReplaceInstUsesWith(*User, NewPN); - EraseInstFromFunction(*User); + replaceInstUsesWith(*User, NewPN); + eraseInstFromFunction(*User); } - return ReplaceInstUsesWith(I, NewPN); + return replaceInstUsesWith(I, NewPN); } /// Given a pointer type and a constant offset, determine whether or not there @@ -1334,8 +1379,8 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - if (Value *V = SimplifyGEPInst(Ops, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(GEP, V); + if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops, DL, TLI, DT, AC)) + return replaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1349,19 +1394,18 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { for (User::op_iterator I = GEP.op_begin() + 1, E = GEP.op_end(); I != E; ++I, ++GTI) { // Skip indices into struct types. - SequentialType *SeqTy = dyn_cast<SequentialType>(*GTI); - if (!SeqTy) + if (isa<StructType>(*GTI)) continue; // Index type should have the same width as IntPtr Type *IndexTy = (*I)->getType(); Type *NewIndexType = IndexTy->isVectorTy() ? VectorType::get(IntPtrTy, IndexTy->getVectorNumElements()) : IntPtrTy; - + // If the element type has zero size then any index over it is equivalent // to an index of zero, so replace it with zero if it is not zero already. - if (SeqTy->getElementType()->isSized() && - DL.getTypeAllocSize(SeqTy->getElementType()) == 0) + Type *EltTy = GTI.getIndexedType(); + if (EltTy->isSized() && DL.getTypeAllocSize(EltTy) == 0) if (!isa<Constant>(*I) || !cast<Constant>(*I)->isNullValue()) { *I = Constant::getNullValue(NewIndexType); MadeChange = true; @@ -1393,7 +1437,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (Op1 == &GEP) return nullptr; - signed DI = -1; + int DI = -1; for (auto I = PN->op_begin()+1, E = PN->op_end(); I !=E; ++I) { GetElementPtrInst *Op2 = dyn_cast<GetElementPtrInst>(*I); @@ -1405,7 +1449,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return nullptr; // Keep track of the type as we walk the GEP. - Type *CurTy = Op1->getOperand(0)->getType()->getScalarType(); + Type *CurTy = nullptr; for (unsigned J = 0, F = Op1->getNumOperands(); J != F; ++J) { if (Op1->getOperand(J)->getType() != Op2->getOperand(J)->getType()) @@ -1436,7 +1480,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Sink down a layer of the type for the next iteration. if (J > 0) { - if (CompositeType *CT = dyn_cast<CompositeType>(CurTy)) { + if (J == 1) { + CurTy = Op1->getSourceElementType(); + } else if (CompositeType *CT = dyn_cast<CompositeType>(CurTy)) { CurTy = CT->getTypeAtIndex(Op1->getOperand(J)); } else { CurTy = nullptr; @@ -1565,8 +1611,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL.getPointerSizeInBits(AS)) { - Type *PtrTy = GEP.getPointerOperandType(); - Type *Ty = PtrTy->getPointerElementType(); + Type *Ty = GEP.getSourceElementType(); uint64_t TyAllocSize = DL.getTypeAllocSize(Ty); bool Matched = false; @@ -1629,9 +1674,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // // This occurs when the program declares an array extern like "int X[];" if (HasZeroPointerIndex) { - PointerType *CPTy = cast<PointerType>(PtrOp->getType()); if (ArrayType *CATy = - dyn_cast<ArrayType>(CPTy->getElementType())) { + dyn_cast<ArrayType>(GEP.getSourceElementType())) { // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? if (CATy->getElementType() == StrippedPtrTy->getElementType()) { // -> GEP i8* X, ... @@ -1688,7 +1732,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // %t = getelementptr i32* bitcast ([2 x i32]* %str to i32*), i32 %V // into: %t1 = getelementptr [2 x i32]* %str, i32 0, i32 %V; bitcast Type *SrcElTy = StrippedPtrTy->getElementType(); - Type *ResElTy = PtrOp->getType()->getPointerElementType(); + Type *ResElTy = GEP.getSourceElementType(); if (SrcElTy->isArrayTy() && DL.getTypeAllocSize(SrcElTy->getArrayElementType()) == DL.getTypeAllocSize(ResElTy)) { @@ -1822,7 +1866,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (I != BCI) { I->takeName(BCI); BCI->getParent()->getInstList().insert(BCI->getIterator(), I); - ReplaceInstUsesWith(*BCI, I); + replaceInstUsesWith(*BCI, I); } return &GEP; } @@ -1844,7 +1888,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { : Builder->CreateGEP(nullptr, Operand, NewIndices); if (NGEP->getType() == GEP.getType()) - return ReplaceInstUsesWith(GEP, NGEP); + return replaceInstUsesWith(GEP, NGEP); NGEP->takeName(&GEP); if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) @@ -1857,6 +1901,20 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return nullptr; } +static bool isNeverEqualToUnescapedAlloc(Value *V, const TargetLibraryInfo *TLI, + Instruction *AI) { + if (isa<ConstantPointerNull>(V)) + return true; + if (auto *LI = dyn_cast<LoadInst>(V)) + return isa<GlobalVariable>(LI->getPointerOperand()); + // Two distinct allocations will never be equal. + // We rely on LookThroughBitCast in isAllocLikeFn being false, since looking + // through bitcasts of V can cause + // the result statement below to be true, even when AI and V (ex: + // i8* ->i32* ->i8* of AI) are the same allocations. + return isAllocLikeFn(V, TLI) && V != AI; +} + static bool isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakVH> &Users, const TargetLibraryInfo *TLI) { @@ -1881,7 +1939,12 @@ isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakVH> &Users, case Instruction::ICmp: { ICmpInst *ICI = cast<ICmpInst>(I); // We can fold eq/ne comparisons with null to false/true, respectively. - if (!ICI->isEquality() || !isa<ConstantPointerNull>(ICI->getOperand(1))) + // We also fold comparisons in some conditions provided the alloc has + // not escaped (see isNeverEqualToUnescapedAlloc). + if (!ICI->isEquality()) + return false; + unsigned OtherIndex = (ICI->getOperand(0) == PI) ? 1 : 0; + if (!isNeverEqualToUnescapedAlloc(ICI->getOperand(OtherIndex), TLI, AI)) return false; Users.emplace_back(I); continue; @@ -1941,23 +2004,40 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { SmallVector<WeakVH, 64> Users; if (isAllocSiteRemovable(&MI, Users, TLI)) { for (unsigned i = 0, e = Users.size(); i != e; ++i) { - Instruction *I = cast_or_null<Instruction>(&*Users[i]); - if (!I) continue; + // Lowering all @llvm.objectsize calls first because they may + // use a bitcast/GEP of the alloca we are removing. + if (!Users[i]) + continue; + + Instruction *I = cast<Instruction>(&*Users[i]); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::objectsize) { + uint64_t Size; + if (!getObjectSize(II->getArgOperand(0), Size, DL, TLI)) { + ConstantInt *CI = cast<ConstantInt>(II->getArgOperand(1)); + Size = CI->isZero() ? -1ULL : 0; + } + replaceInstUsesWith(*I, ConstantInt::get(I->getType(), Size)); + eraseInstFromFunction(*I); + Users[i] = nullptr; // Skip examining in the next loop. + } + } + } + for (unsigned i = 0, e = Users.size(); i != e; ++i) { + if (!Users[i]) + continue; + + Instruction *I = cast<Instruction>(&*Users[i]); if (ICmpInst *C = dyn_cast<ICmpInst>(I)) { - ReplaceInstUsesWith(*C, + replaceInstUsesWith(*C, ConstantInt::get(Type::getInt1Ty(C->getContext()), C->isFalseWhenEqual())); } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I)) { - ReplaceInstUsesWith(*I, UndefValue::get(I->getType())); - } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { - if (II->getIntrinsicID() == Intrinsic::objectsize) { - ConstantInt *CI = cast<ConstantInt>(II->getArgOperand(1)); - uint64_t DontKnow = CI->isZero() ? -1ULL : 0; - ReplaceInstUsesWith(*I, ConstantInt::get(I->getType(), DontKnow)); - } + replaceInstUsesWith(*I, UndefValue::get(I->getType())); } - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); } if (InvokeInst *II = dyn_cast<InvokeInst>(&MI)) { @@ -1967,7 +2047,7 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { InvokeInst::Create(F, II->getNormalDest(), II->getUnwindDest(), None, "", II->getParent()); } - return EraseInstFromFunction(MI); + return eraseInstFromFunction(MI); } return nullptr; } @@ -2038,13 +2118,13 @@ Instruction *InstCombiner::visitFree(CallInst &FI) { // Insert a new store to null because we cannot modify the CFG here. Builder->CreateStore(ConstantInt::getTrue(FI.getContext()), UndefValue::get(Type::getInt1PtrTy(FI.getContext()))); - return EraseInstFromFunction(FI); + return eraseInstFromFunction(FI); } // If we have 'free null' delete the instruction. This can happen in stl code // when lots of inlining happens. if (isa<ConstantPointerNull>(Op)) - return EraseInstFromFunction(FI); + return eraseInstFromFunction(FI); // If we optimize for code size, try to move the call to free before the null // test so that simplify cfg can remove the empty block and dead code @@ -2145,6 +2225,7 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { unsigned LeadingKnownOnes = KnownOne.countLeadingOnes(); // Compute the number of leading bits we can ignore. + // TODO: A better way to determine this would use ComputeNumSignBits(). for (auto &C : SI.cases()) { LeadingKnownZeros = std::min( LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); @@ -2154,17 +2235,15 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { unsigned NewWidth = BitWidth - std::max(LeadingKnownZeros, LeadingKnownOnes); - // Truncate the condition operand if the new type is equal to or larger than - // the largest legal integer type. We need to be conservative here since - // x86 generates redundant zero-extension instructions if the operand is - // truncated to i8 or i16. + // Shrink the condition operand if the new type is smaller than the old type. + // This may produce a non-standard type for the switch, but that's ok because + // the backend should extend back to a legal type for the target. bool TruncCond = false; - if (NewWidth > 0 && BitWidth > NewWidth && - NewWidth >= DL.getLargestLegalIntTypeSize()) { + if (NewWidth > 0 && NewWidth < BitWidth) { TruncCond = true; IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); Builder->SetInsertPoint(&SI); - Value *NewCond = Builder->CreateTrunc(SI.getCondition(), Ty, "trunc"); + Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc"); SI.setCondition(NewCond); for (auto &C : SI.cases()) @@ -2172,28 +2251,27 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { SI.getContext(), C.getCaseValue()->getValue().trunc(NewWidth))); } - if (Instruction *I = dyn_cast<Instruction>(Cond)) { - if (I->getOpcode() == Instruction::Add) - if (ConstantInt *AddRHS = dyn_cast<ConstantInt>(I->getOperand(1))) { - // change 'switch (X+4) case 1:' into 'switch (X) case -3' - // Skip the first item since that's the default case. - for (SwitchInst::CaseIt i = SI.case_begin(), e = SI.case_end(); - i != e; ++i) { - ConstantInt* CaseVal = i.getCaseValue(); - Constant *LHS = CaseVal; - if (TruncCond) - LHS = LeadingKnownZeros - ? ConstantExpr::getZExt(CaseVal, Cond->getType()) - : ConstantExpr::getSExt(CaseVal, Cond->getType()); - Constant* NewCaseVal = ConstantExpr::getSub(LHS, AddRHS); - assert(isa<ConstantInt>(NewCaseVal) && - "Result of expression should be constant"); - i.setValue(cast<ConstantInt>(NewCaseVal)); - } - SI.setCondition(I->getOperand(0)); - Worklist.Add(I); - return &SI; + ConstantInt *AddRHS = nullptr; + if (match(Cond, m_Add(m_Value(), m_ConstantInt(AddRHS)))) { + Instruction *I = cast<Instruction>(Cond); + // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. + for (SwitchInst::CaseIt i = SI.case_begin(), e = SI.case_end(); i != e; + ++i) { + ConstantInt *CaseVal = i.getCaseValue(); + Constant *LHS = CaseVal; + if (TruncCond) { + LHS = LeadingKnownZeros + ? ConstantExpr::getZExt(CaseVal, Cond->getType()) + : ConstantExpr::getSExt(CaseVal, Cond->getType()); } + Constant *NewCaseVal = ConstantExpr::getSub(LHS, AddRHS); + assert(isa<ConstantInt>(NewCaseVal) && + "Result of expression should be constant"); + i.setValue(cast<ConstantInt>(NewCaseVal)); + } + SI.setCondition(I->getOperand(0)); + Worklist.Add(I); + return &SI; } return TruncCond ? &SI : nullptr; @@ -2203,11 +2281,11 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); if (!EV.hasIndices()) - return ReplaceInstUsesWith(EV, Agg); + return replaceInstUsesWith(EV, Agg); if (Value *V = SimplifyExtractValueInst(Agg, EV.getIndices(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(EV, V); + return replaceInstUsesWith(EV, V); if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) { // We're extracting from an insertvalue instruction, compare the indices @@ -2233,7 +2311,7 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { // %B = insertvalue { i32, { i32 } } %A, i32 42, 1, 0 // %C = extractvalue { i32, { i32 } } %B, 1, 0 // with "i32 42" - return ReplaceInstUsesWith(EV, IV->getInsertedValueOperand()); + return replaceInstUsesWith(EV, IV->getInsertedValueOperand()); if (exti == exte) { // The extract list is a prefix of the insert list. i.e. replace // %I = insertvalue { i32, { i32 } } %A, i32 42, 1, 0 @@ -2273,8 +2351,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { case Intrinsic::sadd_with_overflow: if (*EV.idx_begin() == 0) { // Normal result. Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - ReplaceInstUsesWith(*II, UndefValue::get(II->getType())); - EraseInstFromFunction(*II); + replaceInstUsesWith(*II, UndefValue::get(II->getType())); + eraseInstFromFunction(*II); return BinaryOperator::CreateAdd(LHS, RHS); } @@ -2290,8 +2368,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { case Intrinsic::ssub_with_overflow: if (*EV.idx_begin() == 0) { // Normal result. Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - ReplaceInstUsesWith(*II, UndefValue::get(II->getType())); - EraseInstFromFunction(*II); + replaceInstUsesWith(*II, UndefValue::get(II->getType())); + eraseInstFromFunction(*II); return BinaryOperator::CreateSub(LHS, RHS); } break; @@ -2299,8 +2377,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { case Intrinsic::smul_with_overflow: if (*EV.idx_begin() == 0) { // Normal result. Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - ReplaceInstUsesWith(*II, UndefValue::get(II->getType())); - EraseInstFromFunction(*II); + replaceInstUsesWith(*II, UndefValue::get(II->getType())); + eraseInstFromFunction(*II); return BinaryOperator::CreateMul(LHS, RHS); } break; @@ -2330,8 +2408,8 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { Value *GEP = Builder->CreateInBoundsGEP(L->getType(), L->getPointerOperand(), Indices); // Returning the load directly will cause the main loop to insert it in - // the wrong spot, so use ReplaceInstUsesWith(). - return ReplaceInstUsesWith(EV, Builder->CreateLoad(GEP)); + // the wrong spot, so use replaceInstUsesWith(). + return replaceInstUsesWith(EV, Builder->CreateLoad(GEP)); } // We could simplify extracts from other values. Note that nested extracts may // already be simplified implicitly by the above: extract (extract (insert) ) @@ -2348,8 +2426,10 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { switch (Personality) { case EHPersonality::GNU_C: - // The GCC C EH personality only exists to support cleanups, so it's not - // clear what the semantics of catch clauses are. + case EHPersonality::GNU_C_SjLj: + case EHPersonality::Rust: + // The GCC C EH and Rust personality only exists to support cleanups, so + // it's not clear what the semantics of catch clauses are. return false; case EHPersonality::Unknown: return false; @@ -2358,6 +2438,7 @@ static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { // match foreign exceptions (or didn't, before gcc-4.7). return false; case EHPersonality::GNU_CXX: + case EHPersonality::GNU_CXX_SjLj: case EHPersonality::GNU_ObjC: case EHPersonality::MSVC_X86SEH: case EHPersonality::MSVC_Win64SEH: @@ -2701,12 +2782,15 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { &DestBlock->getParent()->getEntryBlock()) return false; + // Do not sink into catchswitch blocks. + if (isa<CatchSwitchInst>(DestBlock->getTerminator())) + return false; + // Do not sink convergent call instructions. if (auto *CI = dyn_cast<CallInst>(I)) { if (CI->isConvergent()) return false; } - // We can only sink load instructions if there is nothing between the load and // the end of block that could change the value. if (I->mayReadFromMemory()) { @@ -2731,7 +2815,7 @@ bool InstCombiner::run() { // Check to see if we can DCE the instruction. if (isInstructionTriviallyDead(I, TLI)) { DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); ++NumDeadInst; MadeIRChange = true; continue; @@ -2744,17 +2828,17 @@ bool InstCombiner::run() { DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I << '\n'); // Add operands to the worklist. - ReplaceInstUsesWith(*I, C); + replaceInstUsesWith(*I, C); ++NumConstProp; - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); MadeIRChange = true; continue; } } - // In general, it is possible for computeKnownBits to determine all bits in a - // value even when the operands are not all constants. - if (!I->use_empty() && I->getType()->isIntegerTy()) { + // In general, it is possible for computeKnownBits to determine all bits in + // a value even when the operands are not all constants. + if (ExpensiveCombines && !I->use_empty() && I->getType()->isIntegerTy()) { unsigned BitWidth = I->getType()->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); @@ -2765,9 +2849,9 @@ bool InstCombiner::run() { " from: " << *I << '\n'); // Add operands to the worklist. - ReplaceInstUsesWith(*I, C); + replaceInstUsesWith(*I, C); ++NumConstProp; - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); MadeIRChange = true; continue; } @@ -2800,6 +2884,7 @@ bool InstCombiner::run() { if (UserIsSuccessor && UserParent->getSinglePredecessor()) { // Okay, the CFG is simple enough, try to sink this instruction. if (TryToSinkInstruction(I, UserParent)) { + DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); MadeIRChange = true; // We'll add uses of the sunk instruction below, but since sinking // can expose opportunities for it's *operands* add them to the @@ -2852,7 +2937,7 @@ bool InstCombiner::run() { InstParent->getInstList().insert(InsertPos, Result); - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); } else { #ifndef NDEBUG DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' @@ -2862,7 +2947,7 @@ bool InstCombiner::run() { // If the instruction was modified, it's possible that it is now dead. // if so, remove it. if (isInstructionTriviallyDead(I, TLI)) { - EraseInstFromFunction(*I); + eraseInstFromFunction(*I); } else { Worklist.Add(I); Worklist.AddUsersToWorkList(*I); @@ -3002,35 +3087,20 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // Do a depth-first traversal of the function, populate the worklist with // the reachable instructions. Ignore blocks that are not reachable. Keep // track of which blocks we visit. - SmallPtrSet<BasicBlock *, 64> Visited; + SmallPtrSet<BasicBlock *, 32> Visited; MadeIRChange |= AddReachableCodeToWorklist(&F.front(), DL, Visited, ICWorklist, TLI); // Do a quick scan over the function. If we find any blocks that are // unreachable, remove any instructions inside of them. This prevents // the instcombine code from having to deal with some bad special cases. - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - if (Visited.count(&*BB)) + for (BasicBlock &BB : F) { + if (Visited.count(&BB)) continue; - // Delete the instructions backwards, as it has a reduced likelihood of - // having to update as many def-use and use-def chains. - Instruction *EndInst = BB->getTerminator(); // Last not to be deleted. - while (EndInst != BB->begin()) { - // Delete the next to last instruction. - Instruction *Inst = &*--EndInst->getIterator(); - if (!Inst->use_empty() && !Inst->getType()->isTokenTy()) - Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); - if (Inst->isEHPad() || Inst->getType()->isTokenTy()) { - EndInst = Inst; - continue; - } - if (!isa<DbgInfoIntrinsic>(Inst)) { - ++NumDeadInst; - MadeIRChange = true; - } - Inst->eraseFromParent(); - } + unsigned NumDeadInstInBB = removeAllNonTerminatorAndEHPadInstructions(&BB); + MadeIRChange |= NumDeadInstInBB > 0; + NumDeadInst += NumDeadInstInBB; } return MadeIRChange; @@ -3040,12 +3110,14 @@ static bool combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, + bool ExpensiveCombines = true, LoopInfo *LI = nullptr) { auto &DL = F.getParent()->getDataLayout(); + ExpensiveCombines |= EnableExpensiveCombines; /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. - IRBuilder<true, TargetFolder, InstCombineIRInserter> Builder( + IRBuilder<TargetFolder, InstCombineIRInserter> Builder( F.getContext(), TargetFolder(DL), InstCombineIRInserter(Worklist, &AC)); // Lower dbg.declare intrinsics otherwise their value may be clobbered @@ -3059,14 +3131,11 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); - bool Changed = false; - if (prepareICWorklistFromFunction(F, DL, &TLI, Worklist)) - Changed = true; + bool Changed = prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombiner IC(Worklist, &Builder, F.optForMinSize(), + InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines, AA, &AC, &TLI, &DT, DL, LI); - if (IC.run()) - Changed = true; + Changed |= IC.run(); if (!Changed) break; @@ -3076,45 +3145,26 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, } PreservedAnalyses InstCombinePass::run(Function &F, - AnalysisManager<Function> *AM) { - auto &AC = AM->getResult<AssumptionAnalysis>(F); - auto &DT = AM->getResult<DominatorTreeAnalysis>(F); - auto &TLI = AM->getResult<TargetLibraryAnalysis>(F); + AnalysisManager<Function> &AM) { + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - auto *LI = AM->getCachedResult<LoopAnalysis>(F); + auto *LI = AM.getCachedResult<LoopAnalysis>(F); // FIXME: The AliasAnalysis is not yet supported in the new pass manager - if (!combineInstructionsOverFunction(F, Worklist, nullptr, AC, TLI, DT, LI)) + if (!combineInstructionsOverFunction(F, Worklist, nullptr, AC, TLI, DT, + ExpensiveCombines, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); // Mark all the analyses that instcombine updates as preserved. - // FIXME: Need a way to preserve CFG analyses here! + // FIXME: This should also 'preserve the CFG'. PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); return PA; } -namespace { -/// \brief The legacy pass manager's instcombine pass. -/// -/// This is a basic whole-function wrapper around the instcombine utility. It -/// will try to combine all instructions in the function. -class InstructionCombiningPass : public FunctionPass { - InstCombineWorklist Worklist; - -public: - static char ID; // Pass identification, replacement for typeid - - InstructionCombiningPass() : FunctionPass(ID) { - initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; -}; -} - void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); AU.addRequired<AAResultsWrapperPass>(); @@ -3122,11 +3172,13 @@ void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<BasicAAWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } bool InstructionCombiningPass::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; // Required analyses. @@ -3139,7 +3191,8 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; - return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, LI); + return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, + ExpensiveCombines, LI); } char InstructionCombiningPass::ID = 0; @@ -3162,6 +3215,6 @@ void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { initializeInstructionCombiningPassPass(*unwrap(R)); } -FunctionPass *llvm::createInstructionCombiningPass() { - return new InstructionCombiningPass(); +FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines) { + return new InstructionCombiningPass(ExpensiveCombines); } diff --git a/lib/Transforms/InstCombine/Makefile b/lib/Transforms/InstCombine/Makefile deleted file mode 100644 index 0c488e78b6d9..000000000000 --- a/lib/Transforms/InstCombine/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Transforms/InstCombine/Makefile -----------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMInstCombine -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp index a9df5e5898ae..43d1b377f858 100644 --- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -13,14 +13,11 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" @@ -48,6 +45,7 @@ #include "llvm/Support/Endian.h" #include "llvm/Support/SwapByteOrder.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/ASanStackFrameLayout.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -65,17 +63,23 @@ using namespace llvm; static const uint64_t kDefaultShadowScale = 3; static const uint64_t kDefaultShadowOffset32 = 1ULL << 29; -static const uint64_t kIOSShadowOffset32 = 1ULL << 30; static const uint64_t kDefaultShadowOffset64 = 1ULL << 44; +static const uint64_t kIOSShadowOffset32 = 1ULL << 30; +static const uint64_t kIOSShadowOffset64 = 0x120200000; +static const uint64_t kIOSSimShadowOffset32 = 1ULL << 30; +static const uint64_t kIOSSimShadowOffset64 = kDefaultShadowOffset64; static const uint64_t kSmallX86_64ShadowOffset = 0x7FFF8000; // < 2G. static const uint64_t kLinuxKasan_ShadowOffset64 = 0xdffffc0000000000; static const uint64_t kPPC64_ShadowOffset64 = 1ULL << 41; +static const uint64_t kSystemZ_ShadowOffset64 = 1ULL << 52; static const uint64_t kMIPS32_ShadowOffset32 = 0x0aaa0000; static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 37; static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36; static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; static const uint64_t kWindowsShadowOffset32 = 3ULL << 28; +// TODO(wwchrome): Experimental for asan Win64, may change. +static const uint64_t kWindowsShadowOffset64 = 0x1ULL << 45; // 32TB. static const size_t kMinStackMallocSize = 1 << 6; // 64B static const size_t kMaxStackMallocSize = 1 << 16; // 64K @@ -89,11 +93,15 @@ static const char *const kAsanReportErrorTemplate = "__asan_report_"; static const char *const kAsanRegisterGlobalsName = "__asan_register_globals"; static const char *const kAsanUnregisterGlobalsName = "__asan_unregister_globals"; +static const char *const kAsanRegisterImageGlobalsName = + "__asan_register_image_globals"; +static const char *const kAsanUnregisterImageGlobalsName = + "__asan_unregister_image_globals"; static const char *const kAsanPoisonGlobalsName = "__asan_before_dynamic_init"; static const char *const kAsanUnpoisonGlobalsName = "__asan_after_dynamic_init"; static const char *const kAsanInitName = "__asan_init"; static const char *const kAsanVersionCheckName = - "__asan_version_mismatch_check_v6"; + "__asan_version_mismatch_check_v8"; static const char *const kAsanPtrCmp = "__sanitizer_ptr_cmp"; static const char *const kAsanPtrSub = "__sanitizer_ptr_sub"; static const char *const kAsanHandleNoReturnName = "__asan_handle_no_return"; @@ -101,13 +109,16 @@ static const int kMaxAsanStackMallocSizeClass = 10; static const char *const kAsanStackMallocNameTemplate = "__asan_stack_malloc_"; static const char *const kAsanStackFreeNameTemplate = "__asan_stack_free_"; static const char *const kAsanGenPrefix = "__asan_gen_"; +static const char *const kODRGenPrefix = "__odr_asan_gen_"; static const char *const kSanCovGenPrefix = "__sancov_gen_"; static const char *const kAsanPoisonStackMemoryName = "__asan_poison_stack_memory"; static const char *const kAsanUnpoisonStackMemoryName = "__asan_unpoison_stack_memory"; +static const char *const kAsanGlobalsRegisteredFlagName = + "__asan_globals_registered"; -static const char *const kAsanOptionDetectUAR = +static const char *const kAsanOptionDetectUseAfterReturn = "__asan_option_detect_stack_use_after_return"; static const char *const kAsanAllocaPoison = "__asan_alloca_poison"; @@ -154,8 +165,11 @@ static cl::opt<int> ClMaxInsnsToInstrumentPerBB( static cl::opt<bool> ClStack("asan-stack", cl::desc("Handle stack memory"), cl::Hidden, cl::init(true)); static cl::opt<bool> ClUseAfterReturn("asan-use-after-return", - cl::desc("Check return-after-free"), + cl::desc("Check stack-use-after-return"), cl::Hidden, cl::init(true)); +static cl::opt<bool> ClUseAfterScope("asan-use-after-scope", + cl::desc("Check stack-use-after-scope"), + cl::Hidden, cl::init(false)); // This flag may need to be replaced with -f[no]asan-globals. static cl::opt<bool> ClGlobals("asan-globals", cl::desc("Handle global objects"), cl::Hidden, @@ -192,10 +206,14 @@ static cl::opt<bool> ClSkipPromotableAllocas( // These flags allow to change the shadow mapping. // The shadow mapping looks like -// Shadow = (Mem >> scale) + (1 << offset_log) +// Shadow = (Mem >> scale) + offset static cl::opt<int> ClMappingScale("asan-mapping-scale", cl::desc("scale of asan shadow mapping"), cl::Hidden, cl::init(0)); +static cl::opt<unsigned long long> ClMappingOffset( + "asan-mapping-offset", + cl::desc("offset of asan shadow mapping [EXPERIMENTAL]"), cl::Hidden, + cl::init(0)); // Optimization flags. Not user visible, used mostly for testing // and benchmarking the tool. @@ -211,11 +229,6 @@ static cl::opt<bool> ClOptStack( "asan-opt-stack", cl::desc("Don't instrument scalar stack variables"), cl::Hidden, cl::init(false)); -static cl::opt<bool> ClCheckLifetime( - "asan-check-lifetime", - cl::desc("Use llvm.lifetime intrinsics to insert extra checks"), cl::Hidden, - cl::init(false)); - static cl::opt<bool> ClDynamicAllocaStack( "asan-stack-dynamic-alloca", cl::desc("Use dynamic alloca to represent stack variables"), cl::Hidden, @@ -226,6 +239,19 @@ static cl::opt<uint32_t> ClForceExperiment( cl::desc("Force optimization experiment (for testing)"), cl::Hidden, cl::init(0)); +static cl::opt<bool> + ClUsePrivateAliasForGlobals("asan-use-private-alias", + cl::desc("Use private aliases for global" + " variables"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> + ClUseMachOGlobalsSection("asan-globals-live-support", + cl::desc("Use linker features to support dead " + "code stripping of globals " + "(Mach-O only)"), + cl::Hidden, cl::init(false)); + // Debug flags. static cl::opt<int> ClDebug("asan-debug", cl::desc("debug"), cl::Hidden, cl::init(0)); @@ -334,11 +360,13 @@ struct ShadowMapping { static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, bool IsKasan) { bool IsAndroid = TargetTriple.isAndroid(); - bool IsIOS = TargetTriple.isiOS(); + bool IsIOS = TargetTriple.isiOS() || TargetTriple.isWatchOS(); bool IsFreeBSD = TargetTriple.isOSFreeBSD(); bool IsLinux = TargetTriple.isOSLinux(); bool IsPPC64 = TargetTriple.getArch() == llvm::Triple::ppc64 || TargetTriple.getArch() == llvm::Triple::ppc64le; + bool IsSystemZ = TargetTriple.getArch() == llvm::Triple::systemz; + bool IsX86 = TargetTriple.getArch() == llvm::Triple::x86; bool IsX86_64 = TargetTriple.getArch() == llvm::Triple::x86_64; bool IsMIPS32 = TargetTriple.getArch() == llvm::Triple::mips || TargetTriple.getArch() == llvm::Triple::mipsel; @@ -359,7 +387,8 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, else if (IsFreeBSD) Mapping.Offset = kFreeBSD_ShadowOffset32; else if (IsIOS) - Mapping.Offset = kIOSShadowOffset32; + // If we're targeting iOS and x86, the binary is built for iOS simulator. + Mapping.Offset = IsX86 ? kIOSSimShadowOffset32 : kIOSShadowOffset32; else if (IsWindows) Mapping.Offset = kWindowsShadowOffset32; else @@ -367,6 +396,8 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, } else { // LongSize == 64 if (IsPPC64) Mapping.Offset = kPPC64_ShadowOffset64; + else if (IsSystemZ) + Mapping.Offset = kSystemZ_ShadowOffset64; else if (IsFreeBSD) Mapping.Offset = kFreeBSD_ShadowOffset64; else if (IsLinux && IsX86_64) { @@ -374,8 +405,13 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, Mapping.Offset = kLinuxKasan_ShadowOffset64; else Mapping.Offset = kSmallX86_64ShadowOffset; + } else if (IsWindows && IsX86_64) { + Mapping.Offset = kWindowsShadowOffset64; } else if (IsMIPS64) Mapping.Offset = kMIPS64_ShadowOffset64; + else if (IsIOS) + // If we're targeting iOS and x86, the binary is built for iOS simulator. + Mapping.Offset = IsX86_64 ? kIOSSimShadowOffset64 : kIOSShadowOffset64; else if (IsAArch64) Mapping.Offset = kAArch64_ShadowOffset64; else @@ -383,14 +419,20 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, } Mapping.Scale = kDefaultShadowScale; - if (ClMappingScale) { + if (ClMappingScale.getNumOccurrences() > 0) { Mapping.Scale = ClMappingScale; } + if (ClMappingOffset.getNumOccurrences() > 0) { + Mapping.Offset = ClMappingOffset; + } + // OR-ing shadow offset if more efficient (at least on x86) if the offset // is a power of two, but on ppc64 we have to use add since the shadow - // offset is not necessary 1/8-th of the address space. - Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 + // offset is not necessary 1/8-th of the address space. On SystemZ, + // we could OR the constant in a single instruction, but it's more + // efficient to load it once and use indexed addressing. + Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ && !(Mapping.Offset & (Mapping.Offset - 1)); return Mapping; @@ -404,9 +446,11 @@ static size_t RedzoneSizeForScale(int MappingScale) { /// AddressSanitizer: instrument the code in module to find memory bugs. struct AddressSanitizer : public FunctionPass { - explicit AddressSanitizer(bool CompileKernel = false, bool Recover = false) + explicit AddressSanitizer(bool CompileKernel = false, bool Recover = false, + bool UseAfterScope = false) : FunctionPass(ID), CompileKernel(CompileKernel || ClEnableKasan), - Recover(Recover || ClRecover) { + Recover(Recover || ClRecover), + UseAfterScope(UseAfterScope || ClUseAfterScope) { initializeAddressSanitizerPass(*PassRegistry::getPassRegistry()); } const char *getPassName() const override { @@ -417,19 +461,20 @@ struct AddressSanitizer : public FunctionPass { AU.addRequired<TargetLibraryInfoWrapperPass>(); } uint64_t getAllocaSizeInBytes(AllocaInst *AI) const { + uint64_t ArraySize = 1; + if (AI->isArrayAllocation()) { + ConstantInt *CI = dyn_cast<ConstantInt>(AI->getArraySize()); + assert(CI && "non-constant array size"); + ArraySize = CI->getZExtValue(); + } Type *Ty = AI->getAllocatedType(); uint64_t SizeInBytes = AI->getModule()->getDataLayout().getTypeAllocSize(Ty); - return SizeInBytes; + return SizeInBytes * ArraySize; } /// Check if we want (and can) handle this alloca. bool isInterestingAlloca(AllocaInst &AI); - // Check if we have dynamic alloca. - bool isDynamicAlloca(AllocaInst &AI) const { - return AI.isArrayAllocation() || !AI.isStaticAlloca(); - } - /// If it is an interesting memory access, return the PointerOperand /// and set IsWrite/Alignment. Otherwise return nullptr. Value *isInterestingMemoryAccess(Instruction *I, bool *IsWrite, @@ -483,6 +528,7 @@ struct AddressSanitizer : public FunctionPass { int LongSize; bool CompileKernel; bool Recover; + bool UseAfterScope; Type *IntptrTy; ShadowMapping Mapping; DominatorTree *DT; @@ -519,6 +565,7 @@ class AddressSanitizerModule : public ModulePass { bool InstrumentGlobals(IRBuilder<> &IRB, Module &M); bool ShouldInstrumentGlobal(GlobalVariable *G); + bool ShouldUseMachOGlobalsSection() const; void poisonOneInitializer(Function &GlobalInit, GlobalValue *ModuleName); void createInitializerPoisonCalls(Module &M, GlobalValue *ModuleName); size_t MinRedzoneSizeForGlobal() const { @@ -536,6 +583,8 @@ class AddressSanitizerModule : public ModulePass { Function *AsanUnpoisonGlobals; Function *AsanRegisterGlobals; Function *AsanUnregisterGlobals; + Function *AsanRegisterImageGlobals; + Function *AsanUnregisterImageGlobals; }; // Stack poisoning does not play well with exception handling. @@ -680,7 +729,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { } StackAlignment = std::max(StackAlignment, AI.getAlignment()); - if (ASan.isDynamicAlloca(AI)) + if (!AI.isStaticAlloca()) DynamicAllocaVec.push_back(&AI); else AllocaVec.push_back(&AI); @@ -692,7 +741,8 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { Intrinsic::ID ID = II.getIntrinsicID(); if (ID == Intrinsic::stackrestore) StackRestoreVec.push_back(&II); if (ID == Intrinsic::localescape) LocalEscapeCall = &II; - if (!ClCheckLifetime) return; + if (!ASan.UseAfterScope) + return; if (ID != Intrinsic::lifetime_start && ID != Intrinsic::lifetime_end) return; // Found lifetime intrinsic, add ASan instrumentation if necessary. @@ -707,7 +757,8 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { return; // Find alloca instruction that corresponds to llvm.lifetime argument. AllocaInst *AI = findAllocaForValue(II.getArgOperand(1)); - if (!AI) return; + if (!AI || !ASan.isInterestingAlloca(*AI)) + return; bool DoPoison = (ID == Intrinsic::lifetime_end); AllocaPoisonCall APC = {&II, AI, SizeValue, DoPoison}; AllocaPoisonCallVec.push_back(APC); @@ -760,9 +811,10 @@ INITIALIZE_PASS_END( "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, false) FunctionPass *llvm::createAddressSanitizerFunctionPass(bool CompileKernel, - bool Recover) { + bool Recover, + bool UseAfterScope) { assert(!CompileKernel || Recover); - return new AddressSanitizer(CompileKernel, Recover); + return new AddressSanitizer(CompileKernel, Recover, UseAfterScope); } char AddressSanitizerModule::ID = 0; @@ -792,7 +844,7 @@ static GlobalVariable *createPrivateGlobalForString(Module &M, StringRef Str, GlobalVariable *GV = new GlobalVariable(M, StrConst->getType(), true, GlobalValue::PrivateLinkage, StrConst, kAsanGenPrefix); - if (AllowMerging) GV->setUnnamedAddr(true); + if (AllowMerging) GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); GV->setAlignment(1); // Strings may not be merged w/o setting align 1. return GV; } @@ -809,13 +861,23 @@ static GlobalVariable *createPrivateGlobalForSourceLoc(Module &M, auto GV = new GlobalVariable(M, LocStruct->getType(), true, GlobalValue::PrivateLinkage, LocStruct, kAsanGenPrefix); - GV->setUnnamedAddr(true); + GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); return GV; } -static bool GlobalWasGeneratedByAsan(GlobalVariable *G) { - return G->getName().find(kAsanGenPrefix) == 0 || - G->getName().find(kSanCovGenPrefix) == 0; +/// \brief Check if \p G has been created by a trusted compiler pass. +static bool GlobalWasGeneratedByCompiler(GlobalVariable *G) { + // Do not instrument asan globals. + if (G->getName().startswith(kAsanGenPrefix) || + G->getName().startswith(kSanCovGenPrefix) || + G->getName().startswith(kODRGenPrefix)) + return true; + + // Do not instrument gcov counter arrays. + if (G->getName() == "__llvm_gcov_ctr") + return true; + + return false; } Value *AddressSanitizer::memToShadow(Value *Shadow, IRBuilder<> &IRB) { @@ -858,7 +920,7 @@ bool AddressSanitizer::isInterestingAlloca(AllocaInst &AI) { bool IsInteresting = (AI.getAllocatedType()->isSized() && // alloca() may be called with 0 size, ignore it. - getAllocaSizeInBytes(&AI) > 0 && + ((!AI.isStaticAlloca()) || getAllocaSizeInBytes(&AI) > 0) && // We are only interested in allocas not promotable to registers. // Promotable allocas are common under -O0. (!ClSkipPromotableAllocas || !isAllocaPromotable(&AI)) && @@ -907,6 +969,14 @@ Value *AddressSanitizer::isInterestingMemoryAccess(Instruction *I, PtrOperand = XCHG->getPointerOperand(); } + // Do not instrument acesses from different address spaces; we cannot deal + // with them. + if (PtrOperand) { + Type *PtrTy = cast<PointerType>(PtrOperand->getType()->getScalarType()); + if (PtrTy->getPointerAddressSpace() != 0) + return nullptr; + } + // Treat memory accesses to promotable allocas as non-interesting since they // will not cause memory violations. This greatly speeds up the instrumented // executable at -O0. @@ -948,9 +1018,9 @@ void AddressSanitizer::instrumentPointerComparisonOrSubtraction( IRBuilder<> IRB(I); Function *F = isa<ICmpInst>(I) ? AsanPtrCmpFunction : AsanPtrSubFunction; Value *Param[2] = {I->getOperand(0), I->getOperand(1)}; - for (int i = 0; i < 2; i++) { - if (Param[i]->getType()->isPointerTy()) - Param[i] = IRB.CreatePointerCast(Param[i], IntptrTy); + for (Value *&i : Param) { + if (i->getType()->isPointerTy()) + i = IRB.CreatePointerCast(i, IntptrTy); } IRB.CreateCall(F, Param); } @@ -1048,7 +1118,7 @@ Instruction *AddressSanitizer::generateCrashCode(Instruction *InsertBefore, Value *AddressSanitizer::createSlowPathCmp(IRBuilder<> &IRB, Value *AddrLong, Value *ShadowValue, uint32_t TypeSize) { - size_t Granularity = 1 << Mapping.Scale; + size_t Granularity = static_cast<size_t>(1) << Mapping.Scale; // Addr & (Granularity - 1) Value *LastAccessedByte = IRB.CreateAnd(AddrLong, ConstantInt::get(IntptrTy, Granularity - 1)); @@ -1091,7 +1161,7 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, IRB.CreateLoad(IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); Value *Cmp = IRB.CreateICmpNE(ShadowValue, CmpVal); - size_t Granularity = 1 << Mapping.Scale; + size_t Granularity = 1ULL << Mapping.Scale; TerminatorInst *CrashTerm = nullptr; if (ClAlwaysSlowPath || (TypeSize < 8 * Granularity)) { @@ -1184,13 +1254,13 @@ void AddressSanitizerModule::createInitializerPoisonCalls( } bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { - Type *Ty = cast<PointerType>(G->getType())->getElementType(); + Type *Ty = G->getValueType(); DEBUG(dbgs() << "GLOBAL: " << *G << "\n"); if (GlobalsMD.get(G).IsBlacklisted) return false; if (!Ty->isSized()) return false; if (!G->hasInitializer()) return false; - if (GlobalWasGeneratedByAsan(G)) return false; // Our own global. + if (GlobalWasGeneratedByCompiler(G)) return false; // Our own globals. // Touch only those globals that will not be defined in other modules. // Don't handle ODR linkage types and COMDATs since other modules may be built // without ASan. @@ -1207,12 +1277,12 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { if (G->getAlignment() > MinRedzoneSizeForGlobal()) return false; if (G->hasSection()) { - StringRef Section(G->getSection()); + StringRef Section = G->getSection(); // Globals from llvm.metadata aren't emitted, do not instrument them. if (Section == "llvm.metadata") return false; // Do not instrument globals from special LLVM sections. - if (Section.find("__llvm") != StringRef::npos) return false; + if (Section.find("__llvm") != StringRef::npos || Section.find("__LLVM") != StringRef::npos) return false; // Do not instrument function pointers to initialization and termination // routines: dynamic linker will not properly handle redzones. @@ -1271,8 +1341,29 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { return true; } +// On Mach-O platforms, we emit global metadata in a separate section of the +// binary in order to allow the linker to properly dead strip. This is only +// supported on recent versions of ld64. +bool AddressSanitizerModule::ShouldUseMachOGlobalsSection() const { + if (!ClUseMachOGlobalsSection) + return false; + + if (!TargetTriple.isOSBinFormatMachO()) + return false; + + if (TargetTriple.isMacOSX() && !TargetTriple.isMacOSXVersionLT(10, 11)) + return true; + if (TargetTriple.isiOS() /* or tvOS */ && !TargetTriple.isOSVersionLT(9)) + return true; + if (TargetTriple.isWatchOS() && !TargetTriple.isOSVersionLT(2)) + return true; + + return false; +} + void AddressSanitizerModule::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); + // Declare our poisoning and unpoisoning functions. AsanPoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); @@ -1280,6 +1371,7 @@ void AddressSanitizerModule::initializeCallbacks(Module &M) { AsanUnpoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( kAsanUnpoisonGlobalsName, IRB.getVoidTy(), nullptr)); AsanUnpoisonGlobals->setLinkage(Function::ExternalLinkage); + // Declare functions that register/unregister globals. AsanRegisterGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); @@ -1288,6 +1380,18 @@ void AddressSanitizerModule::initializeCallbacks(Module &M) { M.getOrInsertFunction(kAsanUnregisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanUnregisterGlobals->setLinkage(Function::ExternalLinkage); + + // Declare the functions that find globals in a shared object and then invoke + // the (un)register function on them. + AsanRegisterImageGlobals = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(kAsanRegisterImageGlobalsName, + IRB.getVoidTy(), IntptrTy, nullptr)); + AsanRegisterImageGlobals->setLinkage(Function::ExternalLinkage); + + AsanUnregisterImageGlobals = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(kAsanUnregisterImageGlobalsName, + IRB.getVoidTy(), IntptrTy, nullptr)); + AsanUnregisterImageGlobals->setLinkage(Function::ExternalLinkage); } // This function replaces all global variables with new variables that have @@ -1313,10 +1417,11 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { // const char *module_name; // size_t has_dynamic_init; // void *source_location; + // size_t odr_indicator; // We initialize an array of such structures and pass it to a run-time call. StructType *GlobalStructTy = StructType::get(IntptrTy, IntptrTy, IntptrTy, IntptrTy, IntptrTy, - IntptrTy, IntptrTy, nullptr); + IntptrTy, IntptrTy, IntptrTy, nullptr); SmallVector<Constant *, 16> Initializers(n); bool HasDynamicallyInitializedGlobals = false; @@ -1332,14 +1437,14 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { GlobalVariable *G = GlobalsToChange[i]; auto MD = GlobalsMD.get(G); + StringRef NameForGlobal = G->getName(); // Create string holding the global name (use global name from metadata // if it's available, otherwise just write the name of global variable). GlobalVariable *Name = createPrivateGlobalForString( - M, MD.Name.empty() ? G->getName() : MD.Name, + M, MD.Name.empty() ? NameForGlobal : MD.Name, /*AllowMerging*/ true); - PointerType *PtrTy = cast<PointerType>(G->getType()); - Type *Ty = PtrTy->getElementType(); + Type *Ty = G->getValueType(); uint64_t SizeInBytes = DL.getTypeAllocSize(Ty); uint64_t MinRZ = MinRedzoneSizeForGlobal(); // MinRZ <= RZ <= kMaxGlobalRedzone @@ -1384,41 +1489,125 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { SourceLoc = ConstantInt::get(IntptrTy, 0); } + Constant *ODRIndicator = ConstantExpr::getNullValue(IRB.getInt8PtrTy()); + GlobalValue *InstrumentedGlobal = NewGlobal; + + bool CanUsePrivateAliases = TargetTriple.isOSBinFormatELF(); + if (CanUsePrivateAliases && ClUsePrivateAliasForGlobals) { + // Create local alias for NewGlobal to avoid crash on ODR between + // instrumented and non-instrumented libraries. + auto *GA = GlobalAlias::create(GlobalValue::InternalLinkage, + NameForGlobal + M.getName(), NewGlobal); + + // With local aliases, we need to provide another externally visible + // symbol __odr_asan_XXX to detect ODR violation. + auto *ODRIndicatorSym = + new GlobalVariable(M, IRB.getInt8Ty(), false, Linkage, + Constant::getNullValue(IRB.getInt8Ty()), + kODRGenPrefix + NameForGlobal, nullptr, + NewGlobal->getThreadLocalMode()); + + // Set meaningful attributes for indicator symbol. + ODRIndicatorSym->setVisibility(NewGlobal->getVisibility()); + ODRIndicatorSym->setDLLStorageClass(NewGlobal->getDLLStorageClass()); + ODRIndicatorSym->setAlignment(1); + ODRIndicator = ODRIndicatorSym; + InstrumentedGlobal = GA; + } + Initializers[i] = ConstantStruct::get( - GlobalStructTy, ConstantExpr::getPointerCast(NewGlobal, IntptrTy), + GlobalStructTy, + ConstantExpr::getPointerCast(InstrumentedGlobal, IntptrTy), ConstantInt::get(IntptrTy, SizeInBytes), ConstantInt::get(IntptrTy, SizeInBytes + RightRedzoneSize), ConstantExpr::getPointerCast(Name, IntptrTy), ConstantExpr::getPointerCast(ModuleName, IntptrTy), - ConstantInt::get(IntptrTy, MD.IsDynInit), SourceLoc, nullptr); + ConstantInt::get(IntptrTy, MD.IsDynInit), SourceLoc, + ConstantExpr::getPointerCast(ODRIndicator, IntptrTy), nullptr); if (ClInitializers && MD.IsDynInit) HasDynamicallyInitializedGlobals = true; DEBUG(dbgs() << "NEW GLOBAL: " << *NewGlobal << "\n"); } - ArrayType *ArrayOfGlobalStructTy = ArrayType::get(GlobalStructTy, n); - GlobalVariable *AllGlobals = new GlobalVariable( - M, ArrayOfGlobalStructTy, false, GlobalVariable::InternalLinkage, - ConstantArray::get(ArrayOfGlobalStructTy, Initializers), ""); + + GlobalVariable *AllGlobals = nullptr; + GlobalVariable *RegisteredFlag = nullptr; + + // On recent Mach-O platforms, we emit the global metadata in a way that + // allows the linker to properly strip dead globals. + if (ShouldUseMachOGlobalsSection()) { + // RegisteredFlag serves two purposes. First, we can pass it to dladdr() + // to look up the loaded image that contains it. Second, we can store in it + // whether registration has already occurred, to prevent duplicate + // registration. + // + // Common linkage allows us to coalesce needles defined in each object + // file so that there's only one per shared library. + RegisteredFlag = new GlobalVariable( + M, IntptrTy, false, GlobalVariable::CommonLinkage, + ConstantInt::get(IntptrTy, 0), kAsanGlobalsRegisteredFlagName); + + // We also emit a structure which binds the liveness of the global + // variable to the metadata struct. + StructType *LivenessTy = StructType::get(IntptrTy, IntptrTy, nullptr); + + for (size_t i = 0; i < n; i++) { + GlobalVariable *Metadata = new GlobalVariable( + M, GlobalStructTy, false, GlobalVariable::InternalLinkage, + Initializers[i], ""); + Metadata->setSection("__DATA,__asan_globals,regular"); + Metadata->setAlignment(1); // don't leave padding in between + + auto LivenessBinder = ConstantStruct::get(LivenessTy, + Initializers[i]->getAggregateElement(0u), + ConstantExpr::getPointerCast(Metadata, IntptrTy), + nullptr); + GlobalVariable *Liveness = new GlobalVariable( + M, LivenessTy, false, GlobalVariable::InternalLinkage, + LivenessBinder, ""); + Liveness->setSection("__DATA,__asan_liveness,regular,live_support"); + } + } else { + // On all other platfoms, we just emit an array of global metadata + // structures. + ArrayType *ArrayOfGlobalStructTy = ArrayType::get(GlobalStructTy, n); + AllGlobals = new GlobalVariable( + M, ArrayOfGlobalStructTy, false, GlobalVariable::InternalLinkage, + ConstantArray::get(ArrayOfGlobalStructTy, Initializers), ""); + } // Create calls for poisoning before initializers run and unpoisoning after. if (HasDynamicallyInitializedGlobals) createInitializerPoisonCalls(M, ModuleName); - IRB.CreateCall(AsanRegisterGlobals, - {IRB.CreatePointerCast(AllGlobals, IntptrTy), - ConstantInt::get(IntptrTy, n)}); - // We also need to unregister globals at the end, e.g. when a shared library + // Create a call to register the globals with the runtime. + if (ShouldUseMachOGlobalsSection()) { + IRB.CreateCall(AsanRegisterImageGlobals, + {IRB.CreatePointerCast(RegisteredFlag, IntptrTy)}); + } else { + IRB.CreateCall(AsanRegisterGlobals, + {IRB.CreatePointerCast(AllGlobals, IntptrTy), + ConstantInt::get(IntptrTy, n)}); + } + + // We also need to unregister globals at the end, e.g., when a shared library // gets closed. Function *AsanDtorFunction = Function::Create(FunctionType::get(Type::getVoidTy(*C), false), GlobalValue::InternalLinkage, kAsanModuleDtorName, &M); BasicBlock *AsanDtorBB = BasicBlock::Create(*C, "", AsanDtorFunction); IRBuilder<> IRB_Dtor(ReturnInst::Create(*C, AsanDtorBB)); - IRB_Dtor.CreateCall(AsanUnregisterGlobals, - {IRB.CreatePointerCast(AllGlobals, IntptrTy), - ConstantInt::get(IntptrTy, n)}); + + if (ShouldUseMachOGlobalsSection()) { + IRB_Dtor.CreateCall(AsanUnregisterImageGlobals, + {IRB.CreatePointerCast(RegisteredFlag, IntptrTy)}); + } else { + IRB_Dtor.CreateCall(AsanUnregisterGlobals, + {IRB.CreatePointerCast(AllGlobals, IntptrTy), + ConstantInt::get(IntptrTy, n)}); + } + appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority); DEBUG(dbgs() << M); @@ -1467,7 +1656,7 @@ void AddressSanitizer::initializeCallbacks(Module &M) { IRB.getVoidTy(), IntptrTy, IntptrTy, ExpType, nullptr)); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { - const std::string Suffix = TypeStr + itostr(1 << AccessSizeIndex); + const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex); AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, @@ -1608,6 +1797,8 @@ bool AddressSanitizer::runOnFunction(Function &F) { bool IsWrite; unsigned Alignment; uint64_t TypeSize; + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); // Fill the set of memory operations to instrument. for (auto &BB : F) { @@ -1636,6 +1827,8 @@ bool AddressSanitizer::runOnFunction(Function &F) { TempsToInstrument.clear(); if (CS.doesNotReturn()) NoReturnCalls.push_back(CS.getInstruction()); } + if (CallInst *CI = dyn_cast<CallInst>(&Inst)) + maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI); continue; } ToInstrument.push_back(&Inst); @@ -1648,8 +1841,6 @@ bool AddressSanitizer::runOnFunction(Function &F) { CompileKernel || (ClInstrumentationWithCallsThreshold >= 0 && ToInstrument.size() > (unsigned)ClInstrumentationWithCallsThreshold); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); const DataLayout &DL = F.getParent()->getDataLayout(); ObjectSizeOffsetVisitor ObjSizeVis(DL, TLI, F.getContext(), /*RoundToAlign=*/true); @@ -1713,12 +1904,15 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { M.getOrInsertFunction(kAsanStackFreeNameTemplate + Suffix, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); } - AsanPoisonStackMemoryFunc = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanPoisonStackMemoryName, IRB.getVoidTy(), - IntptrTy, IntptrTy, nullptr)); - AsanUnpoisonStackMemoryFunc = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), - IntptrTy, IntptrTy, nullptr)); + if (ASan.UseAfterScope) { + AsanPoisonStackMemoryFunc = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(kAsanPoisonStackMemoryName, IRB.getVoidTy(), + IntptrTy, IntptrTy, nullptr)); + AsanUnpoisonStackMemoryFunc = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), + IntptrTy, IntptrTy, nullptr)); + } + AsanAllocaPoisonFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanAllocasUnpoisonFunc = @@ -1825,13 +2019,21 @@ void FunctionStackPoisoner::poisonStack() { assert(AllocaVec.size() > 0 || DynamicAllocaVec.size() > 0); // Insert poison calls for lifetime intrinsics for alloca. - bool HavePoisonedAllocas = false; + bool HavePoisonedStaticAllocas = false; for (const auto &APC : AllocaPoisonCallVec) { assert(APC.InsBefore); assert(APC.AI); + assert(ASan.isInterestingAlloca(*APC.AI)); + bool IsDynamicAlloca = !(*APC.AI).isStaticAlloca(); + if (!ClInstrumentAllocas && IsDynamicAlloca) + continue; + IRBuilder<> IRB(APC.InsBefore); poisonAlloca(APC.AI, APC.Size, IRB, APC.DoPoison); - HavePoisonedAllocas |= APC.DoPoison; + // Dynamic allocas will be unpoisoned unconditionally below in + // unpoisonDynamicAllocas. + // Flag that we need unpoison static allocas. + HavePoisonedStaticAllocas |= (APC.DoPoison && !IsDynamicAlloca); } if (ClInstrumentAllocas && DynamicAllocaVec.size() > 0) { @@ -1846,7 +2048,7 @@ void FunctionStackPoisoner::poisonStack() { int StackMallocIdx = -1; DebugLoc EntryDebugLocation; - if (auto SP = getDISubprogram(&F)) + if (auto SP = F.getSubprogram()) EntryDebugLocation = DebugLoc::get(SP->getScopeLine(), 0, SP); Instruction *InsBefore = AllocaVec[0]; @@ -1878,7 +2080,7 @@ void FunctionStackPoisoner::poisonStack() { // i.e. 32 bytes on 64-bit platforms and 16 bytes in 32-bit platforms. size_t MinHeaderSize = ASan.LongSize / 2; ASanStackFrameLayout L; - ComputeASanStackFrameLayout(SVD, 1UL << Mapping.Scale, MinHeaderSize, &L); + ComputeASanStackFrameLayout(SVD, 1ULL << Mapping.Scale, MinHeaderSize, &L); DEBUG(dbgs() << L.DescriptionString << " --- " << L.FrameSize << "\n"); uint64_t LocalStackSize = L.FrameSize; bool DoStackMalloc = ClUseAfterReturn && !ASan.CompileKernel && @@ -1904,13 +2106,13 @@ void FunctionStackPoisoner::poisonStack() { // ? __asan_stack_malloc_N(LocalStackSize) // : nullptr; // void *LocalStackBase = (FakeStack) ? FakeStack : alloca(LocalStackSize); - Constant *OptionDetectUAR = F.getParent()->getOrInsertGlobal( - kAsanOptionDetectUAR, IRB.getInt32Ty()); - Value *UARIsEnabled = - IRB.CreateICmpNE(IRB.CreateLoad(OptionDetectUAR), + Constant *OptionDetectUseAfterReturn = F.getParent()->getOrInsertGlobal( + kAsanOptionDetectUseAfterReturn, IRB.getInt32Ty()); + Value *UseAfterReturnIsEnabled = + IRB.CreateICmpNE(IRB.CreateLoad(OptionDetectUseAfterReturn), Constant::getNullValue(IRB.getInt32Ty())); Instruction *Term = - SplitBlockAndInsertIfThen(UARIsEnabled, InsBefore, false); + SplitBlockAndInsertIfThen(UseAfterReturnIsEnabled, InsBefore, false); IRBuilder<> IRBIf(Term); IRBIf.SetCurrentDebugLocation(EntryDebugLocation); StackMallocIdx = StackMallocSizeClass(LocalStackSize); @@ -1920,7 +2122,7 @@ void FunctionStackPoisoner::poisonStack() { ConstantInt::get(IntptrTy, LocalStackSize)); IRB.SetInsertPoint(InsBefore); IRB.SetCurrentDebugLocation(EntryDebugLocation); - FakeStack = createPHI(IRB, UARIsEnabled, FakeStackValue, Term, + FakeStack = createPHI(IRB, UseAfterReturnIsEnabled, FakeStackValue, Term, ConstantInt::get(IntptrTy, 0)); Value *NoFakeStack = @@ -1977,6 +2179,16 @@ void FunctionStackPoisoner::poisonStack() { Value *ShadowBase = ASan.memToShadow(LocalStackBase, IRB); poisonRedZones(L.ShadowBytes, IRB, ShadowBase, true); + auto UnpoisonStack = [&](IRBuilder<> &IRB) { + if (HavePoisonedStaticAllocas) { + // If we poisoned some allocas in llvm.lifetime analysis, + // unpoison whole stack frame now. + poisonAlloca(LocalStackBase, LocalStackSize, IRB, false); + } else { + poisonRedZones(L.ShadowBytes, IRB, ShadowBase, false); + } + }; + // (Un)poison the stack before all ret instructions. for (auto Ret : RetVec) { IRBuilder<> IRBRet(Ret); @@ -2021,13 +2233,9 @@ void FunctionStackPoisoner::poisonStack() { } IRBuilder<> IRBElse(ElseTerm); - poisonRedZones(L.ShadowBytes, IRBElse, ShadowBase, false); - } else if (HavePoisonedAllocas) { - // If we poisoned some allocas in llvm.lifetime analysis, - // unpoison whole stack frame now. - poisonAlloca(LocalStackBase, LocalStackSize, IRBRet, false); + UnpoisonStack(IRBElse); } else { - poisonRedZones(L.ShadowBytes, IRBRet, ShadowBase, false); + UnpoisonStack(IRBRet); } } diff --git a/lib/Transforms/Instrumentation/BoundsChecking.cpp b/lib/Transforms/Instrumentation/BoundsChecking.cpp index fd3dfd9af033..d4c8369fa9d3 100644 --- a/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -36,7 +36,7 @@ STATISTIC(ChecksAdded, "Bounds checks added"); STATISTIC(ChecksSkipped, "Bounds checks skipped"); STATISTIC(ChecksUnable, "Bounds checks unable to add"); -typedef IRBuilder<true, TargetFolder> BuilderTy; +typedef IRBuilder<TargetFolder> BuilderTy; namespace { struct BoundsChecking : public FunctionPass { @@ -185,9 +185,8 @@ bool BoundsChecking::runOnFunction(Function &F) { } bool MadeChange = false; - for (std::vector<Instruction*>::iterator i = WorkList.begin(), - e = WorkList.end(); i != e; ++i) { - Inst = *i; + for (Instruction *i : WorkList) { + Inst = i; Builder->SetInsertPoint(Inst); if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { diff --git a/lib/Transforms/Instrumentation/CFGMST.h b/lib/Transforms/Instrumentation/CFGMST.h index c47fdbf68996..3cd7351cad62 100644 --- a/lib/Transforms/Instrumentation/CFGMST.h +++ b/lib/Transforms/Instrumentation/CFGMST.h @@ -21,7 +21,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include <string> #include <utility> #include <vector> diff --git a/lib/Transforms/Instrumentation/CMakeLists.txt b/lib/Transforms/Instrumentation/CMakeLists.txt index cae1e5af7ac7..57a569b3791e 100644 --- a/lib/Transforms/Instrumentation/CMakeLists.txt +++ b/lib/Transforms/Instrumentation/CMakeLists.txt @@ -4,12 +4,13 @@ add_llvm_library(LLVMInstrumentation DataFlowSanitizer.cpp GCOVProfiling.cpp MemorySanitizer.cpp + IndirectCallPromotion.cpp Instrumentation.cpp InstrProfiling.cpp PGOInstrumentation.cpp - SafeStack.cpp SanitizerCoverage.cpp ThreadSanitizer.cpp + EfficiencySanitizer.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index d459fc50d136..b34d5b8c45a7 100644 --- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -134,7 +134,7 @@ namespace { StringRef GetGlobalTypeString(const GlobalValue &G) { // Types of GlobalVariables are always pointer types. - Type *GType = G.getType()->getElementType(); + Type *GType = G.getValueType(); // For now we support blacklisting struct types only. if (StructType *SGType = dyn_cast<StructType>(GType)) { if (!SGType->isLiteral()) @@ -166,7 +166,7 @@ class DFSanABIList { if (isIn(*GA.getParent(), Category)) return true; - if (isa<FunctionType>(GA.getType()->getElementType())) + if (isa<FunctionType>(GA.getValueType())) return SCL->inSection("fun", GA.getName(), Category); return SCL->inSection("global", GA.getName(), Category) || @@ -791,25 +791,20 @@ bool DataFlowSanitizer::runOnModule(Module &M) { } } - for (std::vector<Function *>::iterator i = FnsToInstrument.begin(), - e = FnsToInstrument.end(); - i != e; ++i) { - if (!*i || (*i)->isDeclaration()) + for (Function *i : FnsToInstrument) { + if (!i || i->isDeclaration()) continue; - removeUnreachableBlocks(**i); + removeUnreachableBlocks(*i); - DFSanFunction DFSF(*this, *i, FnsWithNativeABI.count(*i)); + DFSanFunction DFSF(*this, i, FnsWithNativeABI.count(i)); // DFSanVisitor may create new basic blocks, which confuses df_iterator. // Build a copy of the list before iterating over it. - llvm::SmallVector<BasicBlock *, 4> BBList( - depth_first(&(*i)->getEntryBlock())); + llvm::SmallVector<BasicBlock *, 4> BBList(depth_first(&i->getEntryBlock())); - for (llvm::SmallVector<BasicBlock *, 4>::iterator i = BBList.begin(), - e = BBList.end(); - i != e; ++i) { - Instruction *Inst = &(*i)->front(); + for (BasicBlock *i : BBList) { + Instruction *Inst = &i->front(); while (1) { // DFSanVisitor may split the current basic block, changing the current // instruction's next pointer and moving the next instruction to the @@ -1066,11 +1061,10 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, SmallVector<Value *, 2> Objs; GetUnderlyingObjects(Addr, Objs, Pos->getModule()->getDataLayout()); bool AllConstants = true; - for (SmallVector<Value *, 2>::iterator i = Objs.begin(), e = Objs.end(); - i != e; ++i) { - if (isa<Function>(*i) || isa<BlockAddress>(*i)) + for (Value *Obj : Objs) { + if (isa<Function>(Obj) || isa<BlockAddress>(Obj)) continue; - if (isa<GlobalVariable>(*i) && cast<GlobalVariable>(*i)->isConstant()) + if (isa<GlobalVariable>(Obj) && cast<GlobalVariable>(Obj)->isConstant()) continue; AllConstants = false; @@ -1412,10 +1406,6 @@ void DFSanVisitor::visitCallSite(CallSite CS) { if (F == DFSF.DFS.DFSanVarargWrapperFn) return; - assert(!(cast<FunctionType>( - CS.getCalledValue()->getType()->getPointerElementType())->isVarArg() && - dyn_cast<InvokeInst>(CS.getInstruction()))); - IRBuilder<> IRB(CS.getInstruction()); DenseMap<Value *, Function *>::iterator i = diff --git a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp new file mode 100644 index 000000000000..fb80f87369f9 --- /dev/null +++ b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp @@ -0,0 +1,901 @@ +//===-- EfficiencySanitizer.cpp - performance tuner -----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file is a part of EfficiencySanitizer, a family of performance tuners +// that detects multiple performance issues via separate sub-tools. +// +// The instrumentation phase is straightforward: +// - Take action on every memory access: either inlined instrumentation, +// or Inserted calls to our run-time library. +// - Optimizations may apply to avoid instrumenting some of the accesses. +// - Turn mem{set,cpy,move} instrinsics into library calls. +// The rest is handled by the run-time library. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "esan" + +// The tool type must be just one of these ClTool* options, as the tools +// cannot be combined due to shadow memory constraints. +static cl::opt<bool> + ClToolCacheFrag("esan-cache-frag", cl::init(false), + cl::desc("Detect data cache fragmentation"), cl::Hidden); +static cl::opt<bool> + ClToolWorkingSet("esan-working-set", cl::init(false), + cl::desc("Measure the working set size"), cl::Hidden); +// Each new tool will get its own opt flag here. +// These are converted to EfficiencySanitizerOptions for use +// in the code. + +static cl::opt<bool> ClInstrumentLoadsAndStores( + "esan-instrument-loads-and-stores", cl::init(true), + cl::desc("Instrument loads and stores"), cl::Hidden); +static cl::opt<bool> ClInstrumentMemIntrinsics( + "esan-instrument-memintrinsics", cl::init(true), + cl::desc("Instrument memintrinsics (memset/memcpy/memmove)"), cl::Hidden); +static cl::opt<bool> ClInstrumentFastpath( + "esan-instrument-fastpath", cl::init(true), + cl::desc("Instrument fastpath"), cl::Hidden); +static cl::opt<bool> ClAuxFieldInfo( + "esan-aux-field-info", cl::init(true), + cl::desc("Generate binary with auxiliary struct field information"), + cl::Hidden); + +// Experiments show that the performance difference can be 2x or more, +// and accuracy loss is typically negligible, so we turn this on by default. +static cl::opt<bool> ClAssumeIntraCacheLine( + "esan-assume-intra-cache-line", cl::init(true), + cl::desc("Assume each memory access touches just one cache line, for " + "better performance but with a potential loss of accuracy."), + cl::Hidden); + +STATISTIC(NumInstrumentedLoads, "Number of instrumented loads"); +STATISTIC(NumInstrumentedStores, "Number of instrumented stores"); +STATISTIC(NumFastpaths, "Number of instrumented fastpaths"); +STATISTIC(NumAccessesWithIrregularSize, + "Number of accesses with a size outside our targeted callout sizes"); +STATISTIC(NumIgnoredStructs, "Number of ignored structs"); +STATISTIC(NumIgnoredGEPs, "Number of ignored GEP instructions"); +STATISTIC(NumInstrumentedGEPs, "Number of instrumented GEP instructions"); +STATISTIC(NumAssumedIntraCacheLine, + "Number of accesses assumed to be intra-cache-line"); + +static const uint64_t EsanCtorAndDtorPriority = 0; +static const char *const EsanModuleCtorName = "esan.module_ctor"; +static const char *const EsanModuleDtorName = "esan.module_dtor"; +static const char *const EsanInitName = "__esan_init"; +static const char *const EsanExitName = "__esan_exit"; + +// We need to specify the tool to the runtime earlier than +// the ctor is called in some cases, so we set a global variable. +static const char *const EsanWhichToolName = "__esan_which_tool"; + +// We must keep these Shadow* constants consistent with the esan runtime. +// FIXME: Try to place these shadow constants, the names of the __esan_* +// interface functions, and the ToolType enum into a header shared between +// llvm and compiler-rt. +static const uint64_t ShadowMask = 0x00000fffffffffffull; +static const uint64_t ShadowOffs[3] = { // Indexed by scale + 0x0000130000000000ull, + 0x0000220000000000ull, + 0x0000440000000000ull, +}; +// This array is indexed by the ToolType enum. +static const int ShadowScale[] = { + 0, // ESAN_None. + 2, // ESAN_CacheFrag: 4B:1B, so 4 to 1 == >>2. + 6, // ESAN_WorkingSet: 64B:1B, so 64 to 1 == >>6. +}; + +// MaxStructCounterNameSize is a soft size limit to avoid insanely long +// names for those extremely large structs. +static const unsigned MaxStructCounterNameSize = 512; + +namespace { + +static EfficiencySanitizerOptions +OverrideOptionsFromCL(EfficiencySanitizerOptions Options) { + if (ClToolCacheFrag) + Options.ToolType = EfficiencySanitizerOptions::ESAN_CacheFrag; + else if (ClToolWorkingSet) + Options.ToolType = EfficiencySanitizerOptions::ESAN_WorkingSet; + + // Direct opt invocation with no params will have the default ESAN_None. + // We run the default tool in that case. + if (Options.ToolType == EfficiencySanitizerOptions::ESAN_None) + Options.ToolType = EfficiencySanitizerOptions::ESAN_CacheFrag; + + return Options; +} + +// Create a constant for Str so that we can pass it to the run-time lib. +static GlobalVariable *createPrivateGlobalForString(Module &M, StringRef Str, + bool AllowMerging) { + Constant *StrConst = ConstantDataArray::getString(M.getContext(), Str); + // We use private linkage for module-local strings. If they can be merged + // with another one, we set the unnamed_addr attribute. + GlobalVariable *GV = + new GlobalVariable(M, StrConst->getType(), true, + GlobalValue::PrivateLinkage, StrConst, ""); + if (AllowMerging) + GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + GV->setAlignment(1); // Strings may not be merged w/o setting align 1. + return GV; +} + +/// EfficiencySanitizer: instrument each module to find performance issues. +class EfficiencySanitizer : public ModulePass { +public: + EfficiencySanitizer( + const EfficiencySanitizerOptions &Opts = EfficiencySanitizerOptions()) + : ModulePass(ID), Options(OverrideOptionsFromCL(Opts)) {} + const char *getPassName() const override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnModule(Module &M) override; + static char ID; + +private: + bool initOnModule(Module &M); + void initializeCallbacks(Module &M); + bool shouldIgnoreStructType(StructType *StructTy); + void createStructCounterName( + StructType *StructTy, SmallString<MaxStructCounterNameSize> &NameStr); + void createCacheFragAuxGV( + Module &M, const DataLayout &DL, StructType *StructTy, + GlobalVariable *&TypeNames, GlobalVariable *&Offsets, GlobalVariable *&Size); + GlobalVariable *createCacheFragInfoGV(Module &M, const DataLayout &DL, + Constant *UnitName); + Constant *createEsanInitToolInfoArg(Module &M, const DataLayout &DL); + void createDestructor(Module &M, Constant *ToolInfoArg); + bool runOnFunction(Function &F, Module &M); + bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL); + bool instrumentMemIntrinsic(MemIntrinsic *MI); + bool instrumentGetElementPtr(Instruction *I, Module &M); + bool insertCounterUpdate(Instruction *I, StructType *StructTy, + unsigned CounterIdx); + unsigned getFieldCounterIdx(StructType *StructTy) { + return 0; + } + unsigned getArrayCounterIdx(StructType *StructTy) { + return StructTy->getNumElements(); + } + unsigned getStructCounterSize(StructType *StructTy) { + // The struct counter array includes: + // - one counter for each struct field, + // - one counter for the struct access within an array. + return (StructTy->getNumElements()/*field*/ + 1/*array*/); + } + bool shouldIgnoreMemoryAccess(Instruction *I); + int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL); + Value *appToShadow(Value *Shadow, IRBuilder<> &IRB); + bool instrumentFastpath(Instruction *I, const DataLayout &DL, bool IsStore, + Value *Addr, unsigned Alignment); + // Each tool has its own fastpath routine: + bool instrumentFastpathCacheFrag(Instruction *I, const DataLayout &DL, + Value *Addr, unsigned Alignment); + bool instrumentFastpathWorkingSet(Instruction *I, const DataLayout &DL, + Value *Addr, unsigned Alignment); + + EfficiencySanitizerOptions Options; + LLVMContext *Ctx; + Type *IntptrTy; + // Our slowpath involves callouts to the runtime library. + // Access sizes are powers of two: 1, 2, 4, 8, 16. + static const size_t NumberOfAccessSizes = 5; + Function *EsanAlignedLoad[NumberOfAccessSizes]; + Function *EsanAlignedStore[NumberOfAccessSizes]; + Function *EsanUnalignedLoad[NumberOfAccessSizes]; + Function *EsanUnalignedStore[NumberOfAccessSizes]; + // For irregular sizes of any alignment: + Function *EsanUnalignedLoadN, *EsanUnalignedStoreN; + Function *MemmoveFn, *MemcpyFn, *MemsetFn; + Function *EsanCtorFunction; + Function *EsanDtorFunction; + // Remember the counter variable for each struct type to avoid + // recomputing the variable name later during instrumentation. + std::map<Type *, GlobalVariable *> StructTyMap; +}; +} // namespace + +char EfficiencySanitizer::ID = 0; +INITIALIZE_PASS_BEGIN( + EfficiencySanitizer, "esan", + "EfficiencySanitizer: finds performance issues.", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END( + EfficiencySanitizer, "esan", + "EfficiencySanitizer: finds performance issues.", false, false) + +const char *EfficiencySanitizer::getPassName() const { + return "EfficiencySanitizer"; +} + +void EfficiencySanitizer::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} + +ModulePass * +llvm::createEfficiencySanitizerPass(const EfficiencySanitizerOptions &Options) { + return new EfficiencySanitizer(Options); +} + +void EfficiencySanitizer::initializeCallbacks(Module &M) { + IRBuilder<> IRB(M.getContext()); + // Initialize the callbacks. + for (size_t Idx = 0; Idx < NumberOfAccessSizes; ++Idx) { + const unsigned ByteSize = 1U << Idx; + std::string ByteSizeStr = utostr(ByteSize); + // We'll inline the most common (i.e., aligned and frequent sizes) + // load + store instrumentation: these callouts are for the slowpath. + SmallString<32> AlignedLoadName("__esan_aligned_load" + ByteSizeStr); + EsanAlignedLoad[Idx] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + AlignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + SmallString<32> AlignedStoreName("__esan_aligned_store" + ByteSizeStr); + EsanAlignedStore[Idx] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + AlignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + SmallString<32> UnalignedLoadName("__esan_unaligned_load" + ByteSizeStr); + EsanUnalignedLoad[Idx] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + UnalignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + SmallString<32> UnalignedStoreName("__esan_unaligned_store" + ByteSizeStr); + EsanUnalignedStore[Idx] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + UnalignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + } + EsanUnalignedLoadN = checkSanitizerInterfaceFunction( + M.getOrInsertFunction("__esan_unaligned_loadN", IRB.getVoidTy(), + IRB.getInt8PtrTy(), IntptrTy, nullptr)); + EsanUnalignedStoreN = checkSanitizerInterfaceFunction( + M.getOrInsertFunction("__esan_unaligned_storeN", IRB.getVoidTy(), + IRB.getInt8PtrTy(), IntptrTy, nullptr)); + MemmoveFn = checkSanitizerInterfaceFunction( + M.getOrInsertFunction("memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IntptrTy, nullptr)); + MemcpyFn = checkSanitizerInterfaceFunction( + M.getOrInsertFunction("memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IntptrTy, nullptr)); + MemsetFn = checkSanitizerInterfaceFunction( + M.getOrInsertFunction("memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt32Ty(), IntptrTy, nullptr)); +} + +bool EfficiencySanitizer::shouldIgnoreStructType(StructType *StructTy) { + if (StructTy == nullptr || StructTy->isOpaque() /* no struct body */) + return true; + return false; +} + +void EfficiencySanitizer::createStructCounterName( + StructType *StructTy, SmallString<MaxStructCounterNameSize> &NameStr) { + // Append NumFields and field type ids to avoid struct conflicts + // with the same name but different fields. + if (StructTy->hasName()) + NameStr += StructTy->getName(); + else + NameStr += "struct.anon"; + // We allow the actual size of the StructCounterName to be larger than + // MaxStructCounterNameSize and append #NumFields and at least one + // field type id. + // Append #NumFields. + NameStr += "#"; + Twine(StructTy->getNumElements()).toVector(NameStr); + // Append struct field type ids in the reverse order. + for (int i = StructTy->getNumElements() - 1; i >= 0; --i) { + NameStr += "#"; + Twine(StructTy->getElementType(i)->getTypeID()).toVector(NameStr); + if (NameStr.size() >= MaxStructCounterNameSize) + break; + } + if (StructTy->isLiteral()) { + // End with # for literal struct. + NameStr += "#"; + } +} + +// Create global variables with auxiliary information (e.g., struct field size, +// offset, and type name) for better user report. +void EfficiencySanitizer::createCacheFragAuxGV( + Module &M, const DataLayout &DL, StructType *StructTy, + GlobalVariable *&TypeName, GlobalVariable *&Offset, + GlobalVariable *&Size) { + auto *Int8PtrTy = Type::getInt8PtrTy(*Ctx); + auto *Int32Ty = Type::getInt32Ty(*Ctx); + // FieldTypeName. + auto *TypeNameArrayTy = ArrayType::get(Int8PtrTy, StructTy->getNumElements()); + TypeName = new GlobalVariable(M, TypeNameArrayTy, true, + GlobalVariable::InternalLinkage, nullptr); + SmallVector<Constant *, 16> TypeNameVec; + // FieldOffset. + auto *OffsetArrayTy = ArrayType::get(Int32Ty, StructTy->getNumElements()); + Offset = new GlobalVariable(M, OffsetArrayTy, true, + GlobalVariable::InternalLinkage, nullptr); + SmallVector<Constant *, 16> OffsetVec; + // FieldSize + auto *SizeArrayTy = ArrayType::get(Int32Ty, StructTy->getNumElements()); + Size = new GlobalVariable(M, SizeArrayTy, true, + GlobalVariable::InternalLinkage, nullptr); + SmallVector<Constant *, 16> SizeVec; + for (unsigned i = 0; i < StructTy->getNumElements(); ++i) { + Type *Ty = StructTy->getElementType(i); + std::string Str; + raw_string_ostream StrOS(Str); + Ty->print(StrOS); + TypeNameVec.push_back( + ConstantExpr::getPointerCast( + createPrivateGlobalForString(M, StrOS.str(), true), + Int8PtrTy)); + OffsetVec.push_back( + ConstantInt::get(Int32Ty, + DL.getStructLayout(StructTy)->getElementOffset(i))); + SizeVec.push_back(ConstantInt::get(Int32Ty, + DL.getTypeAllocSize(Ty))); + } + TypeName->setInitializer(ConstantArray::get(TypeNameArrayTy, TypeNameVec)); + Offset->setInitializer(ConstantArray::get(OffsetArrayTy, OffsetVec)); + Size->setInitializer(ConstantArray::get(SizeArrayTy, SizeVec)); +} + +// Create the global variable for the cache-fragmentation tool. +GlobalVariable *EfficiencySanitizer::createCacheFragInfoGV( + Module &M, const DataLayout &DL, Constant *UnitName) { + assert(Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag); + + auto *Int8PtrTy = Type::getInt8PtrTy(*Ctx); + auto *Int8PtrPtrTy = Int8PtrTy->getPointerTo(); + auto *Int32Ty = Type::getInt32Ty(*Ctx); + auto *Int32PtrTy = Type::getInt32PtrTy(*Ctx); + auto *Int64Ty = Type::getInt64Ty(*Ctx); + auto *Int64PtrTy = Type::getInt64PtrTy(*Ctx); + // This structure should be kept consistent with the StructInfo struct + // in the runtime library. + // struct StructInfo { + // const char *StructName; + // u32 Size; + // u32 NumFields; + // u32 *FieldOffset; // auxiliary struct field info. + // u32 *FieldSize; // auxiliary struct field info. + // const char **FieldTypeName; // auxiliary struct field info. + // u64 *FieldCounters; + // u64 *ArrayCounter; + // }; + auto *StructInfoTy = + StructType::get(Int8PtrTy, Int32Ty, Int32Ty, Int32PtrTy, Int32PtrTy, + Int8PtrPtrTy, Int64PtrTy, Int64PtrTy, nullptr); + auto *StructInfoPtrTy = StructInfoTy->getPointerTo(); + // This structure should be kept consistent with the CacheFragInfo struct + // in the runtime library. + // struct CacheFragInfo { + // const char *UnitName; + // u32 NumStructs; + // StructInfo *Structs; + // }; + auto *CacheFragInfoTy = + StructType::get(Int8PtrTy, Int32Ty, StructInfoPtrTy, nullptr); + + std::vector<StructType *> Vec = M.getIdentifiedStructTypes(); + unsigned NumStructs = 0; + SmallVector<Constant *, 16> Initializers; + + for (auto &StructTy : Vec) { + if (shouldIgnoreStructType(StructTy)) { + ++NumIgnoredStructs; + continue; + } + ++NumStructs; + + // StructName. + SmallString<MaxStructCounterNameSize> CounterNameStr; + createStructCounterName(StructTy, CounterNameStr); + GlobalVariable *StructCounterName = createPrivateGlobalForString( + M, CounterNameStr, /*AllowMerging*/true); + + // Counters. + // We create the counter array with StructCounterName and weak linkage + // so that the structs with the same name and layout from different + // compilation units will be merged into one. + auto *CounterArrayTy = ArrayType::get(Int64Ty, + getStructCounterSize(StructTy)); + GlobalVariable *Counters = + new GlobalVariable(M, CounterArrayTy, false, + GlobalVariable::WeakAnyLinkage, + ConstantAggregateZero::get(CounterArrayTy), + CounterNameStr); + + // Remember the counter variable for each struct type. + StructTyMap.insert(std::pair<Type *, GlobalVariable *>(StructTy, Counters)); + + // We pass the field type name array, offset array, and size array to + // the runtime for better reporting. + GlobalVariable *TypeName = nullptr, *Offset = nullptr, *Size = nullptr; + if (ClAuxFieldInfo) + createCacheFragAuxGV(M, DL, StructTy, TypeName, Offset, Size); + + Constant *FieldCounterIdx[2]; + FieldCounterIdx[0] = ConstantInt::get(Int32Ty, 0); + FieldCounterIdx[1] = ConstantInt::get(Int32Ty, + getFieldCounterIdx(StructTy)); + Constant *ArrayCounterIdx[2]; + ArrayCounterIdx[0] = ConstantInt::get(Int32Ty, 0); + ArrayCounterIdx[1] = ConstantInt::get(Int32Ty, + getArrayCounterIdx(StructTy)); + Initializers.push_back( + ConstantStruct::get( + StructInfoTy, + ConstantExpr::getPointerCast(StructCounterName, Int8PtrTy), + ConstantInt::get(Int32Ty, + DL.getStructLayout(StructTy)->getSizeInBytes()), + ConstantInt::get(Int32Ty, StructTy->getNumElements()), + Offset == nullptr ? ConstantPointerNull::get(Int32PtrTy) : + ConstantExpr::getPointerCast(Offset, Int32PtrTy), + Size == nullptr ? ConstantPointerNull::get(Int32PtrTy) : + ConstantExpr::getPointerCast(Size, Int32PtrTy), + TypeName == nullptr ? ConstantPointerNull::get(Int8PtrPtrTy) : + ConstantExpr::getPointerCast(TypeName, Int8PtrPtrTy), + ConstantExpr::getGetElementPtr(CounterArrayTy, Counters, + FieldCounterIdx), + ConstantExpr::getGetElementPtr(CounterArrayTy, Counters, + ArrayCounterIdx), + nullptr)); + } + // Structs. + Constant *StructInfo; + if (NumStructs == 0) { + StructInfo = ConstantPointerNull::get(StructInfoPtrTy); + } else { + auto *StructInfoArrayTy = ArrayType::get(StructInfoTy, NumStructs); + StructInfo = ConstantExpr::getPointerCast( + new GlobalVariable(M, StructInfoArrayTy, false, + GlobalVariable::InternalLinkage, + ConstantArray::get(StructInfoArrayTy, Initializers)), + StructInfoPtrTy); + } + + auto *CacheFragInfoGV = new GlobalVariable( + M, CacheFragInfoTy, true, GlobalVariable::InternalLinkage, + ConstantStruct::get(CacheFragInfoTy, + UnitName, + ConstantInt::get(Int32Ty, NumStructs), + StructInfo, + nullptr)); + return CacheFragInfoGV; +} + +// Create the tool-specific argument passed to EsanInit and EsanExit. +Constant *EfficiencySanitizer::createEsanInitToolInfoArg(Module &M, + const DataLayout &DL) { + // This structure contains tool-specific information about each compilation + // unit (module) and is passed to the runtime library. + GlobalVariable *ToolInfoGV = nullptr; + + auto *Int8PtrTy = Type::getInt8PtrTy(*Ctx); + // Compilation unit name. + auto *UnitName = ConstantExpr::getPointerCast( + createPrivateGlobalForString(M, M.getModuleIdentifier(), true), + Int8PtrTy); + + // Create the tool-specific variable. + if (Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag) + ToolInfoGV = createCacheFragInfoGV(M, DL, UnitName); + + if (ToolInfoGV != nullptr) + return ConstantExpr::getPointerCast(ToolInfoGV, Int8PtrTy); + + // Create the null pointer if no tool-specific variable created. + return ConstantPointerNull::get(Int8PtrTy); +} + +void EfficiencySanitizer::createDestructor(Module &M, Constant *ToolInfoArg) { + PointerType *Int8PtrTy = Type::getInt8PtrTy(*Ctx); + EsanDtorFunction = Function::Create(FunctionType::get(Type::getVoidTy(*Ctx), + false), + GlobalValue::InternalLinkage, + EsanModuleDtorName, &M); + ReturnInst::Create(*Ctx, BasicBlock::Create(*Ctx, "", EsanDtorFunction)); + IRBuilder<> IRB_Dtor(EsanDtorFunction->getEntryBlock().getTerminator()); + Function *EsanExit = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(EsanExitName, IRB_Dtor.getVoidTy(), + Int8PtrTy, nullptr)); + EsanExit->setLinkage(Function::ExternalLinkage); + IRB_Dtor.CreateCall(EsanExit, {ToolInfoArg}); + appendToGlobalDtors(M, EsanDtorFunction, EsanCtorAndDtorPriority); +} + +bool EfficiencySanitizer::initOnModule(Module &M) { + Ctx = &M.getContext(); + const DataLayout &DL = M.getDataLayout(); + IRBuilder<> IRB(M.getContext()); + IntegerType *OrdTy = IRB.getInt32Ty(); + PointerType *Int8PtrTy = Type::getInt8PtrTy(*Ctx); + IntptrTy = DL.getIntPtrType(M.getContext()); + // Create the variable passed to EsanInit and EsanExit. + Constant *ToolInfoArg = createEsanInitToolInfoArg(M, DL); + // Constructor + // We specify the tool type both in the EsanWhichToolName global + // and as an arg to the init routine as a sanity check. + std::tie(EsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions( + M, EsanModuleCtorName, EsanInitName, /*InitArgTypes=*/{OrdTy, Int8PtrTy}, + /*InitArgs=*/{ + ConstantInt::get(OrdTy, static_cast<int>(Options.ToolType)), + ToolInfoArg}); + appendToGlobalCtors(M, EsanCtorFunction, EsanCtorAndDtorPriority); + + createDestructor(M, ToolInfoArg); + + new GlobalVariable(M, OrdTy, true, + GlobalValue::WeakAnyLinkage, + ConstantInt::get(OrdTy, + static_cast<int>(Options.ToolType)), + EsanWhichToolName); + + return true; +} + +Value *EfficiencySanitizer::appToShadow(Value *Shadow, IRBuilder<> &IRB) { + // Shadow = ((App & Mask) + Offs) >> Scale + Shadow = IRB.CreateAnd(Shadow, ConstantInt::get(IntptrTy, ShadowMask)); + uint64_t Offs; + int Scale = ShadowScale[Options.ToolType]; + if (Scale <= 2) + Offs = ShadowOffs[Scale]; + else + Offs = ShadowOffs[0] << Scale; + Shadow = IRB.CreateAdd(Shadow, ConstantInt::get(IntptrTy, Offs)); + if (Scale > 0) + Shadow = IRB.CreateLShr(Shadow, Scale); + return Shadow; +} + +bool EfficiencySanitizer::shouldIgnoreMemoryAccess(Instruction *I) { + if (Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag) { + // We'd like to know about cache fragmentation in vtable accesses and + // constant data references, so we do not currently ignore anything. + return false; + } else if (Options.ToolType == EfficiencySanitizerOptions::ESAN_WorkingSet) { + // TODO: the instrumentation disturbs the data layout on the stack, so we + // may want to add an option to ignore stack references (if we can + // distinguish them) to reduce overhead. + } + // TODO(bruening): future tools will be returning true for some cases. + return false; +} + +bool EfficiencySanitizer::runOnModule(Module &M) { + bool Res = initOnModule(M); + initializeCallbacks(M); + for (auto &F : M) { + Res |= runOnFunction(F, M); + } + return Res; +} + +bool EfficiencySanitizer::runOnFunction(Function &F, Module &M) { + // This is required to prevent instrumenting the call to __esan_init from + // within the module constructor. + if (&F == EsanCtorFunction) + return false; + SmallVector<Instruction *, 8> LoadsAndStores; + SmallVector<Instruction *, 8> MemIntrinCalls; + SmallVector<Instruction *, 8> GetElementPtrs; + bool Res = false; + const DataLayout &DL = M.getDataLayout(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + for (auto &BB : F) { + for (auto &Inst : BB) { + if ((isa<LoadInst>(Inst) || isa<StoreInst>(Inst) || + isa<AtomicRMWInst>(Inst) || isa<AtomicCmpXchgInst>(Inst)) && + !shouldIgnoreMemoryAccess(&Inst)) + LoadsAndStores.push_back(&Inst); + else if (isa<MemIntrinsic>(Inst)) + MemIntrinCalls.push_back(&Inst); + else if (isa<GetElementPtrInst>(Inst)) + GetElementPtrs.push_back(&Inst); + else if (CallInst *CI = dyn_cast<CallInst>(&Inst)) + maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI); + } + } + + if (ClInstrumentLoadsAndStores) { + for (auto Inst : LoadsAndStores) { + Res |= instrumentLoadOrStore(Inst, DL); + } + } + + if (ClInstrumentMemIntrinsics) { + for (auto Inst : MemIntrinCalls) { + Res |= instrumentMemIntrinsic(cast<MemIntrinsic>(Inst)); + } + } + + if (Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag) { + for (auto Inst : GetElementPtrs) { + Res |= instrumentGetElementPtr(Inst, M); + } + } + + return Res; +} + +bool EfficiencySanitizer::instrumentLoadOrStore(Instruction *I, + const DataLayout &DL) { + IRBuilder<> IRB(I); + bool IsStore; + Value *Addr; + unsigned Alignment; + if (LoadInst *Load = dyn_cast<LoadInst>(I)) { + IsStore = false; + Alignment = Load->getAlignment(); + Addr = Load->getPointerOperand(); + } else if (StoreInst *Store = dyn_cast<StoreInst>(I)) { + IsStore = true; + Alignment = Store->getAlignment(); + Addr = Store->getPointerOperand(); + } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) { + IsStore = true; + Alignment = 0; + Addr = RMW->getPointerOperand(); + } else if (AtomicCmpXchgInst *Xchg = dyn_cast<AtomicCmpXchgInst>(I)) { + IsStore = true; + Alignment = 0; + Addr = Xchg->getPointerOperand(); + } else + llvm_unreachable("Unsupported mem access type"); + + Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType(); + const uint32_t TypeSizeBytes = DL.getTypeStoreSizeInBits(OrigTy) / 8; + Value *OnAccessFunc = nullptr; + + // Convert 0 to the default alignment. + if (Alignment == 0) + Alignment = DL.getPrefTypeAlignment(OrigTy); + + if (IsStore) + NumInstrumentedStores++; + else + NumInstrumentedLoads++; + int Idx = getMemoryAccessFuncIndex(Addr, DL); + if (Idx < 0) { + OnAccessFunc = IsStore ? EsanUnalignedStoreN : EsanUnalignedLoadN; + IRB.CreateCall(OnAccessFunc, + {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()), + ConstantInt::get(IntptrTy, TypeSizeBytes)}); + } else { + if (ClInstrumentFastpath && + instrumentFastpath(I, DL, IsStore, Addr, Alignment)) { + NumFastpaths++; + return true; + } + if (Alignment == 0 || (Alignment % TypeSizeBytes) == 0) + OnAccessFunc = IsStore ? EsanAlignedStore[Idx] : EsanAlignedLoad[Idx]; + else + OnAccessFunc = IsStore ? EsanUnalignedStore[Idx] : EsanUnalignedLoad[Idx]; + IRB.CreateCall(OnAccessFunc, + IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy())); + } + return true; +} + +// It's simplest to replace the memset/memmove/memcpy intrinsics with +// calls that the runtime library intercepts. +// Our pass is late enough that calls should not turn back into intrinsics. +bool EfficiencySanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { + IRBuilder<> IRB(MI); + bool Res = false; + if (isa<MemSetInst>(MI)) { + IRB.CreateCall( + MemsetFn, + {IRB.CreatePointerCast(MI->getArgOperand(0), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getArgOperand(1), IRB.getInt32Ty(), false), + IRB.CreateIntCast(MI->getArgOperand(2), IntptrTy, false)}); + MI->eraseFromParent(); + Res = true; + } else if (isa<MemTransferInst>(MI)) { + IRB.CreateCall( + isa<MemCpyInst>(MI) ? MemcpyFn : MemmoveFn, + {IRB.CreatePointerCast(MI->getArgOperand(0), IRB.getInt8PtrTy()), + IRB.CreatePointerCast(MI->getArgOperand(1), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getArgOperand(2), IntptrTy, false)}); + MI->eraseFromParent(); + Res = true; + } else + llvm_unreachable("Unsupported mem intrinsic type"); + return Res; +} + +bool EfficiencySanitizer::instrumentGetElementPtr(Instruction *I, Module &M) { + GetElementPtrInst *GepInst = dyn_cast<GetElementPtrInst>(I); + bool Res = false; + if (GepInst == nullptr || GepInst->getNumIndices() == 1) { + ++NumIgnoredGEPs; + return false; + } + Type *SourceTy = GepInst->getSourceElementType(); + StructType *StructTy; + ConstantInt *Idx; + // Check if GEP calculates address from a struct array. + if (isa<StructType>(SourceTy)) { + StructTy = cast<StructType>(SourceTy); + Idx = dyn_cast<ConstantInt>(GepInst->getOperand(1)); + if ((Idx == nullptr || Idx->getSExtValue() != 0) && + !shouldIgnoreStructType(StructTy) && StructTyMap.count(StructTy) != 0) + Res |= insertCounterUpdate(I, StructTy, getArrayCounterIdx(StructTy)); + } + // Iterate all (except the first and the last) idx within each GEP instruction + // for possible nested struct field address calculation. + for (unsigned i = 1; i < GepInst->getNumIndices(); ++i) { + SmallVector<Value *, 8> IdxVec(GepInst->idx_begin(), + GepInst->idx_begin() + i); + Type *Ty = GetElementPtrInst::getIndexedType(SourceTy, IdxVec); + unsigned CounterIdx = 0; + if (isa<ArrayType>(Ty)) { + ArrayType *ArrayTy = cast<ArrayType>(Ty); + StructTy = dyn_cast<StructType>(ArrayTy->getElementType()); + if (shouldIgnoreStructType(StructTy) || StructTyMap.count(StructTy) == 0) + continue; + // The last counter for struct array access. + CounterIdx = getArrayCounterIdx(StructTy); + } else if (isa<StructType>(Ty)) { + StructTy = cast<StructType>(Ty); + if (shouldIgnoreStructType(StructTy) || StructTyMap.count(StructTy) == 0) + continue; + // Get the StructTy's subfield index. + Idx = cast<ConstantInt>(GepInst->getOperand(i+1)); + assert(Idx->getSExtValue() >= 0 && + Idx->getSExtValue() < StructTy->getNumElements()); + CounterIdx = getFieldCounterIdx(StructTy) + Idx->getSExtValue(); + } + Res |= insertCounterUpdate(I, StructTy, CounterIdx); + } + if (Res) + ++NumInstrumentedGEPs; + else + ++NumIgnoredGEPs; + return Res; +} + +bool EfficiencySanitizer::insertCounterUpdate(Instruction *I, + StructType *StructTy, + unsigned CounterIdx) { + GlobalVariable *CounterArray = StructTyMap[StructTy]; + if (CounterArray == nullptr) + return false; + IRBuilder<> IRB(I); + Constant *Indices[2]; + // Xref http://llvm.org/docs/LangRef.html#i-getelementptr and + // http://llvm.org/docs/GetElementPtr.html. + // The first index of the GEP instruction steps through the first operand, + // i.e., the array itself. + Indices[0] = ConstantInt::get(IRB.getInt32Ty(), 0); + // The second index is the index within the array. + Indices[1] = ConstantInt::get(IRB.getInt32Ty(), CounterIdx); + Constant *Counter = + ConstantExpr::getGetElementPtr( + ArrayType::get(IRB.getInt64Ty(), getStructCounterSize(StructTy)), + CounterArray, Indices); + Value *Load = IRB.CreateLoad(Counter); + IRB.CreateStore(IRB.CreateAdd(Load, ConstantInt::get(IRB.getInt64Ty(), 1)), + Counter); + return true; +} + +int EfficiencySanitizer::getMemoryAccessFuncIndex(Value *Addr, + const DataLayout &DL) { + Type *OrigPtrTy = Addr->getType(); + Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType(); + assert(OrigTy->isSized()); + // The size is always a multiple of 8. + uint32_t TypeSizeBytes = DL.getTypeStoreSizeInBits(OrigTy) / 8; + if (TypeSizeBytes != 1 && TypeSizeBytes != 2 && TypeSizeBytes != 4 && + TypeSizeBytes != 8 && TypeSizeBytes != 16) { + // Irregular sizes do not have per-size call targets. + NumAccessesWithIrregularSize++; + return -1; + } + size_t Idx = countTrailingZeros(TypeSizeBytes); + assert(Idx < NumberOfAccessSizes); + return Idx; +} + +bool EfficiencySanitizer::instrumentFastpath(Instruction *I, + const DataLayout &DL, bool IsStore, + Value *Addr, unsigned Alignment) { + if (Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag) { + return instrumentFastpathCacheFrag(I, DL, Addr, Alignment); + } else if (Options.ToolType == EfficiencySanitizerOptions::ESAN_WorkingSet) { + return instrumentFastpathWorkingSet(I, DL, Addr, Alignment); + } + return false; +} + +bool EfficiencySanitizer::instrumentFastpathCacheFrag(Instruction *I, + const DataLayout &DL, + Value *Addr, + unsigned Alignment) { + // Do nothing. + return true; // Return true to avoid slowpath instrumentation. +} + +bool EfficiencySanitizer::instrumentFastpathWorkingSet( + Instruction *I, const DataLayout &DL, Value *Addr, unsigned Alignment) { + assert(ShadowScale[Options.ToolType] == 6); // The code below assumes this + IRBuilder<> IRB(I); + Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType(); + const uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); + // Bail to the slowpath if the access might touch multiple cache lines. + // An access aligned to its size is guaranteed to be intra-cache-line. + // getMemoryAccessFuncIndex has already ruled out a size larger than 16 + // and thus larger than a cache line for platforms this tool targets + // (and our shadow memory setup assumes 64-byte cache lines). + assert(TypeSize <= 128); + if (!(TypeSize == 8 || + (Alignment % (TypeSize / 8)) == 0)) { + if (ClAssumeIntraCacheLine) + ++NumAssumedIntraCacheLine; + else + return false; + } + + // We inline instrumentation to set the corresponding shadow bits for + // each cache line touched by the application. Here we handle a single + // load or store where we've already ruled out the possibility that it + // might touch more than one cache line and thus we simply update the + // shadow memory for a single cache line. + // Our shadow memory model is fine with races when manipulating shadow values. + // We generate the following code: + // + // const char BitMask = 0x81; + // char *ShadowAddr = appToShadow(AppAddr); + // if ((*ShadowAddr & BitMask) != BitMask) + // *ShadowAddr |= Bitmask; + // + Value *AddrPtr = IRB.CreatePointerCast(Addr, IntptrTy); + Value *ShadowPtr = appToShadow(AddrPtr, IRB); + Type *ShadowTy = IntegerType::get(*Ctx, 8U); + Type *ShadowPtrTy = PointerType::get(ShadowTy, 0); + // The bottom bit is used for the current sampling period's working set. + // The top bit is used for the total working set. We set both on each + // memory access, if they are not already set. + Value *ValueMask = ConstantInt::get(ShadowTy, 0x81); // 10000001B + + Value *OldValue = IRB.CreateLoad(IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); + // The AND and CMP will be turned into a TEST instruction by the compiler. + Value *Cmp = IRB.CreateICmpNE(IRB.CreateAnd(OldValue, ValueMask), ValueMask); + TerminatorInst *CmpTerm = SplitBlockAndInsertIfThen(Cmp, I, false); + // FIXME: do I need to call SetCurrentDebugLocation? + IRB.SetInsertPoint(CmpTerm); + // We use OR to set the shadow bits to avoid corrupting the middle 6 bits, + // which are used by the runtime library. + Value *NewVal = IRB.CreateOr(OldValue, ValueMask); + IRB.CreateStore(NewVal, IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); + IRB.SetInsertPoint(I); + + return true; +} diff --git a/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/lib/Transforms/Instrumentation/GCOVProfiling.cpp index ffde7f8d9bae..b4070b602768 100644 --- a/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -14,7 +14,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" @@ -35,6 +34,8 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/GCOVProfiler.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <memory> @@ -68,86 +69,93 @@ GCOVOptions GCOVOptions::getDefault() { } namespace { - class GCOVFunction; +class GCOVFunction; + +class GCOVProfiler { +public: + GCOVProfiler() : GCOVProfiler(GCOVOptions::getDefault()) {} + GCOVProfiler(const GCOVOptions &Opts) : Options(Opts) { + assert((Options.EmitNotes || Options.EmitData) && + "GCOVProfiler asked to do nothing?"); + ReversedVersion[0] = Options.Version[3]; + ReversedVersion[1] = Options.Version[2]; + ReversedVersion[2] = Options.Version[1]; + ReversedVersion[3] = Options.Version[0]; + ReversedVersion[4] = '\0'; + } + bool runOnModule(Module &M); + +private: + // Create the .gcno files for the Module based on DebugInfo. + void emitProfileNotes(); + + // Modify the program to track transitions along edges and call into the + // profiling runtime to emit .gcda files when run. + bool emitProfileArcs(); + + // Get pointers to the functions in the runtime library. + Constant *getStartFileFunc(); + Constant *getIncrementIndirectCounterFunc(); + Constant *getEmitFunctionFunc(); + Constant *getEmitArcsFunc(); + Constant *getSummaryInfoFunc(); + Constant *getEndFileFunc(); + + // Create or retrieve an i32 state value that is used to represent the + // pred block number for certain non-trivial edges. + GlobalVariable *getEdgeStateValue(); + + // Produce a table of pointers to counters, by predecessor and successor + // block number. + GlobalVariable *buildEdgeLookupTable(Function *F, GlobalVariable *Counter, + const UniqueVector<BasicBlock *> &Preds, + const UniqueVector<BasicBlock *> &Succs); + + // Add the function to write out all our counters to the global destructor + // list. + Function * + insertCounterWriteout(ArrayRef<std::pair<GlobalVariable *, MDNode *>>); + Function *insertFlush(ArrayRef<std::pair<GlobalVariable *, MDNode *>>); + void insertIndirectCounterIncrement(); + + std::string mangleName(const DICompileUnit *CU, const char *NewStem); - class GCOVProfiler : public ModulePass { - public: - static char ID; - GCOVProfiler() : GCOVProfiler(GCOVOptions::getDefault()) {} - GCOVProfiler(const GCOVOptions &Opts) : ModulePass(ID), Options(Opts) { - assert((Options.EmitNotes || Options.EmitData) && - "GCOVProfiler asked to do nothing?"); - ReversedVersion[0] = Options.Version[3]; - ReversedVersion[1] = Options.Version[2]; - ReversedVersion[2] = Options.Version[1]; - ReversedVersion[3] = Options.Version[0]; - ReversedVersion[4] = '\0'; - initializeGCOVProfilerPass(*PassRegistry::getPassRegistry()); - } - const char *getPassName() const override { - return "GCOV Profiler"; - } + GCOVOptions Options; - private: - bool runOnModule(Module &M) override; - - // Create the .gcno files for the Module based on DebugInfo. - void emitProfileNotes(); - - // Modify the program to track transitions along edges and call into the - // profiling runtime to emit .gcda files when run. - bool emitProfileArcs(); - - // Get pointers to the functions in the runtime library. - Constant *getStartFileFunc(); - Constant *getIncrementIndirectCounterFunc(); - Constant *getEmitFunctionFunc(); - Constant *getEmitArcsFunc(); - Constant *getSummaryInfoFunc(); - Constant *getDeleteWriteoutFunctionListFunc(); - Constant *getDeleteFlushFunctionListFunc(); - Constant *getEndFileFunc(); - - // Create or retrieve an i32 state value that is used to represent the - // pred block number for certain non-trivial edges. - GlobalVariable *getEdgeStateValue(); - - // Produce a table of pointers to counters, by predecessor and successor - // block number. - GlobalVariable *buildEdgeLookupTable(Function *F, - GlobalVariable *Counter, - const UniqueVector<BasicBlock *>&Preds, - const UniqueVector<BasicBlock*>&Succs); - - // Add the function to write out all our counters to the global destructor - // list. - Function *insertCounterWriteout(ArrayRef<std::pair<GlobalVariable*, - MDNode*> >); - Function *insertFlush(ArrayRef<std::pair<GlobalVariable*, MDNode*> >); - void insertIndirectCounterIncrement(); - - std::string mangleName(const DICompileUnit *CU, const char *NewStem); - - GCOVOptions Options; - - // Reversed, NUL-terminated copy of Options.Version. - char ReversedVersion[5]; - // Checksum, produced by hash of EdgeDestinations - SmallVector<uint32_t, 4> FileChecksums; - - Module *M; - LLVMContext *Ctx; - SmallVector<std::unique_ptr<GCOVFunction>, 16> Funcs; - DenseMap<DISubprogram *, Function *> FnMap; - }; + // Reversed, NUL-terminated copy of Options.Version. + char ReversedVersion[5]; + // Checksum, produced by hash of EdgeDestinations + SmallVector<uint32_t, 4> FileChecksums; + + Module *M; + LLVMContext *Ctx; + SmallVector<std::unique_ptr<GCOVFunction>, 16> Funcs; +}; + +class GCOVProfilerLegacyPass : public ModulePass { +public: + static char ID; + GCOVProfilerLegacyPass() + : GCOVProfilerLegacyPass(GCOVOptions::getDefault()) {} + GCOVProfilerLegacyPass(const GCOVOptions &Opts) + : ModulePass(ID), Profiler(Opts) { + initializeGCOVProfilerLegacyPassPass(*PassRegistry::getPassRegistry()); + } + const char *getPassName() const override { return "GCOV Profiler"; } + + bool runOnModule(Module &M) override { return Profiler.runOnModule(M); } + +private: + GCOVProfiler Profiler; +}; } -char GCOVProfiler::ID = 0; -INITIALIZE_PASS(GCOVProfiler, "insert-gcov-profiling", +char GCOVProfilerLegacyPass::ID = 0; +INITIALIZE_PASS(GCOVProfilerLegacyPass, "insert-gcov-profiling", "Insert instrumentation for GCOV profiling", false, false) ModulePass *llvm::createGCOVProfilerPass(const GCOVOptions &Options) { - return new GCOVProfiler(Options); + return new GCOVProfilerLegacyPass(Options); } static StringRef getFunctionName(const DISubprogram *SP) { @@ -257,10 +265,9 @@ namespace { void writeOut() { uint32_t Len = 3; SmallVector<StringMapEntry<GCOVLines *> *, 32> SortedLinesByFile; - for (StringMap<GCOVLines *>::iterator I = LinesByFile.begin(), - E = LinesByFile.end(); I != E; ++I) { - Len += I->second->length(); - SortedLinesByFile.push_back(&*I); + for (auto &I : LinesByFile) { + Len += I.second->length(); + SortedLinesByFile.push_back(&I); } writeBytes(LinesTag, 4); @@ -272,10 +279,8 @@ namespace { StringMapEntry<GCOVLines *> *RHS) { return LHS->getKey() < RHS->getKey(); }); - for (SmallVectorImpl<StringMapEntry<GCOVLines *> *>::iterator - I = SortedLinesByFile.begin(), E = SortedLinesByFile.end(); - I != E; ++I) - (*I)->getValue()->writeOut(); + for (auto &I : SortedLinesByFile) + I->getValue()->writeOut(); write(0); write(0); } @@ -450,28 +455,32 @@ bool GCOVProfiler::runOnModule(Module &M) { this->M = &M; Ctx = &M.getContext(); - FnMap.clear(); - for (Function &F : M) { - if (DISubprogram *SP = F.getSubprogram()) - FnMap[SP] = &F; - } - if (Options.EmitNotes) emitProfileNotes(); if (Options.EmitData) return emitProfileArcs(); return false; } -static bool functionHasLines(Function *F) { +PreservedAnalyses GCOVProfilerPass::run(Module &M, + AnalysisManager<Module> &AM) { + + GCOVProfiler Profiler(GCOVOpts); + + if (!Profiler.runOnModule(M)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +static bool functionHasLines(Function &F) { // Check whether this function actually has any source lines. Not only // do these waste space, they also can crash gcov. - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { - for (BasicBlock::iterator I = BB->begin(), IE = BB->end(); - I != IE; ++I) { + for (auto &BB : F) { + for (auto &I : BB) { // Debug intrinsic locations correspond to the location of the // declaration, not necessarily any statements or expressions. - if (isa<DbgInfoIntrinsic>(I)) continue; + if (isa<DbgInfoIntrinsic>(&I)) continue; - const DebugLoc &Loc = I->getDebugLoc(); + const DebugLoc &Loc = I.getDebugLoc(); if (!Loc) continue; @@ -504,27 +513,27 @@ void GCOVProfiler::emitProfileNotes() { std::string EdgeDestinations; unsigned FunctionIdent = 0; - for (auto *SP : CU->getSubprograms()) { - Function *F = FnMap[SP]; - if (!F) continue; + for (auto &F : M->functions()) { + DISubprogram *SP = F.getSubprogram(); + if (!SP) continue; if (!functionHasLines(F)) continue; // gcov expects every function to start with an entry block that has a // single successor, so split the entry block to make sure of that. - BasicBlock &EntryBlock = F->getEntryBlock(); + BasicBlock &EntryBlock = F.getEntryBlock(); BasicBlock::iterator It = EntryBlock.begin(); while (isa<AllocaInst>(*It) || isa<DbgInfoIntrinsic>(*It)) ++It; EntryBlock.splitBasicBlock(It); - Funcs.push_back(make_unique<GCOVFunction>(SP, F, &out, FunctionIdent++, + Funcs.push_back(make_unique<GCOVFunction>(SP, &F, &out, FunctionIdent++, Options.UseCfgChecksum, Options.ExitBlockBeforeBody)); GCOVFunction &Func = *Funcs.back(); - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { - GCOVBlock &Block = Func.getBlock(&*BB); - TerminatorInst *TI = BB->getTerminator(); + for (auto &BB : F) { + GCOVBlock &Block = Func.getBlock(&BB); + TerminatorInst *TI = BB.getTerminator(); if (int successors = TI->getNumSuccessors()) { for (int i = 0; i != successors; ++i) { Block.addEdge(Func.getBlock(TI->getSuccessor(i))); @@ -534,13 +543,12 @@ void GCOVProfiler::emitProfileNotes() { } uint32_t Line = 0; - for (BasicBlock::iterator I = BB->begin(), IE = BB->end(); - I != IE; ++I) { + for (auto &I : BB) { // Debug intrinsic locations correspond to the location of the // declaration, not necessarily any statements or expressions. - if (isa<DbgInfoIntrinsic>(I)) continue; + if (isa<DbgInfoIntrinsic>(&I)) continue; - const DebugLoc &Loc = I->getDebugLoc(); + const DebugLoc &Loc = I.getDebugLoc(); if (!Loc) continue; @@ -581,16 +589,15 @@ bool GCOVProfiler::emitProfileArcs() { bool Result = false; bool InsertIndCounterIncrCode = false; for (unsigned i = 0, e = CU_Nodes->getNumOperands(); i != e; ++i) { - auto *CU = cast<DICompileUnit>(CU_Nodes->getOperand(i)); SmallVector<std::pair<GlobalVariable *, MDNode *>, 8> CountersBySP; - for (auto *SP : CU->getSubprograms()) { - Function *F = FnMap[SP]; - if (!F) continue; + for (auto &F : M->functions()) { + DISubprogram *SP = F.getSubprogram(); + if (!SP) continue; if (!functionHasLines(F)) continue; if (!Result) Result = true; unsigned Edges = 0; - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { - TerminatorInst *TI = BB->getTerminator(); + for (auto &BB : F) { + TerminatorInst *TI = BB.getTerminator(); if (isa<ReturnInst>(TI)) ++Edges; else @@ -610,12 +617,12 @@ bool GCOVProfiler::emitProfileArcs() { UniqueVector<BasicBlock *> ComplexEdgeSuccs; unsigned Edge = 0; - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { - TerminatorInst *TI = BB->getTerminator(); + for (auto &BB : F) { + TerminatorInst *TI = BB.getTerminator(); int Successors = isa<ReturnInst>(TI) ? 1 : TI->getNumSuccessors(); if (Successors) { if (Successors == 1) { - IRBuilder<> Builder(&*BB->getFirstInsertionPt()); + IRBuilder<> Builder(&*BB.getFirstInsertionPt()); Value *Counter = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Edge); Value *Count = Builder.CreateLoad(Counter); @@ -626,16 +633,13 @@ bool GCOVProfiler::emitProfileArcs() { Value *Sel = Builder.CreateSelect(BI->getCondition(), Builder.getInt64(Edge), Builder.getInt64(Edge + 1)); - SmallVector<Value *, 2> Idx; - Idx.push_back(Builder.getInt64(0)); - Idx.push_back(Sel); - Value *Counter = Builder.CreateInBoundsGEP(Counters->getValueType(), - Counters, Idx); + Value *Counter = Builder.CreateInBoundsGEP( + Counters->getValueType(), Counters, {Builder.getInt64(0), Sel}); Value *Count = Builder.CreateLoad(Counter); Count = Builder.CreateAdd(Count, Builder.getInt64(1)); Builder.CreateStore(Count, Counter); } else { - ComplexEdgePreds.insert(&*BB); + ComplexEdgePreds.insert(&BB); for (int i = 0; i != Successors; ++i) ComplexEdgeSuccs.insert(TI->getSuccessor(i)); } @@ -646,7 +650,7 @@ bool GCOVProfiler::emitProfileArcs() { if (!ComplexEdgePreds.empty()) { GlobalVariable *EdgeTable = - buildEdgeLookupTable(F, Counters, + buildEdgeLookupTable(&F, Counters, ComplexEdgePreds, ComplexEdgeSuccs); GlobalVariable *EdgeState = getEdgeStateValue(); @@ -679,7 +683,7 @@ bool GCOVProfiler::emitProfileArcs() { FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); Function *F = Function::Create(FTy, GlobalValue::InternalLinkage, "__llvm_gcov_init", M); - F->setUnnamedAddr(true); + F->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); F->setLinkage(GlobalValue::InternalLinkage); F->addFnAttr(Attribute::NoInline); if (Options.NoRedZone) @@ -732,8 +736,8 @@ GlobalVariable *GCOVProfiler::buildEdgeLookupTable( EdgeTable[i] = NullValue; unsigned Edge = 0; - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { - TerminatorInst *TI = BB->getTerminator(); + for (BasicBlock &BB : *F) { + TerminatorInst *TI = BB.getTerminator(); int Successors = isa<ReturnInst>(TI) ? 1 : TI->getNumSuccessors(); if (Successors > 1 && !isa<BranchInst>(TI) && !isa<ReturnInst>(TI)) { for (int i = 0; i != Successors; ++i) { @@ -742,7 +746,7 @@ GlobalVariable *GCOVProfiler::buildEdgeLookupTable( Value *Counter = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Edge + i); EdgeTable[((Succs.idFor(Succ) - 1) * Preds.size()) + - (Preds.idFor(&*BB) - 1)] = cast<Constant>(Counter); + (Preds.idFor(&BB) - 1)] = cast<Constant>(Counter); } } Edge += Successors; @@ -754,7 +758,7 @@ GlobalVariable *GCOVProfiler::buildEdgeLookupTable( ConstantArray::get(EdgeTableTy, makeArrayRef(&EdgeTable[0],TableSize)), "__llvm_gcda_edge_table"); - EdgeTableGV->setUnnamedAddr(true); + EdgeTableGV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); return EdgeTableGV; } @@ -805,16 +809,6 @@ Constant *GCOVProfiler::getSummaryInfoFunc() { return M->getOrInsertFunction("llvm_gcda_summary_info", FTy); } -Constant *GCOVProfiler::getDeleteWriteoutFunctionListFunc() { - FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); - return M->getOrInsertFunction("llvm_delete_writeout_function_list", FTy); -} - -Constant *GCOVProfiler::getDeleteFlushFunctionListFunc() { - FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); - return M->getOrInsertFunction("llvm_delete_flush_function_list", FTy); -} - Constant *GCOVProfiler::getEndFileFunc() { FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); return M->getOrInsertFunction("llvm_gcda_end_file", FTy); @@ -828,7 +822,7 @@ GlobalVariable *GCOVProfiler::getEdgeStateValue() { ConstantInt::get(Type::getInt32Ty(*Ctx), 0xffffffff), "__llvm_gcov_global_state_pred"); - GV->setUnnamedAddr(true); + GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); } return GV; } @@ -840,7 +834,7 @@ Function *GCOVProfiler::insertCounterWriteout( if (!WriteoutF) WriteoutF = Function::Create(WriteoutFTy, GlobalValue::InternalLinkage, "__llvm_gcov_writeout", M); - WriteoutF->setUnnamedAddr(true); + WriteoutF->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); WriteoutF->addFnAttr(Attribute::NoInline); if (Options.NoRedZone) WriteoutF->addFnAttr(Attribute::NoRedZone); @@ -884,7 +878,7 @@ Function *GCOVProfiler::insertCounterWriteout( GlobalVariable *GV = CountersBySP[j].first; unsigned Arcs = - cast<ArrayType>(GV->getType()->getElementType())->getNumElements(); + cast<ArrayType>(GV->getValueType())->getNumElements(); Builder.CreateCall(EmitArcs, {Builder.getInt32(Arcs), Builder.CreateConstGEP2_64(GV, 0, 0)}); } @@ -900,7 +894,7 @@ Function *GCOVProfiler::insertCounterWriteout( void GCOVProfiler::insertIndirectCounterIncrement() { Function *Fn = cast<Function>(GCOVProfiler::getIncrementIndirectCounterFunc()); - Fn->setUnnamedAddr(true); + Fn->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); Fn->setLinkage(GlobalValue::InternalLinkage); Fn->addFnAttr(Attribute::NoInline); if (Options.NoRedZone) @@ -957,7 +951,7 @@ insertFlush(ArrayRef<std::pair<GlobalVariable*, MDNode*> > CountersBySP) { "__llvm_gcov_flush", M); else FlushF->setLinkage(GlobalValue::InternalLinkage); - FlushF->setUnnamedAddr(true); + FlushF->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); FlushF->addFnAttr(Attribute::NoInline); if (Options.NoRedZone) FlushF->addFnAttr(Attribute::NoRedZone); @@ -972,11 +966,9 @@ insertFlush(ArrayRef<std::pair<GlobalVariable*, MDNode*> > CountersBySP) { Builder.CreateCall(WriteoutF, {}); // Zero out the counters. - for (ArrayRef<std::pair<GlobalVariable *, MDNode *> >::iterator - I = CountersBySP.begin(), E = CountersBySP.end(); - I != E; ++I) { - GlobalVariable *GV = I->first; - Constant *Null = Constant::getNullValue(GV->getType()->getElementType()); + for (const auto &I : CountersBySP) { + GlobalVariable *GV = I.first; + Constant *Null = Constant::getNullValue(GV->getValueType()); Builder.CreateStore(Null, GV); } diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp new file mode 100644 index 000000000000..202b94b19c4c --- /dev/null +++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -0,0 +1,661 @@ +//===-- IndirectCallPromotion.cpp - Promote indirect calls to direct calls ===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the transformation that promotes indirect calls to +// conditional direct calls when the indirect-call value profile metadata is +// available. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Triple.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/IndirectCallPromotionAnalysis.h" +#include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProfReader.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/PGOInstrumentation.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <string> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "pgo-icall-prom" + +STATISTIC(NumOfPGOICallPromotion, "Number of indirect call promotions."); +STATISTIC(NumOfPGOICallsites, "Number of indirect call candidate sites."); + +// Command line option to disable indirect-call promotion with the default as +// false. This is for debug purpose. +static cl::opt<bool> DisableICP("disable-icp", cl::init(false), cl::Hidden, + cl::desc("Disable indirect call promotion")); + +// Set the cutoff value for the promotion. If the value is other than 0, we +// stop the transformation once the total number of promotions equals the cutoff +// value. +// For debug use only. +static cl::opt<unsigned> + ICPCutOff("icp-cutoff", cl::init(0), cl::Hidden, cl::ZeroOrMore, + cl::desc("Max number of promotions for this compilaiton")); + +// If ICPCSSkip is non zero, the first ICPCSSkip callsites will be skipped. +// For debug use only. +static cl::opt<unsigned> + ICPCSSkip("icp-csskip", cl::init(0), cl::Hidden, cl::ZeroOrMore, + cl::desc("Skip Callsite up to this number for this compilaiton")); + +// Set if the pass is called in LTO optimization. The difference for LTO mode +// is the pass won't prefix the source module name to the internal linkage +// symbols. +static cl::opt<bool> ICPLTOMode("icp-lto", cl::init(false), cl::Hidden, + cl::desc("Run indirect-call promotion in LTO " + "mode")); + +// If the option is set to true, only call instructions will be considered for +// transformation -- invoke instructions will be ignored. +static cl::opt<bool> + ICPCallOnly("icp-call-only", cl::init(false), cl::Hidden, + cl::desc("Run indirect-call promotion for call instructions " + "only")); + +// If the option is set to true, only invoke instructions will be considered for +// transformation -- call instructions will be ignored. +static cl::opt<bool> ICPInvokeOnly("icp-invoke-only", cl::init(false), + cl::Hidden, + cl::desc("Run indirect-call promotion for " + "invoke instruction only")); + +// Dump the function level IR if the transformation happened in this +// function. For debug use only. +static cl::opt<bool> + ICPDUMPAFTER("icp-dumpafter", cl::init(false), cl::Hidden, + cl::desc("Dump IR after transformation happens")); + +namespace { +class PGOIndirectCallPromotionLegacyPass : public ModulePass { +public: + static char ID; + + PGOIndirectCallPromotionLegacyPass(bool InLTO = false) + : ModulePass(ID), InLTO(InLTO) { + initializePGOIndirectCallPromotionLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + const char *getPassName() const override { + return "PGOIndirectCallPromotion"; + } + +private: + bool runOnModule(Module &M) override; + + // If this pass is called in LTO. We need to special handling the PGOFuncName + // for the static variables due to LTO's internalization. + bool InLTO; +}; +} // end anonymous namespace + +char PGOIndirectCallPromotionLegacyPass::ID = 0; +INITIALIZE_PASS(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom", + "Use PGO instrumentation profile to promote indirect calls to " + "direct calls.", + false, false) + +ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO) { + return new PGOIndirectCallPromotionLegacyPass(InLTO); +} + +namespace { +// The class for main data structure to promote indirect calls to conditional +// direct calls. +class ICallPromotionFunc { +private: + Function &F; + Module *M; + + // Symtab that maps indirect call profile values to function names and + // defines. + InstrProfSymtab *Symtab; + + enum TargetStatus { + OK, // Should be able to promote. + NotAvailableInModule, // Cannot find the target in current module. + ReturnTypeMismatch, // Return type mismatch b/w target and indirect-call. + NumArgsMismatch, // Number of arguments does not match. + ArgTypeMismatch // Type mismatch in the arguments (cannot bitcast). + }; + + // Test if we can legally promote this direct-call of Target. + TargetStatus isPromotionLegal(Instruction *Inst, uint64_t Target, + Function *&F); + + // A struct that records the direct target and it's call count. + struct PromotionCandidate { + Function *TargetFunction; + uint64_t Count; + PromotionCandidate(Function *F, uint64_t C) : TargetFunction(F), Count(C) {} + }; + + // Check if the indirect-call call site should be promoted. Return the number + // of promotions. Inst is the candidate indirect call, ValueDataRef + // contains the array of value profile data for profiled targets, + // TotalCount is the total profiled count of call executions, and + // NumCandidates is the number of candidate entries in ValueDataRef. + std::vector<PromotionCandidate> getPromotionCandidatesForCallSite( + Instruction *Inst, const ArrayRef<InstrProfValueData> &ValueDataRef, + uint64_t TotalCount, uint32_t NumCandidates); + + // Main function that transforms Inst (either a indirect-call instruction, or + // an invoke instruction , to a conditional call to F. This is like: + // if (Inst.CalledValue == F) + // F(...); + // else + // Inst(...); + // end + // TotalCount is the profile count value that the instruction executes. + // Count is the profile count value that F is the target function. + // These two values are being used to update the branch weight. + void promote(Instruction *Inst, Function *F, uint64_t Count, + uint64_t TotalCount); + + // Promote a list of targets for one indirect-call callsite. Return + // the number of promotions. + uint32_t tryToPromote(Instruction *Inst, + const std::vector<PromotionCandidate> &Candidates, + uint64_t &TotalCount); + + static const char *StatusToString(const TargetStatus S) { + switch (S) { + case OK: + return "OK to promote"; + case NotAvailableInModule: + return "Cannot find the target"; + case ReturnTypeMismatch: + return "Return type mismatch"; + case NumArgsMismatch: + return "The number of arguments mismatch"; + case ArgTypeMismatch: + return "Argument Type mismatch"; + } + llvm_unreachable("Should not reach here"); + } + + // Noncopyable + ICallPromotionFunc(const ICallPromotionFunc &other) = delete; + ICallPromotionFunc &operator=(const ICallPromotionFunc &other) = delete; + +public: + ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab) + : F(Func), M(Modu), Symtab(Symtab) { + } + bool processFunction(); +}; +} // end anonymous namespace + +ICallPromotionFunc::TargetStatus +ICallPromotionFunc::isPromotionLegal(Instruction *Inst, uint64_t Target, + Function *&TargetFunction) { + Function *DirectCallee = Symtab->getFunction(Target); + if (DirectCallee == nullptr) + return NotAvailableInModule; + // Check the return type. + Type *CallRetType = Inst->getType(); + if (!CallRetType->isVoidTy()) { + Type *FuncRetType = DirectCallee->getReturnType(); + if (FuncRetType != CallRetType && + !CastInst::isBitCastable(FuncRetType, CallRetType)) + return ReturnTypeMismatch; + } + + // Check if the arguments are compatible with the parameters + FunctionType *DirectCalleeType = DirectCallee->getFunctionType(); + unsigned ParamNum = DirectCalleeType->getFunctionNumParams(); + CallSite CS(Inst); + unsigned ArgNum = CS.arg_size(); + + if (ParamNum != ArgNum && !DirectCalleeType->isVarArg()) + return NumArgsMismatch; + + for (unsigned I = 0; I < ParamNum; ++I) { + Type *PTy = DirectCalleeType->getFunctionParamType(I); + Type *ATy = CS.getArgument(I)->getType(); + if (PTy == ATy) + continue; + if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy)) + return ArgTypeMismatch; + } + + DEBUG(dbgs() << " #" << NumOfPGOICallPromotion << " Promote the icall to " + << Symtab->getFuncName(Target) << "\n"); + TargetFunction = DirectCallee; + return OK; +} + +// Indirect-call promotion heuristic. The direct targets are sorted based on +// the count. Stop at the first target that is not promoted. +std::vector<ICallPromotionFunc::PromotionCandidate> +ICallPromotionFunc::getPromotionCandidatesForCallSite( + Instruction *Inst, const ArrayRef<InstrProfValueData> &ValueDataRef, + uint64_t TotalCount, uint32_t NumCandidates) { + std::vector<PromotionCandidate> Ret; + + DEBUG(dbgs() << " \nWork on callsite #" << NumOfPGOICallsites << *Inst + << " Num_targets: " << ValueDataRef.size() + << " Num_candidates: " << NumCandidates << "\n"); + NumOfPGOICallsites++; + if (ICPCSSkip != 0 && NumOfPGOICallsites <= ICPCSSkip) { + DEBUG(dbgs() << " Skip: User options.\n"); + return Ret; + } + + for (uint32_t I = 0; I < NumCandidates; I++) { + uint64_t Count = ValueDataRef[I].Count; + assert(Count <= TotalCount); + uint64_t Target = ValueDataRef[I].Value; + DEBUG(dbgs() << " Candidate " << I << " Count=" << Count + << " Target_func: " << Target << "\n"); + + if (ICPInvokeOnly && dyn_cast<CallInst>(Inst)) { + DEBUG(dbgs() << " Not promote: User options.\n"); + break; + } + if (ICPCallOnly && dyn_cast<InvokeInst>(Inst)) { + DEBUG(dbgs() << " Not promote: User option.\n"); + break; + } + if (ICPCutOff != 0 && NumOfPGOICallPromotion >= ICPCutOff) { + DEBUG(dbgs() << " Not promote: Cutoff reached.\n"); + break; + } + Function *TargetFunction = nullptr; + TargetStatus Status = isPromotionLegal(Inst, Target, TargetFunction); + if (Status != OK) { + StringRef TargetFuncName = Symtab->getFuncName(Target); + const char *Reason = StatusToString(Status); + DEBUG(dbgs() << " Not promote: " << Reason << "\n"); + emitOptimizationRemarkMissed( + F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(), + Twine("Cannot promote indirect call to ") + + (TargetFuncName.empty() ? Twine(Target) : Twine(TargetFuncName)) + + Twine(" with count of ") + Twine(Count) + ": " + Reason); + break; + } + Ret.push_back(PromotionCandidate(TargetFunction, Count)); + TotalCount -= Count; + } + return Ret; +} + +// Create a diamond structure for If_Then_Else. Also update the profile +// count. Do the fix-up for the invoke instruction. +static void createIfThenElse(Instruction *Inst, Function *DirectCallee, + uint64_t Count, uint64_t TotalCount, + BasicBlock **DirectCallBB, + BasicBlock **IndirectCallBB, + BasicBlock **MergeBB) { + CallSite CS(Inst); + Value *OrigCallee = CS.getCalledValue(); + + IRBuilder<> BBBuilder(Inst); + LLVMContext &Ctx = Inst->getContext(); + Value *BCI1 = + BBBuilder.CreateBitCast(OrigCallee, Type::getInt8PtrTy(Ctx), ""); + Value *BCI2 = + BBBuilder.CreateBitCast(DirectCallee, Type::getInt8PtrTy(Ctx), ""); + Value *PtrCmp = BBBuilder.CreateICmpEQ(BCI1, BCI2, ""); + + uint64_t ElseCount = TotalCount - Count; + uint64_t MaxCount = (Count >= ElseCount ? Count : ElseCount); + uint64_t Scale = calculateCountScale(MaxCount); + MDBuilder MDB(Inst->getContext()); + MDNode *BranchWeights = MDB.createBranchWeights( + scaleBranchCount(Count, Scale), scaleBranchCount(ElseCount, Scale)); + TerminatorInst *ThenTerm, *ElseTerm; + SplitBlockAndInsertIfThenElse(PtrCmp, Inst, &ThenTerm, &ElseTerm, + BranchWeights); + *DirectCallBB = ThenTerm->getParent(); + (*DirectCallBB)->setName("if.true.direct_targ"); + *IndirectCallBB = ElseTerm->getParent(); + (*IndirectCallBB)->setName("if.false.orig_indirect"); + *MergeBB = Inst->getParent(); + (*MergeBB)->setName("if.end.icp"); + + // Special handing of Invoke instructions. + InvokeInst *II = dyn_cast<InvokeInst>(Inst); + if (!II) + return; + + // We don't need branch instructions for invoke. + ThenTerm->eraseFromParent(); + ElseTerm->eraseFromParent(); + + // Add jump from Merge BB to the NormalDest. This is needed for the newly + // created direct invoke stmt -- as its NormalDst will be fixed up to MergeBB. + BranchInst::Create(II->getNormalDest(), *MergeBB); +} + +// Find the PHI in BB that have the CallResult as the operand. +static bool getCallRetPHINode(BasicBlock *BB, Instruction *Inst) { + BasicBlock *From = Inst->getParent(); + for (auto &I : *BB) { + PHINode *PHI = dyn_cast<PHINode>(&I); + if (!PHI) + continue; + int IX = PHI->getBasicBlockIndex(From); + if (IX == -1) + continue; + Value *V = PHI->getIncomingValue(IX); + if (dyn_cast<Instruction>(V) == Inst) + return true; + } + return false; +} + +// This method fixes up PHI nodes in BB where BB is the UnwindDest of an +// invoke instruction. In BB, there may be PHIs with incoming block being +// OrigBB (the MergeBB after if-then-else splitting). After moving the invoke +// instructions to its own BB, OrigBB is no longer the predecessor block of BB. +// Instead two new predecessors are added: IndirectCallBB and DirectCallBB, +// so the PHI node's incoming BBs need to be fixed up accordingly. +static void fixupPHINodeForUnwind(Instruction *Inst, BasicBlock *BB, + BasicBlock *OrigBB, + BasicBlock *IndirectCallBB, + BasicBlock *DirectCallBB) { + for (auto &I : *BB) { + PHINode *PHI = dyn_cast<PHINode>(&I); + if (!PHI) + continue; + int IX = PHI->getBasicBlockIndex(OrigBB); + if (IX == -1) + continue; + Value *V = PHI->getIncomingValue(IX); + PHI->addIncoming(V, IndirectCallBB); + PHI->setIncomingBlock(IX, DirectCallBB); + } +} + +// This method fixes up PHI nodes in BB where BB is the NormalDest of an +// invoke instruction. In BB, there may be PHIs with incoming block being +// OrigBB (the MergeBB after if-then-else splitting). After moving the invoke +// instructions to its own BB, a new incoming edge will be added to the original +// NormalDstBB from the IndirectCallBB. +static void fixupPHINodeForNormalDest(Instruction *Inst, BasicBlock *BB, + BasicBlock *OrigBB, + BasicBlock *IndirectCallBB, + Instruction *NewInst) { + for (auto &I : *BB) { + PHINode *PHI = dyn_cast<PHINode>(&I); + if (!PHI) + continue; + int IX = PHI->getBasicBlockIndex(OrigBB); + if (IX == -1) + continue; + Value *V = PHI->getIncomingValue(IX); + if (dyn_cast<Instruction>(V) == Inst) { + PHI->setIncomingBlock(IX, IndirectCallBB); + PHI->addIncoming(NewInst, OrigBB); + continue; + } + PHI->addIncoming(V, IndirectCallBB); + } +} + +// Add a bitcast instruction to the direct-call return value if needed. +static Instruction *insertCallRetCast(const Instruction *Inst, + Instruction *DirectCallInst, + Function *DirectCallee) { + if (Inst->getType()->isVoidTy()) + return DirectCallInst; + + Type *CallRetType = Inst->getType(); + Type *FuncRetType = DirectCallee->getReturnType(); + if (FuncRetType == CallRetType) + return DirectCallInst; + + BasicBlock *InsertionBB; + if (CallInst *CI = dyn_cast<CallInst>(DirectCallInst)) + InsertionBB = CI->getParent(); + else + InsertionBB = (dyn_cast<InvokeInst>(DirectCallInst))->getNormalDest(); + + return (new BitCastInst(DirectCallInst, CallRetType, "", + InsertionBB->getTerminator())); +} + +// Create a DirectCall instruction in the DirectCallBB. +// Parameter Inst is the indirect-call (invoke) instruction. +// DirectCallee is the decl of the direct-call (invoke) target. +// DirecallBB is the BB that the direct-call (invoke) instruction is inserted. +// MergeBB is the bottom BB of the if-then-else-diamond after the +// transformation. For invoke instruction, the edges from DirectCallBB and +// IndirectCallBB to MergeBB are removed before this call (during +// createIfThenElse). +static Instruction *createDirectCallInst(const Instruction *Inst, + Function *DirectCallee, + BasicBlock *DirectCallBB, + BasicBlock *MergeBB) { + Instruction *NewInst = Inst->clone(); + if (CallInst *CI = dyn_cast<CallInst>(NewInst)) { + CI->setCalledFunction(DirectCallee); + CI->mutateFunctionType(DirectCallee->getFunctionType()); + } else { + // Must be an invoke instruction. Direct invoke's normal destination is + // fixed up to MergeBB. MergeBB is the place where return cast is inserted. + // Also since IndirectCallBB does not have an edge to MergeBB, there is no + // need to insert new PHIs into MergeBB. + InvokeInst *II = dyn_cast<InvokeInst>(NewInst); + assert(II); + II->setCalledFunction(DirectCallee); + II->mutateFunctionType(DirectCallee->getFunctionType()); + II->setNormalDest(MergeBB); + } + + DirectCallBB->getInstList().insert(DirectCallBB->getFirstInsertionPt(), + NewInst); + + // Clear the value profile data. + NewInst->setMetadata(LLVMContext::MD_prof, 0); + CallSite NewCS(NewInst); + FunctionType *DirectCalleeType = DirectCallee->getFunctionType(); + unsigned ParamNum = DirectCalleeType->getFunctionNumParams(); + for (unsigned I = 0; I < ParamNum; ++I) { + Type *ATy = NewCS.getArgument(I)->getType(); + Type *PTy = DirectCalleeType->getParamType(I); + if (ATy != PTy) { + BitCastInst *BI = new BitCastInst(NewCS.getArgument(I), PTy, "", NewInst); + NewCS.setArgument(I, BI); + } + } + + return insertCallRetCast(Inst, NewInst, DirectCallee); +} + +// Create a PHI to unify the return values of calls. +static void insertCallRetPHI(Instruction *Inst, Instruction *CallResult, + Function *DirectCallee) { + if (Inst->getType()->isVoidTy()) + return; + + BasicBlock *RetValBB = CallResult->getParent(); + + BasicBlock *PHIBB; + if (InvokeInst *II = dyn_cast<InvokeInst>(CallResult)) + RetValBB = II->getNormalDest(); + + PHIBB = RetValBB->getSingleSuccessor(); + if (getCallRetPHINode(PHIBB, Inst)) + return; + + PHINode *CallRetPHI = PHINode::Create(Inst->getType(), 0); + PHIBB->getInstList().push_front(CallRetPHI); + Inst->replaceAllUsesWith(CallRetPHI); + CallRetPHI->addIncoming(Inst, Inst->getParent()); + CallRetPHI->addIncoming(CallResult, RetValBB); +} + +// This function does the actual indirect-call promotion transformation: +// For an indirect-call like: +// Ret = (*Foo)(Args); +// It transforms to: +// if (Foo == DirectCallee) +// Ret1 = DirectCallee(Args); +// else +// Ret2 = (*Foo)(Args); +// Ret = phi(Ret1, Ret2); +// It adds type casts for the args do not match the parameters and the return +// value. Branch weights metadata also updated. +void ICallPromotionFunc::promote(Instruction *Inst, Function *DirectCallee, + uint64_t Count, uint64_t TotalCount) { + assert(DirectCallee != nullptr); + BasicBlock *BB = Inst->getParent(); + // Just to suppress the non-debug build warning. + (void)BB; + DEBUG(dbgs() << "\n\n== Basic Block Before ==\n"); + DEBUG(dbgs() << *BB << "\n"); + + BasicBlock *DirectCallBB, *IndirectCallBB, *MergeBB; + createIfThenElse(Inst, DirectCallee, Count, TotalCount, &DirectCallBB, + &IndirectCallBB, &MergeBB); + + Instruction *NewInst = + createDirectCallInst(Inst, DirectCallee, DirectCallBB, MergeBB); + + // Move Inst from MergeBB to IndirectCallBB. + Inst->removeFromParent(); + IndirectCallBB->getInstList().insert(IndirectCallBB->getFirstInsertionPt(), + Inst); + + if (InvokeInst *II = dyn_cast<InvokeInst>(Inst)) { + // At this point, the original indirect invoke instruction has the original + // UnwindDest and NormalDest. For the direct invoke instruction, the + // NormalDest points to MergeBB, and MergeBB jumps to the original + // NormalDest. MergeBB might have a new bitcast instruction for the return + // value. The PHIs are with the original NormalDest. Since we now have two + // incoming edges to NormalDest and UnwindDest, we have to do some fixups. + // + // UnwindDest will not use the return value. So pass nullptr here. + fixupPHINodeForUnwind(Inst, II->getUnwindDest(), MergeBB, IndirectCallBB, + DirectCallBB); + // We don't need to update the operand from NormalDest for DirectCallBB. + // Pass nullptr here. + fixupPHINodeForNormalDest(Inst, II->getNormalDest(), MergeBB, + IndirectCallBB, NewInst); + } + + insertCallRetPHI(Inst, NewInst, DirectCallee); + + DEBUG(dbgs() << "\n== Basic Blocks After ==\n"); + DEBUG(dbgs() << *BB << *DirectCallBB << *IndirectCallBB << *MergeBB << "\n"); + + emitOptimizationRemark( + F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(), + Twine("Promote indirect call to ") + DirectCallee->getName() + + " with count " + Twine(Count) + " out of " + Twine(TotalCount)); +} + +// Promote indirect-call to conditional direct-call for one callsite. +uint32_t ICallPromotionFunc::tryToPromote( + Instruction *Inst, const std::vector<PromotionCandidate> &Candidates, + uint64_t &TotalCount) { + uint32_t NumPromoted = 0; + + for (auto &C : Candidates) { + uint64_t Count = C.Count; + promote(Inst, C.TargetFunction, Count, TotalCount); + assert(TotalCount >= Count); + TotalCount -= Count; + NumOfPGOICallPromotion++; + NumPromoted++; + } + return NumPromoted; +} + +// Traverse all the indirect-call callsite and get the value profile +// annotation to perform indirect-call promotion. +bool ICallPromotionFunc::processFunction() { + bool Changed = false; + ICallPromotionAnalysis ICallAnalysis; + for (auto &I : findIndirectCallSites(F)) { + uint32_t NumVals, NumCandidates; + uint64_t TotalCount; + auto ICallProfDataRef = ICallAnalysis.getPromotionCandidatesForInstruction( + I, NumVals, TotalCount, NumCandidates); + if (!NumCandidates) + continue; + auto PromotionCandidates = getPromotionCandidatesForCallSite( + I, ICallProfDataRef, TotalCount, NumCandidates); + uint32_t NumPromoted = tryToPromote(I, PromotionCandidates, TotalCount); + if (NumPromoted == 0) + continue; + + Changed = true; + // Adjust the MD.prof metadata. First delete the old one. + I->setMetadata(LLVMContext::MD_prof, 0); + // If all promoted, we don't need the MD.prof metadata. + if (TotalCount == 0 || NumPromoted == NumVals) + continue; + // Otherwise we need update with the un-promoted records back. + annotateValueSite(*M, *I, ICallProfDataRef.slice(NumPromoted), TotalCount, + IPVK_IndirectCallTarget, NumCandidates); + } + return Changed; +} + +// A wrapper function that does the actual work. +static bool promoteIndirectCalls(Module &M, bool InLTO) { + if (DisableICP) + return false; + InstrProfSymtab Symtab; + Symtab.create(M, InLTO); + bool Changed = false; + for (auto &F : M) { + if (F.isDeclaration()) + continue; + if (F.hasFnAttribute(Attribute::OptimizeNone)) + continue; + ICallPromotionFunc ICallPromotion(F, &M, &Symtab); + bool FuncChanged = ICallPromotion.processFunction(); + if (ICPDUMPAFTER && FuncChanged) { + DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs())); + DEBUG(dbgs() << "\n"); + } + Changed |= FuncChanged; + if (ICPCutOff != 0 && NumOfPGOICallPromotion >= ICPCutOff) { + DEBUG(dbgs() << " Stop: Cutoff reached.\n"); + break; + } + } + return Changed; +} + +bool PGOIndirectCallPromotionLegacyPass::runOnModule(Module &M) { + // Command-line option has the priority for InLTO. + return promoteIndirectCalls(M, InLTO | ICPLTOMode); +} + +PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, AnalysisManager<Module> &AM) { + if (!promoteIndirectCalls(M, InLTO | ICPLTOMode)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp index 28483e7e9b69..b11c6be696f3 100644 --- a/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -13,12 +13,12 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/InstrProfiling.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/ProfileData/InstrProf.h" -#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; @@ -27,121 +27,112 @@ using namespace llvm; namespace { -class InstrProfiling : public ModulePass { +cl::opt<bool> DoNameCompression("enable-name-compression", + cl::desc("Enable name string compression"), + cl::init(true)); + +cl::opt<bool> ValueProfileStaticAlloc( + "vp-static-alloc", + cl::desc("Do static counter allocation for value profiler"), + cl::init(true)); +cl::opt<double> NumCountersPerValueSite( + "vp-counters-per-site", + cl::desc("The average number of profile counters allocated " + "per value profiling site."), + // This is set to a very small value because in real programs, only + // a very small percentage of value sites have non-zero targets, e.g, 1/30. + // For those sites with non-zero profile, the average number of targets + // is usually smaller than 2. + cl::init(1.0)); + +class InstrProfilingLegacyPass : public ModulePass { + InstrProfiling InstrProf; + public: static char ID; - - InstrProfiling() : ModulePass(ID) {} - - InstrProfiling(const InstrProfOptions &Options) - : ModulePass(ID), Options(Options) {} - + InstrProfilingLegacyPass() : ModulePass(ID), InstrProf() {} + InstrProfilingLegacyPass(const InstrProfOptions &Options) + : ModulePass(ID), InstrProf(Options) {} const char *getPassName() const override { return "Frontend instrumentation-based coverage lowering"; } - bool runOnModule(Module &M) override; + bool runOnModule(Module &M) override { return InstrProf.run(M); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); } +}; -private: - InstrProfOptions Options; - Module *M; - typedef struct PerFunctionProfileData { - uint32_t NumValueSites[IPVK_Last+1]; - GlobalVariable* RegionCounters; - GlobalVariable* DataVar; - PerFunctionProfileData() : RegionCounters(nullptr), DataVar(nullptr) { - memset(NumValueSites, 0, sizeof(uint32_t) * (IPVK_Last+1)); - } - } PerFunctionProfileData; - DenseMap<GlobalVariable *, PerFunctionProfileData> ProfileDataMap; - std::vector<Value *> UsedVars; - - bool isMachO() const { - return Triple(M->getTargetTriple()).isOSBinFormatMachO(); - } - - /// Get the section name for the counter variables. - StringRef getCountersSection() const { - return getInstrProfCountersSectionName(isMachO()); - } - - /// Get the section name for the name variables. - StringRef getNameSection() const { - return getInstrProfNameSectionName(isMachO()); - } - - /// Get the section name for the profile data variables. - StringRef getDataSection() const { - return getInstrProfDataSectionName(isMachO()); - } - - /// Get the section name for the coverage mapping data. - StringRef getCoverageSection() const { - return getInstrProfCoverageSectionName(isMachO()); - } - - /// Count the number of instrumented value sites for the function. - void computeNumValueSiteCounts(InstrProfValueProfileInst *Ins); - - /// Replace instrprof_value_profile with a call to runtime library. - void lowerValueProfileInst(InstrProfValueProfileInst *Ins); - - /// Replace instrprof_increment with an increment of the appropriate value. - void lowerIncrement(InstrProfIncrementInst *Inc); +} // anonymous namespace - /// Force emitting of name vars for unused functions. - void lowerCoverageData(GlobalVariable *CoverageNamesVar); +PreservedAnalyses InstrProfiling::run(Module &M, AnalysisManager<Module> &AM) { + if (!run(M)) + return PreservedAnalyses::all(); - /// Get the region counters for an increment, creating them if necessary. - /// - /// If the counter array doesn't yet exist, the profile data variables - /// referring to them will also be created. - GlobalVariable *getOrCreateRegionCounters(InstrProfIncrementInst *Inc); + return PreservedAnalyses::none(); +} - /// Emit runtime registration functions for each profile data variable. - void emitRegistration(); +char InstrProfilingLegacyPass::ID = 0; +INITIALIZE_PASS(InstrProfilingLegacyPass, "instrprof", + "Frontend instrumentation-based coverage lowering.", false, + false) - /// Emit the necessary plumbing to pull in the runtime initialization. - void emitRuntimeHook(); +ModulePass * +llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options) { + return new InstrProfilingLegacyPass(Options); +} - /// Add uses of our data variables and runtime hook. - void emitUses(); +bool InstrProfiling::isMachO() const { + return Triple(M->getTargetTriple()).isOSBinFormatMachO(); +} - /// Create a static initializer for our data, on platforms that need it, - /// and for any profile output file that was specified. - void emitInitialization(); -}; +/// Get the section name for the counter variables. +StringRef InstrProfiling::getCountersSection() const { + return getInstrProfCountersSectionName(isMachO()); +} -} // anonymous namespace +/// Get the section name for the name variables. +StringRef InstrProfiling::getNameSection() const { + return getInstrProfNameSectionName(isMachO()); +} -char InstrProfiling::ID = 0; -INITIALIZE_PASS(InstrProfiling, "instrprof", - "Frontend instrumentation-based coverage lowering.", false, - false) +/// Get the section name for the profile data variables. +StringRef InstrProfiling::getDataSection() const { + return getInstrProfDataSectionName(isMachO()); +} -ModulePass *llvm::createInstrProfilingPass(const InstrProfOptions &Options) { - return new InstrProfiling(Options); +/// Get the section name for the coverage mapping data. +StringRef InstrProfiling::getCoverageSection() const { + return getInstrProfCoverageSectionName(isMachO()); } -bool InstrProfiling::runOnModule(Module &M) { +bool InstrProfiling::run(Module &M) { bool MadeChange = false; this->M = &M; + NamesVar = nullptr; + NamesSize = 0; ProfileDataMap.clear(); UsedVars.clear(); // We did not know how many value sites there would be inside // the instrumented function. This is counting the number of instrumented // target value sites to enter it as field in the profile data variable. - for (Function &F : M) + for (Function &F : M) { + InstrProfIncrementInst *FirstProfIncInst = nullptr; for (BasicBlock &BB : F) - for (auto I = BB.begin(), E = BB.end(); I != E;) - if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(I++)) + for (auto I = BB.begin(), E = BB.end(); I != E; I++) + if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(I)) computeNumValueSiteCounts(Ind); + else if (FirstProfIncInst == nullptr) + FirstProfIncInst = dyn_cast<InstrProfIncrementInst>(I); + + // Value profiling intrinsic lowering requires per-function profile data + // variable to be created first. + if (FirstProfIncInst != nullptr) + static_cast<void>(getOrCreateRegionCounters(FirstProfIncInst)); + } for (Function &F : M) for (BasicBlock &BB : F) @@ -157,7 +148,7 @@ bool InstrProfiling::runOnModule(Module &M) { } if (GlobalVariable *CoverageNamesVar = - M.getNamedGlobal(getCoverageNamesVarName())) { + M.getNamedGlobal(getCoverageUnusedNamesVarName())) { lowerCoverageData(CoverageNamesVar); MadeChange = true; } @@ -165,6 +156,8 @@ bool InstrProfiling::runOnModule(Module &M) { if (!MadeChange) return false; + emitVNodes(); + emitNameData(); emitRegistration(); emitRuntimeHook(); emitUses(); @@ -204,7 +197,7 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { GlobalVariable *Name = Ind->getName(); auto It = ProfileDataMap.find(Name); assert(It != ProfileDataMap.end() && It->second.DataVar && - "value profiling detected in function with no counter incerement"); + "value profiling detected in function with no counter incerement"); GlobalVariable *DataVar = It->second.DataVar; uint64_t ValueKind = Ind->getValueKind()->getZExtValue(); @@ -213,9 +206,9 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { Index += It->second.NumValueSites[Kind]; IRBuilder<> Builder(Ind); - Value* Args[3] = {Ind->getTargetValue(), - Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), - Builder.getInt32(Index)}; + Value *Args[3] = {Ind->getTargetValue(), + Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), + Builder.getInt32(Index)}; Ind->replaceAllUsesWith( Builder.CreateCall(getOrInsertValueProfilingCall(*M), Args)); Ind->eraseFromParent(); @@ -243,9 +236,8 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { assert(isa<GlobalVariable>(V) && "Missing reference to function name"); GlobalVariable *Name = cast<GlobalVariable>(V); - // Move the name variable to the right section. - Name->setSection(getNameSection()); - Name->setAlignment(1); + Name->setLinkage(GlobalValue::PrivateLinkage); + ReferencedNames.push_back(Name); } } @@ -261,22 +253,77 @@ static inline bool shouldRecordFunctionAddr(Function *F) { if (!F->hasLinkOnceLinkage() && !F->hasLocalLinkage() && !F->hasAvailableExternallyLinkage()) return true; + // Prohibit function address recording if the function is both internal and + // COMDAT. This avoids the profile data variable referencing internal symbols + // in COMDAT. + if (F->hasLocalLinkage() && F->hasComdat()) + return false; // Check uses of this function for other than direct calls or invokes to it. - return F->hasAddressTaken(); + // Inline virtual functions have linkeOnceODR linkage. When a key method + // exists, the vtable will only be emitted in the TU where the key method + // is defined. In a TU where vtable is not available, the function won't + // be 'addresstaken'. If its address is not recorded here, the profile data + // with missing address may be picked by the linker leading to missing + // indirect call target info. + return F->hasAddressTaken() || F->hasLinkOnceLinkage(); +} + +static inline bool needsComdatForCounter(Function &F, Module &M) { + + if (F.hasComdat()) + return true; + + Triple TT(M.getTargetTriple()); + if (!TT.isOSBinFormatELF()) + return false; + + // See createPGOFuncNameVar for more details. To avoid link errors, profile + // counters for function with available_externally linkage needs to be changed + // to linkonce linkage. On ELF based systems, this leads to weak symbols to be + // created. Without using comdat, duplicate entries won't be removed by the + // linker leading to increased data segement size and raw profile size. Even + // worse, since the referenced counter from profile per-function data object + // will be resolved to the common strong definition, the profile counts for + // available_externally functions will end up being duplicated in raw profile + // data. This can result in distorted profile as the counts of those dups + // will be accumulated by the profile merger. + GlobalValue::LinkageTypes Linkage = F.getLinkage(); + if (Linkage != GlobalValue::ExternalWeakLinkage && + Linkage != GlobalValue::AvailableExternallyLinkage) + return false; + + return true; } -static inline Comdat *getOrCreateProfileComdat(Module &M, +static inline Comdat *getOrCreateProfileComdat(Module &M, Function &F, InstrProfIncrementInst *Inc) { + if (!needsComdatForCounter(F, M)) + return nullptr; + // COFF format requires a COMDAT section to have a key symbol with the same - // name. The linker targeting COFF also requires that the COMDAT section + // name. The linker targeting COFF also requires that the COMDAT // a section is associated to must precede the associating section. For this - // reason, we must choose the name var's name as the name of the comdat. + // reason, we must choose the counter var's name as the name of the comdat. StringRef ComdatPrefix = (Triple(M.getTargetTriple()).isOSBinFormatCOFF() - ? getInstrProfNameVarPrefix() + ? getInstrProfCountersVarPrefix() : getInstrProfComdatPrefix()); return M.getOrInsertComdat(StringRef(getVarName(Inc, ComdatPrefix))); } +static bool needsRuntimeRegistrationOfSectionRange(const Module &M) { + // Don't do this for Darwin. compiler-rt uses linker magic. + if (Triple(M.getTargetTriple()).isOSDarwin()) + return false; + + // Use linker script magic to get data/cnts/name start/end. + if (Triple(M.getTargetTriple()).isOSLinux() || + Triple(M.getTargetTriple()).isOSFreeBSD() || + Triple(M.getTargetTriple()).isPS4CPU()) + return false; + + return true; +} + GlobalVariable * InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { GlobalVariable *NamePtr = Inc->getName(); @@ -294,11 +341,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { // linking. Function *Fn = Inc->getParent()->getParent(); Comdat *ProfileVarsComdat = nullptr; - if (Fn->hasComdat()) - ProfileVarsComdat = getOrCreateProfileComdat(*M, Inc); - NamePtr->setSection(getNameSection()); - NamePtr->setAlignment(1); - NamePtr->setComdat(ProfileVarsComdat); + ProfileVarsComdat = getOrCreateProfileComdat(*M, *Fn, Inc); uint64_t NumCounters = Inc->getNumCounters()->getZExtValue(); LLVMContext &Ctx = M->getContext(); @@ -314,27 +357,51 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { CounterPtr->setAlignment(8); CounterPtr->setComdat(ProfileVarsComdat); - // Create data variable. auto *Int8PtrTy = Type::getInt8PtrTy(Ctx); + // Allocate statically the array of pointers to value profile nodes for + // the current function. + Constant *ValuesPtrExpr = ConstantPointerNull::get(Int8PtrTy); + if (ValueProfileStaticAlloc && !needsRuntimeRegistrationOfSectionRange(*M)) { + + uint64_t NS = 0; + for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) + NS += PD.NumValueSites[Kind]; + if (NS) { + ArrayType *ValuesTy = ArrayType::get(Type::getInt64Ty(Ctx), NS); + + auto *ValuesVar = + new GlobalVariable(*M, ValuesTy, false, NamePtr->getLinkage(), + Constant::getNullValue(ValuesTy), + getVarName(Inc, getInstrProfValuesVarPrefix())); + ValuesVar->setVisibility(NamePtr->getVisibility()); + ValuesVar->setSection(getInstrProfValuesSectionName(isMachO())); + ValuesVar->setAlignment(8); + ValuesVar->setComdat(ProfileVarsComdat); + ValuesPtrExpr = + ConstantExpr::getBitCast(ValuesVar, llvm::Type::getInt8PtrTy(Ctx)); + } + } + + // Create data variable. auto *Int16Ty = Type::getInt16Ty(Ctx); - auto *Int16ArrayTy = ArrayType::get(Int16Ty, IPVK_Last+1); + auto *Int16ArrayTy = ArrayType::get(Int16Ty, IPVK_Last + 1); Type *DataTypes[] = { - #define INSTR_PROF_DATA(Type, LLVMType, Name, Init) LLVMType, - #include "llvm/ProfileData/InstrProfData.inc" +#define INSTR_PROF_DATA(Type, LLVMType, Name, Init) LLVMType, +#include "llvm/ProfileData/InstrProfData.inc" }; auto *DataTy = StructType::get(Ctx, makeArrayRef(DataTypes)); - Constant *FunctionAddr = shouldRecordFunctionAddr(Fn) ? - ConstantExpr::getBitCast(Fn, Int8PtrTy) : - ConstantPointerNull::get(Int8PtrTy); + Constant *FunctionAddr = shouldRecordFunctionAddr(Fn) + ? ConstantExpr::getBitCast(Fn, Int8PtrTy) + : ConstantPointerNull::get(Int8PtrTy); - Constant *Int16ArrayVals[IPVK_Last+1]; + Constant *Int16ArrayVals[IPVK_Last + 1]; for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) Int16ArrayVals[Kind] = ConstantInt::get(Int16Ty, PD.NumValueSites[Kind]); Constant *DataVals[] = { - #define INSTR_PROF_DATA(Type, LLVMType, Name, Init) Init, - #include "llvm/ProfileData/InstrProfData.inc" +#define INSTR_PROF_DATA(Type, LLVMType, Name, Init) Init, +#include "llvm/ProfileData/InstrProfData.inc" }; auto *Data = new GlobalVariable(*M, DataTy, false, NamePtr->getLinkage(), ConstantStruct::get(DataTy, DataVals), @@ -350,28 +417,99 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { // Mark the data variable as used so that it isn't stripped out. UsedVars.push_back(Data); + // Now that the linkage set by the FE has been passed to the data and counter + // variables, reset Name variable's linkage and visibility to private so that + // it can be removed later by the compiler. + NamePtr->setLinkage(GlobalValue::PrivateLinkage); + // Collect the referenced names to be used by emitNameData. + ReferencedNames.push_back(NamePtr); return CounterPtr; } -void InstrProfiling::emitRegistration() { - // Don't do this for Darwin. compiler-rt uses linker magic. - if (Triple(M->getTargetTriple()).isOSDarwin()) +void InstrProfiling::emitVNodes() { + if (!ValueProfileStaticAlloc) return; - // Use linker script magic to get data/cnts/name start/end. - if (Triple(M->getTargetTriple()).isOSLinux() || - Triple(M->getTargetTriple()).isOSFreeBSD()) + // For now only support this on platforms that do + // not require runtime registration to discover + // named section start/end. + if (needsRuntimeRegistrationOfSectionRange(*M)) + return; + + size_t TotalNS = 0; + for (auto &PD : ProfileDataMap) { + for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) + TotalNS += PD.second.NumValueSites[Kind]; + } + + if (!TotalNS) + return; + + uint64_t NumCounters = TotalNS * NumCountersPerValueSite; +// Heuristic for small programs with very few total value sites. +// The default value of vp-counters-per-site is chosen based on +// the observation that large apps usually have a low percentage +// of value sites that actually have any profile data, and thus +// the average number of counters per site is low. For small +// apps with very few sites, this may not be true. Bump up the +// number of counters in this case. +#define INSTR_PROF_MIN_VAL_COUNTS 10 + if (NumCounters < INSTR_PROF_MIN_VAL_COUNTS) + NumCounters = std::max(INSTR_PROF_MIN_VAL_COUNTS, (int)NumCounters * 2); + + auto &Ctx = M->getContext(); + Type *VNodeTypes[] = { +#define INSTR_PROF_VALUE_NODE(Type, LLVMType, Name, Init) LLVMType, +#include "llvm/ProfileData/InstrProfData.inc" + }; + auto *VNodeTy = StructType::get(Ctx, makeArrayRef(VNodeTypes)); + + ArrayType *VNodesTy = ArrayType::get(VNodeTy, NumCounters); + auto *VNodesVar = new GlobalVariable( + *M, VNodesTy, false, llvm::GlobalValue::PrivateLinkage, + Constant::getNullValue(VNodesTy), getInstrProfVNodesVarName()); + VNodesVar->setSection(getInstrProfVNodesSectionName(isMachO())); + UsedVars.push_back(VNodesVar); +} + +void InstrProfiling::emitNameData() { + std::string UncompressedData; + + if (ReferencedNames.empty()) + return; + + std::string CompressedNameStr; + if (Error E = collectPGOFuncNameStrings(ReferencedNames, CompressedNameStr, + DoNameCompression)) { + llvm::report_fatal_error(toString(std::move(E)), false); + } + + auto &Ctx = M->getContext(); + auto *NamesVal = llvm::ConstantDataArray::getString( + Ctx, StringRef(CompressedNameStr), false); + NamesVar = new llvm::GlobalVariable(*M, NamesVal->getType(), true, + llvm::GlobalValue::PrivateLinkage, + NamesVal, getInstrProfNamesVarName()); + NamesSize = CompressedNameStr.size(); + NamesVar->setSection(getNameSection()); + UsedVars.push_back(NamesVar); +} + +void InstrProfiling::emitRegistration() { + if (!needsRuntimeRegistrationOfSectionRange(*M)) return; // Construct the function. auto *VoidTy = Type::getVoidTy(M->getContext()); auto *VoidPtrTy = Type::getInt8PtrTy(M->getContext()); + auto *Int64Ty = Type::getInt64Ty(M->getContext()); auto *RegisterFTy = FunctionType::get(VoidTy, false); auto *RegisterF = Function::Create(RegisterFTy, GlobalValue::InternalLinkage, getInstrProfRegFuncsName(), M); - RegisterF->setUnnamedAddr(true); - if (Options.NoRedZone) RegisterF->addFnAttr(Attribute::NoRedZone); + RegisterF->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + if (Options.NoRedZone) + RegisterF->addFnAttr(Attribute::NoRedZone); auto *RuntimeRegisterTy = FunctionType::get(VoidTy, VoidPtrTy, false); auto *RuntimeRegisterF = @@ -380,7 +518,20 @@ void InstrProfiling::emitRegistration() { IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", RegisterF)); for (Value *Data : UsedVars) - IRB.CreateCall(RuntimeRegisterF, IRB.CreateBitCast(Data, VoidPtrTy)); + if (Data != NamesVar) + IRB.CreateCall(RuntimeRegisterF, IRB.CreateBitCast(Data, VoidPtrTy)); + + if (NamesVar) { + Type *ParamTypes[] = {VoidPtrTy, Int64Ty}; + auto *NamesRegisterTy = + FunctionType::get(VoidTy, makeArrayRef(ParamTypes), false); + auto *NamesRegisterF = + Function::Create(NamesRegisterTy, GlobalVariable::ExternalLinkage, + getInstrProfNamesRegFuncName(), M); + IRB.CreateCall(NamesRegisterF, {IRB.CreateBitCast(NamesVar, VoidPtrTy), + IRB.getInt64(NamesSize)}); + } + IRB.CreateRetVoid(); } @@ -392,7 +543,8 @@ void InstrProfiling::emitRuntimeHook() { return; // If the module's provided its own runtime, we don't need to do anything. - if (M->getGlobalVariable(getInstrProfRuntimeHookVarName())) return; + if (M->getGlobalVariable(getInstrProfRuntimeHookVarName())) + return; // Declare an external variable that will pull in the runtime initialization. auto *Int32Ty = Type::getInt32Ty(M->getContext()); @@ -405,8 +557,11 @@ void InstrProfiling::emitRuntimeHook() { GlobalValue::LinkOnceODRLinkage, getInstrProfRuntimeHookVarUseFuncName(), M); User->addFnAttr(Attribute::NoInline); - if (Options.NoRedZone) User->addFnAttr(Attribute::NoRedZone); + if (Options.NoRedZone) + User->addFnAttr(Attribute::NoRedZone); User->setVisibility(GlobalValue::HiddenVisibility); + if (Triple(M->getTargetTriple()).supportsCOMDAT()) + User->setComdat(M->getOrInsertComdat(User->getName())); IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", User)); auto *Load = IRB.CreateLoad(Var); @@ -448,16 +603,18 @@ void InstrProfiling::emitInitialization() { std::string InstrProfileOutput = Options.InstrProfileOutput; Constant *RegisterF = M->getFunction(getInstrProfRegFuncsName()); - if (!RegisterF && InstrProfileOutput.empty()) return; + if (!RegisterF && InstrProfileOutput.empty()) + return; // Create the initialization function. auto *VoidTy = Type::getVoidTy(M->getContext()); auto *F = Function::Create(FunctionType::get(VoidTy, false), GlobalValue::InternalLinkage, getInstrProfInitFuncName(), M); - F->setUnnamedAddr(true); + F->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); F->addFnAttr(Attribute::NoInline); - if (Options.NoRedZone) F->addFnAttr(Attribute::NoRedZone); + if (Options.NoRedZone) + F->addFnAttr(Attribute::NoRedZone); // Add the basic block and the necessary calls. IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", F)); diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp index a05a5fa09f9a..2963d08752c4 100644 --- a/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -59,15 +59,16 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) { initializeAddressSanitizerPass(Registry); initializeAddressSanitizerModulePass(Registry); initializeBoundsCheckingPass(Registry); - initializeGCOVProfilerPass(Registry); - initializePGOInstrumentationGenPass(Registry); - initializePGOInstrumentationUsePass(Registry); - initializeInstrProfilingPass(Registry); + initializeGCOVProfilerLegacyPassPass(Registry); + initializePGOInstrumentationGenLegacyPassPass(Registry); + initializePGOInstrumentationUseLegacyPassPass(Registry); + initializePGOIndirectCallPromotionLegacyPassPass(Registry); + initializeInstrProfilingLegacyPassPass(Registry); initializeMemorySanitizerPass(Registry); initializeThreadSanitizerPass(Registry); initializeSanitizerCoverageModulePass(Registry); initializeDataFlowSanitizerPass(Registry); - initializeSafeStackPass(Registry); + initializeEfficiencySanitizerPass(Registry); } /// LLVMInitializeInstrumentation - C binding for diff --git a/lib/Transforms/Instrumentation/Makefile b/lib/Transforms/Instrumentation/Makefile deleted file mode 100644 index 6cbc7a9cd88a..000000000000 --- a/lib/Transforms/Instrumentation/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Transforms/Instrumentation/Makefile -------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMInstrumentation -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 34aaa7f27d6e..970f9ab86e82 100644 --- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -91,7 +91,6 @@ //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" @@ -109,9 +108,9 @@ #include "llvm/IR/Type.h" #include "llvm/IR/ValueMap.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -191,6 +190,12 @@ static cl::opt<bool> ClCheckConstantShadow("msan-check-constant-shadow", cl::desc("Insert checks for constant shadow values"), cl::Hidden, cl::init(false)); +// This is off by default because of a bug in gold: +// https://sourceware.org/bugzilla/show_bug.cgi?id=19002 +static cl::opt<bool> ClWithComdat("msan-with-comdat", + cl::desc("Place MSan constructors in comdat sections"), + cl::Hidden, cl::init(false)); + static const char *const kMsanModuleCtorName = "msan.module_ctor"; static const char *const kMsanInitName = "__msan_init"; @@ -312,6 +317,9 @@ class MemorySanitizer : public FunctionPass { TrackOrigins(std::max(TrackOrigins, (int)ClTrackOrigins)), WarningFn(nullptr) {} const char *getPassName() const override { return "MemorySanitizer"; } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } bool runOnFunction(Function &F) override; bool doInitialization(Module &M) override; static char ID; // Pass identification, replacement for typeid. @@ -374,13 +382,18 @@ class MemorySanitizer : public FunctionPass { friend struct VarArgAMD64Helper; friend struct VarArgMIPS64Helper; friend struct VarArgAArch64Helper; + friend struct VarArgPowerPC64Helper; }; } // anonymous namespace char MemorySanitizer::ID = 0; -INITIALIZE_PASS(MemorySanitizer, "msan", - "MemorySanitizer: detects uninitialized reads.", - false, false) +INITIALIZE_PASS_BEGIN( + MemorySanitizer, "msan", + "MemorySanitizer: detects uninitialized reads.", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END( + MemorySanitizer, "msan", + "MemorySanitizer: detects uninitialized reads.", false, false) FunctionPass *llvm::createMemorySanitizerPass(int TrackOrigins) { return new MemorySanitizer(TrackOrigins); @@ -540,8 +553,14 @@ bool MemorySanitizer::doInitialization(Module &M) { createSanitizerCtorAndInitFunctions(M, kMsanModuleCtorName, kMsanInitName, /*InitArgTypes=*/{}, /*InitArgs=*/{}); + if (ClWithComdat) { + Comdat *MsanCtorComdat = M.getOrInsertComdat(kMsanModuleCtorName); + MsanCtorFunction->setComdat(MsanCtorComdat); + appendToGlobalCtors(M, MsanCtorFunction, 0, MsanCtorFunction); + } else { + appendToGlobalCtors(M, MsanCtorFunction, 0); + } - appendToGlobalCtors(M, MsanCtorFunction, 0); if (TrackOrigins) new GlobalVariable(M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage, @@ -591,7 +610,7 @@ CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, unsigned TypeSizeToSizeIndex(unsigned TypeSize) { if (TypeSize <= 8) return 0; - return Log2_32_Ceil(TypeSize / 8); + return Log2_32_Ceil((TypeSize + 7) / 8); } /// This class does all the work for a given function. Store and Load @@ -606,6 +625,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { SmallVector<PHINode *, 16> ShadowPHINodes, OriginPHINodes; ValueMap<Value*, Value*> ShadowMap, OriginMap; std::unique_ptr<VarArgHelper> VAHelper; + const TargetLibraryInfo *TLI; // The following flags disable parts of MSan instrumentation based on // blacklist contents and command-line options. @@ -623,7 +643,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { : Shadow(S), Origin(O), OrigIns(I) { } }; SmallVector<ShadowOriginAndInsertPoint, 16> InstrumentationList; - SmallVector<Instruction*, 16> StoreList; + SmallVector<StoreInst *, 16> StoreList; MemorySanitizerVisitor(Function &F, MemorySanitizer &MS) : F(F), MS(MS), VAHelper(CreateVarArgHelper(F, MS, *this)) { @@ -635,6 +655,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // FIXME: Consider using SpecialCaseList to specify a list of functions that // must always return fully initialized values. For now, we hardcode "main". CheckReturnValue = SanitizeFunction && (F.getName() == "main"); + TLI = &MS.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); DEBUG(if (!InsertChecks) dbgs() << "MemorySanitizer is not inserting checks into '" @@ -731,26 +752,26 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } void materializeStores(bool InstrumentWithCalls) { - for (auto Inst : StoreList) { - StoreInst &SI = *dyn_cast<StoreInst>(Inst); - - IRBuilder<> IRB(&SI); - Value *Val = SI.getValueOperand(); - Value *Addr = SI.getPointerOperand(); - Value *Shadow = SI.isAtomic() ? getCleanShadow(Val) : getShadow(Val); + for (StoreInst *SI : StoreList) { + IRBuilder<> IRB(SI); + Value *Val = SI->getValueOperand(); + Value *Addr = SI->getPointerOperand(); + Value *Shadow = SI->isAtomic() ? getCleanShadow(Val) : getShadow(Val); Value *ShadowPtr = getShadowPtr(Addr, Shadow->getType(), IRB); StoreInst *NewSI = - IRB.CreateAlignedStore(Shadow, ShadowPtr, SI.getAlignment()); + IRB.CreateAlignedStore(Shadow, ShadowPtr, SI->getAlignment()); DEBUG(dbgs() << " STORE: " << *NewSI << "\n"); (void)NewSI; - if (ClCheckAccessAddress) insertShadowCheck(Addr, &SI); + if (ClCheckAccessAddress) + insertShadowCheck(Addr, SI); - if (SI.isAtomic()) SI.setOrdering(addReleaseOrdering(SI.getOrdering())); + if (SI->isAtomic()) + SI->setOrdering(addReleaseOrdering(SI->getOrdering())); - if (MS.TrackOrigins && !SI.isAtomic()) - storeOrigin(IRB, Addr, Shadow, getOrigin(Val), SI.getAlignment(), + if (MS.TrackOrigins && !SI->isAtomic()) + storeOrigin(IRB, Addr, Shadow, getOrigin(Val), SI->getAlignment(), InstrumentWithCalls); } } @@ -1142,7 +1163,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(A, getCleanOrigin()); } } - ArgOffset += RoundUpToAlignment(Size, kShadowTLSAlignment); + ArgOffset += alignTo(Size, kShadowTLSAlignment); } assert(*ShadowPtr && "Could not find shadow for an argument"); return *ShadowPtr; @@ -1210,34 +1231,34 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { AtomicOrdering addReleaseOrdering(AtomicOrdering a) { switch (a) { - case NotAtomic: - return NotAtomic; - case Unordered: - case Monotonic: - case Release: - return Release; - case Acquire: - case AcquireRelease: - return AcquireRelease; - case SequentiallyConsistent: - return SequentiallyConsistent; + case AtomicOrdering::NotAtomic: + return AtomicOrdering::NotAtomic; + case AtomicOrdering::Unordered: + case AtomicOrdering::Monotonic: + case AtomicOrdering::Release: + return AtomicOrdering::Release; + case AtomicOrdering::Acquire: + case AtomicOrdering::AcquireRelease: + return AtomicOrdering::AcquireRelease; + case AtomicOrdering::SequentiallyConsistent: + return AtomicOrdering::SequentiallyConsistent; } llvm_unreachable("Unknown ordering"); } AtomicOrdering addAcquireOrdering(AtomicOrdering a) { switch (a) { - case NotAtomic: - return NotAtomic; - case Unordered: - case Monotonic: - case Acquire: - return Acquire; - case Release: - case AcquireRelease: - return AcquireRelease; - case SequentiallyConsistent: - return SequentiallyConsistent; + case AtomicOrdering::NotAtomic: + return AtomicOrdering::NotAtomic; + case AtomicOrdering::Unordered: + case AtomicOrdering::Monotonic: + case AtomicOrdering::Acquire: + return AtomicOrdering::Acquire; + case AtomicOrdering::Release: + case AtomicOrdering::AcquireRelease: + return AtomicOrdering::AcquireRelease; + case AtomicOrdering::SequentiallyConsistent: + return AtomicOrdering::SequentiallyConsistent; } llvm_unreachable("Unknown ordering"); } @@ -1603,7 +1624,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { for (unsigned Idx = 0; Idx < NumElements; ++Idx) { if (ConstantInt *Elt = dyn_cast<ConstantInt>(ConstArg->getAggregateElement(Idx))) { - APInt V = Elt->getValue(); + const APInt &V = Elt->getValue(); APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros(); Elements.push_back(ConstantInt::get(EltTy, V2)); } else { @@ -1613,7 +1634,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { ShadowMul = ConstantVector::get(Elements); } else { if (ConstantInt *Elt = dyn_cast<ConstantInt>(ConstArg)) { - APInt V = Elt->getValue(); + const APInt &V = Elt->getValue(); APInt V2 = APInt(V.getBitWidth(), 1) << V.countTrailingZeros(); ShadowMul = ConstantInt::get(Ty, V2); } else { @@ -2123,6 +2144,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return CreateShadowCast(IRB, S2, T, /* Signed */ true); } + // Given a vector, extract its first element, and return all + // zeroes if it is zero, and all ones otherwise. + Value *LowerElementShadowExtend(IRBuilder<> &IRB, Value *S, Type *T) { + Value *S1 = IRB.CreateExtractElement(S, (uint64_t)0); + Value *S2 = IRB.CreateICmpNE(S1, getCleanShadow(S1)); + return CreateShadowCast(IRB, S2, T, /* Signed */ true); + } + Value *VariableShadowExtend(IRBuilder<> &IRB, Value *S) { Type *T = S->getType(); assert(T->isVectorTy()); @@ -2270,15 +2299,39 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } + // \brief Instrument compare-packed intrinsic. + // Basically, an or followed by sext(icmp ne 0) to end up with all-zeros or + // all-ones shadow. + void handleVectorComparePackedIntrinsic(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Type *ResTy = getShadowTy(&I); + Value *S0 = IRB.CreateOr(getShadow(&I, 0), getShadow(&I, 1)); + Value *S = IRB.CreateSExt( + IRB.CreateICmpNE(S0, Constant::getNullValue(ResTy)), ResTy); + setShadow(&I, S); + setOriginForNaryOp(I); + } + + // \brief Instrument compare-scalar intrinsic. + // This handles both cmp* intrinsics which return the result in the first + // element of a vector, and comi* which return the result as i32. + void handleVectorCompareScalarIntrinsic(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *S0 = IRB.CreateOr(getShadow(&I, 0), getShadow(&I, 1)); + Value *S = LowerElementShadowExtend(IRB, S0, getShadowTy(&I)); + setShadow(&I, S); + setOriginForNaryOp(I); + } + void visitIntrinsicInst(IntrinsicInst &I) { switch (I.getIntrinsicID()) { case llvm::Intrinsic::bswap: handleBswap(I); break; - case llvm::Intrinsic::x86_avx512_cvtsd2usi64: - case llvm::Intrinsic::x86_avx512_cvtsd2usi: - case llvm::Intrinsic::x86_avx512_cvtss2usi64: - case llvm::Intrinsic::x86_avx512_cvtss2usi: + case llvm::Intrinsic::x86_avx512_vcvtsd2usi64: + case llvm::Intrinsic::x86_avx512_vcvtsd2usi32: + case llvm::Intrinsic::x86_avx512_vcvtss2usi64: + case llvm::Intrinsic::x86_avx512_vcvtss2usi32: case llvm::Intrinsic::x86_avx512_cvttss2usi64: case llvm::Intrinsic::x86_avx512_cvttss2usi: case llvm::Intrinsic::x86_avx512_cvttsd2usi64: @@ -2303,8 +2356,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { case llvm::Intrinsic::x86_sse_cvttss2si: handleVectorConvertIntrinsic(I, 1); break; - case llvm::Intrinsic::x86_sse2_cvtdq2pd: - case llvm::Intrinsic::x86_sse2_cvtps2pd: case llvm::Intrinsic::x86_sse_cvtps2pi: case llvm::Intrinsic::x86_sse_cvttps2pi: handleVectorConvertIntrinsic(I, 2); @@ -2413,6 +2464,43 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { handleVectorPmaddIntrinsic(I, 16); break; + case llvm::Intrinsic::x86_sse_cmp_ss: + case llvm::Intrinsic::x86_sse2_cmp_sd: + case llvm::Intrinsic::x86_sse_comieq_ss: + case llvm::Intrinsic::x86_sse_comilt_ss: + case llvm::Intrinsic::x86_sse_comile_ss: + case llvm::Intrinsic::x86_sse_comigt_ss: + case llvm::Intrinsic::x86_sse_comige_ss: + case llvm::Intrinsic::x86_sse_comineq_ss: + case llvm::Intrinsic::x86_sse_ucomieq_ss: + case llvm::Intrinsic::x86_sse_ucomilt_ss: + case llvm::Intrinsic::x86_sse_ucomile_ss: + case llvm::Intrinsic::x86_sse_ucomigt_ss: + case llvm::Intrinsic::x86_sse_ucomige_ss: + case llvm::Intrinsic::x86_sse_ucomineq_ss: + case llvm::Intrinsic::x86_sse2_comieq_sd: + case llvm::Intrinsic::x86_sse2_comilt_sd: + case llvm::Intrinsic::x86_sse2_comile_sd: + case llvm::Intrinsic::x86_sse2_comigt_sd: + case llvm::Intrinsic::x86_sse2_comige_sd: + case llvm::Intrinsic::x86_sse2_comineq_sd: + case llvm::Intrinsic::x86_sse2_ucomieq_sd: + case llvm::Intrinsic::x86_sse2_ucomilt_sd: + case llvm::Intrinsic::x86_sse2_ucomile_sd: + case llvm::Intrinsic::x86_sse2_ucomigt_sd: + case llvm::Intrinsic::x86_sse2_ucomige_sd: + case llvm::Intrinsic::x86_sse2_ucomineq_sd: + handleVectorCompareScalarIntrinsic(I); + break; + + case llvm::Intrinsic::x86_sse_cmp_ps: + case llvm::Intrinsic::x86_sse2_cmp_pd: + // FIXME: For x86_avx_cmp_pd_256 and x86_avx_cmp_ps_256 this function + // generates reasonably looking IR that fails in the backend with "Do not + // know how to split the result of this operator!". + handleVectorComparePackedIntrinsic(I); + break; + default: if (!handleUnknownIntrinsic(I)) visitInstruction(I); @@ -2450,6 +2538,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { AttributeSet::FunctionIndex, B)); } + + maybeMarkSanitizerLibraryCallNoBuiltin(Call, TLI); } IRBuilder<> IRB(&I); @@ -2498,7 +2588,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { (void)Store; assert(Size != 0 && Store != nullptr); DEBUG(dbgs() << " Param:" << *Store << "\n"); - ArgOffset += RoundUpToAlignment(Size, 8); + ArgOffset += alignTo(Size, 8); } DEBUG(dbgs() << " done with call args\n"); @@ -2811,14 +2901,19 @@ struct VarArgAMD64Helper : public VarArgHelper { ArgIt != End; ++ArgIt) { Value *A = *ArgIt; unsigned ArgNo = CS.getArgumentNo(ArgIt); + bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams(); bool IsByVal = CS.paramHasAttr(ArgNo + 1, Attribute::ByVal); if (IsByVal) { // ByVal arguments always go to the overflow area. + // Fixed arguments passed through the overflow area will be stepped + // over by va_start, so don't count them towards the offset. + if (IsFixed) + continue; assert(A->getType()->isPointerTy()); Type *RealTy = A->getType()->getPointerElementType(); uint64_t ArgSize = DL.getTypeAllocSize(RealTy); Value *Base = getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); - OverflowOffset += RoundUpToAlignment(ArgSize, 8); + OverflowOffset += alignTo(ArgSize, 8); IRB.CreateMemCpy(Base, MSV.getShadowPtr(A, IRB.getInt8Ty(), IRB), ArgSize, kShadowTLSAlignment); } else { @@ -2838,10 +2933,16 @@ struct VarArgAMD64Helper : public VarArgHelper { FpOffset += 16; break; case AK_Memory: + if (IsFixed) + continue; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); - OverflowOffset += RoundUpToAlignment(ArgSize, 8); + OverflowOffset += alignTo(ArgSize, 8); } + // Take fixed arguments into account for GpOffset and FpOffset, + // but don't actually store shadows for them. + if (IsFixed) + continue; IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } } @@ -2952,20 +3053,22 @@ struct VarArgMIPS64Helper : public VarArgHelper { void visitCallSite(CallSite &CS, IRBuilder<> &IRB) override { unsigned VAArgOffset = 0; const DataLayout &DL = F.getParent()->getDataLayout(); - for (CallSite::arg_iterator ArgIt = CS.arg_begin() + 1, End = CS.arg_end(); + for (CallSite::arg_iterator ArgIt = CS.arg_begin() + + CS.getFunctionType()->getNumParams(), End = CS.arg_end(); ArgIt != End; ++ArgIt) { + llvm::Triple TargetTriple(F.getParent()->getTargetTriple()); Value *A = *ArgIt; Value *Base; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); -#if defined(__MIPSEB__) || defined(MIPSEB) - // Adjusting the shadow for argument with size < 8 to match the placement - // of bits in big endian system - if (ArgSize < 8) - VAArgOffset += (8 - ArgSize); -#endif + if (TargetTriple.getArch() == llvm::Triple::mips64) { + // Adjusting the shadow for argument with size < 8 to match the placement + // of bits in big endian system + if (ArgSize < 8) + VAArgOffset += (8 - ArgSize); + } Base = getShadowPtrForVAArgument(A->getType(), IRB, VAArgOffset); VAArgOffset += ArgSize; - VAArgOffset = RoundUpToAlignment(VAArgOffset, 8); + VAArgOffset = alignTo(VAArgOffset, 8); IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } @@ -3038,13 +3141,13 @@ struct VarArgMIPS64Helper : public VarArgHelper { /// \brief AArch64-specific implementation of VarArgHelper. struct VarArgAArch64Helper : public VarArgHelper { - static const unsigned kAArch64GrArgSize = 56; + static const unsigned kAArch64GrArgSize = 64; static const unsigned kAArch64VrArgSize = 128; static const unsigned AArch64GrBegOffset = 0; static const unsigned AArch64GrEndOffset = kAArch64GrArgSize; // Make VR space aligned to 16 bytes. - static const unsigned AArch64VrBegOffset = AArch64GrEndOffset + 8; + static const unsigned AArch64VrBegOffset = AArch64GrEndOffset; static const unsigned AArch64VrEndOffset = AArch64VrBegOffset + kAArch64VrArgSize; static const unsigned AArch64VAEndOffset = AArch64VrEndOffset; @@ -3089,9 +3192,11 @@ struct VarArgAArch64Helper : public VarArgHelper { unsigned OverflowOffset = AArch64VAEndOffset; const DataLayout &DL = F.getParent()->getDataLayout(); - for (CallSite::arg_iterator ArgIt = CS.arg_begin() + 1, End = CS.arg_end(); + for (CallSite::arg_iterator ArgIt = CS.arg_begin(), End = CS.arg_end(); ArgIt != End; ++ArgIt) { Value *A = *ArgIt; + unsigned ArgNo = CS.getArgumentNo(ArgIt); + bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams(); ArgKind AK = classifyArgument(A); if (AK == AK_GeneralPurpose && GrOffset >= AArch64GrEndOffset) AK = AK_Memory; @@ -3108,11 +3213,19 @@ struct VarArgAArch64Helper : public VarArgHelper { VrOffset += 16; break; case AK_Memory: + // Don't count fixed arguments in the overflow area - va_start will + // skip right over them. + if (IsFixed) + continue; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); - OverflowOffset += RoundUpToAlignment(ArgSize, 8); + OverflowOffset += alignTo(ArgSize, 8); break; } + // Count Gp/Vr fixed arguments to their respective offsets, but don't + // bother to actually store a shadow. + if (IsFixed) + continue; IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } Constant *OverflowSize = @@ -3271,6 +3384,163 @@ struct VarArgAArch64Helper : public VarArgHelper { } }; +/// \brief PowerPC64-specific implementation of VarArgHelper. +struct VarArgPowerPC64Helper : public VarArgHelper { + Function &F; + MemorySanitizer &MS; + MemorySanitizerVisitor &MSV; + Value *VAArgTLSCopy; + Value *VAArgSize; + + SmallVector<CallInst*, 16> VAStartInstrumentationList; + + VarArgPowerPC64Helper(Function &F, MemorySanitizer &MS, + MemorySanitizerVisitor &MSV) + : F(F), MS(MS), MSV(MSV), VAArgTLSCopy(nullptr), + VAArgSize(nullptr) {} + + void visitCallSite(CallSite &CS, IRBuilder<> &IRB) override { + // For PowerPC, we need to deal with alignment of stack arguments - + // they are mostly aligned to 8 bytes, but vectors and i128 arrays + // are aligned to 16 bytes, byvals can be aligned to 8 or 16 bytes, + // and QPX vectors are aligned to 32 bytes. For that reason, we + // compute current offset from stack pointer (which is always properly + // aligned), and offset for the first vararg, then subtract them. + unsigned VAArgBase; + llvm::Triple TargetTriple(F.getParent()->getTargetTriple()); + // Parameter save area starts at 48 bytes from frame pointer for ABIv1, + // and 32 bytes for ABIv2. This is usually determined by target + // endianness, but in theory could be overriden by function attribute. + // For simplicity, we ignore it here (it'd only matter for QPX vectors). + if (TargetTriple.getArch() == llvm::Triple::ppc64) + VAArgBase = 48; + else + VAArgBase = 32; + unsigned VAArgOffset = VAArgBase; + const DataLayout &DL = F.getParent()->getDataLayout(); + for (CallSite::arg_iterator ArgIt = CS.arg_begin(), End = CS.arg_end(); + ArgIt != End; ++ArgIt) { + Value *A = *ArgIt; + unsigned ArgNo = CS.getArgumentNo(ArgIt); + bool IsFixed = ArgNo < CS.getFunctionType()->getNumParams(); + bool IsByVal = CS.paramHasAttr(ArgNo + 1, Attribute::ByVal); + if (IsByVal) { + assert(A->getType()->isPointerTy()); + Type *RealTy = A->getType()->getPointerElementType(); + uint64_t ArgSize = DL.getTypeAllocSize(RealTy); + uint64_t ArgAlign = CS.getParamAlignment(ArgNo + 1); + if (ArgAlign < 8) + ArgAlign = 8; + VAArgOffset = alignTo(VAArgOffset, ArgAlign); + if (!IsFixed) { + Value *Base = getShadowPtrForVAArgument(RealTy, IRB, + VAArgOffset - VAArgBase); + IRB.CreateMemCpy(Base, MSV.getShadowPtr(A, IRB.getInt8Ty(), IRB), + ArgSize, kShadowTLSAlignment); + } + VAArgOffset += alignTo(ArgSize, 8); + } else { + Value *Base; + uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); + uint64_t ArgAlign = 8; + if (A->getType()->isArrayTy()) { + // Arrays are aligned to element size, except for long double + // arrays, which are aligned to 8 bytes. + Type *ElementTy = A->getType()->getArrayElementType(); + if (!ElementTy->isPPC_FP128Ty()) + ArgAlign = DL.getTypeAllocSize(ElementTy); + } else if (A->getType()->isVectorTy()) { + // Vectors are naturally aligned. + ArgAlign = DL.getTypeAllocSize(A->getType()); + } + if (ArgAlign < 8) + ArgAlign = 8; + VAArgOffset = alignTo(VAArgOffset, ArgAlign); + if (DL.isBigEndian()) { + // Adjusting the shadow for argument with size < 8 to match the placement + // of bits in big endian system + if (ArgSize < 8) + VAArgOffset += (8 - ArgSize); + } + if (!IsFixed) { + Base = getShadowPtrForVAArgument(A->getType(), IRB, + VAArgOffset - VAArgBase); + IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); + } + VAArgOffset += ArgSize; + VAArgOffset = alignTo(VAArgOffset, 8); + } + if (IsFixed) + VAArgBase = VAArgOffset; + } + + Constant *TotalVAArgSize = ConstantInt::get(IRB.getInt64Ty(), + VAArgOffset - VAArgBase); + // Here using VAArgOverflowSizeTLS as VAArgSizeTLS to avoid creation of + // a new class member i.e. it is the total size of all VarArgs. + IRB.CreateStore(TotalVAArgSize, MS.VAArgOverflowSizeTLS); + } + + /// \brief Compute the shadow address for a given va_arg. + Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, + int ArgOffset) { + Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); + Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); + return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), + "_msarg"); + } + + void visitVAStartInst(VAStartInst &I) override { + IRBuilder<> IRB(&I); + VAStartInstrumentationList.push_back(&I); + Value *VAListTag = I.getArgOperand(0); + Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); + IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), + /* size */8, /* alignment */8, false); + } + + void visitVACopyInst(VACopyInst &I) override { + IRBuilder<> IRB(&I); + Value *VAListTag = I.getArgOperand(0); + Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); + // Unpoison the whole __va_list_tag. + // FIXME: magic ABI constants. + IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), + /* size */8, /* alignment */8, false); + } + + void finalizeInstrumentation() override { + assert(!VAArgSize && !VAArgTLSCopy && + "finalizeInstrumentation called twice"); + IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + VAArgSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); + Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), + VAArgSize); + + if (!VAStartInstrumentationList.empty()) { + // If there is a va_start in this function, make a backup copy of + // va_arg_tls somewhere in the function entry block. + VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); + IRB.CreateMemCpy(VAArgTLSCopy, MS.VAArgTLS, CopySize, 8); + } + + // Instrument va_start. + // Copy va_list shadow from the backup copy of the TLS contents. + for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { + CallInst *OrigInst = VAStartInstrumentationList[i]; + IRBuilder<> IRB(OrigInst->getNextNode()); + Value *VAListTag = OrigInst->getArgOperand(0); + Value *RegSaveAreaPtrPtr = + IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), + Type::getInt64PtrTy(*MS.C)); + Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); + Value *RegSaveAreaShadowPtr = + MSV.getShadowPtr(RegSaveAreaPtr, IRB.getInt8Ty(), IRB); + IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, CopySize, 8); + } + } +}; + /// \brief A no-op implementation of VarArgHelper. struct VarArgNoOpHelper : public VarArgHelper { VarArgNoOpHelper(Function &F, MemorySanitizer &MS, @@ -3297,6 +3567,9 @@ VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, return new VarArgMIPS64Helper(Func, Msan, Visitor); else if (TargetTriple.getArch() == llvm::Triple::aarch64) return new VarArgAArch64Helper(Func, Msan, Visitor); + else if (TargetTriple.getArch() == llvm::Triple::ppc64 || + TargetTriple.getArch() == llvm::Triple::ppc64le) + return new VarArgPowerPC64Helper(Func, Msan, Visitor); else return new VarArgNoOpHelper(Func, Msan, Visitor); } diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 4b59b93b325f..f54d8ad48146 100644 --- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -25,9 +25,12 @@ // // This file contains two passes: // (1) Pass PGOInstrumentationGen which instruments the IR to generate edge -// count profile, and +// count profile, and generates the instrumentation for indirect call +// profiling. // (2) Pass PGOInstrumentationUse which reads the edge count profile and -// annotates the branch weights. +// annotates the branch weights. It also reads the indirect call value +// profiling records and annotate the indirect call instructions. +// // To get the precise counter information, These two passes need to invoke at // the same compilation point (so they see the same IR). For pass // PGOInstrumentationGen, the real work is done in instrumentOneFunc(). For @@ -45,14 +48,16 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/PGOInstrumentation.h" #include "CFGMST.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Triple.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" @@ -62,10 +67,13 @@ #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/ProfileData/InstrProfReader.h" +#include "llvm/ProfileData/ProfileCommon.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/Debug.h" #include "llvm/Support/JamCRC.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <algorithm> #include <string> #include <utility> #include <vector> @@ -81,6 +89,7 @@ STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts."); STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile."); STATISTIC(NumOfPGOMissing, "Number of functions without profile."); +STATISTIC(NumOfPGOICall, "Number of indirect call value instrumentations."); // Command line option to specify the file to read profile from. This is // mainly used for testing. @@ -90,13 +99,37 @@ static cl::opt<std::string> cl::desc("Specify the path of profile data file. This is" "mainly for test purpose.")); +// Command line option to disable value profiling. The default is false: +// i.e. value profiling is enabled by default. This is for debug purpose. +static cl::opt<bool> DisableValueProfiling("disable-vp", cl::init(false), + cl::Hidden, + cl::desc("Disable Value Profiling")); + +// Command line option to set the maximum number of VP annotations to write to +// the metadata for a single indirect call callsite. +static cl::opt<unsigned> MaxNumAnnotations( + "icp-max-annotations", cl::init(3), cl::Hidden, cl::ZeroOrMore, + cl::desc("Max number of annotations for a single indirect " + "call callsite")); + +// Command line option to enable/disable the warning about missing profile +// information. +static cl::opt<bool> NoPGOWarnMissing("no-pgo-warn-missing", cl::init(false), + cl::Hidden); + +// Command line option to enable/disable the warning about a hash mismatch in +// the profile data. +static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), + cl::Hidden); + namespace { -class PGOInstrumentationGen : public ModulePass { +class PGOInstrumentationGenLegacyPass : public ModulePass { public: static char ID; - PGOInstrumentationGen() : ModulePass(ID) { - initializePGOInstrumentationGenPass(*PassRegistry::getPassRegistry()); + PGOInstrumentationGenLegacyPass() : ModulePass(ID) { + initializePGOInstrumentationGenLegacyPassPass( + *PassRegistry::getPassRegistry()); } const char *getPassName() const override { @@ -111,16 +144,17 @@ private: } }; -class PGOInstrumentationUse : public ModulePass { +class PGOInstrumentationUseLegacyPass : public ModulePass { public: static char ID; // Provide the profile filename as the parameter. - PGOInstrumentationUse(std::string Filename = "") - : ModulePass(ID), ProfileFileName(Filename) { + PGOInstrumentationUseLegacyPass(std::string Filename = "") + : ModulePass(ID), ProfileFileName(std::move(Filename)) { if (!PGOTestProfileFile.empty()) ProfileFileName = PGOTestProfileFile; - initializePGOInstrumentationUsePass(*PassRegistry::getPassRegistry()); + initializePGOInstrumentationUseLegacyPassPass( + *PassRegistry::getPassRegistry()); } const char *getPassName() const override { @@ -129,37 +163,36 @@ public: private: std::string ProfileFileName; - std::unique_ptr<IndexedInstrProfReader> PGOReader; - bool runOnModule(Module &M) override; + bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<BlockFrequencyInfoWrapperPass>(); } }; } // end anonymous namespace -char PGOInstrumentationGen::ID = 0; -INITIALIZE_PASS_BEGIN(PGOInstrumentationGen, "pgo-instr-gen", +char PGOInstrumentationGenLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PGOInstrumentationGenLegacyPass, "pgo-instr-gen", "PGO instrumentation.", false, false) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_END(PGOInstrumentationGen, "pgo-instr-gen", +INITIALIZE_PASS_END(PGOInstrumentationGenLegacyPass, "pgo-instr-gen", "PGO instrumentation.", false, false) -ModulePass *llvm::createPGOInstrumentationGenPass() { - return new PGOInstrumentationGen(); +ModulePass *llvm::createPGOInstrumentationGenLegacyPass() { + return new PGOInstrumentationGenLegacyPass(); } -char PGOInstrumentationUse::ID = 0; -INITIALIZE_PASS_BEGIN(PGOInstrumentationUse, "pgo-instr-use", +char PGOInstrumentationUseLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PGOInstrumentationUseLegacyPass, "pgo-instr-use", "Read PGO instrumentation profile.", false, false) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_END(PGOInstrumentationUse, "pgo-instr-use", +INITIALIZE_PASS_END(PGOInstrumentationUseLegacyPass, "pgo-instr-use", "Read PGO instrumentation profile.", false, false) -ModulePass *llvm::createPGOInstrumentationUsePass(StringRef Filename) { - return new PGOInstrumentationUse(Filename.str()); +ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename) { + return new PGOInstrumentationUseLegacyPass(Filename.str()); } namespace { @@ -225,7 +258,7 @@ public: // Dump edges and BB information. void dumpInfo(std::string Str = "") const { MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + " Hash: " + - Twine(FunctionHash) + "\t" + Str); + Twine(FunctionHash) + "\t" + Str); } FuncPGOInstrumentation(Function &Func, bool CreateGlobalVar = false, @@ -247,7 +280,7 @@ public: if (CreateGlobalVar) FuncNameVar = createPGOFuncNameVar(F, FuncName); - }; + } }; // Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index @@ -305,7 +338,7 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) { return InstrBB; } -// Visit all edge and instrument the edges not in MST. +// Visit all edge and instrument the edges not in MST, and do value profiling. // Critical edges will be split. static void instrumentOneFunc(Function &F, Module *M, BranchProbabilityInfo *BPI, @@ -318,6 +351,7 @@ static void instrumentOneFunc(Function &F, Module *M, } uint32_t I = 0; + Type *I8PtrTy = Type::getInt8PtrTy(M->getContext()); for (auto &E : FuncInfo.MST.AllEdges) { BasicBlock *InstrBB = FuncInfo.getInstrBB(E.get()); if (!InstrBB) @@ -326,13 +360,34 @@ static void instrumentOneFunc(Function &F, Module *M, IRBuilder<> Builder(InstrBB, InstrBB->getFirstInsertionPt()); assert(Builder.GetInsertPoint() != InstrBB->end() && "Cannot get the Instrumentation point"); - Type *I8PtrTy = Type::getInt8PtrTy(M->getContext()); Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment), {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), Builder.getInt64(FuncInfo.FunctionHash), Builder.getInt32(NumCounters), Builder.getInt32(I++)}); } + + if (DisableValueProfiling) + return; + + unsigned NumIndirectCallSites = 0; + for (auto &I : findIndirectCallSites(F)) { + CallSite CS(I); + Value *Callee = CS.getCalledValue(); + DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " + << NumIndirectCallSites << "\n"); + IRBuilder<> Builder(I); + assert(Builder.GetInsertPoint() != I->getParent()->end() && + "Cannot get the Instrumentation point"); + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), + {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), + Builder.getInt64(FuncInfo.FunctionHash), + Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()), + Builder.getInt32(llvm::InstrProfValueKind::IPVK_IndirectCallTarget), + Builder.getInt32(NumIndirectCallSites++)}); + } + NumOfPGOICall += NumIndirectCallSites; } // This class represents a CFG edge in profile use compilation. @@ -352,7 +407,8 @@ struct PGOUseEdge : public PGOEdge { const std::string infoString() const { if (!CountValid) return PGOEdge::infoString(); - return (Twine(PGOEdge::infoString()) + " Count=" + Twine(CountValue)).str(); + return (Twine(PGOEdge::infoString()) + " Count=" + Twine(CountValue)) + .str(); } }; @@ -399,6 +455,33 @@ static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) { } class PGOUseFunc { +public: + PGOUseFunc(Function &Func, Module *Modu, BranchProbabilityInfo *BPI = nullptr, + BlockFrequencyInfo *BFI = nullptr) + : F(Func), M(Modu), FuncInfo(Func, false, BPI, BFI), + FreqAttr(FFA_Normal) {} + + // Read counts for the instrumented BB from profile. + bool readCounters(IndexedInstrProfReader *PGOReader); + + // Populate the counts for all BBs. + void populateCounters(); + + // Set the branch weights based on the count values. + void setBranchWeights(); + + // Annotate the indirect call sites. + void annotateIndirectCallSites(); + + // The hotness of the function from the profile count. + enum FuncFreqAttr { FFA_Normal, FFA_Cold, FFA_Hot }; + + // Return the function hotness from the profile. + FuncFreqAttr getFuncFreqAttr() const { return FreqAttr; } + + // Return the profile record for this function; + InstrProfRecord &getProfileRecord() { return ProfileRecord; } + private: Function &F; Module *M; @@ -414,6 +497,12 @@ private: // compilation. uint64_t ProgramMaxCount; + // ProfileRecord for this function. + InstrProfRecord ProfileRecord; + + // Function hotness info derived from profile. + FuncFreqAttr FreqAttr; + // Find the Instrumented BB and set the value. void setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile); @@ -427,7 +516,7 @@ private: // Set the hot/cold inline hints based on the count values. // FIXME: This function should be removed once the functionality in // the inliner is implemented. - void applyFunctionAttributes(uint64_t EntryCount, uint64_t MaxCount) { + void markFunctionAttributes(uint64_t EntryCount, uint64_t MaxCount) { if (ProgramMaxCount == 0) return; // Threshold of the hot functions. @@ -435,24 +524,10 @@ private: // Threshold of the cold functions. const BranchProbability ColdFunctionThreshold(2, 10000); if (EntryCount >= HotFunctionThreshold.scale(ProgramMaxCount)) - F.addFnAttr(llvm::Attribute::InlineHint); + FreqAttr = FFA_Hot; else if (MaxCount <= ColdFunctionThreshold.scale(ProgramMaxCount)) - F.addFnAttr(llvm::Attribute::Cold); + FreqAttr = FFA_Cold; } - -public: - PGOUseFunc(Function &Func, Module *Modu, BranchProbabilityInfo *BPI = nullptr, - BlockFrequencyInfo *BFI = nullptr) - : F(Func), M(Modu), FuncInfo(Func, false, BPI, BFI) {} - - // Read counts for the instrumented BB from profile. - bool readCounters(IndexedInstrProfReader *PGOReader); - - // Populate the counts for all BBs. - void populateCounters(); - - // Set the branch weights based on the count values. - void setBranchWeights(); }; // Visit all the edges and assign the count value for the instrumented @@ -511,21 +586,32 @@ void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) { // Return true if the profile are successfully read, and false on errors. bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) { auto &Ctx = M->getContext(); - ErrorOr<InstrProfRecord> Result = + Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord(FuncInfo.FuncName, FuncInfo.FunctionHash); - if (std::error_code EC = Result.getError()) { - if (EC == instrprof_error::unknown_function) - NumOfPGOMissing++; - else if (EC == instrprof_error::hash_mismatch || - EC == llvm::instrprof_error::malformed) - NumOfPGOMismatch++; - - std::string Msg = EC.message() + std::string(" ") + F.getName().str(); - Ctx.diagnose( - DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); + if (Error E = Result.takeError()) { + handleAllErrors(std::move(E), [&](const InstrProfError &IPE) { + auto Err = IPE.get(); + bool SkipWarning = false; + if (Err == instrprof_error::unknown_function) { + NumOfPGOMissing++; + SkipWarning = NoPGOWarnMissing; + } else if (Err == instrprof_error::hash_mismatch || + Err == instrprof_error::malformed) { + NumOfPGOMismatch++; + SkipWarning = NoPGOWarnMismatch; + } + + if (SkipWarning) + return; + + std::string Msg = IPE.message() + std::string(" ") + F.getName().str(); + Ctx.diagnose( + DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); + }); return false; } - std::vector<uint64_t> &CountFromProfile = Result.get().Counts; + ProfileRecord = std::move(Result.get()); + std::vector<uint64_t> &CountFromProfile = ProfileRecord.Counts; NumOfPGOFunc++; DEBUG(dbgs() << CountFromProfile.size() << " counts\n"); @@ -605,16 +691,17 @@ void PGOUseFunc::populateCounters() { } DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n"); +#ifndef NDEBUG // Assert every BB has a valid counter. + for (auto &BB : F) + assert(getBBInfo(&BB).CountValid && "BB count is not valid"); +#endif uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue; + F.setEntryCount(FuncEntryCount); uint64_t FuncMaxCount = FuncEntryCount; - for (auto &BB : F) { - assert(getBBInfo(&BB).CountValid && "BB count is not valid"); - uint64_t Count = getBBInfo(&BB).CountValue; - if (Count > FuncMaxCount) - FuncMaxCount = Count; - } - applyFunctionAttributes(FuncEntryCount, FuncMaxCount); + for (auto &BB : F) + FuncMaxCount = std::max(FuncMaxCount, getBBInfo(&BB).CountValue); + markFunctionAttributes(FuncEntryCount, FuncMaxCount); DEBUG(FuncInfo.dumpInfo("after reading profile.")); } @@ -642,7 +729,7 @@ void PGOUseFunc::setBranchWeights() { const PGOUseEdge *E = BBCountInfo.OutEdges[s]; const BasicBlock *SrcBB = E->SrcBB; const BasicBlock *DestBB = E->DestBB; - if (DestBB == 0) + if (DestBB == nullptr) continue; unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB); uint64_t EdgeCount = E->CountValue; @@ -663,56 +750,204 @@ void PGOUseFunc::setBranchWeights() { dbgs() << "\n";); } } + +// Traverse all the indirect callsites and annotate the instructions. +void PGOUseFunc::annotateIndirectCallSites() { + if (DisableValueProfiling) + return; + + // Create the PGOFuncName meta data. + createPGOFuncNameMetadata(F, FuncInfo.FuncName); + + unsigned IndirectCallSiteIndex = 0; + auto IndirectCallSites = findIndirectCallSites(F); + unsigned NumValueSites = + ProfileRecord.getNumValueSites(IPVK_IndirectCallTarget); + if (NumValueSites != IndirectCallSites.size()) { + std::string Msg = + std::string("Inconsistent number of indirect call sites: ") + + F.getName().str(); + auto &Ctx = M->getContext(); + Ctx.diagnose( + DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); + return; + } + + for (auto &I : IndirectCallSites) { + DEBUG(dbgs() << "Read one indirect call instrumentation: Index=" + << IndirectCallSiteIndex << " out of " << NumValueSites + << "\n"); + annotateValueSite(*M, *I, ProfileRecord, IPVK_IndirectCallTarget, + IndirectCallSiteIndex, MaxNumAnnotations); + IndirectCallSiteIndex++; + } +} } // end anonymous namespace -bool PGOInstrumentationGen::runOnModule(Module &M) { +// Create a COMDAT variable IR_LEVEL_PROF_VARNAME to make the runtime +// aware this is an ir_level profile so it can set the version flag. +static void createIRLevelProfileFlagVariable(Module &M) { + Type *IntTy64 = Type::getInt64Ty(M.getContext()); + uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF); + auto IRLevelVersionVariable = new GlobalVariable( + M, IntTy64, true, GlobalVariable::ExternalLinkage, + Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), + INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR)); + IRLevelVersionVariable->setVisibility(GlobalValue::DefaultVisibility); + Triple TT(M.getTargetTriple()); + if (!TT.supportsCOMDAT()) + IRLevelVersionVariable->setLinkage(GlobalValue::WeakAnyLinkage); + else + IRLevelVersionVariable->setComdat(M.getOrInsertComdat( + StringRef(INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR)))); +} + +static bool InstrumentAllFunctions( + Module &M, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, + function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { + createIRLevelProfileFlagVariable(M); for (auto &F : M) { if (F.isDeclaration()) continue; - BranchProbabilityInfo *BPI = - &(getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI()); - BlockFrequencyInfo *BFI = - &(getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI()); + auto *BPI = LookupBPI(F); + auto *BFI = LookupBFI(F); instrumentOneFunc(F, &M, BPI, BFI); } return true; } -static void setPGOCountOnFunc(PGOUseFunc &Func, - IndexedInstrProfReader *PGOReader) { - if (Func.readCounters(PGOReader)) { - Func.populateCounters(); - Func.setBranchWeights(); - } +bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) { + if (skipModule(M)) + return false; + + auto LookupBPI = [this](Function &F) { + return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI(); + }; + auto LookupBFI = [this](Function &F) { + return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); + }; + return InstrumentAllFunctions(M, LookupBPI, LookupBFI); } -bool PGOInstrumentationUse::runOnModule(Module &M) { +PreservedAnalyses PGOInstrumentationGen::run(Module &M, + AnalysisManager<Module> &AM) { + + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto LookupBPI = [&FAM](Function &F) { + return &FAM.getResult<BranchProbabilityAnalysis>(F); + }; + + auto LookupBFI = [&FAM](Function &F) { + return &FAM.getResult<BlockFrequencyAnalysis>(F); + }; + + if (!InstrumentAllFunctions(M, LookupBPI, LookupBFI)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +static bool annotateAllFunctions( + Module &M, StringRef ProfileFileName, + function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, + function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { DEBUG(dbgs() << "Read in profile counters: "); auto &Ctx = M.getContext(); // Read the counter array from file. auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName); - if (std::error_code EC = ReaderOrErr.getError()) { - Ctx.diagnose( - DiagnosticInfoPGOProfile(ProfileFileName.data(), EC.message())); + if (Error E = ReaderOrErr.takeError()) { + handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) { + Ctx.diagnose( + DiagnosticInfoPGOProfile(ProfileFileName.data(), EI.message())); + }); return false; } - PGOReader = std::move(ReaderOrErr.get()); + std::unique_ptr<IndexedInstrProfReader> PGOReader = + std::move(ReaderOrErr.get()); if (!PGOReader) { Ctx.diagnose(DiagnosticInfoPGOProfile(ProfileFileName.data(), - "Cannot get PGOReader")); + StringRef("Cannot get PGOReader"))); + return false; + } + // TODO: might need to change the warning once the clang option is finalized. + if (!PGOReader->isIRLevelProfile()) { + Ctx.diagnose(DiagnosticInfoPGOProfile( + ProfileFileName.data(), "Not an IR level instrumentation profile")); return false; } + std::vector<Function *> HotFunctions; + std::vector<Function *> ColdFunctions; for (auto &F : M) { if (F.isDeclaration()) continue; - BranchProbabilityInfo *BPI = - &(getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI()); - BlockFrequencyInfo *BFI = - &(getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI()); + auto *BPI = LookupBPI(F); + auto *BFI = LookupBFI(F); PGOUseFunc Func(F, &M, BPI, BFI); - setPGOCountOnFunc(Func, PGOReader.get()); + if (!Func.readCounters(PGOReader.get())) + continue; + Func.populateCounters(); + Func.setBranchWeights(); + Func.annotateIndirectCallSites(); + PGOUseFunc::FuncFreqAttr FreqAttr = Func.getFuncFreqAttr(); + if (FreqAttr == PGOUseFunc::FFA_Cold) + ColdFunctions.push_back(&F); + else if (FreqAttr == PGOUseFunc::FFA_Hot) + HotFunctions.push_back(&F); + } + M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext())); + // Set function hotness attribute from the profile. + // We have to apply these attributes at the end because their presence + // can affect the BranchProbabilityInfo of any callers, resulting in an + // inconsistent MST between prof-gen and prof-use. + for (auto &F : HotFunctions) { + F->addFnAttr(llvm::Attribute::InlineHint); + DEBUG(dbgs() << "Set inline attribute to function: " << F->getName() + << "\n"); + } + for (auto &F : ColdFunctions) { + F->addFnAttr(llvm::Attribute::Cold); + DEBUG(dbgs() << "Set cold attribute to function: " << F->getName() << "\n"); } + return true; } + +PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename) + : ProfileFileName(std::move(Filename)) { + if (!PGOTestProfileFile.empty()) + ProfileFileName = PGOTestProfileFile; +} + +PreservedAnalyses PGOInstrumentationUse::run(Module &M, + AnalysisManager<Module> &AM) { + + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto LookupBPI = [&FAM](Function &F) { + return &FAM.getResult<BranchProbabilityAnalysis>(F); + }; + + auto LookupBFI = [&FAM](Function &F) { + return &FAM.getResult<BlockFrequencyAnalysis>(F); + }; + + if (!annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) { + if (skipModule(M)) + return false; + + auto LookupBPI = [this](Function &F) { + return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI(); + }; + auto LookupBFI = [this](Function &F) { + return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); + }; + + return annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI); +} diff --git a/lib/Transforms/Instrumentation/SafeStack.cpp b/lib/Transforms/Instrumentation/SafeStack.cpp deleted file mode 100644 index abed465f102d..000000000000 --- a/lib/Transforms/Instrumentation/SafeStack.cpp +++ /dev/null @@ -1,760 +0,0 @@ -//===-- SafeStack.cpp - Safe Stack Insertion ------------------------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This pass splits the stack into the safe stack (kept as-is for LLVM backend) -// and the unsafe stack (explicitly allocated and managed through the runtime -// support library). -// -// http://clang.llvm.org/docs/SafeStack.html -// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/Instrumentation.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Triple.h" -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/CodeGen/Passes.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DIBuilder.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Module.h" -#include "llvm/Pass.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/Format.h" -#include "llvm/Support/MathExtras.h" -#include "llvm/Support/raw_os_ostream.h" -#include "llvm/Target/TargetLowering.h" -#include "llvm/Target/TargetSubtargetInfo.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/ModuleUtils.h" - -using namespace llvm; - -#define DEBUG_TYPE "safestack" - -enum UnsafeStackPtrStorageVal { ThreadLocalUSP, SingleThreadUSP }; - -static cl::opt<UnsafeStackPtrStorageVal> USPStorage("safe-stack-usp-storage", - cl::Hidden, cl::init(ThreadLocalUSP), - cl::desc("Type of storage for the unsafe stack pointer"), - cl::values(clEnumValN(ThreadLocalUSP, "thread-local", - "Thread-local storage"), - clEnumValN(SingleThreadUSP, "single-thread", - "Non-thread-local storage"), - clEnumValEnd)); - -namespace llvm { - -STATISTIC(NumFunctions, "Total number of functions"); -STATISTIC(NumUnsafeStackFunctions, "Number of functions with unsafe stack"); -STATISTIC(NumUnsafeStackRestorePointsFunctions, - "Number of functions that use setjmp or exceptions"); - -STATISTIC(NumAllocas, "Total number of allocas"); -STATISTIC(NumUnsafeStaticAllocas, "Number of unsafe static allocas"); -STATISTIC(NumUnsafeDynamicAllocas, "Number of unsafe dynamic allocas"); -STATISTIC(NumUnsafeByValArguments, "Number of unsafe byval arguments"); -STATISTIC(NumUnsafeStackRestorePoints, "Number of setjmps and landingpads"); - -} // namespace llvm - -namespace { - -/// Rewrite an SCEV expression for a memory access address to an expression that -/// represents offset from the given alloca. -/// -/// The implementation simply replaces all mentions of the alloca with zero. -class AllocaOffsetRewriter : public SCEVRewriteVisitor<AllocaOffsetRewriter> { - const Value *AllocaPtr; - -public: - AllocaOffsetRewriter(ScalarEvolution &SE, const Value *AllocaPtr) - : SCEVRewriteVisitor(SE), AllocaPtr(AllocaPtr) {} - - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - if (Expr->getValue() == AllocaPtr) - return SE.getZero(Expr->getType()); - return Expr; - } -}; - -/// The SafeStack pass splits the stack of each function into the safe -/// stack, which is only accessed through memory safe dereferences (as -/// determined statically), and the unsafe stack, which contains all -/// local variables that are accessed in ways that we can't prove to -/// be safe. -class SafeStack : public FunctionPass { - const TargetMachine *TM; - const TargetLoweringBase *TL; - const DataLayout *DL; - ScalarEvolution *SE; - - Type *StackPtrTy; - Type *IntPtrTy; - Type *Int32Ty; - Type *Int8Ty; - - Value *UnsafeStackPtr = nullptr; - - /// Unsafe stack alignment. Each stack frame must ensure that the stack is - /// aligned to this value. We need to re-align the unsafe stack if the - /// alignment of any object on the stack exceeds this value. - /// - /// 16 seems like a reasonable upper bound on the alignment of objects that we - /// might expect to appear on the stack on most common targets. - enum { StackAlignment = 16 }; - - /// \brief Build a value representing a pointer to the unsafe stack pointer. - Value *getOrCreateUnsafeStackPtr(IRBuilder<> &IRB, Function &F); - - /// \brief Find all static allocas, dynamic allocas, return instructions and - /// stack restore points (exception unwind blocks and setjmp calls) in the - /// given function and append them to the respective vectors. - void findInsts(Function &F, SmallVectorImpl<AllocaInst *> &StaticAllocas, - SmallVectorImpl<AllocaInst *> &DynamicAllocas, - SmallVectorImpl<Argument *> &ByValArguments, - SmallVectorImpl<ReturnInst *> &Returns, - SmallVectorImpl<Instruction *> &StackRestorePoints); - - /// \brief Calculate the allocation size of a given alloca. Returns 0 if the - /// size can not be statically determined. - uint64_t getStaticAllocaAllocationSize(const AllocaInst* AI); - - /// \brief Allocate space for all static allocas in \p StaticAllocas, - /// replace allocas with pointers into the unsafe stack and generate code to - /// restore the stack pointer before all return instructions in \p Returns. - /// - /// \returns A pointer to the top of the unsafe stack after all unsafe static - /// allocas are allocated. - Value *moveStaticAllocasToUnsafeStack(IRBuilder<> &IRB, Function &F, - ArrayRef<AllocaInst *> StaticAllocas, - ArrayRef<Argument *> ByValArguments, - ArrayRef<ReturnInst *> Returns); - - /// \brief Generate code to restore the stack after all stack restore points - /// in \p StackRestorePoints. - /// - /// \returns A local variable in which to maintain the dynamic top of the - /// unsafe stack if needed. - AllocaInst * - createStackRestorePoints(IRBuilder<> &IRB, Function &F, - ArrayRef<Instruction *> StackRestorePoints, - Value *StaticTop, bool NeedDynamicTop); - - /// \brief Replace all allocas in \p DynamicAllocas with code to allocate - /// space dynamically on the unsafe stack and store the dynamic unsafe stack - /// top to \p DynamicTop if non-null. - void moveDynamicAllocasToUnsafeStack(Function &F, Value *UnsafeStackPtr, - AllocaInst *DynamicTop, - ArrayRef<AllocaInst *> DynamicAllocas); - - bool IsSafeStackAlloca(const Value *AllocaPtr, uint64_t AllocaSize); - - bool IsMemIntrinsicSafe(const MemIntrinsic *MI, const Use &U, - const Value *AllocaPtr, uint64_t AllocaSize); - bool IsAccessSafe(Value *Addr, uint64_t Size, const Value *AllocaPtr, - uint64_t AllocaSize); - -public: - static char ID; // Pass identification, replacement for typeid. - SafeStack(const TargetMachine *TM) - : FunctionPass(ID), TM(TM), TL(nullptr), DL(nullptr) { - initializeSafeStackPass(*PassRegistry::getPassRegistry()); - } - SafeStack() : SafeStack(nullptr) {} - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<ScalarEvolutionWrapperPass>(); - } - - bool doInitialization(Module &M) override { - DL = &M.getDataLayout(); - - StackPtrTy = Type::getInt8PtrTy(M.getContext()); - IntPtrTy = DL->getIntPtrType(M.getContext()); - Int32Ty = Type::getInt32Ty(M.getContext()); - Int8Ty = Type::getInt8Ty(M.getContext()); - - return false; - } - - bool runOnFunction(Function &F) override; -}; // class SafeStack - -uint64_t SafeStack::getStaticAllocaAllocationSize(const AllocaInst* AI) { - uint64_t Size = DL->getTypeAllocSize(AI->getAllocatedType()); - if (AI->isArrayAllocation()) { - auto C = dyn_cast<ConstantInt>(AI->getArraySize()); - if (!C) - return 0; - Size *= C->getZExtValue(); - } - return Size; -} - -bool SafeStack::IsAccessSafe(Value *Addr, uint64_t AccessSize, - const Value *AllocaPtr, uint64_t AllocaSize) { - AllocaOffsetRewriter Rewriter(*SE, AllocaPtr); - const SCEV *Expr = Rewriter.visit(SE->getSCEV(Addr)); - - uint64_t BitWidth = SE->getTypeSizeInBits(Expr->getType()); - ConstantRange AccessStartRange = SE->getUnsignedRange(Expr); - ConstantRange SizeRange = - ConstantRange(APInt(BitWidth, 0), APInt(BitWidth, AccessSize)); - ConstantRange AccessRange = AccessStartRange.add(SizeRange); - ConstantRange AllocaRange = - ConstantRange(APInt(BitWidth, 0), APInt(BitWidth, AllocaSize)); - bool Safe = AllocaRange.contains(AccessRange); - - DEBUG(dbgs() << "[SafeStack] " - << (isa<AllocaInst>(AllocaPtr) ? "Alloca " : "ByValArgument ") - << *AllocaPtr << "\n" - << " Access " << *Addr << "\n" - << " SCEV " << *Expr - << " U: " << SE->getUnsignedRange(Expr) - << ", S: " << SE->getSignedRange(Expr) << "\n" - << " Range " << AccessRange << "\n" - << " AllocaRange " << AllocaRange << "\n" - << " " << (Safe ? "safe" : "unsafe") << "\n"); - - return Safe; -} - -bool SafeStack::IsMemIntrinsicSafe(const MemIntrinsic *MI, const Use &U, - const Value *AllocaPtr, - uint64_t AllocaSize) { - // All MemIntrinsics have destination address in Arg0 and size in Arg2. - if (MI->getRawDest() != U) return true; - const auto *Len = dyn_cast<ConstantInt>(MI->getLength()); - // Non-constant size => unsafe. FIXME: try SCEV getRange. - if (!Len) return false; - return IsAccessSafe(U, Len->getZExtValue(), AllocaPtr, AllocaSize); -} - -/// Check whether a given allocation must be put on the safe -/// stack or not. The function analyzes all uses of AI and checks whether it is -/// only accessed in a memory safe way (as decided statically). -bool SafeStack::IsSafeStackAlloca(const Value *AllocaPtr, uint64_t AllocaSize) { - // Go through all uses of this alloca and check whether all accesses to the - // allocated object are statically known to be memory safe and, hence, the - // object can be placed on the safe stack. - SmallPtrSet<const Value *, 16> Visited; - SmallVector<const Value *, 8> WorkList; - WorkList.push_back(AllocaPtr); - - // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc. - while (!WorkList.empty()) { - const Value *V = WorkList.pop_back_val(); - for (const Use &UI : V->uses()) { - auto I = cast<const Instruction>(UI.getUser()); - assert(V == UI.get()); - - switch (I->getOpcode()) { - case Instruction::Load: { - if (!IsAccessSafe(UI, DL->getTypeStoreSize(I->getType()), AllocaPtr, - AllocaSize)) - return false; - break; - } - case Instruction::VAArg: - // "va-arg" from a pointer is safe. - break; - case Instruction::Store: { - if (V == I->getOperand(0)) { - // Stored the pointer - conservatively assume it may be unsafe. - DEBUG(dbgs() << "[SafeStack] Unsafe alloca: " << *AllocaPtr - << "\n store of address: " << *I << "\n"); - return false; - } - - if (!IsAccessSafe(UI, DL->getTypeStoreSize(I->getOperand(0)->getType()), - AllocaPtr, AllocaSize)) - return false; - break; - } - case Instruction::Ret: { - // Information leak. - return false; - } - - case Instruction::Call: - case Instruction::Invoke: { - ImmutableCallSite CS(I); - - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) - continue; - } - - if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) { - if (!IsMemIntrinsicSafe(MI, UI, AllocaPtr, AllocaSize)) { - DEBUG(dbgs() << "[SafeStack] Unsafe alloca: " << *AllocaPtr - << "\n unsafe memintrinsic: " << *I - << "\n"); - return false; - } - continue; - } - - // LLVM 'nocapture' attribute is only set for arguments whose address - // is not stored, passed around, or used in any other non-trivial way. - // We assume that passing a pointer to an object as a 'nocapture - // readnone' argument is safe. - // FIXME: a more precise solution would require an interprocedural - // analysis here, which would look at all uses of an argument inside - // the function being called. - ImmutableCallSite::arg_iterator B = CS.arg_begin(), E = CS.arg_end(); - for (ImmutableCallSite::arg_iterator A = B; A != E; ++A) - if (A->get() == V) - if (!(CS.doesNotCapture(A - B) && (CS.doesNotAccessMemory(A - B) || - CS.doesNotAccessMemory()))) { - DEBUG(dbgs() << "[SafeStack] Unsafe alloca: " << *AllocaPtr - << "\n unsafe call: " << *I << "\n"); - return false; - } - continue; - } - - default: - if (Visited.insert(I).second) - WorkList.push_back(cast<const Instruction>(I)); - } - } - } - - // All uses of the alloca are safe, we can place it on the safe stack. - return true; -} - -Value *SafeStack::getOrCreateUnsafeStackPtr(IRBuilder<> &IRB, Function &F) { - // Check if there is a target-specific location for the unsafe stack pointer. - if (TL) - if (Value *V = TL->getSafeStackPointerLocation(IRB)) - return V; - - // Otherwise, assume the target links with compiler-rt, which provides a - // thread-local variable with a magic name. - Module &M = *F.getParent(); - const char *UnsafeStackPtrVar = "__safestack_unsafe_stack_ptr"; - auto UnsafeStackPtr = - dyn_cast_or_null<GlobalVariable>(M.getNamedValue(UnsafeStackPtrVar)); - - bool UseTLS = USPStorage == ThreadLocalUSP; - - if (!UnsafeStackPtr) { - auto TLSModel = UseTLS ? - GlobalValue::InitialExecTLSModel : - GlobalValue::NotThreadLocal; - // The global variable is not defined yet, define it ourselves. - // We use the initial-exec TLS model because we do not support the - // variable living anywhere other than in the main executable. - UnsafeStackPtr = new GlobalVariable( - M, StackPtrTy, false, GlobalValue::ExternalLinkage, nullptr, - UnsafeStackPtrVar, nullptr, TLSModel); - } else { - // The variable exists, check its type and attributes. - if (UnsafeStackPtr->getValueType() != StackPtrTy) - report_fatal_error(Twine(UnsafeStackPtrVar) + " must have void* type"); - if (UseTLS != UnsafeStackPtr->isThreadLocal()) - report_fatal_error(Twine(UnsafeStackPtrVar) + " must " + - (UseTLS ? "" : "not ") + "be thread-local"); - } - return UnsafeStackPtr; -} - -void SafeStack::findInsts(Function &F, - SmallVectorImpl<AllocaInst *> &StaticAllocas, - SmallVectorImpl<AllocaInst *> &DynamicAllocas, - SmallVectorImpl<Argument *> &ByValArguments, - SmallVectorImpl<ReturnInst *> &Returns, - SmallVectorImpl<Instruction *> &StackRestorePoints) { - for (Instruction &I : instructions(&F)) { - if (auto AI = dyn_cast<AllocaInst>(&I)) { - ++NumAllocas; - - uint64_t Size = getStaticAllocaAllocationSize(AI); - if (IsSafeStackAlloca(AI, Size)) - continue; - - if (AI->isStaticAlloca()) { - ++NumUnsafeStaticAllocas; - StaticAllocas.push_back(AI); - } else { - ++NumUnsafeDynamicAllocas; - DynamicAllocas.push_back(AI); - } - } else if (auto RI = dyn_cast<ReturnInst>(&I)) { - Returns.push_back(RI); - } else if (auto CI = dyn_cast<CallInst>(&I)) { - // setjmps require stack restore. - if (CI->getCalledFunction() && CI->canReturnTwice()) - StackRestorePoints.push_back(CI); - } else if (auto LP = dyn_cast<LandingPadInst>(&I)) { - // Exception landing pads require stack restore. - StackRestorePoints.push_back(LP); - } else if (auto II = dyn_cast<IntrinsicInst>(&I)) { - if (II->getIntrinsicID() == Intrinsic::gcroot) - llvm::report_fatal_error( - "gcroot intrinsic not compatible with safestack attribute"); - } - } - for (Argument &Arg : F.args()) { - if (!Arg.hasByValAttr()) - continue; - uint64_t Size = - DL->getTypeStoreSize(Arg.getType()->getPointerElementType()); - if (IsSafeStackAlloca(&Arg, Size)) - continue; - - ++NumUnsafeByValArguments; - ByValArguments.push_back(&Arg); - } -} - -AllocaInst * -SafeStack::createStackRestorePoints(IRBuilder<> &IRB, Function &F, - ArrayRef<Instruction *> StackRestorePoints, - Value *StaticTop, bool NeedDynamicTop) { - if (StackRestorePoints.empty()) - return nullptr; - - // We need the current value of the shadow stack pointer to restore - // after longjmp or exception catching. - - // FIXME: On some platforms this could be handled by the longjmp/exception - // runtime itself. - - AllocaInst *DynamicTop = nullptr; - if (NeedDynamicTop) - // If we also have dynamic alloca's, the stack pointer value changes - // throughout the function. For now we store it in an alloca. - DynamicTop = IRB.CreateAlloca(StackPtrTy, /*ArraySize=*/nullptr, - "unsafe_stack_dynamic_ptr"); - - if (!StaticTop) - // We need the original unsafe stack pointer value, even if there are - // no unsafe static allocas. - StaticTop = IRB.CreateLoad(UnsafeStackPtr, false, "unsafe_stack_ptr"); - - if (NeedDynamicTop) - IRB.CreateStore(StaticTop, DynamicTop); - - // Restore current stack pointer after longjmp/exception catch. - for (Instruction *I : StackRestorePoints) { - ++NumUnsafeStackRestorePoints; - - IRB.SetInsertPoint(I->getNextNode()); - Value *CurrentTop = DynamicTop ? IRB.CreateLoad(DynamicTop) : StaticTop; - IRB.CreateStore(CurrentTop, UnsafeStackPtr); - } - - return DynamicTop; -} - -Value *SafeStack::moveStaticAllocasToUnsafeStack( - IRBuilder<> &IRB, Function &F, ArrayRef<AllocaInst *> StaticAllocas, - ArrayRef<Argument *> ByValArguments, ArrayRef<ReturnInst *> Returns) { - if (StaticAllocas.empty() && ByValArguments.empty()) - return nullptr; - - DIBuilder DIB(*F.getParent()); - - // We explicitly compute and set the unsafe stack layout for all unsafe - // static alloca instructions. We save the unsafe "base pointer" in the - // prologue into a local variable and restore it in the epilogue. - - // Load the current stack pointer (we'll also use it as a base pointer). - // FIXME: use a dedicated register for it ? - Instruction *BasePointer = - IRB.CreateLoad(UnsafeStackPtr, false, "unsafe_stack_ptr"); - assert(BasePointer->getType() == StackPtrTy); - - for (ReturnInst *RI : Returns) { - IRB.SetInsertPoint(RI); - IRB.CreateStore(BasePointer, UnsafeStackPtr); - } - - // Compute maximum alignment among static objects on the unsafe stack. - unsigned MaxAlignment = 0; - for (Argument *Arg : ByValArguments) { - Type *Ty = Arg->getType()->getPointerElementType(); - unsigned Align = std::max((unsigned)DL->getPrefTypeAlignment(Ty), - Arg->getParamAlignment()); - if (Align > MaxAlignment) - MaxAlignment = Align; - } - for (AllocaInst *AI : StaticAllocas) { - Type *Ty = AI->getAllocatedType(); - unsigned Align = - std::max((unsigned)DL->getPrefTypeAlignment(Ty), AI->getAlignment()); - if (Align > MaxAlignment) - MaxAlignment = Align; - } - - if (MaxAlignment > StackAlignment) { - // Re-align the base pointer according to the max requested alignment. - assert(isPowerOf2_32(MaxAlignment)); - IRB.SetInsertPoint(BasePointer->getNextNode()); - BasePointer = cast<Instruction>(IRB.CreateIntToPtr( - IRB.CreateAnd(IRB.CreatePtrToInt(BasePointer, IntPtrTy), - ConstantInt::get(IntPtrTy, ~uint64_t(MaxAlignment - 1))), - StackPtrTy)); - } - - int64_t StaticOffset = 0; // Current stack top. - IRB.SetInsertPoint(BasePointer->getNextNode()); - - for (Argument *Arg : ByValArguments) { - Type *Ty = Arg->getType()->getPointerElementType(); - - uint64_t Size = DL->getTypeStoreSize(Ty); - if (Size == 0) - Size = 1; // Don't create zero-sized stack objects. - - // Ensure the object is properly aligned. - unsigned Align = std::max((unsigned)DL->getPrefTypeAlignment(Ty), - Arg->getParamAlignment()); - - // Add alignment. - // NOTE: we ensure that BasePointer itself is aligned to >= Align. - StaticOffset += Size; - StaticOffset = RoundUpToAlignment(StaticOffset, Align); - - Value *Off = IRB.CreateGEP(BasePointer, // BasePointer is i8* - ConstantInt::get(Int32Ty, -StaticOffset)); - Value *NewArg = IRB.CreateBitCast(Off, Arg->getType(), - Arg->getName() + ".unsafe-byval"); - - // Replace alloc with the new location. - replaceDbgDeclare(Arg, BasePointer, BasePointer->getNextNode(), DIB, - /*Deref=*/true, -StaticOffset); - Arg->replaceAllUsesWith(NewArg); - IRB.SetInsertPoint(cast<Instruction>(NewArg)->getNextNode()); - IRB.CreateMemCpy(Off, Arg, Size, Arg->getParamAlignment()); - } - - // Allocate space for every unsafe static AllocaInst on the unsafe stack. - for (AllocaInst *AI : StaticAllocas) { - IRB.SetInsertPoint(AI); - - Type *Ty = AI->getAllocatedType(); - uint64_t Size = getStaticAllocaAllocationSize(AI); - if (Size == 0) - Size = 1; // Don't create zero-sized stack objects. - - // Ensure the object is properly aligned. - unsigned Align = - std::max((unsigned)DL->getPrefTypeAlignment(Ty), AI->getAlignment()); - - // Add alignment. - // NOTE: we ensure that BasePointer itself is aligned to >= Align. - StaticOffset += Size; - StaticOffset = RoundUpToAlignment(StaticOffset, Align); - - Value *Off = IRB.CreateGEP(BasePointer, // BasePointer is i8* - ConstantInt::get(Int32Ty, -StaticOffset)); - Value *NewAI = IRB.CreateBitCast(Off, AI->getType(), AI->getName()); - if (AI->hasName() && isa<Instruction>(NewAI)) - cast<Instruction>(NewAI)->takeName(AI); - - // Replace alloc with the new location. - replaceDbgDeclareForAlloca(AI, BasePointer, DIB, /*Deref=*/true, -StaticOffset); - AI->replaceAllUsesWith(NewAI); - AI->eraseFromParent(); - } - - // Re-align BasePointer so that our callees would see it aligned as - // expected. - // FIXME: no need to update BasePointer in leaf functions. - StaticOffset = RoundUpToAlignment(StaticOffset, StackAlignment); - - // Update shadow stack pointer in the function epilogue. - IRB.SetInsertPoint(BasePointer->getNextNode()); - - Value *StaticTop = - IRB.CreateGEP(BasePointer, ConstantInt::get(Int32Ty, -StaticOffset), - "unsafe_stack_static_top"); - IRB.CreateStore(StaticTop, UnsafeStackPtr); - return StaticTop; -} - -void SafeStack::moveDynamicAllocasToUnsafeStack( - Function &F, Value *UnsafeStackPtr, AllocaInst *DynamicTop, - ArrayRef<AllocaInst *> DynamicAllocas) { - DIBuilder DIB(*F.getParent()); - - for (AllocaInst *AI : DynamicAllocas) { - IRBuilder<> IRB(AI); - - // Compute the new SP value (after AI). - Value *ArraySize = AI->getArraySize(); - if (ArraySize->getType() != IntPtrTy) - ArraySize = IRB.CreateIntCast(ArraySize, IntPtrTy, false); - - Type *Ty = AI->getAllocatedType(); - uint64_t TySize = DL->getTypeAllocSize(Ty); - Value *Size = IRB.CreateMul(ArraySize, ConstantInt::get(IntPtrTy, TySize)); - - Value *SP = IRB.CreatePtrToInt(IRB.CreateLoad(UnsafeStackPtr), IntPtrTy); - SP = IRB.CreateSub(SP, Size); - - // Align the SP value to satisfy the AllocaInst, type and stack alignments. - unsigned Align = std::max( - std::max((unsigned)DL->getPrefTypeAlignment(Ty), AI->getAlignment()), - (unsigned)StackAlignment); - - assert(isPowerOf2_32(Align)); - Value *NewTop = IRB.CreateIntToPtr( - IRB.CreateAnd(SP, ConstantInt::get(IntPtrTy, ~uint64_t(Align - 1))), - StackPtrTy); - - // Save the stack pointer. - IRB.CreateStore(NewTop, UnsafeStackPtr); - if (DynamicTop) - IRB.CreateStore(NewTop, DynamicTop); - - Value *NewAI = IRB.CreatePointerCast(NewTop, AI->getType()); - if (AI->hasName() && isa<Instruction>(NewAI)) - NewAI->takeName(AI); - - replaceDbgDeclareForAlloca(AI, NewAI, DIB, /*Deref=*/true); - AI->replaceAllUsesWith(NewAI); - AI->eraseFromParent(); - } - - if (!DynamicAllocas.empty()) { - // Now go through the instructions again, replacing stacksave/stackrestore. - for (inst_iterator It = inst_begin(&F), Ie = inst_end(&F); It != Ie;) { - Instruction *I = &*(It++); - auto II = dyn_cast<IntrinsicInst>(I); - if (!II) - continue; - - if (II->getIntrinsicID() == Intrinsic::stacksave) { - IRBuilder<> IRB(II); - Instruction *LI = IRB.CreateLoad(UnsafeStackPtr); - LI->takeName(II); - II->replaceAllUsesWith(LI); - II->eraseFromParent(); - } else if (II->getIntrinsicID() == Intrinsic::stackrestore) { - IRBuilder<> IRB(II); - Instruction *SI = IRB.CreateStore(II->getArgOperand(0), UnsafeStackPtr); - SI->takeName(II); - assert(II->use_empty()); - II->eraseFromParent(); - } - } - } -} - -bool SafeStack::runOnFunction(Function &F) { - DEBUG(dbgs() << "[SafeStack] Function: " << F.getName() << "\n"); - - if (!F.hasFnAttribute(Attribute::SafeStack)) { - DEBUG(dbgs() << "[SafeStack] safestack is not requested" - " for this function\n"); - return false; - } - - if (F.isDeclaration()) { - DEBUG(dbgs() << "[SafeStack] function definition" - " is not available\n"); - return false; - } - - TL = TM ? TM->getSubtargetImpl(F)->getTargetLowering() : nullptr; - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - - { - // Make sure the regular stack protector won't run on this function - // (safestack attribute takes precedence). - AttrBuilder B; - B.addAttribute(Attribute::StackProtect) - .addAttribute(Attribute::StackProtectReq) - .addAttribute(Attribute::StackProtectStrong); - F.removeAttributes( - AttributeSet::FunctionIndex, - AttributeSet::get(F.getContext(), AttributeSet::FunctionIndex, B)); - } - - ++NumFunctions; - - SmallVector<AllocaInst *, 16> StaticAllocas; - SmallVector<AllocaInst *, 4> DynamicAllocas; - SmallVector<Argument *, 4> ByValArguments; - SmallVector<ReturnInst *, 4> Returns; - - // Collect all points where stack gets unwound and needs to be restored - // This is only necessary because the runtime (setjmp and unwind code) is - // not aware of the unsafe stack and won't unwind/restore it prorerly. - // To work around this problem without changing the runtime, we insert - // instrumentation to restore the unsafe stack pointer when necessary. - SmallVector<Instruction *, 4> StackRestorePoints; - - // Find all static and dynamic alloca instructions that must be moved to the - // unsafe stack, all return instructions and stack restore points. - findInsts(F, StaticAllocas, DynamicAllocas, ByValArguments, Returns, - StackRestorePoints); - - if (StaticAllocas.empty() && DynamicAllocas.empty() && - ByValArguments.empty() && StackRestorePoints.empty()) - return false; // Nothing to do in this function. - - if (!StaticAllocas.empty() || !DynamicAllocas.empty() || - !ByValArguments.empty()) - ++NumUnsafeStackFunctions; // This function has the unsafe stack. - - if (!StackRestorePoints.empty()) - ++NumUnsafeStackRestorePointsFunctions; - - IRBuilder<> IRB(&F.front(), F.begin()->getFirstInsertionPt()); - UnsafeStackPtr = getOrCreateUnsafeStackPtr(IRB, F); - - // The top of the unsafe stack after all unsafe static allocas are allocated. - Value *StaticTop = moveStaticAllocasToUnsafeStack(IRB, F, StaticAllocas, - ByValArguments, Returns); - - // Safe stack object that stores the current unsafe stack top. It is updated - // as unsafe dynamic (non-constant-sized) allocas are allocated and freed. - // This is only needed if we need to restore stack pointer after longjmp - // or exceptions, and we have dynamic allocations. - // FIXME: a better alternative might be to store the unsafe stack pointer - // before setjmp / invoke instructions. - AllocaInst *DynamicTop = createStackRestorePoints( - IRB, F, StackRestorePoints, StaticTop, !DynamicAllocas.empty()); - - // Handle dynamic allocas. - moveDynamicAllocasToUnsafeStack(F, UnsafeStackPtr, DynamicTop, - DynamicAllocas); - - DEBUG(dbgs() << "[SafeStack] safestack applied\n"); - return true; -} - -} // anonymous namespace - -char SafeStack::ID = 0; -INITIALIZE_TM_PASS_BEGIN(SafeStack, "safe-stack", - "Safe Stack instrumentation pass", false, false) -INITIALIZE_TM_PASS_END(SafeStack, "safe-stack", - "Safe Stack instrumentation pass", false, false) - -FunctionPass *llvm::createSafeStackPass(const llvm::TargetMachine *TM) { - return new SafeStack(TM); -} diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 09de7a2cda2b..7d404473655d 100644 --- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -28,13 +28,15 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/EHPersonalities.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" @@ -45,6 +47,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -53,22 +56,28 @@ using namespace llvm; #define DEBUG_TYPE "sancov" -static const char *const kSanCovModuleInitName = "__sanitizer_cov_module_init"; -static const char *const kSanCovName = "__sanitizer_cov"; -static const char *const kSanCovWithCheckName = "__sanitizer_cov_with_check"; -static const char *const kSanCovIndirCallName = "__sanitizer_cov_indir_call16"; -static const char *const kSanCovTraceEnter = "__sanitizer_cov_trace_func_enter"; -static const char *const kSanCovTraceBB = "__sanitizer_cov_trace_basic_block"; -static const char *const kSanCovTraceCmp = "__sanitizer_cov_trace_cmp"; -static const char *const kSanCovTraceSwitch = "__sanitizer_cov_trace_switch"; -static const char *const kSanCovModuleCtorName = "sancov.module_ctor"; -static const uint64_t kSanCtorAndDtorPriority = 2; - -static cl::opt<int> ClCoverageLevel("sanitizer-coverage-level", - cl::desc("Sanitizer Coverage. 0: none, 1: entry block, 2: all blocks, " - "3: all blocks and critical edges, " - "4: above plus indirect calls"), - cl::Hidden, cl::init(0)); +static const char *const SanCovModuleInitName = "__sanitizer_cov_module_init"; +static const char *const SanCovName = "__sanitizer_cov"; +static const char *const SanCovWithCheckName = "__sanitizer_cov_with_check"; +static const char *const SanCovIndirCallName = "__sanitizer_cov_indir_call16"; +static const char *const SanCovTracePCIndirName = + "__sanitizer_cov_trace_pc_indir"; +static const char *const SanCovTraceEnterName = + "__sanitizer_cov_trace_func_enter"; +static const char *const SanCovTraceBBName = + "__sanitizer_cov_trace_basic_block"; +static const char *const SanCovTracePCName = "__sanitizer_cov_trace_pc"; +static const char *const SanCovTraceCmpName = "__sanitizer_cov_trace_cmp"; +static const char *const SanCovTraceSwitchName = "__sanitizer_cov_trace_switch"; +static const char *const SanCovModuleCtorName = "sancov.module_ctor"; +static const uint64_t SanCtorAndDtorPriority = 2; + +static cl::opt<int> ClCoverageLevel( + "sanitizer-coverage-level", + cl::desc("Sanitizer Coverage. 0: none, 1: entry block, 2: all blocks, " + "3: all blocks and critical edges, " + "4: above plus indirect calls"), + cl::Hidden, cl::init(0)); static cl::opt<unsigned> ClCoverageBlockThreshold( "sanitizer-coverage-block-threshold", @@ -82,12 +91,21 @@ static cl::opt<bool> "callbacks at every basic block"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClExperimentalTracePC("sanitizer-coverage-trace-pc", + cl::desc("Experimental pc tracing"), + cl::Hidden, cl::init(false)); + static cl::opt<bool> ClExperimentalCMPTracing("sanitizer-coverage-experimental-trace-compares", cl::desc("Experimental tracing of CMP and similar " "instructions"), cl::Hidden, cl::init(false)); +static cl::opt<bool> + ClPruneBlocks("sanitizer-coverage-prune-blocks", + cl::desc("Reduce the number of instrumented blocks"), + cl::Hidden, cl::init(true)); + // Experimental 8-bit counters used as an additional search heuristic during // coverage-guided fuzzing. // The counters are not thread-friendly: @@ -131,22 +149,28 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) { Options.TraceBB |= ClExperimentalTracing; Options.TraceCmp |= ClExperimentalCMPTracing; Options.Use8bitCounters |= ClUse8bitCounters; + Options.TracePC |= ClExperimentalTracePC; return Options; } class SanitizerCoverageModule : public ModulePass { - public: +public: SanitizerCoverageModule( const SanitizerCoverageOptions &Options = SanitizerCoverageOptions()) - : ModulePass(ID), Options(OverrideFromCL(Options)) {} + : ModulePass(ID), Options(OverrideFromCL(Options)) { + initializeSanitizerCoverageModulePass(*PassRegistry::getPassRegistry()); + } bool runOnModule(Module &M) override; bool runOnFunction(Function &F); - static char ID; // Pass identification, replacement for typeid - const char *getPassName() const override { - return "SanitizerCoverageModule"; + static char ID; // Pass identification, replacement for typeid + const char *getPassName() const override { return "SanitizerCoverageModule"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); } - private: +private: void InjectCoverageForIndirectCalls(Function &F, ArrayRef<Instruction *> IndirCalls); void InjectTraceForCmp(Function &F, ArrayRef<Instruction *> CmpTraceTargets); @@ -162,8 +186,8 @@ class SanitizerCoverageModule : public ModulePass { } Function *SanCovFunction; Function *SanCovWithCheckFunction; - Function *SanCovIndirCallFunction; - Function *SanCovTraceEnter, *SanCovTraceBB; + Function *SanCovIndirCallFunction, *SanCovTracePCIndir; + Function *SanCovTraceEnter, *SanCovTraceBB, *SanCovTracePC; Function *SanCovTraceCmpFunction; Function *SanCovTraceSwitchFunction; InlineAsm *EmptyAsm; @@ -178,7 +202,7 @@ class SanitizerCoverageModule : public ModulePass { SanitizerCoverageOptions Options; }; -} // namespace +} // namespace bool SanitizerCoverageModule::runOnModule(Module &M) { if (Options.CoverageType == SanitizerCoverageOptions::SCK_None) @@ -195,28 +219,32 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { Int64Ty = IRB.getInt64Ty(); SanCovFunction = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kSanCovName, VoidTy, Int32PtrTy, nullptr)); + M.getOrInsertFunction(SanCovName, VoidTy, Int32PtrTy, nullptr)); SanCovWithCheckFunction = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kSanCovWithCheckName, VoidTy, Int32PtrTy, nullptr)); + M.getOrInsertFunction(SanCovWithCheckName, VoidTy, Int32PtrTy, nullptr)); + SanCovTracePCIndir = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy, nullptr)); SanCovIndirCallFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kSanCovIndirCallName, VoidTy, IntptrTy, IntptrTy, nullptr)); + SanCovIndirCallName, VoidTy, IntptrTy, IntptrTy, nullptr)); SanCovTraceCmpFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kSanCovTraceCmp, VoidTy, Int64Ty, Int64Ty, Int64Ty, nullptr)); + SanCovTraceCmpName, VoidTy, Int64Ty, Int64Ty, Int64Ty, nullptr)); SanCovTraceSwitchFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kSanCovTraceSwitch, VoidTy, Int64Ty, Int64PtrTy, nullptr)); + SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy, nullptr)); // We insert an empty inline asm after cov callbacks to avoid callback merge. EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), StringRef(""), StringRef(""), /*hasSideEffects=*/true); + SanCovTracePC = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(SanCovTracePCName, VoidTy, nullptr)); SanCovTraceEnter = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kSanCovTraceEnter, VoidTy, Int32PtrTy, nullptr)); + M.getOrInsertFunction(SanCovTraceEnterName, VoidTy, Int32PtrTy, nullptr)); SanCovTraceBB = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kSanCovTraceBB, VoidTy, Int32PtrTy, nullptr)); + M.getOrInsertFunction(SanCovTraceBBName, VoidTy, Int32PtrTy, nullptr)); // At this point we create a dummy array of guards because we don't // know how many elements we will need. @@ -243,7 +271,6 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { M, Int32ArrayNTy, false, GlobalValue::PrivateLinkage, Constant::getNullValue(Int32ArrayNTy), "__sancov_gen_cov"); - // Replace the dummy array with the real one. GuardArray->replaceAllUsesWith( IRB.CreatePointerCast(RealGuardArray, Int32PtrTy)); @@ -252,13 +279,12 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { GlobalVariable *RealEightBitCounterArray; if (Options.Use8bitCounters) { // Make sure the array is 16-aligned. - static const int kCounterAlignment = 16; - Type *Int8ArrayNTy = - ArrayType::get(Int8Ty, RoundUpToAlignment(N, kCounterAlignment)); + static const int CounterAlignment = 16; + Type *Int8ArrayNTy = ArrayType::get(Int8Ty, alignTo(N, CounterAlignment)); RealEightBitCounterArray = new GlobalVariable( M, Int8ArrayNTy, false, GlobalValue::PrivateLinkage, Constant::getNullValue(Int8ArrayNTy), "__sancov_gen_cov_counter"); - RealEightBitCounterArray->setAlignment(kCounterAlignment); + RealEightBitCounterArray->setAlignment(CounterAlignment); EightBitCounterArray->replaceAllUsesWith( IRB.CreatePointerCast(RealEightBitCounterArray, Int8PtrTy)); EightBitCounterArray->eraseFromParent(); @@ -271,26 +297,64 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { new GlobalVariable(M, ModNameStrConst->getType(), true, GlobalValue::PrivateLinkage, ModNameStrConst); - Function *CtorFunc; - std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( - M, kSanCovModuleCtorName, kSanCovModuleInitName, - {Int32PtrTy, IntptrTy, Int8PtrTy, Int8PtrTy}, - {IRB.CreatePointerCast(RealGuardArray, Int32PtrTy), - ConstantInt::get(IntptrTy, N), - Options.Use8bitCounters - ? IRB.CreatePointerCast(RealEightBitCounterArray, Int8PtrTy) - : Constant::getNullValue(Int8PtrTy), - IRB.CreatePointerCast(ModuleName, Int8PtrTy)}); + if (!Options.TracePC) { + Function *CtorFunc; + std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( + M, SanCovModuleCtorName, SanCovModuleInitName, + {Int32PtrTy, IntptrTy, Int8PtrTy, Int8PtrTy}, + {IRB.CreatePointerCast(RealGuardArray, Int32PtrTy), + ConstantInt::get(IntptrTy, N), + Options.Use8bitCounters + ? IRB.CreatePointerCast(RealEightBitCounterArray, Int8PtrTy) + : Constant::getNullValue(Int8PtrTy), + IRB.CreatePointerCast(ModuleName, Int8PtrTy)}); + + appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); + } + + return true; +} + +// True if block has successors and it dominates all of them. +static bool isFullDominator(const BasicBlock *BB, const DominatorTree *DT) { + if (succ_begin(BB) == succ_end(BB)) + return false; + + for (const BasicBlock *SUCC : make_range(succ_begin(BB), succ_end(BB))) { + if (!DT->dominates(BB, SUCC)) + return false; + } + + return true; +} - appendToGlobalCtors(M, CtorFunc, kSanCtorAndDtorPriority); +// True if block has predecessors and it postdominates all of them. +static bool isFullPostDominator(const BasicBlock *BB, + const PostDominatorTree *PDT) { + if (pred_begin(BB) == pred_end(BB)) + return false; + + for (const BasicBlock *PRED : make_range(pred_begin(BB), pred_end(BB))) { + if (!PDT->dominates(BB, PRED)) + return false; + } return true; } +static bool shouldInstrumentBlock(const Function& F, const BasicBlock *BB, const DominatorTree *DT, + const PostDominatorTree *PDT) { + if (!ClPruneBlocks || &F.getEntryBlock() == BB) + return true; + + return !(isFullDominator(BB, DT) || isFullPostDominator(BB, PDT)); +} + bool SanitizerCoverageModule::runOnFunction(Function &F) { - if (F.empty()) return false; + if (F.empty()) + return false; if (F.getName().find(".module_ctor") != std::string::npos) - return false; // Should not instrument sanitizer init functions. + return false; // Should not instrument sanitizer init functions. // Don't instrument functions using SEH for now. Splitting basic blocks like // we do for coverage breaks WinEHPrepare. // FIXME: Remove this when SEH no longer uses landingpad pattern matching. @@ -299,12 +363,19 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { return false; if (Options.CoverageType >= SanitizerCoverageOptions::SCK_Edge) SplitAllCriticalEdges(F); - SmallVector<Instruction*, 8> IndirCalls; - SmallVector<BasicBlock*, 16> AllBlocks; - SmallVector<Instruction*, 8> CmpTraceTargets; - SmallVector<Instruction*, 8> SwitchTraceTargets; + SmallVector<Instruction *, 8> IndirCalls; + SmallVector<BasicBlock *, 16> BlocksToInstrument; + SmallVector<Instruction *, 8> CmpTraceTargets; + SmallVector<Instruction *, 8> SwitchTraceTargets; + + const DominatorTree *DT = + &getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + const PostDominatorTree *PDT = + &getAnalysis<PostDominatorTreeWrapperPass>(F).getPostDomTree(); + for (auto &BB : F) { - AllBlocks.push_back(&BB); + if (shouldInstrumentBlock(F, &BB, DT, PDT)) + BlocksToInstrument.push_back(&BB); for (auto &Inst : BB) { if (Options.IndirectCalls) { CallSite CS(&Inst); @@ -319,7 +390,8 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { } } } - InjectCoverage(F, AllBlocks); + + InjectCoverage(F, BlocksToInstrument); InjectCoverageForIndirectCalls(F, IndirCalls); InjectTraceForCmp(F, CmpTraceTargets); InjectTraceForSwitch(F, SwitchTraceTargets); @@ -346,28 +418,34 @@ bool SanitizerCoverageModule::InjectCoverage(Function &F, // On every indirect call we call a run-time function // __sanitizer_cov_indir_call* with two parameters: // - callee address, -// - global cache array that contains kCacheSize pointers (zero-initialized). +// - global cache array that contains CacheSize pointers (zero-initialized). // The cache is used to speed up recording the caller-callee pairs. // The address of the caller is passed implicitly via caller PC. -// kCacheSize is encoded in the name of the run-time function. +// CacheSize is encoded in the name of the run-time function. void SanitizerCoverageModule::InjectCoverageForIndirectCalls( Function &F, ArrayRef<Instruction *> IndirCalls) { - if (IndirCalls.empty()) return; - const int kCacheSize = 16; - const int kCacheAlignment = 64; // Align for better performance. - Type *Ty = ArrayType::get(IntptrTy, kCacheSize); + if (IndirCalls.empty()) + return; + const int CacheSize = 16; + const int CacheAlignment = 64; // Align for better performance. + Type *Ty = ArrayType::get(IntptrTy, CacheSize); for (auto I : IndirCalls) { IRBuilder<> IRB(I); CallSite CS(I); Value *Callee = CS.getCalledValue(); - if (isa<InlineAsm>(Callee)) continue; + if (isa<InlineAsm>(Callee)) + continue; GlobalVariable *CalleeCache = new GlobalVariable( *F.getParent(), Ty, false, GlobalValue::PrivateLinkage, Constant::getNullValue(Ty), "__sancov_gen_callee_cache"); - CalleeCache->setAlignment(kCacheAlignment); - IRB.CreateCall(SanCovIndirCallFunction, - {IRB.CreatePointerCast(Callee, IntptrTy), - IRB.CreatePointerCast(CalleeCache, IntptrTy)}); + CalleeCache->setAlignment(CacheAlignment); + if (Options.TracePC) + IRB.CreateCall(SanCovTracePCIndir, + IRB.CreatePointerCast(Callee, IntptrTy)); + else + IRB.CreateCall(SanCovIndirCallFunction, + {IRB.CreatePointerCast(Callee, IntptrTy), + IRB.CreatePointerCast(CalleeCache, IntptrTy)}); } } @@ -376,7 +454,7 @@ void SanitizerCoverageModule::InjectCoverageForIndirectCalls( // {NumCases, ValueSizeInBits, Case0Value, Case1Value, Case2Value, ... }) void SanitizerCoverageModule::InjectTraceForSwitch( - Function &F, ArrayRef<Instruction *> SwitchTraceTargets) { + Function &, ArrayRef<Instruction *> SwitchTraceTargets) { for (auto I : SwitchTraceTargets) { if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { IRBuilder<> IRB(I); @@ -391,7 +469,7 @@ void SanitizerCoverageModule::InjectTraceForSwitch( if (Cond->getType()->getScalarSizeInBits() < Int64Ty->getScalarSizeInBits()) Cond = IRB.CreateIntCast(Cond, Int64Ty, false); - for (auto It: SI->cases()) { + for (auto It : SI->cases()) { Constant *C = It.getCaseValue(); if (C->getType()->getScalarSizeInBits() < Int64Ty->getScalarSizeInBits()) @@ -409,15 +487,15 @@ void SanitizerCoverageModule::InjectTraceForSwitch( } } - void SanitizerCoverageModule::InjectTraceForCmp( - Function &F, ArrayRef<Instruction *> CmpTraceTargets) { + Function &, ArrayRef<Instruction *> CmpTraceTargets) { for (auto I : CmpTraceTargets) { if (ICmpInst *ICMP = dyn_cast<ICmpInst>(I)) { IRBuilder<> IRB(ICMP); Value *A0 = ICMP->getOperand(0); Value *A1 = ICMP->getOperand(1); - if (!A0->getType()->isIntegerTy()) continue; + if (!A0->getType()->isIntegerTy()) + continue; uint64_t TypeSize = DL->getTypeStoreSizeInBits(A0->getType()); // __sanitizer_cov_trace_cmp((type_size << 32) | predicate, A0, A1); IRB.CreateCall( @@ -430,8 +508,8 @@ void SanitizerCoverageModule::InjectTraceForCmp( } void SanitizerCoverageModule::SetNoSanitizeMetadata(Instruction *I) { - I->setMetadata( - I->getModule()->getMDKindID("nosanitize"), MDNode::get(*C, None)); + I->setMetadata(I->getModule()->getMDKindID("nosanitize"), + MDNode::get(*C, None)); } void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, @@ -448,7 +526,7 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, bool IsEntryBB = &BB == &F.getEntryBlock(); DebugLoc EntryLoc; if (IsEntryBB) { - if (auto SP = getDISubprogram(&F)) + if (auto SP = F.getSubprogram()) EntryLoc = DebugLoc::get(SP->getScopeLine(), 0, SP); // Keep static allocas and llvm.localescape calls in the entry block. Even // if we aren't splitting the block, it's nice for allocas to be before @@ -465,16 +543,20 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, ConstantInt::get(IntptrTy, (1 + NumberOfInstrumentedBlocks()) * 4)); Type *Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); GuardP = IRB.CreateIntToPtr(GuardP, Int32PtrTy); - if (Options.TraceBB) { + if (Options.TracePC) { + IRB.CreateCall(SanCovTracePC); // gets the PC using GET_CALLER_PC. + IRB.CreateCall(EmptyAsm, {}); // Avoids callback merge. + } else if (Options.TraceBB) { IRB.CreateCall(IsEntryBB ? SanCovTraceEnter : SanCovTraceBB, GuardP); } else if (UseCalls) { IRB.CreateCall(SanCovWithCheckFunction, GuardP); } else { LoadInst *Load = IRB.CreateLoad(GuardP); - Load->setAtomic(Monotonic); + Load->setAtomic(AtomicOrdering::Monotonic); Load->setAlignment(4); SetNoSanitizeMetadata(Load); - Value *Cmp = IRB.CreateICmpSGE(Constant::getNullValue(Load->getType()), Load); + Value *Cmp = + IRB.CreateICmpSGE(Constant::getNullValue(Load->getType()), Load); Instruction *Ins = SplitBlockAndInsertIfThen( Cmp, &*IP, false, MDBuilder(*C).createBranchWeights(1, 100000)); IRB.SetInsertPoint(Ins); @@ -499,9 +581,16 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, } char SanitizerCoverageModule::ID = 0; -INITIALIZE_PASS(SanitizerCoverageModule, "sancov", - "SanitizerCoverage: TODO." - "ModulePass", false, false) +INITIALIZE_PASS_BEGIN(SanitizerCoverageModule, "sancov", + "SanitizerCoverage: TODO." + "ModulePass", + false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_END(SanitizerCoverageModule, "sancov", + "SanitizerCoverage: TODO." + "ModulePass", + false, false) ModulePass *llvm::createSanitizerCoverageModulePass( const SanitizerCoverageOptions &Options) { return new SanitizerCoverageModule(Options); diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 9331e1d2b3fd..dcb62d3ed1b5 100644 --- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" @@ -36,11 +37,13 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; @@ -81,6 +84,7 @@ namespace { struct ThreadSanitizer : public FunctionPass { ThreadSanitizer() : FunctionPass(ID) {} const char *getPassName() const override; + void getAnalysisUsage(AnalysisUsage &AU) const override; bool runOnFunction(Function &F) override; bool doInitialization(Module &M) override; static char ID; // Pass identification, replacement for typeid. @@ -121,7 +125,13 @@ struct ThreadSanitizer : public FunctionPass { } // namespace char ThreadSanitizer::ID = 0; -INITIALIZE_PASS(ThreadSanitizer, "tsan", +INITIALIZE_PASS_BEGIN( + ThreadSanitizer, "tsan", + "ThreadSanitizer: detects data races.", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END( + ThreadSanitizer, "tsan", "ThreadSanitizer: detects data races.", false, false) @@ -129,6 +139,10 @@ const char *ThreadSanitizer::getPassName() const { return "ThreadSanitizer"; } +void ThreadSanitizer::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} + FunctionPass *llvm::createThreadSanitizerPass() { return new ThreadSanitizer(); } @@ -243,6 +257,37 @@ static bool isVtableAccess(Instruction *I) { return false; } +// Do not instrument known races/"benign races" that come from compiler +// instrumentatin. The user has no way of suppressing them. +static bool shouldInstrumentReadWriteFromAddress(Value *Addr) { + // Peel off GEPs and BitCasts. + Addr = Addr->stripInBoundsOffsets(); + + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) { + if (GV->hasSection()) { + StringRef SectionName = GV->getSection(); + // Check if the global is in the PGO counters section. + if (SectionName.endswith(getInstrProfCountersSectionName( + /*AddSegment=*/false))) + return false; + } + + // Check if the global is in a GCOV counter array. + if (GV->getName().startswith("__llvm_gcov_ctr")) + return false; + } + + // Do not instrument acesses from different address spaces; we cannot deal + // with them. + if (Addr) { + Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType()); + if (PtrTy->getPointerAddressSpace() != 0) + return false; + } + + return true; +} + bool ThreadSanitizer::addrPointsToConstantData(Value *Addr) { // If this is a GEP, just analyze its pointer operand. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr)) @@ -281,14 +326,17 @@ void ThreadSanitizer::chooseInstructionsToInstrument( const DataLayout &DL) { SmallSet<Value*, 8> WriteTargets; // Iterate from the end. - for (SmallVectorImpl<Instruction*>::reverse_iterator It = Local.rbegin(), - E = Local.rend(); It != E; ++It) { - Instruction *I = *It; + for (Instruction *I : reverse(Local)) { if (StoreInst *Store = dyn_cast<StoreInst>(I)) { - WriteTargets.insert(Store->getPointerOperand()); + Value *Addr = Store->getPointerOperand(); + if (!shouldInstrumentReadWriteFromAddress(Addr)) + continue; + WriteTargets.insert(Addr); } else { LoadInst *Load = cast<LoadInst>(I); Value *Addr = Load->getPointerOperand(); + if (!shouldInstrumentReadWriteFromAddress(Addr)) + continue; if (WriteTargets.count(Addr)) { // We will write to this temp, so no reason to analyze the read. NumOmittedReadsBeforeWrite++; @@ -344,6 +392,8 @@ bool ThreadSanitizer::runOnFunction(Function &F) { bool HasCalls = false; bool SanitizeFunction = F.hasFnAttribute(Attribute::SanitizeThread); const DataLayout &DL = F.getParent()->getDataLayout(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); // Traverse all instructions, collect loads/stores/returns, check for calls. for (auto &BB : F) { @@ -355,6 +405,8 @@ bool ThreadSanitizer::runOnFunction(Function &F) { else if (isa<ReturnInst>(Inst)) RetVec.push_back(&Inst); else if (isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) { + if (CallInst *CI = dyn_cast<CallInst>(&Inst)) + maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI); if (isa<MemIntrinsic>(Inst)) MemIntrinCalls.push_back(&Inst); HasCalls = true; @@ -456,14 +508,16 @@ bool ThreadSanitizer::instrumentLoadOrStore(Instruction *I, static ConstantInt *createOrdering(IRBuilder<> *IRB, AtomicOrdering ord) { uint32_t v = 0; switch (ord) { - case NotAtomic: llvm_unreachable("unexpected atomic ordering!"); - case Unordered: // Fall-through. - case Monotonic: v = 0; break; - // case Consume: v = 1; break; // Not specified yet. - case Acquire: v = 2; break; - case Release: v = 3; break; - case AcquireRelease: v = 4; break; - case SequentiallyConsistent: v = 5; break; + case AtomicOrdering::NotAtomic: + llvm_unreachable("unexpected atomic ordering!"); + case AtomicOrdering::Unordered: // Fall-through. + case AtomicOrdering::Monotonic: v = 0; break; + // Not specified yet: + // case AtomicOrdering::Consume: v = 1; break; + case AtomicOrdering::Acquire: v = 2; break; + case AtomicOrdering::Release: v = 3; break; + case AtomicOrdering::AcquireRelease: v = 4; break; + case AtomicOrdering::SequentiallyConsistent: v = 5; break; } return IRB->getInt32(v); } @@ -496,6 +550,11 @@ bool ThreadSanitizer::instrumentMemIntrinsic(Instruction *I) { return false; } +static Value *createIntOrPtrToIntCast(Value *V, Type* Ty, IRBuilder<> &IRB) { + return isa<PointerType>(V->getType()) ? + IRB.CreatePtrToInt(V, Ty) : IRB.CreateIntCast(V, Ty, false); +} + // Both llvm and ThreadSanitizer atomic operations are based on C++11/C1x // standards. For background see C++11 standard. A slightly older, publicly // available draft of the standard (not entirely up-to-date, but close enough @@ -517,9 +576,16 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { Type *PtrTy = Ty->getPointerTo(); Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), createOrdering(&IRB, LI->getOrdering())}; - CallInst *C = CallInst::Create(TsanAtomicLoad[Idx], Args); - ReplaceInstWithInst(I, C); - + Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType(); + if (Ty == OrigTy) { + Instruction *C = CallInst::Create(TsanAtomicLoad[Idx], Args); + ReplaceInstWithInst(I, C); + } else { + // We are loading a pointer, so we need to cast the return value. + Value *C = IRB.CreateCall(TsanAtomicLoad[Idx], Args); + Instruction *Cast = CastInst::Create(Instruction::IntToPtr, C, OrigTy); + ReplaceInstWithInst(I, Cast); + } } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { Value *Addr = SI->getPointerOperand(); int Idx = getMemoryAccessFuncIndex(Addr, DL); @@ -530,7 +596,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize); Type *PtrTy = Ty->getPointerTo(); Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), - IRB.CreateIntCast(SI->getValueOperand(), Ty, false), + createIntOrPtrToIntCast(SI->getValueOperand(), Ty, IRB), createOrdering(&IRB, SI->getOrdering())}; CallInst *C = CallInst::Create(TsanAtomicStore[Idx], Args); ReplaceInstWithInst(I, C); @@ -560,15 +626,26 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { const unsigned BitSize = ByteSize * 8; Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize); Type *PtrTy = Ty->getPointerTo(); + Value *CmpOperand = + createIntOrPtrToIntCast(CASI->getCompareOperand(), Ty, IRB); + Value *NewOperand = + createIntOrPtrToIntCast(CASI->getNewValOperand(), Ty, IRB); Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), - IRB.CreateIntCast(CASI->getCompareOperand(), Ty, false), - IRB.CreateIntCast(CASI->getNewValOperand(), Ty, false), + CmpOperand, + NewOperand, createOrdering(&IRB, CASI->getSuccessOrdering()), createOrdering(&IRB, CASI->getFailureOrdering())}; CallInst *C = IRB.CreateCall(TsanAtomicCAS[Idx], Args); - Value *Success = IRB.CreateICmpEQ(C, CASI->getCompareOperand()); + Value *Success = IRB.CreateICmpEQ(C, CmpOperand); + Value *OldVal = C; + Type *OrigOldValTy = CASI->getNewValOperand()->getType(); + if (Ty != OrigOldValTy) { + // The value is a pointer, so we need to cast the return value. + OldVal = IRB.CreateIntToPtr(C, OrigOldValTy); + } - Value *Res = IRB.CreateInsertValue(UndefValue::get(CASI->getType()), C, 0); + Value *Res = + IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0); Res = IRB.CreateInsertValue(Res, Success, 1); I->replaceAllUsesWith(Res); diff --git a/lib/Transforms/Makefile b/lib/Transforms/Makefile deleted file mode 100644 index c390517d07cd..000000000000 --- a/lib/Transforms/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -##===- lib/Transforms/Makefile -----------------------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../.. -PARALLEL_DIRS = Utils Instrumentation Scalar InstCombine IPO Vectorize Hello ObjCARC - -include $(LEVEL)/Makefile.config - -# No support for plugins on windows targets -ifeq ($(HOST_OS), $(filter $(HOST_OS), Cygwin MingW Minix)) - PARALLEL_DIRS := $(filter-out Hello, $(PARALLEL_DIRS)) -endif - -include $(LEVEL)/Makefile.common diff --git a/lib/Transforms/ObjCARC/BlotMapVector.h b/lib/Transforms/ObjCARC/BlotMapVector.h index d6439b698418..ef075bdccbfe 100644 --- a/lib/Transforms/ObjCARC/BlotMapVector.h +++ b/lib/Transforms/ObjCARC/BlotMapVector.h @@ -31,7 +31,7 @@ public: const_iterator begin() const { return Vector.begin(); } const_iterator end() const { return Vector.end(); } -#ifdef XDEBUG +#ifdef EXPENSIVE_CHECKS ~BlotMapVector() { assert(Vector.size() >= Map.size()); // May differ due to blotting. for (typename MapTy::const_iterator I = Map.begin(), E = Map.end(); I != E; diff --git a/lib/Transforms/ObjCARC/DependencyAnalysis.h b/lib/Transforms/ObjCARC/DependencyAnalysis.h index 8e042d47ee6e..8cc1232b18ca 100644 --- a/lib/Transforms/ObjCARC/DependencyAnalysis.h +++ b/lib/Transforms/ObjCARC/DependencyAnalysis.h @@ -24,6 +24,7 @@ #define LLVM_LIB_TRANSFORMS_OBJCARC_DEPENDENCYANALYSIS_H #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/ObjCARCInstKind.h" namespace llvm { class BasicBlock; diff --git a/lib/Transforms/ObjCARC/Makefile b/lib/Transforms/ObjCARC/Makefile deleted file mode 100644 index 2a34e21714f1..000000000000 --- a/lib/Transforms/ObjCARC/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Transforms/ObjCARC/Makefile ---------------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMObjCARCOpts -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - diff --git a/lib/Transforms/ObjCARC/ObjCARC.cpp b/lib/Transforms/ObjCARC/ObjCARC.cpp index d860723bb460..688dd12c408a 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.cpp +++ b/lib/Transforms/ObjCARC/ObjCARC.cpp @@ -17,7 +17,6 @@ #include "llvm-c/Core.h" #include "llvm-c/Initialization.h" #include "llvm/InitializePasses.h" -#include "llvm/Support/CommandLine.h" namespace llvm { class PassRegistry; diff --git a/lib/Transforms/ObjCARC/ObjCARC.h b/lib/Transforms/ObjCARC/ObjCARC.h index 5fd45b00af17..f02b75f0b456 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.h +++ b/lib/Transforms/ObjCARC/ObjCARC.h @@ -24,7 +24,6 @@ #define LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H #include "llvm/ADT/StringSwitch.h" -#include "llvm/ADT/Optional.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/Analysis/ObjCARCInstKind.h" diff --git a/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp b/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp index 969e77c1f888..b2c62a0e8eeb 100644 --- a/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp @@ -70,7 +70,7 @@ void ObjCARCAPElim::getAnalysisUsage(AnalysisUsage &AU) const { /// possibly produce autoreleases. bool ObjCARCAPElim::MayAutorelease(ImmutableCallSite CS, unsigned Depth) { if (const Function *Callee = CS.getCalledFunction()) { - if (Callee->isDeclaration() || Callee->mayBeOverridden()) + if (!Callee->hasExactDefinition()) return true; for (const BasicBlock &BB : *Callee) { for (const Instruction &I : BB) @@ -132,6 +132,9 @@ bool ObjCARCAPElim::runOnModule(Module &M) { if (!ModuleHasARC(M)) return false; + if (skipModule(M)) + return false; + // Find the llvm.global_ctors variable, as the first step in // identifying the global constructors. In theory, unnecessary autorelease // pools could occur anywhere, but in practice it's pretty rare. Global diff --git a/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/lib/Transforms/ObjCARC/ObjCARCContract.cpp index 1cdf5689f42a..11e2d03e17d9 100644 --- a/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -66,7 +66,7 @@ namespace { /// The inline asm string to insert between calls and RetainRV calls to make /// the optimization work on targets which need it. - const MDString *RetainRVMarker; + const MDString *RVInstMarker; /// The set of inserted objc_storeStrong calls. If at the end of walking the /// function we have found no alloca instructions, these calls can be marked @@ -201,6 +201,7 @@ static StoreInst *findSafeStoreForStoreStrongContraction(LoadInst *Load, // Get the location associated with Load. MemoryLocation Loc = MemoryLocation::get(Load); + auto *LocPtr = Loc.Ptr->stripPointerCasts(); // Walk down to find the store and the release, which may be in either order. for (auto I = std::next(BasicBlock::iterator(Load)), @@ -261,7 +262,7 @@ static StoreInst *findSafeStoreForStoreStrongContraction(LoadInst *Load, // Then make sure that the pointer we are storing to is Ptr. If so, we // found our Store! - if (Store->getPointerOperand() == Loc.Ptr) + if (Store->getPointerOperand()->stripPointerCasts() == LocPtr) continue; // Otherwise, we have an unknown store to some other ptr that clobbers @@ -423,20 +424,20 @@ bool ObjCARCContract::tryToPeepholeInstruction( return false; // If we succeed in our optimization, fall through. // FALLTHROUGH - case ARCInstKind::RetainRV: { + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: { // If we're compiling for a target which needs a special inline-asm - // marker to do the retainAutoreleasedReturnValue optimization, - // insert it now. - if (!RetainRVMarker) + // marker to do the return value optimization, insert it now. + if (!RVInstMarker) return false; BasicBlock::iterator BBI = Inst->getIterator(); BasicBlock *InstParent = Inst->getParent(); - // Step up to see if the call immediately precedes the RetainRV call. + // Step up to see if the call immediately precedes the RV call. // If it's an invoke, we have to cross a block boundary. And we have // to carefully dodge no-op instructions. do { - if (&*BBI == InstParent->begin()) { + if (BBI == InstParent->begin()) { BasicBlock *Pred = InstParent->getSinglePredecessor(); if (!Pred) goto decline_rv_optimization; @@ -447,14 +448,14 @@ bool ObjCARCContract::tryToPeepholeInstruction( } while (IsNoopInstruction(&*BBI)); if (&*BBI == GetArgRCIdentityRoot(Inst)) { - DEBUG(dbgs() << "Adding inline asm marker for " - "retainAutoreleasedReturnValue optimization.\n"); + DEBUG(dbgs() << "Adding inline asm marker for the return value " + "optimization.\n"); Changed = true; - InlineAsm *IA = - InlineAsm::get(FunctionType::get(Type::getVoidTy(Inst->getContext()), - /*isVarArg=*/false), - RetainRVMarker->getString(), - /*Constraints=*/"", /*hasSideEffects=*/true); + InlineAsm *IA = InlineAsm::get( + FunctionType::get(Type::getVoidTy(Inst->getContext()), + /*isVarArg=*/false), + RVInstMarker->getString(), + /*Constraints=*/"", /*hasSideEffects=*/true); CallInst::Create(IA, "", Inst); } decline_rv_optimization: @@ -605,7 +606,7 @@ bool ObjCARCContract::runOnFunction(Function &F) { cast<GEPOperator>(Arg)->hasAllZeroIndices()) Arg = cast<GEPOperator>(Arg)->getPointerOperand(); else if (isa<GlobalAlias>(Arg) && - !cast<GlobalAlias>(Arg)->mayBeOverridden()) + !cast<GlobalAlias>(Arg)->isInterposable()) Arg = cast<GlobalAlias>(Arg)->getAliasee(); else break; @@ -650,15 +651,15 @@ bool ObjCARCContract::doInitialization(Module &M) { EP.init(&M); - // Initialize RetainRVMarker. - RetainRVMarker = nullptr; + // Initialize RVInstMarker. + RVInstMarker = nullptr; if (NamedMDNode *NMD = M.getNamedMetadata("clang.arc.retainAutoreleasedReturnValueMarker")) if (NMD->getNumOperands() == 1) { const MDNode *N = NMD->getOperand(0); if (N->getNumOperands() == 1) if (const MDString *S = dyn_cast<MDString>(N->getOperand(0))) - RetainRVMarker = S; + RVInstMarker = S; } return false; diff --git a/lib/Transforms/ObjCARC/ObjCARCExpand.cpp b/lib/Transforms/ObjCARC/ObjCARCExpand.cpp index 53c19c39f97f..bb6a0a0e73db 100644 --- a/lib/Transforms/ObjCARC/ObjCARCExpand.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCExpand.cpp @@ -24,7 +24,6 @@ //===----------------------------------------------------------------------===// #include "ObjCARC.h" -#include "llvm/ADT/StringRef.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" diff --git a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index f0ee6e2be487..a6907b56cf45 100644 --- a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -889,6 +889,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { Inst->getParent(), Inst, DependingInstructions, Visited, PA); break; + case ARCInstKind::ClaimRV: case ARCInstKind::RetainRV: case ARCInstKind::AutoreleaseRV: // Don't move these; the RV optimization depends on the autoreleaseRV @@ -1459,17 +1460,13 @@ bool ObjCARCOpt::Visit(Function &F, // Use reverse-postorder on the reverse CFG for bottom-up. bool BottomUpNestingDetected = false; - for (SmallVectorImpl<BasicBlock *>::const_reverse_iterator I = - ReverseCFGPostOrder.rbegin(), E = ReverseCFGPostOrder.rend(); - I != E; ++I) - BottomUpNestingDetected |= VisitBottomUp(*I, BBStates, Retains); + for (BasicBlock *BB : reverse(ReverseCFGPostOrder)) + BottomUpNestingDetected |= VisitBottomUp(BB, BBStates, Retains); // Use reverse-postorder for top-down. bool TopDownNestingDetected = false; - for (SmallVectorImpl<BasicBlock *>::const_reverse_iterator I = - PostOrder.rbegin(), E = PostOrder.rend(); - I != E; ++I) - TopDownNestingDetected |= VisitTopDown(*I, BBStates, Releases); + for (BasicBlock *BB : reverse(PostOrder)) + TopDownNestingDetected |= VisitTopDown(BB, BBStates, Releases); return TopDownNestingDetected && BottomUpNestingDetected; } @@ -1554,9 +1551,7 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( unsigned NewCount = 0; bool FirstRelease = true; for (;;) { - for (SmallVectorImpl<Instruction *>::const_iterator - NI = NewRetains.begin(), NE = NewRetains.end(); NI != NE; ++NI) { - Instruction *NewRetain = *NI; + for (Instruction *NewRetain : NewRetains) { auto It = Retains.find(NewRetain); assert(It != Retains.end()); const RRInfo &NewRetainRRI = It->second; @@ -1630,9 +1625,7 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( if (NewReleases.empty()) break; // Back the other way. - for (SmallVectorImpl<Instruction *>::const_iterator - NI = NewReleases.begin(), NE = NewReleases.end(); NI != NE; ++NI) { - Instruction *NewRelease = *NI; + for (Instruction *NewRelease : NewReleases) { auto It = Releases.find(NewRelease); assert(It != Releases.end()); const RRInfo &NewReleaseRRI = It->second; diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp index 590a52da6b19..0eed0240c741 100644 --- a/lib/Transforms/Scalar/ADCE.cpp +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -22,10 +22,12 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProf.h" #include "llvm/Transforms/Scalar.h" using namespace llvm; @@ -33,22 +35,70 @@ using namespace llvm; STATISTIC(NumRemoved, "Number of instructions removed"); +static void collectLiveScopes(const DILocalScope &LS, + SmallPtrSetImpl<const Metadata *> &AliveScopes) { + if (!AliveScopes.insert(&LS).second) + return; + + if (isa<DISubprogram>(LS)) + return; + + // Tail-recurse through the scope chain. + collectLiveScopes(cast<DILocalScope>(*LS.getScope()), AliveScopes); +} + +static void collectLiveScopes(const DILocation &DL, + SmallPtrSetImpl<const Metadata *> &AliveScopes) { + // Even though DILocations are not scopes, shove them into AliveScopes so we + // don't revisit them. + if (!AliveScopes.insert(&DL).second) + return; + + // Collect live scopes from the scope chain. + collectLiveScopes(*DL.getScope(), AliveScopes); + + // Tail-recurse through the inlined-at chain. + if (const DILocation *IA = DL.getInlinedAt()) + collectLiveScopes(*IA, AliveScopes); +} + +// Check if this instruction is a runtime call for value profiling and +// if it's instrumenting a constant. +static bool isInstrumentsConstant(Instruction &I) { + if (CallInst *CI = dyn_cast<CallInst>(&I)) + if (Function *Callee = CI->getCalledFunction()) + if (Callee->getName().equals(getInstrProfValueProfFuncName())) + if (isa<Constant>(CI->getArgOperand(0))) + return true; + return false; +} + static bool aggressiveDCE(Function& F) { - SmallPtrSet<Instruction*, 128> Alive; + SmallPtrSet<Instruction*, 32> Alive; SmallVector<Instruction*, 128> Worklist; // Collect the set of "root" instructions that are known live. for (Instruction &I : instructions(F)) { - if (isa<TerminatorInst>(I) || isa<DbgInfoIntrinsic>(I) || I.isEHPad() || - I.mayHaveSideEffects()) { + if (isa<TerminatorInst>(I) || I.isEHPad() || I.mayHaveSideEffects()) { + // Skip any value profile instrumentation calls if they are + // instrumenting constants. + if (isInstrumentsConstant(I)) + continue; Alive.insert(&I); Worklist.push_back(&I); } } - // Propagate liveness backwards to operands. + // Propagate liveness backwards to operands. Keep track of live debug info + // scopes. + SmallPtrSet<const Metadata *, 32> AliveScopes; while (!Worklist.empty()) { Instruction *Curr = Worklist.pop_back_val(); + + // Collect the live debug info scopes attached to this instruction. + if (const DILocation *DL = Curr->getDebugLoc()) + collectLiveScopes(*DL, AliveScopes); + for (Use &OI : Curr->operands()) { if (Instruction *Inst = dyn_cast<Instruction>(OI)) if (Alive.insert(Inst).second) @@ -61,10 +111,30 @@ static bool aggressiveDCE(Function& F) { // value of the function, and may therefore be deleted safely. // NOTE: We reuse the Worklist vector here for memory efficiency. for (Instruction &I : instructions(F)) { - if (!Alive.count(&I)) { - Worklist.push_back(&I); - I.dropAllReferences(); + // Check if the instruction is alive. + if (Alive.count(&I)) + continue; + + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&I)) { + // Check if the scope of this variable location is alive. + if (AliveScopes.count(DII->getDebugLoc()->getScope())) + continue; + + // Fallthrough and drop the intrinsic. + DEBUG({ + // If intrinsic is pointing at a live SSA value, there may be an + // earlier optimization bug: if we know the location of the variable, + // why isn't the scope of the location alive? + if (Value *V = DII->getVariableLocation()) + if (Instruction *II = dyn_cast<Instruction>(V)) + if (Alive.count(II)) + dbgs() << "Dropping debug info for " << *DII << "\n"; + }); } + + // Prepare to delete. + Worklist.push_back(&I); + I.dropAllReferences(); } for (Instruction *&I : Worklist) { @@ -75,10 +145,14 @@ static bool aggressiveDCE(Function& F) { return !Worklist.empty(); } -PreservedAnalyses ADCEPass::run(Function &F) { - if (aggressiveDCE(F)) - return PreservedAnalyses::none(); - return PreservedAnalyses::all(); +PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &) { + if (!aggressiveDCE(F)) + return PreservedAnalyses::all(); + + // FIXME: This should also 'preserve the CFG'. + auto PA = PreservedAnalyses(); + PA.preserve<GlobalsAA>(); + return PA; } namespace { @@ -89,7 +163,7 @@ struct ADCELegacyPass : public FunctionPass { } bool runOnFunction(Function& F) override { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; return aggressiveDCE(F); } diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index 4b721d38adba..7f8b8ce91e79 100644 --- a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -18,6 +18,7 @@ #define AA_NAME "alignment-from-assumptions" #define DEBUG_TYPE AA_NAME +#include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h" #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" @@ -25,13 +26,11 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instruction.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" @@ -67,18 +66,7 @@ struct AlignmentFromAssumptions : public FunctionPass { AU.addPreserved<ScalarEvolutionWrapperPass>(); } - // For memory transfers, we need a common alignment for both the source and - // destination. If we have a new alignment for only one operand of a transfer - // instruction, save it in these maps. If we reach the other operand through - // another assumption later, then we may change the alignment at that point. - DenseMap<MemTransferInst *, unsigned> NewDestAlignments, NewSrcAlignments; - - ScalarEvolution *SE; - DominatorTree *DT; - - bool extractAlignmentInfo(CallInst *I, Value *&AAPtr, const SCEV *&AlignSCEV, - const SCEV *&OffSCEV); - bool processAssumption(CallInst *I); + AlignmentFromAssumptionsPass Impl; }; } @@ -209,9 +197,10 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, return 0; } -bool AlignmentFromAssumptions::extractAlignmentInfo(CallInst *I, - Value *&AAPtr, const SCEV *&AlignSCEV, - const SCEV *&OffSCEV) { +bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I, + Value *&AAPtr, + const SCEV *&AlignSCEV, + const SCEV *&OffSCEV) { // An alignment assume must be a statement about the least-significant // bits of the pointer being zero, possibly with some offset. ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0)); @@ -302,7 +291,7 @@ bool AlignmentFromAssumptions::extractAlignmentInfo(CallInst *I, return true; } -bool AlignmentFromAssumptions::processAssumption(CallInst *ACall) { +bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { Value *AAPtr; const SCEV *AlignSCEV, *OffSCEV; if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV)) @@ -411,14 +400,26 @@ bool AlignmentFromAssumptions::processAssumption(CallInst *ACall) { } bool AlignmentFromAssumptions::runOnFunction(Function &F) { - bool Changed = false; + if (skipFunction(F)) + return false; + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + + return Impl.runImpl(F, AC, SE, DT); +} + +bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, + ScalarEvolution *SE_, + DominatorTree *DT_) { + SE = SE_; + DT = DT_; NewDestAlignments.clear(); NewSrcAlignments.clear(); + bool Changed = false; for (auto &AssumeVH : AC.assumptions()) if (AssumeVH) Changed |= processAssumption(cast<CallInst>(AssumeVH)); @@ -426,3 +427,20 @@ bool AlignmentFromAssumptions::runOnFunction(Function &F) { return Changed; } +PreservedAnalyses +AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) { + + AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); + ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + bool Changed = runImpl(F, AC, &SE, &DT); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<AAManager>(); + PA.preserve<ScalarEvolutionAnalysis>(); + PA.preserve<GlobalsAA>(); + PA.preserve<LoopAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + return PA; +} diff --git a/lib/Transforms/Scalar/BDCE.cpp b/lib/Transforms/Scalar/BDCE.cpp index cb9b8b6fffc8..4f6225f4c7b0 100644 --- a/lib/Transforms/Scalar/BDCE.cpp +++ b/lib/Transforms/Scalar/BDCE.cpp @@ -14,11 +14,11 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/BDCE.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/IR/CFG.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -27,6 +27,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" using namespace llvm; #define DEBUG_TYPE "bdce" @@ -34,35 +35,7 @@ using namespace llvm; STATISTIC(NumRemoved, "Number of instructions removed (unused)"); STATISTIC(NumSimplified, "Number of instructions trivialized (dead bits)"); -namespace { -struct BDCE : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - BDCE() : FunctionPass(ID) { - initializeBDCEPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function& F) override; - - void getAnalysisUsage(AnalysisUsage& AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DemandedBits>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } -}; -} - -char BDCE::ID = 0; -INITIALIZE_PASS_BEGIN(BDCE, "bdce", "Bit-Tracking Dead Code Elimination", - false, false) -INITIALIZE_PASS_DEPENDENCY(DemandedBits) -INITIALIZE_PASS_END(BDCE, "bdce", "Bit-Tracking Dead Code Elimination", - false, false) - -bool BDCE::runOnFunction(Function& F) { - if (skipOptnoneFunction(F)) - return false; - DemandedBits &DB = getAnalysis<DemandedBits>(); - +static bool bitTrackingDCE(Function &F, DemandedBits &DB) { SmallVector<Instruction*, 128> Worklist; bool Changed = false; for (Instruction &I : instructions(F)) { @@ -96,7 +69,44 @@ bool BDCE::runOnFunction(Function& F) { return Changed; } -FunctionPass *llvm::createBitTrackingDCEPass() { - return new BDCE(); +PreservedAnalyses BDCEPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &DB = AM.getResult<DemandedBitsAnalysis>(F); + if (!bitTrackingDCE(F, DB)) + return PreservedAnalyses::all(); + + // FIXME: This should also 'preserve the CFG'. + auto PA = PreservedAnalyses(); + PA.preserve<GlobalsAA>(); + return PA; } +namespace { +struct BDCELegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + BDCELegacyPass() : FunctionPass(ID) { + initializeBDCELegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto &DB = getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); + return bitTrackingDCE(F, DB); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DemandedBitsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; +} + +char BDCELegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(BDCELegacyPass, "bdce", + "Bit-Tracking Dead Code Elimination", false, false) +INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) +INITIALIZE_PASS_END(BDCELegacyPass, "bdce", + "Bit-Tracking Dead Code Elimination", false, false) + +FunctionPass *llvm::createBitTrackingDCEPass() { return new BDCELegacyPass(); } diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index a0ddbd085206..9f04344b8b0a 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -10,13 +10,16 @@ add_llvm_library(LLVMScalarOpts EarlyCSE.cpp FlattenCFGPass.cpp Float2Int.cpp + GuardWidening.cpp GVN.cpp + GVNHoist.cpp InductiveRangeCheckElimination.cpp IndVarSimplify.cpp JumpThreading.cpp LICM.cpp LoadCombine.cpp LoopDeletion.cpp + LoopDataPrefetch.cpp LoopDistribute.cpp LoopIdiomRecognize.cpp LoopInstSimplify.cpp @@ -24,11 +27,14 @@ add_llvm_library(LLVMScalarOpts LoopLoadElimination.cpp LoopRerollPass.cpp LoopRotation.cpp + LoopSimplifyCFG.cpp LoopStrengthReduce.cpp LoopUnrollPass.cpp LoopUnswitch.cpp + LoopVersioningLICM.cpp LowerAtomic.cpp LowerExpectIntrinsic.cpp + LowerGuardIntrinsic.cpp MemCpyOptimizer.cpp MergedLoadStoreMotion.cpp NaryReassociate.cpp @@ -40,7 +46,6 @@ add_llvm_library(LLVMScalarOpts SCCP.cpp SROA.cpp Scalar.cpp - ScalarReplAggregates.cpp Scalarizer.cpp SeparateConstOffsetFromGEP.cpp SimplifyCFGPass.cpp diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp index 84f7f5fff5b5..913e939c2bd4 100644 --- a/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -33,20 +33,20 @@ // %0 = load i64* inttoptr (i64 big_constant to i64*) //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/ConstantHoisting.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include <tuple> using namespace llvm; +using namespace consthoist; #define DEBUG_TYPE "consthoist" @@ -54,75 +54,12 @@ STATISTIC(NumConstantsHoisted, "Number of constants hoisted"); STATISTIC(NumConstantsRebased, "Number of constants rebased"); namespace { -struct ConstantUser; -struct RebasedConstantInfo; - -typedef SmallVector<ConstantUser, 8> ConstantUseListType; -typedef SmallVector<RebasedConstantInfo, 4> RebasedConstantListType; - -/// \brief Keeps track of the user of a constant and the operand index where the -/// constant is used. -struct ConstantUser { - Instruction *Inst; - unsigned OpndIdx; - - ConstantUser(Instruction *Inst, unsigned Idx) : Inst(Inst), OpndIdx(Idx) { } -}; - -/// \brief Keeps track of a constant candidate and its uses. -struct ConstantCandidate { - ConstantUseListType Uses; - ConstantInt *ConstInt; - unsigned CumulativeCost; - - ConstantCandidate(ConstantInt *ConstInt) - : ConstInt(ConstInt), CumulativeCost(0) { } - - /// \brief Add the user to the use list and update the cost. - void addUser(Instruction *Inst, unsigned Idx, unsigned Cost) { - CumulativeCost += Cost; - Uses.push_back(ConstantUser(Inst, Idx)); - } -}; - -/// \brief This represents a constant that has been rebased with respect to a -/// base constant. The difference to the base constant is recorded in Offset. -struct RebasedConstantInfo { - ConstantUseListType Uses; - Constant *Offset; - - RebasedConstantInfo(ConstantUseListType &&Uses, Constant *Offset) - : Uses(std::move(Uses)), Offset(Offset) { } -}; - -/// \brief A base constant and all its rebased constants. -struct ConstantInfo { - ConstantInt *BaseConstant; - RebasedConstantListType RebasedConstants; -}; - /// \brief The constant hoisting pass. -class ConstantHoisting : public FunctionPass { - typedef DenseMap<ConstantInt *, unsigned> ConstCandMapType; - typedef std::vector<ConstantCandidate> ConstCandVecType; - - const TargetTransformInfo *TTI; - DominatorTree *DT; - BasicBlock *Entry; - - /// Keeps track of constant candidates found in the function. - ConstCandVecType ConstCandVec; - - /// Keep track of cast instructions we already cloned. - SmallDenseMap<Instruction *, Instruction *> ClonedCastMap; - - /// These are the final constants we decided to hoist. - SmallVector<ConstantInfo, 8> ConstantVec; +class ConstantHoistingLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid - ConstantHoisting() : FunctionPass(ID), TTI(nullptr), DT(nullptr), - Entry(nullptr) { - initializeConstantHoistingPass(*PassRegistry::getPassRegistry()); + ConstantHoistingLegacyPass() : FunctionPass(ID) { + initializeConstantHoistingLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &Fn) override; @@ -135,67 +72,36 @@ public: AU.addRequired<TargetTransformInfoWrapperPass>(); } -private: - /// \brief Initialize the pass. - void setup(Function &Fn) { - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(Fn); - Entry = &Fn.getEntryBlock(); - } + void releaseMemory() override { Impl.releaseMemory(); } - /// \brief Cleanup. - void cleanup() { - ConstantVec.clear(); - ClonedCastMap.clear(); - ConstCandVec.clear(); - - TTI = nullptr; - DT = nullptr; - Entry = nullptr; - } - - Instruction *findMatInsertPt(Instruction *Inst, unsigned Idx = ~0U) const; - Instruction *findConstantInsertionPoint(const ConstantInfo &ConstInfo) const; - void collectConstantCandidates(ConstCandMapType &ConstCandMap, - Instruction *Inst, unsigned Idx, - ConstantInt *ConstInt); - void collectConstantCandidates(ConstCandMapType &ConstCandMap, - Instruction *Inst); - void collectConstantCandidates(Function &Fn); - void findAndMakeBaseConstant(ConstCandVecType::iterator S, - ConstCandVecType::iterator E); - void findBaseConstants(); - void emitBaseConstants(Instruction *Base, Constant *Offset, - const ConstantUser &ConstUser); - bool emitBaseConstants(); - void deleteDeadCastInst() const; - bool optimizeConstants(Function &Fn); +private: + ConstantHoistingPass Impl; }; } -char ConstantHoisting::ID = 0; -INITIALIZE_PASS_BEGIN(ConstantHoisting, "consthoist", "Constant Hoisting", - false, false) +char ConstantHoistingLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(ConstantHoistingLegacyPass, "consthoist", + "Constant Hoisting", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(ConstantHoisting, "consthoist", "Constant Hoisting", - false, false) +INITIALIZE_PASS_END(ConstantHoistingLegacyPass, "consthoist", + "Constant Hoisting", false, false) FunctionPass *llvm::createConstantHoistingPass() { - return new ConstantHoisting(); + return new ConstantHoistingLegacyPass(); } /// \brief Perform the constant hoisting optimization for the given function. -bool ConstantHoisting::runOnFunction(Function &Fn) { - if (skipOptnoneFunction(Fn)) +bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { + if (skipFunction(Fn)) return false; DEBUG(dbgs() << "********** Begin Constant Hoisting **********\n"); DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); - setup(Fn); - - bool MadeChange = optimizeConstants(Fn); + bool MadeChange = Impl.runImpl( + Fn, getAnalysis<TargetTransformInfoWrapperPass>().getTTI(Fn), + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), Fn.getEntryBlock()); if (MadeChange) { DEBUG(dbgs() << "********** Function after Constant Hoisting: " @@ -204,15 +110,13 @@ bool ConstantHoisting::runOnFunction(Function &Fn) { } DEBUG(dbgs() << "********** End Constant Hoisting **********\n"); - cleanup(); - return MadeChange; } /// \brief Find the constant materialization insertion point. -Instruction *ConstantHoisting::findMatInsertPt(Instruction *Inst, - unsigned Idx) const { +Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, + unsigned Idx) const { // If the operand is a cast instruction, then we have to materialize the // constant before the cast instruction. if (Idx != ~0U) { @@ -237,8 +141,8 @@ Instruction *ConstantHoisting::findMatInsertPt(Instruction *Inst, } /// \brief Find an insertion point that dominates all uses. -Instruction *ConstantHoisting:: -findConstantInsertionPoint(const ConstantInfo &ConstInfo) const { +Instruction *ConstantHoistingPass::findConstantInsertionPoint( + const ConstantInfo &ConstInfo) const { assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); // Collect all basic blocks. SmallPtrSet<BasicBlock *, 8> BBs; @@ -272,10 +176,9 @@ findConstantInsertionPoint(const ConstantInfo &ConstInfo) const { /// The operand at index Idx is not necessarily the constant integer itself. It /// could also be a cast instruction or a constant expression that uses the // constant integer. -void ConstantHoisting::collectConstantCandidates(ConstCandMapType &ConstCandMap, - Instruction *Inst, - unsigned Idx, - ConstantInt *ConstInt) { +void ConstantHoistingPass::collectConstantCandidates( + ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx, + ConstantInt *ConstInt) { unsigned Cost; // Ask the target about the cost of materializing the constant for the given // instruction and operand index. @@ -309,8 +212,8 @@ void ConstantHoisting::collectConstantCandidates(ConstCandMapType &ConstCandMap, /// \brief Scan the instruction for expensive integer constants and record them /// in the constant candidate vector. -void ConstantHoisting::collectConstantCandidates(ConstCandMapType &ConstCandMap, - Instruction *Inst) { +void ConstantHoistingPass::collectConstantCandidates( + ConstCandMapType &ConstCandMap, Instruction *Inst) { // Skip all cast instructions. They are visited indirectly later on. if (Inst->isCast()) return; @@ -320,6 +223,18 @@ void ConstantHoisting::collectConstantCandidates(ConstCandMapType &ConstCandMap, if (isa<InlineAsm>(Call->getCalledValue())) return; + // Switch cases must remain constant, and if the value being tested is + // constant the entire thing should disappear. + if (isa<SwitchInst>(Inst)) + return; + + // Static allocas (constant size in the entry block) are handled by + // prologue/epilogue insertion so they're free anyway. We definitely don't + // want to make them non-constant. + auto AI = dyn_cast<AllocaInst>(Inst); + if (AI && AI->isStaticAlloca()) + return; + // Scan all operands. for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { Value *Opnd = Inst->getOperand(Idx); @@ -363,25 +278,116 @@ void ConstantHoisting::collectConstantCandidates(ConstCandMapType &ConstCandMap, /// \brief Collect all integer constants in the function that cannot be folded /// into an instruction itself. -void ConstantHoisting::collectConstantCandidates(Function &Fn) { +void ConstantHoistingPass::collectConstantCandidates(Function &Fn) { ConstCandMapType ConstCandMap; for (BasicBlock &BB : Fn) for (Instruction &Inst : BB) collectConstantCandidates(ConstCandMap, &Inst); } -/// \brief Find the base constant within the given range and rebase all other -/// constants with respect to the base constant. -void ConstantHoisting::findAndMakeBaseConstant(ConstCandVecType::iterator S, - ConstCandVecType::iterator E) { - auto MaxCostItr = S; +// This helper function is necessary to deal with values that have different +// bit widths (APInt Operator- does not like that). If the value cannot be +// represented in uint64 we return an "empty" APInt. This is then interpreted +// as the value is not in range. +static llvm::Optional<APInt> calculateOffsetDiff(APInt V1, APInt V2) +{ + llvm::Optional<APInt> Res = None; + unsigned BW = V1.getBitWidth() > V2.getBitWidth() ? + V1.getBitWidth() : V2.getBitWidth(); + uint64_t LimVal1 = V1.getLimitedValue(); + uint64_t LimVal2 = V2.getLimitedValue(); + + if (LimVal1 == ~0ULL || LimVal2 == ~0ULL) + return Res; + + uint64_t Diff = LimVal1 - LimVal2; + return APInt(BW, Diff, true); +} + +// From a list of constants, one needs to picked as the base and the other +// constants will be transformed into an offset from that base constant. The +// question is which we can pick best? For example, consider these constants +// and their number of uses: +// +// Constants| 2 | 4 | 12 | 42 | +// NumUses | 3 | 2 | 8 | 7 | +// +// Selecting constant 12 because it has the most uses will generate negative +// offsets for constants 2 and 4 (i.e. -10 and -8 respectively). If negative +// offsets lead to less optimal code generation, then there might be better +// solutions. Suppose immediates in the range of 0..35 are most optimally +// supported by the architecture, then selecting constant 2 is most optimal +// because this will generate offsets: 0, 2, 10, 40. Offsets 0, 2 and 10 are in +// range 0..35, and thus 3 + 2 + 8 = 13 uses are in range. Selecting 12 would +// have only 8 uses in range, so choosing 2 as a base is more optimal. Thus, in +// selecting the base constant the range of the offsets is a very important +// factor too that we take into account here. This algorithm calculates a total +// costs for selecting a constant as the base and substract the costs if +// immediates are out of range. It has quadratic complexity, so we call this +// function only when we're optimising for size and there are less than 100 +// constants, we fall back to the straightforward algorithm otherwise +// which does not do all the offset calculations. +unsigned +ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, + ConstCandVecType::iterator E, + ConstCandVecType::iterator &MaxCostItr) { unsigned NumUses = 0; - // Use the constant that has the maximum cost as base constant. + + if(!Entry->getParent()->optForSize() || std::distance(S,E) > 100) { + for (auto ConstCand = S; ConstCand != E; ++ConstCand) { + NumUses += ConstCand->Uses.size(); + if (ConstCand->CumulativeCost > MaxCostItr->CumulativeCost) + MaxCostItr = ConstCand; + } + return NumUses; + } + + DEBUG(dbgs() << "== Maximize constants in range ==\n"); + int MaxCost = -1; for (auto ConstCand = S; ConstCand != E; ++ConstCand) { + auto Value = ConstCand->ConstInt->getValue(); + Type *Ty = ConstCand->ConstInt->getType(); + int Cost = 0; NumUses += ConstCand->Uses.size(); - if (ConstCand->CumulativeCost > MaxCostItr->CumulativeCost) + DEBUG(dbgs() << "= Constant: " << ConstCand->ConstInt->getValue() << "\n"); + + for (auto User : ConstCand->Uses) { + unsigned Opcode = User.Inst->getOpcode(); + unsigned OpndIdx = User.OpndIdx; + Cost += TTI->getIntImmCost(Opcode, OpndIdx, Value, Ty); + DEBUG(dbgs() << "Cost: " << Cost << "\n"); + + for (auto C2 = S; C2 != E; ++C2) { + llvm::Optional<APInt> Diff = calculateOffsetDiff( + C2->ConstInt->getValue(), + ConstCand->ConstInt->getValue()); + if (Diff) { + const int ImmCosts = + TTI->getIntImmCodeSizeCost(Opcode, OpndIdx, Diff.getValue(), Ty); + Cost -= ImmCosts; + DEBUG(dbgs() << "Offset " << Diff.getValue() << " " + << "has penalty: " << ImmCosts << "\n" + << "Adjusted cost: " << Cost << "\n"); + } + } + } + DEBUG(dbgs() << "Cumulative cost: " << Cost << "\n"); + if (Cost > MaxCost) { + MaxCost = Cost; MaxCostItr = ConstCand; + DEBUG(dbgs() << "New candidate: " << MaxCostItr->ConstInt->getValue() + << "\n"); + } } + return NumUses; +} + +/// \brief Find the base constant within the given range and rebase all other +/// constants with respect to the base constant. +void ConstantHoistingPass::findAndMakeBaseConstant( + ConstCandVecType::iterator S, ConstCandVecType::iterator E) { + auto MaxCostItr = S; + unsigned NumUses = maximizeConstantsInRange(S, E, MaxCostItr); // Don't hoist constants that have only one use. if (NumUses <= 1) @@ -404,7 +410,7 @@ void ConstantHoisting::findAndMakeBaseConstant(ConstCandVecType::iterator S, /// \brief Finds and combines constant candidates that can be easily /// rematerialized with an add from a common base constant. -void ConstantHoisting::findBaseConstants() { +void ConstantHoistingPass::findBaseConstants() { // Sort the constants by value and type. This invalidates the mapping! std::sort(ConstCandVec.begin(), ConstCandVec.end(), [](const ConstantCandidate &LHS, const ConstantCandidate &RHS) { @@ -466,8 +472,9 @@ static bool updateOperand(Instruction *Inst, unsigned Idx, Instruction *Mat) { /// \brief Emit materialization code for all rebased constants and update their /// users. -void ConstantHoisting::emitBaseConstants(Instruction *Base, Constant *Offset, - const ConstantUser &ConstUser) { +void ConstantHoistingPass::emitBaseConstants(Instruction *Base, + Constant *Offset, + const ConstantUser &ConstUser) { Instruction *Mat = Base; if (Offset) { Instruction *InsertionPt = findMatInsertPt(ConstUser.Inst, @@ -538,7 +545,7 @@ void ConstantHoisting::emitBaseConstants(Instruction *Base, Constant *Offset, /// \brief Hoist and hide the base constant behind a bitcast and emit /// materialization code for derived constants. -bool ConstantHoisting::emitBaseConstants() { +bool ConstantHoistingPass::emitBaseConstants() { bool MadeChange = false; for (auto const &ConstInfo : ConstantVec) { // Hoist and hide the base constant behind a bitcast. @@ -572,14 +579,18 @@ bool ConstantHoisting::emitBaseConstants() { /// \brief Check all cast instructions we made a copy of and remove them if they /// have no more users. -void ConstantHoisting::deleteDeadCastInst() const { +void ConstantHoistingPass::deleteDeadCastInst() const { for (auto const &I : ClonedCastMap) if (I.first->use_empty()) I.first->eraseFromParent(); } /// \brief Optimize expensive integer constants in the given function. -bool ConstantHoisting::optimizeConstants(Function &Fn) { +bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, + DominatorTree &DT, BasicBlock &Entry) { + this->TTI = &TTI; + this->DT = &DT; + this->Entry = &Entry; // Collect all constant candidates. collectConstantCandidates(Fn); @@ -604,3 +615,14 @@ bool ConstantHoisting::optimizeConstants(Function &Fn) { return MadeChange; } + +PreservedAnalyses ConstantHoistingPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + if (!runImpl(F, TTI, DT, F.getEntryBlock())) + return PreservedAnalyses::all(); + + // FIXME: This should also 'preserve the CFG'. + return PreservedAnalyses::none(); +} diff --git a/lib/Transforms/Scalar/ConstantProp.cpp b/lib/Transforms/Scalar/ConstantProp.cpp index c974ebb9456f..88172d19fe5a 100644 --- a/lib/Transforms/Scalar/ConstantProp.cpp +++ b/lib/Transforms/Scalar/ConstantProp.cpp @@ -61,11 +61,14 @@ FunctionPass *llvm::createConstantPropagationPass() { } bool ConstantPropagation::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + // Initialize the worklist to all of the instructions ready to process... std::set<Instruction*> WorkList; - for(inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) { - WorkList.insert(&*i); - } + for (Instruction &I: instructions(&F)) + WorkList.insert(&I); + bool Changed = false; const DataLayout &DL = F.getParent()->getDataLayout(); TargetLibraryInfo *TLI = diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 686bd4071104..c0fed0533392 100644 --- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h" #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" @@ -35,22 +36,11 @@ STATISTIC(NumMemAccess, "Number of memory access targets propagated"); STATISTIC(NumCmps, "Number of comparisons propagated"); STATISTIC(NumReturns, "Number of return values propagated"); STATISTIC(NumDeadCases, "Number of switch cases removed"); +STATISTIC(NumSDivs, "Number of sdiv converted to udiv"); +STATISTIC(NumSRems, "Number of srem converted to urem"); namespace { class CorrelatedValuePropagation : public FunctionPass { - LazyValueInfo *LVI; - - bool processSelect(SelectInst *SI); - bool processPHI(PHINode *P); - bool processMemAccess(Instruction *I); - bool processCmp(CmpInst *C); - bool processSwitch(SwitchInst *SI); - bool processCallSite(CallSite CS); - - /// Return a constant value for V usable at At and everything it - /// dominates. If no such Constant can be found, return nullptr. - Constant *getConstantAt(Value *V, Instruction *At); - public: static char ID; CorrelatedValuePropagation(): FunctionPass(ID) { @@ -60,7 +50,7 @@ namespace { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<LazyValueInfo>(); + AU.addRequired<LazyValueInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } }; @@ -69,7 +59,7 @@ namespace { char CorrelatedValuePropagation::ID = 0; INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation", "Value Propagation", false, false) -INITIALIZE_PASS_DEPENDENCY(LazyValueInfo) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) INITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation", "Value Propagation", false, false) @@ -78,7 +68,7 @@ Pass *llvm::createCorrelatedValuePropagationPass() { return new CorrelatedValuePropagation(); } -bool CorrelatedValuePropagation::processSelect(SelectInst *S) { +static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { if (S->getType()->isVectorTy()) return false; if (isa<Constant>(S->getOperand(0))) return false; @@ -101,7 +91,7 @@ bool CorrelatedValuePropagation::processSelect(SelectInst *S) { return true; } -bool CorrelatedValuePropagation::processPHI(PHINode *P) { +static bool processPHI(PHINode *P, LazyValueInfo *LVI) { bool Changed = false; BasicBlock *BB = P->getParent(); @@ -169,7 +159,7 @@ bool CorrelatedValuePropagation::processPHI(PHINode *P) { return Changed; } -bool CorrelatedValuePropagation::processMemAccess(Instruction *I) { +static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) { Value *Pointer = nullptr; if (LoadInst *L = dyn_cast<LoadInst>(I)) Pointer = L->getPointerOperand(); @@ -186,11 +176,11 @@ bool CorrelatedValuePropagation::processMemAccess(Instruction *I) { return true; } -/// processCmp - See if LazyValueInfo's ability to exploit edge conditions, -/// or range information is sufficient to prove this comparison. Even for -/// local conditions, this can sometimes prove conditions instcombine can't by +/// See if LazyValueInfo's ability to exploit edge conditions or range +/// information is sufficient to prove this comparison. Even for local +/// conditions, this can sometimes prove conditions instcombine can't by /// exploiting range information. -bool CorrelatedValuePropagation::processCmp(CmpInst *C) { +static bool processCmp(CmpInst *C, LazyValueInfo *LVI) { Value *Op0 = C->getOperand(0); Constant *Op1 = dyn_cast<Constant>(C->getOperand(1)); if (!Op1) return false; @@ -218,14 +208,14 @@ bool CorrelatedValuePropagation::processCmp(CmpInst *C) { return true; } -/// processSwitch - Simplify a switch instruction by removing cases which can -/// never fire. If the uselessness of a case could be determined locally then -/// constant propagation would already have figured it out. Instead, walk the -/// predecessors and statically evaluate cases based on information available -/// on that edge. Cases that cannot fire no matter what the incoming edge can -/// safely be removed. If a case fires on every incoming edge then the entire -/// switch can be removed and replaced with a branch to the case destination. -bool CorrelatedValuePropagation::processSwitch(SwitchInst *SI) { +/// Simplify a switch instruction by removing cases which can never fire. If the +/// uselessness of a case could be determined locally then constant propagation +/// would already have figured it out. Instead, walk the predecessors and +/// statically evaluate cases based on information available on that edge. Cases +/// that cannot fire no matter what the incoming edge can safely be removed. If +/// a case fires on every incoming edge then the entire switch can be removed +/// and replaced with a branch to the case destination. +static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { Value *Cond = SI->getCondition(); BasicBlock *BB = SI->getParent(); @@ -304,16 +294,18 @@ bool CorrelatedValuePropagation::processSwitch(SwitchInst *SI) { return Changed; } -/// processCallSite - Infer nonnull attributes for the arguments at the -/// specified callsite. -bool CorrelatedValuePropagation::processCallSite(CallSite CS) { +/// Infer nonnull attributes for the arguments at the specified callsite. +static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { SmallVector<unsigned, 4> Indices; unsigned ArgNo = 0; for (Value *V : CS.args()) { PointerType *Type = dyn_cast<PointerType>(V->getType()); - + // Try to mark pointer typed parameters as non-null. We skip the + // relatively expensive analysis for constants which are obviously either + // null or non-null to start with. if (Type && !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && + !isa<Constant>(V) && LVI->getPredicateAt(ICmpInst::ICMP_EQ, V, ConstantPointerNull::get(Type), CS.getInstruction()) == LazyValueInfo::False) @@ -334,7 +326,62 @@ bool CorrelatedValuePropagation::processCallSite(CallSite CS) { return true; } -Constant *CorrelatedValuePropagation::getConstantAt(Value *V, Instruction *At) { +// Helper function to rewrite srem and sdiv. As a policy choice, we choose not +// to waste compile time on anything where the operands are local defs. While +// LVI can sometimes reason about such cases, it's not its primary purpose. +static bool hasLocalDefs(BinaryOperator *SDI) { + for (Value *O : SDI->operands()) { + auto *I = dyn_cast<Instruction>(O); + if (I && I->getParent() == SDI->getParent()) + return true; + } + return false; +} + +static bool hasPositiveOperands(BinaryOperator *SDI, LazyValueInfo *LVI) { + Constant *Zero = ConstantInt::get(SDI->getType(), 0); + for (Value *O : SDI->operands()) { + auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SGE, O, Zero, SDI); + if (Result != LazyValueInfo::True) + return false; + } + return true; +} + +static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy() || hasLocalDefs(SDI) || + !hasPositiveOperands(SDI, LVI)) + return false; + + ++NumSRems; + auto *BO = BinaryOperator::CreateURem(SDI->getOperand(0), SDI->getOperand(1), + SDI->getName(), SDI); + SDI->replaceAllUsesWith(BO); + SDI->eraseFromParent(); + return true; +} + +/// See if LazyValueInfo's ability to exploit edge conditions or range +/// information is sufficient to prove the both operands of this SDiv are +/// positive. If this is the case, replace the SDiv with a UDiv. Even for local +/// conditions, this can sometimes prove conditions instcombine can't by +/// exploiting range information. +static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy() || hasLocalDefs(SDI) || + !hasPositiveOperands(SDI, LVI)) + return false; + + ++NumSDivs; + auto *BO = BinaryOperator::CreateUDiv(SDI->getOperand(0), SDI->getOperand(1), + SDI->getName(), SDI); + BO->setIsExact(SDI->isExact()); + SDI->replaceAllUsesWith(BO); + SDI->eraseFromParent(); + + return true; +} + +static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { if (Constant *C = LVI->getConstant(V, At->getParent(), At)) return C; @@ -357,44 +404,45 @@ Constant *CorrelatedValuePropagation::getConstantAt(Value *V, Instruction *At) { ConstantInt::getFalse(C->getContext()); } -bool CorrelatedValuePropagation::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) - return false; - - LVI = &getAnalysis<LazyValueInfo>(); - +static bool runImpl(Function &F, LazyValueInfo *LVI) { bool FnChanged = false; - for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ++FI) { + for (BasicBlock &BB : F) { bool BBChanged = false; - for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ) { + for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) { Instruction *II = &*BI++; switch (II->getOpcode()) { case Instruction::Select: - BBChanged |= processSelect(cast<SelectInst>(II)); + BBChanged |= processSelect(cast<SelectInst>(II), LVI); break; case Instruction::PHI: - BBChanged |= processPHI(cast<PHINode>(II)); + BBChanged |= processPHI(cast<PHINode>(II), LVI); break; case Instruction::ICmp: case Instruction::FCmp: - BBChanged |= processCmp(cast<CmpInst>(II)); + BBChanged |= processCmp(cast<CmpInst>(II), LVI); break; case Instruction::Load: case Instruction::Store: - BBChanged |= processMemAccess(II); + BBChanged |= processMemAccess(II, LVI); break; case Instruction::Call: case Instruction::Invoke: - BBChanged |= processCallSite(CallSite(II)); + BBChanged |= processCallSite(CallSite(II), LVI); + break; + case Instruction::SRem: + BBChanged |= processSRem(cast<BinaryOperator>(II), LVI); + break; + case Instruction::SDiv: + BBChanged |= processSDiv(cast<BinaryOperator>(II), LVI); break; } } - Instruction *Term = FI->getTerminator(); + Instruction *Term = BB.getTerminator(); switch (Term->getOpcode()) { case Instruction::Switch: - BBChanged |= processSwitch(cast<SwitchInst>(Term)); + BBChanged |= processSwitch(cast<SwitchInst>(Term), LVI); break; case Instruction::Ret: { auto *RI = cast<ReturnInst>(Term); @@ -404,7 +452,7 @@ bool CorrelatedValuePropagation::runOnFunction(Function &F) { auto *RetVal = RI->getReturnValue(); if (!RetVal) break; // handle "ret void" if (isa<Constant>(RetVal)) break; // nothing to do - if (auto *C = getConstantAt(RetVal, RI)) { + if (auto *C = getConstantAt(RetVal, RI, LVI)) { ++NumReturns; RI->replaceUsesOfWith(RetVal, C); BBChanged = true; @@ -417,3 +465,28 @@ bool CorrelatedValuePropagation::runOnFunction(Function &F) { return FnChanged; } + +bool CorrelatedValuePropagation::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + return runImpl(F, LVI); +} + +PreservedAnalyses +CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) { + + LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F); + bool Changed = runImpl(F, LVI); + + // FIXME: We need to invalidate LVI to avoid PR28400. Is there a better + // solution? + AM.invalidate<LazyValueAnalysis>(F); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp index b67c3c7742fd..f73809d9f045 100644 --- a/lib/Transforms/Scalar/DCE.cpp +++ b/lib/Transforms/Scalar/DCE.cpp @@ -16,13 +16,14 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/DCE.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/Pass.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -41,7 +42,7 @@ namespace { initializeDeadInstEliminationPass(*PassRegistry::getPassRegistry()); } bool runOnBasicBlock(BasicBlock &BB) override { - if (skipOptnoneFunction(BB)) + if (skipBasicBlock(BB)) return false; auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr; @@ -71,28 +72,6 @@ Pass *llvm::createDeadInstEliminationPass() { return new DeadInstElimination(); } - -namespace { - //===--------------------------------------------------------------------===// - // DeadCodeElimination pass implementation - // - struct DCE : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - DCE() : FunctionPass(ID) { - initializeDCEPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - } - }; -} - -char DCE::ID = 0; -INITIALIZE_PASS(DCE, "dce", "Dead Code Elimination", false, false) - static bool DCEInstruction(Instruction *I, SmallSetVector<Instruction *, 16> &WorkList, const TargetLibraryInfo *TLI) { @@ -121,13 +100,7 @@ static bool DCEInstruction(Instruction *I, return false; } -bool DCE::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) - return false; - - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr; - +static bool eliminateDeadCode(Function &F, TargetLibraryInfo *TLI) { bool MadeChange = false; SmallSetVector<Instruction *, 16> WorkList; // Iterate over the original function, only adding insts to the worklist @@ -150,7 +123,38 @@ bool DCE::runOnFunction(Function &F) { return MadeChange; } -FunctionPass *llvm::createDeadCodeEliminationPass() { - return new DCE(); +PreservedAnalyses DCEPass::run(Function &F, AnalysisManager<Function> &AM) { + if (eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +namespace { +struct DCELegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + DCELegacyPass() : FunctionPass(ID) { + initializeDCELegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr; + + return eliminateDeadCode(F, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } +}; } +char DCELegacyPass::ID = 0; +INITIALIZE_PASS(DCELegacyPass, "dce", "Dead Code Elimination", false, false) + +FunctionPass *llvm::createDeadCodeEliminationPass() { + return new DCELegacyPass(); +} diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index 36ad0a5f7b91..ed58a87ae1a8 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -15,7 +15,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/DeadStoreElimination.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" @@ -34,9 +35,12 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <map> using namespace llvm; #define DEBUG_TYPE "dse" @@ -44,90 +48,35 @@ using namespace llvm; STATISTIC(NumRedundantStores, "Number of redundant stores deleted"); STATISTIC(NumFastStores, "Number of stores deleted"); STATISTIC(NumFastOther , "Number of other instrs removed"); +STATISTIC(NumCompletePartials, "Number of stores dead by later partials"); -namespace { - struct DSE : public FunctionPass { - AliasAnalysis *AA; - MemoryDependenceAnalysis *MD; - DominatorTree *DT; - const TargetLibraryInfo *TLI; - - static char ID; // Pass identification, replacement for typeid - DSE() : FunctionPass(ID), AA(nullptr), MD(nullptr), DT(nullptr) { - initializeDSEPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipOptnoneFunction(F)) - return false; - - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - MD = &getAnalysis<MemoryDependenceAnalysis>(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); +static cl::opt<bool> +EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking", + cl::init(true), cl::Hidden, + cl::desc("Enable partial-overwrite tracking in DSE")); - bool Changed = false; - for (BasicBlock &I : F) - // Only check non-dead blocks. Dead blocks may have strange pointer - // cycles that will confuse alias analysis. - if (DT->isReachableFromEntry(&I)) - Changed |= runOnBasicBlock(I); - - AA = nullptr; MD = nullptr; DT = nullptr; - return Changed; - } - - bool runOnBasicBlock(BasicBlock &BB); - bool MemoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI); - bool HandleFree(CallInst *F); - bool handleEndBlock(BasicBlock &BB); - void RemoveAccessedObjects(const MemoryLocation &LoadedLoc, - SmallSetVector<Value *, 16> &DeadStackObjects, - const DataLayout &DL); - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<MemoryDependenceAnalysis>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<MemoryDependenceAnalysis>(); - } - }; -} - -char DSE::ID = 0; -INITIALIZE_PASS_BEGIN(DSE, "dse", "Dead Store Elimination", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(DSE, "dse", "Dead Store Elimination", false, false) - -FunctionPass *llvm::createDeadStoreEliminationPass() { return new DSE(); } //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// -/// DeleteDeadInstruction - Delete this instruction. Before we do, go through -/// and zero out all the operands of this instruction. If any of them become -/// dead, delete them and the computation tree that feeds them. -/// +/// Delete this instruction. Before we do, go through and zero out all the +/// operands of this instruction. If any of them become dead, delete them and +/// the computation tree that feeds them. /// If ValueSet is non-null, remove any deleted instructions from it as well. -/// -static void DeleteDeadInstruction(Instruction *I, - MemoryDependenceAnalysis &MD, - const TargetLibraryInfo &TLI, - SmallSetVector<Value*, 16> *ValueSet = nullptr) { +static void +deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, + MemoryDependenceResults &MD, const TargetLibraryInfo &TLI, + SmallSetVector<Value *, 16> *ValueSet = nullptr) { SmallVector<Instruction*, 32> NowDeadInsts; NowDeadInsts.push_back(I); --NumFastOther; + // Keeping the iterator straight is a pain, so we let this routine tell the + // caller what the next instruction is after we're done mucking about. + BasicBlock::iterator NewIter = *BBI; + // Before we touch this instruction, remove it from memdep! do { Instruction *DeadInst = NowDeadInsts.pop_back_val(); @@ -150,15 +99,19 @@ static void DeleteDeadInstruction(Instruction *I, NowDeadInsts.push_back(OpI); } - DeadInst->eraseFromParent(); + + if (NewIter == DeadInst->getIterator()) + NewIter = DeadInst->eraseFromParent(); + else + DeadInst->eraseFromParent(); if (ValueSet) ValueSet->remove(DeadInst); } while (!NowDeadInsts.empty()); + *BBI = NewIter; } - -/// hasMemoryWrite - Does this instruction write some memory? This only returns -/// true for things that we can analyze with other helpers below. +/// Does this instruction write some memory? This only returns true for things +/// that we can analyze with other helpers below. static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) { if (isa<StoreInst>(I)) return true; @@ -176,30 +129,23 @@ static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) { } if (auto CS = CallSite(I)) { if (Function *F = CS.getCalledFunction()) { - if (TLI.has(LibFunc::strcpy) && - F->getName() == TLI.getName(LibFunc::strcpy)) { + StringRef FnName = F->getName(); + if (TLI.has(LibFunc::strcpy) && FnName == TLI.getName(LibFunc::strcpy)) return true; - } - if (TLI.has(LibFunc::strncpy) && - F->getName() == TLI.getName(LibFunc::strncpy)) { + if (TLI.has(LibFunc::strncpy) && FnName == TLI.getName(LibFunc::strncpy)) return true; - } - if (TLI.has(LibFunc::strcat) && - F->getName() == TLI.getName(LibFunc::strcat)) { + if (TLI.has(LibFunc::strcat) && FnName == TLI.getName(LibFunc::strcat)) return true; - } - if (TLI.has(LibFunc::strncat) && - F->getName() == TLI.getName(LibFunc::strncat)) { + if (TLI.has(LibFunc::strncat) && FnName == TLI.getName(LibFunc::strncat)) return true; - } } } return false; } -/// getLocForWrite - Return a Location stored to by the specified instruction. -/// If isRemovable returns true, this function and getLocForRead completely -/// describe the memory operations for this instruction. +/// Return a Location stored to by the specified instruction. If isRemovable +/// returns true, this function and getLocForRead completely describe the memory +/// operations for this instruction. static MemoryLocation getLocForWrite(Instruction *Inst, AliasAnalysis &AA) { if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) return MemoryLocation::get(SI); @@ -228,8 +174,8 @@ static MemoryLocation getLocForWrite(Instruction *Inst, AliasAnalysis &AA) { } } -/// getLocForRead - Return the location read by the specified "hasMemoryWrite" -/// instruction if any. +/// Return the location read by the specified "hasMemoryWrite" instruction if +/// any. static MemoryLocation getLocForRead(Instruction *Inst, const TargetLibraryInfo &TLI) { assert(hasMemoryWrite(Inst, TLI) && "Unknown instruction case"); @@ -241,9 +187,8 @@ static MemoryLocation getLocForRead(Instruction *Inst, return MemoryLocation(); } - -/// isRemovable - If the value of this instruction and the memory it writes to -/// is unused, may we delete this instruction? +/// If the value of this instruction and the memory it writes to is unused, may +/// we delete this instruction? static bool isRemovable(Instruction *I) { // Don't remove volatile/atomic stores. if (StoreInst *SI = dyn_cast<StoreInst>(I)) @@ -275,9 +220,9 @@ static bool isRemovable(Instruction *I) { } -/// isShortenable - Returns true if this instruction can be safely shortened in +/// Returns true if the end of this instruction can be safely shortened in /// length. -static bool isShortenable(Instruction *I) { +static bool isShortenableAtTheEnd(Instruction *I) { // Don't shorten stores for now if (isa<StoreInst>(I)) return false; @@ -288,6 +233,7 @@ static bool isShortenable(Instruction *I) { case Intrinsic::memset: case Intrinsic::memcpy: // Do shorten memory intrinsics. + // FIXME: Add memmove if it's also safe to transform. return true; } } @@ -297,7 +243,16 @@ static bool isShortenable(Instruction *I) { return false; } -/// getStoredPointerOperand - Return the pointer that is being written to. +/// Returns true if the beginning of this instruction can be safely shortened +/// in length. +static bool isShortenableAtTheBeginning(Instruction *I) { + // FIXME: Handle only memset for now. Supporting memcpy/memmove should be + // easily done by offsetting the source address. + IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); + return II && II->getIntrinsicID() == Intrinsic::memset; +} + +/// Return the pointer that is being written to. static Value *getStoredPointerOperand(Instruction *I) { if (StoreInst *SI = dyn_cast<StoreInst>(I)) return SI->getPointerOperand(); @@ -327,46 +282,45 @@ static uint64_t getPointerSize(const Value *V, const DataLayout &DL, } namespace { - enum OverwriteResult - { - OverwriteComplete, - OverwriteEnd, - OverwriteUnknown - }; +enum OverwriteResult { + OverwriteBegin, + OverwriteComplete, + OverwriteEnd, + OverwriteUnknown +}; } -/// isOverwrite - Return 'OverwriteComplete' if a store to the 'Later' location -/// completely overwrites a store to the 'Earlier' location. -/// 'OverwriteEnd' if the end of the 'Earlier' location is completely -/// overwritten by 'Later', or 'OverwriteUnknown' if nothing can be determined +typedef DenseMap<Instruction *, + std::map<int64_t, int64_t>> InstOverlapIntervalsTy; + +/// Return 'OverwriteComplete' if a store to the 'Later' location completely +/// overwrites a store to the 'Earlier' location, 'OverwriteEnd' if the end of +/// the 'Earlier' location is completely overwritten by 'Later', +/// 'OverwriteBegin' if the beginning of the 'Earlier' location is overwritten +/// by 'Later', or 'OverwriteUnknown' if nothing can be determined. static OverwriteResult isOverwrite(const MemoryLocation &Later, const MemoryLocation &Earlier, const DataLayout &DL, const TargetLibraryInfo &TLI, - int64_t &EarlierOff, int64_t &LaterOff) { + int64_t &EarlierOff, int64_t &LaterOff, + Instruction *DepWrite, + InstOverlapIntervalsTy &IOL) { + // If we don't know the sizes of either access, then we can't do a comparison. + if (Later.Size == MemoryLocation::UnknownSize || + Earlier.Size == MemoryLocation::UnknownSize) + return OverwriteUnknown; + const Value *P1 = Earlier.Ptr->stripPointerCasts(); const Value *P2 = Later.Ptr->stripPointerCasts(); // If the start pointers are the same, we just have to compare sizes to see if // the later store was larger than the earlier store. if (P1 == P2) { - // If we don't know the sizes of either access, then we can't do a - // comparison. - if (Later.Size == MemoryLocation::UnknownSize || - Earlier.Size == MemoryLocation::UnknownSize) - return OverwriteUnknown; - // Make sure that the Later size is >= the Earlier size. if (Later.Size >= Earlier.Size) return OverwriteComplete; } - // Otherwise, we have to have size information, and the later store has to be - // larger than the earlier one. - if (Later.Size == MemoryLocation::UnknownSize || - Earlier.Size == MemoryLocation::UnknownSize) - return OverwriteUnknown; - // Check to see if the later store is to the entire object (either a global, // an alloca, or a byval/inalloca argument). If so, then it clearly // overwrites any other store to the same object. @@ -416,8 +370,68 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, uint64_t(EarlierOff - LaterOff) + Earlier.Size <= Later.Size) return OverwriteComplete; - // The other interesting case is if the later store overwrites the end of - // the earlier store + // We may now overlap, although the overlap is not complete. There might also + // be other incomplete overlaps, and together, they might cover the complete + // earlier write. + // Note: The correctness of this logic depends on the fact that this function + // is not even called providing DepWrite when there are any intervening reads. + if (EnablePartialOverwriteTracking && + LaterOff < int64_t(EarlierOff + Earlier.Size) && + int64_t(LaterOff + Later.Size) >= EarlierOff) { + + // Insert our part of the overlap into the map. + auto &IM = IOL[DepWrite]; + DEBUG(dbgs() << "DSE: Partial overwrite: Earlier [" << EarlierOff << ", " << + int64_t(EarlierOff + Earlier.Size) << ") Later [" << + LaterOff << ", " << int64_t(LaterOff + Later.Size) << ")\n"); + + // Make sure that we only insert non-overlapping intervals and combine + // adjacent intervals. The intervals are stored in the map with the ending + // offset as the key (in the half-open sense) and the starting offset as + // the value. + int64_t LaterIntStart = LaterOff, LaterIntEnd = LaterOff + Later.Size; + + // Find any intervals ending at, or after, LaterIntStart which start + // before LaterIntEnd. + auto ILI = IM.lower_bound(LaterIntStart); + if (ILI != IM.end() && ILI->second <= LaterIntEnd) { + // This existing interval is overlapped with the current store somewhere + // in [LaterIntStart, LaterIntEnd]. Merge them by erasing the existing + // intervals and adjusting our start and end. + LaterIntStart = std::min(LaterIntStart, ILI->second); + LaterIntEnd = std::max(LaterIntEnd, ILI->first); + ILI = IM.erase(ILI); + + // Continue erasing and adjusting our end in case other previous + // intervals are also overlapped with the current store. + // + // |--- ealier 1 ---| |--- ealier 2 ---| + // |------- later---------| + // + while (ILI != IM.end() && ILI->second <= LaterIntEnd) { + assert(ILI->second > LaterIntStart && "Unexpected interval"); + LaterIntEnd = std::max(LaterIntEnd, ILI->first); + ILI = IM.erase(ILI); + } + } + + IM[LaterIntEnd] = LaterIntStart; + + ILI = IM.begin(); + if (ILI->second <= EarlierOff && + ILI->first >= int64_t(EarlierOff + Earlier.Size)) { + DEBUG(dbgs() << "DSE: Full overwrite from partials: Earlier [" << + EarlierOff << ", " << + int64_t(EarlierOff + Earlier.Size) << + ") Composite Later [" << + ILI->second << ", " << ILI->first << ")\n"); + ++NumCompletePartials; + return OverwriteComplete; + } + } + + // Another interesting case is if the later store overwrites the end of the + // earlier store. // // |--earlier--| // |-- later --| @@ -429,11 +443,25 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size)) return OverwriteEnd; + // Finally, we also need to check if the later store overwrites the beginning + // of the earlier store. + // + // |--earlier--| + // |-- later --| + // + // In this case we may want to move the destination address and trim the size + // of earlier to avoid generating writes to addresses which will definitely + // be overwritten later. + if (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff) { + assert (int64_t(LaterOff + Later.Size) < int64_t(EarlierOff + Earlier.Size) + && "Expect to be handled as OverwriteComplete" ); + return OverwriteBegin; + } // Otherwise, they don't completely overlap. return OverwriteUnknown; } -/// isPossibleSelfRead - If 'Inst' might be a self read (i.e. a noop copy of a +/// If 'Inst' might be a self read (i.e. a noop copy of a /// memory region into an identical pointer) then it doesn't actually make its /// input dead in the traditional sense. Consider this case: /// @@ -478,192 +506,13 @@ static bool isPossibleSelfRead(Instruction *Inst, } -//===----------------------------------------------------------------------===// -// DSE Pass -//===----------------------------------------------------------------------===// - -bool DSE::runOnBasicBlock(BasicBlock &BB) { - const DataLayout &DL = BB.getModule()->getDataLayout(); - bool MadeChange = false; - - // Do a top-down walk on the BB. - for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) { - Instruction *Inst = &*BBI++; - - // Handle 'free' calls specially. - if (CallInst *F = isFreeCall(Inst, TLI)) { - MadeChange |= HandleFree(F); - continue; - } - - // If we find something that writes memory, get its memory dependence. - if (!hasMemoryWrite(Inst, *TLI)) - continue; - - // If we're storing the same value back to a pointer that we just - // loaded from, then the store can be removed. - if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - - auto RemoveDeadInstAndUpdateBBI = [&](Instruction *DeadInst) { - // DeleteDeadInstruction can delete the current instruction. Save BBI - // in case we need it. - WeakVH NextInst(&*BBI); - - DeleteDeadInstruction(DeadInst, *MD, *TLI); - - if (!NextInst) // Next instruction deleted. - BBI = BB.begin(); - else if (BBI != BB.begin()) // Revisit this instruction if possible. - --BBI; - ++NumRedundantStores; - MadeChange = true; - }; - - if (LoadInst *DepLoad = dyn_cast<LoadInst>(SI->getValueOperand())) { - if (SI->getPointerOperand() == DepLoad->getPointerOperand() && - isRemovable(SI) && - MemoryIsNotModifiedBetween(DepLoad, SI)) { - - DEBUG(dbgs() << "DSE: Remove Store Of Load from same pointer:\n " - << "LOAD: " << *DepLoad << "\n STORE: " << *SI << '\n'); - - RemoveDeadInstAndUpdateBBI(SI); - continue; - } - } - - // Remove null stores into the calloc'ed objects - Constant *StoredConstant = dyn_cast<Constant>(SI->getValueOperand()); - - if (StoredConstant && StoredConstant->isNullValue() && - isRemovable(SI)) { - Instruction *UnderlyingPointer = dyn_cast<Instruction>( - GetUnderlyingObject(SI->getPointerOperand(), DL)); - - if (UnderlyingPointer && isCallocLikeFn(UnderlyingPointer, TLI) && - MemoryIsNotModifiedBetween(UnderlyingPointer, SI)) { - DEBUG(dbgs() - << "DSE: Remove null store to the calloc'ed object:\n DEAD: " - << *Inst << "\n OBJECT: " << *UnderlyingPointer << '\n'); - - RemoveDeadInstAndUpdateBBI(SI); - continue; - } - } - } - - MemDepResult InstDep = MD->getDependency(Inst); - - // Ignore any store where we can't find a local dependence. - // FIXME: cross-block DSE would be fun. :) - if (!InstDep.isDef() && !InstDep.isClobber()) - continue; - - // Figure out what location is being stored to. - MemoryLocation Loc = getLocForWrite(Inst, *AA); - - // If we didn't get a useful location, fail. - if (!Loc.Ptr) - continue; - - while (InstDep.isDef() || InstDep.isClobber()) { - // Get the memory clobbered by the instruction we depend on. MemDep will - // skip any instructions that 'Loc' clearly doesn't interact with. If we - // end up depending on a may- or must-aliased load, then we can't optimize - // away the store and we bail out. However, if we depend on on something - // that overwrites the memory location we *can* potentially optimize it. - // - // Find out what memory location the dependent instruction stores. - Instruction *DepWrite = InstDep.getInst(); - MemoryLocation DepLoc = getLocForWrite(DepWrite, *AA); - // If we didn't get a useful location, or if it isn't a size, bail out. - if (!DepLoc.Ptr) - break; - - // If we find a write that is a) removable (i.e., non-volatile), b) is - // completely obliterated by the store to 'Loc', and c) which we know that - // 'Inst' doesn't load from, then we can remove it. - if (isRemovable(DepWrite) && - !isPossibleSelfRead(Inst, Loc, DepWrite, *TLI, *AA)) { - int64_t InstWriteOffset, DepWriteOffset; - OverwriteResult OR = - isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset); - if (OR == OverwriteComplete) { - DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " - << *DepWrite << "\n KILLER: " << *Inst << '\n'); - - // Delete the store and now-dead instructions that feed it. - DeleteDeadInstruction(DepWrite, *MD, *TLI); - ++NumFastStores; - MadeChange = true; - - // DeleteDeadInstruction can delete the current instruction in loop - // cases, reset BBI. - BBI = Inst->getIterator(); - if (BBI != BB.begin()) - --BBI; - break; - } else if (OR == OverwriteEnd && isShortenable(DepWrite)) { - // TODO: base this on the target vector size so that if the earlier - // store was too small to get vector writes anyway then its likely - // a good idea to shorten it - // Power of 2 vector writes are probably always a bad idea to optimize - // as any store/memset/memcpy is likely using vector instructions so - // shortening it to not vector size is likely to be slower - MemIntrinsic* DepIntrinsic = cast<MemIntrinsic>(DepWrite); - unsigned DepWriteAlign = DepIntrinsic->getAlignment(); - if (llvm::isPowerOf2_64(InstWriteOffset) || - ((DepWriteAlign != 0) && InstWriteOffset % DepWriteAlign == 0)) { - - DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW END: " - << *DepWrite << "\n KILLER (offset " - << InstWriteOffset << ", " - << DepLoc.Size << ")" - << *Inst << '\n'); - - Value* DepWriteLength = DepIntrinsic->getLength(); - Value* TrimmedLength = ConstantInt::get(DepWriteLength->getType(), - InstWriteOffset - - DepWriteOffset); - DepIntrinsic->setLength(TrimmedLength); - MadeChange = true; - } - } - } - - // If this is a may-aliased store that is clobbering the store value, we - // can keep searching past it for another must-aliased pointer that stores - // to the same location. For example, in: - // store -> P - // store -> Q - // store -> P - // we can remove the first store to P even though we don't know if P and Q - // alias. - if (DepWrite == &BB.front()) break; - - // Can't look past this instruction if it might read 'Loc'. - if (AA->getModRefInfo(DepWrite, Loc) & MRI_Ref) - break; - - InstDep = MD->getPointerDependencyFrom(Loc, false, - DepWrite->getIterator(), &BB); - } - } - - // If this block ends in a return, unwind, or unreachable, all allocas are - // dead at its end, which means stores to them are also dead. - if (BB.getTerminator()->getNumSuccessors() == 0) - MadeChange |= handleEndBlock(BB); - - return MadeChange; -} - /// Returns true if the memory which is accessed by the second instruction is not /// modified between the first and the second instruction. /// Precondition: Second instruction must be dominated by the first /// instruction. -bool DSE::MemoryIsNotModifiedBetween(Instruction *FirstI, - Instruction *SecondI) { +static bool memoryIsNotModifiedBetween(Instruction *FirstI, + Instruction *SecondI, + AliasAnalysis *AA) { SmallVector<BasicBlock *, 16> WorkList; SmallPtrSet<BasicBlock *, 8> Visited; BasicBlock::iterator FirstBBI(FirstI); @@ -718,7 +567,7 @@ bool DSE::MemoryIsNotModifiedBetween(Instruction *FirstI, /// Find all blocks that will unconditionally lead to the block BB and append /// them to F. -static void FindUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks, +static void findUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks, BasicBlock *BB, DominatorTree *DT) { for (pred_iterator I = pred_begin(BB), E = pred_end(BB); I != E; ++I) { BasicBlock *Pred = *I; @@ -732,9 +581,11 @@ static void FindUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks, } } -/// HandleFree - Handle frees of entire structures whose dependency is a store +/// Handle frees of entire structures whose dependency is a store /// to a field of that structure. -bool DSE::HandleFree(CallInst *F) { +static bool handleFree(CallInst *F, AliasAnalysis *AA, + MemoryDependenceResults *MD, DominatorTree *DT, + const TargetLibraryInfo *TLI) { bool MadeChange = false; MemoryLocation Loc = MemoryLocation(F->getOperand(0)); @@ -761,10 +612,9 @@ bool DSE::HandleFree(CallInst *F) { if (!AA->isMustAlias(F->getArgOperand(0), DepPointer)) break; - auto Next = ++Dependency->getIterator(); - - // DCE instructions only used to calculate that store - DeleteDeadInstruction(Dependency, *MD, *TLI); + // DCE instructions only used to calculate that store. + BasicBlock::iterator BBI(Dependency); + deleteDeadInstruction(Dependency, &BBI, *MD, *TLI); ++NumFastStores; MadeChange = true; @@ -773,23 +623,53 @@ bool DSE::HandleFree(CallInst *F) { // s[0] = 0; // s[1] = 0; // This has just been deleted. // free(s); - Dep = MD->getPointerDependencyFrom(Loc, false, Next, BB); + Dep = MD->getPointerDependencyFrom(Loc, false, BBI, BB); } if (Dep.isNonLocal()) - FindUnconditionalPreds(Blocks, BB, DT); + findUnconditionalPreds(Blocks, BB, DT); } return MadeChange; } -/// handleEndBlock - Remove dead stores to stack-allocated locations in the -/// function end block. Ex: +/// Check to see if the specified location may alias any of the stack objects in +/// the DeadStackObjects set. If so, they become live because the location is +/// being loaded. +static void removeAccessedObjects(const MemoryLocation &LoadedLoc, + SmallSetVector<Value *, 16> &DeadStackObjects, + const DataLayout &DL, AliasAnalysis *AA, + const TargetLibraryInfo *TLI) { + const Value *UnderlyingPointer = GetUnderlyingObject(LoadedLoc.Ptr, DL); + + // A constant can't be in the dead pointer set. + if (isa<Constant>(UnderlyingPointer)) + return; + + // If the kill pointer can be easily reduced to an alloca, don't bother doing + // extraneous AA queries. + if (isa<AllocaInst>(UnderlyingPointer) || isa<Argument>(UnderlyingPointer)) { + DeadStackObjects.remove(const_cast<Value*>(UnderlyingPointer)); + return; + } + + // Remove objects that could alias LoadedLoc. + DeadStackObjects.remove_if([&](Value *I) { + // See if the loaded location could alias the stack location. + MemoryLocation StackLoc(I, getPointerSize(I, DL, *TLI)); + return !AA->isNoAlias(StackLoc, LoadedLoc); + }); +} + +/// Remove dead stores to stack-allocated locations in the function end block. +/// Ex: /// %A = alloca i32 /// ... /// store i32 1, i32* %A /// ret void -bool DSE::handleEndBlock(BasicBlock &BB) { +static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, + MemoryDependenceResults *MD, + const TargetLibraryInfo *TLI) { bool MadeChange = false; // Keep track of all of the stack objects that are dead at the end of the @@ -828,15 +708,14 @@ bool DSE::handleEndBlock(BasicBlock &BB) { // Stores to stack values are valid candidates for removal. bool AllDead = true; - for (SmallVectorImpl<Value *>::iterator I = Pointers.begin(), - E = Pointers.end(); I != E; ++I) - if (!DeadStackObjects.count(*I)) { + for (Value *Pointer : Pointers) + if (!DeadStackObjects.count(Pointer)) { AllDead = false; break; } if (AllDead) { - Instruction *Dead = &*BBI++; + Instruction *Dead = &*BBI; DEBUG(dbgs() << "DSE: Dead Store at End of Block:\n DEAD: " << *Dead << "\n Objects: "; @@ -849,7 +728,7 @@ bool DSE::handleEndBlock(BasicBlock &BB) { dbgs() << '\n'); // DCE instructions only used to calculate that store. - DeleteDeadInstruction(Dead, *MD, *TLI, &DeadStackObjects); + deleteDeadInstruction(Dead, &BBI, *MD, *TLI, &DeadStackObjects); ++NumFastStores; MadeChange = true; continue; @@ -858,8 +737,7 @@ bool DSE::handleEndBlock(BasicBlock &BB) { // Remove any dead non-memory-mutating instructions. if (isInstructionTriviallyDead(&*BBI, TLI)) { - Instruction *Inst = &*BBI++; - DeleteDeadInstruction(Inst, *MD, *TLI, &DeadStackObjects); + deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, &DeadStackObjects); ++NumFastOther; MadeChange = true; continue; @@ -873,7 +751,7 @@ bool DSE::handleEndBlock(BasicBlock &BB) { } if (auto CS = CallSite(&*BBI)) { - // Remove allocation function calls from the list of dead stack objects; + // Remove allocation function calls from the list of dead stack objects; // there can't be any references before the definition. if (isAllocLikeFn(&*BBI, TLI)) DeadStackObjects.remove(&*BBI); @@ -900,6 +778,14 @@ bool DSE::handleEndBlock(BasicBlock &BB) { continue; } + // We can remove the dead stores, irrespective of the fence and its ordering + // (release/acquire/seq_cst). Fences only constraints the ordering of + // already visible stores, it does not make a store visible to other + // threads. So, skipping over a fence does not change a store from being + // dead. + if (isa<FenceInst>(*BBI)) + continue; + MemoryLocation LoadedLoc; // If we encounter a use of the pointer, it is no longer considered dead @@ -922,7 +808,7 @@ bool DSE::handleEndBlock(BasicBlock &BB) { // Remove any allocas from the DeadPointer set that are loaded, as this // makes any stores above the access live. - RemoveAccessedObjects(LoadedLoc, DeadStackObjects, DL); + removeAccessedObjects(LoadedLoc, DeadStackObjects, DL, AA, TLI); // If all of the allocas were clobbered by the access then we're not going // to find anything else to process. @@ -933,29 +819,285 @@ bool DSE::handleEndBlock(BasicBlock &BB) { return MadeChange; } -/// RemoveAccessedObjects - Check to see if the specified location may alias any -/// of the stack objects in the DeadStackObjects set. If so, they become live -/// because the location is being loaded. -void DSE::RemoveAccessedObjects(const MemoryLocation &LoadedLoc, - SmallSetVector<Value *, 16> &DeadStackObjects, - const DataLayout &DL) { - const Value *UnderlyingPointer = GetUnderlyingObject(LoadedLoc.Ptr, DL); +static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, + AliasAnalysis *AA, MemoryDependenceResults *MD, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + // Must be a store instruction. + StoreInst *SI = dyn_cast<StoreInst>(Inst); + if (!SI) + return false; - // A constant can't be in the dead pointer set. - if (isa<Constant>(UnderlyingPointer)) - return; + // If we're storing the same value back to a pointer that we just loaded from, + // then the store can be removed. + if (LoadInst *DepLoad = dyn_cast<LoadInst>(SI->getValueOperand())) { + if (SI->getPointerOperand() == DepLoad->getPointerOperand() && + isRemovable(SI) && memoryIsNotModifiedBetween(DepLoad, SI, AA)) { - // If the kill pointer can be easily reduced to an alloca, don't bother doing - // extraneous AA queries. - if (isa<AllocaInst>(UnderlyingPointer) || isa<Argument>(UnderlyingPointer)) { - DeadStackObjects.remove(const_cast<Value*>(UnderlyingPointer)); - return; + DEBUG(dbgs() << "DSE: Remove Store Of Load from same pointer:\n LOAD: " + << *DepLoad << "\n STORE: " << *SI << '\n'); + + deleteDeadInstruction(SI, &BBI, *MD, *TLI); + ++NumRedundantStores; + return true; + } } - // Remove objects that could alias LoadedLoc. - DeadStackObjects.remove_if([&](Value *I) { - // See if the loaded location could alias the stack location. - MemoryLocation StackLoc(I, getPointerSize(I, DL, *TLI)); - return !AA->isNoAlias(StackLoc, LoadedLoc); - }); + // Remove null stores into the calloc'ed objects + Constant *StoredConstant = dyn_cast<Constant>(SI->getValueOperand()); + if (StoredConstant && StoredConstant->isNullValue() && isRemovable(SI)) { + Instruction *UnderlyingPointer = + dyn_cast<Instruction>(GetUnderlyingObject(SI->getPointerOperand(), DL)); + + if (UnderlyingPointer && isCallocLikeFn(UnderlyingPointer, TLI) && + memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA)) { + DEBUG( + dbgs() << "DSE: Remove null store to the calloc'ed object:\n DEAD: " + << *Inst << "\n OBJECT: " << *UnderlyingPointer << '\n'); + + deleteDeadInstruction(SI, &BBI, *MD, *TLI); + ++NumRedundantStores; + return true; + } + } + return false; +} + +static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, + MemoryDependenceResults *MD, DominatorTree *DT, + const TargetLibraryInfo *TLI) { + const DataLayout &DL = BB.getModule()->getDataLayout(); + bool MadeChange = false; + + // A map of interval maps representing partially-overwritten value parts. + InstOverlapIntervalsTy IOL; + + // Do a top-down walk on the BB. + for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) { + // Handle 'free' calls specially. + if (CallInst *F = isFreeCall(&*BBI, TLI)) { + MadeChange |= handleFree(F, AA, MD, DT, TLI); + // Increment BBI after handleFree has potentially deleted instructions. + // This ensures we maintain a valid iterator. + ++BBI; + continue; + } + + Instruction *Inst = &*BBI++; + + // Check to see if Inst writes to memory. If not, continue. + if (!hasMemoryWrite(Inst, *TLI)) + continue; + + // eliminateNoopStore will update in iterator, if necessary. + if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI)) { + MadeChange = true; + continue; + } + + // If we find something that writes memory, get its memory dependence. + MemDepResult InstDep = MD->getDependency(Inst); + + // Ignore any store where we can't find a local dependence. + // FIXME: cross-block DSE would be fun. :) + if (!InstDep.isDef() && !InstDep.isClobber()) + continue; + + // Figure out what location is being stored to. + MemoryLocation Loc = getLocForWrite(Inst, *AA); + + // If we didn't get a useful location, fail. + if (!Loc.Ptr) + continue; + + while (InstDep.isDef() || InstDep.isClobber()) { + // Get the memory clobbered by the instruction we depend on. MemDep will + // skip any instructions that 'Loc' clearly doesn't interact with. If we + // end up depending on a may- or must-aliased load, then we can't optimize + // away the store and we bail out. However, if we depend on something + // that overwrites the memory location we *can* potentially optimize it. + // + // Find out what memory location the dependent instruction stores. + Instruction *DepWrite = InstDep.getInst(); + MemoryLocation DepLoc = getLocForWrite(DepWrite, *AA); + // If we didn't get a useful location, or if it isn't a size, bail out. + if (!DepLoc.Ptr) + break; + + // If we find a write that is a) removable (i.e., non-volatile), b) is + // completely obliterated by the store to 'Loc', and c) which we know that + // 'Inst' doesn't load from, then we can remove it. + if (isRemovable(DepWrite) && + !isPossibleSelfRead(Inst, Loc, DepWrite, *TLI, *AA)) { + int64_t InstWriteOffset, DepWriteOffset; + OverwriteResult OR = + isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset, + DepWrite, IOL); + if (OR == OverwriteComplete) { + DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " + << *DepWrite << "\n KILLER: " << *Inst << '\n'); + + // Delete the store and now-dead instructions that feed it. + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI); + ++NumFastStores; + MadeChange = true; + + // We erased DepWrite; start over. + InstDep = MD->getDependency(Inst); + continue; + } else if ((OR == OverwriteEnd && isShortenableAtTheEnd(DepWrite)) || + ((OR == OverwriteBegin && + isShortenableAtTheBeginning(DepWrite)))) { + // TODO: base this on the target vector size so that if the earlier + // store was too small to get vector writes anyway then its likely + // a good idea to shorten it + // Power of 2 vector writes are probably always a bad idea to optimize + // as any store/memset/memcpy is likely using vector instructions so + // shortening it to not vector size is likely to be slower + MemIntrinsic *DepIntrinsic = cast<MemIntrinsic>(DepWrite); + unsigned DepWriteAlign = DepIntrinsic->getAlignment(); + bool IsOverwriteEnd = (OR == OverwriteEnd); + if (!IsOverwriteEnd) + InstWriteOffset = int64_t(InstWriteOffset + Loc.Size); + + if ((llvm::isPowerOf2_64(InstWriteOffset) && + DepWriteAlign <= InstWriteOffset) || + ((DepWriteAlign != 0) && InstWriteOffset % DepWriteAlign == 0)) { + + DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW " + << (IsOverwriteEnd ? "END" : "BEGIN") << ": " + << *DepWrite << "\n KILLER (offset " + << InstWriteOffset << ", " << DepLoc.Size << ")" + << *Inst << '\n'); + + int64_t NewLength = + IsOverwriteEnd + ? InstWriteOffset - DepWriteOffset + : DepLoc.Size - (InstWriteOffset - DepWriteOffset); + + Value *DepWriteLength = DepIntrinsic->getLength(); + Value *TrimmedLength = + ConstantInt::get(DepWriteLength->getType(), NewLength); + DepIntrinsic->setLength(TrimmedLength); + + if (!IsOverwriteEnd) { + int64_t OffsetMoved = (InstWriteOffset - DepWriteOffset); + Value *Indices[1] = { + ConstantInt::get(DepWriteLength->getType(), OffsetMoved)}; + GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( + DepIntrinsic->getRawDest(), Indices, "", DepWrite); + DepIntrinsic->setDest(NewDestGEP); + } + MadeChange = true; + } + } + } + + // If this is a may-aliased store that is clobbering the store value, we + // can keep searching past it for another must-aliased pointer that stores + // to the same location. For example, in: + // store -> P + // store -> Q + // store -> P + // we can remove the first store to P even though we don't know if P and Q + // alias. + if (DepWrite == &BB.front()) break; + + // Can't look past this instruction if it might read 'Loc'. + if (AA->getModRefInfo(DepWrite, Loc) & MRI_Ref) + break; + + InstDep = MD->getPointerDependencyFrom(Loc, false, + DepWrite->getIterator(), &BB); + } + } + + // If this block ends in a return, unwind, or unreachable, all allocas are + // dead at its end, which means stores to them are also dead. + if (BB.getTerminator()->getNumSuccessors() == 0) + MadeChange |= handleEndBlock(BB, AA, MD, TLI); + + return MadeChange; +} + +static bool eliminateDeadStores(Function &F, AliasAnalysis *AA, + MemoryDependenceResults *MD, DominatorTree *DT, + const TargetLibraryInfo *TLI) { + bool MadeChange = false; + for (BasicBlock &BB : F) + // Only check non-dead blocks. Dead blocks may have strange pointer + // cycles that will confuse alias analysis. + if (DT->isReachableFromEntry(&BB)) + MadeChange |= eliminateDeadStores(BB, AA, MD, DT, TLI); + return MadeChange; +} + +//===----------------------------------------------------------------------===// +// DSE Pass +//===----------------------------------------------------------------------===// +PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { + AliasAnalysis *AA = &AM.getResult<AAManager>(F); + DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); + MemoryDependenceResults *MD = &AM.getResult<MemoryDependenceAnalysis>(F); + const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F); + + if (!eliminateDeadStores(F, AA, MD, DT, TLI)) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); + PA.preserve<MemoryDependenceAnalysis>(); + return PA; +} + +namespace { +/// A legacy pass for the legacy pass manager that wraps \c DSEPass. +class DSELegacyPass : public FunctionPass { +public: + DSELegacyPass() : FunctionPass(ID) { + initializeDSELegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + MemoryDependenceResults *MD = + &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + return eliminateDeadStores(F, AA, MD, DT, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<MemoryDependenceWrapperPass>(); + } + + static char ID; // Pass identification, replacement for typeid +}; +} // end anonymous namespace + +char DSELegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false, + false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false, + false) + +FunctionPass *llvm::createDeadStoreEliminationPass() { + return new DSELegacyPass(); } diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index 7ef062e71ff3..9d0ef42e0396 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -16,8 +16,8 @@ #include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -40,6 +40,7 @@ using namespace llvm::PatternMatch; STATISTIC(NumSimplify, "Number of instructions simplified or DCE'd"); STATISTIC(NumCSE, "Number of instructions CSE'd"); +STATISTIC(NumCSECVP, "Number of compare instructions CVP'd"); STATISTIC(NumCSELoad, "Number of load instructions CSE'd"); STATISTIC(NumCSECall, "Number of call instructions CSE'd"); STATISTIC(NumDSE, "Number of trivial dead stores removed"); @@ -97,15 +98,6 @@ unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { if (BinOp->isCommutative() && BinOp->getOperand(0) > BinOp->getOperand(1)) std::swap(LHS, RHS); - if (isa<OverflowingBinaryOperator>(BinOp)) { - // Hash the overflow behavior - unsigned Overflow = - BinOp->hasNoSignedWrap() * OverflowingBinaryOperator::NoSignedWrap | - BinOp->hasNoUnsignedWrap() * - OverflowingBinaryOperator::NoUnsignedWrap; - return hash_combine(BinOp->getOpcode(), Overflow, LHS, RHS); - } - return hash_combine(BinOp->getOpcode(), LHS, RHS); } @@ -152,7 +144,7 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { if (LHSI->getOpcode() != RHSI->getOpcode()) return false; - if (LHSI->isIdenticalTo(RHSI)) + if (LHSI->isIdenticalToWhenDefined(RHSI)) return true; // If we're not strictly identical, we still might be a commutable instruction @@ -164,15 +156,6 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { "same opcode, but different instruction type?"); BinaryOperator *RHSBinOp = cast<BinaryOperator>(RHSI); - // Check overflow attributes - if (isa<OverflowingBinaryOperator>(LHSBinOp)) { - assert(isa<OverflowingBinaryOperator>(RHSBinOp) && - "same opcode, but different operator type?"); - if (LHSBinOp->hasNoUnsignedWrap() != RHSBinOp->hasNoUnsignedWrap() || - LHSBinOp->hasNoSignedWrap() != RHSBinOp->hasNoSignedWrap()) - return false; - } - // Commuted equality return LHSBinOp->getOperand(0) == RHSBinOp->getOperand(1) && LHSBinOp->getOperand(1) == RHSBinOp->getOperand(0); @@ -296,16 +279,18 @@ public: /// present the table; it is the responsibility of the consumer to inspect /// the atomicity/volatility if needed. struct LoadValue { - Value *Data; + Instruction *DefInst; unsigned Generation; int MatchingId; bool IsAtomic; + bool IsInvariant; LoadValue() - : Data(nullptr), Generation(0), MatchingId(-1), IsAtomic(false) {} - LoadValue(Value *Data, unsigned Generation, unsigned MatchingId, - bool IsAtomic) - : Data(Data), Generation(Generation), MatchingId(MatchingId), - IsAtomic(IsAtomic) {} + : DefInst(nullptr), Generation(0), MatchingId(-1), IsAtomic(false), + IsInvariant(false) {} + LoadValue(Instruction *Inst, unsigned Generation, unsigned MatchingId, + bool IsAtomic, bool IsInvariant) + : DefInst(Inst), Generation(Generation), MatchingId(MatchingId), + IsAtomic(IsAtomic), IsInvariant(IsInvariant) {} }; typedef RecyclingAllocator<BumpPtrAllocator, ScopedHashTableVal<Value *, LoadValue>> @@ -318,7 +303,8 @@ public: /// values. /// /// It uses the same generation count as loads. - typedef ScopedHashTable<CallValue, std::pair<Value *, unsigned>> CallHTType; + typedef ScopedHashTable<CallValue, std::pair<Instruction *, unsigned>> + CallHTType; CallHTType AvailableCalls; /// \brief This is the current generation of the memory value. @@ -354,7 +340,7 @@ private: // Contains all the needed information to create a stack for doing a depth // first tranversal of the tree. This includes scopes for values, loads, and // calls as well as the generation. There is a child iterator so that the - // children do not need to be store spearately. + // children do not need to be store separately. class StackNode { public: StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, @@ -446,7 +432,12 @@ private: return true; } - + bool isInvariantLoad() const { + if (auto *LI = dyn_cast<LoadInst>(Inst)) + return LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr; + return false; + } + bool isMatchingMemLoc(const ParseMemoryInst &Inst) const { return (getPointerOperand() == Inst.getPointerOperand() && getMatchingId() == Inst.getMatchingId()); @@ -500,6 +491,7 @@ private: } bool EarlyCSE::processNode(DomTreeNode *Node) { + bool Changed = false; BasicBlock *BB = Node->getBlock(); // If this block has a single predecessor, then the predecessor is the parent @@ -513,7 +505,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If this node has a single predecessor which ends in a conditional branch, // we can infer the value of the branch condition given that we took this - // path. We need the single predeccesor to ensure there's not another path + // path. We need the single predecessor to ensure there's not another path // which reaches this block where the condition might hold a different // value. Since we're adding this to the scoped hash table (like any other // def), it will have been popped if we encounter a future merge block. @@ -530,9 +522,13 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" << CondInst->getName() << "' as " << *ConditionalConstant << " in " << BB->getName() << "\n"); - // Replace all dominated uses with the known value - replaceDominatedUsesWith(CondInst, ConditionalConstant, DT, - BasicBlockEdge(Pred, BB)); + // Replace all dominated uses with the known value. + if (unsigned Count = + replaceDominatedUsesWith(CondInst, ConditionalConstant, DT, + BasicBlockEdge(Pred, BB))) { + Changed = true; + NumCSECVP = NumCSECVP + Count; + } } /// LastStore - Keep track of the last non-volatile store that we saw... for @@ -541,7 +537,6 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { /// stores which can occur in bitfield code among other things. Instruction *LastStore = nullptr; - bool Changed = false; const DataLayout &DL = BB->getModule()->getDataLayout(); // See if any instructions in the block can be eliminated. If so, do it. If @@ -567,15 +562,38 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { continue; } + if (match(Inst, m_Intrinsic<Intrinsic::experimental_guard>())) { + if (auto *CondI = + dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0))) { + // The condition we're on guarding here is true for all dominated + // locations. + if (SimpleValue::canHandle(CondI)) + AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); + } + + // Guard intrinsics read all memory, but don't write any memory. + // Accordingly, don't update the generation but consume the last store (to + // avoid an incorrect DSE). + LastStore = nullptr; + continue; + } + // If the instruction can be simplified (e.g. X+0 = X) then replace it with // its simpler value. if (Value *V = SimplifyInstruction(Inst, DL, &TLI, &DT, &AC)) { DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V << '\n'); - Inst->replaceAllUsesWith(V); - Inst->eraseFromParent(); - Changed = true; - ++NumSimplify; - continue; + if (!Inst->use_empty()) { + Inst->replaceAllUsesWith(V); + Changed = true; + } + if (isInstructionTriviallyDead(Inst, &TLI)) { + Inst->eraseFromParent(); + Changed = true; + } + if (Changed) { + ++NumSimplify; + continue; + } } // If this is a simple instruction that we can value number, process it. @@ -583,6 +601,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // See if the instruction has an available value. If so, use it. if (Value *V = AvailableValues.lookup(Inst)) { DEBUG(dbgs() << "EarlyCSE CSE: " << *Inst << " to: " << *V << '\n'); + if (auto *I = dyn_cast<Instruction>(V)) + I->andIRFlags(Inst); Inst->replaceAllUsesWith(V); Inst->eraseFromParent(); Changed = true; @@ -606,18 +626,25 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { } // If we have an available version of this load, and if it is the right - // generation, replace this instruction. + // generation or the load is known to be from an invariant location, + // replace this instruction. + // + // A dominating invariant load implies that the location loaded from is + // unchanging beginning at the point of the invariant load, so the load + // we're CSE'ing _away_ does not need to be invariant, only the available + // load we're CSE'ing _to_ does. LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); - if (InVal.Data != nullptr && InVal.Generation == CurrentGeneration && + if (InVal.DefInst != nullptr && + (InVal.Generation == CurrentGeneration || InVal.IsInvariant) && InVal.MatchingId == MemInst.getMatchingId() && // We don't yet handle removing loads with ordering of any kind. !MemInst.isVolatile() && MemInst.isUnordered() && // We can't replace an atomic load with one which isn't also atomic. InVal.IsAtomic >= MemInst.isAtomic()) { - Value *Op = getOrCreateResult(InVal.Data, Inst->getType()); + Value *Op = getOrCreateResult(InVal.DefInst, Inst->getType()); if (Op != nullptr) { DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst - << " to: " << *InVal.Data << '\n'); + << " to: " << *InVal.DefInst << '\n'); if (!Inst->use_empty()) Inst->replaceAllUsesWith(Op); Inst->eraseFromParent(); @@ -631,7 +658,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { AvailableLoads.insert( MemInst.getPointerOperand(), LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), - MemInst.isAtomic())); + MemInst.isAtomic(), MemInst.isInvariantLoad())); LastStore = nullptr; continue; } @@ -649,7 +676,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { if (CallValue::canHandle(Inst)) { // If we have an available version of this call, and if it is the right // generation, replace this instruction. - std::pair<Value *, unsigned> InVal = AvailableCalls.lookup(Inst); + std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(Inst); if (InVal.first != nullptr && InVal.second == CurrentGeneration) { DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst << " to: " << *InVal.first << '\n'); @@ -663,7 +690,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // Otherwise, remember that we have this instruction. AvailableCalls.insert( - Inst, std::pair<Value *, unsigned>(Inst, CurrentGeneration)); + Inst, std::pair<Instruction *, unsigned>(Inst, CurrentGeneration)); continue; } @@ -673,7 +700,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // to advance the generation. We do need to prevent DSE across the fence, // but that's handled above. if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) - if (FI->getOrdering() == Release) { + if (FI->getOrdering() == AtomicOrdering::Release) { assert(Inst->mayReadFromMemory() && "relied on to prevent DSE above"); continue; } @@ -685,8 +712,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // the store originally was. if (MemInst.isValid() && MemInst.isStore()) { LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); - if (InVal.Data && - InVal.Data == getOrCreateResult(Inst, InVal.Data->getType()) && + if (InVal.DefInst && + InVal.DefInst == getOrCreateResult(Inst, InVal.DefInst->getType()) && InVal.Generation == CurrentGeneration && InVal.MatchingId == MemInst.getMatchingId() && // We don't yet handle removing stores with ordering of any kind. @@ -743,7 +770,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { AvailableLoads.insert( MemInst.getPointerOperand(), LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), - MemInst.isAtomic())); + MemInst.isAtomic(), /*IsInvariant=*/false)); // Remember that this was the last unordered store we saw for DSE. We // don't yet handle DSE on ordered or volatile stores since we don't @@ -818,11 +845,11 @@ bool EarlyCSE::run() { } PreservedAnalyses EarlyCSEPass::run(Function &F, - AnalysisManager<Function> *AM) { - auto &TLI = AM->getResult<TargetLibraryAnalysis>(F); - auto &TTI = AM->getResult<TargetIRAnalysis>(F); - auto &DT = AM->getResult<DominatorTreeAnalysis>(F); - auto &AC = AM->getResult<AssumptionAnalysis>(F); + AnalysisManager<Function> &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); EarlyCSE CSE(TLI, TTI, DT, AC); @@ -833,6 +860,7 @@ PreservedAnalyses EarlyCSEPass::run(Function &F, // FIXME: Bundle this with other CFG-preservation. PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); return PA; } @@ -853,7 +881,7 @@ public: } bool runOnFunction(Function &F) override { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); diff --git a/lib/Transforms/Scalar/Float2Int.cpp b/lib/Transforms/Scalar/Float2Int.cpp index 7f5d78656b50..7aa6dc6992b6 100644 --- a/lib/Transforms/Scalar/Float2Int.cpp +++ b/lib/Transforms/Scalar/Float2Int.cpp @@ -13,15 +13,13 @@ //===----------------------------------------------------------------------===// #define DEBUG_TYPE "float2int" + +#include "llvm/Transforms/Scalar/Float2Int.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/EquivalenceClasses.h" -#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" @@ -53,41 +51,31 @@ MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden, "(default=64)")); namespace { - struct Float2Int : public FunctionPass { + struct Float2IntLegacyPass : public FunctionPass { static char ID; // Pass identification, replacement for typeid - Float2Int() : FunctionPass(ID) { - initializeFloat2IntPass(*PassRegistry::getPassRegistry()); + Float2IntLegacyPass() : FunctionPass(ID) { + initializeFloat2IntLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + return Impl.runImpl(F); } - bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addPreserved<GlobalsAAWrapperPass>(); } - void findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots); - ConstantRange seen(Instruction *I, ConstantRange R); - ConstantRange badRange(); - ConstantRange unknownRange(); - ConstantRange validateRange(ConstantRange R); - void walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots); - void walkForwards(); - bool validateAndTransform(); - Value *convert(Instruction *I, Type *ToTy); - void cleanup(); - - MapVector<Instruction*, ConstantRange > SeenInsts; - SmallPtrSet<Instruction*,8> Roots; - EquivalenceClasses<Instruction*> ECs; - MapVector<Instruction*, Value*> ConvertedInsts; - LLVMContext *Ctx; + private: + Float2IntPass Impl; }; } -char Float2Int::ID = 0; -INITIALIZE_PASS_BEGIN(Float2Int, "float2int", "Float to int", false, false) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(Float2Int, "float2int", "Float to int", false, false) +char Float2IntLegacyPass::ID = 0; +INITIALIZE_PASS(Float2IntLegacyPass, "float2int", "Float to int", false, false) // Given a FCmp predicate, return a matching ICmp predicate if one // exists, otherwise return BAD_ICMP_PREDICATE. @@ -129,7 +117,7 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) { // Find the roots - instructions that convert from the FP domain to // integer domain. -void Float2Int::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) { +void Float2IntPass::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) { for (auto &I : instructions(F)) { if (isa<VectorType>(I.getType())) continue; @@ -149,7 +137,7 @@ void Float2Int::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) { } // Helper - mark I as having been traversed, having range R. -ConstantRange Float2Int::seen(Instruction *I, ConstantRange R) { +ConstantRange Float2IntPass::seen(Instruction *I, ConstantRange R) { DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n"); if (SeenInsts.find(I) != SeenInsts.end()) SeenInsts.find(I)->second = R; @@ -159,13 +147,13 @@ ConstantRange Float2Int::seen(Instruction *I, ConstantRange R) { } // Helper - get a range representing a poison value. -ConstantRange Float2Int::badRange() { +ConstantRange Float2IntPass::badRange() { return ConstantRange(MaxIntegerBW + 1, true); } -ConstantRange Float2Int::unknownRange() { +ConstantRange Float2IntPass::unknownRange() { return ConstantRange(MaxIntegerBW + 1, false); } -ConstantRange Float2Int::validateRange(ConstantRange R) { +ConstantRange Float2IntPass::validateRange(ConstantRange R) { if (R.getBitWidth() > MaxIntegerBW + 1) return badRange(); return R; @@ -185,7 +173,7 @@ ConstantRange Float2Int::validateRange(ConstantRange R) { // Breadth-first walk of the use-def graph; determine the set of nodes // we care about and eagerly determine if some of them are poisonous. -void Float2Int::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { +void Float2IntPass::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { std::deque<Instruction*> Worklist(Roots.begin(), Roots.end()); while (!Worklist.empty()) { Instruction *I = Worklist.back(); @@ -246,8 +234,8 @@ void Float2Int::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { // Walk forwards down the list of seen instructions, so we visit defs before // uses. -void Float2Int::walkForwards() { - for (auto &It : make_range(SeenInsts.rbegin(), SeenInsts.rend())) { +void Float2IntPass::walkForwards() { + for (auto &It : reverse(SeenInsts)) { if (It.second != unknownRange()) continue; @@ -318,7 +306,7 @@ void Float2Int::walkForwards() { // Instead, we ask APFloat to round itself to an integral value - this // preserves sign-of-zero - then compare the result with the original. // - APFloat F = CF->getValueAPF(); + const APFloat &F = CF->getValueAPF(); // First, weed out obviously incorrect values. Non-finite numbers // can't be represented and neither can negative zero, unless @@ -357,7 +345,7 @@ void Float2Int::walkForwards() { } // If there is a valid transform to be done, do it. -bool Float2Int::validateAndTransform() { +bool Float2IntPass::validateAndTransform() { bool MadeChange = false; // Iterate over every disjoint partition of the def-use graph. @@ -439,7 +427,7 @@ bool Float2Int::validateAndTransform() { return MadeChange; } -Value *Float2Int::convert(Instruction *I, Type *ToTy) { +Value *Float2IntPass::convert(Instruction *I, Type *ToTy) { if (ConvertedInsts.find(I) != ConvertedInsts.end()) // Already converted this instruction. return ConvertedInsts[I]; @@ -511,15 +499,12 @@ Value *Float2Int::convert(Instruction *I, Type *ToTy) { } // Perform dead code elimination on the instructions we just modified. -void Float2Int::cleanup() { - for (auto &I : make_range(ConvertedInsts.rbegin(), ConvertedInsts.rend())) +void Float2IntPass::cleanup() { + for (auto &I : reverse(ConvertedInsts)) I.first->eraseFromParent(); } -bool Float2Int::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) - return false; - +bool Float2IntPass::runImpl(Function &F) { DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n"); // Clear out all state. ECs = EquivalenceClasses<Instruction*>(); @@ -540,4 +525,17 @@ bool Float2Int::runOnFunction(Function &F) { return Modified; } -FunctionPass *llvm::createFloat2IntPass() { return new Float2Int(); } +namespace llvm { +FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); } + +PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &) { + if (!runImpl(F)) + return PreservedAnalyses::all(); + else { + // FIXME: This should also 'preserve the CFG'. + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; + } +} +} // End namespace llvm diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index a028b8c444ba..a35a1062cbcd 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -15,7 +15,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Hashing.h" @@ -44,7 +44,6 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -53,6 +52,7 @@ #include "llvm/Transforms/Utils/SSAUpdater.h" #include <vector> using namespace llvm; +using namespace llvm::gvn; using namespace PatternMatch; #define DEBUG_TYPE "gvn" @@ -74,106 +74,167 @@ static cl::opt<uint32_t> MaxRecurseDepth("max-recurse-depth", cl::Hidden, cl::init(1000), cl::ZeroOrMore, cl::desc("Max recurse depth (default = 1000)")); -//===----------------------------------------------------------------------===// -// ValueTable Class -//===----------------------------------------------------------------------===// - -/// This class holds the mapping between values and value numbers. It is used -/// as an efficient mechanism to determine the expression-wise equivalence of -/// two values. -namespace { - struct Expression { - uint32_t opcode; - Type *type; - SmallVector<uint32_t, 4> varargs; +struct llvm::GVN::Expression { + uint32_t opcode; + Type *type; + SmallVector<uint32_t, 4> varargs; - Expression(uint32_t o = ~2U) : opcode(o) { } + Expression(uint32_t o = ~2U) : opcode(o) {} - bool operator==(const Expression &other) const { - if (opcode != other.opcode) - return false; - if (opcode == ~0U || opcode == ~1U) - return true; - if (type != other.type) - return false; - if (varargs != other.varargs) - return false; + bool operator==(const Expression &other) const { + if (opcode != other.opcode) + return false; + if (opcode == ~0U || opcode == ~1U) return true; - } - - friend hash_code hash_value(const Expression &Value) { - return hash_combine(Value.opcode, Value.type, - hash_combine_range(Value.varargs.begin(), - Value.varargs.end())); - } - }; + if (type != other.type) + return false; + if (varargs != other.varargs) + return false; + return true; + } - class ValueTable { - DenseMap<Value*, uint32_t> valueNumbering; - DenseMap<Expression, uint32_t> expressionNumbering; - AliasAnalysis *AA; - MemoryDependenceAnalysis *MD; - DominatorTree *DT; - - uint32_t nextValueNumber; - - Expression create_expression(Instruction* I); - Expression create_cmp_expression(unsigned Opcode, - CmpInst::Predicate Predicate, - Value *LHS, Value *RHS); - Expression create_extractvalue_expression(ExtractValueInst* EI); - uint32_t lookup_or_add_call(CallInst* C); - public: - ValueTable() : nextValueNumber(1) { } - uint32_t lookup_or_add(Value *V); - uint32_t lookup(Value *V) const; - uint32_t lookup_or_add_cmp(unsigned Opcode, CmpInst::Predicate Pred, - Value *LHS, Value *RHS); - bool exists(Value *V) const; - void add(Value *V, uint32_t num); - void clear(); - void erase(Value *v); - void setAliasAnalysis(AliasAnalysis* A) { AA = A; } - AliasAnalysis *getAliasAnalysis() const { return AA; } - void setMemDep(MemoryDependenceAnalysis* M) { MD = M; } - void setDomTree(DominatorTree* D) { DT = D; } - uint32_t getNextUnusedValueNumber() { return nextValueNumber; } - void verifyRemoved(const Value *) const; - }; -} + friend hash_code hash_value(const Expression &Value) { + return hash_combine( + Value.opcode, Value.type, + hash_combine_range(Value.varargs.begin(), Value.varargs.end())); + } +}; namespace llvm { -template <> struct DenseMapInfo<Expression> { - static inline Expression getEmptyKey() { - return ~0U; - } +template <> struct DenseMapInfo<GVN::Expression> { + static inline GVN::Expression getEmptyKey() { return ~0U; } - static inline Expression getTombstoneKey() { - return ~1U; - } + static inline GVN::Expression getTombstoneKey() { return ~1U; } - static unsigned getHashValue(const Expression e) { + static unsigned getHashValue(const GVN::Expression &e) { using llvm::hash_value; return static_cast<unsigned>(hash_value(e)); } - static bool isEqual(const Expression &LHS, const Expression &RHS) { + static bool isEqual(const GVN::Expression &LHS, const GVN::Expression &RHS) { return LHS == RHS; } }; +} // End llvm namespace. + +/// Represents a particular available value that we know how to materialize. +/// Materialization of an AvailableValue never fails. An AvailableValue is +/// implicitly associated with a rematerialization point which is the +/// location of the instruction from which it was formed. +struct llvm::gvn::AvailableValue { + enum ValType { + SimpleVal, // A simple offsetted value that is accessed. + LoadVal, // A value produced by a load. + MemIntrin, // A memory intrinsic which is loaded from. + UndefVal // A UndefValue representing a value from dead block (which + // is not yet physically removed from the CFG). + }; -} + /// V - The value that is live out of the block. + PointerIntPair<Value *, 2, ValType> Val; + + /// Offset - The byte offset in Val that is interesting for the load query. + unsigned Offset; + + static AvailableValue get(Value *V, unsigned Offset = 0) { + AvailableValue Res; + Res.Val.setPointer(V); + Res.Val.setInt(SimpleVal); + Res.Offset = Offset; + return Res; + } + + static AvailableValue getMI(MemIntrinsic *MI, unsigned Offset = 0) { + AvailableValue Res; + Res.Val.setPointer(MI); + Res.Val.setInt(MemIntrin); + Res.Offset = Offset; + return Res; + } + + static AvailableValue getLoad(LoadInst *LI, unsigned Offset = 0) { + AvailableValue Res; + Res.Val.setPointer(LI); + Res.Val.setInt(LoadVal); + Res.Offset = Offset; + return Res; + } + + static AvailableValue getUndef() { + AvailableValue Res; + Res.Val.setPointer(nullptr); + Res.Val.setInt(UndefVal); + Res.Offset = 0; + return Res; + } + + bool isSimpleValue() const { return Val.getInt() == SimpleVal; } + bool isCoercedLoadValue() const { return Val.getInt() == LoadVal; } + bool isMemIntrinValue() const { return Val.getInt() == MemIntrin; } + bool isUndefValue() const { return Val.getInt() == UndefVal; } + + Value *getSimpleValue() const { + assert(isSimpleValue() && "Wrong accessor"); + return Val.getPointer(); + } + + LoadInst *getCoercedLoadValue() const { + assert(isCoercedLoadValue() && "Wrong accessor"); + return cast<LoadInst>(Val.getPointer()); + } + + MemIntrinsic *getMemIntrinValue() const { + assert(isMemIntrinValue() && "Wrong accessor"); + return cast<MemIntrinsic>(Val.getPointer()); + } + + /// Emit code at the specified insertion point to adjust the value defined + /// here to the specified type. This handles various coercion cases. + Value *MaterializeAdjustedValue(LoadInst *LI, Instruction *InsertPt, + GVN &gvn) const; +}; + +/// Represents an AvailableValue which can be rematerialized at the end of +/// the associated BasicBlock. +struct llvm::gvn::AvailableValueInBlock { + /// BB - The basic block in question. + BasicBlock *BB; + + /// AV - The actual available value + AvailableValue AV; + + static AvailableValueInBlock get(BasicBlock *BB, AvailableValue &&AV) { + AvailableValueInBlock Res; + Res.BB = BB; + Res.AV = std::move(AV); + return Res; + } + + static AvailableValueInBlock get(BasicBlock *BB, Value *V, + unsigned Offset = 0) { + return get(BB, AvailableValue::get(V, Offset)); + } + static AvailableValueInBlock getUndef(BasicBlock *BB) { + return get(BB, AvailableValue::getUndef()); + } + + /// Emit code at the end of this block to adjust the value defined here to + /// the specified type. This handles various coercion cases. + Value *MaterializeAdjustedValue(LoadInst *LI, GVN &gvn) const { + return AV.MaterializeAdjustedValue(LI, BB->getTerminator(), gvn); + } +}; //===----------------------------------------------------------------------===// // ValueTable Internal Functions //===----------------------------------------------------------------------===// -Expression ValueTable::create_expression(Instruction *I) { +GVN::Expression GVN::ValueTable::createExpr(Instruction *I) { Expression e; e.type = I->getType(); e.opcode = I->getOpcode(); for (Instruction::op_iterator OI = I->op_begin(), OE = I->op_end(); OI != OE; ++OI) - e.varargs.push_back(lookup_or_add(*OI)); + e.varargs.push_back(lookupOrAdd(*OI)); if (I->isCommutative()) { // Ensure that commutative instructions that only differ by a permutation // of their operands get the same value number by sorting the operand value @@ -201,15 +262,15 @@ Expression ValueTable::create_expression(Instruction *I) { return e; } -Expression ValueTable::create_cmp_expression(unsigned Opcode, - CmpInst::Predicate Predicate, - Value *LHS, Value *RHS) { +GVN::Expression GVN::ValueTable::createCmpExpr(unsigned Opcode, + CmpInst::Predicate Predicate, + Value *LHS, Value *RHS) { assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && "Not a comparison!"); Expression e; e.type = CmpInst::makeCmpResultType(LHS->getType()); - e.varargs.push_back(lookup_or_add(LHS)); - e.varargs.push_back(lookup_or_add(RHS)); + e.varargs.push_back(lookupOrAdd(LHS)); + e.varargs.push_back(lookupOrAdd(RHS)); // Sort the operand value numbers so x<y and y>x get the same value number. if (e.varargs[0] > e.varargs[1]) { @@ -220,7 +281,7 @@ Expression ValueTable::create_cmp_expression(unsigned Opcode, return e; } -Expression ValueTable::create_extractvalue_expression(ExtractValueInst *EI) { +GVN::Expression GVN::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { assert(EI && "Not an ExtractValueInst?"); Expression e; e.type = EI->getType(); @@ -252,8 +313,8 @@ Expression ValueTable::create_extractvalue_expression(ExtractValueInst *EI) { // Intrinsic recognized. Grab its args to finish building the expression. assert(I->getNumArgOperands() == 2 && "Expect two args for recognised intrinsics."); - e.varargs.push_back(lookup_or_add(I->getArgOperand(0))); - e.varargs.push_back(lookup_or_add(I->getArgOperand(1))); + e.varargs.push_back(lookupOrAdd(I->getArgOperand(0))); + e.varargs.push_back(lookupOrAdd(I->getArgOperand(1))); return e; } } @@ -263,7 +324,7 @@ Expression ValueTable::create_extractvalue_expression(ExtractValueInst *EI) { e.opcode = EI->getOpcode(); for (Instruction::op_iterator OI = EI->op_begin(), OE = EI->op_end(); OI != OE; ++OI) - e.varargs.push_back(lookup_or_add(*OI)); + e.varargs.push_back(lookupOrAdd(*OI)); for (ExtractValueInst::idx_iterator II = EI->idx_begin(), IE = EI->idx_end(); II != IE; ++II) @@ -276,20 +337,32 @@ Expression ValueTable::create_extractvalue_expression(ExtractValueInst *EI) { // ValueTable External Functions //===----------------------------------------------------------------------===// +GVN::ValueTable::ValueTable() : nextValueNumber(1) {} +GVN::ValueTable::ValueTable(const ValueTable &Arg) + : valueNumbering(Arg.valueNumbering), + expressionNumbering(Arg.expressionNumbering), AA(Arg.AA), MD(Arg.MD), + DT(Arg.DT), nextValueNumber(Arg.nextValueNumber) {} +GVN::ValueTable::ValueTable(ValueTable &&Arg) + : valueNumbering(std::move(Arg.valueNumbering)), + expressionNumbering(std::move(Arg.expressionNumbering)), + AA(std::move(Arg.AA)), MD(std::move(Arg.MD)), DT(std::move(Arg.DT)), + nextValueNumber(std::move(Arg.nextValueNumber)) {} +GVN::ValueTable::~ValueTable() {} + /// add - Insert a value into the table with a specified value number. -void ValueTable::add(Value *V, uint32_t num) { +void GVN::ValueTable::add(Value *V, uint32_t num) { valueNumbering.insert(std::make_pair(V, num)); } -uint32_t ValueTable::lookup_or_add_call(CallInst *C) { +uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { if (AA->doesNotAccessMemory(C)) { - Expression exp = create_expression(C); + Expression exp = createExpr(C); uint32_t &e = expressionNumbering[exp]; if (!e) e = nextValueNumber++; valueNumbering[C] = e; return e; } else if (AA->onlyReadsMemory(C)) { - Expression exp = create_expression(C); + Expression exp = createExpr(C); uint32_t &e = expressionNumbering[exp]; if (!e) { e = nextValueNumber++; @@ -318,21 +391,21 @@ uint32_t ValueTable::lookup_or_add_call(CallInst *C) { } for (unsigned i = 0, e = C->getNumArgOperands(); i < e; ++i) { - uint32_t c_vn = lookup_or_add(C->getArgOperand(i)); - uint32_t cd_vn = lookup_or_add(local_cdep->getArgOperand(i)); + uint32_t c_vn = lookupOrAdd(C->getArgOperand(i)); + uint32_t cd_vn = lookupOrAdd(local_cdep->getArgOperand(i)); if (c_vn != cd_vn) { valueNumbering[C] = nextValueNumber; return nextValueNumber++; } } - uint32_t v = lookup_or_add(local_cdep); + uint32_t v = lookupOrAdd(local_cdep); valueNumbering[C] = v; return v; } // Non-local case. - const MemoryDependenceAnalysis::NonLocalDepInfo &deps = + const MemoryDependenceResults::NonLocalDepInfo &deps = MD->getNonLocalCallDependency(CallSite(C)); // FIXME: Move the checking logic to MemDep! CallInst* cdep = nullptr; @@ -372,15 +445,15 @@ uint32_t ValueTable::lookup_or_add_call(CallInst *C) { return nextValueNumber++; } for (unsigned i = 0, e = C->getNumArgOperands(); i < e; ++i) { - uint32_t c_vn = lookup_or_add(C->getArgOperand(i)); - uint32_t cd_vn = lookup_or_add(cdep->getArgOperand(i)); + uint32_t c_vn = lookupOrAdd(C->getArgOperand(i)); + uint32_t cd_vn = lookupOrAdd(cdep->getArgOperand(i)); if (c_vn != cd_vn) { valueNumbering[C] = nextValueNumber; return nextValueNumber++; } } - uint32_t v = lookup_or_add(cdep); + uint32_t v = lookupOrAdd(cdep); valueNumbering[C] = v; return v; @@ -391,11 +464,11 @@ uint32_t ValueTable::lookup_or_add_call(CallInst *C) { } /// Returns true if a value number exists for the specified value. -bool ValueTable::exists(Value *V) const { return valueNumbering.count(V) != 0; } +bool GVN::ValueTable::exists(Value *V) const { return valueNumbering.count(V) != 0; } /// lookup_or_add - Returns the value number for the specified value, assigning /// it a new number if it did not have one before. -uint32_t ValueTable::lookup_or_add(Value *V) { +uint32_t GVN::ValueTable::lookupOrAdd(Value *V) { DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V); if (VI != valueNumbering.end()) return VI->second; @@ -409,7 +482,7 @@ uint32_t ValueTable::lookup_or_add(Value *V) { Expression exp; switch (I->getOpcode()) { case Instruction::Call: - return lookup_or_add_call(cast<CallInst>(I)); + return lookupOrAddCall(cast<CallInst>(I)); case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -448,10 +521,10 @@ uint32_t ValueTable::lookup_or_add(Value *V) { case Instruction::ShuffleVector: case Instruction::InsertValue: case Instruction::GetElementPtr: - exp = create_expression(I); + exp = createExpr(I); break; case Instruction::ExtractValue: - exp = create_extractvalue_expression(cast<ExtractValueInst>(I)); + exp = createExtractvalueExpr(cast<ExtractValueInst>(I)); break; default: valueNumbering[V] = nextValueNumber; @@ -466,7 +539,7 @@ uint32_t ValueTable::lookup_or_add(Value *V) { /// Returns the value number of the specified value. Fails if /// the value has not yet been numbered. -uint32_t ValueTable::lookup(Value *V) const { +uint32_t GVN::ValueTable::lookup(Value *V) const { DenseMap<Value*, uint32_t>::const_iterator VI = valueNumbering.find(V); assert(VI != valueNumbering.end() && "Value not numbered?"); return VI->second; @@ -476,30 +549,30 @@ uint32_t ValueTable::lookup(Value *V) const { /// assigning it a new number if it did not have one before. Useful when /// we deduced the result of a comparison, but don't immediately have an /// instruction realizing that comparison to hand. -uint32_t ValueTable::lookup_or_add_cmp(unsigned Opcode, - CmpInst::Predicate Predicate, - Value *LHS, Value *RHS) { - Expression exp = create_cmp_expression(Opcode, Predicate, LHS, RHS); +uint32_t GVN::ValueTable::lookupOrAddCmp(unsigned Opcode, + CmpInst::Predicate Predicate, + Value *LHS, Value *RHS) { + Expression exp = createCmpExpr(Opcode, Predicate, LHS, RHS); uint32_t& e = expressionNumbering[exp]; if (!e) e = nextValueNumber++; return e; } /// Remove all entries from the ValueTable. -void ValueTable::clear() { +void GVN::ValueTable::clear() { valueNumbering.clear(); expressionNumbering.clear(); nextValueNumber = 1; } /// Remove a value from the value numbering. -void ValueTable::erase(Value *V) { +void GVN::ValueTable::erase(Value *V) { valueNumbering.erase(V); } /// verifyRemoved - Verify that the value is removed from all internal data /// structures. -void ValueTable::verifyRemoved(const Value *V) const { +void GVN::ValueTable::verifyRemoved(const Value *V) const { for (DenseMap<Value*, uint32_t>::const_iterator I = valueNumbering.begin(), E = valueNumbering.end(); I != E; ++I) { assert(I->first != V && "Inst still occurs in value numbering map!"); @@ -510,251 +583,26 @@ void ValueTable::verifyRemoved(const Value *V) const { // GVN Pass //===----------------------------------------------------------------------===// -namespace { - class GVN; - struct AvailableValueInBlock { - /// BB - The basic block in question. - BasicBlock *BB; - enum ValType { - SimpleVal, // A simple offsetted value that is accessed. - LoadVal, // A value produced by a load. - MemIntrin, // A memory intrinsic which is loaded from. - UndefVal // A UndefValue representing a value from dead block (which - // is not yet physically removed from the CFG). - }; - - /// V - The value that is live out of the block. - PointerIntPair<Value *, 2, ValType> Val; - - /// Offset - The byte offset in Val that is interesting for the load query. - unsigned Offset; - - static AvailableValueInBlock get(BasicBlock *BB, Value *V, - unsigned Offset = 0) { - AvailableValueInBlock Res; - Res.BB = BB; - Res.Val.setPointer(V); - Res.Val.setInt(SimpleVal); - Res.Offset = Offset; - return Res; - } - - static AvailableValueInBlock getMI(BasicBlock *BB, MemIntrinsic *MI, - unsigned Offset = 0) { - AvailableValueInBlock Res; - Res.BB = BB; - Res.Val.setPointer(MI); - Res.Val.setInt(MemIntrin); - Res.Offset = Offset; - return Res; - } - - static AvailableValueInBlock getLoad(BasicBlock *BB, LoadInst *LI, - unsigned Offset = 0) { - AvailableValueInBlock Res; - Res.BB = BB; - Res.Val.setPointer(LI); - Res.Val.setInt(LoadVal); - Res.Offset = Offset; - return Res; - } - - static AvailableValueInBlock getUndef(BasicBlock *BB) { - AvailableValueInBlock Res; - Res.BB = BB; - Res.Val.setPointer(nullptr); - Res.Val.setInt(UndefVal); - Res.Offset = 0; - return Res; - } - - bool isSimpleValue() const { return Val.getInt() == SimpleVal; } - bool isCoercedLoadValue() const { return Val.getInt() == LoadVal; } - bool isMemIntrinValue() const { return Val.getInt() == MemIntrin; } - bool isUndefValue() const { return Val.getInt() == UndefVal; } - - Value *getSimpleValue() const { - assert(isSimpleValue() && "Wrong accessor"); - return Val.getPointer(); - } - - LoadInst *getCoercedLoadValue() const { - assert(isCoercedLoadValue() && "Wrong accessor"); - return cast<LoadInst>(Val.getPointer()); - } - - MemIntrinsic *getMemIntrinValue() const { - assert(isMemIntrinValue() && "Wrong accessor"); - return cast<MemIntrinsic>(Val.getPointer()); - } - - /// Emit code into this block to adjust the value defined here to the - /// specified type. This handles various coercion cases. - Value *MaterializeAdjustedValue(LoadInst *LI, GVN &gvn) const; - }; - - class GVN : public FunctionPass { - bool NoLoads; - MemoryDependenceAnalysis *MD; - DominatorTree *DT; - const TargetLibraryInfo *TLI; - AssumptionCache *AC; - SetVector<BasicBlock *> DeadBlocks; - - ValueTable VN; - - /// A mapping from value numbers to lists of Value*'s that - /// have that value number. Use findLeader to query it. - struct LeaderTableEntry { - Value *Val; - const BasicBlock *BB; - LeaderTableEntry *Next; - }; - DenseMap<uint32_t, LeaderTableEntry> LeaderTable; - BumpPtrAllocator TableAllocator; - - // Block-local map of equivalent values to their leader, does not - // propagate to any successors. Entries added mid-block are applied - // to the remaining instructions in the block. - SmallMapVector<llvm::Value *, llvm::Constant *, 4> ReplaceWithConstMap; - SmallVector<Instruction*, 8> InstrsToErase; - - typedef SmallVector<NonLocalDepResult, 64> LoadDepVect; - typedef SmallVector<AvailableValueInBlock, 64> AvailValInBlkVect; - typedef SmallVector<BasicBlock*, 64> UnavailBlkVect; - - public: - static char ID; // Pass identification, replacement for typeid - explicit GVN(bool noloads = false) - : FunctionPass(ID), NoLoads(noloads), MD(nullptr) { - initializeGVNPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - /// This removes the specified instruction from - /// our various maps and marks it for deletion. - void markInstructionForDeletion(Instruction *I) { - VN.erase(I); - InstrsToErase.push_back(I); - } - - DominatorTree &getDominatorTree() const { return *DT; } - AliasAnalysis *getAliasAnalysis() const { return VN.getAliasAnalysis(); } - MemoryDependenceAnalysis &getMemDep() const { return *MD; } - private: - /// Push a new Value to the LeaderTable onto the list for its value number. - void addToLeaderTable(uint32_t N, Value *V, const BasicBlock *BB) { - LeaderTableEntry &Curr = LeaderTable[N]; - if (!Curr.Val) { - Curr.Val = V; - Curr.BB = BB; - return; - } - - LeaderTableEntry *Node = TableAllocator.Allocate<LeaderTableEntry>(); - Node->Val = V; - Node->BB = BB; - Node->Next = Curr.Next; - Curr.Next = Node; - } - - /// Scan the list of values corresponding to a given - /// value number, and remove the given instruction if encountered. - void removeFromLeaderTable(uint32_t N, Instruction *I, BasicBlock *BB) { - LeaderTableEntry* Prev = nullptr; - LeaderTableEntry* Curr = &LeaderTable[N]; - - while (Curr && (Curr->Val != I || Curr->BB != BB)) { - Prev = Curr; - Curr = Curr->Next; - } - - if (!Curr) - return; - - if (Prev) { - Prev->Next = Curr->Next; - } else { - if (!Curr->Next) { - Curr->Val = nullptr; - Curr->BB = nullptr; - } else { - LeaderTableEntry* Next = Curr->Next; - Curr->Val = Next->Val; - Curr->BB = Next->BB; - Curr->Next = Next->Next; - } - } - } - - // List of critical edges to be split between iterations. - SmallVector<std::pair<TerminatorInst*, unsigned>, 4> toSplit; - - // This transformation requires dominator postdominator info - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - if (!NoLoads) - AU.addRequired<MemoryDependenceAnalysis>(); - AU.addRequired<AAResultsWrapperPass>(); - - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } - - - // Helper functions of redundant load elimination - bool processLoad(LoadInst *L); - bool processNonLocalLoad(LoadInst *L); - bool processAssumeIntrinsic(IntrinsicInst *II); - void AnalyzeLoadAvailability(LoadInst *LI, LoadDepVect &Deps, - AvailValInBlkVect &ValuesPerBlock, - UnavailBlkVect &UnavailableBlocks); - bool PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, - UnavailBlkVect &UnavailableBlocks); - - // Other helper routines - bool processInstruction(Instruction *I); - bool processBlock(BasicBlock *BB); - void dump(DenseMap<uint32_t, Value*> &d); - bool iterateOnFunction(Function &F); - bool performPRE(Function &F); - bool performScalarPRE(Instruction *I); - bool performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, - unsigned int ValNo); - Value *findLeader(const BasicBlock *BB, uint32_t num); - void cleanupGlobalSets(); - void verifyRemoved(const Instruction *I) const; - bool splitCriticalEdges(); - BasicBlock *splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ); - bool replaceOperandsWithConsts(Instruction *I) const; - bool propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, - bool DominatesByEdge); - bool processFoldableCondBr(BranchInst *BI); - void addDeadBlock(BasicBlock *BB); - void assignValNumForDeadCode(); - }; - - char GVN::ID = 0; -} - -// The public interface to this file... -FunctionPass *llvm::createGVNPass(bool NoLoads) { - return new GVN(NoLoads); +PreservedAnalyses GVN::run(Function &F, AnalysisManager<Function> &AM) { + // FIXME: The order of evaluation of these 'getResult' calls is very + // significant! Re-ordering these variables will cause GVN when run alone to + // be less effective! We should fix memdep and basic-aa to not exhibit this + // behavior, but until then don't change the order here. + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &MemDep = AM.getResult<MemoryDependenceAnalysis>(F); + bool Changed = runImpl(F, AC, DT, TLI, AA, &MemDep); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); + return PA; } -INITIALIZE_PASS_BEGIN(GVN, "gvn", "Global Value Numbering", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(GVN, "gvn", "Global Value Numbering", false, false) - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void GVN::dump(DenseMap<uint32_t, Value*>& d) { errs() << "{\n"; for (DenseMap<uint32_t, Value*>::iterator I = d.begin(), @@ -764,7 +612,6 @@ void GVN::dump(DenseMap<uint32_t, Value*>& d) { } errs() << "}\n"; } -#endif /// Return true if we can prove that the value /// we're analyzing is fully available in the specified block. As we go, keep @@ -875,38 +722,45 @@ static bool CanCoerceMustAliasedValueToLoad(Value *StoredVal, static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, IRBuilder<> &IRB, const DataLayout &DL) { - if (!CanCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL)) - return nullptr; + assert(CanCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) && + "precondition violation - materialization can't fail"); + + if (auto *CExpr = dyn_cast<ConstantExpr>(StoredVal)) + StoredVal = ConstantFoldConstantExpression(CExpr, DL); // If this is already the right type, just return it. Type *StoredValTy = StoredVal->getType(); - uint64_t StoreSize = DL.getTypeSizeInBits(StoredValTy); - uint64_t LoadSize = DL.getTypeSizeInBits(LoadedTy); + uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy); + uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy); // If the store and reload are the same size, we can always reuse it. - if (StoreSize == LoadSize) { + if (StoredValSize == LoadedValSize) { // Pointer to Pointer -> use bitcast. if (StoredValTy->getScalarType()->isPointerTy() && - LoadedTy->getScalarType()->isPointerTy()) - return IRB.CreateBitCast(StoredVal, LoadedTy); + LoadedTy->getScalarType()->isPointerTy()) { + StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy); + } else { + // Convert source pointers to integers, which can be bitcast. + if (StoredValTy->getScalarType()->isPointerTy()) { + StoredValTy = DL.getIntPtrType(StoredValTy); + StoredVal = IRB.CreatePtrToInt(StoredVal, StoredValTy); + } - // Convert source pointers to integers, which can be bitcast. - if (StoredValTy->getScalarType()->isPointerTy()) { - StoredValTy = DL.getIntPtrType(StoredValTy); - StoredVal = IRB.CreatePtrToInt(StoredVal, StoredValTy); - } + Type *TypeToCastTo = LoadedTy; + if (TypeToCastTo->getScalarType()->isPointerTy()) + TypeToCastTo = DL.getIntPtrType(TypeToCastTo); - Type *TypeToCastTo = LoadedTy; - if (TypeToCastTo->getScalarType()->isPointerTy()) - TypeToCastTo = DL.getIntPtrType(TypeToCastTo); + if (StoredValTy != TypeToCastTo) + StoredVal = IRB.CreateBitCast(StoredVal, TypeToCastTo); - if (StoredValTy != TypeToCastTo) - StoredVal = IRB.CreateBitCast(StoredVal, TypeToCastTo); + // Cast to pointer if the load needs a pointer type. + if (LoadedTy->getScalarType()->isPointerTy()) + StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy); + } - // Cast to pointer if the load needs a pointer type. - if (LoadedTy->getScalarType()->isPointerTy()) - StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy); + if (auto *CExpr = dyn_cast<ConstantExpr>(StoredVal)) + StoredVal = ConstantFoldConstantExpression(CExpr, DL); return StoredVal; } @@ -914,7 +768,8 @@ static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, // If the loaded value is smaller than the available value, then we can // extract out a piece from it. If the available value is too small, then we // can't do anything. - assert(StoreSize >= LoadSize && "CanCoerceMustAliasedValueToLoad fail"); + assert(StoredValSize >= LoadedValSize && + "CanCoerceMustAliasedValueToLoad fail"); // Convert source pointers to integers, which can be manipulated. if (StoredValTy->getScalarType()->isPointerTy()) { @@ -924,29 +779,35 @@ static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, // Convert vectors and fp to integer, which can be manipulated. if (!StoredValTy->isIntegerTy()) { - StoredValTy = IntegerType::get(StoredValTy->getContext(), StoreSize); + StoredValTy = IntegerType::get(StoredValTy->getContext(), StoredValSize); StoredVal = IRB.CreateBitCast(StoredVal, StoredValTy); } // If this is a big-endian system, we need to shift the value down to the low // bits so that a truncate will work. if (DL.isBigEndian()) { - StoredVal = IRB.CreateLShr(StoredVal, StoreSize - LoadSize, "tmp"); + uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy) - + DL.getTypeStoreSizeInBits(LoadedTy); + StoredVal = IRB.CreateLShr(StoredVal, ShiftAmt, "tmp"); } // Truncate the integer to the right size now. - Type *NewIntTy = IntegerType::get(StoredValTy->getContext(), LoadSize); + Type *NewIntTy = IntegerType::get(StoredValTy->getContext(), LoadedValSize); StoredVal = IRB.CreateTrunc(StoredVal, NewIntTy, "trunc"); - if (LoadedTy == NewIntTy) - return StoredVal; + if (LoadedTy != NewIntTy) { + // If the result is a pointer, inttoptr. + if (LoadedTy->getScalarType()->isPointerTy()) + StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy, "inttoptr"); + else + // Otherwise, bitcast. + StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy, "bitcast"); + } - // If the result is a pointer, inttoptr. - if (LoadedTy->getScalarType()->isPointerTy()) - return IRB.CreateIntToPtr(StoredVal, LoadedTy, "inttoptr"); + if (auto *CExpr = dyn_cast<ConstantExpr>(StoredVal)) + StoredVal = ConstantFoldConstantExpression(CExpr, DL); - // Otherwise, bitcast. - return IRB.CreateBitCast(StoredVal, LoadedTy, "bitcast"); + return StoredVal; } /// This function is called when we have a @@ -1067,10 +928,15 @@ static int AnalyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL); unsigned LoadSize = DL.getTypeStoreSize(LoadTy); - unsigned Size = MemoryDependenceAnalysis::getLoadLoadClobberFullWidthSize( + unsigned Size = MemoryDependenceResults::getLoadLoadClobberFullWidthSize( LoadBase, LoadOffs, LoadSize, DepLI); if (Size == 0) return -1; + // Check non-obvious conditions enforced by MDA which we rely on for being + // able to materialize this potentially available value + assert(DepLI->isSimple() && "Cannot widen volatile/atomic load!"); + assert(DepLI->getType()->isIntegerTy() && "Can't widen non-integer load"); + return AnalyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, Size*8, DL); } @@ -1117,7 +983,7 @@ static int AnalyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, OffsetCst); Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); - if (ConstantFoldLoadFromConstPtr(Src, DL)) + if (ConstantFoldLoadFromConstPtr(Src, LoadTy, DL)) return Offset; return -1; } @@ -1173,9 +1039,9 @@ static Value *GetLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, const DataLayout &DL = SrcVal->getModule()->getDataLayout(); // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to // widen SrcVal out to a larger load. - unsigned SrcValSize = DL.getTypeStoreSize(SrcVal->getType()); + unsigned SrcValStoreSize = DL.getTypeStoreSize(SrcVal->getType()); unsigned LoadSize = DL.getTypeStoreSize(LoadTy); - if (Offset+LoadSize > SrcValSize) { + if (Offset+LoadSize > SrcValStoreSize) { assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!"); assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load"); // If we have a load/load clobber an DepLI can be widened to cover this @@ -1207,8 +1073,7 @@ static Value *GetLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, // system, we need to shift down to get the relevant bits. Value *RV = NewLoad; if (DL.isBigEndian()) - RV = Builder.CreateLShr(RV, - NewLoadSize*8-SrcVal->getType()->getPrimitiveSizeInBits()); + RV = Builder.CreateLShr(RV, (NewLoadSize - SrcValStoreSize) * 8); RV = Builder.CreateTrunc(RV, SrcVal->getType()); SrcVal->replaceAllUsesWith(RV); @@ -1279,7 +1144,7 @@ static Value *GetMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, Src = ConstantExpr::getGetElementPtr(Type::getInt8Ty(Src->getContext()), Src, OffsetCst); Src = ConstantExpr::getBitCast(Src, PointerType::get(LoadTy, AS)); - return ConstantFoldLoadFromConstPtr(Src, DL); + return ConstantFoldLoadFromConstPtr(Src, LoadTy, DL); } @@ -1294,7 +1159,8 @@ static Value *ConstructSSAForLoadSet(LoadInst *LI, if (ValuesPerBlock.size() == 1 && gvn.getDominatorTree().properlyDominates(ValuesPerBlock[0].BB, LI->getParent())) { - assert(!ValuesPerBlock[0].isUndefValue() && "Dead BB dominate this block"); + assert(!ValuesPerBlock[0].AV.isUndefValue() && + "Dead BB dominate this block"); return ValuesPerBlock[0].MaterializeAdjustedValue(LI, gvn); } @@ -1316,15 +1182,16 @@ static Value *ConstructSSAForLoadSet(LoadInst *LI, return SSAUpdate.GetValueInMiddleOfBlock(LI->getParent()); } -Value *AvailableValueInBlock::MaterializeAdjustedValue(LoadInst *LI, - GVN &gvn) const { +Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI, + Instruction *InsertPt, + GVN &gvn) const { Value *Res; Type *LoadTy = LI->getType(); const DataLayout &DL = LI->getModule()->getDataLayout(); if (isSimpleValue()) { Res = getSimpleValue(); if (Res->getType() != LoadTy) { - Res = GetStoreValueForLoad(Res, Offset, LoadTy, BB->getTerminator(), DL); + Res = GetStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL); DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset << " " << *getSimpleValue() << '\n' @@ -1335,16 +1202,15 @@ Value *AvailableValueInBlock::MaterializeAdjustedValue(LoadInst *LI, if (Load->getType() == LoadTy && Offset == 0) { Res = Load; } else { - Res = GetLoadValueForLoad(Load, Offset, LoadTy, BB->getTerminator(), - gvn); - + Res = GetLoadValueForLoad(Load, Offset, LoadTy, InsertPt, gvn); + DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset << " " << *getCoercedLoadValue() << '\n' << *Res << '\n' << "\n\n\n"); } } else if (isMemIntrinValue()) { Res = GetMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy, - BB->getTerminator(), DL); + InsertPt, DL); DEBUG(dbgs() << "GVN COERCED NONLOCAL MEM INTRIN:\nOffset: " << Offset << " " << *getMemIntrinValue() << '\n' << *Res << '\n' << "\n\n\n"); @@ -1353,6 +1219,7 @@ Value *AvailableValueInBlock::MaterializeAdjustedValue(LoadInst *LI, DEBUG(dbgs() << "GVN COERCED NONLOCAL Undef:\n";); return UndefValue::get(LoadTy); } + assert(Res && "failed to materialize?"); return Res; } @@ -1362,7 +1229,134 @@ static bool isLifetimeStart(const Instruction *Inst) { return false; } -void GVN::AnalyzeLoadAvailability(LoadInst *LI, LoadDepVect &Deps, +bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, + Value *Address, AvailableValue &Res) { + + assert((DepInfo.isDef() || DepInfo.isClobber()) && + "expected a local dependence"); + assert(LI->isUnordered() && "rules below are incorrect for ordered access"); + + const DataLayout &DL = LI->getModule()->getDataLayout(); + + if (DepInfo.isClobber()) { + // If the dependence is to a store that writes to a superset of the bits + // read by the load, we can extract the bits we need for the load from the + // stored value. + if (StoreInst *DepSI = dyn_cast<StoreInst>(DepInfo.getInst())) { + // Can't forward from non-atomic to atomic without violating memory model. + if (Address && LI->isAtomic() <= DepSI->isAtomic()) { + int Offset = + AnalyzeLoadFromClobberingStore(LI->getType(), Address, DepSI); + if (Offset != -1) { + Res = AvailableValue::get(DepSI->getValueOperand(), Offset); + return true; + } + } + } + + // Check to see if we have something like this: + // load i32* P + // load i8* (P+1) + // if we have this, replace the later with an extraction from the former. + if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInfo.getInst())) { + // If this is a clobber and L is the first instruction in its block, then + // we have the first instruction in the entry block. + // Can't forward from non-atomic to atomic without violating memory model. + if (DepLI != LI && Address && LI->isAtomic() <= DepLI->isAtomic()) { + int Offset = + AnalyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL); + + if (Offset != -1) { + Res = AvailableValue::getLoad(DepLI, Offset); + return true; + } + } + } + + // If the clobbering value is a memset/memcpy/memmove, see if we can + // forward a value on from it. + if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInfo.getInst())) { + if (Address && !LI->isAtomic()) { + int Offset = AnalyzeLoadFromClobberingMemInst(LI->getType(), Address, + DepMI, DL); + if (Offset != -1) { + Res = AvailableValue::getMI(DepMI, Offset); + return true; + } + } + } + // Nothing known about this clobber, have to be conservative + DEBUG( + // fast print dep, using operator<< on instruction is too slow. + dbgs() << "GVN: load "; + LI->printAsOperand(dbgs()); + Instruction *I = DepInfo.getInst(); + dbgs() << " is clobbered by " << *I << '\n'; + ); + return false; + } + assert(DepInfo.isDef() && "follows from above"); + + Instruction *DepInst = DepInfo.getInst(); + + // Loading the allocation -> undef. + if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) || + // Loading immediately after lifetime begin -> undef. + isLifetimeStart(DepInst)) { + Res = AvailableValue::get(UndefValue::get(LI->getType())); + return true; + } + + // Loading from calloc (which zero initializes memory) -> zero + if (isCallocLikeFn(DepInst, TLI)) { + Res = AvailableValue::get(Constant::getNullValue(LI->getType())); + return true; + } + + if (StoreInst *S = dyn_cast<StoreInst>(DepInst)) { + // Reject loads and stores that are to the same address but are of + // different types if we have to. If the stored value is larger or equal to + // the loaded value, we can reuse it. + if (S->getValueOperand()->getType() != LI->getType() && + !CanCoerceMustAliasedValueToLoad(S->getValueOperand(), + LI->getType(), DL)) + return false; + + // Can't forward from non-atomic to atomic without violating memory model. + if (S->isAtomic() < LI->isAtomic()) + return false; + + Res = AvailableValue::get(S->getValueOperand()); + return true; + } + + if (LoadInst *LD = dyn_cast<LoadInst>(DepInst)) { + // If the types mismatch and we can't handle it, reject reuse of the load. + // If the stored value is larger or equal to the loaded value, we can reuse + // it. + if (LD->getType() != LI->getType() && + !CanCoerceMustAliasedValueToLoad(LD, LI->getType(), DL)) + return false; + + // Can't forward from non-atomic to atomic without violating memory model. + if (LD->isAtomic() < LI->isAtomic()) + return false; + + Res = AvailableValue::getLoad(LD); + return true; + } + + // Unknown def - must be conservative + DEBUG( + // fast print dep, using operator<< on instruction is too slow. + dbgs() << "GVN: load "; + LI->printAsOperand(dbgs()); + dbgs() << " has unknown def " << *DepInst << '\n'; + ); + return false; +} + +void GVN::AnalyzeLoadAvailability(LoadInst *LI, LoadDepVect &Deps, AvailValInBlkVect &ValuesPerBlock, UnavailBlkVect &UnavailableBlocks) { @@ -1371,7 +1365,6 @@ void GVN::AnalyzeLoadAvailability(LoadInst *LI, LoadDepVect &Deps, // dependencies that produce an unknown value for the load (such as a call // that could potentially clobber the load). unsigned NumDeps = Deps.size(); - const DataLayout &DL = LI->getModule()->getDataLayout(); for (unsigned i = 0, e = NumDeps; i != e; ++i) { BasicBlock *DepBB = Deps[i].getBB(); MemDepResult DepInfo = Deps[i].getResult(); @@ -1388,122 +1381,28 @@ void GVN::AnalyzeLoadAvailability(LoadInst *LI, LoadDepVect &Deps, continue; } - if (DepInfo.isClobber()) { - // The address being loaded in this non-local block may not be the same as - // the pointer operand of the load if PHI translation occurs. Make sure - // to consider the right address. - Value *Address = Deps[i].getAddress(); - - // If the dependence is to a store that writes to a superset of the bits - // read by the load, we can extract the bits we need for the load from the - // stored value. - if (StoreInst *DepSI = dyn_cast<StoreInst>(DepInfo.getInst())) { - if (Address) { - int Offset = - AnalyzeLoadFromClobberingStore(LI->getType(), Address, DepSI); - if (Offset != -1) { - ValuesPerBlock.push_back(AvailableValueInBlock::get(DepBB, - DepSI->getValueOperand(), - Offset)); - continue; - } - } - } - - // Check to see if we have something like this: - // load i32* P - // load i8* (P+1) - // if we have this, replace the later with an extraction from the former. - if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInfo.getInst())) { - // If this is a clobber and L is the first instruction in its block, then - // we have the first instruction in the entry block. - if (DepLI != LI && Address) { - int Offset = - AnalyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL); - - if (Offset != -1) { - ValuesPerBlock.push_back(AvailableValueInBlock::getLoad(DepBB,DepLI, - Offset)); - continue; - } - } - } - - // If the clobbering value is a memset/memcpy/memmove, see if we can - // forward a value on from it. - if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInfo.getInst())) { - if (Address) { - int Offset = AnalyzeLoadFromClobberingMemInst(LI->getType(), Address, - DepMI, DL); - if (Offset != -1) { - ValuesPerBlock.push_back(AvailableValueInBlock::getMI(DepBB, DepMI, - Offset)); - continue; - } - } - } - - UnavailableBlocks.push_back(DepBB); - continue; - } - - // DepInfo.isDef() here - - Instruction *DepInst = DepInfo.getInst(); - - // Loading the allocation -> undef. - if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) || - // Loading immediately after lifetime begin -> undef. - isLifetimeStart(DepInst)) { - ValuesPerBlock.push_back(AvailableValueInBlock::get(DepBB, - UndefValue::get(LI->getType()))); - continue; - } - - // Loading from calloc (which zero initializes memory) -> zero - if (isCallocLikeFn(DepInst, TLI)) { - ValuesPerBlock.push_back(AvailableValueInBlock::get( - DepBB, Constant::getNullValue(LI->getType()))); - continue; - } - - if (StoreInst *S = dyn_cast<StoreInst>(DepInst)) { - // Reject loads and stores that are to the same address but are of - // different types if we have to. - if (S->getValueOperand()->getType() != LI->getType()) { - // If the stored value is larger or equal to the loaded value, we can - // reuse it. - if (!CanCoerceMustAliasedValueToLoad(S->getValueOperand(), - LI->getType(), DL)) { - UnavailableBlocks.push_back(DepBB); - continue; - } - } + // The address being loaded in this non-local block may not be the same as + // the pointer operand of the load if PHI translation occurs. Make sure + // to consider the right address. + Value *Address = Deps[i].getAddress(); + AvailableValue AV; + if (AnalyzeLoadAvailability(LI, DepInfo, Address, AV)) { + // subtlety: because we know this was a non-local dependency, we know + // it's safe to materialize anywhere between the instruction within + // DepInfo and the end of it's block. ValuesPerBlock.push_back(AvailableValueInBlock::get(DepBB, - S->getValueOperand())); - continue; - } - - if (LoadInst *LD = dyn_cast<LoadInst>(DepInst)) { - // If the types mismatch and we can't handle it, reject reuse of the load. - if (LD->getType() != LI->getType()) { - // If the stored value is larger or equal to the loaded value, we can - // reuse it. - if (!CanCoerceMustAliasedValueToLoad(LD, LI->getType(), DL)) { - UnavailableBlocks.push_back(DepBB); - continue; - } - } - ValuesPerBlock.push_back(AvailableValueInBlock::getLoad(DepBB, LD)); - continue; + std::move(AV))); + } else { + UnavailableBlocks.push_back(DepBB); } - - UnavailableBlocks.push_back(DepBB); } + + assert(NumDeps == ValuesPerBlock.size() + UnavailableBlocks.size() && + "post condition violation"); } -bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, +bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, UnavailBlkVect &UnavailableBlocks) { // Okay, we have *some* definitions of the value. This means that the value // is available in some of our (transitive) predecessors. Lets think about @@ -1661,16 +1560,17 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // parent's availability map. However, in doing so, we risk getting into // ordering issues. If a block hasn't been processed yet, we would be // marking a value as AVAIL-IN, which isn't what we intend. - VN.lookup_or_add(I); + VN.lookupOrAdd(I); } for (const auto &PredLoad : PredLoads) { BasicBlock *UnavailablePred = PredLoad.first; Value *LoadPtr = PredLoad.second; - Instruction *NewLoad = new LoadInst(LoadPtr, LI->getName()+".pre", false, - LI->getAlignment(), - UnavailablePred->getTerminator()); + auto *NewLoad = new LoadInst(LoadPtr, LI->getName()+".pre", + LI->isVolatile(), LI->getAlignment(), + LI->getOrdering(), LI->getSynchScope(), + UnavailablePred->getTerminator()); // Transfer the old load's AA tags to the new load. AAMDNodes Tags; @@ -1682,6 +1582,8 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, NewLoad->setMetadata(LLVMContext::MD_invariant_load, MD); if (auto *InvGroupMD = LI->getMetadata(LLVMContext::MD_invariant_group)) NewLoad->setMetadata(LLVMContext::MD_invariant_group, InvGroupMD); + if (auto *RangeMD = LI->getMetadata(LLVMContext::MD_range)) + NewLoad->setMetadata(LLVMContext::MD_range, RangeMD); // Transfer DebugLoc. NewLoad->setDebugLoc(LI->getDebugLoc()); @@ -1846,30 +1748,29 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { } static void patchReplacementInstruction(Instruction *I, Value *Repl) { + auto *ReplInst = dyn_cast<Instruction>(Repl); + if (!ReplInst) + return; + // Patch the replacement so that it is not more restrictive than the value // being replaced. - BinaryOperator *Op = dyn_cast<BinaryOperator>(I); - BinaryOperator *ReplOp = dyn_cast<BinaryOperator>(Repl); - if (Op && ReplOp) - ReplOp->andIRFlags(Op); - - if (Instruction *ReplInst = dyn_cast<Instruction>(Repl)) { - // FIXME: If both the original and replacement value are part of the - // same control-flow region (meaning that the execution of one - // guarantees the execution of the other), then we can combine the - // noalias scopes here and do better than the general conservative - // answer used in combineMetadata(). - - // In general, GVN unifies expressions over different control-flow - // regions, and so we need a conservative combination of the noalias - // scopes. - static const unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_range, - LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, - LLVMContext::MD_invariant_group}; - combineMetadata(ReplInst, I, KnownIDs); - } + ReplInst->andIRFlags(I); + + // FIXME: If both the original and replacement value are part of the + // same control-flow region (meaning that the execution of one + // guarantees the execution of the other), then we can combine the + // noalias scopes here and do better than the general conservative + // answer used in combineMetadata(). + + // In general, GVN unifies expressions over different control-flow + // regions, and so we need a conservative combination of the noalias + // scopes. + static const unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, + LLVMContext::MD_invariant_group}; + combineMetadata(ReplInst, I, KnownIDs); } static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { @@ -1883,7 +1784,8 @@ bool GVN::processLoad(LoadInst *L) { if (!MD) return false; - if (!L->isSimple()) + // This code hasn't been audited for ordered or volatile memory access + if (!L->isUnordered()) return false; if (L->use_empty()) { @@ -1893,84 +1795,14 @@ bool GVN::processLoad(LoadInst *L) { // ... to a pointer that has been loaded from before... MemDepResult Dep = MD->getDependency(L); - const DataLayout &DL = L->getModule()->getDataLayout(); - - // If we have a clobber and target data is around, see if this is a clobber - // that we can fix up through code synthesis. - if (Dep.isClobber()) { - // Check to see if we have something like this: - // store i32 123, i32* %P - // %A = bitcast i32* %P to i8* - // %B = gep i8* %A, i32 1 - // %C = load i8* %B - // - // We could do that by recognizing if the clobber instructions are obviously - // a common base + constant offset, and if the previous store (or memset) - // completely covers this load. This sort of thing can happen in bitfield - // access code. - Value *AvailVal = nullptr; - if (StoreInst *DepSI = dyn_cast<StoreInst>(Dep.getInst())) { - int Offset = AnalyzeLoadFromClobberingStore( - L->getType(), L->getPointerOperand(), DepSI); - if (Offset != -1) - AvailVal = GetStoreValueForLoad(DepSI->getValueOperand(), Offset, - L->getType(), L, DL); - } - - // Check to see if we have something like this: - // load i32* P - // load i8* (P+1) - // if we have this, replace the later with an extraction from the former. - if (LoadInst *DepLI = dyn_cast<LoadInst>(Dep.getInst())) { - // If this is a clobber and L is the first instruction in its block, then - // we have the first instruction in the entry block. - if (DepLI == L) - return false; - - int Offset = AnalyzeLoadFromClobberingLoad( - L->getType(), L->getPointerOperand(), DepLI, DL); - if (Offset != -1) - AvailVal = GetLoadValueForLoad(DepLI, Offset, L->getType(), L, *this); - } - - // If the clobbering value is a memset/memcpy/memmove, see if we can forward - // a value on from it. - if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(Dep.getInst())) { - int Offset = AnalyzeLoadFromClobberingMemInst( - L->getType(), L->getPointerOperand(), DepMI, DL); - if (Offset != -1) - AvailVal = GetMemInstValueForLoad(DepMI, Offset, L->getType(), L, DL); - } - - if (AvailVal) { - DEBUG(dbgs() << "GVN COERCED INST:\n" << *Dep.getInst() << '\n' - << *AvailVal << '\n' << *L << "\n\n\n"); - - // Replace the load! - L->replaceAllUsesWith(AvailVal); - if (AvailVal->getType()->getScalarType()->isPointerTy()) - MD->invalidateCachedPointerInfo(AvailVal); - markInstructionForDeletion(L); - ++NumGVNLoad; - return true; - } - - // If the value isn't available, don't do anything! - DEBUG( - // fast print dep, using operator<< on instruction is too slow. - dbgs() << "GVN: load "; - L->printAsOperand(dbgs()); - Instruction *I = Dep.getInst(); - dbgs() << " is clobbered by " << *I << '\n'; - ); - return false; - } // If it is defined in another block, try harder. if (Dep.isNonLocal()) return processNonLocalLoad(L); - if (!Dep.isDef()) { + // Only handle the local case below + if (!Dep.isDef() && !Dep.isClobber()) { + // This might be a NonFuncLocal or an Unknown DEBUG( // fast print dep, using operator<< on instruction is too slow. dbgs() << "GVN: load "; @@ -1980,86 +1812,18 @@ bool GVN::processLoad(LoadInst *L) { return false; } - Instruction *DepInst = Dep.getInst(); - if (StoreInst *DepSI = dyn_cast<StoreInst>(DepInst)) { - Value *StoredVal = DepSI->getValueOperand(); - - // The store and load are to a must-aliased pointer, but they may not - // actually have the same type. See if we know how to reuse the stored - // value (depending on its type). - if (StoredVal->getType() != L->getType()) { - IRBuilder<> Builder(L); - StoredVal = - CoerceAvailableValueToLoadType(StoredVal, L->getType(), Builder, DL); - if (!StoredVal) - return false; - - DEBUG(dbgs() << "GVN COERCED STORE:\n" << *DepSI << '\n' << *StoredVal - << '\n' << *L << "\n\n\n"); - } - - // Remove it! - L->replaceAllUsesWith(StoredVal); - if (StoredVal->getType()->getScalarType()->isPointerTy()) - MD->invalidateCachedPointerInfo(StoredVal); - markInstructionForDeletion(L); - ++NumGVNLoad; - return true; - } - - if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInst)) { - Value *AvailableVal = DepLI; - - // The loads are of a must-aliased pointer, but they may not actually have - // the same type. See if we know how to reuse the previously loaded value - // (depending on its type). - if (DepLI->getType() != L->getType()) { - IRBuilder<> Builder(L); - AvailableVal = - CoerceAvailableValueToLoadType(DepLI, L->getType(), Builder, DL); - if (!AvailableVal) - return false; - - DEBUG(dbgs() << "GVN COERCED LOAD:\n" << *DepLI << "\n" << *AvailableVal - << "\n" << *L << "\n\n\n"); - } - - // Remove it! - patchAndReplaceAllUsesWith(L, AvailableVal); - if (DepLI->getType()->getScalarType()->isPointerTy()) - MD->invalidateCachedPointerInfo(DepLI); - markInstructionForDeletion(L); - ++NumGVNLoad; - return true; - } - - // If this load really doesn't depend on anything, then we must be loading an - // undef value. This can happen when loading for a fresh allocation with no - // intervening stores, for example. - if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI)) { - L->replaceAllUsesWith(UndefValue::get(L->getType())); - markInstructionForDeletion(L); - ++NumGVNLoad; - return true; - } + AvailableValue AV; + if (AnalyzeLoadAvailability(L, Dep, L->getPointerOperand(), AV)) { + Value *AvailableValue = AV.MaterializeAdjustedValue(L, L, *this); - // If this load occurs either right after a lifetime begin, - // then the loaded value is undefined. - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(DepInst)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start) { - L->replaceAllUsesWith(UndefValue::get(L->getType())); - markInstructionForDeletion(L); - ++NumGVNLoad; - return true; - } - } - - // If this load follows a calloc (which zero initializes memory), - // then the loaded value is zero - if (isCallocLikeFn(DepInst, TLI)) { - L->replaceAllUsesWith(Constant::getNullValue(L->getType())); + // Replace the load! + patchAndReplaceAllUsesWith(L, AvailableValue); markInstructionForDeletion(L); ++NumGVNLoad; + // Tell MDA to rexamine the reused pointer since we might have more + // information after forwarding it. + if (MD && AvailableValue->getType()->getScalarType()->isPointerTy()) + MD->invalidateCachedPointerInfo(AvailableValue); return true; } @@ -2105,9 +1869,8 @@ static bool isOnlyReachableViaThisEdge(const BasicBlockEdge &E, // GVN runs all such loops have preheaders, which means that Dst will have // been changed to have only one predecessor, namely Src. const BasicBlock *Pred = E.getEnd()->getSinglePredecessor(); - const BasicBlock *Src = E.getStart(); - assert((!Pred || Pred == Src) && "No edge between these basic blocks!"); - (void)Src; + assert((!Pred || Pred == E.getStart()) && + "No edge between these basic blocks!"); return Pred != nullptr; } @@ -2133,7 +1896,8 @@ bool GVN::replaceOperandsWithConsts(Instruction *Instr) const { /// The given values are known to be equal in every block /// dominated by 'Root'. Exploit this, for example by replacing 'LHS' with /// 'RHS' everywhere in the scope. Returns whether a change was made. -/// If DominatesByEdge is false, then it means that it is dominated by Root.End. +/// If DominatesByEdge is false, then it means that we will propagate the RHS +/// value starting from the end of Root.Start. bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, bool DominatesByEdge) { SmallVector<std::pair<Value*, Value*>, 4> Worklist; @@ -2141,7 +1905,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, bool Changed = false; // For speed, compute a conservative fast approximation to // DT->dominates(Root, Root.getEnd()); - bool RootDominatesEnd = isOnlyReachableViaThisEdge(Root, DT); + const bool RootDominatesEnd = isOnlyReachableViaThisEdge(Root, DT); while (!Worklist.empty()) { std::pair<Value*, Value*> Item = Worklist.pop_back_val(); @@ -2164,12 +1928,12 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, // right-hand side, ensure the longest lived term is on the right-hand side, // so the shortest lived term will be replaced by the longest lived. // This tends to expose more simplifications. - uint32_t LVN = VN.lookup_or_add(LHS); + uint32_t LVN = VN.lookupOrAdd(LHS); if ((isa<Argument>(LHS) && isa<Argument>(RHS)) || (isa<Instruction>(LHS) && isa<Instruction>(RHS))) { // Move the 'oldest' value to the right-hand side, using the value number // as a proxy for age. - uint32_t RVN = VN.lookup_or_add(RHS); + uint32_t RVN = VN.lookupOrAdd(RHS); if (LVN < RVN) { std::swap(LHS, RHS); LVN = RVN; @@ -2195,7 +1959,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, unsigned NumReplacements = DominatesByEdge ? replaceDominatedUsesWith(LHS, RHS, *DT, Root) - : replaceDominatedUsesWith(LHS, RHS, *DT, Root.getEnd()); + : replaceDominatedUsesWith(LHS, RHS, *DT, Root.getStart()); Changed |= NumReplacements > 0; NumGVNEqProp += NumReplacements; @@ -2245,7 +2009,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, // Floating point -0.0 and 0.0 compare equal, so we can only // propagate values if we know that we have a constant and that // its value is non-zero. - + // FIXME: We should do this optimization if 'no signed zeros' is // applicable via an instruction-level fast-math-flag or some other // indicator that relaxed FP semantics are being used. @@ -2253,7 +2017,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, if (isa<ConstantFP>(Op1) && !cast<ConstantFP>(Op1)->isZero()) Worklist.push_back(std::make_pair(Op0, Op1)); } - + // If "A >= B" is known true, replace "A < B" with false everywhere. CmpInst::Predicate NotPred = Cmp->getInversePredicate(); Constant *NotVal = ConstantInt::get(Cmp->getType(), isKnownFalse); @@ -2261,7 +2025,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, // out the value number that it would have and use that to find an // appropriate instruction (if any). uint32_t NextNum = VN.getNextUnusedValueNumber(); - uint32_t Num = VN.lookup_or_add_cmp(Cmp->getOpcode(), NotPred, Op0, Op1); + uint32_t Num = VN.lookupOrAddCmp(Cmp->getOpcode(), NotPred, Op0, Op1); // If the number we were assigned was brand new then there is no point in // looking for an instruction realizing it: there cannot be one! if (Num < NextNum) { @@ -2271,7 +2035,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, DominatesByEdge ? replaceDominatedUsesWith(NotCmp, NotVal, *DT, Root) : replaceDominatedUsesWith(NotCmp, NotVal, *DT, - Root.getEnd()); + Root.getStart()); Changed |= NumReplacements > 0; NumGVNEqProp += NumReplacements; } @@ -2303,12 +2067,21 @@ bool GVN::processInstruction(Instruction *I) { // "%z = and i32 %x, %y" becomes "%z = and i32 %x, %x" which we now simplify. const DataLayout &DL = I->getModule()->getDataLayout(); if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { - I->replaceAllUsesWith(V); - if (MD && V->getType()->getScalarType()->isPointerTy()) - MD->invalidateCachedPointerInfo(V); - markInstructionForDeletion(I); - ++NumGVNSimpl; - return true; + bool Changed = false; + if (!I->use_empty()) { + I->replaceAllUsesWith(V); + Changed = true; + } + if (isInstructionTriviallyDead(I, TLI)) { + markInstructionForDeletion(I); + Changed = true; + } + if (Changed) { + if (MD && V->getType()->getScalarType()->isPointerTy()) + MD->invalidateCachedPointerInfo(V); + ++NumGVNSimpl; + return true; + } } if (IntrinsicInst *IntrinsicI = dyn_cast<IntrinsicInst>(I)) @@ -2319,7 +2092,7 @@ bool GVN::processInstruction(Instruction *I) { if (processLoad(LI)) return true; - unsigned Num = VN.lookup_or_add(LI); + unsigned Num = VN.lookupOrAdd(LI); addToLeaderTable(Num, LI, LI->getParent()); return false; } @@ -2383,7 +2156,7 @@ bool GVN::processInstruction(Instruction *I) { return false; uint32_t NextNum = VN.getNextUnusedValueNumber(); - unsigned Num = VN.lookup_or_add(I); + unsigned Num = VN.lookupOrAdd(I); // Allocations are always uniquely numbered, so we can save time and memory // by fast failing them. @@ -2422,18 +2195,16 @@ bool GVN::processInstruction(Instruction *I) { } /// runOnFunction - This is the main transformation entry point for a function. -bool GVN::runOnFunction(Function& F) { - if (skipOptnoneFunction(F)) - return false; - - if (!NoLoads) - MD = &getAnalysis<MemoryDependenceAnalysis>(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - VN.setAliasAnalysis(&getAnalysis<AAResultsWrapperPass>().getAAResults()); - VN.setMemDep(MD); +bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, + const TargetLibraryInfo &RunTLI, AAResults &RunAA, + MemoryDependenceResults *RunMD) { + AC = &RunAC; + DT = &RunDT; VN.setDomTree(DT); + TLI = &RunTLI; + VN.setAliasAnalysis(&RunAA); + MD = RunMD; + VN.setMemDep(MD); bool Changed = false; bool ShouldContinue = true; @@ -2476,7 +2247,7 @@ bool GVN::runOnFunction(Function& F) { cleanupGlobalSets(); // Do not cleanup DeadBlocks in cleanupGlobalSets() as it's called for each - // iteration. + // iteration. DeadBlocks.clear(); return Changed; @@ -2576,8 +2347,6 @@ bool GVN::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, } bool GVN::performScalarPRE(Instruction *CurInst) { - SmallVector<std::pair<Value*, BasicBlock*>, 8> predMap; - if (isa<AllocaInst>(CurInst) || isa<TerminatorInst>(CurInst) || isa<PHINode>(CurInst) || CurInst->getType()->isVoidTy() || CurInst->mayReadFromMemory() || CurInst->mayHaveSideEffects() || @@ -2608,8 +2377,8 @@ bool GVN::performScalarPRE(Instruction *CurInst) { unsigned NumWithout = 0; BasicBlock *PREPred = nullptr; BasicBlock *CurrentBlock = CurInst->getParent(); - predMap.clear(); + SmallVector<std::pair<Value *, BasicBlock *>, 8> predMap; for (BasicBlock *P : predecessors(CurrentBlock)) { // We're not interested in PRE where the block is its // own predecessor, or in blocks with predecessors @@ -2702,7 +2471,7 @@ bool GVN::performScalarPRE(Instruction *CurInst) { DEBUG(verifyRemoved(CurInst)); CurInst->eraseFromParent(); ++NumGVNInstr; - + return true; } @@ -2825,7 +2594,7 @@ void GVN::addDeadBlock(BasicBlock *BB) { SmallVector<BasicBlock *, 8> Dom; DT->getDescendants(D, Dom); DeadBlocks.insert(Dom.begin(), Dom.end()); - + // Figure out the dominance-frontier(D). for (BasicBlock *B : Dom) { for (BasicBlock *S : successors(B)) { @@ -2883,13 +2652,13 @@ void GVN::addDeadBlock(BasicBlock *BB) { // If the given branch is recognized as a foldable branch (i.e. conditional // branch with constant condition), it will perform following analyses and // transformation. -// 1) If the dead out-coming edge is a critical-edge, split it. Let +// 1) If the dead out-coming edge is a critical-edge, split it. Let // R be the target of the dead out-coming edge. // 1) Identify the set of dead blocks implied by the branch's dead outcoming // edge. The result of this step will be {X| X is dominated by R} // 2) Identify those blocks which haves at least one dead predecessor. The // result of this step will be dominance-frontier(R). -// 3) Update the PHIs in DF(R) by replacing the operands corresponding to +// 3) Update the PHIs in DF(R) by replacing the operands corresponding to // dead blocks with "UndefVal" in an hope these PHIs will optimized away. // // Return true iff *NEW* dead code are found. @@ -2905,8 +2674,8 @@ bool GVN::processFoldableCondBr(BranchInst *BI) { if (!Cond) return false; - BasicBlock *DeadRoot = Cond->getZExtValue() ? - BI->getSuccessor(1) : BI->getSuccessor(0); + BasicBlock *DeadRoot = + Cond->getZExtValue() ? BI->getSuccessor(1) : BI->getSuccessor(0); if (DeadBlocks.count(DeadRoot)) return false; @@ -2924,8 +2693,62 @@ bool GVN::processFoldableCondBr(BranchInst *BI) { void GVN::assignValNumForDeadCode() { for (BasicBlock *BB : DeadBlocks) { for (Instruction &Inst : *BB) { - unsigned ValNum = VN.lookup_or_add(&Inst); + unsigned ValNum = VN.lookupOrAdd(&Inst); addToLeaderTable(ValNum, &Inst, BB); } } } + +class llvm::gvn::GVNLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + explicit GVNLegacyPass(bool NoLoads = false) + : FunctionPass(ID), NoLoads(NoLoads) { + initializeGVNLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + return Impl.runImpl( + F, getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + getAnalysis<AAResultsWrapperPass>().getAAResults(), + NoLoads ? nullptr + : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + if (!NoLoads) + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + +private: + bool NoLoads; + GVN Impl; +}; + +char GVNLegacyPass::ID = 0; + +// The public interface to this file... +FunctionPass *llvm::createGVNPass(bool NoLoads) { + return new GVNLegacyPass(NoLoads); +} + +INITIALIZE_PASS_BEGIN(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_END(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp new file mode 100644 index 000000000000..cce1db3874b7 --- /dev/null +++ b/lib/Transforms/Scalar/GVNHoist.cpp @@ -0,0 +1,825 @@ +//===- GVNHoist.cpp - Hoist scalar and load expressions -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass hoists expressions from branches to a common dominator. It uses +// GVN (global value numbering) to discover expressions computing the same +// values. The primary goal is to reduce the code size, and in some +// cases reduce critical path (by exposing more ILP). +// Hoisting may affect the performance in some cases. To mitigate that, hoisting +// is disabled in the following cases. +// 1. Scalars across calls. +// 2. geps when corresponding load/store cannot be hoisted. +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Utils/MemorySSA.h" + +using namespace llvm; + +#define DEBUG_TYPE "gvn-hoist" + +STATISTIC(NumHoisted, "Number of instructions hoisted"); +STATISTIC(NumRemoved, "Number of instructions removed"); +STATISTIC(NumLoadsHoisted, "Number of loads hoisted"); +STATISTIC(NumLoadsRemoved, "Number of loads removed"); +STATISTIC(NumStoresHoisted, "Number of stores hoisted"); +STATISTIC(NumStoresRemoved, "Number of stores removed"); +STATISTIC(NumCallsHoisted, "Number of calls hoisted"); +STATISTIC(NumCallsRemoved, "Number of calls removed"); + +static cl::opt<int> + MaxHoistedThreshold("gvn-max-hoisted", cl::Hidden, cl::init(-1), + cl::desc("Max number of instructions to hoist " + "(default unlimited = -1)")); +static cl::opt<int> MaxNumberOfBBSInPath( + "gvn-hoist-max-bbs", cl::Hidden, cl::init(4), + cl::desc("Max number of basic blocks on the path between " + "hoisting locations (default = 4, unlimited = -1)")); + +namespace { + +// Provides a sorting function based on the execution order of two instructions. +struct SortByDFSIn { +private: + DenseMap<const BasicBlock *, unsigned> &DFSNumber; + +public: + SortByDFSIn(DenseMap<const BasicBlock *, unsigned> &D) : DFSNumber(D) {} + + // Returns true when A executes before B. + bool operator()(const Instruction *A, const Instruction *B) const { + // FIXME: libc++ has a std::sort() algorithm that will call the compare + // function on the same element. Once PR20837 is fixed and some more years + // pass by and all the buildbots have moved to a corrected std::sort(), + // enable the following assert: + // + // assert(A != B); + + const BasicBlock *BA = A->getParent(); + const BasicBlock *BB = B->getParent(); + unsigned NA = DFSNumber[BA]; + unsigned NB = DFSNumber[BB]; + if (NA < NB) + return true; + if (NA == NB) { + // Sort them in the order they occur in the same basic block. + BasicBlock::const_iterator AI(A), BI(B); + return std::distance(AI, BI) < 0; + } + return false; + } +}; + +// A map from a pair of VNs to all the instructions with those VNs. +typedef DenseMap<std::pair<unsigned, unsigned>, SmallVector<Instruction *, 4>> + VNtoInsns; +// An invalid value number Used when inserting a single value number into +// VNtoInsns. +enum : unsigned { InvalidVN = ~2U }; + +// Records all scalar instructions candidate for code hoisting. +class InsnInfo { + VNtoInsns VNtoScalars; + +public: + // Inserts I and its value number in VNtoScalars. + void insert(Instruction *I, GVN::ValueTable &VN) { + // Scalar instruction. + unsigned V = VN.lookupOrAdd(I); + VNtoScalars[{V, InvalidVN}].push_back(I); + } + + const VNtoInsns &getVNTable() const { return VNtoScalars; } +}; + +// Records all load instructions candidate for code hoisting. +class LoadInfo { + VNtoInsns VNtoLoads; + +public: + // Insert Load and the value number of its memory address in VNtoLoads. + void insert(LoadInst *Load, GVN::ValueTable &VN) { + if (Load->isSimple()) { + unsigned V = VN.lookupOrAdd(Load->getPointerOperand()); + VNtoLoads[{V, InvalidVN}].push_back(Load); + } + } + + const VNtoInsns &getVNTable() const { return VNtoLoads; } +}; + +// Records all store instructions candidate for code hoisting. +class StoreInfo { + VNtoInsns VNtoStores; + +public: + // Insert the Store and a hash number of the store address and the stored + // value in VNtoStores. + void insert(StoreInst *Store, GVN::ValueTable &VN) { + if (!Store->isSimple()) + return; + // Hash the store address and the stored value. + Value *Ptr = Store->getPointerOperand(); + Value *Val = Store->getValueOperand(); + VNtoStores[{VN.lookupOrAdd(Ptr), VN.lookupOrAdd(Val)}].push_back(Store); + } + + const VNtoInsns &getVNTable() const { return VNtoStores; } +}; + +// Records all call instructions candidate for code hoisting. +class CallInfo { + VNtoInsns VNtoCallsScalars; + VNtoInsns VNtoCallsLoads; + VNtoInsns VNtoCallsStores; + +public: + // Insert Call and its value numbering in one of the VNtoCalls* containers. + void insert(CallInst *Call, GVN::ValueTable &VN) { + // A call that doesNotAccessMemory is handled as a Scalar, + // onlyReadsMemory will be handled as a Load instruction, + // all other calls will be handled as stores. + unsigned V = VN.lookupOrAdd(Call); + auto Entry = std::make_pair(V, InvalidVN); + + if (Call->doesNotAccessMemory()) + VNtoCallsScalars[Entry].push_back(Call); + else if (Call->onlyReadsMemory()) + VNtoCallsLoads[Entry].push_back(Call); + else + VNtoCallsStores[Entry].push_back(Call); + } + + const VNtoInsns &getScalarVNTable() const { return VNtoCallsScalars; } + + const VNtoInsns &getLoadVNTable() const { return VNtoCallsLoads; } + + const VNtoInsns &getStoreVNTable() const { return VNtoCallsStores; } +}; + +typedef DenseMap<const BasicBlock *, bool> BBSideEffectsSet; +typedef SmallVector<Instruction *, 4> SmallVecInsn; +typedef SmallVectorImpl<Instruction *> SmallVecImplInsn; + +// This pass hoists common computations across branches sharing common +// dominator. The primary goal is to reduce the code size, and in some +// cases reduce critical path (by exposing more ILP). +class GVNHoist { +public: + GVN::ValueTable VN; + DominatorTree *DT; + AliasAnalysis *AA; + MemoryDependenceResults *MD; + const bool OptForMinSize; + DenseMap<const BasicBlock *, unsigned> DFSNumber; + BBSideEffectsSet BBSideEffects; + MemorySSA *MSSA; + int HoistedCtr; + + enum InsKind { Unknown, Scalar, Load, Store }; + + GVNHoist(DominatorTree *Dt, AliasAnalysis *Aa, MemoryDependenceResults *Md, + bool OptForMinSize) + : DT(Dt), AA(Aa), MD(Md), OptForMinSize(OptForMinSize), HoistedCtr(0) {} + + // Return true when there are exception handling in BB. + bool hasEH(const BasicBlock *BB) { + auto It = BBSideEffects.find(BB); + if (It != BBSideEffects.end()) + return It->second; + + if (BB->isEHPad() || BB->hasAddressTaken()) { + BBSideEffects[BB] = true; + return true; + } + + if (BB->getTerminator()->mayThrow()) { + BBSideEffects[BB] = true; + return true; + } + + BBSideEffects[BB] = false; + return false; + } + + // Return true when all paths from A to the end of the function pass through + // either B or C. + bool hoistingFromAllPaths(const BasicBlock *A, const BasicBlock *B, + const BasicBlock *C) { + // We fully copy the WL in order to be able to remove items from it. + SmallPtrSet<const BasicBlock *, 2> WL; + WL.insert(B); + WL.insert(C); + + for (auto It = df_begin(A), E = df_end(A); It != E;) { + // There exists a path from A to the exit of the function if we are still + // iterating in DF traversal and we removed all instructions from the work + // list. + if (WL.empty()) + return false; + + const BasicBlock *BB = *It; + if (WL.erase(BB)) { + // Stop DFS traversal when BB is in the work list. + It.skipChildren(); + continue; + } + + // Check for end of function, calls that do not return, etc. + if (!isGuaranteedToTransferExecutionToSuccessor(BB->getTerminator())) + return false; + + // Increment DFS traversal when not skipping children. + ++It; + } + + return true; + } + + /* Return true when I1 appears before I2 in the instructions of BB. */ + bool firstInBB(BasicBlock *BB, const Instruction *I1, const Instruction *I2) { + for (Instruction &I : *BB) { + if (&I == I1) + return true; + if (&I == I2) + return false; + } + + llvm_unreachable("I1 and I2 not found in BB"); + } + // Return true when there are users of Def in BB. + bool hasMemoryUseOnPath(MemoryAccess *Def, const BasicBlock *BB, + const Instruction *OldPt) { + const BasicBlock *DefBB = Def->getBlock(); + const BasicBlock *OldBB = OldPt->getParent(); + + for (User *U : Def->users()) + if (auto *MU = dyn_cast<MemoryUse>(U)) { + BasicBlock *UBB = MU->getBlock(); + // Only analyze uses in BB. + if (BB != UBB) + continue; + + // A use in the same block as the Def is on the path. + if (UBB == DefBB) { + assert(MSSA->locallyDominates(Def, MU) && "def not dominating use"); + return true; + } + + if (UBB != OldBB) + return true; + + // It is only harmful to hoist when the use is before OldPt. + if (firstInBB(UBB, MU->getMemoryInst(), OldPt)) + return true; + } + + return false; + } + + // Return true when there are exception handling or loads of memory Def + // between OldPt and NewPt. + + // Decrement by 1 NBBsOnAllPaths for each block between HoistPt and BB, and + // return true when the counter NBBsOnAllPaths reaces 0, except when it is + // initialized to -1 which is unlimited. + bool hasEHOrLoadsOnPath(const Instruction *NewPt, const Instruction *OldPt, + MemoryAccess *Def, int &NBBsOnAllPaths) { + const BasicBlock *NewBB = NewPt->getParent(); + const BasicBlock *OldBB = OldPt->getParent(); + assert(DT->dominates(NewBB, OldBB) && "invalid path"); + assert(DT->dominates(Def->getBlock(), NewBB) && + "def does not dominate new hoisting point"); + + // Walk all basic blocks reachable in depth-first iteration on the inverse + // CFG from OldBB to NewBB. These blocks are all the blocks that may be + // executed between the execution of NewBB and OldBB. Hoisting an expression + // from OldBB into NewBB has to be safe on all execution paths. + for (auto I = idf_begin(OldBB), E = idf_end(OldBB); I != E;) { + if (*I == NewBB) { + // Stop traversal when reaching HoistPt. + I.skipChildren(); + continue; + } + + // Impossible to hoist with exceptions on the path. + if (hasEH(*I)) + return true; + + // Check that we do not move a store past loads. + if (hasMemoryUseOnPath(Def, *I, OldPt)) + return true; + + // Stop walk once the limit is reached. + if (NBBsOnAllPaths == 0) + return true; + + // -1 is unlimited number of blocks on all paths. + if (NBBsOnAllPaths != -1) + --NBBsOnAllPaths; + + ++I; + } + + return false; + } + + // Return true when there are exception handling between HoistPt and BB. + // Decrement by 1 NBBsOnAllPaths for each block between HoistPt and BB, and + // return true when the counter NBBsOnAllPaths reaches 0, except when it is + // initialized to -1 which is unlimited. + bool hasEHOnPath(const BasicBlock *HoistPt, const BasicBlock *BB, + int &NBBsOnAllPaths) { + assert(DT->dominates(HoistPt, BB) && "Invalid path"); + + // Walk all basic blocks reachable in depth-first iteration on + // the inverse CFG from BBInsn to NewHoistPt. These blocks are all the + // blocks that may be executed between the execution of NewHoistPt and + // BBInsn. Hoisting an expression from BBInsn into NewHoistPt has to be safe + // on all execution paths. + for (auto I = idf_begin(BB), E = idf_end(BB); I != E;) { + if (*I == HoistPt) { + // Stop traversal when reaching NewHoistPt. + I.skipChildren(); + continue; + } + + // Impossible to hoist with exceptions on the path. + if (hasEH(*I)) + return true; + + // Stop walk once the limit is reached. + if (NBBsOnAllPaths == 0) + return true; + + // -1 is unlimited number of blocks on all paths. + if (NBBsOnAllPaths != -1) + --NBBsOnAllPaths; + + ++I; + } + + return false; + } + + // Return true when it is safe to hoist a memory load or store U from OldPt + // to NewPt. + bool safeToHoistLdSt(const Instruction *NewPt, const Instruction *OldPt, + MemoryUseOrDef *U, InsKind K, int &NBBsOnAllPaths) { + + // In place hoisting is safe. + if (NewPt == OldPt) + return true; + + const BasicBlock *NewBB = NewPt->getParent(); + const BasicBlock *OldBB = OldPt->getParent(); + const BasicBlock *UBB = U->getBlock(); + + // Check for dependences on the Memory SSA. + MemoryAccess *D = U->getDefiningAccess(); + BasicBlock *DBB = D->getBlock(); + if (DT->properlyDominates(NewBB, DBB)) + // Cannot move the load or store to NewBB above its definition in DBB. + return false; + + if (NewBB == DBB && !MSSA->isLiveOnEntryDef(D)) + if (auto *UD = dyn_cast<MemoryUseOrDef>(D)) + if (firstInBB(DBB, NewPt, UD->getMemoryInst())) + // Cannot move the load or store to NewPt above its definition in D. + return false; + + // Check for unsafe hoistings due to side effects. + if (K == InsKind::Store) { + if (hasEHOrLoadsOnPath(NewPt, OldPt, D, NBBsOnAllPaths)) + return false; + } else if (hasEHOnPath(NewBB, OldBB, NBBsOnAllPaths)) + return false; + + if (UBB == NewBB) { + if (DT->properlyDominates(DBB, NewBB)) + return true; + assert(UBB == DBB); + assert(MSSA->locallyDominates(D, U)); + } + + // No side effects: it is safe to hoist. + return true; + } + + // Return true when it is safe to hoist scalar instructions from BB1 and BB2 + // to HoistBB. + bool safeToHoistScalar(const BasicBlock *HoistBB, const BasicBlock *BB1, + const BasicBlock *BB2, int &NBBsOnAllPaths) { + // Check that the hoisted expression is needed on all paths. When HoistBB + // already contains an instruction to be hoisted, the expression is needed + // on all paths. Enable scalar hoisting at -Oz as it is safe to hoist + // scalars to a place where they are partially needed. + if (!OptForMinSize && BB1 != HoistBB && + !hoistingFromAllPaths(HoistBB, BB1, BB2)) + return false; + + if (hasEHOnPath(HoistBB, BB1, NBBsOnAllPaths) || + hasEHOnPath(HoistBB, BB2, NBBsOnAllPaths)) + return false; + + // Safe to hoist scalars from BB1 and BB2 to HoistBB. + return true; + } + + // Each element of a hoisting list contains the basic block where to hoist and + // a list of instructions to be hoisted. + typedef std::pair<BasicBlock *, SmallVecInsn> HoistingPointInfo; + typedef SmallVector<HoistingPointInfo, 4> HoistingPointList; + + // Partition InstructionsToHoist into a set of candidates which can share a + // common hoisting point. The partitions are collected in HPL. IsScalar is + // true when the instructions in InstructionsToHoist are scalars. IsLoad is + // true when the InstructionsToHoist are loads, false when they are stores. + void partitionCandidates(SmallVecImplInsn &InstructionsToHoist, + HoistingPointList &HPL, InsKind K) { + // No need to sort for two instructions. + if (InstructionsToHoist.size() > 2) { + SortByDFSIn Pred(DFSNumber); + std::sort(InstructionsToHoist.begin(), InstructionsToHoist.end(), Pred); + } + + int NBBsOnAllPaths = MaxNumberOfBBSInPath; + + SmallVecImplInsn::iterator II = InstructionsToHoist.begin(); + SmallVecImplInsn::iterator Start = II; + Instruction *HoistPt = *II; + BasicBlock *HoistBB = HoistPt->getParent(); + MemoryUseOrDef *UD; + if (K != InsKind::Scalar) + UD = cast<MemoryUseOrDef>(MSSA->getMemoryAccess(HoistPt)); + + for (++II; II != InstructionsToHoist.end(); ++II) { + Instruction *Insn = *II; + BasicBlock *BB = Insn->getParent(); + BasicBlock *NewHoistBB; + Instruction *NewHoistPt; + + if (BB == HoistBB) { + NewHoistBB = HoistBB; + NewHoistPt = firstInBB(BB, Insn, HoistPt) ? Insn : HoistPt; + } else { + NewHoistBB = DT->findNearestCommonDominator(HoistBB, BB); + if (NewHoistBB == BB) + NewHoistPt = Insn; + else if (NewHoistBB == HoistBB) + NewHoistPt = HoistPt; + else + NewHoistPt = NewHoistBB->getTerminator(); + } + + if (K == InsKind::Scalar) { + if (safeToHoistScalar(NewHoistBB, HoistBB, BB, NBBsOnAllPaths)) { + // Extend HoistPt to NewHoistPt. + HoistPt = NewHoistPt; + HoistBB = NewHoistBB; + continue; + } + } else { + // When NewBB already contains an instruction to be hoisted, the + // expression is needed on all paths. + // Check that the hoisted expression is needed on all paths: it is + // unsafe to hoist loads to a place where there may be a path not + // loading from the same address: for instance there may be a branch on + // which the address of the load may not be initialized. + if ((HoistBB == NewHoistBB || BB == NewHoistBB || + hoistingFromAllPaths(NewHoistBB, HoistBB, BB)) && + // Also check that it is safe to move the load or store from HoistPt + // to NewHoistPt, and from Insn to NewHoistPt. + safeToHoistLdSt(NewHoistPt, HoistPt, UD, K, NBBsOnAllPaths) && + safeToHoistLdSt(NewHoistPt, Insn, + cast<MemoryUseOrDef>(MSSA->getMemoryAccess(Insn)), + K, NBBsOnAllPaths)) { + // Extend HoistPt to NewHoistPt. + HoistPt = NewHoistPt; + HoistBB = NewHoistBB; + continue; + } + } + + // At this point it is not safe to extend the current hoisting to + // NewHoistPt: save the hoisting list so far. + if (std::distance(Start, II) > 1) + HPL.push_back({HoistBB, SmallVecInsn(Start, II)}); + + // Start over from BB. + Start = II; + if (K != InsKind::Scalar) + UD = cast<MemoryUseOrDef>(MSSA->getMemoryAccess(*Start)); + HoistPt = Insn; + HoistBB = BB; + NBBsOnAllPaths = MaxNumberOfBBSInPath; + } + + // Save the last partition. + if (std::distance(Start, II) > 1) + HPL.push_back({HoistBB, SmallVecInsn(Start, II)}); + } + + // Initialize HPL from Map. + void computeInsertionPoints(const VNtoInsns &Map, HoistingPointList &HPL, + InsKind K) { + for (const auto &Entry : Map) { + if (MaxHoistedThreshold != -1 && ++HoistedCtr > MaxHoistedThreshold) + return; + + const SmallVecInsn &V = Entry.second; + if (V.size() < 2) + continue; + + // Compute the insertion point and the list of expressions to be hoisted. + SmallVecInsn InstructionsToHoist; + for (auto I : V) + if (!hasEH(I->getParent())) + InstructionsToHoist.push_back(I); + + if (!InstructionsToHoist.empty()) + partitionCandidates(InstructionsToHoist, HPL, K); + } + } + + // Return true when all operands of Instr are available at insertion point + // HoistPt. When limiting the number of hoisted expressions, one could hoist + // a load without hoisting its access function. So before hoisting any + // expression, make sure that all its operands are available at insert point. + bool allOperandsAvailable(const Instruction *I, + const BasicBlock *HoistPt) const { + for (const Use &Op : I->operands()) + if (const auto *Inst = dyn_cast<Instruction>(&Op)) + if (!DT->dominates(Inst->getParent(), HoistPt)) + return false; + + return true; + } + + Instruction *firstOfTwo(Instruction *I, Instruction *J) const { + for (Instruction &I1 : *I->getParent()) + if (&I1 == I || &I1 == J) + return &I1; + llvm_unreachable("Both I and J must be from same BB"); + } + + // Replace the use of From with To in Insn. + void replaceUseWith(Instruction *Insn, Value *From, Value *To) const { + for (Value::use_iterator UI = From->use_begin(), UE = From->use_end(); + UI != UE;) { + Use &U = *UI++; + if (U.getUser() == Insn) { + U.set(To); + return; + } + } + llvm_unreachable("should replace exactly once"); + } + + bool makeOperandsAvailable(Instruction *Repl, BasicBlock *HoistPt) const { + // Check whether the GEP of a ld/st can be synthesized at HoistPt. + GetElementPtrInst *Gep = nullptr; + Instruction *Val = nullptr; + if (auto *Ld = dyn_cast<LoadInst>(Repl)) + Gep = dyn_cast<GetElementPtrInst>(Ld->getPointerOperand()); + if (auto *St = dyn_cast<StoreInst>(Repl)) { + Gep = dyn_cast<GetElementPtrInst>(St->getPointerOperand()); + Val = dyn_cast<Instruction>(St->getValueOperand()); + // Check that the stored value is available. + if (Val) { + if (isa<GetElementPtrInst>(Val)) { + // Check whether we can compute the GEP at HoistPt. + if (!allOperandsAvailable(Val, HoistPt)) + return false; + } else if (!DT->dominates(Val->getParent(), HoistPt)) + return false; + } + } + + // Check whether we can compute the Gep at HoistPt. + if (!Gep || !allOperandsAvailable(Gep, HoistPt)) + return false; + + // Copy the gep before moving the ld/st. + Instruction *ClonedGep = Gep->clone(); + ClonedGep->insertBefore(HoistPt->getTerminator()); + replaceUseWith(Repl, Gep, ClonedGep); + + // Also copy Val when it is a GEP. + if (Val && isa<GetElementPtrInst>(Val)) { + Instruction *ClonedVal = Val->clone(); + ClonedVal->insertBefore(HoistPt->getTerminator()); + replaceUseWith(Repl, Val, ClonedVal); + } + + return true; + } + + std::pair<unsigned, unsigned> hoist(HoistingPointList &HPL) { + unsigned NI = 0, NL = 0, NS = 0, NC = 0, NR = 0; + for (const HoistingPointInfo &HP : HPL) { + // Find out whether we already have one of the instructions in HoistPt, + // in which case we do not have to move it. + BasicBlock *HoistPt = HP.first; + const SmallVecInsn &InstructionsToHoist = HP.second; + Instruction *Repl = nullptr; + for (Instruction *I : InstructionsToHoist) + if (I->getParent() == HoistPt) { + // If there are two instructions in HoistPt to be hoisted in place: + // update Repl to be the first one, such that we can rename the uses + // of the second based on the first. + Repl = !Repl ? I : firstOfTwo(Repl, I); + } + + if (Repl) { + // Repl is already in HoistPt: it remains in place. + assert(allOperandsAvailable(Repl, HoistPt) && + "instruction depends on operands that are not available"); + } else { + // When we do not find Repl in HoistPt, select the first in the list + // and move it to HoistPt. + Repl = InstructionsToHoist.front(); + + // We can move Repl in HoistPt only when all operands are available. + // The order in which hoistings are done may influence the availability + // of operands. + if (!allOperandsAvailable(Repl, HoistPt) && + !makeOperandsAvailable(Repl, HoistPt)) + continue; + Repl->moveBefore(HoistPt->getTerminator()); + } + + if (isa<LoadInst>(Repl)) + ++NL; + else if (isa<StoreInst>(Repl)) + ++NS; + else if (isa<CallInst>(Repl)) + ++NC; + else // Scalar + ++NI; + + // Remove and rename all other instructions. + for (Instruction *I : InstructionsToHoist) + if (I != Repl) { + ++NR; + if (isa<LoadInst>(Repl)) + ++NumLoadsRemoved; + else if (isa<StoreInst>(Repl)) + ++NumStoresRemoved; + else if (isa<CallInst>(Repl)) + ++NumCallsRemoved; + I->replaceAllUsesWith(Repl); + I->eraseFromParent(); + } + } + + NumHoisted += NL + NS + NC + NI; + NumRemoved += NR; + NumLoadsHoisted += NL; + NumStoresHoisted += NS; + NumCallsHoisted += NC; + return {NI, NL + NC + NS}; + } + + // Hoist all expressions. Returns Number of scalars hoisted + // and number of non-scalars hoisted. + std::pair<unsigned, unsigned> hoistExpressions(Function &F) { + InsnInfo II; + LoadInfo LI; + StoreInfo SI; + CallInfo CI; + for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { + for (Instruction &I1 : *BB) { + if (auto *Load = dyn_cast<LoadInst>(&I1)) + LI.insert(Load, VN); + else if (auto *Store = dyn_cast<StoreInst>(&I1)) + SI.insert(Store, VN); + else if (auto *Call = dyn_cast<CallInst>(&I1)) { + if (auto *Intr = dyn_cast<IntrinsicInst>(Call)) { + if (isa<DbgInfoIntrinsic>(Intr) || + Intr->getIntrinsicID() == Intrinsic::assume) + continue; + } + if (Call->mayHaveSideEffects()) { + if (!OptForMinSize) + break; + // We may continue hoisting across calls which write to memory. + if (Call->mayThrow()) + break; + } + CI.insert(Call, VN); + } else if (OptForMinSize || !isa<GetElementPtrInst>(&I1)) + // Do not hoist scalars past calls that may write to memory because + // that could result in spills later. geps are handled separately. + // TODO: We can relax this for targets like AArch64 as they have more + // registers than X86. + II.insert(&I1, VN); + } + } + + HoistingPointList HPL; + computeInsertionPoints(II.getVNTable(), HPL, InsKind::Scalar); + computeInsertionPoints(LI.getVNTable(), HPL, InsKind::Load); + computeInsertionPoints(SI.getVNTable(), HPL, InsKind::Store); + computeInsertionPoints(CI.getScalarVNTable(), HPL, InsKind::Scalar); + computeInsertionPoints(CI.getLoadVNTable(), HPL, InsKind::Load); + computeInsertionPoints(CI.getStoreVNTable(), HPL, InsKind::Store); + return hoist(HPL); + } + + bool run(Function &F) { + VN.setDomTree(DT); + VN.setAliasAnalysis(AA); + VN.setMemDep(MD); + bool Res = false; + + unsigned I = 0; + for (const BasicBlock *BB : depth_first(&F.getEntryBlock())) + DFSNumber.insert({BB, ++I}); + + // FIXME: use lazy evaluation of VN to avoid the fix-point computation. + while (1) { + // FIXME: only compute MemorySSA once. We need to update the analysis in + // the same time as transforming the code. + MemorySSA M(F, AA, DT); + MSSA = &M; + + auto HoistStat = hoistExpressions(F); + if (HoistStat.first + HoistStat.second == 0) { + return Res; + } + if (HoistStat.second > 0) { + // To address a limitation of the current GVN, we need to rerun the + // hoisting after we hoisted loads in order to be able to hoist all + // scalars dependent on the hoisted loads. Same for stores. + VN.clear(); + } + Res = true; + } + + return Res; + } +}; + +class GVNHoistLegacyPass : public FunctionPass { +public: + static char ID; + + GVNHoistLegacyPass() : FunctionPass(ID) { + initializeGVNHoistLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto &MD = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + + GVNHoist G(&DT, &AA, &MD, F.optForMinSize()); + return G.run(F); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + } +}; +} // namespace + +PreservedAnalyses GVNHoistPass::run(Function &F, + AnalysisManager<Function> &AM) { + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + AliasAnalysis &AA = AM.getResult<AAManager>(F); + MemoryDependenceResults &MD = AM.getResult<MemoryDependenceAnalysis>(F); + + GVNHoist G(&DT, &AA, &MD, F.optForMinSize()); + if (!G.run(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + return PA; +} + +char GVNHoistLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(GVNHoistLegacyPass, "gvn-hoist", + "Early GVN Hoisting of Expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(GVNHoistLegacyPass, "gvn-hoist", + "Early GVN Hoisting of Expressions", false, false) + +FunctionPass *llvm::createGVNHoistPass() { return new GVNHoistLegacyPass(); } diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp new file mode 100644 index 000000000000..7686e65efed9 --- /dev/null +++ b/lib/Transforms/Scalar/GuardWidening.cpp @@ -0,0 +1,691 @@ +//===- GuardWidening.cpp - ---- Guard widening ----------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the guard widening pass. The semantics of the +// @llvm.experimental.guard intrinsic lets LLVM transform it so that it fails +// more often that it did before the transform. This optimization is called +// "widening" and can be used hoist and common runtime checks in situations like +// these: +// +// %cmp0 = 7 u< Length +// call @llvm.experimental.guard(i1 %cmp0) [ "deopt"(...) ] +// call @unknown_side_effects() +// %cmp1 = 9 u< Length +// call @llvm.experimental.guard(i1 %cmp1) [ "deopt"(...) ] +// ... +// +// => +// +// %cmp0 = 9 u< Length +// call @llvm.experimental.guard(i1 %cmp0) [ "deopt"(...) ] +// call @unknown_side_effects() +// ... +// +// If %cmp0 is false, @llvm.experimental.guard will "deoptimize" back to a +// generic implementation of the same function, which will have the correct +// semantics from that point onward. It is always _legal_ to deoptimize (so +// replacing %cmp0 with false is "correct"), though it may not always be +// profitable to do so. +// +// NB! This pass is a work in progress. It hasn't been tuned to be "production +// ready" yet. It is known to have quadriatic running time and will not scale +// to large numbers of guards +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/GuardWidening.h" +#include "llvm/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" + +using namespace llvm; + +#define DEBUG_TYPE "guard-widening" + +namespace { + +class GuardWideningImpl { + DominatorTree &DT; + PostDominatorTree &PDT; + LoopInfo &LI; + + /// The set of guards whose conditions have been widened into dominating + /// guards. + SmallVector<IntrinsicInst *, 16> EliminatedGuards; + + /// The set of guards which have been widened to include conditions to other + /// guards. + DenseSet<IntrinsicInst *> WidenedGuards; + + /// Try to eliminate guard \p Guard by widening it into an earlier dominating + /// guard. \p DFSI is the DFS iterator on the dominator tree that is + /// currently visiting the block containing \p Guard, and \p GuardsPerBlock + /// maps BasicBlocks to the set of guards seen in that block. + bool eliminateGuardViaWidening( + IntrinsicInst *Guard, const df_iterator<DomTreeNode *> &DFSI, + const DenseMap<BasicBlock *, SmallVector<IntrinsicInst *, 8>> & + GuardsPerBlock); + + /// Used to keep track of which widening potential is more effective. + enum WideningScore { + /// Don't widen. + WS_IllegalOrNegative, + + /// Widening is performance neutral as far as the cycles spent in check + /// conditions goes (but can still help, e.g., code layout, having less + /// deopt state). + WS_Neutral, + + /// Widening is profitable. + WS_Positive, + + /// Widening is very profitable. Not significantly different from \c + /// WS_Positive, except by the order. + WS_VeryPositive + }; + + static StringRef scoreTypeToString(WideningScore WS); + + /// Compute the score for widening the condition in \p DominatedGuard + /// (contained in \p DominatedGuardLoop) into \p DominatingGuard (contained in + /// \p DominatingGuardLoop). + WideningScore computeWideningScore(IntrinsicInst *DominatedGuard, + Loop *DominatedGuardLoop, + IntrinsicInst *DominatingGuard, + Loop *DominatingGuardLoop); + + /// Helper to check if \p V can be hoisted to \p InsertPos. + bool isAvailableAt(Value *V, Instruction *InsertPos) { + SmallPtrSet<Instruction *, 8> Visited; + return isAvailableAt(V, InsertPos, Visited); + } + + bool isAvailableAt(Value *V, Instruction *InsertPos, + SmallPtrSetImpl<Instruction *> &Visited); + + /// Helper to hoist \p V to \p InsertPos. Guaranteed to succeed if \c + /// isAvailableAt returned true. + void makeAvailableAt(Value *V, Instruction *InsertPos); + + /// Common helper used by \c widenGuard and \c isWideningCondProfitable. Try + /// to generate an expression computing the logical AND of \p Cond0 and \p + /// Cond1. Return true if the expression computing the AND is only as + /// expensive as computing one of the two. If \p InsertPt is true then + /// actually generate the resulting expression, make it available at \p + /// InsertPt and return it in \p Result (else no change to the IR is made). + bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt, + Value *&Result); + + /// Represents a range check of the form \c Base + \c Offset u< \c Length, + /// with the constraint that \c Length is not negative. \c CheckInst is the + /// pre-existing instruction in the IR that computes the result of this range + /// check. + class RangeCheck { + Value *Base; + ConstantInt *Offset; + Value *Length; + ICmpInst *CheckInst; + + public: + explicit RangeCheck(Value *Base, ConstantInt *Offset, Value *Length, + ICmpInst *CheckInst) + : Base(Base), Offset(Offset), Length(Length), CheckInst(CheckInst) {} + + void setBase(Value *NewBase) { Base = NewBase; } + void setOffset(ConstantInt *NewOffset) { Offset = NewOffset; } + + Value *getBase() const { return Base; } + ConstantInt *getOffset() const { return Offset; } + const APInt &getOffsetValue() const { return getOffset()->getValue(); } + Value *getLength() const { return Length; }; + ICmpInst *getCheckInst() const { return CheckInst; } + + void print(raw_ostream &OS, bool PrintTypes = false) { + OS << "Base: "; + Base->printAsOperand(OS, PrintTypes); + OS << " Offset: "; + Offset->printAsOperand(OS, PrintTypes); + OS << " Length: "; + Length->printAsOperand(OS, PrintTypes); + } + + LLVM_DUMP_METHOD void dump() { + print(dbgs()); + dbgs() << "\n"; + } + }; + + /// Parse \p CheckCond into a conjunction (logical-and) of range checks; and + /// append them to \p Checks. Returns true on success, may clobber \c Checks + /// on failure. + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks) { + SmallPtrSet<Value *, 8> Visited; + return parseRangeChecks(CheckCond, Checks, Visited); + } + + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited); + + /// Combine the checks in \p Checks into a smaller set of checks and append + /// them into \p CombinedChecks. Return true on success (i.e. all of checks + /// in \p Checks were combined into \p CombinedChecks). Clobbers \p Checks + /// and \p CombinedChecks on success and on failure. + bool combineRangeChecks(SmallVectorImpl<RangeCheck> &Checks, + SmallVectorImpl<RangeCheck> &CombinedChecks); + + /// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of + /// computing only one of the two expressions? + bool isWideningCondProfitable(Value *Cond0, Value *Cond1) { + Value *ResultUnused; + return widenCondCommon(Cond0, Cond1, /*InsertPt=*/nullptr, ResultUnused); + } + + /// Widen \p ToWiden to fail if \p NewCondition is false (in addition to + /// whatever it is already checking). + void widenGuard(IntrinsicInst *ToWiden, Value *NewCondition) { + Value *Result; + widenCondCommon(ToWiden->getArgOperand(0), NewCondition, ToWiden, Result); + ToWiden->setArgOperand(0, Result); + } + +public: + explicit GuardWideningImpl(DominatorTree &DT, PostDominatorTree &PDT, + LoopInfo &LI) + : DT(DT), PDT(PDT), LI(LI) {} + + /// The entry point for this pass. + bool run(); +}; + +struct GuardWideningLegacyPass : public FunctionPass { + static char ID; + GuardWideningPass Impl; + + GuardWideningLegacyPass() : FunctionPass(ID) { + initializeGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + return GuardWideningImpl( + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(), + getAnalysis<LoopInfoWrapperPass>().getLoopInfo()).run(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + } +}; + +} + +bool GuardWideningImpl::run() { + using namespace llvm::PatternMatch; + + DenseMap<BasicBlock *, SmallVector<IntrinsicInst *, 8>> GuardsInBlock; + bool Changed = false; + + for (auto DFI = df_begin(DT.getRootNode()), DFE = df_end(DT.getRootNode()); + DFI != DFE; ++DFI) { + auto *BB = (*DFI)->getBlock(); + auto &CurrentList = GuardsInBlock[BB]; + + for (auto &I : *BB) + if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>())) + CurrentList.push_back(cast<IntrinsicInst>(&I)); + + for (auto *II : CurrentList) + Changed |= eliminateGuardViaWidening(II, DFI, GuardsInBlock); + } + + for (auto *II : EliminatedGuards) + if (!WidenedGuards.count(II)) + II->eraseFromParent(); + + return Changed; +} + +bool GuardWideningImpl::eliminateGuardViaWidening( + IntrinsicInst *GuardInst, const df_iterator<DomTreeNode *> &DFSI, + const DenseMap<BasicBlock *, SmallVector<IntrinsicInst *, 8>> & + GuardsInBlock) { + IntrinsicInst *BestSoFar = nullptr; + auto BestScoreSoFar = WS_IllegalOrNegative; + auto *GuardInstLoop = LI.getLoopFor(GuardInst->getParent()); + + // In the set of dominating guards, find the one we can merge GuardInst with + // for the most profit. + for (unsigned i = 0, e = DFSI.getPathLength(); i != e; ++i) { + auto *CurBB = DFSI.getPath(i)->getBlock(); + auto *CurLoop = LI.getLoopFor(CurBB); + assert(GuardsInBlock.count(CurBB) && "Must have been populated by now!"); + const auto &GuardsInCurBB = GuardsInBlock.find(CurBB)->second; + + auto I = GuardsInCurBB.begin(); + auto E = GuardsInCurBB.end(); + +#ifndef NDEBUG + { + unsigned Index = 0; + for (auto &I : *CurBB) { + if (Index == GuardsInCurBB.size()) + break; + if (GuardsInCurBB[Index] == &I) + Index++; + } + assert(Index == GuardsInCurBB.size() && + "Guards expected to be in order!"); + } +#endif + + assert((i == (e - 1)) == (GuardInst->getParent() == CurBB) && "Bad DFS?"); + + if (i == (e - 1)) { + // Corner case: make sure we're only looking at guards strictly dominating + // GuardInst when visiting GuardInst->getParent(). + auto NewEnd = std::find(I, E, GuardInst); + assert(NewEnd != E && "GuardInst not in its own block?"); + E = NewEnd; + } + + for (auto *Candidate : make_range(I, E)) { + auto Score = + computeWideningScore(GuardInst, GuardInstLoop, Candidate, CurLoop); + DEBUG(dbgs() << "Score between " << *GuardInst->getArgOperand(0) + << " and " << *Candidate->getArgOperand(0) << " is " + << scoreTypeToString(Score) << "\n"); + if (Score > BestScoreSoFar) { + BestScoreSoFar = Score; + BestSoFar = Candidate; + } + } + } + + if (BestScoreSoFar == WS_IllegalOrNegative) { + DEBUG(dbgs() << "Did not eliminate guard " << *GuardInst << "\n"); + return false; + } + + assert(BestSoFar != GuardInst && "Should have never visited same guard!"); + assert(DT.dominates(BestSoFar, GuardInst) && "Should be!"); + + DEBUG(dbgs() << "Widening " << *GuardInst << " into " << *BestSoFar + << " with score " << scoreTypeToString(BestScoreSoFar) << "\n"); + widenGuard(BestSoFar, GuardInst->getArgOperand(0)); + GuardInst->setArgOperand(0, ConstantInt::getTrue(GuardInst->getContext())); + EliminatedGuards.push_back(GuardInst); + WidenedGuards.insert(BestSoFar); + return true; +} + +GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( + IntrinsicInst *DominatedGuard, Loop *DominatedGuardLoop, + IntrinsicInst *DominatingGuard, Loop *DominatingGuardLoop) { + bool HoistingOutOfLoop = false; + + if (DominatingGuardLoop != DominatedGuardLoop) { + if (DominatingGuardLoop && + !DominatingGuardLoop->contains(DominatedGuardLoop)) + return WS_IllegalOrNegative; + + HoistingOutOfLoop = true; + } + + if (!isAvailableAt(DominatedGuard->getArgOperand(0), DominatingGuard)) + return WS_IllegalOrNegative; + + bool HoistingOutOfIf = + !PDT.dominates(DominatedGuard->getParent(), DominatingGuard->getParent()); + + if (isWideningCondProfitable(DominatedGuard->getArgOperand(0), + DominatingGuard->getArgOperand(0))) + return HoistingOutOfLoop ? WS_VeryPositive : WS_Positive; + + if (HoistingOutOfLoop) + return WS_Positive; + + return HoistingOutOfIf ? WS_IllegalOrNegative : WS_Neutral; +} + +bool GuardWideningImpl::isAvailableAt(Value *V, Instruction *Loc, + SmallPtrSetImpl<Instruction *> &Visited) { + auto *Inst = dyn_cast<Instruction>(V); + if (!Inst || DT.dominates(Inst, Loc) || Visited.count(Inst)) + return true; + + if (!isSafeToSpeculativelyExecute(Inst, Loc, &DT) || + Inst->mayReadFromMemory()) + return false; + + Visited.insert(Inst); + + // We only want to go _up_ the dominance chain when recursing. + assert(!isa<PHINode>(Loc) && + "PHIs should return false for isSafeToSpeculativelyExecute"); + assert(DT.isReachableFromEntry(Inst->getParent()) && + "We did a DFS from the block entry!"); + return all_of(Inst->operands(), + [&](Value *Op) { return isAvailableAt(Op, Loc, Visited); }); +} + +void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) { + auto *Inst = dyn_cast<Instruction>(V); + if (!Inst || DT.dominates(Inst, Loc)) + return; + + assert(isSafeToSpeculativelyExecute(Inst, Loc, &DT) && + !Inst->mayReadFromMemory() && "Should've checked with isAvailableAt!"); + + for (Value *Op : Inst->operands()) + makeAvailableAt(Op, Loc); + + Inst->moveBefore(Loc); +} + +bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, + Instruction *InsertPt, Value *&Result) { + using namespace llvm::PatternMatch; + + { + // L >u C0 && L >u C1 -> L >u max(C0, C1) + ConstantInt *RHS0, *RHS1; + Value *LHS; + ICmpInst::Predicate Pred0, Pred1; + if (match(Cond0, m_ICmp(Pred0, m_Value(LHS), m_ConstantInt(RHS0))) && + match(Cond1, m_ICmp(Pred1, m_Specific(LHS), m_ConstantInt(RHS1)))) { + + ConstantRange CR0 = + ConstantRange::makeExactICmpRegion(Pred0, RHS0->getValue()); + ConstantRange CR1 = + ConstantRange::makeExactICmpRegion(Pred1, RHS1->getValue()); + + // SubsetIntersect is a subset of the actual mathematical intersection of + // CR0 and CR1, while SupersetIntersect is a superset of the actual + // mathematical intersection. If these two ConstantRanges are equal, then + // we know we were able to represent the actual mathematical intersection + // of CR0 and CR1, and can use the same to generate an icmp instruction. + // + // Given what we're doing here and the semantics of guards, it would + // actually be correct to just use SubsetIntersect, but that may be too + // aggressive in cases we care about. + auto SubsetIntersect = CR0.inverse().unionWith(CR1.inverse()).inverse(); + auto SupersetIntersect = CR0.intersectWith(CR1); + + APInt NewRHSAP; + CmpInst::Predicate Pred; + if (SubsetIntersect == SupersetIntersect && + SubsetIntersect.getEquivalentICmp(Pred, NewRHSAP)) { + if (InsertPt) { + ConstantInt *NewRHS = ConstantInt::get(Cond0->getContext(), NewRHSAP); + Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); + } + return true; + } + } + } + + { + SmallVector<GuardWideningImpl::RangeCheck, 4> Checks, CombinedChecks; + if (parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) && + combineRangeChecks(Checks, CombinedChecks)) { + if (InsertPt) { + Result = nullptr; + for (auto &RC : CombinedChecks) { + makeAvailableAt(RC.getCheckInst(), InsertPt); + if (Result) + Result = BinaryOperator::CreateAnd(RC.getCheckInst(), Result, "", + InsertPt); + else + Result = RC.getCheckInst(); + } + + Result->setName("wide.chk"); + } + return true; + } + } + + // Base case -- just logical-and the two conditions together. + + if (InsertPt) { + makeAvailableAt(Cond0, InsertPt); + makeAvailableAt(Cond1, InsertPt); + + Result = BinaryOperator::CreateAnd(Cond0, Cond1, "wide.chk", InsertPt); + } + + // We were not able to compute Cond0 AND Cond1 for the price of one. + return false; +} + +bool GuardWideningImpl::parseRangeChecks( + Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited) { + if (!Visited.insert(CheckCond).second) + return true; + + using namespace llvm::PatternMatch; + + { + Value *AndLHS, *AndRHS; + if (match(CheckCond, m_And(m_Value(AndLHS), m_Value(AndRHS)))) + return parseRangeChecks(AndLHS, Checks) && + parseRangeChecks(AndRHS, Checks); + } + + auto *IC = dyn_cast<ICmpInst>(CheckCond); + if (!IC || !IC->getOperand(0)->getType()->isIntegerTy() || + (IC->getPredicate() != ICmpInst::ICMP_ULT && + IC->getPredicate() != ICmpInst::ICMP_UGT)) + return false; + + Value *CmpLHS = IC->getOperand(0), *CmpRHS = IC->getOperand(1); + if (IC->getPredicate() == ICmpInst::ICMP_UGT) + std::swap(CmpLHS, CmpRHS); + + auto &DL = IC->getModule()->getDataLayout(); + + GuardWideningImpl::RangeCheck Check( + CmpLHS, cast<ConstantInt>(ConstantInt::getNullValue(CmpRHS->getType())), + CmpRHS, IC); + + if (!isKnownNonNegative(Check.getLength(), DL)) + return false; + + // What we have in \c Check now is a correct interpretation of \p CheckCond. + // Try to see if we can move some constant offsets into the \c Offset field. + + bool Changed; + auto &Ctx = CheckCond->getContext(); + + do { + Value *OpLHS; + ConstantInt *OpRHS; + Changed = false; + +#ifndef NDEBUG + auto *BaseInst = dyn_cast<Instruction>(Check.getBase()); + assert((!BaseInst || DT.isReachableFromEntry(BaseInst->getParent())) && + "Unreachable instruction?"); +#endif + + if (match(Check.getBase(), m_Add(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { + Check.setBase(OpLHS); + APInt NewOffset = Check.getOffsetValue() + OpRHS->getValue(); + Check.setOffset(ConstantInt::get(Ctx, NewOffset)); + Changed = true; + } else if (match(Check.getBase(), + m_Or(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { + unsigned BitWidth = OpLHS->getType()->getScalarSizeInBits(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(OpLHS, KnownZero, KnownOne, DL); + if ((OpRHS->getValue() & KnownZero) == OpRHS->getValue()) { + Check.setBase(OpLHS); + APInt NewOffset = Check.getOffsetValue() + OpRHS->getValue(); + Check.setOffset(ConstantInt::get(Ctx, NewOffset)); + Changed = true; + } + } + } while (Changed); + + Checks.push_back(Check); + return true; +} + +bool GuardWideningImpl::combineRangeChecks( + SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, + SmallVectorImpl<GuardWideningImpl::RangeCheck> &RangeChecksOut) { + unsigned OldCount = Checks.size(); + while (!Checks.empty()) { + // Pick all of the range checks with a specific base and length, and try to + // merge them. + Value *CurrentBase = Checks.front().getBase(); + Value *CurrentLength = Checks.front().getLength(); + + SmallVector<GuardWideningImpl::RangeCheck, 3> CurrentChecks; + + auto IsCurrentCheck = [&](GuardWideningImpl::RangeCheck &RC) { + return RC.getBase() == CurrentBase && RC.getLength() == CurrentLength; + }; + + std::copy_if(Checks.begin(), Checks.end(), + std::back_inserter(CurrentChecks), IsCurrentCheck); + Checks.erase(remove_if(Checks, IsCurrentCheck), Checks.end()); + + assert(CurrentChecks.size() != 0 && "We know we have at least one!"); + + if (CurrentChecks.size() < 3) { + RangeChecksOut.insert(RangeChecksOut.end(), CurrentChecks.begin(), + CurrentChecks.end()); + continue; + } + + // CurrentChecks.size() will typically be 3 here, but so far there has been + // no need to hard-code that fact. + + std::sort(CurrentChecks.begin(), CurrentChecks.end(), + [&](const GuardWideningImpl::RangeCheck &LHS, + const GuardWideningImpl::RangeCheck &RHS) { + return LHS.getOffsetValue().slt(RHS.getOffsetValue()); + }); + + // Note: std::sort should not invalidate the ChecksStart iterator. + + ConstantInt *MinOffset = CurrentChecks.front().getOffset(), + *MaxOffset = CurrentChecks.back().getOffset(); + + unsigned BitWidth = MaxOffset->getValue().getBitWidth(); + if ((MaxOffset->getValue() - MinOffset->getValue()) + .ugt(APInt::getSignedMinValue(BitWidth))) + return false; + + APInt MaxDiff = MaxOffset->getValue() - MinOffset->getValue(); + const APInt &HighOffset = MaxOffset->getValue(); + auto OffsetOK = [&](const GuardWideningImpl::RangeCheck &RC) { + return (HighOffset - RC.getOffsetValue()).ult(MaxDiff); + }; + + if (MaxDiff.isMinValue() || + !std::all_of(std::next(CurrentChecks.begin()), CurrentChecks.end(), + OffsetOK)) + return false; + + // We have a series of f+1 checks as: + // + // I+k_0 u< L ... Chk_0 + // I_k_1 u< L ... Chk_1 + // ... + // I_k_f u< L ... Chk_(f+1) + // + // with forall i in [0,f): k_f-k_i u< k_f-k_0 ... Precond_0 + // k_f-k_0 u< INT_MIN+k_f ... Precond_1 + // k_f != k_0 ... Precond_2 + // + // Claim: + // Chk_0 AND Chk_(f+1) implies all the other checks + // + // Informal proof sketch: + // + // We will show that the integer range [I+k_0,I+k_f] does not unsigned-wrap + // (i.e. going from I+k_0 to I+k_f does not cross the -1,0 boundary) and + // thus I+k_f is the greatest unsigned value in that range. + // + // This combined with Ckh_(f+1) shows that everything in that range is u< L. + // Via Precond_0 we know that all of the indices in Chk_0 through Chk_(f+1) + // lie in [I+k_0,I+k_f], this proving our claim. + // + // To see that [I+k_0,I+k_f] is not a wrapping range, note that there are + // two possibilities: I+k_0 u< I+k_f or I+k_0 >u I+k_f (they can't be equal + // since k_0 != k_f). In the former case, [I+k_0,I+k_f] is not a wrapping + // range by definition, and the latter case is impossible: + // + // 0-----I+k_f---I+k_0----L---INT_MAX,INT_MIN------------------(-1) + // xxxxxx xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + // + // For Chk_0 to succeed, we'd have to have k_f-k_0 (the range highlighted + // with 'x' above) to be at least >u INT_MIN. + + RangeChecksOut.emplace_back(CurrentChecks.front()); + RangeChecksOut.emplace_back(CurrentChecks.back()); + } + + assert(RangeChecksOut.size() <= OldCount && "We pessimized!"); + return RangeChecksOut.size() != OldCount; +} + +PreservedAnalyses GuardWideningPass::run(Function &F, + AnalysisManager<Function> &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); + bool Changed = GuardWideningImpl(DT, PDT, LI).run(); + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} + +StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { + switch (WS) { + case WS_IllegalOrNegative: + return "IllegalOrNegative"; + case WS_Neutral: + return "Neutral"; + case WS_Positive: + return "Positive"; + case WS_VeryPositive: + return "VeryPositive"; + } + + llvm_unreachable("Fully covered switch above!"); +} + +char GuardWideningLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(GuardWideningLegacyPass, "guard-widening", "Widen guards", + false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(GuardWideningLegacyPass, "guard-widening", "Widen guards", + false, false) + +FunctionPass *llvm::createGuardWideningPass() { + return new GuardWideningLegacyPass(); +} diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index ec5e15f0b8f8..542cf38e43bb 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -24,13 +24,14 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/IndVarSimplify.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -69,9 +70,6 @@ static cl::opt<bool> VerifyIndvars( "verify-indvars", cl::Hidden, cl::desc("Verify the ScalarEvolution result after running indvars")); -static cl::opt<bool> ReduceLiveIVs("liv-reduce", cl::Hidden, - cl::desc("Reduce live induction variables.")); - enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, AlwaysRepl }; static cl::opt<ReplaceExitVal> ReplaceExitValue( @@ -87,42 +85,16 @@ static cl::opt<ReplaceExitVal> ReplaceExitValue( namespace { struct RewritePhi; -class IndVarSimplify : public LoopPass { - LoopInfo *LI; - ScalarEvolution *SE; - DominatorTree *DT; - TargetLibraryInfo *TLI; +class IndVarSimplify { + LoopInfo *LI; + ScalarEvolution *SE; + DominatorTree *DT; + const DataLayout &DL; + TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; SmallVector<WeakVH, 16> DeadInsts; - bool Changed; -public: - - static char ID; // Pass identification, replacement for typeid - IndVarSimplify() - : LoopPass(ID), LI(nullptr), SE(nullptr), DT(nullptr), Changed(false) { - initializeIndVarSimplifyPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreservedID(LoopSimplifyID); - AU.addPreservedID(LCSSAID); - AU.setPreservesCFG(); - } - -private: - void releaseMemory() override { - DeadInsts.clear(); - } + bool Changed = false; bool isValidRewrite(Value *FromVal, Value *ToVal); @@ -133,6 +105,7 @@ private: bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet); void rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter); + void rewriteFirstIterationLoopExitValues(Loop *L); Value *linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, PHINode *IndVar, SCEVExpander &Rewriter); @@ -141,22 +114,15 @@ private: Value *expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L, Instruction *InsertPt, Type *Ty); -}; -} -char IndVarSimplify::ID = 0; -INITIALIZE_PASS_BEGIN(IndVarSimplify, "indvars", - "Induction Variable Simplification", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSA) -INITIALIZE_PASS_END(IndVarSimplify, "indvars", - "Induction Variable Simplification", false, false) +public: + IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, + const DataLayout &DL, TargetLibraryInfo *TLI, + TargetTransformInfo *TTI) + : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) {} -Pass *llvm::createIndVarSimplifyPass() { - return new IndVarSimplify(); + bool run(Loop *L); +}; } /// Return true if the SCEV expansion generated by the rewriter can replace the @@ -504,10 +470,9 @@ struct RewritePhi { unsigned Ith; // Ith incoming value. Value *Val; // Exit value after expansion. bool HighCost; // High Cost when expansion. - bool SafePhi; // LCSSASafePhiForRAUW. - RewritePhi(PHINode *P, unsigned I, Value *V, bool H, bool S) - : PN(P), Ith(I), Val(V), HighCost(H), SafePhi(S) {} + RewritePhi(PHINode *P, unsigned I, Value *V, bool H) + : PN(P), Ith(I), Val(V), HighCost(H) {} }; } @@ -550,9 +515,7 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { // Find all values that are computed inside the loop, but used outside of it. // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan // the exit blocks of the loop to find them. - for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { - BasicBlock *ExitBB = ExitBlocks[i]; - + for (BasicBlock *ExitBB : ExitBlocks) { // If there are no PHI nodes in this exit block, then no values defined // inside the loop are used on this path, skip it. PHINode *PN = dyn_cast<PHINode>(ExitBB->begin()); @@ -560,29 +523,13 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { unsigned NumPreds = PN->getNumIncomingValues(); - // We would like to be able to RAUW single-incoming value PHI nodes. We - // have to be certain this is safe even when this is an LCSSA PHI node. - // While the computed exit value is no longer varying in *this* loop, the - // exit block may be an exit block for an outer containing loop as well, - // the exit value may be varying in the outer loop, and thus it may still - // require an LCSSA PHI node. The safe case is when this is - // single-predecessor PHI node (LCSSA) and the exit block containing it is - // part of the enclosing loop, or this is the outer most loop of the nest. - // In either case the exit value could (at most) be varying in the same - // loop body as the phi node itself. Thus if it is in turn used outside of - // an enclosing loop it will only be via a separate LCSSA node. - bool LCSSASafePhiForRAUW = - NumPreds == 1 && - (!L->getParentLoop() || L->getParentLoop() == LI->getLoopFor(ExitBB)); - // Iterate over all of the PHI nodes. BasicBlock::iterator BBI = ExitBB->begin(); while ((PN = dyn_cast<PHINode>(BBI++))) { if (PN->use_empty()) continue; // dead use, don't replace it - // SCEV only supports integer expressions for now. - if (!PN->getType()->isIntegerTy() && !PN->getType()->isPointerTy()) + if (!SE->isSCEVable(PN->getType())) continue; // It's necessary to tell ScalarEvolution about this explicitly so that @@ -669,8 +616,7 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { } // Collect all the candidate PHINodes to be rewritten. - RewritePhiSet.push_back( - RewritePhi(PN, i, ExitVal, HighCost, LCSSASafePhiForRAUW)); + RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost); } } } @@ -699,9 +645,9 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { if (isInstructionTriviallyDead(Inst, TLI)) DeadInsts.push_back(Inst); - // If we determined that this PHI is safe to replace even if an LCSSA - // PHI, do so. - if (Phi.SafePhi) { + // Replace PN with ExitVal if that is legal and does not break LCSSA. + if (PN->getNumIncomingValues() == 1 && + LI->replacementPreservesLCSSAForm(PN, ExitVal)) { PN->replaceAllUsesWith(ExitVal); PN->eraseFromParent(); } @@ -712,6 +658,80 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { Rewriter.clearInsertPoint(); } +//===---------------------------------------------------------------------===// +// rewriteFirstIterationLoopExitValues: Rewrite loop exit values if we know +// they will exit at the first iteration. +//===---------------------------------------------------------------------===// + +/// Check to see if this loop has loop invariant conditions which lead to loop +/// exits. If so, we know that if the exit path is taken, it is at the first +/// loop iteration. This lets us predict exit values of PHI nodes that live in +/// loop header. +void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { + // Verify the input to the pass is already in LCSSA form. + assert(L->isLCSSAForm(*DT)); + + SmallVector<BasicBlock *, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + auto *LoopHeader = L->getHeader(); + assert(LoopHeader && "Invalid loop"); + + for (auto *ExitBB : ExitBlocks) { + BasicBlock::iterator BBI = ExitBB->begin(); + // If there are no more PHI nodes in this exit block, then no more + // values defined inside the loop are used on this path. + while (auto *PN = dyn_cast<PHINode>(BBI++)) { + for (unsigned IncomingValIdx = 0, E = PN->getNumIncomingValues(); + IncomingValIdx != E; ++IncomingValIdx) { + auto *IncomingBB = PN->getIncomingBlock(IncomingValIdx); + + // We currently only support loop exits from loop header. If the + // incoming block is not loop header, we need to recursively check + // all conditions starting from loop header are loop invariants. + // Additional support might be added in the future. + if (IncomingBB != LoopHeader) + continue; + + // Get condition that leads to the exit path. + auto *TermInst = IncomingBB->getTerminator(); + + Value *Cond = nullptr; + if (auto *BI = dyn_cast<BranchInst>(TermInst)) { + // Must be a conditional branch, otherwise the block + // should not be in the loop. + Cond = BI->getCondition(); + } else if (auto *SI = dyn_cast<SwitchInst>(TermInst)) + Cond = SI->getCondition(); + else + continue; + + if (!L->isLoopInvariant(Cond)) + continue; + + auto *ExitVal = + dyn_cast<PHINode>(PN->getIncomingValue(IncomingValIdx)); + + // Only deal with PHIs. + if (!ExitVal) + continue; + + // If ExitVal is a PHI on the loop header, then we know its + // value along this exit because the exit can only be taken + // on the first iteration. + auto *LoopPreheader = L->getLoopPreheader(); + assert(LoopPreheader && "Invalid loop"); + int PreheaderIdx = ExitVal->getBasicBlockIndex(LoopPreheader); + if (PreheaderIdx != -1) { + assert(ExitVal->getParent() == LoopHeader && + "ExitVal must be in loop header"); + PN->setIncomingValue(IncomingValIdx, + ExitVal->getIncomingValue(PreheaderIdx)); + } + } + } + } +} + /// Check whether it is possible to delete the loop after rewriting exit /// value. If it is possible, ignore ReplaceExitValue and do rewriting /// aggressively. @@ -1240,6 +1260,12 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { if (UsePhi->getNumOperands() != 1) truncateIVUse(DU, DT, LI); else { + // Widening the PHI requires us to insert a trunc. The logical place + // for this trunc is in the same BB as the PHI. This is not possible if + // the BB is terminated by a catchswitch. + if (isa<CatchSwitchInst>(UsePhi->getParent()->getTerminator())) + return nullptr; + PHINode *WidePhi = PHINode::Create(DU.WideDef->getType(), 1, UsePhi->getName() + ".wide", UsePhi); @@ -1317,8 +1343,7 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { // Reuse the IV increment that SCEVExpander created as long as it dominates // NarrowUse. Instruction *WideUse = nullptr; - if (WideAddRec == WideIncExpr - && Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) + if (WideAddRec == WideIncExpr && Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) WideUse = WideInc; else { WideUse = cloneIVUser(DU, WideAddRec); @@ -1355,8 +1380,7 @@ void WidenIV::pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef) { if (!Widened.insert(NarrowUser).second) continue; - NarrowIVUsers.push_back( - NarrowIVDefUse(NarrowDef, NarrowUser, WideDef, NeverNegative)); + NarrowIVUsers.emplace_back(NarrowDef, NarrowUser, WideDef, NeverNegative); } } @@ -1391,9 +1415,10 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { // An AddRec must have loop-invariant operands. Since this AddRec is // materialized by a loop header phi, the expression cannot have any post-loop // operands, so they must dominate the loop header. - assert(SE->properlyDominates(AddRec->getStart(), L->getHeader()) && - SE->properlyDominates(AddRec->getStepRecurrence(*SE), L->getHeader()) - && "Loop header phi recurrence inputs do not dominate the loop"); + assert( + SE->properlyDominates(AddRec->getStart(), L->getHeader()) && + SE->properlyDominates(AddRec->getStepRecurrence(*SE), L->getHeader()) && + "Loop header phi recurrence inputs do not dominate the loop"); // The rewriter provides a value for the desired IV expression. This may // either find an existing phi or materialize a new one. Either way, we @@ -1463,8 +1488,6 @@ public: : SE(SCEV), TTI(TTI), IVPhi(IV) { DT = DTree; WI.NarrowIV = IVPhi; - if (ReduceLiveIVs) - setSplitOverflowIntrinsics(); } // Implement the interface used by simplifyUsersOfIV. @@ -1729,6 +1752,7 @@ static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, const SCEV *BestInit = nullptr; BasicBlock *LatchBlock = L->getLoopLatch(); assert(LatchBlock && "needsLFTR should guarantee a loop latch"); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { PHINode *Phi = cast<PHINode>(I); @@ -1747,8 +1771,7 @@ static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, // AR may be wider than BECount. With eq/ne tests overflow is immaterial. // AR may not be a narrower type, or we may never exit. uint64_t PhiWidth = SE->getTypeSizeInBits(AR->getType()); - if (PhiWidth < BCWidth || - !L->getHeader()->getModule()->getDataLayout().isLegalInteger(PhiWidth)) + if (PhiWidth < BCWidth || !DL.isLegalInteger(PhiWidth)) continue; const SCEV *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)); @@ -1767,8 +1790,8 @@ static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, // the loop test. In this case we assume that performing LFTR could not // increase the number of undef users. if (ICmpInst *Cond = getLoopTest(L)) { - if (Phi != getLoopPhiForCounter(Cond->getOperand(0), L, DT) - && Phi != getLoopPhiForCounter(Cond->getOperand(1), L, DT)) { + if (Phi != getLoopPhiForCounter(Cond->getOperand(0), L, DT) && + Phi != getLoopPhiForCounter(Cond->getOperand(1), L, DT)) { continue; } } @@ -1810,9 +1833,7 @@ static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, // finds a valid pointer IV. Sign extend BECount in order to materialize a // GEP. Avoid running SCEVExpander on a new pointer value, instead reusing // the existing GEPs whenever possible. - if (IndVar->getType()->isPointerTy() - && !IVCount->getType()->isPointerTy()) { - + if (IndVar->getType()->isPointerTy() && !IVCount->getType()->isPointerTy()) { // IVOffset will be the new GEP offset that is interpreted by GEP as a // signed value. IVCount on the other hand represents the loop trip count, // which is an unsigned value. FindLoopCounter only allows induction @@ -1833,13 +1854,13 @@ static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, // We could handle pointer IVs other than i8*, but we need to compensate for // gep index scaling. See canExpandBackedgeTakenCount comments. assert(SE->getSizeOfExpr(IntegerType::getInt64Ty(IndVar->getContext()), - cast<PointerType>(GEPBase->getType())->getElementType())->isOne() - && "unit stride pointer IV must be i8*"); + cast<PointerType>(GEPBase->getType()) + ->getElementType())->isOne() && + "unit stride pointer IV must be i8*"); IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); return Builder.CreateGEP(nullptr, GEPBase, GEPOffset, "lftr.limit"); - } - else { + } else { // In any other case, convert both IVInit and IVCount to integers before // comparing. This may result in SCEV expension of pointers, but in practice // SCEV will fold the pointer arithmetic away as such: @@ -1913,8 +1934,9 @@ linearFunctionTestReplace(Loop *L, } Value *ExitCnt = genLoopLimit(IndVar, IVCount, L, Rewriter, SE); - assert(ExitCnt->getType()->isPointerTy() == IndVar->getType()->isPointerTy() - && "genLoopLimit missed a cast"); + assert(ExitCnt->getType()->isPointerTy() == + IndVar->getType()->isPointerTy() && + "genLoopLimit missed a cast"); // Insert a new icmp_ne or icmp_eq instruction before the branch. BranchInst *BI = cast<BranchInst>(L->getExitingBlock()->getTerminator()); @@ -2074,9 +2096,9 @@ void IndVarSimplify::sinkUnusedInvariants(Loop *L) { // IndVarSimplify driver. Manage several subpasses of IV simplification. //===----------------------------------------------------------------------===// -bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipOptnoneFunction(L)) - return false; +bool IndVarSimplify::run(Loop *L) { + // We need (and expect!) the incoming loop to be in LCSSA. + assert(L->isRecursivelyLCSSAForm(*DT) && "LCSSA required to run indvars!"); // If LoopSimplify form is not available, stay out of trouble. Some notes: // - LSR currently only supports LoopSimplify-form loops. Indvars' @@ -2089,18 +2111,6 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { if (!L->isLoopSimplifyForm()) return false; - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - TLI = TLIP ? &TLIP->getTLI() : nullptr; - auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); - TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr; - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - - DeadInsts.clear(); - Changed = false; - // If there are any floating-point recurrences, attempt to // transform them to use integer recurrences. rewriteNonIntegerIVs(L); @@ -2172,6 +2182,11 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { // loop may be sunk below the loop to reduce register pressure. sinkUnusedInvariants(L); + // rewriteFirstIterationLoopExitValues does not rely on the computation of + // trip count and therefore can further simplify exit values in addition to + // rewriteLoopExitValues. + rewriteFirstIterationLoopExitValues(L); + // Clean up dead instructions. Changed |= DeleteDeadPHIs(L->getHeader(), TLI); @@ -2197,3 +2212,69 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { return Changed; } + +PreservedAnalyses IndVarSimplifyPass::run(Loop &L, AnalysisManager<Loop> &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function *F = L.getHeader()->getParent(); + const DataLayout &DL = F->getParent()->getDataLayout(); + + auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); + auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); + + assert((LI && SE && DT) && + "Analyses required for indvarsimplify not available!"); + + // Optional analyses. + auto *TTI = FAM.getCachedResult<TargetIRAnalysis>(*F); + auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(*F); + + IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI); + if (!IVS.run(&L)) + return PreservedAnalyses::all(); + + // FIXME: This should also 'preserve the CFG'. + return getLoopPassPreservedAnalyses(); +} + +namespace { +struct IndVarSimplifyLegacyPass : public LoopPass { + static char ID; // Pass identification, replacement for typeid + IndVarSimplifyLegacyPass() : LoopPass(ID) { + initializeIndVarSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + auto *TLI = TLIP ? &TLIP->getTLI() : nullptr; + auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); + auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr; + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + + IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI); + return IVS.run(L); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + getLoopAnalysisUsage(AU); + } +}; +} + +char IndVarSimplifyLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(IndVarSimplifyLegacyPass, "indvars", + "Induction Variable Simplification", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(IndVarSimplifyLegacyPass, "indvars", + "Induction Variable Simplification", false, false) + +Pass *llvm::createIndVarSimplifyPass() { + return new IndVarSimplifyLegacyPass(); +} diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index dea61f6ff3d7..ec7f09a2d598 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -67,7 +67,6 @@ #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" #include "llvm/Transforms/Utils/UnrollLoop.h" -#include <array> using namespace llvm; @@ -114,24 +113,22 @@ class InductiveRangeCheck { RANGE_CHECK_UNKNOWN = (unsigned)-1 }; - static const char *rangeCheckKindToStr(RangeCheckKind); + static StringRef rangeCheckKindToStr(RangeCheckKind); - const SCEV *Offset; - const SCEV *Scale; - Value *Length; - BranchInst *Branch; - RangeCheckKind Kind; + const SCEV *Offset = nullptr; + const SCEV *Scale = nullptr; + Value *Length = nullptr; + Use *CheckUse = nullptr; + RangeCheckKind Kind = RANGE_CHECK_UNKNOWN; static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, Value *&Index, Value *&Length); - static InductiveRangeCheck::RangeCheckKind - parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition, - const SCEV *&Index, Value *&UpperLimit); - - InductiveRangeCheck() : - Offset(nullptr), Scale(nullptr), Length(nullptr), Branch(nullptr) { } + static void + extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, + SmallVectorImpl<InductiveRangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited); public: const SCEV *getOffset() const { return Offset; } @@ -150,9 +147,9 @@ public: Length->print(OS); else OS << "(null)"; - OS << "\n Branch: "; - getBranch()->print(OS); - OS << "\n"; + OS << "\n CheckUse: "; + getCheckUse()->getUser()->print(OS); + OS << " Operand: " << getCheckUse()->getOperandNo() << "\n"; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -161,7 +158,7 @@ public: } #endif - BranchInst *getBranch() const { return Branch; } + Use *getCheckUse() const { return CheckUse; } /// Represents an signed integer range [Range.getBegin(), Range.getEnd()). If /// R.getEnd() sle R.getBegin(), then R denotes the empty range. @@ -180,8 +177,6 @@ public: const SCEV *getEnd() const { return End; } }; - typedef SpecificBumpPtrAllocator<InductiveRangeCheck> AllocatorTy; - /// This is the value the condition of the branch needs to evaluate to for the /// branch to take the hot successor (see (1) above). bool getPassingDirection() { return true; } @@ -190,19 +185,20 @@ public: /// check is redundant and can be constant-folded away. The induction /// variable is not required to be the canonical {0,+,1} induction variable. Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, - const SCEVAddRecExpr *IndVar, - IRBuilder<> &B) const; - - /// Create an inductive range check out of BI if possible, else return - /// nullptr. - static InductiveRangeCheck *create(AllocatorTy &Alloc, BranchInst *BI, - Loop *L, ScalarEvolution &SE, - BranchProbabilityInfo &BPI); + const SCEVAddRecExpr *IndVar) const; + + /// Parse out a set of inductive range checks from \p BI and append them to \p + /// Checks. + /// + /// NB! There may be conditions feeding into \p BI that aren't inductive range + /// checks, and hence don't end up in \p Checks. + static void + extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE, + BranchProbabilityInfo &BPI, + SmallVectorImpl<InductiveRangeCheck> &Checks); }; class InductiveRangeCheckElimination : public LoopPass { - InductiveRangeCheck::AllocatorTy Allocator; - public: static char ID; InductiveRangeCheckElimination() : LoopPass(ID) { @@ -211,11 +207,8 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); - AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<BranchProbabilityInfoWrapperPass>(); + getLoopAnalysisUsage(AU); } bool runOnLoop(Loop *L, LPPassManager &LPM) override; @@ -226,15 +219,12 @@ char InductiveRangeCheckElimination::ID = 0; INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce", "Inductive range check elimination", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSA) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_END(InductiveRangeCheckElimination, "irce", "Inductive range check elimination", false, false) -const char *InductiveRangeCheck::rangeCheckKindToStr( +StringRef InductiveRangeCheck::rangeCheckKindToStr( InductiveRangeCheck::RangeCheckKind RCK) { switch (RCK) { case InductiveRangeCheck::RANGE_CHECK_UNKNOWN: @@ -253,11 +243,9 @@ const char *InductiveRangeCheck::rangeCheckKindToStr( llvm_unreachable("unknown range check type!"); } -/// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` -/// cannot +/// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` cannot /// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set -/// `Index` and `Length` to `nullptr`. Otherwise set `Index` to the value -/// being +/// `Index` and `Length` to `nullptr`. Otherwise set `Index` to the value being /// range checked, and set `Length` to the upper limit `Index` is being range /// checked with if (and only if) the range check type is stronger or equal to /// RANGE_CHECK_UPPER. @@ -327,106 +315,89 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, llvm_unreachable("default clause returns!"); } -/// Parses an arbitrary condition into a range check. `Length` is set only if -/// the range check is recognized to be `RANGE_CHECK_UPPER` or stronger. -InductiveRangeCheck::RangeCheckKind -InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE, - Value *Condition, const SCEV *&Index, - Value *&Length) { +void InductiveRangeCheck::extractRangeChecksFromCond( + Loop *L, ScalarEvolution &SE, Use &ConditionUse, + SmallVectorImpl<InductiveRangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited) { using namespace llvm::PatternMatch; - Value *A = nullptr; - Value *B = nullptr; - - if (match(Condition, m_And(m_Value(A), m_Value(B)))) { - Value *IndexA = nullptr, *IndexB = nullptr; - Value *LengthA = nullptr, *LengthB = nullptr; - ICmpInst *ICmpA = dyn_cast<ICmpInst>(A), *ICmpB = dyn_cast<ICmpInst>(B); - - if (!ICmpA || !ICmpB) - return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - - auto RCKindA = parseRangeCheckICmp(L, ICmpA, SE, IndexA, LengthA); - auto RCKindB = parseRangeCheckICmp(L, ICmpB, SE, IndexB, LengthB); - - if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN || - RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) - return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - - if (IndexA != IndexB) - return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - - if (LengthA != nullptr && LengthB != nullptr && LengthA != LengthB) - return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; - - Index = SE.getSCEV(IndexA); - if (isa<SCEVCouldNotCompute>(Index)) - return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; + Value *Condition = ConditionUse.get(); + if (!Visited.insert(Condition).second) + return; - Length = LengthA == nullptr ? LengthB : LengthA; + if (match(Condition, m_And(m_Value(), m_Value()))) { + SmallVector<InductiveRangeCheck, 8> SubChecks; + extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0), + SubChecks, Visited); + extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1), + SubChecks, Visited); + + if (SubChecks.size() == 2) { + // Handle a special case where we know how to merge two checks separately + // checking the upper and lower bounds into a full range check. + const auto &RChkA = SubChecks[0]; + const auto &RChkB = SubChecks[1]; + if ((RChkA.Length == RChkB.Length || !RChkA.Length || !RChkB.Length) && + RChkA.Offset == RChkB.Offset && RChkA.Scale == RChkB.Scale) { + + // If RChkA.Kind == RChkB.Kind then we just found two identical checks. + // But if one of them is a RANGE_CHECK_LOWER and the other is a + // RANGE_CHECK_UPPER (only possibility if they're different) then + // together they form a RANGE_CHECK_BOTH. + SubChecks[0].Kind = + (InductiveRangeCheck::RangeCheckKind)(RChkA.Kind | RChkB.Kind); + SubChecks[0].Length = RChkA.Length ? RChkA.Length : RChkB.Length; + SubChecks[0].CheckUse = &ConditionUse; + + // We updated one of the checks in place, now erase the other. + SubChecks.pop_back(); + } + } - return (InductiveRangeCheck::RangeCheckKind)(RCKindA | RCKindB); + Checks.insert(Checks.end(), SubChecks.begin(), SubChecks.end()); + return; } - if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { - Value *IndexVal = nullptr; - - auto RCKind = parseRangeCheckICmp(L, ICI, SE, IndexVal, Length); + ICmpInst *ICI = dyn_cast<ICmpInst>(Condition); + if (!ICI) + return; - if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) - return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; + Value *Length = nullptr, *Index; + auto RCKind = parseRangeCheckICmp(L, ICI, SE, Index, Length); + if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) + return; - Index = SE.getSCEV(IndexVal); - if (isa<SCEVCouldNotCompute>(Index)) - return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; + const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index)); + bool IsAffineIndex = + IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine(); - return RCKind; - } + if (!IsAffineIndex) + return; - return InductiveRangeCheck::RANGE_CHECK_UNKNOWN; + InductiveRangeCheck IRC; + IRC.Length = Length; + IRC.Offset = IndexAddRec->getStart(); + IRC.Scale = IndexAddRec->getStepRecurrence(SE); + IRC.CheckUse = &ConditionUse; + IRC.Kind = RCKind; + Checks.push_back(IRC); } - -InductiveRangeCheck * -InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI, - Loop *L, ScalarEvolution &SE, - BranchProbabilityInfo &BPI) { +void InductiveRangeCheck::extractRangeChecksFromBranch( + BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo &BPI, + SmallVectorImpl<InductiveRangeCheck> &Checks) { if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) - return nullptr; + return; BranchProbability LikelyTaken(15, 16); - if (BPI.getEdgeProbability(BI->getParent(), (unsigned) 0) < LikelyTaken) - return nullptr; - - Value *Length = nullptr; - const SCEV *IndexSCEV = nullptr; - - auto RCKind = InductiveRangeCheck::parseRangeCheck(L, SE, BI->getCondition(), - IndexSCEV, Length); - - if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) - return nullptr; - - assert(IndexSCEV && "contract with SplitRangeCheckCondition!"); - assert((!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) || Length) && - "contract with SplitRangeCheckCondition!"); - - const SCEVAddRecExpr *IndexAddRec = dyn_cast<SCEVAddRecExpr>(IndexSCEV); - bool IsAffineIndex = - IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine(); + if (BPI.getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) + return; - if (!IsAffineIndex) - return nullptr; - - InductiveRangeCheck *IRC = new (A.Allocate()) InductiveRangeCheck; - IRC->Length = Length; - IRC->Offset = IndexAddRec->getStart(); - IRC->Scale = IndexAddRec->getStepRecurrence(SE); - IRC->Branch = BI; - IRC->Kind = RCKind; - return IRC; + SmallPtrSet<Value *, 8> Visited; + InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0), + Checks, Visited); } namespace { @@ -666,7 +637,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP return None; } - BranchInst *LatchBr = dyn_cast<BranchInst>(&*Latch->rbegin()); + BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); if (!LatchBr || LatchBr->isUnconditional()) { FailureReason = "latch terminator not conditional branch"; return None; @@ -792,7 +763,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP return None; } - IRBuilder<> B(&*Preheader->rbegin()); + IRBuilder<> B(Preheader->getTerminator()); RightValue = B.CreateAdd(RightValue, One); } @@ -814,7 +785,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP return None; } - IRBuilder<> B(&*Preheader->rbegin()); + IRBuilder<> B(Preheader->getTerminator()); RightValue = B.CreateSub(RightValue, One); } } @@ -833,7 +804,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP const DataLayout &DL = Preheader->getModule()->getDataLayout(); Value *IndVarStartV = SCEVExpander(SE, DL, "irce") - .expandCodeFor(IndVarStart, IndVarTy, &*Preheader->rbegin()); + .expandCodeFor(IndVarStart, IndVarTy, Preheader->getTerminator()); IndVarStartV->setName("indvar.start"); LoopStructure Result; @@ -947,7 +918,7 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, for (Instruction &I : *ClonedBB) RemapInstruction(&I, Result.Map, - RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); // Exit blocks will now have one more predecessor and their PHI nodes need // to be edited to reflect that. No phi nodes need to be introduced because @@ -1055,7 +1026,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, &*BBInsertLocation); - BranchInst *PreheaderJump = cast<BranchInst>(&*Preheader->rbegin()); + BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); bool Increasing = LS.IndVarIncreasing; IRBuilder<> B(PreheaderJump); @@ -1305,9 +1276,8 @@ bool LoopConstrainer::run() { /// in which the range check can be safely elided. If it cannot compute such a /// range, returns None. Optional<InductiveRangeCheck::Range> -InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, - const SCEVAddRecExpr *IndVar, - IRBuilder<> &) const { +InductiveRangeCheck::computeSafeIterationSpace( + ScalarEvolution &SE, const SCEVAddRecExpr *IndVar) const { // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is @@ -1375,7 +1345,7 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, static Optional<InductiveRangeCheck::Range> IntersectRange(ScalarEvolution &SE, const Optional<InductiveRangeCheck::Range> &R1, - const InductiveRangeCheck::Range &R2, IRBuilder<> &B) { + const InductiveRangeCheck::Range &R2) { if (!R1.hasValue()) return R2; auto &R1Value = R1.getValue(); @@ -1392,6 +1362,9 @@ IntersectRange(ScalarEvolution &SE, } bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + if (L->getBlocks().size() >= LoopSizeCutoff) { DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";); return false; @@ -1404,17 +1377,15 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { } LLVMContext &Context = Preheader->getContext(); - InductiveRangeCheck::AllocatorTy IRCAlloc; - SmallVector<InductiveRangeCheck *, 16> RangeChecks; + SmallVector<InductiveRangeCheck, 16> RangeChecks; ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); for (auto BBI : L->getBlocks()) if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) - if (InductiveRangeCheck *IRC = - InductiveRangeCheck::create(IRCAlloc, TBI, L, SE, BPI)) - RangeChecks.push_back(IRC); + InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI, + RangeChecks); if (RangeChecks.empty()) return false; @@ -1423,8 +1394,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { OS << "irce: looking at loop "; L->print(OS); OS << "irce: loop has " << RangeChecks.size() << " inductive range checks: \n"; - for (InductiveRangeCheck *IRC : RangeChecks) - IRC->print(OS); + for (InductiveRangeCheck &IRC : RangeChecks) + IRC.print(OS); }; DEBUG(PrintRecognizedRangeChecks(dbgs())); @@ -1450,14 +1421,14 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { Optional<InductiveRangeCheck::Range> SafeIterRange; Instruction *ExprInsertPt = Preheader->getTerminator(); - SmallVector<InductiveRangeCheck *, 4> RangeChecksToEliminate; + SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; IRBuilder<> B(ExprInsertPt); - for (InductiveRangeCheck *IRC : RangeChecks) { - auto Result = IRC->computeSafeIterationSpace(SE, IndVar, B); + for (InductiveRangeCheck &IRC : RangeChecks) { + auto Result = IRC.computeSafeIterationSpace(SE, IndVar); if (Result.hasValue()) { auto MaybeSafeIterRange = - IntersectRange(SE, SafeIterRange, Result.getValue(), B); + IntersectRange(SE, SafeIterRange, Result.getValue()); if (MaybeSafeIterRange.hasValue()) { RangeChecksToEliminate.push_back(IRC); SafeIterRange = MaybeSafeIterRange.getValue(); @@ -1487,11 +1458,11 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { // Optimize away the now-redundant range checks. - for (InductiveRangeCheck *IRC : RangeChecksToEliminate) { - ConstantInt *FoldedRangeCheck = IRC->getPassingDirection() + for (InductiveRangeCheck &IRC : RangeChecksToEliminate) { + ConstantInt *FoldedRangeCheck = IRC.getPassingDirection() ? ConstantInt::getTrue(Context) : ConstantInt::getFalse(Context); - IRC->getBranch()->setCondition(FoldedRangeCheck); + IRC.getCheckUse()->set(FoldedRangeCheck); } } diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index dcdcfed66e64..b9e717cf763e 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -11,31 +11,25 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/JumpThreading.h" #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" -#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" -#include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -46,6 +40,7 @@ #include <algorithm> #include <memory> using namespace llvm; +using namespace jumpthreading; #define DEBUG_TYPE "jump-threading" @@ -66,17 +61,6 @@ ImplicationSearchThreshold( cl::init(3), cl::Hidden); namespace { - // These are at global scope so static functions can use them too. - typedef SmallVectorImpl<std::pair<Constant*, BasicBlock*> > PredValueInfo; - typedef SmallVector<std::pair<Constant*, BasicBlock*>, 8> PredValueInfoTy; - - // This is used to keep track of what kind of constant we're currently hoping - // to find. - enum ConstantPreference { - WantInteger, - WantBlockAddress - }; - /// This pass performs 'jump threading', which looks at blocks that have /// multiple predecessors and multiple successors. If one or more of the /// predecessors of the block can be proven to always jump to one of the @@ -94,89 +78,31 @@ namespace { /// revectored to the false side of the second if. /// class JumpThreading : public FunctionPass { - TargetLibraryInfo *TLI; - LazyValueInfo *LVI; - std::unique_ptr<BlockFrequencyInfo> BFI; - std::unique_ptr<BranchProbabilityInfo> BPI; - bool HasProfileData; -#ifdef NDEBUG - SmallPtrSet<const BasicBlock *, 16> LoopHeaders; -#else - SmallSet<AssertingVH<const BasicBlock>, 16> LoopHeaders; -#endif - DenseSet<std::pair<Value*, BasicBlock*> > RecursionSet; - - unsigned BBDupThreshold; - - // RAII helper for updating the recursion stack. - struct RecursionSetRemover { - DenseSet<std::pair<Value*, BasicBlock*> > &TheSet; - std::pair<Value*, BasicBlock*> ThePair; - - RecursionSetRemover(DenseSet<std::pair<Value*, BasicBlock*> > &S, - std::pair<Value*, BasicBlock*> P) - : TheSet(S), ThePair(P) { } - - ~RecursionSetRemover() { - TheSet.erase(ThePair); - } - }; + JumpThreadingPass Impl; + public: static char ID; // Pass identification - JumpThreading(int T = -1) : FunctionPass(ID) { - BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); + JumpThreading(int T = -1) : FunctionPass(ID), Impl(T) { initializeJumpThreadingPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<LazyValueInfo>(); - AU.addPreserved<LazyValueInfo>(); + AU.addRequired<LazyValueInfoWrapperPass>(); + AU.addPreserved<LazyValueInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } - void releaseMemory() override { - BFI.reset(); - BPI.reset(); - } - - void FindLoopHeaders(Function &F); - bool ProcessBlock(BasicBlock *BB); - bool ThreadEdge(BasicBlock *BB, const SmallVectorImpl<BasicBlock*> &PredBBs, - BasicBlock *SuccBB); - bool DuplicateCondBranchOnPHIIntoPred(BasicBlock *BB, - const SmallVectorImpl<BasicBlock *> &PredBBs); - - bool ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, - PredValueInfo &Result, - ConstantPreference Preference, - Instruction *CxtI = nullptr); - bool ProcessThreadableEdges(Value *Cond, BasicBlock *BB, - ConstantPreference Preference, - Instruction *CxtI = nullptr); - - bool ProcessBranchOnPHI(PHINode *PN); - bool ProcessBranchOnXOR(BinaryOperator *BO); - bool ProcessImpliedCondition(BasicBlock *BB); - - bool SimplifyPartiallyRedundantLoad(LoadInst *LI); - bool TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB); - bool TryToUnfoldSelectInCurrBB(BasicBlock *BB); - - private: - BasicBlock *SplitBlockPreds(BasicBlock *BB, ArrayRef<BasicBlock *> Preds, - const char *Suffix); - void UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, BasicBlock *BB, - BasicBlock *NewBB, BasicBlock *SuccBB); + void releaseMemory() override { Impl.releaseMemory(); } }; } char JumpThreading::ID = 0; INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading", "Jump Threading", false, false) -INITIALIZE_PASS_DEPENDENCY(LazyValueInfo) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(JumpThreading, "jump-threading", "Jump Threading", false, false) @@ -184,24 +110,72 @@ INITIALIZE_PASS_END(JumpThreading, "jump-threading", // Public interface to the Jump Threading pass FunctionPass *llvm::createJumpThreadingPass(int Threshold) { return new JumpThreading(Threshold); } +JumpThreadingPass::JumpThreadingPass(int T) { + BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); +} + /// runOnFunction - Top level algorithm. /// bool JumpThreading::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; + auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + std::unique_ptr<BlockFrequencyInfo> BFI; + std::unique_ptr<BranchProbabilityInfo> BPI; + bool HasProfileData = F.getEntryCount().hasValue(); + if (HasProfileData) { + LoopInfo LI{DominatorTree(F)}; + BPI.reset(new BranchProbabilityInfo(F, LI)); + BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); + } + return Impl.runImpl(F, TLI, LVI, HasProfileData, std::move(BFI), + std::move(BPI)); +} + +PreservedAnalyses JumpThreadingPass::run(Function &F, + AnalysisManager<Function> &AM) { + + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &LVI = AM.getResult<LazyValueAnalysis>(F); + std::unique_ptr<BlockFrequencyInfo> BFI; + std::unique_ptr<BranchProbabilityInfo> BPI; + bool HasProfileData = F.getEntryCount().hasValue(); + if (HasProfileData) { + LoopInfo LI{DominatorTree(F)}; + BPI.reset(new BranchProbabilityInfo(F, LI)); + BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); + } + bool Changed = + runImpl(F, &TLI, &LVI, HasProfileData, std::move(BFI), std::move(BPI)); + + // FIXME: We need to invalidate LVI to avoid PR28400. Is there a better + // solution? + AM.invalidate<LazyValueAnalysis>(F); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} + +bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, + LazyValueInfo *LVI_, bool HasProfileData_, + std::unique_ptr<BlockFrequencyInfo> BFI_, + std::unique_ptr<BranchProbabilityInfo> BPI_) { DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - LVI = &getAnalysis<LazyValueInfo>(); + TLI = TLI_; + LVI = LVI_; BFI.reset(); BPI.reset(); // When profile data is available, we need to update edge weights after // successful jump threading, which requires both BPI and BFI being available. - HasProfileData = F.getEntryCount().hasValue(); + HasProfileData = HasProfileData_; if (HasProfileData) { - LoopInfo LI{DominatorTree(F)}; - BPI.reset(new BranchProbabilityInfo(F, LI)); - BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); + BPI = std::move(BPI_); + BFI = std::move(BFI_); } // Remove unreachable blocks from function as they may result in infinite @@ -245,10 +219,13 @@ bool JumpThreading::runOnFunction(Function &F) { // Can't thread an unconditional jump, but if the block is "almost // empty", we can replace uses of it with uses of the successor and make // this dead. + // We should not eliminate the loop header either, because eliminating + // a loop header might later prevent LoopSimplify from transforming nested + // loops into simplified form. if (BI && BI->isUnconditional() && BB != &BB->getParent()->getEntryBlock() && // If the terminator is the only non-phi instruction, try to nuke it. - BB->getFirstNonPHIOrDbg()->isTerminator()) { + BB->getFirstNonPHIOrDbg()->isTerminator() && !LoopHeaders.count(BB)) { // Since TryToSimplifyUncondBranchFromEmptyBlock may delete the // block, we have to make sure it isn't in the LoopHeaders set. We // reinsert afterward if needed. @@ -361,7 +338,7 @@ static unsigned getJumpThreadDuplicationCost(const BasicBlock *BB, /// enough to track all of these properties and keep it up-to-date as the CFG /// mutates, so we don't allow any of these transformations. /// -void JumpThreading::FindLoopHeaders(Function &F) { +void JumpThreadingPass::FindLoopHeaders(Function &F) { SmallVector<std::pair<const BasicBlock*,const BasicBlock*>, 32> Edges; FindFunctionBackedges(F, Edges); @@ -395,10 +372,9 @@ static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) { /// /// This returns true if there were any known values. /// -bool JumpThreading:: -ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, - ConstantPreference Preference, - Instruction *CxtI) { +bool JumpThreadingPass::ComputeValueKnownInPredecessors( + Value *V, BasicBlock *BB, PredValueInfo &Result, + ConstantPreference Preference, Instruction *CxtI) { // This method walks up use-def chains recursively. Because of this, we could // get into an infinite loop going around loops in the use-def chain. To // prevent this, keep track of what (value, block) pairs we've already visited @@ -415,7 +391,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, for (BasicBlock *Pred : predecessors(BB)) Result.push_back(std::make_pair(KC, Pred)); - return true; + return !Result.empty(); } // If V is a non-instruction value, or an instruction in a different block, @@ -465,6 +441,25 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, return !Result.empty(); } + // Handle Cast instructions. Only see through Cast when the source operand is + // PHI or Cmp and the source type is i1 to save the compilation time. + if (CastInst *CI = dyn_cast<CastInst>(I)) { + Value *Source = CI->getOperand(0); + if (!Source->getType()->isIntegerTy(1)) + return false; + if (!isa<PHINode>(Source) && !isa<CmpInst>(Source)) + return false; + ComputeValueKnownInPredecessors(Source, BB, Result, Preference, CxtI); + if (Result.empty()) + return false; + + // Convert the known values. + for (auto &R : Result) + R.first = ConstantExpr::getCast(CI->getOpcode(), R.first, CI->getType()); + + return true; + } + PredValueInfoTy LHSVals, RHSVals; // Handle some boolean conditions. @@ -705,7 +700,7 @@ static bool hasAddressTakenAndUsed(BasicBlock *BB) { /// ProcessBlock - If there are any predecessors whose control can be threaded /// through to a successor, transform them now. -bool JumpThreading::ProcessBlock(BasicBlock *BB) { +bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // If the block is trivially dead, just return and let the caller nuke it. // This simplifies other transformations. if (pred_empty(BB) && @@ -889,7 +884,7 @@ bool JumpThreading::ProcessBlock(BasicBlock *BB) { return false; } -bool JumpThreading::ProcessImpliedCondition(BasicBlock *BB) { +bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) { auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); if (!BI || !BI->isConditional()) return false; @@ -903,12 +898,17 @@ bool JumpThreading::ProcessImpliedCondition(BasicBlock *BB) { while (CurrentPred && Iter++ < ImplicationSearchThreshold) { auto *PBI = dyn_cast<BranchInst>(CurrentPred->getTerminator()); - if (!PBI || !PBI->isConditional() || PBI->getSuccessor(0) != CurrentBB) + if (!PBI || !PBI->isConditional()) + return false; + if (PBI->getSuccessor(0) != CurrentBB && PBI->getSuccessor(1) != CurrentBB) return false; - if (isImpliedCondition(PBI->getCondition(), Cond, DL)) { - BI->getSuccessor(1)->removePredecessor(BB); - BranchInst::Create(BI->getSuccessor(0), BI); + bool FalseDest = PBI->getSuccessor(1) == CurrentBB; + Optional<bool> Implication = + isImpliedCondition(PBI->getCondition(), Cond, DL, FalseDest); + if (Implication) { + BI->getSuccessor(*Implication ? 1 : 0)->removePredecessor(BB); + BranchInst::Create(BI->getSuccessor(*Implication ? 0 : 1), BI); BI->eraseFromParent(); return true; } @@ -923,9 +923,9 @@ bool JumpThreading::ProcessImpliedCondition(BasicBlock *BB) { /// load instruction, eliminate it by replacing it with a PHI node. This is an /// important optimization that encourages jump threading, and needs to be run /// interlaced with other jump threading tasks. -bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { - // Don't hack volatile/atomic loads. - if (!LI->isSimple()) return false; +bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { + // Don't hack volatile and ordered loads. + if (!LI->isUnordered()) return false; // If the load is defined in a block with exactly one predecessor, it can't be // partially redundant. @@ -952,10 +952,9 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { BasicBlock::iterator BBIt(LI); if (Value *AvailableVal = - FindAvailableLoadedValue(LoadedPtr, LoadBB, BBIt, DefMaxInstsToScan)) { + FindAvailableLoadedValue(LI, LoadBB, BBIt, DefMaxInstsToScan)) { // If the value of the load is locally available within the block, just use // it. This frequently occurs for reg2mem'd allocas. - //cerr << "LOAD ELIMINATED:\n" << *BBIt << *LI << "\n"; // If the returned value is the load itself, replace with an undef. This can // only happen in dead loops. @@ -994,7 +993,7 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // Scan the predecessor to see if the value is available in the pred. BBIt = PredBB->end(); AAMDNodes ThisAATags; - Value *PredAvailable = FindAvailableLoadedValue(LoadedPtr, PredBB, BBIt, + Value *PredAvailable = FindAvailableLoadedValue(LI, PredBB, BBIt, DefMaxInstsToScan, nullptr, &ThisAATags); if (!PredAvailable) { @@ -1056,9 +1055,10 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { if (UnavailablePred) { assert(UnavailablePred->getTerminator()->getNumSuccessors() == 1 && "Can't handle critical edge here!"); - LoadInst *NewVal = new LoadInst(LoadedPtr, LI->getName()+".pr", false, - LI->getAlignment(), - UnavailablePred->getTerminator()); + LoadInst *NewVal = + new LoadInst(LoadedPtr, LI->getName() + ".pr", false, + LI->getAlignment(), LI->getOrdering(), LI->getSynchScope(), + UnavailablePred->getTerminator()); NewVal->setDebugLoc(LI->getDebugLoc()); if (AATags) NewVal->setAAMetadata(AATags); @@ -1100,8 +1100,6 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { PN->addIncoming(PredV, I->first); } - //cerr << "PRE: " << *LI << *PN << "\n"; - LI->replaceAllUsesWith(PN); LI->eraseFromParent(); @@ -1171,9 +1169,9 @@ FindMostPopularDest(BasicBlock *BB, return MostPopularDest; } -bool JumpThreading::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, - ConstantPreference Preference, - Instruction *CxtI) { +bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, + ConstantPreference Preference, + Instruction *CxtI) { // If threading this would thread across a loop header, don't even try to // thread the edge. if (LoopHeaders.count(BB)) @@ -1279,7 +1277,7 @@ bool JumpThreading::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, /// a PHI node in the current block. See if there are any simplifications we /// can do based on inputs to the phi node. /// -bool JumpThreading::ProcessBranchOnPHI(PHINode *PN) { +bool JumpThreadingPass::ProcessBranchOnPHI(PHINode *PN) { BasicBlock *BB = PN->getParent(); // TODO: We could make use of this to do it once for blocks with common PHI @@ -1309,7 +1307,7 @@ bool JumpThreading::ProcessBranchOnPHI(PHINode *PN) { /// a xor instruction in the current block. See if there are any /// simplifications we can do based on inputs to the xor. /// -bool JumpThreading::ProcessBranchOnXOR(BinaryOperator *BO) { +bool JumpThreadingPass::ProcessBranchOnXOR(BinaryOperator *BO) { BasicBlock *BB = BO->getParent(); // If either the LHS or RHS of the xor is a constant, don't do this @@ -1437,9 +1435,9 @@ static void AddPHINodeEntriesForMappedBlock(BasicBlock *PHIBB, /// ThreadEdge - We have decided that it is safe and profitable to factor the /// blocks in PredBBs to one predecessor, then thread an edge from it to SuccBB /// across BB. Transform the IR to reflect this change. -bool JumpThreading::ThreadEdge(BasicBlock *BB, - const SmallVectorImpl<BasicBlock*> &PredBBs, - BasicBlock *SuccBB) { +bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, + const SmallVectorImpl<BasicBlock *> &PredBBs, + BasicBlock *SuccBB) { // If threading to the same block as we come from, we would infinite loop. if (SuccBB == BB) { DEBUG(dbgs() << " Not threading across BB '" << BB->getName() @@ -1593,9 +1591,9 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB, /// Create a new basic block that will be the predecessor of BB and successor of /// all blocks in Preds. When profile data is availble, update the frequency of /// this new block. -BasicBlock *JumpThreading::SplitBlockPreds(BasicBlock *BB, - ArrayRef<BasicBlock *> Preds, - const char *Suffix) { +BasicBlock *JumpThreadingPass::SplitBlockPreds(BasicBlock *BB, + ArrayRef<BasicBlock *> Preds, + const char *Suffix) { // Collect the frequencies of all predecessors of BB, which will be used to // update the edge weight on BB->SuccBB. BlockFrequency PredBBFreq(0); @@ -1615,10 +1613,10 @@ BasicBlock *JumpThreading::SplitBlockPreds(BasicBlock *BB, /// Update the block frequency of BB and branch weight and the metadata on the /// edge BB->SuccBB. This is done by scaling the weight of BB->SuccBB by 1 - /// Freq(PredBB->BB) / Freq(BB->SuccBB). -void JumpThreading::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, - BasicBlock *BB, - BasicBlock *NewBB, - BasicBlock *SuccBB) { +void JumpThreadingPass::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, + BasicBlock *BB, + BasicBlock *NewBB, + BasicBlock *SuccBB) { if (!HasProfileData) return; @@ -1679,8 +1677,8 @@ void JumpThreading::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, /// If we can duplicate the contents of BB up into PredBB do so now, this /// improves the odds that the branch will be on an analyzable instruction like /// a compare. -bool JumpThreading::DuplicateCondBranchOnPHIIntoPred(BasicBlock *BB, - const SmallVectorImpl<BasicBlock *> &PredBBs) { +bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( + BasicBlock *BB, const SmallVectorImpl<BasicBlock *> &PredBBs) { assert(!PredBBs.empty() && "Can't handle an empty set"); // If BB is a loop header, then duplicating this block outside the loop would @@ -1750,13 +1748,18 @@ bool JumpThreading::DuplicateCondBranchOnPHIIntoPred(BasicBlock *BB, // phi translation. if (Value *IV = SimplifyInstruction(New, BB->getModule()->getDataLayout())) { - delete New; ValueMapping[&*BI] = IV; + if (!New->mayHaveSideEffects()) { + delete New; + New = nullptr; + } } else { + ValueMapping[&*BI] = New; + } + if (New) { // Otherwise, insert the new instruction into the block. New->setName(BI->getName()); PredBB->getInstList().insert(OldPredBranch->getIterator(), New); - ValueMapping[&*BI] = New; } } @@ -1829,7 +1832,7 @@ bool JumpThreading::DuplicateCondBranchOnPHIIntoPred(BasicBlock *BB, /// /// And expand the select into a branch structure if one of its arms allows %c /// to be folded. This later enables threading from bb1 over bb2. -bool JumpThreading::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { +bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); PHINode *CondLHS = dyn_cast<PHINode>(CondCmp->getOperand(0)); Constant *CondRHS = cast<Constant>(CondCmp->getOperand(1)); @@ -1907,7 +1910,7 @@ bool JumpThreading::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { /// select if the associated PHI has at least one constant. If the unfolded /// select is not jump-threaded, it will be folded again in the later /// optimizations. -bool JumpThreading::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { +bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { // If threading this would thread across a loop header, don't thread the edge. // See the comments above FindLoopHeaders for justifications and caveats. if (LoopHeaders.count(BB)) diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp index 8923ff74253c..2c0a70e44f57 100644 --- a/lib/Transforms/Scalar/LICM.cpp +++ b/lib/Transforms/Scalar/LICM.cpp @@ -30,15 +30,19 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LICM.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopPassManager.h" +#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -56,183 +60,173 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> +#include <utility> using namespace llvm; #define DEBUG_TYPE "licm" -STATISTIC(NumSunk , "Number of instructions sunk out of loop"); -STATISTIC(NumHoisted , "Number of instructions hoisted out of loop"); +STATISTIC(NumSunk, "Number of instructions sunk out of loop"); +STATISTIC(NumHoisted, "Number of instructions hoisted out of loop"); STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk"); STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk"); -STATISTIC(NumPromoted , "Number of memory locations promoted to registers"); +STATISTIC(NumPromoted, "Number of memory locations promoted to registers"); static cl::opt<bool> -DisablePromotion("disable-licm-promotion", cl::Hidden, - cl::desc("Disable memory promotion in LICM pass")); + DisablePromotion("disable-licm-promotion", cl::Hidden, + cl::desc("Disable memory promotion in LICM pass")); static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop, - const LICMSafetyInfo *SafetyInfo); -static bool hoist(Instruction &I, BasicBlock *Preheader); + const LoopSafetyInfo *SafetyInfo); +static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo); static bool sink(Instruction &I, const LoopInfo *LI, const DominatorTree *DT, const Loop *CurLoop, AliasSetTracker *CurAST, - const LICMSafetyInfo *SafetyInfo); -static bool isGuaranteedToExecute(const Instruction &Inst, - const DominatorTree *DT, - const Loop *CurLoop, - const LICMSafetyInfo *SafetyInfo); + const LoopSafetyInfo *SafetyInfo); static bool isSafeToExecuteUnconditionally(const Instruction &Inst, const DominatorTree *DT, - const TargetLibraryInfo *TLI, const Loop *CurLoop, - const LICMSafetyInfo *SafetyInfo, + const LoopSafetyInfo *SafetyInfo, const Instruction *CtxI = nullptr); static bool pointerInvalidatedByLoop(Value *V, uint64_t Size, - const AAMDNodes &AAInfo, + const AAMDNodes &AAInfo, AliasSetTracker *CurAST); static Instruction * CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, - const LICMSafetyInfo *SafetyInfo); + const LoopSafetyInfo *SafetyInfo); static bool canSinkOrHoistInst(Instruction &I, AliasAnalysis *AA, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, AliasSetTracker *CurAST, - LICMSafetyInfo *SafetyInfo); + LoopSafetyInfo *SafetyInfo); namespace { - struct LICM : public LoopPass { - static char ID; // Pass identification, replacement for typeid - LICM() : LoopPass(ID) { - initializeLICMPass(*PassRegistry::getPassRegistry()); - } +struct LoopInvariantCodeMotion { + bool runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, + TargetLibraryInfo *TLI, ScalarEvolution *SE, bool DeleteAST); - bool runOnLoop(Loop *L, LPPassManager &LPM) override; - - /// This transformation requires natural loop information & requires that - /// loop preheaders be inserted into the CFG... - /// - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addPreservedID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); - AU.addPreservedID(LCSSAID); - AU.addRequired<AAResultsWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<BasicAAWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<SCEVAAWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } + DenseMap<Loop *, AliasSetTracker *> &getLoopToAliasSetMap() { + return LoopToAliasSetMap; + } + +private: + DenseMap<Loop *, AliasSetTracker *> LoopToAliasSetMap; - using llvm::Pass::doFinalization; + AliasSetTracker *collectAliasInfoForLoop(Loop *L, LoopInfo *LI, + AliasAnalysis *AA); +}; + +struct LegacyLICMPass : public LoopPass { + static char ID; // Pass identification, replacement for typeid + LegacyLICMPass() : LoopPass(ID) { + initializeLegacyLICMPassPass(*PassRegistry::getPassRegistry()); + } - bool doFinalization() override { - assert(LoopToAliasSetMap.empty() && "Didn't free loop alias sets"); + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) return false; - } - private: - AliasAnalysis *AA; // Current AliasAnalysis information - LoopInfo *LI; // Current LoopInfo - DominatorTree *DT; // Dominator Tree for the current Loop. + auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + return LICM.runOnLoop(L, + &getAnalysis<AAResultsWrapperPass>().getAAResults(), + &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), + &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + SE ? &SE->getSE() : nullptr, false); + } - TargetLibraryInfo *TLI; // TargetLibraryInfo for constant folding. + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + /// + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } - // State that is updated as we process loops. - bool Changed; // Set to true when we change anything. - BasicBlock *Preheader; // The preheader block of the current loop... - Loop *CurLoop; // The current loop we are working on... - AliasSetTracker *CurAST; // AliasSet information for the current loop... - DenseMap<Loop*, AliasSetTracker*> LoopToAliasSetMap; + using llvm::Pass::doFinalization; - /// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info. - void cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, - Loop *L) override; + bool doFinalization() override { + assert(LICM.getLoopToAliasSetMap().empty() && + "Didn't free loop alias sets"); + return false; + } - /// deleteAnalysisValue - Simple Analysis hook. Delete value V from alias - /// set. - void deleteAnalysisValue(Value *V, Loop *L) override; +private: + LoopInvariantCodeMotion LICM; - /// Simple Analysis hook. Delete loop L from alias set map. - void deleteAnalysisLoop(Loop *L) override; - }; + /// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info. + void cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, + Loop *L) override; + + /// deleteAnalysisValue - Simple Analysis hook. Delete value V from alias + /// set. + void deleteAnalysisValue(Value *V, Loop *L) override; + + /// Simple Analysis hook. Delete loop L from alias set map. + void deleteAnalysisLoop(Loop *L) override; +}; +} + +PreservedAnalyses LICMPass::run(Loop &L, AnalysisManager<Loop> &AM) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function *F = L.getHeader()->getParent(); + + auto *AA = FAM.getCachedResult<AAManager>(*F); + auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); + auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(*F); + auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); + assert((AA && LI && DT && TLI && SE) && "Analyses for LICM not available"); + + LoopInvariantCodeMotion LICM; + + if (!LICM.runOnLoop(&L, AA, LI, DT, TLI, SE, true)) + return PreservedAnalyses::all(); + + // FIXME: There is no setPreservesCFG in the new PM. When that becomes + // available, it should be used here. + return getLoopPassPreservedAnalyses(); } -char LICM::ID = 0; -INITIALIZE_PASS_BEGIN(LICM, "licm", "Loop Invariant Code Motion", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSA) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +char LegacyLICMPass::ID = 0; +INITIALIZE_PASS_BEGIN(LegacyLICMPass, "licm", "Loop Invariant Code Motion", + false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) -INITIALIZE_PASS_END(LICM, "licm", "Loop Invariant Code Motion", false, false) +INITIALIZE_PASS_END(LegacyLICMPass, "licm", "Loop Invariant Code Motion", false, + false) -Pass *llvm::createLICMPass() { return new LICM(); } +Pass *llvm::createLICMPass() { return new LegacyLICMPass(); } /// Hoist expressions out of the specified loop. Note, alias info for inner /// loop is not preserved so it is not a good idea to run LICM multiple /// times on one loop. +/// We should delete AST for inner loops in the new pass manager to avoid +/// memory leak. /// -bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipOptnoneFunction(L)) - return false; - - Changed = false; - - // Get our Loop and Alias Analysis information... - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); +bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AliasAnalysis *AA, + LoopInfo *LI, DominatorTree *DT, + TargetLibraryInfo *TLI, + ScalarEvolution *SE, bool DeleteAST) { + bool Changed = false; assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); - CurAST = new AliasSetTracker(*AA); - // Collect Alias info from subloops. - for (Loop *InnerL : L->getSubLoops()) { - AliasSetTracker *InnerAST = LoopToAliasSetMap[InnerL]; - assert(InnerAST && "Where is my AST?"); - - // What if InnerLoop was modified by other passes ? - CurAST->add(*InnerAST); - - // Once we've incorporated the inner loop's AST into ours, we don't need the - // subloop's anymore. - delete InnerAST; - LoopToAliasSetMap.erase(InnerL); - } - - CurLoop = L; + AliasSetTracker *CurAST = collectAliasInfoForLoop(L, LI, AA); // Get the preheader block to move instructions into... - Preheader = L->getLoopPreheader(); - - // Loop over the body of this loop, looking for calls, invokes, and stores. - // Because subloops have already been incorporated into AST, we skip blocks in - // subloops. - // - for (BasicBlock *BB : L->blocks()) { - if (LI->getLoopFor(BB) == L) // Ignore blocks in subloops. - CurAST->add(*BB); // Incorporate the specified basic block - } + BasicBlock *Preheader = L->getLoopPreheader(); // Compute loop safety information. - LICMSafetyInfo SafetyInfo; - computeLICMSafetyInfo(&SafetyInfo, CurLoop); + LoopSafetyInfo SafetyInfo; + computeLoopSafetyInfo(&SafetyInfo, L); // We want to visit all of the instructions in this loop... that are not parts // of our subloops (they have already had their invariants hoisted out of @@ -245,11 +239,11 @@ bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { // instructions, we perform another pass to hoist them out of the loop. // if (L->hasDedicatedExits()) - Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, CurLoop, + Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, CurAST, &SafetyInfo); if (Preheader) - Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, - CurLoop, CurAST, &SafetyInfo); + Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, + CurAST, &SafetyInfo); // Now that all loop invariants have been removed from the loop, promote any // memory references to scalars that we can. @@ -260,9 +254,8 @@ bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { // Loop over all of the alias sets in the tracker object. for (AliasSet &AS : *CurAST) - Changed |= promoteLoopAccessesToScalars(AS, ExitBlocks, InsertPts, - PIC, LI, DT, CurLoop, - CurAST, &SafetyInfo); + Changed |= promoteLoopAccessesToScalars( + AS, ExitBlocks, InsertPts, PIC, LI, DT, TLI, L, CurAST, &SafetyInfo); // Once we have promoted values across the loop body we have to recursively // reform LCSSA as any nested loop may now have values defined within the @@ -271,8 +264,7 @@ bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { // SSAUpdater strategy during promotion that was LCSSA aware and reformed // it as it went. if (Changed) { - auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); - formLCSSARecursively(*L, *DT, LI, SEWP ? &SEWP->getSE() : nullptr); + formLCSSARecursively(*L, *DT, LI, SE); } } @@ -283,50 +275,49 @@ bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { assert((!L->getParentLoop() || L->getParentLoop()->isLCSSAForm(*DT)) && "Parent loop not left in LCSSA form after LICM!"); - // Clear out loops state information for the next iteration - CurLoop = nullptr; - Preheader = nullptr; - // If this loop is nested inside of another one, save the alias information // for when we process the outer loop. - if (L->getParentLoop()) + if (L->getParentLoop() && !DeleteAST) LoopToAliasSetMap[L] = CurAST; else delete CurAST; + + if (Changed && SE) + SE->forgetLoopDispositions(L); return Changed; } /// Walk the specified region of the CFG (defined by all blocks dominated by -/// the specified block, and that are in the current loop) in reverse depth +/// the specified block, and that are in the current loop) in reverse depth /// first order w.r.t the DominatorTree. This allows us to visit uses before /// definitions, allowing us to sink a loop body in one pass without iteration. /// bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, - AliasSetTracker *CurAST, LICMSafetyInfo *SafetyInfo) { + AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo) { // Verify inputs. - assert(N != nullptr && AA != nullptr && LI != nullptr && - DT != nullptr && CurLoop != nullptr && CurAST != nullptr && - SafetyInfo != nullptr && "Unexpected input to sinkRegion"); + assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && + CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && + "Unexpected input to sinkRegion"); - // Set changed as false. - bool Changed = false; - // Get basic block BasicBlock *BB = N->getBlock(); // If this subregion is not in the top level loop at all, exit. - if (!CurLoop->contains(BB)) return Changed; + if (!CurLoop->contains(BB)) + return false; // We are processing blocks in reverse dfo, so process children first. - const std::vector<DomTreeNode*> &Children = N->getChildren(); + bool Changed = false; + const std::vector<DomTreeNode *> &Children = N->getChildren(); for (DomTreeNode *Child : Children) Changed |= sinkRegion(Child, AA, LI, DT, TLI, CurLoop, CurAST, SafetyInfo); // Only need to process the contents of this block if it is not part of a // subloop (which would already have been processed). - if (inSubLoop(BB,CurLoop,LI)) return Changed; + if (inSubLoop(BB, CurLoop, LI)) + return Changed; - for (BasicBlock::iterator II = BB->end(); II != BB->begin(); ) { + for (BasicBlock::iterator II = BB->end(); II != BB->begin();) { Instruction &I = *--II; // If the instruction is dead, we would try to sink it because it isn't used @@ -361,21 +352,23 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, /// bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, - AliasSetTracker *CurAST, LICMSafetyInfo *SafetyInfo) { + AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo) { // Verify inputs. - assert(N != nullptr && AA != nullptr && LI != nullptr && - DT != nullptr && CurLoop != nullptr && CurAST != nullptr && - SafetyInfo != nullptr && "Unexpected input to hoistRegion"); - // Set changed as false. - bool Changed = false; - // Get basic block + assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && + CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && + "Unexpected input to hoistRegion"); + BasicBlock *BB = N->getBlock(); + // If this subregion is not in the top level loop at all, exit. - if (!CurLoop->contains(BB)) return Changed; + if (!CurLoop->contains(BB)) + return false; + // Only need to process the contents of this block if it is not part of a // subloop (which would already have been processed). + bool Changed = false; if (!inSubLoop(BB, CurLoop, LI)) - for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E; ) { + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) { Instruction &I = *II++; // Try constant folding this instruction. If all the operands are // constants, it is technically hoistable, but it would be better to just @@ -396,12 +389,13 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // if (CurLoop->hasLoopInvariantOperands(&I) && canSinkOrHoistInst(I, AA, DT, TLI, CurLoop, CurAST, SafetyInfo) && - isSafeToExecuteUnconditionally(I, DT, TLI, CurLoop, SafetyInfo, - CurLoop->getLoopPreheader()->getTerminator())) - Changed |= hoist(I, CurLoop->getLoopPreheader()); + isSafeToExecuteUnconditionally( + I, DT, CurLoop, SafetyInfo, + CurLoop->getLoopPreheader()->getTerminator())) + Changed |= hoist(I, DT, CurLoop, SafetyInfo); } - const std::vector<DomTreeNode*> &Children = N->getChildren(); + const std::vector<DomTreeNode *> &Children = N->getChildren(); for (DomTreeNode *Child : Children) Changed |= hoistRegion(Child, AA, LI, DT, TLI, CurLoop, CurAST, SafetyInfo); return Changed; @@ -410,7 +404,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, /// Computes loop safety information, checks loop body & header /// for the possibility of may throw exception. /// -void llvm::computeLICMSafetyInfo(LICMSafetyInfo * SafetyInfo, Loop * CurLoop) { +void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { assert(CurLoop != nullptr && "CurLoop cant be null"); BasicBlock *Header = CurLoop->getHeader(); // Setting default safety values. @@ -419,15 +413,17 @@ void llvm::computeLICMSafetyInfo(LICMSafetyInfo * SafetyInfo, Loop * CurLoop) { // Iterate over header and compute safety info. for (BasicBlock::iterator I = Header->begin(), E = Header->end(); (I != E) && !SafetyInfo->HeaderMayThrow; ++I) - SafetyInfo->HeaderMayThrow |= I->mayThrow(); - + SafetyInfo->HeaderMayThrow |= + !isGuaranteedToTransferExecutionToSuccessor(&*I); + SafetyInfo->MayThrow = SafetyInfo->HeaderMayThrow; - // Iterate over loop instructions and compute safety info. - for (Loop::block_iterator BB = CurLoop->block_begin(), - BBE = CurLoop->block_end(); (BB != BBE) && !SafetyInfo->MayThrow ; ++BB) + // Iterate over loop instructions and compute safety info. + for (Loop::block_iterator BB = CurLoop->block_begin(), + BBE = CurLoop->block_end(); + (BB != BBE) && !SafetyInfo->MayThrow; ++BB) for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); (I != E) && !SafetyInfo->MayThrow; ++I) - SafetyInfo->MayThrow |= I->mayThrow(); + SafetyInfo->MayThrow |= !isGuaranteedToTransferExecutionToSuccessor(&*I); // Compute funclet colors if we might sink/hoist in a function with a funclet // personality routine. @@ -443,11 +439,11 @@ void llvm::computeLICMSafetyInfo(LICMSafetyInfo * SafetyInfo, Loop * CurLoop) { /// bool canSinkOrHoistInst(Instruction &I, AliasAnalysis *AA, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, - AliasSetTracker *CurAST, LICMSafetyInfo *SafetyInfo) { + AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo) { // Loads have extra constraints we have to verify before we can hoist them. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { if (!LI->isUnordered()) - return false; // Don't hoist volatile/atomic loads! + return false; // Don't hoist volatile/atomic loads! // Loads from constant memory are always safe to move, even if they end up // in the same alias set as something that ends up being modified. @@ -499,7 +495,8 @@ bool canSinkOrHoistInst(Instruction &I, AliasAnalysis *AA, DominatorTree *DT, break; } } - if (!FoundMod) return true; + if (!FoundMod) + return true; } // FIXME: This should use mod/ref information to see if we can hoist or @@ -518,9 +515,8 @@ bool canSinkOrHoistInst(Instruction &I, AliasAnalysis *AA, DominatorTree *DT, // TODO: Plumb the context instruction through to make hoisting and sinking // more powerful. Hoisting of loads already works due to the special casing - // above. - return isSafeToExecuteUnconditionally(I, DT, TLI, CurLoop, SafetyInfo, - nullptr); + // above. + return isSafeToExecuteUnconditionally(I, DT, CurLoop, SafetyInfo, nullptr); } /// Returns true if a PHINode is a trivially replaceable with an @@ -541,7 +537,7 @@ static bool isTriviallyReplacablePHI(const PHINode &PN, const Instruction &I) { /// blocks of the loop. /// static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop, - const LICMSafetyInfo *SafetyInfo) { + const LoopSafetyInfo *SafetyInfo) { const auto &BlockColors = SafetyInfo->BlockColors; for (const User *U : I.users()) { const Instruction *UI = cast<Instruction>(U); @@ -588,7 +584,7 @@ static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop, static Instruction * CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, - const LICMSafetyInfo *SafetyInfo) { + const LoopSafetyInfo *SafetyInfo) { Instruction *New; if (auto *CI = dyn_cast<CallInst>(&I)) { const auto &BlockColors = SafetyInfo->BlockColors; @@ -621,7 +617,8 @@ CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, } ExitBlock.getInstList().insert(ExitBlock.getFirstInsertionPt(), New); - if (!I.getName().empty()) New->setName(I.getName() + ".le"); + if (!I.getName().empty()) + New->setName(I.getName() + ".le"); // Build LCSSA PHI nodes for any in-loop operands. Note that this is // particularly cheap because we can rip off the PHI node that we're @@ -652,18 +649,20 @@ CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, /// static bool sink(Instruction &I, const LoopInfo *LI, const DominatorTree *DT, const Loop *CurLoop, AliasSetTracker *CurAST, - const LICMSafetyInfo *SafetyInfo) { + const LoopSafetyInfo *SafetyInfo) { DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); bool Changed = false; - if (isa<LoadInst>(I)) ++NumMovedLoads; - else if (isa<CallInst>(I)) ++NumMovedCalls; + if (isa<LoadInst>(I)) + ++NumMovedLoads; + else if (isa<CallInst>(I)) + ++NumMovedCalls; ++NumSunk; Changed = true; #ifndef NDEBUG SmallVector<BasicBlock *, 32> ExitBlocks; CurLoop->getUniqueExitBlocks(ExitBlocks); - SmallPtrSet<BasicBlock *, 32> ExitBlockSet(ExitBlocks.begin(), + SmallPtrSet<BasicBlock *, 32> ExitBlockSet(ExitBlocks.begin(), ExitBlocks.end()); #endif @@ -717,18 +716,30 @@ static bool sink(Instruction &I, const LoopInfo *LI, const DominatorTree *DT, /// When an instruction is found to only use loop invariant operands that /// is safe to hoist, this instruction is called to do the dirty work. /// -static bool hoist(Instruction &I, BasicBlock *Preheader) { - DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " - << I << "\n"); +static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo) { + auto *Preheader = CurLoop->getLoopPreheader(); + DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I + << "\n"); + + // Metadata can be dependent on conditions we are hoisting above. + // Conservatively strip all metadata on the instruction unless we were + // guaranteed to execute I if we entered the loop, in which case the metadata + // is valid in the loop preheader. + if (I.hasMetadataOtherThanDebugLoc() && + // The check on hasMetadataOtherThanDebugLoc is to prevent us from burning + // time in isGuaranteedToExecute if we don't actually have anything to + // drop. It is a compile time optimization, not required for correctness. + !isGuaranteedToExecute(I, DT, CurLoop, SafetyInfo)) + I.dropUnknownNonDebugMetadata(); + // Move the new node to the Preheader, before its terminator. I.moveBefore(Preheader->getTerminator()); - // Metadata can be dependent on the condition we are hoisting above. - // Conservatively strip all metadata on the instruction. - I.dropUnknownNonDebugMetadata(); - - if (isa<LoadInst>(I)) ++NumMovedLoads; - else if (isa<CallInst>(I)) ++NumMovedCalls; + if (isa<LoadInst>(I)) + ++NumMovedLoads; + else if (isa<CallInst>(I)) + ++NumMovedCalls; ++NumHoisted; return true; } @@ -736,134 +747,91 @@ static bool hoist(Instruction &I, BasicBlock *Preheader) { /// Only sink or hoist an instruction if it is not a trapping instruction, /// or if the instruction is known not to trap when moved to the preheader. /// or if it is a trapping instruction and is guaranteed to execute. -static bool isSafeToExecuteUnconditionally(const Instruction &Inst, +static bool isSafeToExecuteUnconditionally(const Instruction &Inst, const DominatorTree *DT, - const TargetLibraryInfo *TLI, const Loop *CurLoop, - const LICMSafetyInfo *SafetyInfo, + const LoopSafetyInfo *SafetyInfo, const Instruction *CtxI) { - if (isSafeToSpeculativelyExecute(&Inst, CtxI, DT, TLI)) + if (isSafeToSpeculativelyExecute(&Inst, CtxI, DT)) return true; return isGuaranteedToExecute(Inst, DT, CurLoop, SafetyInfo); } -static bool isGuaranteedToExecute(const Instruction &Inst, - const DominatorTree *DT, - const Loop *CurLoop, - const LICMSafetyInfo * SafetyInfo) { - - // We have to check to make sure that the instruction dominates all - // of the exit blocks. If it doesn't, then there is a path out of the loop - // which does not execute this instruction, so we can't hoist it. - - // If the instruction is in the header block for the loop (which is very - // common), it is always guaranteed to dominate the exit blocks. Since this - // is a common case, and can save some work, check it now. - if (Inst.getParent() == CurLoop->getHeader()) - // If there's a throw in the header block, we can't guarantee we'll reach - // Inst. - return !SafetyInfo->HeaderMayThrow; - - // Somewhere in this loop there is an instruction which may throw and make us - // exit the loop. - if (SafetyInfo->MayThrow) - return false; - - // Get the exit blocks for the current loop. - SmallVector<BasicBlock*, 8> ExitBlocks; - CurLoop->getExitBlocks(ExitBlocks); - - // Verify that the block dominates each of the exit blocks of the loop. - for (BasicBlock *ExitBlock : ExitBlocks) - if (!DT->dominates(Inst.getParent(), ExitBlock)) - return false; - - // As a degenerate case, if the loop is statically infinite then we haven't - // proven anything since there are no exit blocks. - if (ExitBlocks.empty()) - return false; - - return true; -} - namespace { - class LoopPromoter : public LoadAndStorePromoter { - Value *SomePtr; // Designated pointer to store to. - SmallPtrSetImpl<Value*> &PointerMustAliases; - SmallVectorImpl<BasicBlock*> &LoopExitBlocks; - SmallVectorImpl<Instruction*> &LoopInsertPts; - PredIteratorCache &PredCache; - AliasSetTracker &AST; - LoopInfo &LI; - DebugLoc DL; - int Alignment; - AAMDNodes AATags; - - Value *maybeInsertLCSSAPHI(Value *V, BasicBlock *BB) const { - if (Instruction *I = dyn_cast<Instruction>(V)) - if (Loop *L = LI.getLoopFor(I->getParent())) - if (!L->contains(BB)) { - // We need to create an LCSSA PHI node for the incoming value and - // store that. - PHINode *PN = - PHINode::Create(I->getType(), PredCache.size(BB), - I->getName() + ".lcssa", &BB->front()); - for (BasicBlock *Pred : PredCache.get(BB)) - PN->addIncoming(I, Pred); - return PN; - } - return V; - } +class LoopPromoter : public LoadAndStorePromoter { + Value *SomePtr; // Designated pointer to store to. + SmallPtrSetImpl<Value *> &PointerMustAliases; + SmallVectorImpl<BasicBlock *> &LoopExitBlocks; + SmallVectorImpl<Instruction *> &LoopInsertPts; + PredIteratorCache &PredCache; + AliasSetTracker &AST; + LoopInfo &LI; + DebugLoc DL; + int Alignment; + AAMDNodes AATags; - public: - LoopPromoter(Value *SP, - ArrayRef<const Instruction *> Insts, - SSAUpdater &S, SmallPtrSetImpl<Value *> &PMA, - SmallVectorImpl<BasicBlock *> &LEB, - SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC, - AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, - const AAMDNodes &AATags) - : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), - LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), - LI(li), DL(dl), Alignment(alignment), AATags(AATags) {} - - bool isInstInList(Instruction *I, - const SmallVectorImpl<Instruction*> &) const override { - Value *Ptr; - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - Ptr = LI->getOperand(0); - else - Ptr = cast<StoreInst>(I)->getPointerOperand(); - return PointerMustAliases.count(Ptr); - } + Value *maybeInsertLCSSAPHI(Value *V, BasicBlock *BB) const { + if (Instruction *I = dyn_cast<Instruction>(V)) + if (Loop *L = LI.getLoopFor(I->getParent())) + if (!L->contains(BB)) { + // We need to create an LCSSA PHI node for the incoming value and + // store that. + PHINode *PN = PHINode::Create(I->getType(), PredCache.size(BB), + I->getName() + ".lcssa", &BB->front()); + for (BasicBlock *Pred : PredCache.get(BB)) + PN->addIncoming(I, Pred); + return PN; + } + return V; + } - void doExtraRewritesBeforeFinalDeletion() const override { - // Insert stores after in the loop exit blocks. Each exit block gets a - // store of the live-out values that feed them. Since we've already told - // the SSA updater about the defs in the loop and the preheader - // definition, it is all set and we can start using it. - for (unsigned i = 0, e = LoopExitBlocks.size(); i != e; ++i) { - BasicBlock *ExitBlock = LoopExitBlocks[i]; - Value *LiveInValue = SSA.GetValueInMiddleOfBlock(ExitBlock); - LiveInValue = maybeInsertLCSSAPHI(LiveInValue, ExitBlock); - Value *Ptr = maybeInsertLCSSAPHI(SomePtr, ExitBlock); - Instruction *InsertPos = LoopInsertPts[i]; - StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); - NewSI->setAlignment(Alignment); - NewSI->setDebugLoc(DL); - if (AATags) NewSI->setAAMetadata(AATags); - } - } +public: + LoopPromoter(Value *SP, ArrayRef<const Instruction *> Insts, SSAUpdater &S, + SmallPtrSetImpl<Value *> &PMA, + SmallVectorImpl<BasicBlock *> &LEB, + SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC, + AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, + const AAMDNodes &AATags) + : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), + LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), + LI(li), DL(std::move(dl)), Alignment(alignment), AATags(AATags) {} + + bool isInstInList(Instruction *I, + const SmallVectorImpl<Instruction *> &) const override { + Value *Ptr; + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + Ptr = LI->getOperand(0); + else + Ptr = cast<StoreInst>(I)->getPointerOperand(); + return PointerMustAliases.count(Ptr); + } - void replaceLoadWithValue(LoadInst *LI, Value *V) const override { - // Update alias analysis. - AST.copyValue(LI, V); + void doExtraRewritesBeforeFinalDeletion() const override { + // Insert stores after in the loop exit blocks. Each exit block gets a + // store of the live-out values that feed them. Since we've already told + // the SSA updater about the defs in the loop and the preheader + // definition, it is all set and we can start using it. + for (unsigned i = 0, e = LoopExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBlock = LoopExitBlocks[i]; + Value *LiveInValue = SSA.GetValueInMiddleOfBlock(ExitBlock); + LiveInValue = maybeInsertLCSSAPHI(LiveInValue, ExitBlock); + Value *Ptr = maybeInsertLCSSAPHI(SomePtr, ExitBlock); + Instruction *InsertPos = LoopInsertPts[i]; + StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); + NewSI->setAlignment(Alignment); + NewSI->setDebugLoc(DL); + if (AATags) + NewSI->setAAMetadata(AATags); } - void instructionDeleted(Instruction *I) const override { - AST.deleteValue(I); - } - }; + } + + void replaceLoadWithValue(LoadInst *LI, Value *V) const override { + // Update alias analysis. + AST.copyValue(LI, V); + } + void instructionDeleted(Instruction *I) const override { AST.deleteValue(I); } +}; } // end anon namespace /// Try to promote memory values to scalars by sinking stores out of the @@ -871,32 +839,28 @@ namespace { /// the stores in the loop, looking for stores to Must pointers which are /// loop invariant. /// -bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, - SmallVectorImpl<BasicBlock*>&ExitBlocks, - SmallVectorImpl<Instruction*>&InsertPts, - PredIteratorCache &PIC, LoopInfo *LI, - DominatorTree *DT, Loop *CurLoop, - AliasSetTracker *CurAST, - LICMSafetyInfo * SafetyInfo) { +bool llvm::promoteLoopAccessesToScalars( + AliasSet &AS, SmallVectorImpl<BasicBlock *> &ExitBlocks, + SmallVectorImpl<Instruction *> &InsertPts, PredIteratorCache &PIC, + LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, + Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo) { // Verify inputs. - assert(LI != nullptr && DT != nullptr && - CurLoop != nullptr && CurAST != nullptr && - SafetyInfo != nullptr && + assert(LI != nullptr && DT != nullptr && CurLoop != nullptr && + CurAST != nullptr && SafetyInfo != nullptr && "Unexpected Input to promoteLoopAccessesToScalars"); - // Initially set Changed status to false. - bool Changed = false; + // We can promote this alias set if it has a store, if it is a "Must" alias // set, if the pointer is loop invariant, and if we are not eliminating any // volatile loads or stores. if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || AS.isVolatile() || !CurLoop->isLoopInvariant(AS.begin()->getValue())) - return Changed; + return false; assert(!AS.empty() && "Must alias set should have at least one pointer element in it!"); Value *SomePtr = AS.begin()->getValue(); - BasicBlock * Preheader = CurLoop->getLoopPreheader(); + BasicBlock *Preheader = CurLoop->getLoopPreheader(); // It isn't safe to promote a load/store from the loop if the load/store is // conditional. For example, turning: @@ -909,12 +873,27 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, // // is not safe, because *P may only be valid to access if 'c' is true. // + // The safety property divides into two parts: + // 1) The memory may not be dereferenceable on entry to the loop. In this + // case, we can't insert the required load in the preheader. + // 2) The memory model does not allow us to insert a store along any dynamic + // path which did not originally have one. + // // It is safe to promote P if all uses are direct load/stores and if at // least one is guaranteed to be executed. bool GuaranteedToExecute = false; - SmallVector<Instruction*, 64> LoopUses; - SmallPtrSet<Value*, 4> PointerMustAliases; + // It is also safe to promote P if we can prove that speculating a load into + // the preheader is safe (i.e. proving dereferenceability on all + // paths through the loop), and that the memory can be proven thread local + // (so that the memory model requirement doesn't apply.) We first establish + // the former, and then run a capture analysis below to establish the later. + // We can use any access within the alias set to prove dereferenceability + // since they're all must alias. + bool CanSpeculateLoad = false; + + SmallVector<Instruction *, 64> LoopUses; + SmallPtrSet<Value *, 4> PointerMustAliases; // We start with an alignment of one and try to find instructions that allow // us to prove better alignment. @@ -922,11 +901,32 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, AAMDNodes AATags; bool HasDedicatedExits = CurLoop->hasDedicatedExits(); + // Don't sink stores from loops without dedicated block exits. Exits + // containing indirect branches are not transformed by loop simplify, + // make sure we catch that. An additional load may be generated in the + // preheader for SSA updater, so also avoid sinking when no preheader + // is available. + if (!HasDedicatedExits || !Preheader) + return false; + + const DataLayout &MDL = Preheader->getModule()->getDataLayout(); + + if (SafetyInfo->MayThrow) { + // If a loop can throw, we have to insert a store along each unwind edge. + // That said, we can't actually make the unwind edge explicit. Therefore, + // we have to prove that the store is dead along the unwind edge. + // + // Currently, this code just special-cases alloca instructions. + if (!isa<AllocaInst>(GetUnderlyingObject(SomePtr, MDL))) + return false; + } + // Check that all of the pointers in the alias set have the same type. We // cannot (yet) promote a memory location that is loaded and stored in // different sizes. While we are at it, collect alignment and AA info. - for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) { - Value *ASIV = ASI->getValue(); + bool Changed = false; + for (const auto &ASI : AS) { + Value *ASIV = ASI.getValue(); PointerMustAliases.insert(ASIV); // Check that all of the pointers in the alias set have the same type. We @@ -947,6 +947,10 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, assert(!Load->isVolatile() && "AST broken"); if (!Load->isSimple()) return Changed; + + if (!GuaranteedToExecute && !CanSpeculateLoad) + CanSpeculateLoad = isSafeToExecuteUnconditionally( + *Load, DT, CurLoop, SafetyInfo, Preheader->getTerminator()); } else if (const StoreInst *Store = dyn_cast<StoreInst>(UI)) { // Stores *of* the pointer are not interesting, only stores *to* the // pointer. @@ -955,13 +959,6 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, assert(!Store->isVolatile() && "AST broken"); if (!Store->isSimple()) return Changed; - // Don't sink stores from loops without dedicated block exits. Exits - // containing indirect branches are not transformed by loop simplify, - // make sure we catch that. An additional load may be generated in the - // preheader for SSA updater, so also avoid sinking when no preheader - // is available. - if (!HasDedicatedExits || !Preheader) - return Changed; // Note that we only check GuaranteedToExecute inside the store case // so that we do not introduce stores where they did not exist before @@ -972,16 +969,22 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, // instruction will be executed, update the alignment. // Larger is better, with the exception of 0 being the best alignment. unsigned InstAlignment = Store->getAlignment(); - if ((InstAlignment > Alignment || InstAlignment == 0) && Alignment != 0) + if ((InstAlignment > Alignment || InstAlignment == 0) && + Alignment != 0) { if (isGuaranteedToExecute(*UI, DT, CurLoop, SafetyInfo)) { GuaranteedToExecute = true; Alignment = InstAlignment; } + } else if (!GuaranteedToExecute) { + GuaranteedToExecute = + isGuaranteedToExecute(*UI, DT, CurLoop, SafetyInfo); + } - if (!GuaranteedToExecute) - GuaranteedToExecute = isGuaranteedToExecute(*UI, DT, - CurLoop, SafetyInfo); - + if (!GuaranteedToExecute && !CanSpeculateLoad) { + CanSpeculateLoad = isDereferenceableAndAlignedPointer( + Store->getPointerOperand(), Store->getAlignment(), MDL, + Preheader->getTerminator(), DT); + } } else return Changed; // Not a load or store. @@ -997,8 +1000,17 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, } } - // If there isn't a guaranteed-to-execute instruction, we can't promote. - if (!GuaranteedToExecute) + // Check legality per comment above. Otherwise, we can't promote. + bool PromotionIsLegal = GuaranteedToExecute; + if (!PromotionIsLegal && CanSpeculateLoad) { + // If this is a thread local location, then we can insert stores along + // paths which originally didn't have them without violating the memory + // model. + Value *Object = GetUnderlyingObject(SomePtr, MDL); + PromotionIsLegal = + isAllocLikeFn(Object, TLI) && !PointerMayBeCaptured(Object, true, true); + } + if (!PromotionIsLegal) return Changed; // Figure out the loop exits and their insertion points, if this is the @@ -1017,7 +1029,8 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, return Changed; // Otherwise, this is safe to promote, lets do it! - DEBUG(dbgs() << "LICM: Promoting value stored to in loop: " <<*SomePtr<<'\n'); + DEBUG(dbgs() << "LICM: Promoting value stored to in loop: " << *SomePtr + << '\n'); Changed = true; ++NumPromoted; @@ -1028,20 +1041,19 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, DebugLoc DL = LoopUses[0]->getDebugLoc(); // We use the SSAUpdater interface to insert phi nodes as required. - SmallVector<PHINode*, 16> NewPHIs; + SmallVector<PHINode *, 16> NewPHIs; SSAUpdater SSA(&NewPHIs); - LoopPromoter Promoter(SomePtr, LoopUses, SSA, - PointerMustAliases, ExitBlocks, + LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, InsertPts, PIC, *CurAST, *LI, DL, Alignment, AATags); // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. - LoadInst *PreheaderLoad = - new LoadInst(SomePtr, SomePtr->getName()+".promoted", - Preheader->getTerminator()); + LoadInst *PreheaderLoad = new LoadInst( + SomePtr, SomePtr->getName() + ".promoted", Preheader->getTerminator()); PreheaderLoad->setAlignment(Alignment); PreheaderLoad->setDebugLoc(DL); - if (AATags) PreheaderLoad->setAAMetadata(AATags); + if (AATags) + PreheaderLoad->setAAMetadata(AATags); SSA.AddAvailableValue(Preheader, PreheaderLoad); // Rewrite all the loads in the loop and remember all the definitions from @@ -1055,10 +1067,67 @@ bool llvm::promoteLoopAccessesToScalars(AliasSet &AS, return Changed; } +/// Returns an owning pointer to an alias set which incorporates aliasing info +/// from L and all subloops of L. +/// FIXME: In new pass manager, there is no helper functions to handle loop +/// analysis such as cloneBasicBlockAnalysis. So the AST needs to be recompute +/// from scratch for every loop. Hook up with the helper functions when +/// available in the new pass manager to avoid redundant computation. +AliasSetTracker * +LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, + AliasAnalysis *AA) { + AliasSetTracker *CurAST = nullptr; + SmallVector<Loop *, 4> RecomputeLoops; + for (Loop *InnerL : L->getSubLoops()) { + auto MapI = LoopToAliasSetMap.find(InnerL); + // If the AST for this inner loop is missing it may have been merged into + // some other loop's AST and then that loop unrolled, and so we need to + // recompute it. + if (MapI == LoopToAliasSetMap.end()) { + RecomputeLoops.push_back(InnerL); + continue; + } + AliasSetTracker *InnerAST = MapI->second; + + if (CurAST != nullptr) { + // What if InnerLoop was modified by other passes ? + CurAST->add(*InnerAST); + + // Once we've incorporated the inner loop's AST into ours, we don't need + // the subloop's anymore. + delete InnerAST; + } else { + CurAST = InnerAST; + } + LoopToAliasSetMap.erase(MapI); + } + if (CurAST == nullptr) + CurAST = new AliasSetTracker(*AA); + + auto mergeLoop = [&](Loop *L) { + // Loop over the body of this loop, looking for calls, invokes, and stores. + // Because subloops have already been incorporated into AST, we skip blocks + // in subloops. + for (BasicBlock *BB : L->blocks()) + if (LI->getLoopFor(BB) == L) // Ignore blocks in subloops. + CurAST->add(*BB); // Incorporate the specified basic block + }; + + // Add everything from the sub loops that are no longer directly available. + for (Loop *InnerL : RecomputeLoops) + mergeLoop(InnerL); + + // And merge in this loop. + mergeLoop(L); + + return CurAST; +} + /// Simple analysis hook. Clone alias set info. /// -void LICM::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, Loop *L) { - AliasSetTracker *AST = LoopToAliasSetMap.lookup(L); +void LegacyLICMPass::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, + Loop *L) { + AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); if (!AST) return; @@ -1067,8 +1136,8 @@ void LICM::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, Loop *L) { /// Simple Analysis hook. Delete value V from alias set /// -void LICM::deleteAnalysisValue(Value *V, Loop *L) { - AliasSetTracker *AST = LoopToAliasSetMap.lookup(L); +void LegacyLICMPass::deleteAnalysisValue(Value *V, Loop *L) { + AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); if (!AST) return; @@ -1077,21 +1146,20 @@ void LICM::deleteAnalysisValue(Value *V, Loop *L) { /// Simple Analysis hook. Delete value L from alias set map. /// -void LICM::deleteAnalysisLoop(Loop *L) { - AliasSetTracker *AST = LoopToAliasSetMap.lookup(L); +void LegacyLICMPass::deleteAnalysisLoop(Loop *L) { + AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); if (!AST) return; delete AST; - LoopToAliasSetMap.erase(L); + LICM.getLoopToAliasSetMap().erase(L); } - /// Return true if the body of this loop may store into the memory /// location pointed to by V. /// static bool pointerInvalidatedByLoop(Value *V, uint64_t Size, - const AAMDNodes &AAInfo, + const AAMDNodes &AAInfo, AliasSetTracker *CurAST) { // Check to see if any of the basic blocks in CurLoop invalidate *V. return CurAST->getAliasSetForPointer(V, Size, AAInfo).isMod(); @@ -1104,4 +1172,3 @@ static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI) { assert(CurLoop->contains(BB) && "Only valid if BB is IN the loop"); return LI->getLoopFor(BB) != CurLoop; } - diff --git a/lib/Transforms/Scalar/LoadCombine.cpp b/lib/Transforms/Scalar/LoadCombine.cpp index 1648878b0628..dfe51a4ce44c 100644 --- a/lib/Transforms/Scalar/LoadCombine.cpp +++ b/lib/Transforms/Scalar/LoadCombine.cpp @@ -35,10 +35,12 @@ using namespace llvm; STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining"); STATISTIC(NumLoadsCombined, "Number of loads combined"); +#define LDCOMBINE_NAME "Combine Adjacent Loads" + namespace { struct PointerOffsetPair { Value *Pointer; - uint64_t Offset; + APInt Offset; }; struct LoadPOPPair { @@ -63,12 +65,16 @@ public: using llvm::Pass::doInitialization; bool doInitialization(Function &) override; bool runOnBasicBlock(BasicBlock &BB) override; - void getAnalysisUsage(AnalysisUsage &AU) const override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } - const char *getPassName() const override { return "LoadCombine"; } + const char *getPassName() const override { return LDCOMBINE_NAME; } static char ID; - typedef IRBuilder<true, TargetFolder> BuilderTy; + typedef IRBuilder<TargetFolder> BuilderTy; private: BuilderTy *Builder; @@ -87,22 +93,25 @@ bool LoadCombine::doInitialization(Function &F) { } PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) { + auto &DL = LI.getModule()->getDataLayout(); + PointerOffsetPair POP; POP.Pointer = LI.getPointerOperand(); - POP.Offset = 0; + unsigned BitWidth = DL.getPointerSizeInBits(LI.getPointerAddressSpace()); + POP.Offset = APInt(BitWidth, 0); + while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) { if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) { - auto &DL = LI.getModule()->getDataLayout(); - unsigned BitWidth = DL.getPointerTypeSizeInBits(GEP->getType()); - APInt Offset(BitWidth, 0); - if (GEP->accumulateConstantOffset(DL, Offset)) - POP.Offset += Offset.getZExtValue(); - else + APInt LastOffset = POP.Offset; + if (!GEP->accumulateConstantOffset(DL, POP.Offset)) { // Can't handle GEPs with variable indices. + POP.Offset = LastOffset; return POP; + } POP.Pointer = GEP->getPointerOperand(); - } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer)) + } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer)) { POP.Pointer = BC->getOperand(0); + } } return POP; } @@ -115,8 +124,8 @@ bool LoadCombine::combineLoads( continue; std::sort(Loads.second.begin(), Loads.second.end(), [](const LoadPOPPair &A, const LoadPOPPair &B) { - return A.POP.Offset < B.POP.Offset; - }); + return A.POP.Offset.slt(B.POP.Offset); + }); if (aggregateLoads(Loads.second)) Combined = true; } @@ -132,28 +141,31 @@ bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) { LoadInst *BaseLoad = nullptr; SmallVector<LoadPOPPair, 8> AggregateLoads; bool Combined = false; - uint64_t PrevOffset = -1ull; + bool ValidPrevOffset = false; + APInt PrevOffset; uint64_t PrevSize = 0; for (auto &L : Loads) { - if (PrevOffset == -1ull) { + if (ValidPrevOffset == false) { BaseLoad = L.Load; PrevOffset = L.POP.Offset; PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize( L.Load->getType()); AggregateLoads.push_back(L); + ValidPrevOffset = true; continue; } if (L.Load->getAlignment() > BaseLoad->getAlignment()) continue; - if (L.POP.Offset > PrevOffset + PrevSize) { + APInt PrevEnd = PrevOffset + PrevSize; + if (L.POP.Offset.sgt(PrevEnd)) { // No other load will be combinable if (combineLoads(AggregateLoads)) Combined = true; AggregateLoads.clear(); - PrevOffset = -1; + ValidPrevOffset = false; continue; } - if (L.POP.Offset != PrevOffset + PrevSize) + if (L.POP.Offset != PrevEnd) // This load is offset less than the size of the last load. // FIXME: We may want to handle this case. continue; @@ -199,7 +211,7 @@ bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) { Value *Ptr = Builder->CreateConstGEP1_64( Builder->CreatePointerCast(Loads[0].POP.Pointer, Builder->getInt8PtrTy(AddressSpace)), - Loads[0].POP.Offset); + Loads[0].POP.Offset.getSExtValue()); LoadInst *NewLoad = new LoadInst( Builder->CreatePointerCast( Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize), @@ -212,7 +224,7 @@ bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) { Value *V = Builder->CreateExtractInteger( L.Load->getModule()->getDataLayout(), NewLoad, cast<IntegerType>(L.Load->getType()), - L.POP.Offset - Loads[0].POP.Offset, "combine.extract"); + (L.POP.Offset - Loads[0].POP.Offset).getZExtValue(), "combine.extract"); L.Load->replaceAllUsesWith(V); } @@ -221,12 +233,12 @@ bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) { } bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { - if (skipOptnoneFunction(BB)) + if (skipBasicBlock(BB)) return false; AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - IRBuilder<true, TargetFolder> TheBuilder( + IRBuilder<TargetFolder> TheBuilder( BB.getContext(), TargetFolder(BB.getModule()->getDataLayout())); Builder = &TheBuilder; @@ -260,23 +272,12 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { return Combined; } -void LoadCombine::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesCFG(); - - AU.addRequired<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); -} - char LoadCombine::ID = 0; BasicBlockPass *llvm::createLoadCombinePass() { return new LoadCombine(); } -INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", "Combine Adjacent Loads", - false, false) +INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(LoadCombine, "load-combine", "Combine Adjacent Loads", - false, false) - +INITIALIZE_PASS_END(LoadCombine, "load-combine", LDCOMBINE_NAME, false, false) diff --git a/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/lib/Transforms/Scalar/LoopDataPrefetch.cpp new file mode 100644 index 000000000000..66b59d27dfde --- /dev/null +++ b/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -0,0 +1,304 @@ +//===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a Loop Data Prefetching Pass. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loop-data-prefetch" +#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +using namespace llvm; + +// By default, we limit this to creating 16 PHIs (which is a little over half +// of the allocatable register set). +static cl::opt<bool> +PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false), + cl::desc("Prefetch write addresses")); + +static cl::opt<unsigned> + PrefetchDistance("prefetch-distance", + cl::desc("Number of instructions to prefetch ahead"), + cl::Hidden); + +static cl::opt<unsigned> + MinPrefetchStride("min-prefetch-stride", + cl::desc("Min stride to add prefetches"), cl::Hidden); + +static cl::opt<unsigned> MaxPrefetchIterationsAhead( + "max-prefetch-iters-ahead", + cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden); + +STATISTIC(NumPrefetches, "Number of prefetches inserted"); + +namespace llvm { + void initializeLoopDataPrefetchPass(PassRegistry&); +} + +namespace { + + class LoopDataPrefetch : public FunctionPass { + public: + static char ID; // Pass ID, replacement for typeid + LoopDataPrefetch() : FunctionPass(ID) { + initializeLoopDataPrefetchPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + // FIXME: For some reason, preserving SE here breaks LSR (even if + // this pass changes nothing). + // AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + bool runOnFunction(Function &F) override; + + private: + bool runOnLoop(Loop *L); + + /// \brief Check if the the stride of the accesses is large enough to + /// warrant a prefetch. + bool isStrideLargeEnough(const SCEVAddRecExpr *AR); + + unsigned getMinPrefetchStride() { + if (MinPrefetchStride.getNumOccurrences() > 0) + return MinPrefetchStride; + return TTI->getMinPrefetchStride(); + } + + unsigned getPrefetchDistance() { + if (PrefetchDistance.getNumOccurrences() > 0) + return PrefetchDistance; + return TTI->getPrefetchDistance(); + } + + unsigned getMaxPrefetchIterationsAhead() { + if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0) + return MaxPrefetchIterationsAhead; + return TTI->getMaxPrefetchIterationsAhead(); + } + + AssumptionCache *AC; + LoopInfo *LI; + ScalarEvolution *SE; + const TargetTransformInfo *TTI; + const DataLayout *DL; + }; +} + +char LoopDataPrefetch::ID = 0; +INITIALIZE_PASS_BEGIN(LoopDataPrefetch, "loop-data-prefetch", + "Loop Data Prefetch", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(LoopDataPrefetch, "loop-data-prefetch", + "Loop Data Prefetch", false, false) + +FunctionPass *llvm::createLoopDataPrefetchPass() { return new LoopDataPrefetch(); } + +bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { + unsigned TargetMinStride = getMinPrefetchStride(); + // No need to check if any stride goes. + if (TargetMinStride <= 1) + return true; + + const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)); + // If MinStride is set, don't prefetch unless we can ensure that stride is + // larger. + if (!ConstStride) + return false; + + unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue()); + return TargetMinStride <= AbsStride; +} + +bool LoopDataPrefetch::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + DL = &F.getParent()->getDataLayout(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + // If PrefetchDistance is not set, don't run the pass. This gives an + // opportunity for targets to run this pass for selected subtargets only + // (whose TTI sets PrefetchDistance). + if (getPrefetchDistance() == 0) + return false; + assert(TTI->getCacheLineSize() && "Cache line size is not set for target"); + + bool MadeChange = false; + + for (Loop *I : *LI) + for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L) + MadeChange |= runOnLoop(*L); + + return MadeChange; +} + +bool LoopDataPrefetch::runOnLoop(Loop *L) { + bool MadeChange = false; + + // Only prefetch in the inner-most loop + if (!L->empty()) + return MadeChange; + + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + + // Calculate the number of iterations ahead to prefetch + CodeMetrics Metrics; + for (Loop::block_iterator I = L->block_begin(), IE = L->block_end(); + I != IE; ++I) { + + // If the loop already has prefetches, then assume that the user knows + // what they are doing and don't add any more. + for (BasicBlock::iterator J = (*I)->begin(), JE = (*I)->end(); + J != JE; ++J) + if (CallInst *CI = dyn_cast<CallInst>(J)) + if (Function *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::prefetch) + return MadeChange; + + Metrics.analyzeBasicBlock(*I, *TTI, EphValues); + } + unsigned LoopSize = Metrics.NumInsts; + if (!LoopSize) + LoopSize = 1; + + unsigned ItersAhead = getPrefetchDistance() / LoopSize; + if (!ItersAhead) + ItersAhead = 1; + + if (ItersAhead > getMaxPrefetchIterationsAhead()) + return MadeChange; + + Function *F = L->getHeader()->getParent(); + DEBUG(dbgs() << "Prefetching " << ItersAhead + << " iterations ahead (loop size: " << LoopSize << ") in " + << F->getName() << ": " << *L); + + SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads; + for (Loop::block_iterator I = L->block_begin(), IE = L->block_end(); + I != IE; ++I) { + for (BasicBlock::iterator J = (*I)->begin(), JE = (*I)->end(); + J != JE; ++J) { + Value *PtrValue; + Instruction *MemI; + + if (LoadInst *LMemI = dyn_cast<LoadInst>(J)) { + MemI = LMemI; + PtrValue = LMemI->getPointerOperand(); + } else if (StoreInst *SMemI = dyn_cast<StoreInst>(J)) { + if (!PrefetchWrites) continue; + MemI = SMemI; + PtrValue = SMemI->getPointerOperand(); + } else continue; + + unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); + if (PtrAddrSpace) + continue; + + if (L->isLoopInvariant(PtrValue)) + continue; + + const SCEV *LSCEV = SE->getSCEV(PtrValue); + const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV); + if (!LSCEVAddRec) + continue; + + // Check if the the stride of the accesses is large enough to warrant a + // prefetch. + if (!isStrideLargeEnough(LSCEVAddRec)) + continue; + + // We don't want to double prefetch individual cache lines. If this load + // is known to be within one cache line of some other load that has + // already been prefetched, then don't prefetch this one as well. + bool DupPref = false; + for (const auto &PrefLoad : PrefLoads) { + const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second); + if (const SCEVConstant *ConstPtrDiff = + dyn_cast<SCEVConstant>(PtrDiff)) { + int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue()); + if (PD < (int64_t) TTI->getCacheLineSize()) { + DupPref = true; + break; + } + } + } + if (DupPref) + continue; + + const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr( + SE->getConstant(LSCEVAddRec->getType(), ItersAhead), + LSCEVAddRec->getStepRecurrence(*SE))); + if (!isSafeToExpand(NextLSCEV, *SE)) + continue; + + PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec)); + + Type *I8Ptr = Type::getInt8PtrTy((*I)->getContext(), PtrAddrSpace); + SCEVExpander SCEVE(*SE, J->getModule()->getDataLayout(), "prefaddr"); + Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI); + + IRBuilder<> Builder(MemI); + Module *M = (*I)->getParent()->getParent(); + Type *I32 = Type::getInt32Ty((*I)->getContext()); + Value *PrefetchFunc = Intrinsic::getDeclaration(M, Intrinsic::prefetch); + Builder.CreateCall( + PrefetchFunc, + {PrefPtrValue, + ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1), + ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); + ++NumPrefetches; + DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV + << "\n"); + emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, + MemI->getDebugLoc(), "prefetched memory access"); + + + MadeChange = true; + } + } + + return MadeChange; +} + diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp index 7b1940b48c31..19b2f89555c2 100644 --- a/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/lib/Transforms/Scalar/LoopDeletion.cpp @@ -14,75 +14,28 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopDeletion.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/LoopPassManager.h" #include "llvm/IR/Dominators.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; #define DEBUG_TYPE "loop-delete" STATISTIC(NumDeleted, "Number of loops deleted"); -namespace { - class LoopDeletion : public LoopPass { - public: - static char ID; // Pass ID, replacement for typeid - LoopDeletion() : LoopPass(ID) { - initializeLoopDeletionPass(*PassRegistry::getPassRegistry()); - } - - // Possibly eliminate loop L if it is dead. - bool runOnLoop(Loop *L, LPPassManager &) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); - - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreservedID(LoopSimplifyID); - AU.addPreservedID(LCSSAID); - } - - private: - bool isLoopDead(Loop *L, SmallVectorImpl<BasicBlock *> &exitingBlocks, - SmallVectorImpl<BasicBlock *> &exitBlocks, - bool &Changed, BasicBlock *Preheader); - - }; -} - -char LoopDeletion::ID = 0; -INITIALIZE_PASS_BEGIN(LoopDeletion, "loop-deletion", - "Delete dead loops", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSA) -INITIALIZE_PASS_END(LoopDeletion, "loop-deletion", - "Delete dead loops", false, false) - -Pass *llvm::createLoopDeletionPass() { - return new LoopDeletion(); -} - /// isLoopDead - Determined if a loop is dead. This assumes that we've already /// checked for unique exit and exiting blocks, and that the code is in LCSSA /// form. -bool LoopDeletion::isLoopDead(Loop *L, - SmallVectorImpl<BasicBlock *> &exitingBlocks, - SmallVectorImpl<BasicBlock *> &exitBlocks, - bool &Changed, BasicBlock *Preheader) { +bool LoopDeletionPass::isLoopDead(Loop *L, ScalarEvolution &SE, + SmallVectorImpl<BasicBlock *> &exitingBlocks, + SmallVectorImpl<BasicBlock *> &exitBlocks, + bool &Changed, BasicBlock *Preheader) { BasicBlock *exitBlock = exitBlocks[0]; // Make sure that all PHI entries coming from the loop are loop invariant. @@ -91,6 +44,8 @@ bool LoopDeletion::isLoopDead(Loop *L, // sufficient to guarantee that no loop-variant values are used outside // of the loop. BasicBlock::iterator BI = exitBlock->begin(); + bool AllEntriesInvariant = true; + bool AllOutgoingValuesSame = true; while (PHINode *P = dyn_cast<PHINode>(BI)) { Value *incoming = P->getIncomingValueForBlock(exitingBlocks[0]); @@ -98,27 +53,37 @@ bool LoopDeletion::isLoopDead(Loop *L, // block. If there are different incoming values for different exiting // blocks, then it is impossible to statically determine which value should // be used. - for (unsigned i = 1, e = exitingBlocks.size(); i < e; ++i) { - if (incoming != P->getIncomingValueForBlock(exitingBlocks[i])) - return false; - } + AllOutgoingValuesSame = + all_of(makeArrayRef(exitingBlocks).slice(1), [&](BasicBlock *BB) { + return incoming == P->getIncomingValueForBlock(BB); + }); + + if (!AllOutgoingValuesSame) + break; if (Instruction *I = dyn_cast<Instruction>(incoming)) - if (!L->makeLoopInvariant(I, Changed, Preheader->getTerminator())) - return false; + if (!L->makeLoopInvariant(I, Changed, Preheader->getTerminator())) { + AllEntriesInvariant = false; + break; + } ++BI; } + if (Changed) + SE.forgetLoopDispositions(L); + + if (!AllEntriesInvariant || !AllOutgoingValuesSame) + return false; + // Make sure that no instructions in the block have potential side-effects. // This includes instructions that could write to memory, and loads that are // marked volatile. This could be made more aggressive by using aliasing // information to identify readonly and readnone calls. for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); LI != LE; ++LI) { - for (BasicBlock::iterator BI = (*LI)->begin(), BE = (*LI)->end(); - BI != BE; ++BI) { - if (BI->mayHaveSideEffects()) + for (Instruction &I : **LI) { + if (I.mayHaveSideEffects()) return false; } } @@ -126,15 +91,15 @@ bool LoopDeletion::isLoopDead(Loop *L, return true; } -/// runOnLoop - Remove dead loops, by which we mean loops that do not impact the -/// observable behavior of the program other than finite running time. Note -/// we do ensure that this never remove a loop that might be infinite, as doing -/// so could change the halting/non-halting nature of a program. -/// NOTE: This entire process relies pretty heavily on LoopSimplify and LCSSA -/// in order to make various safety checks work. -bool LoopDeletion::runOnLoop(Loop *L, LPPassManager &) { - if (skipOptnoneFunction(L)) - return false; +/// Remove dead loops, by which we mean loops that do not impact the observable +/// behavior of the program other than finite running time. Note we do ensure +/// that this never remove a loop that might be infinite, as doing so could +/// change the halting/non-halting nature of a program. NOTE: This entire +/// process relies pretty heavily on LoopSimplify and LCSSA in order to make +/// various safety checks work. +bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE, + LoopInfo &loopInfo) { + assert(L->isLCSSAForm(DT) && "Expected LCSSA!"); // We can only remove the loop if there is a preheader that we can // branch from after removing it. @@ -151,10 +116,10 @@ bool LoopDeletion::runOnLoop(Loop *L, LPPassManager &) { if (L->begin() != L->end()) return false; - SmallVector<BasicBlock*, 4> exitingBlocks; + SmallVector<BasicBlock *, 4> exitingBlocks; L->getExitingBlocks(exitingBlocks); - SmallVector<BasicBlock*, 4> exitBlocks; + SmallVector<BasicBlock *, 4> exitBlocks; L->getUniqueExitBlocks(exitBlocks); // We require that the loop only have a single exit block. Otherwise, we'd @@ -166,12 +131,11 @@ bool LoopDeletion::runOnLoop(Loop *L, LPPassManager &) { // Finally, we have to check that the loop really is dead. bool Changed = false; - if (!isLoopDead(L, exitingBlocks, exitBlocks, Changed, preheader)) + if (!isLoopDead(L, SE, exitingBlocks, exitBlocks, Changed, preheader)) return Changed; // Don't remove loops for which we can't solve the trip count. // They could be infinite, in which case we'd be changing program behavior. - ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); const SCEV *S = SE.getMaxBackedgeTakenCount(L); if (isa<SCEVCouldNotCompute>(S)) return Changed; @@ -208,16 +172,14 @@ bool LoopDeletion::runOnLoop(Loop *L, LPPassManager &) { // Update the dominator tree and remove the instructions and blocks that will // be deleted from the reference counting scheme. - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); SmallVector<DomTreeNode*, 8> ChildNodes; for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); LI != LE; ++LI) { // Move all of the block's children to be children of the preheader, which // allows us to remove the domtree entry for the block. ChildNodes.insert(ChildNodes.begin(), DT[*LI]->begin(), DT[*LI]->end()); - for (SmallVectorImpl<DomTreeNode *>::iterator DI = ChildNodes.begin(), - DE = ChildNodes.end(); DI != DE; ++DI) { - DT.changeImmediateDominator(*DI, DT[preheader]); + for (DomTreeNode *ChildNode : ChildNodes) { + DT.changeImmediateDominator(ChildNode, DT[preheader]); } ChildNodes.clear(); @@ -238,8 +200,8 @@ bool LoopDeletion::runOnLoop(Loop *L, LPPassManager &) { // Finally, the blocks from loopinfo. This has to happen late because // otherwise our loop iterators won't work. - LoopInfo &loopInfo = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - SmallPtrSet<BasicBlock*, 8> blocks; + + SmallPtrSet<BasicBlock *, 8> blocks; blocks.insert(L->block_begin(), L->block_end()); for (BasicBlock *BB : blocks) loopInfo.removeBlock(BB); @@ -252,3 +214,56 @@ bool LoopDeletion::runOnLoop(Loop *L, LPPassManager &) { return Changed; } + +PreservedAnalyses LoopDeletionPass::run(Loop &L, AnalysisManager<Loop> &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function *F = L.getHeader()->getParent(); + + auto &DT = *FAM.getCachedResult<DominatorTreeAnalysis>(*F); + auto &SE = *FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); + auto &LI = *FAM.getCachedResult<LoopAnalysis>(*F); + + bool Changed = runImpl(&L, DT, SE, LI); + if (!Changed) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +namespace { +class LoopDeletionLegacyPass : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + LoopDeletionLegacyPass() : LoopPass(ID) { + initializeLoopDeletionLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + // Possibly eliminate loop L if it is dead. + bool runOnLoop(Loop *L, LPPassManager &) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + getLoopAnalysisUsage(AU); + } +}; +} + +char LoopDeletionLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopDeletionLegacyPass, "loop-deletion", + "Delete dead loops", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(LoopDeletionLegacyPass, "loop-deletion", + "Delete dead loops", false, false) + +Pass *llvm::createLoopDeletionPass() { return new LoopDeletionLegacyPass(); } + +bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &) { + if (skipLoop(L)) + return false; + + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + LoopInfo &loopInfo = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + + LoopDeletionPass Impl; + return Impl.runImpl(L, DT, SE, loopInfo); +} diff --git a/lib/Transforms/Scalar/LoopDistribute.cpp b/lib/Transforms/Scalar/LoopDistribute.cpp index 3d3cf3e2890b..7eca28ed2bb7 100644 --- a/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/lib/Transforms/Scalar/LoopDistribute.cpp @@ -22,12 +22,17 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/LoopDistribute.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPassManager.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -60,6 +65,19 @@ static cl::opt<unsigned> DistributeSCEVCheckThreshold( cl::desc("The maximum number of SCEV checks allowed for Loop " "Distribution")); +static cl::opt<unsigned> PragmaDistributeSCEVCheckThreshold( + "loop-distribute-scev-check-threshold-with-pragma", cl::init(128), + cl::Hidden, + cl::desc( + "The maximum number of SCEV checks allowed for Loop " + "Distribution for loop marked with #pragma loop distribute(enable)")); + +// Note that the initial value for this depends on whether the pass is invoked +// directly or from the optimization pipeline. +static cl::opt<bool> EnableLoopDistribute( + "enable-loop-distribute", cl::Hidden, + cl::desc("Enable the new, experimental LoopDistribution Pass")); + STATISTIC(NumLoopsDistributed, "Number of loops distributed"); namespace { @@ -170,7 +188,7 @@ public: // Delete the instructions backwards, as it has a reduced likelihood of // having to update as many def-use and use-def chains. - for (auto *Inst : make_range(Unused.rbegin(), Unused.rend())) { + for (auto *Inst : reverse(Unused)) { if (!Inst->use_empty()) Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); Inst->eraseFromParent(); @@ -571,121 +589,39 @@ private: AccessesType Accesses; }; -/// \brief The pass class. -class LoopDistribute : public FunctionPass { +/// \brief The actual class performing the per-loop work. +class LoopDistributeForLoop { public: - LoopDistribute() : FunctionPass(ID) { - initializeLoopDistributePass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LAA = &getAnalysis<LoopAccessAnalysis>(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - - // Build up a worklist of inner-loops to vectorize. This is necessary as the - // act of distributing a loop creates new loops and can invalidate iterators - // across the loops. - SmallVector<Loop *, 8> Worklist; - - for (Loop *TopLevelLoop : *LI) - for (Loop *L : depth_first(TopLevelLoop)) - // We only handle inner-most loops. - if (L->empty()) - Worklist.push_back(L); - - // Now walk the identified inner loops. - bool Changed = false; - for (Loop *L : Worklist) - Changed |= processLoop(L); - - // Process each loop nest in the function. - return Changed; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<LoopAccessAnalysis>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - } - - static char ID; - -private: - /// \brief Filter out checks between pointers from the same partition. - /// - /// \p PtrToPartition contains the partition number for pointers. Partition - /// number -1 means that the pointer is used in multiple partitions. In this - /// case we can't safely omit the check. - SmallVector<RuntimePointerChecking::PointerCheck, 4> - includeOnlyCrossPartitionChecks( - const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &AllChecks, - const SmallVectorImpl<int> &PtrToPartition, - const RuntimePointerChecking *RtPtrChecking) { - SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; - - std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { - for (unsigned PtrIdx1 : Check.first->Members) - for (unsigned PtrIdx2 : Check.second->Members) - // Only include this check if there is a pair of pointers - // that require checking and the pointers fall into - // separate partitions. - // - // (Note that we already know at this point that the two - // pointer groups need checking but it doesn't follow - // that each pair of pointers within the two groups need - // checking as well. - // - // In other words we don't want to include a check just - // because there is a pair of pointers between the two - // pointer groups that require checks and a different - // pair whose pointers fall into different partitions.) - if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) && - !RuntimePointerChecking::arePointersInSamePartition( - PtrToPartition, PtrIdx1, PtrIdx2)) - return true; - return false; - }); - - return Checks; + LoopDistributeForLoop(Loop *L, Function *F, LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE, OptimizationRemarkEmitter *ORE) + : L(L), F(F), LI(LI), LAI(nullptr), DT(DT), SE(SE), ORE(ORE) { + setForced(); } /// \brief Try to distribute an inner-most loop. - bool processLoop(Loop *L) { + bool processLoop(std::function<const LoopAccessInfo &(Loop &)> &GetLAA) { assert(L->empty() && "Only process inner loops."); DEBUG(dbgs() << "\nLDist: In \"" << L->getHeader()->getParent()->getName() << "\" checking " << *L << "\n"); BasicBlock *PH = L->getLoopPreheader(); - if (!PH) { - DEBUG(dbgs() << "Skipping; no preheader"); - return false; - } - if (!L->getExitBlock()) { - DEBUG(dbgs() << "Skipping; multiple exit blocks"); - return false; - } - // LAA will check that we only have a single exiting block. + if (!PH) + return fail("no preheader"); + if (!L->getExitBlock()) + return fail("multiple exit blocks"); - const LoopAccessInfo &LAI = LAA->getInfo(L, ValueToValueMap()); + // LAA will check that we only have a single exiting block. + LAI = &GetLAA(*L); // Currently, we only distribute to isolate the part of the loop with // dependence cycles to enable partial vectorization. - if (LAI.canVectorizeMemory()) { - DEBUG(dbgs() << "Skipping; memory operations are safe for vectorization"); - return false; - } - auto *Dependences = LAI.getDepChecker().getDependences(); - if (!Dependences || Dependences->empty()) { - DEBUG(dbgs() << "Skipping; No unsafe dependences to isolate"); - return false; - } + if (LAI->canVectorizeMemory()) + return fail("memory operations are safe for vectorization"); + + auto *Dependences = LAI->getDepChecker().getDependences(); + if (!Dependences || Dependences->empty()) + return fail("no unsafe dependences to isolate"); InstPartitionContainer Partitions(L, LI, DT); @@ -708,7 +644,7 @@ private: // NumUnsafeDependencesActive > 0 indicates this situation and in this case // we just keep assigning to the same cyclic partition until // NumUnsafeDependencesActive reaches 0. - const MemoryDepChecker &DepChecker = LAI.getDepChecker(); + const MemoryDepChecker &DepChecker = LAI->getDepChecker(); MemoryInstructionDependences MID(DepChecker.getMemoryInstructions(), *Dependences); @@ -738,14 +674,14 @@ private: DEBUG(dbgs() << "Seeded partitions:\n" << Partitions); if (Partitions.getSize() < 2) - return false; + return fail("cannot isolate unsafe dependencies"); // Run the merge heuristics: Merge non-cyclic adjacent partitions since we // should be able to vectorize these together. Partitions.mergeBeforePopulating(); DEBUG(dbgs() << "\nMerged partitions:\n" << Partitions); if (Partitions.getSize() < 2) - return false; + return fail("cannot isolate unsafe dependencies"); // Now, populate the partitions with non-memory operations. Partitions.populateUsedSet(); @@ -757,15 +693,15 @@ private: DEBUG(dbgs() << "\nPartitions merged to ensure unique loads:\n" << Partitions); if (Partitions.getSize() < 2) - return false; + return fail("cannot isolate unsafe dependencies"); } // Don't distribute the loop if we need too many SCEV run-time checks. - const SCEVUnionPredicate &Pred = LAI.PSE.getUnionPredicate(); - if (Pred.getComplexity() > DistributeSCEVCheckThreshold) { - DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); - return false; - } + const SCEVUnionPredicate &Pred = LAI->getPSE().getUnionPredicate(); + if (Pred.getComplexity() > (IsForced.getValueOr(false) + ? PragmaDistributeSCEVCheckThreshold + : DistributeSCEVCheckThreshold)) + return fail("too many SCEV run-time checks needed.\n"); DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); // We're done forming the partitions set up the reverse mapping from @@ -779,19 +715,20 @@ private: SplitBlock(PH, PH->getTerminator(), DT, LI); // If we need run-time checks, version the loop now. - auto PtrToPartition = Partitions.computePartitionSetForPointers(LAI); - const auto *RtPtrChecking = LAI.getRuntimePointerChecking(); + auto PtrToPartition = Partitions.computePartitionSetForPointers(*LAI); + const auto *RtPtrChecking = LAI->getRuntimePointerChecking(); const auto &AllChecks = RtPtrChecking->getChecks(); auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition, RtPtrChecking); if (!Pred.isAlwaysTrue() || !Checks.empty()) { DEBUG(dbgs() << "\nPointers:\n"); - DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); - LoopVersioning LVer(LAI, L, LI, DT, SE, false); + DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks)); + LoopVersioning LVer(*LAI, L, LI, DT, SE, false); LVer.setAliasChecks(std::move(Checks)); - LVer.setSCEVChecks(LAI.PSE.getUnionPredicate()); + LVer.setSCEVChecks(LAI->getPSE().getUnionPredicate()); LVer.versionLoop(DefsUsedOutside); + LVer.annotateLoopWithNoAlias(); } // Create identical copies of the original loop for each partition and hook @@ -810,27 +747,244 @@ private: } ++NumLoopsDistributed; + // Report the success. + emitOptimizationRemark(F->getContext(), LDIST_NAME, *F, L->getStartLoc(), + "distributed loop"); return true; } + /// \brief Provide diagnostics then \return with false. + bool fail(llvm::StringRef Message) { + LLVMContext &Ctx = F->getContext(); + bool Forced = isForced().getValueOr(false); + + DEBUG(dbgs() << "Skipping; " << Message << "\n"); + + // With Rpass-missed report that distribution failed. + ORE->emitOptimizationRemarkMissed( + LDIST_NAME, L, + "loop not distributed: use -Rpass-analysis=loop-distribute for more " + "info"); + + // With Rpass-analysis report why. This is on by default if distribution + // was requested explicitly. + emitOptimizationRemarkAnalysis( + Ctx, Forced ? DiagnosticInfoOptimizationRemarkAnalysis::AlwaysPrint + : LDIST_NAME, + *F, L->getStartLoc(), Twine("loop not distributed: ") + Message); + + // Also issue a warning if distribution was requested explicitly but it + // failed. + if (Forced) + Ctx.diagnose(DiagnosticInfoOptimizationFailure( + *F, L->getStartLoc(), "loop not distributed: failed " + "explicitly specified loop distribution")); + + return false; + } + + /// \brief Return if distribution forced to be enabled/disabled for the loop. + /// + /// If the optional has a value, it indicates whether distribution was forced + /// to be enabled (true) or disabled (false). If the optional has no value + /// distribution was not forced either way. + const Optional<bool> &isForced() const { return IsForced; } + +private: + /// \brief Filter out checks between pointers from the same partition. + /// + /// \p PtrToPartition contains the partition number for pointers. Partition + /// number -1 means that the pointer is used in multiple partitions. In this + /// case we can't safely omit the check. + SmallVector<RuntimePointerChecking::PointerCheck, 4> + includeOnlyCrossPartitionChecks( + const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &AllChecks, + const SmallVectorImpl<int> &PtrToPartition, + const RuntimePointerChecking *RtPtrChecking) { + SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; + + std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks), + [&](const RuntimePointerChecking::PointerCheck &Check) { + for (unsigned PtrIdx1 : Check.first->Members) + for (unsigned PtrIdx2 : Check.second->Members) + // Only include this check if there is a pair of pointers + // that require checking and the pointers fall into + // separate partitions. + // + // (Note that we already know at this point that the two + // pointer groups need checking but it doesn't follow + // that each pair of pointers within the two groups need + // checking as well. + // + // In other words we don't want to include a check just + // because there is a pair of pointers between the two + // pointer groups that require checks and a different + // pair whose pointers fall into different partitions.) + if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) && + !RuntimePointerChecking::arePointersInSamePartition( + PtrToPartition, PtrIdx1, PtrIdx2)) + return true; + return false; + }); + + return Checks; + } + + /// \brief Check whether the loop metadata is forcing distribution to be + /// enabled/disabled. + void setForced() { + Optional<const MDOperand *> Value = + findStringMetadataForLoop(L, "llvm.loop.distribute.enable"); + if (!Value) + return; + + const MDOperand *Op = *Value; + assert(Op && mdconst::hasa<ConstantInt>(*Op) && "invalid metadata"); + IsForced = mdconst::extract<ConstantInt>(*Op)->getZExtValue(); + } + + Loop *L; + Function *F; + // Analyses used. LoopInfo *LI; - LoopAccessAnalysis *LAA; + const LoopAccessInfo *LAI; DominatorTree *DT; ScalarEvolution *SE; + OptimizationRemarkEmitter *ORE; + + /// \brief Indicates whether distribution is forced to be enabled/disabled for + /// the loop. + /// + /// If the optional has a value, it indicates whether distribution was forced + /// to be enabled (true) or disabled (false). If the optional has no value + /// distribution was not forced either way. + Optional<bool> IsForced; +}; + +/// Shared implementation between new and old PMs. +static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, + std::function<const LoopAccessInfo &(Loop &)> &GetLAA, + bool ProcessAllLoops) { + // Build up a worklist of inner-loops to vectorize. This is necessary as the + // act of distributing a loop creates new loops and can invalidate iterators + // across the loops. + SmallVector<Loop *, 8> Worklist; + + for (Loop *TopLevelLoop : *LI) + for (Loop *L : depth_first(TopLevelLoop)) + // We only handle inner-most loops. + if (L->empty()) + Worklist.push_back(L); + + // Now walk the identified inner loops. + bool Changed = false; + for (Loop *L : Worklist) { + LoopDistributeForLoop LDL(L, &F, LI, DT, SE, ORE); + + // If distribution was forced for the specific loop to be + // enabled/disabled, follow that. Otherwise use the global flag. + if (LDL.isForced().getValueOr(ProcessAllLoops)) + Changed |= LDL.processLoop(GetLAA); + } + + // Process each loop nest in the function. + return Changed; +} + +/// \brief The pass class. +class LoopDistributeLegacy : public FunctionPass { +public: + /// \p ProcessAllLoopsByDefault specifies whether loop distribution should be + /// performed by default. Pass -enable-loop-distribute={0,1} overrides this + /// default. We use this to keep LoopDistribution off by default when invoked + /// from the optimization pipeline but on when invoked explicitly from opt. + LoopDistributeLegacy(bool ProcessAllLoopsByDefault = true) + : FunctionPass(ID), ProcessAllLoops(ProcessAllLoopsByDefault) { + // The default is set by the caller. + if (EnableLoopDistribute.getNumOccurrences() > 0) + ProcessAllLoops = EnableLoopDistribute; + initializeLoopDistributeLegacyPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + std::function<const LoopAccessInfo &(Loop &)> GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; + + return runImpl(F, LI, DT, SE, ORE, GetLAA, ProcessAllLoops); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<LoopAccessLegacyAnalysis>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + } + + static char ID; + +private: + /// \brief Whether distribution should be on in this function. The per-loop + /// pragma can override this. + bool ProcessAllLoops; }; } // anonymous namespace -char LoopDistribute::ID; +PreservedAnalyses LoopDistributePass::run(Function &F, + FunctionAnalysisManager &AM) { + // FIXME: This does not currently match the behavior from the old PM. + // ProcessAllLoops with the old PM defaults to true when invoked from opt and + // false when invoked from the optimization pipeline. + bool ProcessAllLoops = false; + if (EnableLoopDistribute.getNumOccurrences() > 0) + ProcessAllLoops = EnableLoopDistribute; + + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); + std::function<const LoopAccessInfo &(Loop &)> GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { + return LAM.getResult<LoopAccessAnalysis>(L); + }; + + bool Changed = runImpl(F, &LI, &DT, &SE, &ORE, GetLAA, ProcessAllLoops); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<LoopAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + return PA; +} + +char LoopDistributeLegacy::ID; static const char ldist_name[] = "Loop Distribition"; -INITIALIZE_PASS_BEGIN(LoopDistribute, LDIST_NAME, ldist_name, false, false) +INITIALIZE_PASS_BEGIN(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, + false) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_END(LoopDistribute, LDIST_NAME, ldist_name, false, false) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_END(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, false) namespace llvm { -FunctionPass *createLoopDistributePass() { return new LoopDistribute(); } +FunctionPass *createLoopDistributePass(bool ProcessAllLoopsByDefault) { + return new LoopDistributeLegacy(ProcessAllLoopsByDefault); +} } diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 4521640e3947..1468676a3543 100644 --- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -26,22 +26,21 @@ // i64 and larger types when i64 is legal and the value has few bits set. It // would be good to enhance isel to emit a loop for ctpop in this case. // -// We should enhance the memset/memcpy recognition to handle multiple stores in -// the loop. This would handle things like: -// void foo(_Complex float *P) -// for (i) { __real__(*P) = 0; __imag__(*P) = 0; } -// // This could recognize common matrix multiplies and dot product idioms and // replace them with calls to BLAS (if linked in??). // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -55,7 +54,10 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; #define DEBUG_TYPE "loop-idiom" @@ -65,7 +67,7 @@ STATISTIC(NumMemCpy, "Number of memcpy's formed from loop load+stores"); namespace { -class LoopIdiomRecognize : public LoopPass { +class LoopIdiomRecognize { Loop *CurLoop; AliasAnalysis *AA; DominatorTree *DT; @@ -76,39 +78,21 @@ class LoopIdiomRecognize : public LoopPass { const DataLayout *DL; public: - static char ID; - explicit LoopIdiomRecognize() : LoopPass(ID) { - initializeLoopIdiomRecognizePass(*PassRegistry::getPassRegistry()); - } + explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, + const DataLayout *DL) + : CurLoop(nullptr), AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), + DL(DL) {} - bool runOnLoop(Loop *L, LPPassManager &LPM) override; - - /// This transformation requires natural loop information & requires that - /// loop preheaders be inserted into the CFG. - /// - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addPreservedID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); - AU.addPreservedID(LCSSAID); - AU.addRequired<AAResultsWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<SCEVAAWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<BasicAAWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } + bool runOnLoop(Loop *L); private: typedef SmallVector<StoreInst *, 8> StoreList; - StoreList StoreRefsForMemset; + typedef MapVector<Value *, StoreList> StoreListMap; + StoreListMap StoreRefsForMemset; + StoreListMap StoreRefsForMemsetPattern; StoreList StoreRefsForMemcpy; bool HasMemset; bool HasMemsetPattern; @@ -122,14 +106,18 @@ private: SmallVectorImpl<BasicBlock *> &ExitBlocks); void collectStores(BasicBlock *BB); - bool isLegalStore(StoreInst *SI, bool &ForMemset, bool &ForMemcpy); - bool processLoopStore(StoreInst *SI, const SCEV *BECount); + bool isLegalStore(StoreInst *SI, bool &ForMemset, bool &ForMemsetPattern, + bool &ForMemcpy); + bool processLoopStores(SmallVectorImpl<StoreInst *> &SL, const SCEV *BECount, + bool ForMemset); bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize, unsigned StoreAlignment, Value *StoredVal, - Instruction *TheStore, const SCEVAddRecExpr *Ev, - const SCEV *BECount, bool NegStride); + Instruction *TheStore, + SmallPtrSetImpl<Instruction *> &Stores, + const SCEVAddRecExpr *Ev, const SCEV *BECount, + bool NegStride); bool processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount); /// @} @@ -145,38 +133,82 @@ private: /// @} }; +class LoopIdiomRecognizeLegacyPass : public LoopPass { +public: + static char ID; + explicit LoopIdiomRecognizeLegacyPass() : LoopPass(ID) { + initializeLoopIdiomRecognizeLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + const TargetTransformInfo *TTI = + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *L->getHeader()->getParent()); + const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout(); + + LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL); + return LIR.runOnLoop(L); + } + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG. + /// + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } +}; } // End anonymous namespace. -char LoopIdiomRecognize::ID = 0; -INITIALIZE_PASS_BEGIN(LoopIdiomRecognize, "loop-idiom", "Recognize loop idioms", - false, false) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSA) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, + AnalysisManager<Loop> &AM) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function *F = L.getHeader()->getParent(); + + // Use getCachedResult because Loop pass cannot trigger a function analysis. + auto *AA = FAM.getCachedResult<AAManager>(*F); + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); + auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); + auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); + auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(*F); + const auto *TTI = FAM.getCachedResult<TargetIRAnalysis>(*F); + const auto *DL = &L.getHeader()->getModule()->getDataLayout(); + assert((AA && DT && LI && SE && TLI && TTI && DL) && + "Analyses for Loop Idiom Recognition not available"); + + LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL); + if (!LIR.runOnLoop(&L)) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +char LoopIdiomRecognizeLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopIdiomRecognizeLegacyPass, "loop-idiom", + "Recognize loop idioms", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(LoopIdiomRecognize, "loop-idiom", "Recognize loop idioms", - false, false) +INITIALIZE_PASS_END(LoopIdiomRecognizeLegacyPass, "loop-idiom", + "Recognize loop idioms", false, false) -Pass *llvm::createLoopIdiomPass() { return new LoopIdiomRecognize(); } +Pass *llvm::createLoopIdiomPass() { return new LoopIdiomRecognizeLegacyPass(); } -/// deleteDeadInstruction - Delete this instruction. Before we do, go through -/// and zero out all the operands of this instruction. If any of them become -/// dead, delete them and the computation tree that feeds them. -/// -static void deleteDeadInstruction(Instruction *I, - const TargetLibraryInfo *TLI) { - SmallVector<Value *, 16> Operands(I->value_op_begin(), I->value_op_end()); +static void deleteDeadInstruction(Instruction *I) { I->replaceAllUsesWith(UndefValue::get(I->getType())); I->eraseFromParent(); - for (Value *Op : Operands) - RecursivelyDeleteTriviallyDeadInstructions(Op, TLI); } //===----------------------------------------------------------------------===// @@ -185,10 +217,7 @@ static void deleteDeadInstruction(Instruction *I, // //===----------------------------------------------------------------------===// -bool LoopIdiomRecognize::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipOptnoneFunction(L)) - return false; - +bool LoopIdiomRecognize::runOnLoop(Loop *L) { CurLoop = L; // If the loop could not be converted to canonical form, it must have an // indirectbr in it, just give up. @@ -200,15 +229,6 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L, LPPassManager &LPM) { if (Name == "memset" || Name == "memcpy") return false; - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *CurLoop->getHeader()->getParent()); - DL = &CurLoop->getHeader()->getModule()->getDataLayout(); - HasMemset = TLI->has(LibFunc::memset); HasMemsetPattern = TLI->has(LibFunc::memset_pattern16); HasMemcpy = TLI->has(LibFunc::memcpy); @@ -240,6 +260,14 @@ bool LoopIdiomRecognize::runOnCountableLoop() { << CurLoop->getHeader()->getName() << "\n"); bool MadeChange = false; + + // The following transforms hoist stores/memsets into the loop pre-header. + // Give up if the loop has instructions may throw. + LoopSafetyInfo SafetyInfo; + computeLoopSafetyInfo(&SafetyInfo, CurLoop); + if (SafetyInfo.MayThrow) + return MadeChange; + // Scan all the blocks in the loop that are not in subloops. for (auto *BB : CurLoop->getBlocks()) { // Ignore blocks in subloops. @@ -258,9 +286,9 @@ static unsigned getStoreSizeInBytes(StoreInst *SI, const DataLayout *DL) { return (unsigned)SizeInBits >> 3; } -static unsigned getStoreStride(const SCEVAddRecExpr *StoreEv) { +static APInt getStoreStride(const SCEVAddRecExpr *StoreEv) { const SCEVConstant *ConstStride = cast<SCEVConstant>(StoreEv->getOperand(1)); - return ConstStride->getAPInt().getZExtValue(); + return ConstStride->getAPInt(); } /// getMemSetPatternValue - If a strided store of the specified value is safe to @@ -305,11 +333,15 @@ static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) { } bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, - bool &ForMemcpy) { + bool &ForMemsetPattern, bool &ForMemcpy) { // Don't touch volatile stores. if (!SI->isSimple()) return false; + // Avoid merging nontemporal stores. + if (SI->getMetadata(LLVMContext::MD_nontemporal)) + return false; + Value *StoredVal = SI->getValueOperand(); Value *StorePtr = SI->getPointerOperand(); @@ -353,7 +385,7 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, StorePtr->getType()->getPointerAddressSpace() == 0 && (PatternValue = getMemSetPatternValue(StoredVal, DL))) { // It looks like we can use PatternValue! - ForMemset = true; + ForMemsetPattern = true; return true; } @@ -361,7 +393,7 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, if (HasMemcpy) { // Check to see if the stride matches the size of the store. If so, then we // know that every byte is touched in the loop. - unsigned Stride = getStoreStride(StoreEv); + APInt Stride = getStoreStride(StoreEv); unsigned StoreSize = getStoreSizeInBytes(SI, DL); if (StoreSize != Stride && StoreSize != -Stride) return false; @@ -393,6 +425,7 @@ bool LoopIdiomRecognize::isLegalStore(StoreInst *SI, bool &ForMemset, void LoopIdiomRecognize::collectStores(BasicBlock *BB) { StoreRefsForMemset.clear(); + StoreRefsForMemsetPattern.clear(); StoreRefsForMemcpy.clear(); for (Instruction &I : *BB) { StoreInst *SI = dyn_cast<StoreInst>(&I); @@ -400,15 +433,22 @@ void LoopIdiomRecognize::collectStores(BasicBlock *BB) { continue; bool ForMemset = false; + bool ForMemsetPattern = false; bool ForMemcpy = false; // Make sure this is a strided store with a constant stride. - if (!isLegalStore(SI, ForMemset, ForMemcpy)) + if (!isLegalStore(SI, ForMemset, ForMemsetPattern, ForMemcpy)) continue; // Save the store locations. - if (ForMemset) - StoreRefsForMemset.push_back(SI); - else if (ForMemcpy) + if (ForMemset) { + // Find the base pointer. + Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), *DL); + StoreRefsForMemset[Ptr].push_back(SI); + } else if (ForMemsetPattern) { + // Find the base pointer. + Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), *DL); + StoreRefsForMemsetPattern[Ptr].push_back(SI); + } else if (ForMemcpy) StoreRefsForMemcpy.push_back(SI); } } @@ -430,9 +470,14 @@ bool LoopIdiomRecognize::runOnLoopBlock( // Look for store instructions, which may be optimized to memset/memcpy. collectStores(BB); - // Look for a single store which can be optimized into a memset. - for (auto &SI : StoreRefsForMemset) - MadeChange |= processLoopStore(SI, BECount); + // Look for a single store or sets of stores with a common base, which can be + // optimized into a memset (memset_pattern). The latter most commonly happens + // with structs and handunrolled loops. + for (auto &SL : StoreRefsForMemset) + MadeChange |= processLoopStores(SL.second, BECount, true); + + for (auto &SL : StoreRefsForMemsetPattern) + MadeChange |= processLoopStores(SL.second, BECount, false); // Optimize the store into a memcpy, if it feeds an similarly strided load. for (auto &SI : StoreRefsForMemcpy) @@ -458,26 +503,144 @@ bool LoopIdiomRecognize::runOnLoopBlock( return MadeChange; } -/// processLoopStore - See if this store can be promoted to a memset. -bool LoopIdiomRecognize::processLoopStore(StoreInst *SI, const SCEV *BECount) { - assert(SI->isSimple() && "Expected only non-volatile stores."); +/// processLoopStores - See if this store(s) can be promoted to a memset. +bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, + const SCEV *BECount, + bool ForMemset) { + // Try to find consecutive stores that can be transformed into memsets. + SetVector<StoreInst *> Heads, Tails; + SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain; + + // Do a quadratic search on all of the given stores and find + // all of the pairs of stores that follow each other. + SmallVector<unsigned, 16> IndexQueue; + for (unsigned i = 0, e = SL.size(); i < e; ++i) { + assert(SL[i]->isSimple() && "Expected only non-volatile stores."); + + Value *FirstStoredVal = SL[i]->getValueOperand(); + Value *FirstStorePtr = SL[i]->getPointerOperand(); + const SCEVAddRecExpr *FirstStoreEv = + cast<SCEVAddRecExpr>(SE->getSCEV(FirstStorePtr)); + APInt FirstStride = getStoreStride(FirstStoreEv); + unsigned FirstStoreSize = getStoreSizeInBytes(SL[i], DL); + + // See if we can optimize just this store in isolation. + if (FirstStride == FirstStoreSize || -FirstStride == FirstStoreSize) { + Heads.insert(SL[i]); + continue; + } - Value *StoredVal = SI->getValueOperand(); - Value *StorePtr = SI->getPointerOperand(); + Value *FirstSplatValue = nullptr; + Constant *FirstPatternValue = nullptr; - // Check to see if the stride matches the size of the store. If so, then we - // know that every byte is touched in the loop. - const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); - unsigned Stride = getStoreStride(StoreEv); - unsigned StoreSize = getStoreSizeInBytes(SI, DL); - if (StoreSize != Stride && StoreSize != -Stride) - return false; + if (ForMemset) + FirstSplatValue = isBytewiseValue(FirstStoredVal); + else + FirstPatternValue = getMemSetPatternValue(FirstStoredVal, DL); + + assert((FirstSplatValue || FirstPatternValue) && + "Expected either splat value or pattern value."); + + IndexQueue.clear(); + // If a store has multiple consecutive store candidates, search Stores + // array according to the sequence: from i+1 to e, then from i-1 to 0. + // This is because usually pairing with immediate succeeding or preceding + // candidate create the best chance to find memset opportunity. + unsigned j = 0; + for (j = i + 1; j < e; ++j) + IndexQueue.push_back(j); + for (j = i; j > 0; --j) + IndexQueue.push_back(j - 1); + + for (auto &k : IndexQueue) { + assert(SL[k]->isSimple() && "Expected only non-volatile stores."); + Value *SecondStorePtr = SL[k]->getPointerOperand(); + const SCEVAddRecExpr *SecondStoreEv = + cast<SCEVAddRecExpr>(SE->getSCEV(SecondStorePtr)); + APInt SecondStride = getStoreStride(SecondStoreEv); + + if (FirstStride != SecondStride) + continue; - bool NegStride = StoreSize == -Stride; + Value *SecondStoredVal = SL[k]->getValueOperand(); + Value *SecondSplatValue = nullptr; + Constant *SecondPatternValue = nullptr; + + if (ForMemset) + SecondSplatValue = isBytewiseValue(SecondStoredVal); + else + SecondPatternValue = getMemSetPatternValue(SecondStoredVal, DL); + + assert((SecondSplatValue || SecondPatternValue) && + "Expected either splat value or pattern value."); + + if (isConsecutiveAccess(SL[i], SL[k], *DL, *SE, false)) { + if (ForMemset) { + if (FirstSplatValue != SecondSplatValue) + continue; + } else { + if (FirstPatternValue != SecondPatternValue) + continue; + } + Tails.insert(SL[k]); + Heads.insert(SL[i]); + ConsecutiveChain[SL[i]] = SL[k]; + break; + } + } + } + + // We may run into multiple chains that merge into a single chain. We mark the + // stores that we transformed so that we don't visit the same store twice. + SmallPtrSet<Value *, 16> TransformedStores; + bool Changed = false; + + // For stores that start but don't end a link in the chain: + for (SetVector<StoreInst *>::iterator it = Heads.begin(), e = Heads.end(); + it != e; ++it) { + if (Tails.count(*it)) + continue; + + // We found a store instr that starts a chain. Now follow the chain and try + // to transform it. + SmallPtrSet<Instruction *, 8> AdjacentStores; + StoreInst *I = *it; + + StoreInst *HeadStore = I; + unsigned StoreSize = 0; + + // Collect the chain into a list. + while (Tails.count(I) || Heads.count(I)) { + if (TransformedStores.count(I)) + break; + AdjacentStores.insert(I); + + StoreSize += getStoreSizeInBytes(I, DL); + // Move to the next value in the chain. + I = ConsecutiveChain[I]; + } - // See if we can optimize just this store in isolation. - return processLoopStridedStore(StorePtr, StoreSize, SI->getAlignment(), - StoredVal, SI, StoreEv, BECount, NegStride); + Value *StoredVal = HeadStore->getValueOperand(); + Value *StorePtr = HeadStore->getPointerOperand(); + const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); + APInt Stride = getStoreStride(StoreEv); + + // Check to see if the stride matches the size of the stores. If so, then + // we know that every byte is touched in the loop. + if (StoreSize != Stride && StoreSize != -Stride) + continue; + + bool NegStride = StoreSize == -Stride; + + if (processLoopStridedStore(StorePtr, StoreSize, HeadStore->getAlignment(), + StoredVal, HeadStore, AdjacentStores, StoreEv, + BECount, NegStride)) { + TransformedStores.insert(AdjacentStores.begin(), AdjacentStores.end()); + Changed = true; + } + } + + return Changed; } /// processLoopMemSet - See if this memset can be promoted to a large memset. @@ -488,7 +651,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, return false; // If we're not allowed to hack on memset, we fail. - if (!TLI->has(LibFunc::memset)) + if (!HasMemset) return false; Value *Pointer = MSI->getDest(); @@ -507,11 +670,12 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, // Check to see if the stride matches the size of the memset. If so, then we // know that every byte is touched in the loop. - const SCEVConstant *Stride = dyn_cast<SCEVConstant>(Ev->getOperand(1)); + const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1)); + if (!ConstStride) + return false; - // TODO: Could also handle negative stride here someday, that will require the - // validity check in mayLoopAccessLocation to be updated though. - if (!Stride || MSI->getLength() != Stride->getValue()) + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) return false; // Verify that the memset value is loop invariant. If not, we can't promote @@ -520,18 +684,22 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, if (!SplatValue || !CurLoop->isLoopInvariant(SplatValue)) return false; + SmallPtrSet<Instruction *, 1> MSIs; + MSIs.insert(MSI); + bool NegStride = SizeInBytes == -Stride; return processLoopStridedStore(Pointer, (unsigned)SizeInBytes, - MSI->getAlignment(), SplatValue, MSI, Ev, - BECount, /*NegStride=*/false); + MSI->getAlignment(), SplatValue, MSI, MSIs, Ev, + BECount, NegStride); } /// mayLoopAccessLocation - Return true if the specified loop might access the /// specified pointer location, which is a loop-strided access. The 'Access' /// argument specifies what the verboten forms of access are (read or write). -static bool mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, - const SCEV *BECount, unsigned StoreSize, - AliasAnalysis &AA, - Instruction *IgnoredStore) { +static bool +mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, + const SCEV *BECount, unsigned StoreSize, + AliasAnalysis &AA, + SmallPtrSetImpl<Instruction *> &IgnoredStores) { // Get the location that may be stored across the loop. Since the access is // strided positively through memory, we say that the modified location starts // at the pointer and has infinite size. @@ -550,8 +718,9 @@ static bool mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, for (Loop::block_iterator BI = L->block_begin(), E = L->block_end(); BI != E; ++BI) - for (BasicBlock::iterator I = (*BI)->begin(), E = (*BI)->end(); I != E; ++I) - if (&*I != IgnoredStore && (AA.getModRefInfo(&*I, StoreLoc) & Access)) + for (Instruction &I : **BI) + if (IgnoredStores.count(&I) == 0 && + (AA.getModRefInfo(&I, StoreLoc) & Access)) return true; return false; @@ -574,7 +743,8 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( Value *DestPtr, unsigned StoreSize, unsigned StoreAlignment, - Value *StoredVal, Instruction *TheStore, const SCEVAddRecExpr *Ev, + Value *StoredVal, Instruction *TheStore, + SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, bool NegStride) { Value *SplatValue = isBytewiseValue(StoredVal); Constant *PatternValue = nullptr; @@ -609,7 +779,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( Value *BasePtr = Expander.expandCodeFor(Start, DestInt8PtrTy, Preheader->getTerminator()); if (mayLoopAccessLocation(BasePtr, MRI_ModRef, CurLoop, BECount, StoreSize, - *AA, TheStore)) { + *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. RecursivelyDeleteTriviallyDeadInstructions(BasePtr, TLI); @@ -644,13 +814,14 @@ bool LoopIdiomRecognize::processLoopStridedStore( Value *MSP = M->getOrInsertFunction("memset_pattern16", Builder.getVoidTy(), Int8PtrTy, Int8PtrTy, IntPtr, (void *)nullptr); + inferLibFuncAttributes(*M->getFunction("memset_pattern16"), *TLI); // Otherwise we should form a memset_pattern16. PatternValue is known to be // an constant array of 16-bytes. Plop the value into a mergable global. GlobalVariable *GV = new GlobalVariable(*M, PatternValue->getType(), true, GlobalValue::PrivateLinkage, PatternValue, ".memset_pattern"); - GV->setUnnamedAddr(true); // Ok to merge these. + GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these. GV->setAlignment(16); Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy); NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); @@ -662,7 +833,8 @@ bool LoopIdiomRecognize::processLoopStridedStore( // Okay, the memset has been formed. Zap the original store and anything that // feeds into it. - deleteDeadInstruction(TheStore, TLI); + for (auto *I : Stores) + deleteDeadInstruction(I); ++NumMemSet; return true; } @@ -676,7 +848,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, Value *StorePtr = SI->getPointerOperand(); const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); - unsigned Stride = getStoreStride(StoreEv); + APInt Stride = getStoreStride(StoreEv); unsigned StoreSize = getStoreSizeInBytes(SI, DL); bool NegStride = StoreSize == -Stride; @@ -714,8 +886,10 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, Value *StoreBasePtr = Expander.expandCodeFor( StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator()); + SmallPtrSet<Instruction *, 1> Stores; + Stores.insert(SI); if (mayLoopAccessLocation(StoreBasePtr, MRI_ModRef, CurLoop, BECount, - StoreSize, *AA, SI)) { + StoreSize, *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); @@ -735,7 +909,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator()); if (mayLoopAccessLocation(LoadBasePtr, MRI_Mod, CurLoop, BECount, StoreSize, - *AA, SI)) { + *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); @@ -769,7 +943,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // Okay, the memcpy has been formed. Zap the original store and anything that // feeds into it. - deleteDeadInstruction(SI, TLI); + deleteDeadInstruction(SI); ++NumMemCpy; return true; } @@ -993,7 +1167,7 @@ bool LoopIdiomRecognize::recognizePopcount() { } static CallInst *createPopcntIntrinsic(IRBuilder<> &IRBuilder, Value *Val, - DebugLoc DL) { + const DebugLoc &DL) { Value *Ops[] = {Val}; Type *Tys[] = {Val->getType()}; diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp index b4102fe9ba34..629cb87d7a91 100644 --- a/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -11,88 +11,43 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopInstSimplify.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/Support/Debug.h" -#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; #define DEBUG_TYPE "loop-instsimplify" STATISTIC(NumSimplified, "Number of redundant instructions simplified"); -namespace { - class LoopInstSimplify : public LoopPass { - public: - static char ID; // Pass ID, replacement for typeid - LoopInstSimplify() : LoopPass(ID) { - initializeLoopInstSimplifyPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop*, LPPassManager&) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addPreservedID(LoopSimplifyID); - AU.addPreservedID(LCSSAID); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } - }; -} - -char LoopInstSimplify::ID = 0; -INITIALIZE_PASS_BEGIN(LoopInstSimplify, "loop-instsimplify", - "Simplify instructions in loops", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LCSSA) -INITIALIZE_PASS_END(LoopInstSimplify, "loop-instsimplify", - "Simplify instructions in loops", false, false) - -Pass *llvm::createLoopInstSimplifyPass() { - return new LoopInstSimplify(); -} - -bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipOptnoneFunction(L)) - return false; - - DominatorTreeWrapperPass *DTWP = - getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; - LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache( - *L->getHeader()->getParent()); - - SmallVector<BasicBlock*, 8> ExitBlocks; +static bool SimplifyLoopInst(Loop *L, DominatorTree *DT, LoopInfo *LI, + AssumptionCache *AC, + const TargetLibraryInfo *TLI) { + SmallVector<BasicBlock *, 8> ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); array_pod_sort(ExitBlocks.begin(), ExitBlocks.end()); - SmallPtrSet<const Instruction*, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; + SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; // The bit we are stealing from the pointer represents whether this basic // block is the header of a subloop, in which case we only process its phis. - typedef PointerIntPair<BasicBlock*, 1> WorklistItem; + typedef PointerIntPair<BasicBlock *, 1> WorklistItem; SmallVector<WorklistItem, 16> VisitStack; - SmallPtrSet<BasicBlock*, 32> Visited; + SmallPtrSet<BasicBlock *, 32> Visited; bool Changed = false; bool LocalChanged; @@ -122,7 +77,7 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { // Don't bother simplifying unused instructions. if (!I->use_empty()) { - Value *V = SimplifyInstruction(I, DL, TLI, DT, &AC); + Value *V = SimplifyInstruction(I, DL, TLI, DT, AC); if (V && LI->replacementPreservesLCSSAForm(I, V)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) @@ -133,14 +88,13 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { ++NumSimplified; } } - bool res = RecursivelyDeleteTriviallyDeadInstructions(I, TLI); - if (res) { - // RecursivelyDeleteTriviallyDeadInstruction can remove - // more than one instruction, so simply incrementing the - // iterator does not work. When instructions get deleted - // re-iterate instead. - BI = BB->begin(); BE = BB->end(); - LocalChanged |= res; + if (RecursivelyDeleteTriviallyDeadInstructions(I, TLI)) { + // RecursivelyDeleteTriviallyDeadInstruction can remove more than one + // instruction, so simply incrementing the iterator does not work. + // When instructions get deleted re-iterate instead. + BI = BB->begin(); + BE = BB->end(); + LocalChanged = true; } if (IsSubloopHeader && !isa<PHINode>(I)) @@ -148,8 +102,10 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { } // Add all successors to the worklist, except for loop exit blocks and the - // bodies of subloops. We visit the headers of loops so that we can process - // their phis, but we contract the rest of the subloop body and only follow + // bodies of subloops. We visit the headers of loops so that we can + // process + // their phis, but we contract the rest of the subloop body and only + // follow // edges leading back to the original loop. for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) { @@ -158,11 +114,11 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { continue; const Loop *SuccLoop = LI->getLoopFor(SuccBB); - if (SuccLoop && SuccLoop->getHeader() == SuccBB - && L->contains(SuccLoop)) { + if (SuccLoop && SuccLoop->getHeader() == SuccBB && + L->contains(SuccLoop)) { VisitStack.push_back(WorklistItem(SuccBB, true)); - SmallVector<BasicBlock*, 8> SubLoopExitBlocks; + SmallVector<BasicBlock *, 8> SubLoopExitBlocks; SuccLoop->getExitBlocks(SubLoopExitBlocks); for (unsigned i = 0; i < SubLoopExitBlocks.size(); ++i) { @@ -174,8 +130,8 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { continue; } - bool IsExitBlock = std::binary_search(ExitBlocks.begin(), - ExitBlocks.end(), SuccBB); + bool IsExitBlock = + std::binary_search(ExitBlocks.begin(), ExitBlocks.end(), SuccBB); if (IsExitBlock) continue; @@ -193,3 +149,68 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { return Changed; } + +namespace { +class LoopInstSimplifyLegacyPass : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + LoopInstSimplifyLegacyPass() : LoopPass(ID) { + initializeLoopInstSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + return SimplifyLoopInst(L, DT, LI, AC, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.setPreservesCFG(); + getLoopAnalysisUsage(AU); + } +}; +} + +PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, + AnalysisManager<Loop> &AM) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function *F = L.getHeader()->getParent(); + + // Use getCachedResult because Loop pass cannot trigger a function analysis. + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); + auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); + auto *AC = FAM.getCachedResult<AssumptionAnalysis>(*F); + const auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(*F); + assert((LI && AC && TLI) && "Analyses for Loop Inst Simplify not available"); + + if (!SimplifyLoopInst(&L, DT, LI, AC, TLI)) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +char LoopInstSimplifyLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopInstSimplifyLegacyPass, "loop-instsimplify", + "Simplify instructions in loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LoopInstSimplifyLegacyPass, "loop-instsimplify", + "Simplify instructions in loops", false, false) + +Pass *llvm::createLoopInstSimplifyPass() { + return new LoopInstSimplifyLegacyPass(); +} diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp index 4295235a3f36..9241ec365277 100644 --- a/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/lib/Transforms/Scalar/LoopInterchange.cpp @@ -15,7 +15,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AliasSetTracker.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CodeMetrics.h" @@ -72,7 +71,7 @@ void printDepMatrix(CharMatrix &DepMatrix) { #endif static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, - Loop *L, DependenceAnalysis *DA) { + Loop *L, DependenceInfo *DI) { typedef SmallVector<Value *, 16> ValueVector; ValueVector MemInstr; @@ -117,7 +116,7 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, continue; if (isa<LoadInst>(Src) && isa<LoadInst>(Des)) continue; - if (auto D = DA->depends(Src, Des, true)) { + if (auto D = DI->depends(Src, Des, true)) { DEBUG(dbgs() << "Found Dependency between Src=" << Src << " Des=" << Des << "\n"); if (D->isFlow()) { @@ -404,12 +403,9 @@ public: private: void splitInnerLoopLatch(Instruction *); - void splitOuterLoopLatch(); void splitInnerLoopHeader(); bool adjustLoopLinks(); void adjustLoopPreheaders(); - void adjustOuterLoopPreheader(); - void adjustInnerLoopPreheader(); bool adjustLoopBranches(); void updateIncomingBlock(BasicBlock *CurrBlock, BasicBlock *OldPred, BasicBlock *NewPred); @@ -430,11 +426,11 @@ struct LoopInterchange : public FunctionPass { static char ID; ScalarEvolution *SE; LoopInfo *LI; - DependenceAnalysis *DA; + DependenceInfo *DI; DominatorTree *DT; bool PreserveLCSSA; LoopInterchange() - : FunctionPass(ID), SE(nullptr), LI(nullptr), DA(nullptr), DT(nullptr) { + : FunctionPass(ID), SE(nullptr), LI(nullptr), DI(nullptr), DT(nullptr) { initializeLoopInterchangePass(*PassRegistry::getPassRegistry()); } @@ -443,15 +439,18 @@ struct LoopInterchange : public FunctionPass { AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequired<DependenceAnalysis>(); + AU.addRequired<DependenceAnalysisWrapperPass>(); AU.addRequiredID(LoopSimplifyID); AU.addRequiredID(LCSSAID); } bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - DA = &getAnalysis<DependenceAnalysis>(); + DI = &getAnalysis<DependenceAnalysisWrapperPass>().getDI(); auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); DT = DTWP ? &DTWP->getDomTree() : nullptr; PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); @@ -472,8 +471,7 @@ struct LoopInterchange : public FunctionPass { } bool isComputableLoopNest(LoopVector LoopList) { - for (auto I = LoopList.begin(), E = LoopList.end(); I != E; ++I) { - Loop *L = *I; + for (Loop *L : LoopList) { const SCEV *ExitCountOuter = SE->getBackedgeTakenCount(L); if (ExitCountOuter == SE->getCouldNotCompute()) { DEBUG(dbgs() << "Couldn't compute Backedge count\n"); @@ -491,7 +489,7 @@ struct LoopInterchange : public FunctionPass { return true; } - unsigned selectLoopForInterchange(LoopVector LoopList) { + unsigned selectLoopForInterchange(const LoopVector &LoopList) { // TODO: Add a better heuristic to select the loop to be interchanged based // on the dependence matrix. Currently we select the innermost loop. return LoopList.size() - 1; @@ -515,7 +513,7 @@ struct LoopInterchange : public FunctionPass { << "\n"); if (!populateDependencyMatrix(DependencyMatrix, LoopList.size(), - OuterMostLoop, DA)) { + OuterMostLoop, DI)) { DEBUG(dbgs() << "Populating Dependency matrix failed\n"); return false; } @@ -813,7 +811,6 @@ bool LoopInterchangeLegality::currentLimitations() { // A[j+1][i+2] = A[j][i]+k; // } // } - bool FoundInduction = false; Instruction *InnerIndexVarInc = nullptr; if (InnerInductionVar->getIncomingBlock(0) == InnerLoopPreHeader) InnerIndexVarInc = @@ -829,17 +826,17 @@ bool LoopInterchangeLegality::currentLimitations() { // we do not have any instruction between the induction variable and branch // instruction. - for (auto I = InnerLoopLatch->rbegin(), E = InnerLoopLatch->rend(); - I != E && !FoundInduction; ++I) { - if (isa<BranchInst>(*I) || isa<CmpInst>(*I) || isa<TruncInst>(*I)) + bool FoundInduction = false; + for (const Instruction &I : reverse(*InnerLoopLatch)) { + if (isa<BranchInst>(I) || isa<CmpInst>(I) || isa<TruncInst>(I)) continue; - const Instruction &Ins = *I; // We found an instruction. If this is not induction variable then it is not // safe to split this loop latch. - if (!Ins.isIdenticalTo(InnerIndexVarInc)) + if (!I.isIdenticalTo(InnerIndexVarInc)) return true; - else - FoundInduction = true; + + FoundInduction = true; + break; } // The loop latch ended and we didn't find the induction variable return as // current limitation. @@ -903,8 +900,7 @@ int LoopInterchangeProfitability::getInstrOrderCost() { BadOrder = GoodOrder = 0; for (auto BI = InnerLoop->block_begin(), BE = InnerLoop->block_end(); BI != BE; ++BI) { - for (auto I = (*BI)->begin(), E = (*BI)->end(); I != E; ++I) { - const Instruction &Ins = *I; + for (Instruction &Ins : **BI) { if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&Ins)) { unsigned NumOp = GEP->getNumOperands(); bool FoundInnerInduction = false; @@ -1073,13 +1069,6 @@ void LoopInterchangeTransform::splitInnerLoopLatch(Instruction *Inc) { InnerLoopLatch = SplitBlock(InnerLoopLatchPred, Inc, DT, LI); } -void LoopInterchangeTransform::splitOuterLoopLatch() { - BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); - BasicBlock *OuterLatchLcssaPhiBlock = OuterLoopLatch; - OuterLoopLatch = SplitBlock(OuterLatchLcssaPhiBlock, - OuterLoopLatch->getFirstNonPHI(), DT, LI); -} - void LoopInterchangeTransform::splitInnerLoopHeader() { // Split the inner loop header out. Here make sure that the reduction PHI's @@ -1102,8 +1091,7 @@ void LoopInterchangeTransform::splitInnerLoopHeader() { PHI->replaceAllUsesWith(V); PHIVec.push_back((PHI)); } - for (auto I = PHIVec.begin(), E = PHIVec.end(); I != E; ++I) { - PHINode *P = *I; + for (PHINode *P : PHIVec) { P->eraseFromParent(); } } else { @@ -1124,20 +1112,6 @@ static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { FromBB->getTerminator()->getIterator()); } -void LoopInterchangeTransform::adjustOuterLoopPreheader() { - BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); - BasicBlock *InnerPreHeader = InnerLoop->getLoopPreheader(); - - moveBBContents(OuterLoopPreHeader, InnerPreHeader->getTerminator()); -} - -void LoopInterchangeTransform::adjustInnerLoopPreheader() { - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - BasicBlock *OuterHeader = OuterLoop->getHeader(); - - moveBBContents(InnerLoopPreHeader, OuterHeader->getTerminator()); -} - void LoopInterchangeTransform::updateIncomingBlock(BasicBlock *CurrBlock, BasicBlock *OldPred, BasicBlock *NewPred) { @@ -1234,8 +1208,7 @@ bool LoopInterchangeTransform::adjustLoopBranches() { PHINode *LcssaPhi = cast<PHINode>(I); LcssaVec.push_back(LcssaPhi); } - for (auto I = LcssaVec.begin(), E = LcssaVec.end(); I != E; ++I) { - PHINode *P = *I; + for (PHINode *P : LcssaVec) { Value *Incoming = P->getIncomingValueForBlock(InnerLoopLatch); P->replaceAllUsesWith(Incoming); P->eraseFromParent(); @@ -1294,11 +1267,11 @@ char LoopInterchange::ID = 0; INITIALIZE_PASS_BEGIN(LoopInterchange, "loop-interchange", "Interchanges loops for cache reuse", false, false) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DependenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSA) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(LoopInterchange, "loop-interchange", diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp index 1064d088514d..f29228c7659e 100644 --- a/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -28,6 +28,7 @@ #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/LoopVersioning.h" #include <forward_list> @@ -61,7 +62,8 @@ struct StoreToLoadForwardingCandidate { /// \brief Return true if the dependence from the store to the load has a /// distance of one. E.g. A[i+1] = A[i] - bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE) const { + bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, + Loop *L) const { Value *LoadPtr = Load->getPointerOperand(); Value *StorePtr = Store->getPointerOperand(); Type *LoadPtrType = LoadPtr->getType(); @@ -72,6 +74,13 @@ struct StoreToLoadForwardingCandidate { LoadType == StorePtr->getType()->getPointerElementType() && "Should be a known dependence"); + // Currently we only support accesses with unit stride. FIXME: we should be + // able to handle non unit stirde as well as long as the stride is equal to + // the dependence distance. + if (getPtrStride(PSE, LoadPtr, L) != 1 || + getPtrStride(PSE, StorePtr, L) != 1) + return false; + auto &DL = Load->getParent()->getModule()->getDataLayout(); unsigned TypeByteSize = DL.getTypeAllocSize(const_cast<Type *>(LoadType)); @@ -83,7 +92,7 @@ struct StoreToLoadForwardingCandidate { auto *Dist = cast<SCEVConstant>( PSE.getSE()->getMinusSCEV(StorePtrSCEV, LoadPtrSCEV)); const APInt &Val = Dist->getAPInt(); - return Val.abs() == TypeByteSize; + return Val == TypeByteSize; } Value *getLoadPtr() const { return Load->getPointerOperand(); } @@ -110,12 +119,17 @@ bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, }); } +/// \brief Return true if the load is not executed on all paths in the loop. +static bool isLoadConditional(LoadInst *Load, Loop *L) { + return Load->getParent() != L->getHeader(); +} + /// \brief The per-loop class that does most of the work. class LoadEliminationForLoop { public: LoadEliminationForLoop(Loop *L, LoopInfo *LI, const LoopAccessInfo &LAI, DominatorTree *DT) - : L(L), LI(LI), LAI(LAI), DT(DT), PSE(LAI.PSE) {} + : L(L), LI(LI), LAI(LAI), DT(DT), PSE(LAI.getPSE()) {} /// \brief Look through the loop-carried and loop-independent dependences in /// this loop and find store->load dependences. @@ -162,6 +176,12 @@ public: auto *Load = dyn_cast<LoadInst>(Destination); if (!Load) continue; + + // Only progagate the value if they are of the same type. + if (Store->getPointerOperand()->getType() != + Load->getPointerOperand()->getType()) + continue; + Candidates.emplace_front(Load, Store); } @@ -219,12 +239,12 @@ public: if (OtherCand == nullptr) continue; - // Handle the very basic of case when the two stores are in the same - // block so deciding which one forwards is easy. The later one forwards - // as long as they both have a dependence distance of one to the load. + // Handle the very basic case when the two stores are in the same block + // so deciding which one forwards is easy. The later one forwards as + // long as they both have a dependence distance of one to the load. if (Cand.Store->getParent() == OtherCand->Store->getParent() && - Cand.isDependenceDistanceOfOne(PSE) && - OtherCand->isDependenceDistanceOfOne(PSE)) { + Cand.isDependenceDistanceOfOne(PSE, L) && + OtherCand->isDependenceDistanceOfOne(PSE, L)) { // They are in the same block, the later one will forward to the load. if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store)) OtherCand = &Cand; @@ -429,14 +449,21 @@ public: unsigned NumForwarding = 0; for (const StoreToLoadForwardingCandidate Cand : StoreToLoadDependences) { DEBUG(dbgs() << "Candidate " << Cand); + // Make sure that the stored values is available everywhere in the loop in // the next iteration. if (!doesStoreDominatesAllLatches(Cand.Store->getParent(), L, DT)) continue; + // If the load is conditional we can't hoist its 0-iteration instance to + // the preheader because that would make it unconditional. Thus we would + // access a memory location that the original loop did not access. + if (isLoadConditional(Cand.Load, L)) + continue; + // Check whether the SCEV difference is the same as the induction step, // thus we load the value in the next iteration. - if (!Cand.isDependenceDistanceOfOne(PSE)) + if (!Cand.isDependenceDistanceOfOne(PSE, L)) continue; ++NumForwarding; @@ -459,18 +486,25 @@ public: return false; } - if (LAI.PSE.getUnionPredicate().getComplexity() > + if (LAI.getPSE().getUnionPredicate().getComplexity() > LoadElimSCEVCheckThreshold) { DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); return false; } - // Point of no-return, start the transformation. First, version the loop if - // necessary. - if (!Checks.empty() || !LAI.PSE.getUnionPredicate().isAlwaysTrue()) { + if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { + if (L->getHeader()->getParent()->optForSize()) { + DEBUG(dbgs() << "Versioning is needed but not allowed when optimizing " + "for size.\n"); + return false; + } + + // Point of no-return, start the transformation. First, version the loop + // if necessary. + LoopVersioning LV(LAI, L, LI, DT, PSE.getSE(), false); LV.setAliasChecks(std::move(Checks)); - LV.setSCEVChecks(LAI.PSE.getUnionPredicate()); + LV.setSCEVChecks(LAI.getPSE().getUnionPredicate()); LV.versionLoop(); } @@ -508,8 +542,11 @@ public: } bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *LAA = &getAnalysis<LoopAccessAnalysis>(); + auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); // Build up a worklist of inner-loops to vectorize. This is necessary as the @@ -526,7 +563,7 @@ public: // Now walk the identified inner loops. bool Changed = false; for (Loop *L : Worklist) { - const LoopAccessInfo &LAI = LAA->getInfo(L, ValueToValueMap()); + const LoopAccessInfo &LAI = LAA->getInfo(L); // The actual work is performed by LoadEliminationForLoop. LoadEliminationForLoop LEL(L, LI, LAI, DT); Changed |= LEL.processLoop(); @@ -537,9 +574,10 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(LoopSimplifyID); AU.addRequired<LoopInfoWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<LoopAccessAnalysis>(); + AU.addRequired<LoopAccessLegacyAnalysis>(); AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); @@ -554,9 +592,10 @@ static const char LLE_name[] = "Loop Load Elimination"; INITIALIZE_PASS_BEGIN(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_END(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) namespace llvm { diff --git a/lib/Transforms/Scalar/LoopRerollPass.cpp b/lib/Transforms/Scalar/LoopRerollPass.cpp index 27c2d8824df0..d2f1b66076a6 100644 --- a/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -14,7 +14,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -128,9 +128,8 @@ NumToleratedFailedMatches("reroll-num-tolerated-failed-matches", cl::init(400), namespace { enum IterationLimits { - /// The maximum number of iterations that we'll try and reroll. This - /// has to be less than 25 in order to fit into a SmallBitVector. - IL_MaxRerollIterations = 16, + /// The maximum number of iterations that we'll try and reroll. + IL_MaxRerollIterations = 32, /// The bitvector index used by loop induction variables and other /// instructions that belong to all iterations. IL_All, @@ -147,13 +146,8 @@ namespace { bool runOnLoop(Loop *L, LPPassManager &LPM) override; void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + getLoopAnalysisUsage(AU); } protected: @@ -169,6 +163,9 @@ namespace { // Map between induction variable and its increment DenseMap<Instruction *, int64_t> IVToIncMap; + // For loop with multiple induction variable, remember the one used only to + // control the loop. + Instruction *LoopControlIV; // A chain of isomorphic instructions, identified by a single-use PHI // representing a reduction. Only the last value may be used outside the @@ -356,9 +353,11 @@ namespace { ScalarEvolution *SE, AliasAnalysis *AA, TargetLibraryInfo *TLI, DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA, - DenseMap<Instruction *, int64_t> &IncrMap) + DenseMap<Instruction *, int64_t> &IncrMap, + Instruction *LoopCtrlIV) : Parent(Parent), L(L), SE(SE), AA(AA), TLI(TLI), DT(DT), LI(LI), - PreserveLCSSA(PreserveLCSSA), IV(IV), IVToIncMap(IncrMap) {} + PreserveLCSSA(PreserveLCSSA), IV(IV), IVToIncMap(IncrMap), + LoopControlIV(LoopCtrlIV) {} /// Stage 1: Find all the DAG roots for the induction variable. bool findRoots(); @@ -370,7 +369,7 @@ namespace { void replace(const SCEV *IterCount); protected: - typedef MapVector<Instruction*, SmallBitVector> UsesTy; + typedef MapVector<Instruction*, BitVector> UsesTy; bool findRootsRecursive(Instruction *IVU, SmallInstructionSet SubsumedInsts); @@ -396,6 +395,8 @@ namespace { bool instrDependsOn(Instruction *I, UsesTy::iterator Start, UsesTy::iterator End); + void replaceIV(Instruction *Inst, Instruction *IV, const SCEV *IterCount); + void updateNonLoopCtrlIncr(); LoopReroll *Parent; @@ -426,8 +427,18 @@ namespace { UsesTy Uses; // Map between induction variable and its increment DenseMap<Instruction *, int64_t> &IVToIncMap; + Instruction *LoopControlIV; }; + // Check if it is a compare-like instruction whose user is a branch + bool isCompareUsedByBranch(Instruction *I) { + auto *TI = I->getParent()->getTerminator(); + if (!isa<BranchInst>(TI) || !isa<CmpInst>(I)) + return false; + return I->hasOneUse() && TI->getOperand(0) == I; + }; + + bool isLoopControlIV(Loop *L, Instruction *IV); void collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs); void collectPossibleReductions(Loop *L, ReductionTracker &Reductions); @@ -438,10 +449,7 @@ namespace { char LoopReroll::ID = 0; INITIALIZE_PASS_BEGIN(LoopReroll, "loop-reroll", "Reroll loops", false, false) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(LoopReroll, "loop-reroll", "Reroll loops", false, false) @@ -460,6 +468,110 @@ static bool hasUsesOutsideLoop(Instruction *I, Loop *L) { return false; } +static const SCEVConstant *getIncrmentFactorSCEV(ScalarEvolution *SE, + const SCEV *SCEVExpr, + Instruction &IV) { + const SCEVMulExpr *MulSCEV = dyn_cast<SCEVMulExpr>(SCEVExpr); + + // If StepRecurrence of a SCEVExpr is a constant (c1 * c2, c2 = sizeof(ptr)), + // Return c1. + if (!MulSCEV && IV.getType()->isPointerTy()) + if (const SCEVConstant *IncSCEV = dyn_cast<SCEVConstant>(SCEVExpr)) { + const PointerType *PTy = cast<PointerType>(IV.getType()); + Type *ElTy = PTy->getElementType(); + const SCEV *SizeOfExpr = + SE->getSizeOfExpr(SE->getEffectiveSCEVType(IV.getType()), ElTy); + if (IncSCEV->getValue()->getValue().isNegative()) { + const SCEV *NewSCEV = + SE->getUDivExpr(SE->getNegativeSCEV(SCEVExpr), SizeOfExpr); + return dyn_cast<SCEVConstant>(SE->getNegativeSCEV(NewSCEV)); + } else { + return dyn_cast<SCEVConstant>(SE->getUDivExpr(SCEVExpr, SizeOfExpr)); + } + } + + if (!MulSCEV) + return nullptr; + + // If StepRecurrence of a SCEVExpr is a c * sizeof(x), where c is constant, + // Return c. + const SCEVConstant *CIncSCEV = nullptr; + for (const SCEV *Operand : MulSCEV->operands()) { + if (const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Operand)) { + CIncSCEV = Constant; + } else if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Operand)) { + Type *AllocTy; + if (!Unknown->isSizeOf(AllocTy)) + break; + } else { + return nullptr; + } + } + return CIncSCEV; +} + +// Check if an IV is only used to control the loop. There are two cases: +// 1. It only has one use which is loop increment, and the increment is only +// used by comparison and the PHI (could has sext with nsw in between), and the +// comparison is only used by branch. +// 2. It is used by loop increment and the comparison, the loop increment is +// only used by the PHI, and the comparison is used only by the branch. +bool LoopReroll::isLoopControlIV(Loop *L, Instruction *IV) { + unsigned IVUses = IV->getNumUses(); + if (IVUses != 2 && IVUses != 1) + return false; + + for (auto *User : IV->users()) { + int32_t IncOrCmpUses = User->getNumUses(); + bool IsCompInst = isCompareUsedByBranch(cast<Instruction>(User)); + + // User can only have one or two uses. + if (IncOrCmpUses != 2 && IncOrCmpUses != 1) + return false; + + // Case 1 + if (IVUses == 1) { + // The only user must be the loop increment. + // The loop increment must have two uses. + if (IsCompInst || IncOrCmpUses != 2) + return false; + } + + // Case 2 + if (IVUses == 2 && IncOrCmpUses != 1) + return false; + + // The users of the IV must be a binary operation or a comparison + if (auto *BO = dyn_cast<BinaryOperator>(User)) { + if (BO->getOpcode() == Instruction::Add) { + // Loop Increment + // User of Loop Increment should be either PHI or CMP + for (auto *UU : User->users()) { + if (PHINode *PN = dyn_cast<PHINode>(UU)) { + if (PN != IV) + return false; + } + // Must be a CMP or an ext (of a value with nsw) then CMP + else { + Instruction *UUser = dyn_cast<Instruction>(UU); + // Skip SExt if we are extending an nsw value + // TODO: Allow ZExt too + if (BO->hasNoSignedWrap() && UUser && UUser->getNumUses() == 1 && + isa<SExtInst>(UUser)) + UUser = dyn_cast<Instruction>(*(UUser->user_begin())); + if (!isCompareUsedByBranch(UUser)) + return false; + } + } + } else + return false; + // Compare : can only have one use, and must be branch + } else if (!IsCompInst) + return false; + } + return true; +} + // Collect the list of loop induction variables with respect to which it might // be possible to reroll the loop. void LoopReroll::collectPossibleIVs(Loop *L, @@ -469,7 +581,7 @@ void LoopReroll::collectPossibleIVs(Loop *L, IE = Header->getFirstInsertionPt(); I != IE; ++I) { if (!isa<PHINode>(I)) continue; - if (!I->getType()->isIntegerTy()) + if (!I->getType()->isIntegerTy() && !I->getType()->isPointerTy()) continue; if (const SCEVAddRecExpr *PHISCEV = @@ -478,15 +590,27 @@ void LoopReroll::collectPossibleIVs(Loop *L, continue; if (!PHISCEV->isAffine()) continue; - if (const SCEVConstant *IncSCEV = - dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE))) { - const APInt &AInt = IncSCEV->getAPInt().abs(); + const SCEVConstant *IncSCEV = nullptr; + if (I->getType()->isPointerTy()) + IncSCEV = + getIncrmentFactorSCEV(SE, PHISCEV->getStepRecurrence(*SE), *I); + else + IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE)); + if (IncSCEV) { + const APInt &AInt = IncSCEV->getValue()->getValue().abs(); if (IncSCEV->getValue()->isZero() || AInt.uge(MaxInc)) continue; IVToIncMap[&*I] = IncSCEV->getValue()->getSExtValue(); DEBUG(dbgs() << "LRR: Possible IV: " << *I << " = " << *PHISCEV << "\n"); - PossibleIVs.push_back(&*I); + + if (isLoopControlIV(L, &*I)) { + assert(!LoopControlIV && "Found two loop control only IV"); + LoopControlIV = &(*I); + DEBUG(dbgs() << "LRR: Possible loop control only IV: " << *I << " = " + << *PHISCEV << "\n"); + } else + PossibleIVs.push_back(&*I); } } } @@ -611,9 +735,8 @@ void LoopReroll::DAGRootTracker::collectInLoopUserSet( const SmallInstructionSet &Exclude, const SmallInstructionSet &Final, DenseSet<Instruction *> &Users) { - for (SmallInstructionVector::const_iterator I = Roots.begin(), - IE = Roots.end(); I != IE; ++I) - collectInLoopUserSet(*I, Exclude, Final, Users); + for (Instruction *Root : Roots) + collectInLoopUserSet(Root, Exclude, Final, Users); } static bool isSimpleLoadStore(Instruction *I) { @@ -651,10 +774,12 @@ static bool isSimpleArithmeticOp(User *IVU) { static bool isLoopIncrement(User *U, Instruction *IV) { BinaryOperator *BO = dyn_cast<BinaryOperator>(U); - if (!BO || BO->getOpcode() != Instruction::Add) + + if ((BO && BO->getOpcode() != Instruction::Add) || + (!BO && !isa<GetElementPtrInst>(U))) return false; - for (auto *UU : BO->users()) { + for (auto *UU : U->users()) { PHINode *PN = dyn_cast<PHINode>(UU); if (PN && PN == IV) return true; @@ -1031,6 +1156,33 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { Uses[I].set(IL_All); } + // Make sure we mark loop-control-only PHIs as used in all iterations. See + // comment above LoopReroll::isLoopControlIV for more information. + BasicBlock *Header = L->getHeader(); + if (LoopControlIV && LoopControlIV != IV) { + for (auto *U : LoopControlIV->users()) { + Instruction *IVUser = dyn_cast<Instruction>(U); + // IVUser could be loop increment or compare + Uses[IVUser].set(IL_All); + for (auto *UU : IVUser->users()) { + Instruction *UUser = dyn_cast<Instruction>(UU); + // UUser could be compare, PHI or branch + Uses[UUser].set(IL_All); + // Skip SExt + if (isa<SExtInst>(UUser)) { + UUser = dyn_cast<Instruction>(*(UUser->user_begin())); + Uses[UUser].set(IL_All); + } + // Is UUser a compare instruction? + if (UU->hasOneUse()) { + Instruction *BI = dyn_cast<BranchInst>(*UUser->user_begin()); + if (BI == cast<BranchInst>(Header->getTerminator())) + Uses[BI].set(IL_All); + } + } + } + } + // Make sure all instructions in the loop are in one and only one // set. for (auto &KV : Uses) { @@ -1272,61 +1424,136 @@ void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) { ++J; } - bool Negative = IVToIncMap[IV] < 0; - const DataLayout &DL = Header->getModule()->getDataLayout(); - // We need to create a new induction variable for each different BaseInst. - for (auto &DRS : RootSets) { - // Insert the new induction variable. - const SCEVAddRecExpr *RealIVSCEV = - cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst)); - const SCEV *Start = RealIVSCEV->getStart(); - const SCEVAddRecExpr *H = cast<SCEVAddRecExpr>(SE->getAddRecExpr( - Start, SE->getConstant(RealIVSCEV->getType(), Negative ? -1 : 1), L, - SCEV::FlagAnyWrap)); - { // Limit the lifetime of SCEVExpander. - SCEVExpander Expander(*SE, DL, "reroll"); - Value *NewIV = Expander.expandCodeFor(H, IV->getType(), &Header->front()); - - for (auto &KV : Uses) { - if (KV.second.find_first() == 0) - KV.first->replaceUsesOfWith(DRS.BaseInst, NewIV); - } + bool HasTwoIVs = LoopControlIV && LoopControlIV != IV; + + if (HasTwoIVs) { + updateNonLoopCtrlIncr(); + replaceIV(LoopControlIV, LoopControlIV, IterCount); + } else + // We need to create a new induction variable for each different BaseInst. + for (auto &DRS : RootSets) + // Insert the new induction variable. + replaceIV(DRS.BaseInst, IV, IterCount); - if (BranchInst *BI = dyn_cast<BranchInst>(Header->getTerminator())) { - // FIXME: Why do we need this check? - if (Uses[BI].find_first() == IL_All) { - const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE); + SimplifyInstructionsInBlock(Header, TLI); + DeleteDeadPHIs(Header, TLI); +} - // Iteration count SCEV minus 1 - const SCEV *ICMinus1SCEV = SE->getMinusSCEV( - ICSCEV, SE->getConstant(ICSCEV->getType(), Negative ? -1 : 1)); +// For non-loop-control IVs, we only need to update the last increment +// with right amount, then we are done. +void LoopReroll::DAGRootTracker::updateNonLoopCtrlIncr() { + const SCEV *NewInc = nullptr; + for (auto *LoopInc : LoopIncs) { + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LoopInc); + const SCEVConstant *COp = nullptr; + if (GEP && LoopInc->getOperand(0)->getType()->isPointerTy()) { + COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(1))); + } else { + COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(0))); + if (!COp) + COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(1))); + } - Value *ICMinus1; // Iteration count minus 1 - if (isa<SCEVConstant>(ICMinus1SCEV)) { - ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(), BI); - } else { - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) - Preheader = InsertPreheaderForLoop(L, DT, LI, PreserveLCSSA); + assert(COp && "Didn't find constant operand of LoopInc!\n"); - ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(), - Preheader->getTerminator()); - } + const APInt &AInt = COp->getValue()->getValue(); + const SCEV *ScaleSCEV = SE->getConstant(COp->getType(), Scale); + if (AInt.isNegative()) { + NewInc = SE->getNegativeSCEV(COp); + NewInc = SE->getUDivExpr(NewInc, ScaleSCEV); + NewInc = SE->getNegativeSCEV(NewInc); + } else + NewInc = SE->getUDivExpr(COp, ScaleSCEV); + + LoopInc->setOperand(1, dyn_cast<SCEVConstant>(NewInc)->getValue()); + } +} - Value *Cond = - new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinus1, "exitcond"); - BI->setCondition(Cond); +void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst, + Instruction *InstIV, + const SCEV *IterCount) { + BasicBlock *Header = L->getHeader(); + int64_t Inc = IVToIncMap[InstIV]; + bool NeedNewIV = InstIV == LoopControlIV; + bool Negative = !NeedNewIV && Inc < 0; + + const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(Inst)); + const SCEV *Start = RealIVSCEV->getStart(); + + if (NeedNewIV) + Start = SE->getConstant(Start->getType(), 0); + + const SCEV *SizeOfExpr = nullptr; + const SCEV *IncrExpr = + SE->getConstant(RealIVSCEV->getType(), Negative ? -1 : 1); + if (auto *PTy = dyn_cast<PointerType>(Inst->getType())) { + Type *ElTy = PTy->getElementType(); + SizeOfExpr = + SE->getSizeOfExpr(SE->getEffectiveSCEVType(Inst->getType()), ElTy); + IncrExpr = SE->getMulExpr(IncrExpr, SizeOfExpr); + } + const SCEV *NewIVSCEV = + SE->getAddRecExpr(Start, IncrExpr, L, SCEV::FlagAnyWrap); + + { // Limit the lifetime of SCEVExpander. + const DataLayout &DL = Header->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "reroll"); + Value *NewIV = + Expander.expandCodeFor(NewIVSCEV, InstIV->getType(), &Header->front()); + + for (auto &KV : Uses) + if (KV.second.find_first() == 0) + KV.first->replaceUsesOfWith(Inst, NewIV); + + if (BranchInst *BI = dyn_cast<BranchInst>(Header->getTerminator())) { + // FIXME: Why do we need this check? + if (Uses[BI].find_first() == IL_All) { + const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE); + + if (NeedNewIV) + ICSCEV = SE->getMulExpr(IterCount, + SE->getConstant(IterCount->getType(), Scale)); + + // Iteration count SCEV minus or plus 1 + const SCEV *MinusPlus1SCEV = + SE->getConstant(ICSCEV->getType(), Negative ? -1 : 1); + if (Inst->getType()->isPointerTy()) { + assert(SizeOfExpr && "SizeOfExpr is not initialized"); + MinusPlus1SCEV = SE->getMulExpr(MinusPlus1SCEV, SizeOfExpr); + } - if (BI->getSuccessor(1) != Header) - BI->swapSuccessors(); + const SCEV *ICMinusPlus1SCEV = SE->getMinusSCEV(ICSCEV, MinusPlus1SCEV); + // Iteration count minus 1 + Instruction *InsertPtr = nullptr; + if (isa<SCEVConstant>(ICMinusPlus1SCEV)) { + InsertPtr = BI; + } else { + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) + Preheader = InsertPreheaderForLoop(L, DT, LI, PreserveLCSSA); + InsertPtr = Preheader->getTerminator(); } + + if (!isa<PointerType>(NewIV->getType()) && NeedNewIV && + (SE->getTypeSizeInBits(NewIV->getType()) < + SE->getTypeSizeInBits(ICMinusPlus1SCEV->getType()))) { + IRBuilder<> Builder(BI); + Builder.SetCurrentDebugLocation(BI->getDebugLoc()); + NewIV = Builder.CreateSExt(NewIV, ICMinusPlus1SCEV->getType()); + } + Value *ICMinusPlus1 = Expander.expandCodeFor( + ICMinusPlus1SCEV, NewIV->getType(), InsertPtr); + + Value *Cond = + new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinusPlus1, "exitcond"); + BI->setCondition(Cond); + + if (BI->getSuccessor(1) != Header) + BI->swapSuccessors(); } } } - - SimplifyInstructionsInBlock(Header, TLI); - DeleteDeadPHIs(Header, TLI); } // Validate the selected reductions. All iterations must have an isomorphic @@ -1334,9 +1561,7 @@ void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) { // entries must appear in order. bool LoopReroll::ReductionTracker::validateSelected() { // For a non-associative reduction, the chain entries must appear in order. - for (DenseSet<int>::iterator RI = Reds.begin(), RIE = Reds.end(); - RI != RIE; ++RI) { - int i = *RI; + for (int i : Reds) { int PrevIter = 0, BaseCount = 0, Count = 0; for (Instruction *J : PossibleReds[i]) { // Note that all instructions in the chain must have been found because @@ -1380,9 +1605,7 @@ bool LoopReroll::ReductionTracker::validateSelected() { void LoopReroll::ReductionTracker::replaceSelected() { // Fixup reductions to refer to the last instruction associated with the // first iteration (not the last). - for (DenseSet<int>::iterator RI = Reds.begin(), RIE = Reds.end(); - RI != RIE; ++RI) { - int i = *RI; + for (int i : Reds) { int j = 0; for (int e = PossibleReds[i].size(); j != e; ++j) if (PossibleRedIter[PossibleReds[i][j]] != 0) { @@ -1396,9 +1619,8 @@ void LoopReroll::ReductionTracker::replaceSelected() { Users.push_back(cast<Instruction>(U)); } - for (SmallInstructionVector::iterator J = Users.begin(), - JE = Users.end(); J != JE; ++J) - (*J)->replaceUsesOfWith(PossibleReds[i].getReducedValue(), + for (Instruction *User : Users) + User->replaceUsesOfWith(PossibleReds[i].getReducedValue(), PossibleReds[i][j]); } } @@ -1450,7 +1672,7 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, const SCEV *IterCount, ReductionTracker &Reductions) { DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI, DT, LI, PreserveLCSSA, - IVToIncMap); + IVToIncMap, LoopControlIV); if (!DAGRoots.findRoots()) return false; @@ -1472,7 +1694,7 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, } bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipOptnoneFunction(L)) + if (skipLoop(L)) return false; AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); @@ -1487,41 +1709,46 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { "] Loop %" << Header->getName() << " (" << L->getNumBlocks() << " block(s))\n"); - bool Changed = false; - // For now, we'll handle only single BB loops. if (L->getNumBlocks() > 1) - return Changed; + return false; if (!SE->hasLoopInvariantBackedgeTakenCount(L)) - return Changed; + return false; const SCEV *LIBETC = SE->getBackedgeTakenCount(L); const SCEV *IterCount = SE->getAddExpr(LIBETC, SE->getOne(LIBETC->getType())); + DEBUG(dbgs() << "\n Before Reroll:\n" << *(L->getHeader()) << "\n"); DEBUG(dbgs() << "LRR: iteration count = " << *IterCount << "\n"); // First, we need to find the induction variable with respect to which we can // reroll (there may be several possible options). SmallInstructionVector PossibleIVs; IVToIncMap.clear(); + LoopControlIV = nullptr; collectPossibleIVs(L, PossibleIVs); if (PossibleIVs.empty()) { DEBUG(dbgs() << "LRR: No possible IVs found\n"); - return Changed; + return false; } ReductionTracker Reductions; collectPossibleReductions(L, Reductions); + bool Changed = false; // For each possible IV, collect the associated possible set of 'root' nodes // (i+1, i+2, etc.). - for (SmallInstructionVector::iterator I = PossibleIVs.begin(), - IE = PossibleIVs.end(); I != IE; ++I) - if (reroll(*I, L, Header, IterCount, Reductions)) { + for (Instruction *PossibleIV : PossibleIVs) + if (reroll(PossibleIV, L, Header, IterCount, Reductions)) { Changed = true; break; } + DEBUG(dbgs() << "\n After Reroll:\n" << *(L->getHeader()) << "\n"); + + // Trip count of L has changed so SE must be re-evaluated. + if (Changed) + SE->forgetLoop(L); return Changed; } diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp index 5e6c2da08cc3..7a06a25a7073 100644 --- a/lib/Transforms/Scalar/LoopRotation.cpp +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopRotation.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -20,6 +20,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -32,20 +33,46 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/ValueMapper.h" using namespace llvm; #define DEBUG_TYPE "loop-rotate" -static cl::opt<unsigned> -DefaultRotationThreshold("rotation-max-header-size", cl::init(16), cl::Hidden, - cl::desc("The default maximum header size for automatic loop rotation")); +static cl::opt<unsigned> DefaultRotationThreshold( + "rotation-max-header-size", cl::init(16), cl::Hidden, + cl::desc("The default maximum header size for automatic loop rotation")); STATISTIC(NumRotated, "Number of loops rotated"); +namespace { +/// A simple loop rotation transformation. +class LoopRotate { + const unsigned MaxHeaderSize; + LoopInfo *LI; + const TargetTransformInfo *TTI; + AssumptionCache *AC; + DominatorTree *DT; + ScalarEvolution *SE; + +public: + LoopRotate(unsigned MaxHeaderSize, LoopInfo *LI, + const TargetTransformInfo *TTI, AssumptionCache *AC, + DominatorTree *DT, ScalarEvolution *SE) + : MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE) { + } + bool processLoop(Loop *L); + +private: + bool rotateLoop(Loop *L, bool SimplifiedLatch); + bool simplifyLoopLatch(Loop *L); +}; +} // end anonymous namespace + /// RewriteUsesOfClonedInstructions - We just cloned the instructions from the /// old header into the preheader. If there were uses of the values produced by /// these instruction that were outside of the loop, we have to insert PHI nodes @@ -69,7 +96,7 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, if (OrigHeaderVal->use_empty()) continue; - Value *OrigPreHeaderVal = ValueMap[OrigHeaderVal]; + Value *OrigPreHeaderVal = ValueMap.lookup(OrigHeaderVal); // The value now exits in two versions: the initial value in the preheader // and the loop "next" value in the original header. @@ -79,7 +106,8 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, // Visit each use of the OrigHeader instruction. for (Value::use_iterator UI = OrigHeaderVal->use_begin(), - UE = OrigHeaderVal->use_end(); UI != UE; ) { + UE = OrigHeaderVal->use_end(); + UI != UE;) { // Grab the use before incrementing the iterator. Use &U = *UI; @@ -108,6 +136,41 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, // Anything else can be handled by SSAUpdater. SSA.RewriteUse(U); } + + // Replace MetadataAsValue(ValueAsMetadata(OrigHeaderVal)) uses in debug + // intrinsics. + LLVMContext &C = OrigHeader->getContext(); + if (auto *VAM = ValueAsMetadata::getIfExists(OrigHeaderVal)) { + if (auto *MAV = MetadataAsValue::getIfExists(C, VAM)) { + for (auto UI = MAV->use_begin(), E = MAV->use_end(); UI != E;) { + // Grab the use before incrementing the iterator. Otherwise, altering + // the Use will invalidate the iterator. + Use &U = *UI++; + DbgInfoIntrinsic *UserInst = dyn_cast<DbgInfoIntrinsic>(U.getUser()); + if (!UserInst) + continue; + + // The original users in the OrigHeader are already using the original + // definitions. + BasicBlock *UserBB = UserInst->getParent(); + if (UserBB == OrigHeader) + continue; + + // Users in the OrigPreHeader need to use the value to which the + // original definitions are mapped and anything else can be handled by + // the SSAUpdater. To avoid adding PHINodes, check if the value is + // available in UserBB, if not substitute undef. + Value *NewVal; + if (UserBB == OrigPreheader) + NewVal = OrigPreHeaderVal; + else if (SSA.HasValueForBlock(UserBB)) + NewVal = SSA.GetValueInMiddleOfBlock(UserBB); + else + NewVal = UndefValue::get(OrigHeaderVal->getType()); + U = MetadataAsValue::get(C, ValueAsMetadata::get(NewVal)); + } + } + } } } @@ -121,10 +184,7 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, /// rotation. LoopRotate should be repeatable and converge to a canonical /// form. This property is satisfied because simplifying the loop latch can only /// happen once across multiple invocations of the LoopRotate pass. -static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, - const TargetTransformInfo *TTI, AssumptionCache *AC, - DominatorTree *DT, ScalarEvolution *SE, - bool SimplifiedLatch) { +bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // If the loop has only one block then there is not much to rotate. if (L->getBlocks().size() == 1) return false; @@ -162,7 +222,14 @@ static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, Metrics.analyzeBasicBlock(OrigHeader, *TTI, EphValues); if (Metrics.notDuplicatable) { DEBUG(dbgs() << "LoopRotation: NOT rotating - contains non-duplicatable" - << " instructions: "; L->dump()); + << " instructions: "; + L->dump()); + return false; + } + if (Metrics.convergent) { + DEBUG(dbgs() << "LoopRotation: NOT rotating - contains convergent " + "instructions: "; + L->dump()); return false; } if (Metrics.NumInsts > MaxHeaderSize) @@ -225,10 +292,9 @@ static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, // executing in each iteration of the loop. This means it is safe to hoist // something that might trap, but isn't safe to hoist something that reads // memory (without proving that the loop doesn't write). - if (L->hasLoopInvariantOperands(Inst) && - !Inst->mayReadFromMemory() && !Inst->mayWriteToMemory() && - !isa<TerminatorInst>(Inst) && !isa<DbgInfoIntrinsic>(Inst) && - !isa<AllocaInst>(Inst)) { + if (L->hasLoopInvariantOperands(Inst) && !Inst->mayReadFromMemory() && + !Inst->mayWriteToMemory() && !isa<TerminatorInst>(Inst) && + !isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst)) { Inst->moveBefore(LoopEntryBranch); continue; } @@ -238,7 +304,7 @@ static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, // Eagerly remap the operands of the instruction. RemapInstruction(C, ValueMap, - RF_NoModuleLevelChanges|RF_IgnoreMissingEntries); + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); // With the operands remapped, see if the instruction constant folds or is // otherwise simplifyable. This commonly occurs because the entry from PHI @@ -248,13 +314,18 @@ static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, if (V && LI->replacementPreservesLCSSAForm(C, V)) { // If so, then delete the temporary instruction and stick the folded value // in the map. - delete C; ValueMap[Inst] = V; + if (!C->mayHaveSideEffects()) { + delete C; + C = nullptr; + } } else { + ValueMap[Inst] = C; + } + if (C) { // Otherwise, stick the new instruction into the new block! C->setName(Inst->getName()); C->insertBefore(LoopEntryBranch); - ValueMap[Inst] = C; } } @@ -280,7 +351,6 @@ static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, L->moveToHeader(NewHeader); assert(L->getHeader() == NewHeader && "Latch block is our new header"); - // At this point, we've finished our major CFG changes. As part of cloning // the loop into the preheader we've simplified instructions and the // duplicated conditional branch may now be branching on a constant. If it is @@ -291,8 +361,8 @@ static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator()); assert(PHBI->isConditional() && "Should be clone of BI condbr!"); if (!isa<ConstantInt>(PHBI->getCondition()) || - PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) - != NewHeader) { + PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) != + NewHeader) { // The conditional branch can't be folded, handle the general case. // Update DominatorTree to reflect the CFG change we just made. Then split // edges as necessary to preserve LoopSimplify form. @@ -329,18 +399,17 @@ static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, // be split. SmallVector<BasicBlock *, 4> ExitPreds(pred_begin(Exit), pred_end(Exit)); bool SplitLatchEdge = false; - for (SmallVectorImpl<BasicBlock *>::iterator PI = ExitPreds.begin(), - PE = ExitPreds.end(); - PI != PE; ++PI) { + for (BasicBlock *ExitPred : ExitPreds) { // We only need to split loop exit edges. - Loop *PredLoop = LI->getLoopFor(*PI); + Loop *PredLoop = LI->getLoopFor(ExitPred); if (!PredLoop || PredLoop->contains(Exit)) continue; - if (isa<IndirectBrInst>((*PI)->getTerminator())) + if (isa<IndirectBrInst>(ExitPred->getTerminator())) continue; - SplitLatchEdge |= L->getLoopLatch() == *PI; + SplitLatchEdge |= L->getLoopLatch() == ExitPred; BasicBlock *ExitSplit = SplitCriticalEdge( - *PI, Exit, CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + ExitPred, Exit, + CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); ExitSplit->moveBefore(Exit); } assert(SplitLatchEdge && @@ -384,8 +453,8 @@ static bool rotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, } } - // If the dominator changed, this may have an effect on other - // predecessors, continue until we reach a fixpoint. + // If the dominator changed, this may have an effect on other + // predecessors, continue until we reach a fixpoint. } while (Changed); } } @@ -432,7 +501,7 @@ static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, // GEPs are cheap if all indices are constant. if (!cast<GEPOperator>(I)->hasAllConstantIndices()) return false; - // fall-thru to increment case + // fall-thru to increment case case Instruction::Add: case Instruction::Sub: case Instruction::And: @@ -441,11 +510,10 @@ static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, case Instruction::Shl: case Instruction::LShr: case Instruction::AShr: { - Value *IVOpnd = !isa<Constant>(I->getOperand(0)) - ? I->getOperand(0) - : !isa<Constant>(I->getOperand(1)) - ? I->getOperand(1) - : nullptr; + Value *IVOpnd = + !isa<Constant>(I->getOperand(0)) + ? I->getOperand(0) + : !isa<Constant>(I->getOperand(1)) ? I->getOperand(1) : nullptr; if (!IVOpnd) return false; @@ -482,7 +550,7 @@ static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, /// canonical form so downstream passes can handle it. /// /// I don't believe this invalidates SCEV. -static bool simplifyLoopLatch(Loop *L, LoopInfo *LI, DominatorTree *DT) { +bool LoopRotate::simplifyLoopLatch(Loop *L) { BasicBlock *Latch = L->getLoopLatch(); if (!Latch || Latch->hasAddressTaken()) return false; @@ -503,7 +571,7 @@ static bool simplifyLoopLatch(Loop *L, LoopInfo *LI, DominatorTree *DT) { return false; DEBUG(dbgs() << "Folding loop latch " << Latch->getName() << " into " - << LastExit->getName() << "\n"); + << LastExit->getName() << "\n"); // Hoist the instructions from Latch into LastExit. LastExit->getInstList().splice(BI->getIterator(), Latch->getInstList(), @@ -527,26 +595,19 @@ static bool simplifyLoopLatch(Loop *L, LoopInfo *LI, DominatorTree *DT) { return true; } -/// Rotate \c L as many times as possible. Return true if the loop is rotated -/// at least once. -static bool iterativelyRotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, - const TargetTransformInfo *TTI, - AssumptionCache *AC, DominatorTree *DT, - ScalarEvolution *SE) { +/// Rotate \c L, and return true if any modification was made. +bool LoopRotate::processLoop(Loop *L) { // Save the loop metadata. MDNode *LoopMD = L->getLoopID(); // Simplify the loop latch before attempting to rotate the header // upward. Rotation may not be needed if the loop tail can be folded into the // loop exit. - bool SimplifiedLatch = simplifyLoopLatch(L, LI, DT); + bool SimplifiedLatch = simplifyLoopLatch(L); - // One loop can be rotated multiple times. - bool MadeChange = false; - while (rotateLoop(L, MaxHeaderSize, LI, TTI, AC, DT, SE, SimplifiedLatch)) { - MadeChange = true; - SimplifiedLatch = false; - } + bool MadeChange = rotateLoop(L, SimplifiedLatch); + assert((!MadeChange || L->isLoopExiting(L->getLoopLatch())) && + "Loop latch should be exiting after loop-rotate."); // Restore the loop metadata. // NB! We presume LoopRotation DOESN'T ADD its own metadata. @@ -556,15 +617,37 @@ static bool iterativelyRotateLoop(Loop *L, unsigned MaxHeaderSize, LoopInfo *LI, return MadeChange; } +LoopRotatePass::LoopRotatePass() {} + +PreservedAnalyses LoopRotatePass::run(Loop &L, AnalysisManager<Loop> &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function *F = L.getHeader()->getParent(); + + auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); + const auto *TTI = FAM.getCachedResult<TargetIRAnalysis>(*F); + auto *AC = FAM.getCachedResult<AssumptionAnalysis>(*F); + assert((LI && TTI && AC) && "Analyses for loop rotation not available"); + + // Optional analyses. + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); + auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); + LoopRotate LR(DefaultRotationThreshold, LI, TTI, AC, DT, SE); + + bool Changed = LR.processLoop(&L); + if (!Changed) + return PreservedAnalyses::all(); + return getLoopPassPreservedAnalyses(); +} + namespace { -class LoopRotate : public LoopPass { +class LoopRotateLegacyPass : public LoopPass { unsigned MaxHeaderSize; public: static char ID; // Pass ID, replacement for typeid - LoopRotate(int SpecifiedMaxHeaderSize = -1) : LoopPass(ID) { - initializeLoopRotatePass(*PassRegistry::getPassRegistry()); + LoopRotateLegacyPass(int SpecifiedMaxHeaderSize = -1) : LoopPass(ID) { + initializeLoopRotateLegacyPassPass(*PassRegistry::getPassRegistry()); if (SpecifiedMaxHeaderSize == -1) MaxHeaderSize = DefaultRotationThreshold; else @@ -573,24 +656,13 @@ public: // LCSSA form makes instruction renaming easier. void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addPreserved<AAResultsWrapperPass>(); AU.addRequired<AssumptionCacheTracker>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addPreservedID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); - AU.addPreservedID(LCSSAID); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<SCEVAAWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<BasicAAWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); + getLoopAnalysisUsage(AU); } bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipOptnoneFunction(L)) + if (skipLoop(L)) return false; Function &F = *L->getHeader()->getParent(); @@ -601,24 +673,21 @@ public: auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; - - return iterativelyRotateLoop(L, MaxHeaderSize, LI, TTI, AC, DT, SE); + LoopRotate LR(MaxHeaderSize, LI, TTI, AC, DT, SE); + return LR.processLoop(L); } }; } -char LoopRotate::ID = 0; -INITIALIZE_PASS_BEGIN(LoopRotate, "loop-rotate", "Rotate Loops", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +char LoopRotateLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", + false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSA) -INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(LoopRotate, "loop-rotate", "Rotate Loops", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", false, + false) Pass *llvm::createLoopRotatePass(int MaxHeaderSize) { - return new LoopRotate(MaxHeaderSize); + return new LoopRotateLegacyPass(MaxHeaderSize); } diff --git a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp new file mode 100644 index 000000000000..ec227932c09e --- /dev/null +++ b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -0,0 +1,114 @@ +//===--------- LoopSimplifyCFG.cpp - Loop CFG Simplification Pass ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Loop SimplifyCFG Pass. This pass is responsible for +// basic loop CFG cleanup, primarily to assist other loop passes. If you +// encounter a noncanonical CFG construct that causes another loop pass to +// perform suboptimally, this is the place to fix it up. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopSimplifyCFG.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopPassManager.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-simplifycfg" + +static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI) { + bool Changed = false; + // Copy blocks into a temporary array to avoid iterator invalidation issues + // as we remove them. + SmallVector<WeakVH, 16> Blocks(L.blocks()); + + for (auto &Block : Blocks) { + // Attempt to merge blocks in the trivial case. Don't modify blocks which + // belong to other loops. + BasicBlock *Succ = cast_or_null<BasicBlock>(Block); + if (!Succ) + continue; + + BasicBlock *Pred = Succ->getSinglePredecessor(); + if (!Pred || !Pred->getSingleSuccessor() || LI.getLoopFor(Pred) != &L) + continue; + + // Pred is going to disappear, so we need to update the loop info. + if (L.getHeader() == Pred) + L.moveToHeader(Succ); + LI.removeBlock(Pred); + MergeBasicBlockIntoOnlyPred(Succ, &DT); + Changed = true; + } + + return Changed; +} + +PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, AnalysisManager<Loop> &AM) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function *F = L.getHeader()->getParent(); + + auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); + assert((LI && DT) && "Analyses for LoopSimplifyCFG not available"); + + if (!simplifyLoopCFG(L, *DT, *LI)) + return PreservedAnalyses::all(); + return getLoopPassPreservedAnalyses(); +} + +namespace { +class LoopSimplifyCFGLegacyPass : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + LoopSimplifyCFGLegacyPass() : LoopPass(ID) { + initializeLoopSimplifyCFGLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &) override { + if (skipLoop(L)) + return false; + + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + return simplifyLoopCFG(*L, DT, LI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<DependenceAnalysisWrapperPass>(); + getLoopAnalysisUsage(AU); + } +}; +} + +char LoopSimplifyCFGLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", + "Simplify loop CFG", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", + "Simplify loop CFG", false, false) + +Pass *llvm::createLoopSimplifyCFGPass() { + return new LoopSimplifyCFGLegacyPass(); +} diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index acfdec43d21a..77c77eb7d798 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -684,10 +684,6 @@ static bool isAddressUse(Instruction *Inst, Value *OperandVal) { switch (II->getIntrinsicID()) { default: break; case Intrinsic::prefetch: - case Intrinsic::x86_sse_storeu_ps: - case Intrinsic::x86_sse2_storeu_pd: - case Intrinsic::x86_sse2_storeu_dq: - case Intrinsic::x86_sse2_storel_dq: if (II->getArgOperand(0) == OperandVal) isAddress = true; break; @@ -704,18 +700,6 @@ static MemAccessTy getAccessType(const Instruction *Inst) { AccessTy.AddrSpace = SI->getPointerAddressSpace(); } else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) { AccessTy.AddrSpace = LI->getPointerAddressSpace(); - } else if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { - // Addressing modes can also be folded into prefetches and a variety - // of intrinsics. - switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::x86_sse_storeu_ps: - case Intrinsic::x86_sse2_storeu_pd: - case Intrinsic::x86_sse2_storeu_dq: - case Intrinsic::x86_sse2_storel_dq: - AccessTy.MemTy = II->getArgOperand(0)->getType(); - break; - } } // All pointers have the same requirements, so canonicalize them to an @@ -963,8 +947,8 @@ void Cost::RateRegister(const SCEV *Reg, isa<SCEVConstant>(cast<SCEVAddRecExpr>(Reg)->getStart())))) ++SetupCost; - NumIVMuls += isa<SCEVMulExpr>(Reg) && - SE.hasComputableLoopEvolution(Reg, L); + NumIVMuls += isa<SCEVMulExpr>(Reg) && + SE.hasComputableLoopEvolution(Reg, L); } /// Record this register in the set. If we haven't seen it before, rate @@ -2752,34 +2736,31 @@ void LSRInstance::CollectChains() { LatchPath.push_back(LoopHeader); // Walk the instruction stream from the loop header to the loop latch. - for (SmallVectorImpl<BasicBlock *>::reverse_iterator - BBIter = LatchPath.rbegin(), BBEnd = LatchPath.rend(); - BBIter != BBEnd; ++BBIter) { - for (BasicBlock::iterator I = (*BBIter)->begin(), E = (*BBIter)->end(); - I != E; ++I) { + for (BasicBlock *BB : reverse(LatchPath)) { + for (Instruction &I : *BB) { // Skip instructions that weren't seen by IVUsers analysis. - if (isa<PHINode>(I) || !IU.isIVUserOrOperand(&*I)) + if (isa<PHINode>(I) || !IU.isIVUserOrOperand(&I)) continue; // Ignore users that are part of a SCEV expression. This way we only // consider leaf IV Users. This effectively rediscovers a portion of // IVUsers analysis but in program order this time. - if (SE.isSCEVable(I->getType()) && !isa<SCEVUnknown>(SE.getSCEV(&*I))) + if (SE.isSCEVable(I.getType()) && !isa<SCEVUnknown>(SE.getSCEV(&I))) continue; // Remove this instruction from any NearUsers set it may be in. for (unsigned ChainIdx = 0, NChains = IVChainVec.size(); ChainIdx < NChains; ++ChainIdx) { - ChainUsersVec[ChainIdx].NearUsers.erase(&*I); + ChainUsersVec[ChainIdx].NearUsers.erase(&I); } // Search for operands that can be chained. SmallPtrSet<Instruction*, 4> UniqueOperands; - User::op_iterator IVOpEnd = I->op_end(); - User::op_iterator IVOpIter = findIVOperand(I->op_begin(), IVOpEnd, L, SE); + User::op_iterator IVOpEnd = I.op_end(); + User::op_iterator IVOpIter = findIVOperand(I.op_begin(), IVOpEnd, L, SE); while (IVOpIter != IVOpEnd) { Instruction *IVOpInst = cast<Instruction>(*IVOpIter); if (UniqueOperands.insert(IVOpInst).second) - ChainInstruction(&*I, IVOpInst, ChainUsersVec); + ChainInstruction(&I, IVOpInst, ChainUsersVec); IVOpIter = findIVOperand(std::next(IVOpIter), IVOpEnd, L, SE); } } // Continue walking down the instructions. @@ -4331,28 +4312,15 @@ BasicBlock::iterator LSRInstance::HoistInsertPosition(BasicBlock::iterator IP, const SmallVectorImpl<Instruction *> &Inputs) const { + Instruction *Tentative = &*IP; for (;;) { - const Loop *IPLoop = LI.getLoopFor(IP->getParent()); - unsigned IPLoopDepth = IPLoop ? IPLoop->getLoopDepth() : 0; - - BasicBlock *IDom; - for (DomTreeNode *Rung = DT.getNode(IP->getParent()); ; ) { - if (!Rung) return IP; - Rung = Rung->getIDom(); - if (!Rung) return IP; - IDom = Rung->getBlock(); - - // Don't climb into a loop though. - const Loop *IDomLoop = LI.getLoopFor(IDom); - unsigned IDomDepth = IDomLoop ? IDomLoop->getLoopDepth() : 0; - if (IDomDepth <= IPLoopDepth && - (IDomDepth != IPLoopDepth || IDomLoop == IPLoop)) - break; - } - bool AllDominate = true; Instruction *BetterPos = nullptr; - Instruction *Tentative = IDom->getTerminator(); + // Don't bother attempting to insert before a catchswitch, their basic block + // cannot have other non-PHI instructions. + if (isa<CatchSwitchInst>(Tentative)) + return IP; + for (Instruction *Inst : Inputs) { if (Inst == Tentative || !DT.dominates(Inst, Tentative)) { AllDominate = false; @@ -4360,7 +4328,7 @@ LSRInstance::HoistInsertPosition(BasicBlock::iterator IP, } // Attempt to find an insert position in the middle of the block, // instead of at the end, so that it can be used for other expansions. - if (IDom == Inst->getParent() && + if (Tentative->getParent() == Inst->getParent() && (!BetterPos || !DT.dominates(Inst, BetterPos))) BetterPos = &*std::next(BasicBlock::iterator(Inst)); } @@ -4370,6 +4338,26 @@ LSRInstance::HoistInsertPosition(BasicBlock::iterator IP, IP = BetterPos->getIterator(); else IP = Tentative->getIterator(); + + const Loop *IPLoop = LI.getLoopFor(IP->getParent()); + unsigned IPLoopDepth = IPLoop ? IPLoop->getLoopDepth() : 0; + + BasicBlock *IDom; + for (DomTreeNode *Rung = DT.getNode(IP->getParent()); ; ) { + if (!Rung) return IP; + Rung = Rung->getIDom(); + if (!Rung) return IP; + IDom = Rung->getBlock(); + + // Don't climb into a loop though. + const Loop *IDomLoop = LI.getLoopFor(IDom); + unsigned IDomDepth = IDomLoop ? IDomLoop->getLoopDepth() : 0; + if (IDomDepth <= IPLoopDepth && + (IDomDepth != IPLoopDepth || IDomLoop == IPLoop)) + break; + } + + Tentative = IDom->getTerminator(); } return IP; @@ -4426,7 +4414,7 @@ LSRInstance::AdjustInsertPositionForExpand(BasicBlock::iterator LowestIP, while (isa<PHINode>(IP)) ++IP; // Ignore landingpad instructions. - while (!isa<TerminatorInst>(IP) && IP->isEHPad()) ++IP; + while (IP->isEHPad()) ++IP; // Ignore debug intrinsics. while (isa<DbgInfoIntrinsic>(IP)) ++IP; @@ -4961,7 +4949,7 @@ INITIALIZE_PASS_BEGIN(LoopStrengthReduce, "loop-reduce", INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(IVUsers) +INITIALIZE_PASS_DEPENDENCY(IVUsersWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_END(LoopStrengthReduce, "loop-reduce", @@ -4991,16 +4979,16 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const { // Requiring LoopSimplify a second time here prevents IVUsers from running // twice, since LoopSimplify was invalidated by running ScalarEvolution. AU.addRequiredID(LoopSimplifyID); - AU.addRequired<IVUsers>(); - AU.addPreserved<IVUsers>(); + AU.addRequired<IVUsersWrapperPass>(); + AU.addPreserved<IVUsersWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); } bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { - if (skipOptnoneFunction(L)) + if (skipLoop(L)) return false; - auto &IU = getAnalysis<IVUsers>(); + auto &IU = getAnalysis<IVUsersWrapperPass>().getIU(); auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp index ecef6dbe24e6..91af4a1922ce 100644 --- a/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -12,13 +12,13 @@ // counts of loops easily. //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/SetVector.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopUnrollAnalyzer.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -31,8 +31,11 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/UnrollLoop.h" #include <climits> +#include <utility> using namespace llvm; @@ -43,40 +46,54 @@ static cl::opt<unsigned> cl::desc("The baseline cost threshold for loop unrolling")); static cl::opt<unsigned> UnrollPercentDynamicCostSavedThreshold( - "unroll-percent-dynamic-cost-saved-threshold", cl::Hidden, + "unroll-percent-dynamic-cost-saved-threshold", cl::init(50), cl::Hidden, cl::desc("The percentage of estimated dynamic cost which must be saved by " "unrolling to allow unrolling up to the max threshold.")); static cl::opt<unsigned> UnrollDynamicCostSavingsDiscount( - "unroll-dynamic-cost-savings-discount", cl::Hidden, + "unroll-dynamic-cost-savings-discount", cl::init(100), cl::Hidden, cl::desc("This is the amount discounted from the total unroll cost when " "the unrolled form has a high dynamic cost savings (triggered by " "the '-unroll-perecent-dynamic-cost-saved-threshold' flag).")); static cl::opt<unsigned> UnrollMaxIterationsCountToAnalyze( - "unroll-max-iteration-count-to-analyze", cl::init(0), cl::Hidden, + "unroll-max-iteration-count-to-analyze", cl::init(10), cl::Hidden, cl::desc("Don't allow loop unrolling to simulate more than this number of" "iterations when checking full unroll profitability")); -static cl::opt<unsigned> -UnrollCount("unroll-count", cl::Hidden, - cl::desc("Use this unroll count for all loops including those with " - "unroll_count pragma values, for testing purposes")); +static cl::opt<unsigned> UnrollCount( + "unroll-count", cl::Hidden, + cl::desc("Use this unroll count for all loops including those with " + "unroll_count pragma values, for testing purposes")); -static cl::opt<bool> -UnrollAllowPartial("unroll-allow-partial", cl::Hidden, - cl::desc("Allows loops to be partially unrolled until " - "-unroll-threshold loop size is reached.")); +static cl::opt<unsigned> UnrollMaxCount( + "unroll-max-count", cl::Hidden, + cl::desc("Set the max unroll count for partial and runtime unrolling, for" + "testing purposes")); + +static cl::opt<unsigned> UnrollFullMaxCount( + "unroll-full-max-count", cl::Hidden, + cl::desc( + "Set the max unroll count for full unrolling, for testing purposes")); static cl::opt<bool> -UnrollRuntime("unroll-runtime", cl::ZeroOrMore, cl::Hidden, - cl::desc("Unroll loops with run-time trip counts")); + UnrollAllowPartial("unroll-allow-partial", cl::Hidden, + cl::desc("Allows loops to be partially unrolled until " + "-unroll-threshold loop size is reached.")); -static cl::opt<unsigned> -PragmaUnrollThreshold("pragma-unroll-threshold", cl::init(16 * 1024), cl::Hidden, - cl::desc("Unrolled size limit for loops with an unroll(full) or " - "unroll_count pragma.")); +static cl::opt<bool> UnrollAllowRemainder( + "unroll-allow-remainder", cl::Hidden, + cl::desc("Allow generation of a loop remainder (extra iterations) " + "when unrolling a loop.")); +static cl::opt<bool> + UnrollRuntime("unroll-runtime", cl::ZeroOrMore, cl::Hidden, + cl::desc("Unroll loops with run-time trip counts")); + +static cl::opt<unsigned> PragmaUnrollThreshold( + "pragma-unroll-threshold", cl::init(16 * 1024), cl::Hidden, + cl::desc("Unrolled size limit for loops with an unroll(full) or " + "unroll_count pragma.")); /// A magic value for use with the Threshold parameter to indicate /// that the loop unroll should be performed regardless of how much @@ -88,26 +105,28 @@ static const unsigned NoThreshold = UINT_MAX; static const unsigned DefaultUnrollRuntimeCount = 8; /// Gather the various unrolling parameters based on the defaults, compiler -/// flags, TTI overrides, pragmas, and user specified parameters. +/// flags, TTI overrides and user specified parameters. static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( Loop *L, const TargetTransformInfo &TTI, Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, - Optional<bool> UserRuntime, unsigned PragmaCount, bool PragmaFullUnroll, - bool PragmaEnableUnroll, unsigned TripCount) { + Optional<bool> UserRuntime) { TargetTransformInfo::UnrollingPreferences UP; // Set up the defaults UP.Threshold = 150; - UP.PercentDynamicCostSavedThreshold = 20; - UP.DynamicCostSavingsDiscount = 2000; - UP.OptSizeThreshold = 50; + UP.PercentDynamicCostSavedThreshold = 50; + UP.DynamicCostSavingsDiscount = 100; + UP.OptSizeThreshold = 0; UP.PartialThreshold = UP.Threshold; - UP.PartialOptSizeThreshold = UP.OptSizeThreshold; + UP.PartialOptSizeThreshold = 0; UP.Count = 0; UP.MaxCount = UINT_MAX; + UP.FullUnrollMaxCount = UINT_MAX; UP.Partial = false; UP.Runtime = false; + UP.AllowRemainder = true; UP.AllowExpensiveTripCount = false; + UP.Force = false; // Override with any target specific settings TTI.getUnrollingPreferences(L, UP); @@ -118,12 +137,6 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.PartialThreshold = UP.PartialOptSizeThreshold; } - // Apply unroll count pragmas - if (PragmaCount) - UP.Count = PragmaCount; - else if (PragmaFullUnroll) - UP.Count = TripCount; - // Apply any user values specified by cl::opt if (UnrollThreshold.getNumOccurrences() > 0) { UP.Threshold = UnrollThreshold; @@ -134,10 +147,14 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UnrollPercentDynamicCostSavedThreshold; if (UnrollDynamicCostSavingsDiscount.getNumOccurrences() > 0) UP.DynamicCostSavingsDiscount = UnrollDynamicCostSavingsDiscount; - if (UnrollCount.getNumOccurrences() > 0) - UP.Count = UnrollCount; + if (UnrollMaxCount.getNumOccurrences() > 0) + UP.MaxCount = UnrollMaxCount; + if (UnrollFullMaxCount.getNumOccurrences() > 0) + UP.FullUnrollMaxCount = UnrollFullMaxCount; if (UnrollAllowPartial.getNumOccurrences() > 0) UP.Partial = UnrollAllowPartial; + if (UnrollAllowRemainder.getNumOccurrences() > 0) + UP.AllowRemainder = UnrollAllowRemainder; if (UnrollRuntime.getNumOccurrences() > 0) UP.Runtime = UnrollRuntime; @@ -153,259 +170,42 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( if (UserRuntime.hasValue()) UP.Runtime = *UserRuntime; - if (PragmaCount > 0 || - ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount != 0)) { - // If the loop has an unrolling pragma, we want to be more aggressive with - // unrolling limits. Set thresholds to at least the PragmaTheshold value - // which is larger than the default limits. - if (UP.Threshold != NoThreshold) - UP.Threshold = std::max<unsigned>(UP.Threshold, PragmaUnrollThreshold); - if (UP.PartialThreshold != NoThreshold) - UP.PartialThreshold = - std::max<unsigned>(UP.PartialThreshold, PragmaUnrollThreshold); - } - return UP; } namespace { -// This class is used to get an estimate of the optimization effects that we -// could get from complete loop unrolling. It comes from the fact that some -// loads might be replaced with concrete constant values and that could trigger -// a chain of instruction simplifications. -// -// E.g. we might have: -// int a[] = {0, 1, 0}; -// v = 0; -// for (i = 0; i < 3; i ++) -// v += b[i]*a[i]; -// If we completely unroll the loop, we would get: -// v = b[0]*a[0] + b[1]*a[1] + b[2]*a[2] -// Which then will be simplified to: -// v = b[0]* 0 + b[1]* 1 + b[2]* 0 -// And finally: -// v = b[1] -class UnrolledInstAnalyzer : private InstVisitor<UnrolledInstAnalyzer, bool> { - typedef InstVisitor<UnrolledInstAnalyzer, bool> Base; - friend class InstVisitor<UnrolledInstAnalyzer, bool>; - struct SimplifiedAddress { - Value *Base = nullptr; - ConstantInt *Offset = nullptr; - }; +/// A struct to densely store the state of an instruction after unrolling at +/// each iteration. +/// +/// This is designed to work like a tuple of <Instruction *, int> for the +/// purposes of hashing and lookup, but to be able to associate two boolean +/// states with each key. +struct UnrolledInstState { + Instruction *I; + int Iteration : 30; + unsigned IsFree : 1; + unsigned IsCounted : 1; +}; -public: - UnrolledInstAnalyzer(unsigned Iteration, - DenseMap<Value *, Constant *> &SimplifiedValues, - ScalarEvolution &SE) - : SimplifiedValues(SimplifiedValues), SE(SE) { - IterationNumber = SE.getConstant(APInt(64, Iteration)); +/// Hashing and equality testing for a set of the instruction states. +struct UnrolledInstStateKeyInfo { + typedef DenseMapInfo<Instruction *> PtrInfo; + typedef DenseMapInfo<std::pair<Instruction *, int>> PairInfo; + static inline UnrolledInstState getEmptyKey() { + return {PtrInfo::getEmptyKey(), 0, 0, 0}; } - - // Allow access to the initial visit method. - using Base::visit; - -private: - /// \brief A cache of pointer bases and constant-folded offsets corresponding - /// to GEP (or derived from GEP) instructions. - /// - /// In order to find the base pointer one needs to perform non-trivial - /// traversal of the corresponding SCEV expression, so it's good to have the - /// results saved. - DenseMap<Value *, SimplifiedAddress> SimplifiedAddresses; - - /// \brief SCEV expression corresponding to number of currently simulated - /// iteration. - const SCEV *IterationNumber; - - /// \brief A Value->Constant map for keeping values that we managed to - /// constant-fold on the given iteration. - /// - /// While we walk the loop instructions, we build up and maintain a mapping - /// of simplified values specific to this iteration. The idea is to propagate - /// any special information we have about loads that can be replaced with - /// constants after complete unrolling, and account for likely simplifications - /// post-unrolling. - DenseMap<Value *, Constant *> &SimplifiedValues; - - ScalarEvolution &SE; - - /// \brief Try to simplify instruction \param I using its SCEV expression. - /// - /// The idea is that some AddRec expressions become constants, which then - /// could trigger folding of other instructions. However, that only happens - /// for expressions whose start value is also constant, which isn't always the - /// case. In another common and important case the start value is just some - /// address (i.e. SCEVUnknown) - in this case we compute the offset and save - /// it along with the base address instead. - bool simplifyInstWithSCEV(Instruction *I) { - if (!SE.isSCEVable(I->getType())) - return false; - - const SCEV *S = SE.getSCEV(I); - if (auto *SC = dyn_cast<SCEVConstant>(S)) { - SimplifiedValues[I] = SC->getValue(); - return true; - } - - auto *AR = dyn_cast<SCEVAddRecExpr>(S); - if (!AR) - return false; - - const SCEV *ValueAtIteration = AR->evaluateAtIteration(IterationNumber, SE); - // Check if the AddRec expression becomes a constant. - if (auto *SC = dyn_cast<SCEVConstant>(ValueAtIteration)) { - SimplifiedValues[I] = SC->getValue(); - return true; - } - - // Check if the offset from the base address becomes a constant. - auto *Base = dyn_cast<SCEVUnknown>(SE.getPointerBase(S)); - if (!Base) - return false; - auto *Offset = - dyn_cast<SCEVConstant>(SE.getMinusSCEV(ValueAtIteration, Base)); - if (!Offset) - return false; - SimplifiedAddress Address; - Address.Base = Base->getValue(); - Address.Offset = Offset->getValue(); - SimplifiedAddresses[I] = Address; - return true; + static inline UnrolledInstState getTombstoneKey() { + return {PtrInfo::getTombstoneKey(), 0, 0, 0}; } - - /// Base case for the instruction visitor. - bool visitInstruction(Instruction &I) { - return simplifyInstWithSCEV(&I); + static inline unsigned getHashValue(const UnrolledInstState &S) { + return PairInfo::getHashValue({S.I, S.Iteration}); } - - /// Try to simplify binary operator I. - /// - /// TODO: Probably it's worth to hoist the code for estimating the - /// simplifications effects to a separate class, since we have a very similar - /// code in InlineCost already. - bool visitBinaryOperator(BinaryOperator &I) { - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - if (!isa<Constant>(LHS)) - if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) - LHS = SimpleLHS; - if (!isa<Constant>(RHS)) - if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) - RHS = SimpleRHS; - - Value *SimpleV = nullptr; - const DataLayout &DL = I.getModule()->getDataLayout(); - if (auto FI = dyn_cast<FPMathOperator>(&I)) - SimpleV = - SimplifyFPBinOp(I.getOpcode(), LHS, RHS, FI->getFastMathFlags(), DL); - else - SimpleV = SimplifyBinOp(I.getOpcode(), LHS, RHS, DL); - - if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) - SimplifiedValues[&I] = C; - - if (SimpleV) - return true; - return Base::visitBinaryOperator(I); - } - - /// Try to fold load I. - bool visitLoad(LoadInst &I) { - Value *AddrOp = I.getPointerOperand(); - - auto AddressIt = SimplifiedAddresses.find(AddrOp); - if (AddressIt == SimplifiedAddresses.end()) - return false; - ConstantInt *SimplifiedAddrOp = AddressIt->second.Offset; - - auto *GV = dyn_cast<GlobalVariable>(AddressIt->second.Base); - // We're only interested in loads that can be completely folded to a - // constant. - if (!GV || !GV->hasDefinitiveInitializer() || !GV->isConstant()) - return false; - - ConstantDataSequential *CDS = - dyn_cast<ConstantDataSequential>(GV->getInitializer()); - if (!CDS) - return false; - - // We might have a vector load from an array. FIXME: for now we just bail - // out in this case, but we should be able to resolve and simplify such - // loads. - if(!CDS->isElementTypeCompatible(I.getType())) - return false; - - int ElemSize = CDS->getElementType()->getPrimitiveSizeInBits() / 8U; - assert(SimplifiedAddrOp->getValue().getActiveBits() < 64 && - "Unexpectedly large index value."); - int64_t Index = SimplifiedAddrOp->getSExtValue() / ElemSize; - if (Index >= CDS->getNumElements()) { - // FIXME: For now we conservatively ignore out of bound accesses, but - // we're allowed to perform the optimization in this case. - return false; - } - - Constant *CV = CDS->getElementAsConstant(Index); - assert(CV && "Constant expected."); - SimplifiedValues[&I] = CV; - - return true; - } - - bool visitCastInst(CastInst &I) { - // Propagate constants through casts. - Constant *COp = dyn_cast<Constant>(I.getOperand(0)); - if (!COp) - COp = SimplifiedValues.lookup(I.getOperand(0)); - if (COp) - if (Constant *C = - ConstantExpr::getCast(I.getOpcode(), COp, I.getType())) { - SimplifiedValues[&I] = C; - return true; - } - - return Base::visitCastInst(I); - } - - bool visitCmpInst(CmpInst &I) { - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - - // First try to handle simplified comparisons. - if (!isa<Constant>(LHS)) - if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) - LHS = SimpleLHS; - if (!isa<Constant>(RHS)) - if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) - RHS = SimpleRHS; - - if (!isa<Constant>(LHS) && !isa<Constant>(RHS)) { - auto SimplifiedLHS = SimplifiedAddresses.find(LHS); - if (SimplifiedLHS != SimplifiedAddresses.end()) { - auto SimplifiedRHS = SimplifiedAddresses.find(RHS); - if (SimplifiedRHS != SimplifiedAddresses.end()) { - SimplifiedAddress &LHSAddr = SimplifiedLHS->second; - SimplifiedAddress &RHSAddr = SimplifiedRHS->second; - if (LHSAddr.Base == RHSAddr.Base) { - LHS = LHSAddr.Offset; - RHS = RHSAddr.Offset; - } - } - } - } - - if (Constant *CLHS = dyn_cast<Constant>(LHS)) { - if (Constant *CRHS = dyn_cast<Constant>(RHS)) { - if (Constant *C = ConstantExpr::getCompare(I.getPredicate(), CLHS, CRHS)) { - SimplifiedValues[&I] = C; - return true; - } - } - } - - return Base::visitCmpInst(I); + static inline bool isEqual(const UnrolledInstState &LHS, + const UnrolledInstState &RHS) { + return PairInfo::isEqual({LHS.I, LHS.Iteration}, {RHS.I, RHS.Iteration}); } }; -} // namespace - +} namespace { struct EstimatedUnrollCost { @@ -441,18 +241,25 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, assert(UnrollMaxIterationsCountToAnalyze < (INT_MAX / 2) && "The unroll iterations max is too large!"); + // Only analyze inner loops. We can't properly estimate cost of nested loops + // and we won't visit inner loops again anyway. + if (!L->empty()) + return None; + // Don't simulate loops with a big or unknown tripcount if (!UnrollMaxIterationsCountToAnalyze || !TripCount || TripCount > UnrollMaxIterationsCountToAnalyze) return None; SmallSetVector<BasicBlock *, 16> BBWorklist; + SmallSetVector<std::pair<BasicBlock *, BasicBlock *>, 4> ExitWorklist; DenseMap<Value *, Constant *> SimplifiedValues; SmallVector<std::pair<Value *, Constant *>, 4> SimplifiedInputValues; // The estimated cost of the unrolled form of the loop. We try to estimate // this by simplifying as much as we can while computing the estimate. int UnrolledCost = 0; + // We also track the estimated dynamic (that is, actually executed) cost in // the rolled form. This helps identify cases when the savings from unrolling // aren't just exposing dead control flows, but actual reduced dynamic @@ -460,6 +267,97 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // unrolling. int RolledDynamicCost = 0; + // We track the simplification of each instruction in each iteration. We use + // this to recursively merge costs into the unrolled cost on-demand so that + // we don't count the cost of any dead code. This is essentially a map from + // <instruction, int> to <bool, bool>, but stored as a densely packed struct. + DenseSet<UnrolledInstState, UnrolledInstStateKeyInfo> InstCostMap; + + // A small worklist used to accumulate cost of instructions from each + // observable and reached root in the loop. + SmallVector<Instruction *, 16> CostWorklist; + + // PHI-used worklist used between iterations while accumulating cost. + SmallVector<Instruction *, 4> PHIUsedList; + + // Helper function to accumulate cost for instructions in the loop. + auto AddCostRecursively = [&](Instruction &RootI, int Iteration) { + assert(Iteration >= 0 && "Cannot have a negative iteration!"); + assert(CostWorklist.empty() && "Must start with an empty cost list"); + assert(PHIUsedList.empty() && "Must start with an empty phi used list"); + CostWorklist.push_back(&RootI); + for (;; --Iteration) { + do { + Instruction *I = CostWorklist.pop_back_val(); + + // InstCostMap only uses I and Iteration as a key, the other two values + // don't matter here. + auto CostIter = InstCostMap.find({I, Iteration, 0, 0}); + if (CostIter == InstCostMap.end()) + // If an input to a PHI node comes from a dead path through the loop + // we may have no cost data for it here. What that actually means is + // that it is free. + continue; + auto &Cost = *CostIter; + if (Cost.IsCounted) + // Already counted this instruction. + continue; + + // Mark that we are counting the cost of this instruction now. + Cost.IsCounted = true; + + // If this is a PHI node in the loop header, just add it to the PHI set. + if (auto *PhiI = dyn_cast<PHINode>(I)) + if (PhiI->getParent() == L->getHeader()) { + assert(Cost.IsFree && "Loop PHIs shouldn't be evaluated as they " + "inherently simplify during unrolling."); + if (Iteration == 0) + continue; + + // Push the incoming value from the backedge into the PHI used list + // if it is an in-loop instruction. We'll use this to populate the + // cost worklist for the next iteration (as we count backwards). + if (auto *OpI = dyn_cast<Instruction>( + PhiI->getIncomingValueForBlock(L->getLoopLatch()))) + if (L->contains(OpI)) + PHIUsedList.push_back(OpI); + continue; + } + + // First accumulate the cost of this instruction. + if (!Cost.IsFree) { + UnrolledCost += TTI.getUserCost(I); + DEBUG(dbgs() << "Adding cost of instruction (iteration " << Iteration + << "): "); + DEBUG(I->dump()); + } + + // We must count the cost of every operand which is not free, + // recursively. If we reach a loop PHI node, simply add it to the set + // to be considered on the next iteration (backwards!). + for (Value *Op : I->operands()) { + // Check whether this operand is free due to being a constant or + // outside the loop. + auto *OpI = dyn_cast<Instruction>(Op); + if (!OpI || !L->contains(OpI)) + continue; + + // Otherwise accumulate its cost. + CostWorklist.push_back(OpI); + } + } while (!CostWorklist.empty()); + + if (PHIUsedList.empty()) + // We've exhausted the search. + break; + + assert(Iteration > 0 && + "Cannot track PHI-used values past the first iteration!"); + CostWorklist.append(PHIUsedList.begin(), PHIUsedList.end()); + PHIUsedList.clear(); + } + }; + // Ensure that we don't violate the loop structure invariants relied on by // this analysis. assert(L->isLoopSimplifyForm() && "Must put loop into normal form first."); @@ -502,7 +400,7 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, while (!SimplifiedInputValues.empty()) SimplifiedValues.insert(SimplifiedInputValues.pop_back_val()); - UnrolledInstAnalyzer Analyzer(Iteration, SimplifiedValues, SE); + UnrolledInstAnalyzer Analyzer(Iteration, SimplifiedValues, SE, L); BBWorklist.clear(); BBWorklist.insert(L->getHeader()); @@ -514,22 +412,32 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // it. We don't change the actual IR, just count optimization // opportunities. for (Instruction &I : *BB) { - int InstCost = TTI.getUserCost(&I); + // Track this instruction's expected baseline cost when executing the + // rolled loop form. + RolledDynamicCost += TTI.getUserCost(&I); // Visit the instruction to analyze its loop cost after unrolling, - // and if the visitor returns false, include this instruction in the - // unrolled cost. - if (!Analyzer.visit(I)) - UnrolledCost += InstCost; - else { - DEBUG(dbgs() << " " << I - << " would be simplified if loop is unrolled.\n"); - (void)0; - } + // and if the visitor returns true, mark the instruction as free after + // unrolling and continue. + bool IsFree = Analyzer.visit(I); + bool Inserted = InstCostMap.insert({&I, (int)Iteration, + (unsigned)IsFree, + /*IsCounted*/ false}).second; + (void)Inserted; + assert(Inserted && "Cannot have a state for an unvisited instruction!"); + + if (IsFree) + continue; - // Also track this instructions expected cost when executing the rolled - // loop form. - RolledDynamicCost += InstCost; + // If the instruction might have a side-effect recursively account for + // the cost of it and all the instructions leading up to it. + if (I.mayHaveSideEffects()) + AddCostRecursively(I, Iteration); + + // Can't properly model a cost of a call. + // FIXME: With a proper cost model we should be able to do it. + if(isa<CallInst>(&I)) + return None; // If unrolled body turns out to be too big, bail out. if (UnrolledCost > MaxUnrolledLoopSize) { @@ -545,42 +453,45 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // Add in the live successors by first checking whether we have terminator // that may be simplified based on the values simplified by this call. + BasicBlock *KnownSucc = nullptr; if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { if (BI->isConditional()) { if (Constant *SimpleCond = SimplifiedValues.lookup(BI->getCondition())) { - BasicBlock *Succ = nullptr; // Just take the first successor if condition is undef if (isa<UndefValue>(SimpleCond)) - Succ = BI->getSuccessor(0); - else - Succ = BI->getSuccessor( - cast<ConstantInt>(SimpleCond)->isZero() ? 1 : 0); - if (L->contains(Succ)) - BBWorklist.insert(Succ); - continue; + KnownSucc = BI->getSuccessor(0); + else if (ConstantInt *SimpleCondVal = + dyn_cast<ConstantInt>(SimpleCond)) + KnownSucc = BI->getSuccessor(SimpleCondVal->isZero() ? 1 : 0); } } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { if (Constant *SimpleCond = SimplifiedValues.lookup(SI->getCondition())) { - BasicBlock *Succ = nullptr; // Just take the first successor if condition is undef if (isa<UndefValue>(SimpleCond)) - Succ = SI->getSuccessor(0); - else - Succ = SI->findCaseValue(cast<ConstantInt>(SimpleCond)) - .getCaseSuccessor(); - if (L->contains(Succ)) - BBWorklist.insert(Succ); - continue; + KnownSucc = SI->getSuccessor(0); + else if (ConstantInt *SimpleCondVal = + dyn_cast<ConstantInt>(SimpleCond)) + KnownSucc = SI->findCaseValue(SimpleCondVal).getCaseSuccessor(); } } + if (KnownSucc) { + if (L->contains(KnownSucc)) + BBWorklist.insert(KnownSucc); + else + ExitWorklist.insert({BB, KnownSucc}); + continue; + } // Add BB's successors to the worklist. for (BasicBlock *Succ : successors(BB)) if (L->contains(Succ)) BBWorklist.insert(Succ); + else + ExitWorklist.insert({BB, Succ}); + AddCostRecursively(*TI, Iteration); } // If we found no optimization opportunities on the first iteration, we @@ -591,6 +502,23 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, return None; } } + + while (!ExitWorklist.empty()) { + BasicBlock *ExitingBB, *ExitBB; + std::tie(ExitingBB, ExitBB) = ExitWorklist.pop_back_val(); + + for (Instruction &I : *ExitBB) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) + break; + + Value *Op = PN->getIncomingValueForBlock(ExitingBB); + if (auto *OpI = dyn_cast<Instruction>(Op)) + if (L->contains(OpI)) + AddCostRecursively(*OpI, TripCount - 1); + } + } + DEBUG(dbgs() << "Analysis finished:\n" << "UnrolledCost: " << UnrolledCost << ", " << "RolledDynamicCost: " << RolledDynamicCost << "\n"); @@ -599,18 +527,18 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, /// ApproximateLoopSize - Approximate the size of the loop. static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls, - bool &NotDuplicatable, + bool &NotDuplicatable, bool &Convergent, const TargetTransformInfo &TTI, AssumptionCache *AC) { SmallPtrSet<const Value *, 32> EphValues; CodeMetrics::collectEphemeralValues(L, AC, EphValues); CodeMetrics Metrics; - for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); - I != E; ++I) - Metrics.analyzeBasicBlock(*I, TTI, EphValues); + for (BasicBlock *BB : L->blocks()) + Metrics.analyzeBasicBlock(BB, TTI, EphValues); NumCalls = Metrics.NumInlineCandidates; NotDuplicatable = Metrics.notDuplicatable; + Convergent = Metrics.convergent; unsigned LoopSize = Metrics.NumInsts; @@ -676,21 +604,22 @@ static unsigned UnrollCountPragmaValue(const Loop *L) { // unrolling pass is run more than once (which it generally is). static void SetLoopAlreadyUnrolled(Loop *L) { MDNode *LoopID = L->getLoopID(); - if (!LoopID) return; - // First remove any existing loop unrolling metadata. SmallVector<Metadata *, 4> MDs; // Reserve first location for self reference to the LoopID metadata node. MDs.push_back(nullptr); - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { - bool IsUnrollMetadata = false; - MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); - if (MD) { - const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); - IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); + + if (LoopID) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + bool IsUnrollMetadata = false; + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (MD) { + const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); + } + if (!IsUnrollMetadata) + MDs.push_back(LoopID->getOperand(i)); } - if (!IsUnrollMetadata) - MDs.push_back(LoopID->getOperand(i)); } // Add unroll(disable) metadata to disable future unrolling. @@ -737,9 +666,9 @@ static bool canUnrollCompletely(Loop *L, unsigned Threshold, (int64_t)UnrolledCost - (int64_t)DynamicCostSavingsDiscount <= (int64_t)Threshold) { DEBUG(dbgs() << " Can fully unroll, because unrolling will reduce the " - "expected dynamic cost by " << PercentDynamicCostSaved - << "% (threshold: " << PercentDynamicCostSavedThreshold - << "%)\n" + "expected dynamic cost by " + << PercentDynamicCostSaved << "% (threshold: " + << PercentDynamicCostSavedThreshold << "%)\n" << " and the unrolled cost (" << UnrolledCost << ") is less than the max threshold (" << DynamicCostSavingsDiscount << ").\n"); @@ -758,82 +687,77 @@ static bool canUnrollCompletely(Loop *L, unsigned Threshold, return false; } -static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, - ScalarEvolution *SE, const TargetTransformInfo &TTI, - AssumptionCache &AC, bool PreserveLCSSA, - Optional<unsigned> ProvidedCount, - Optional<unsigned> ProvidedThreshold, - Optional<bool> ProvidedAllowPartial, - Optional<bool> ProvidedRuntime) { - BasicBlock *Header = L->getHeader(); - DEBUG(dbgs() << "Loop Unroll: F[" << Header->getParent()->getName() - << "] Loop %" << Header->getName() << "\n"); +// Returns true if unroll count was set explicitly. +// Calculates unroll count and writes it to UP.Count. +static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, + DominatorTree &DT, LoopInfo *LI, + ScalarEvolution *SE, unsigned TripCount, + unsigned TripMultiple, unsigned LoopSize, + TargetTransformInfo::UnrollingPreferences &UP) { + // BEInsns represents number of instructions optimized when "back edge" + // becomes "fall through" in unrolled loop. + // For now we count a conditional branch on a backedge and a comparison + // feeding it. + unsigned BEInsns = 2; + // Check for explicit Count. + // 1st priority is unroll count set by "unroll-count" option. + bool UserUnrollCount = UnrollCount.getNumOccurrences() > 0; + if (UserUnrollCount) { + UP.Count = UnrollCount; + UP.AllowExpensiveTripCount = true; + UP.Force = true; + if (UP.AllowRemainder && + (LoopSize - BEInsns) * UP.Count + BEInsns < UP.Threshold) + return true; + } - if (HasUnrollDisablePragma(L)) { - return false; + // 2nd priority is unroll count set by pragma. + unsigned PragmaCount = UnrollCountPragmaValue(L); + if (PragmaCount > 0) { + UP.Count = PragmaCount; + UP.Runtime = true; + UP.AllowExpensiveTripCount = true; + UP.Force = true; + if (UP.AllowRemainder && + (LoopSize - BEInsns) * UP.Count + BEInsns < PragmaUnrollThreshold) + return true; } bool PragmaFullUnroll = HasUnrollFullPragma(L); - bool PragmaEnableUnroll = HasUnrollEnablePragma(L); - unsigned PragmaCount = UnrollCountPragmaValue(L); - bool HasPragma = PragmaFullUnroll || PragmaEnableUnroll || PragmaCount > 0; - - // Find trip count and trip multiple if count is not available - unsigned TripCount = 0; - unsigned TripMultiple = 1; - // If there are multiple exiting blocks but one of them is the latch, use the - // latch for the trip count estimation. Otherwise insist on a single exiting - // block for the trip count estimation. - BasicBlock *ExitingBlock = L->getLoopLatch(); - if (!ExitingBlock || !L->isLoopExiting(ExitingBlock)) - ExitingBlock = L->getExitingBlock(); - if (ExitingBlock) { - TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); - TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); + if (PragmaFullUnroll && TripCount != 0) { + UP.Count = TripCount; + if ((LoopSize - BEInsns) * UP.Count + BEInsns < PragmaUnrollThreshold) + return false; } - TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( - L, TTI, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, - ProvidedRuntime, PragmaCount, PragmaFullUnroll, PragmaEnableUnroll, - TripCount); - - unsigned Count = UP.Count; - bool CountSetExplicitly = Count != 0; - // Use a heuristic count if we didn't set anything explicitly. - if (!CountSetExplicitly) - Count = TripCount == 0 ? DefaultUnrollRuntimeCount : TripCount; - if (TripCount && Count > TripCount) - Count = TripCount; + bool PragmaEnableUnroll = HasUnrollEnablePragma(L); + bool ExplicitUnroll = PragmaCount > 0 || PragmaFullUnroll || + PragmaEnableUnroll || UserUnrollCount; - unsigned NumInlineCandidates; - bool notDuplicatable; - unsigned LoopSize = - ApproximateLoopSize(L, NumInlineCandidates, notDuplicatable, TTI, &AC); - DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); + uint64_t UnrolledSize; + DebugLoc LoopLoc = L->getStartLoc(); + Function *F = L->getHeader()->getParent(); + LLVMContext &Ctx = F->getContext(); - // When computing the unrolled size, note that the conditional branch on the - // backedge and the comparison feeding it are not replicated like the rest of - // the loop body (which is why 2 is subtracted). - uint64_t UnrolledSize = (uint64_t)(LoopSize-2) * Count + 2; - if (notDuplicatable) { - DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" - << " instructions.\n"); - return false; - } - if (NumInlineCandidates != 0) { - DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); - return false; + if (ExplicitUnroll && TripCount != 0) { + // If the loop has an unrolling pragma, we want to be more aggressive with + // unrolling limits. Set thresholds to at least the PragmaThreshold value + // which is larger than the default limits. + UP.Threshold = std::max<unsigned>(UP.Threshold, PragmaUnrollThreshold); + UP.PartialThreshold = + std::max<unsigned>(UP.PartialThreshold, PragmaUnrollThreshold); } - // Given Count, TripCount and thresholds determine the type of - // unrolling which is to be performed. - enum { Full = 0, Partial = 1, Runtime = 2 }; - int Unrolling; - if (TripCount && Count == TripCount) { - Unrolling = Partial; - // If the loop is really small, we don't need to run an expensive analysis. + // 3rd priority is full unroll count. + // Full unroll make sense only when TripCount could be staticaly calculated. + // Also we need to check if we exceed FullUnrollMaxCount. + if (TripCount && TripCount <= UP.FullUnrollMaxCount) { + // When computing the unrolled size, note that BEInsns are not replicated + // like the rest of the loop body. + UnrolledSize = (uint64_t)(LoopSize - BEInsns) * TripCount + BEInsns; if (canUnrollCompletely(L, UP.Threshold, 100, UP.DynamicCostSavingsDiscount, UnrolledSize, UnrolledSize)) { - Unrolling = Full; + UP.Count = TripCount; + return ExplicitUnroll; } else { // The loop isn't that small, but we still can fully unroll it if that // helps to remove a significant number of instructions. @@ -845,99 +769,216 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, UP.PercentDynamicCostSavedThreshold, UP.DynamicCostSavingsDiscount, Cost->UnrolledCost, Cost->RolledDynamicCost)) { - Unrolling = Full; + UP.Count = TripCount; + return ExplicitUnroll; } } - } else if (TripCount && Count < TripCount) { - Unrolling = Partial; - } else { - Unrolling = Runtime; } - // Reduce count based on the type of unrolling and the threshold values. - unsigned OriginalCount = Count; - bool AllowRuntime = PragmaEnableUnroll || (PragmaCount > 0) || UP.Runtime; - // Don't unroll a runtime trip count loop with unroll full pragma. - if (HasRuntimeUnrollDisablePragma(L) || PragmaFullUnroll) { - AllowRuntime = false; - } - if (Unrolling == Partial) { - bool AllowPartial = PragmaEnableUnroll || UP.Partial; - if (!AllowPartial && !CountSetExplicitly) { + // 4rd priority is partial unrolling. + // Try partial unroll only when TripCount could be staticaly calculated. + if (TripCount) { + if (UP.Count == 0) + UP.Count = TripCount; + UP.Partial |= ExplicitUnroll; + if (!UP.Partial) { DEBUG(dbgs() << " will not try to unroll partially because " << "-unroll-allow-partial not given\n"); + UP.Count = 0; return false; } - if (UP.PartialThreshold != NoThreshold && - UnrolledSize > UP.PartialThreshold) { + if (UP.PartialThreshold != NoThreshold) { // Reduce unroll count to be modulo of TripCount for partial unrolling. - Count = (std::max(UP.PartialThreshold, 3u) - 2) / (LoopSize - 2); - while (Count != 0 && TripCount % Count != 0) - Count--; - } - } else if (Unrolling == Runtime) { - if (!AllowRuntime && !CountSetExplicitly) { - DEBUG(dbgs() << " will not try to unroll loop with runtime trip count " - << "-unroll-runtime not given\n"); - return false; - } - // Reduce unroll count to be the largest power-of-two factor of - // the original count which satisfies the threshold limit. - while (Count != 0 && UnrolledSize > UP.PartialThreshold) { - Count >>= 1; - UnrolledSize = (LoopSize-2) * Count + 2; + UnrolledSize = (uint64_t)(LoopSize - BEInsns) * UP.Count + BEInsns; + if (UnrolledSize > UP.PartialThreshold) + UP.Count = (std::max(UP.PartialThreshold, 3u) - BEInsns) / + (LoopSize - BEInsns); + if (UP.Count > UP.MaxCount) + UP.Count = UP.MaxCount; + while (UP.Count != 0 && TripCount % UP.Count != 0) + UP.Count--; + if (UP.AllowRemainder && UP.Count <= 1) { + // If there is no Count that is modulo of TripCount, set Count to + // largest power-of-two factor that satisfies the threshold limit. + // As we'll create fixup loop, do the type of unrolling only if + // remainder loop is allowed. + UP.Count = DefaultUnrollRuntimeCount; + UnrolledSize = (LoopSize - BEInsns) * UP.Count + BEInsns; + while (UP.Count != 0 && UnrolledSize > UP.PartialThreshold) { + UP.Count >>= 1; + UnrolledSize = (LoopSize - BEInsns) * UP.Count + BEInsns; + } + } + if (UP.Count < 2) { + if (PragmaEnableUnroll) + emitOptimizationRemarkMissed( + Ctx, DEBUG_TYPE, *F, LoopLoc, + "Unable to unroll loop as directed by unroll(enable) pragma " + "because unrolled size is too large."); + UP.Count = 0; + } + } else { + UP.Count = TripCount; } - if (Count > UP.MaxCount) - Count = UP.MaxCount; - DEBUG(dbgs() << " partially unrolling with count: " << Count << "\n"); - } - - if (HasPragma) { - if (PragmaCount != 0) - // If loop has an unroll count pragma mark loop as unrolled to prevent - // unrolling beyond that requested by the pragma. - SetLoopAlreadyUnrolled(L); - - // Emit optimization remarks if we are unable to unroll the loop - // as directed by a pragma. - DebugLoc LoopLoc = L->getStartLoc(); - Function *F = Header->getParent(); - LLVMContext &Ctx = F->getContext(); - if ((PragmaCount > 0) && Count != OriginalCount) { - emitOptimizationRemarkMissed( - Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to unroll loop the number of times directed by " - "unroll_count pragma because unrolled size is too large."); - } else if (PragmaFullUnroll && !TripCount) { - emitOptimizationRemarkMissed( - Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to fully unroll loop as directed by unroll(full) pragma " - "because loop has a runtime trip count."); - } else if (PragmaEnableUnroll && Count != TripCount && Count < 2) { + if ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount && + UP.Count != TripCount) emitOptimizationRemarkMissed( Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to unroll loop as directed by unroll(enable) pragma because " + "Unable to fully unroll loop as directed by unroll pragma because " "unrolled size is too large."); - } else if ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount && - Count != TripCount) { + return ExplicitUnroll; + } + assert(TripCount == 0 && + "All cases when TripCount is constant should be covered here."); + if (PragmaFullUnroll) + emitOptimizationRemarkMissed( + Ctx, DEBUG_TYPE, *F, LoopLoc, + "Unable to fully unroll loop as directed by unroll(full) pragma " + "because loop has a runtime trip count."); + + // 5th priority is runtime unrolling. + // Don't unroll a runtime trip count loop when it is disabled. + if (HasRuntimeUnrollDisablePragma(L)) { + UP.Count = 0; + return false; + } + // Reduce count based on the type of unrolling and the threshold values. + UP.Runtime |= PragmaEnableUnroll || PragmaCount > 0 || UserUnrollCount; + if (!UP.Runtime) { + DEBUG(dbgs() << " will not try to unroll loop with runtime trip count " + << "-unroll-runtime not given\n"); + UP.Count = 0; + return false; + } + if (UP.Count == 0) + UP.Count = DefaultUnrollRuntimeCount; + UnrolledSize = (LoopSize - BEInsns) * UP.Count + BEInsns; + + // Reduce unroll count to be the largest power-of-two factor of + // the original count which satisfies the threshold limit. + while (UP.Count != 0 && UnrolledSize > UP.PartialThreshold) { + UP.Count >>= 1; + UnrolledSize = (LoopSize - BEInsns) * UP.Count + BEInsns; + } + +#ifndef NDEBUG + unsigned OrigCount = UP.Count; +#endif + + if (!UP.AllowRemainder && UP.Count != 0 && (TripMultiple % UP.Count) != 0) { + while (UP.Count != 0 && TripMultiple % UP.Count != 0) + UP.Count >>= 1; + DEBUG(dbgs() << "Remainder loop is restricted (that could architecture " + "specific or because the loop contains a convergent " + "instruction), so unroll count must divide the trip " + "multiple, " + << TripMultiple << ". Reducing unroll count from " + << OrigCount << " to " << UP.Count << ".\n"); + if (PragmaCount > 0 && !UP.AllowRemainder) emitOptimizationRemarkMissed( Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to fully unroll loop as directed by unroll pragma because " - "unrolled size is too large."); - } + Twine("Unable to unroll loop the number of times directed by " + "unroll_count pragma because remainder loop is restricted " + "(that could architecture specific or because the loop " + "contains a convergent instruction) and so must have an unroll " + "count that divides the loop trip multiple of ") + + Twine(TripMultiple) + ". Unrolling instead " + Twine(UP.Count) + + " time(s)."); } - if (Unrolling != Full && Count < 2) { - // Partial unrolling by 1 is a nop. For full unrolling, a factor - // of 1 makes sense because loop control can be eliminated. + if (UP.Count > UP.MaxCount) + UP.Count = UP.MaxCount; + DEBUG(dbgs() << " partially unrolling with count: " << UP.Count << "\n"); + if (UP.Count < 2) + UP.Count = 0; + return ExplicitUnroll; +} + +static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, + ScalarEvolution *SE, const TargetTransformInfo &TTI, + AssumptionCache &AC, bool PreserveLCSSA, + Optional<unsigned> ProvidedCount, + Optional<unsigned> ProvidedThreshold, + Optional<bool> ProvidedAllowPartial, + Optional<bool> ProvidedRuntime) { + DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() + << "] Loop %" << L->getHeader()->getName() << "\n"); + if (HasUnrollDisablePragma(L)) { return false; } + unsigned NumInlineCandidates; + bool NotDuplicatable; + bool Convergent; + unsigned LoopSize = ApproximateLoopSize( + L, NumInlineCandidates, NotDuplicatable, Convergent, TTI, &AC); + DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); + if (NotDuplicatable) { + DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" + << " instructions.\n"); + return false; + } + if (NumInlineCandidates != 0) { + DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); + return false; + } + if (!L->isLoopSimplifyForm()) { + DEBUG( + dbgs() << " Not unrolling loop which is not in loop-simplify form.\n"); + return false; + } + + // Find trip count and trip multiple if count is not available + unsigned TripCount = 0; + unsigned TripMultiple = 1; + // If there are multiple exiting blocks but one of them is the latch, use the + // latch for the trip count estimation. Otherwise insist on a single exiting + // block for the trip count estimation. + BasicBlock *ExitingBlock = L->getLoopLatch(); + if (!ExitingBlock || !L->isLoopExiting(ExitingBlock)) + ExitingBlock = L->getExitingBlock(); + if (ExitingBlock) { + TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); + TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); + } + + TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( + L, TTI, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, + ProvidedRuntime); + + // If the loop contains a convergent operation, the prelude we'd add + // to do the first few instructions before we hit the unrolled loop + // is unsafe -- it adds a control-flow dependency to the convergent + // operation. Therefore restrict remainder loop (try unrollig without). + // + // TODO: This is quite conservative. In practice, convergent_op() + // is likely to be called unconditionally in the loop. In this + // case, the program would be ill-formed (on most architectures) + // unless n were the same on all threads in a thread group. + // Assuming n is the same on all threads, any kind of unrolling is + // safe. But currently llvm's notion of convergence isn't powerful + // enough to express this. + if (Convergent) + UP.AllowRemainder = false; + + bool IsCountSetExplicitly = computeUnrollCount(L, TTI, DT, LI, SE, TripCount, + TripMultiple, LoopSize, UP); + if (!UP.Count) + return false; + // Unroll factor (Count) must be less or equal to TripCount. + if (TripCount && UP.Count > TripCount) + UP.Count = TripCount; + // Unroll the loop. - if (!UnrollLoop(L, Count, TripCount, AllowRuntime, UP.AllowExpensiveTripCount, - TripMultiple, LI, SE, &DT, &AC, PreserveLCSSA)) + if (!UnrollLoop(L, UP.Count, TripCount, UP.Force, UP.Runtime, + UP.AllowExpensiveTripCount, TripMultiple, LI, SE, &DT, &AC, + PreserveLCSSA)) return false; + // If loop has an unroll count pragma or unrolled by explicitly set count + // mark loop as unrolled to prevent unrolling beyond that requested. + if (IsCountSetExplicitly) + SetLoopAlreadyUnrolled(L); return true; } @@ -948,8 +989,9 @@ public: LoopUnroll(Optional<unsigned> Threshold = None, Optional<unsigned> Count = None, Optional<bool> AllowPartial = None, Optional<bool> Runtime = None) - : LoopPass(ID), ProvidedCount(Count), ProvidedThreshold(Threshold), - ProvidedAllowPartial(AllowPartial), ProvidedRuntime(Runtime) { + : LoopPass(ID), ProvidedCount(std::move(Count)), + ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), + ProvidedRuntime(Runtime) { initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); } @@ -959,7 +1001,7 @@ public: Optional<bool> ProvidedRuntime; bool runOnLoop(Loop *L, LPPassManager &) override { - if (skipOptnoneFunction(L)) + if (skipLoop(L)) return false; Function &F = *L->getHeader()->getParent(); @@ -982,35 +1024,19 @@ public: /// void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addPreservedID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); - AU.addPreservedID(LCSSAID); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); - // FIXME: Loop unroll requires LCSSA. And LCSSA requires dom info. - // If loop unroll does not preserve dom info then LCSSA pass on next - // loop will receive invalid dom info. - // For now, recreate dom info, if loop is unrolled. - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); + // FIXME: Loop passes are required to preserve domtree, and for now we just + // recreate dom info if anything gets unrolled. + getLoopAnalysisUsage(AU); } }; } char LoopUnroll::ID = 0; INITIALIZE_PASS_BEGIN(LoopUnroll, "loop-unroll", "Unroll loops", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSA) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false) Pass *llvm::createLoopUnrollPass(int Threshold, int Count, int AllowPartial, diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index 95d7f8a3beda..71980e85e8ca 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -55,6 +55,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" #include <algorithm> #include <map> #include <set> @@ -64,6 +65,7 @@ using namespace llvm; STATISTIC(NumBranches, "Number of branches unswitched"); STATISTIC(NumSwitches, "Number of switches unswitched"); +STATISTIC(NumGuards, "Number of guards unswitched"); STATISTIC(NumSelects , "Number of selects unswitched"); STATISTIC(NumTrivial , "Number of unswitches that are trivial"); STATISTIC(NumSimplify, "Number of simplifications of unswitched code"); @@ -187,6 +189,9 @@ namespace { BasicBlock *loopHeader; BasicBlock *loopPreheader; + bool SanitizeMemory; + LoopSafetyInfo SafetyInfo; + // LoopBlocks contains all of the basic blocks of the loop, including the // preheader of the loop, the body of the loop, and the exit blocks of the // loop, in that order. @@ -211,17 +216,8 @@ namespace { /// void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); - AU.addRequiredID(LoopSimplifyID); - AU.addPreservedID(LoopSimplifyID); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequiredID(LCSSAID); - AU.addPreservedID(LCSSAID); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); + getLoopAnalysisUsage(AU); } private: @@ -382,11 +378,9 @@ void LUAnalysisCache::cloneData(const Loop *NewLoop, const Loop *OldLoop, char LoopUnswitch::ID = 0; INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LCSSA) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) @@ -396,7 +390,11 @@ Pass *llvm::createLoopUnswitchPass(bool Os) { /// Cond is a condition that occurs in L. If it is invariant in the loop, or has /// an invariant piece, return the invariant. Otherwise, return null. -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { +static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + DenseMap<Value *, Value *> &Cache) { + auto CacheIt = Cache.find(Cond); + if (CacheIt != Cache.end()) + return CacheIt->second; // We started analyze new instruction, increment scanned instructions counter. ++TotalInsts; @@ -411,8 +409,10 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { // TODO: Handle: br (VARIANT|INVARIANT). // Hoist simple values out. - if (L->makeLoopInvariant(Cond, Changed)) + if (L->makeLoopInvariant(Cond, Changed)) { + Cache[Cond] = Cond; return Cond; + } if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond)) if (BO->getOpcode() == Instruction::And || @@ -420,17 +420,29 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { // If either the left or right side is invariant, we can unswitch on this, // which will cause the branch to go away in one loop and the condition to // simplify in the other one. - if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed)) + if (Value *LHS = + FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) { + Cache[Cond] = LHS; return LHS; - if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed)) + } + if (Value *RHS = + FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) { + Cache[Cond] = RHS; return RHS; + } } + Cache[Cond] = nullptr; return nullptr; } +static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { + DenseMap<Value *, Value *> Cache; + return FindLIVLoopCondition(Cond, L, Changed, Cache); +} + bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { - if (skipOptnoneFunction(L)) + if (skipLoop(L)) return false; AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( @@ -441,6 +453,10 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { currentLoop = L; Function *F = currentLoop->getHeader()->getParent(); + SanitizeMemory = F->hasFnAttribute(Attribute::SanitizeMemory); + if (SanitizeMemory) + computeLoopSafetyInfo(&SafetyInfo, L); + EnabledPGO = F->getEntryCount().hasValue(); if (LoopUnswitchWithBlockFrequency && EnabledPGO) { @@ -499,17 +515,34 @@ bool LoopUnswitch::processCurrentLoop() { return true; } - // Do not unswitch loops containing convergent operations, as we might be - // making them control dependent on the unswitch value when they were not - // before. - // FIXME: This could be refined to only bail if the convergent operation is - // not already control-dependent on the unswitch value. + // Run through the instructions in the loop, keeping track of three things: + // + // - That we do not unswitch loops containing convergent operations, as we + // might be making them control dependent on the unswitch value when they + // were not before. + // FIXME: This could be refined to only bail if the convergent operation is + // not already control-dependent on the unswitch value. + // + // - That basic blocks in the loop contain invokes whose predecessor edges we + // cannot split. + // + // - The set of guard intrinsics encountered (these are non terminator + // instructions that are also profitable to be unswitched). + + SmallVector<IntrinsicInst *, 4> Guards; + for (const auto BB : currentLoop->blocks()) { for (auto &I : *BB) { auto CS = CallSite(&I); if (!CS) continue; if (CS.hasFnAttr(Attribute::Convergent)) return false; + if (auto *II = dyn_cast<InvokeInst>(&I)) + if (!II->getUnwindDest()->canSplitPredecessors()) + return false; + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::experimental_guard) + Guards.push_back(II); } } @@ -529,12 +562,36 @@ bool LoopUnswitch::processCurrentLoop() { return false; } + for (IntrinsicInst *Guard : Guards) { + Value *LoopCond = + FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed); + if (LoopCond && + UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { + // NB! Unswitching (if successful) could have erased some of the + // instructions in Guards leaving dangling pointers there. This is fine + // because we're returning now, and won't look at Guards again. + ++NumGuards; + return true; + } + } + // Loop over all of the basic blocks in the loop. If we find an interior // block that is branching on a loop-invariant condition, we can unswitch this // loop. for (Loop::block_iterator I = currentLoop->block_begin(), E = currentLoop->block_end(); I != E; ++I) { TerminatorInst *TI = (*I)->getTerminator(); + + // Unswitching on a potentially uninitialized predicate is not + // MSan-friendly. Limit this to the cases when the original predicate is + // guaranteed to execute, to avoid creating a use-of-uninitialized-value + // in the code that did not have one. + // This is a workaround for the discrepancy between LLVM IR and MSan + // semantics. See PR28054 for more details. + if (SanitizeMemory && + !isGuaranteedToExecute(*TI, DT, currentLoop, &SafetyInfo)) + continue; + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { // If this isn't branching on an invariant condition, we can't unswitch // it. @@ -628,8 +685,8 @@ static bool isTrivialLoopExitBlockHelper(Loop *L, BasicBlock *BB, // Okay, everything after this looks good, check to make sure that this block // doesn't include any side effects. - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) - if (I->mayHaveSideEffects()) + for (Instruction &I : *BB) + if (I.mayHaveSideEffects()) return false; return true; @@ -679,8 +736,8 @@ static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, New.addBasicBlockToLoop(cast<BasicBlock>(VM[*I]), *LI); // Add all of the subloops to the new loop. - for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) - CloneLoop(*I, &New, VM, LI, LPM); + for (Loop *I : *L) + CloneLoop(I, &New, VM, LI, LPM); return &New; } @@ -1075,10 +1132,9 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, // Rewrite the code to refer to itself. for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) - for (BasicBlock::iterator I = NewBlocks[i]->begin(), - E = NewBlocks[i]->end(); I != E; ++I) - RemapInstruction(&*I, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + for (Instruction &I : *NewBlocks[i]) + RemapInstruction(&I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); // Rewrite the original preheader to select between versions of the loop. BranchInst *OldBR = cast<BranchInst>(loopPreheader->getTerminator()); @@ -1180,9 +1236,8 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, Worklist.push_back(UI); } - for (std::vector<Instruction*>::iterator UI = Worklist.begin(), - UE = Worklist.end(); UI != UE; ++UI) - (*UI)->replaceUsesOfWith(LIC, Replacement); + for (Instruction *UI : Worklist) + UI->replaceUsesOfWith(LIC, Replacement); SimplifyCode(Worklist, L); return; diff --git a/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/lib/Transforms/Scalar/LoopVersioningLICM.cpp new file mode 100644 index 000000000000..0ccf0af7165b --- /dev/null +++ b/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -0,0 +1,571 @@ +//===----------- LoopVersioningLICM.cpp - LICM Loop Versioning ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// When alias analysis is uncertain about the aliasing between any two accesses, +// it will return MayAlias. This uncertainty from alias analysis restricts LICM +// from proceeding further. In cases where alias analysis is uncertain we might +// use loop versioning as an alternative. +// +// Loop Versioning will create a version of the loop with aggressive aliasing +// assumptions in addition to the original with conservative (default) aliasing +// assumptions. The version of the loop making aggressive aliasing assumptions +// will have all the memory accesses marked as no-alias. These two versions of +// loop will be preceded by a memory runtime check. This runtime check consists +// of bound checks for all unique memory accessed in loop, and it ensures the +// lack of memory aliasing. The result of the runtime check determines which of +// the loop versions is executed: If the runtime check detects any memory +// aliasing, then the original loop is executed. Otherwise, the version with +// aggressive aliasing assumptions is used. +// +// Following are the top level steps: +// +// a) Perform LoopVersioningLICM's feasibility check. +// b) If loop is a candidate for versioning then create a memory bound check, +// by considering all the memory accesses in loop body. +// c) Clone original loop and set all memory accesses as no-alias in new loop. +// d) Set original loop & versioned loop as a branch target of the runtime check +// result. +// +// It transforms loop as shown below: +// +// +----------------+ +// |Runtime Memcheck| +// +----------------+ +// | +// +----------+----------------+----------+ +// | | +// +---------+----------+ +-----------+----------+ +// |Orig Loop Preheader | |Cloned Loop Preheader | +// +--------------------+ +----------------------+ +// | | +// +--------------------+ +----------------------+ +// |Orig Loop Body | |Cloned Loop Body | +// +--------------------+ +----------------------+ +// | | +// +--------------------+ +----------------------+ +// |Orig Loop Exit Block| |Cloned Loop Exit Block| +// +--------------------+ +-----------+----------+ +// | | +// +----------+--------------+-----------+ +// | +// +-----+----+ +// |Join Block| +// +----------+ +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredIteratorCache.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#define DEBUG_TYPE "loop-versioning-licm" +static const char* LICMVersioningMetaData = + "llvm.loop.licm_versioning.disable"; + +using namespace llvm; + +/// Threshold minimum allowed percentage for possible +/// invariant instructions in a loop. +static cl::opt<float> + LVInvarThreshold("licm-versioning-invariant-threshold", + cl::desc("LoopVersioningLICM's minimum allowed percentage" + "of possible invariant instructions per loop"), + cl::init(25), cl::Hidden); + +/// Threshold for maximum allowed loop nest/depth +static cl::opt<unsigned> LVLoopDepthThreshold( + "licm-versioning-max-depth-threshold", + cl::desc( + "LoopVersioningLICM's threshold for maximum allowed loop nest/depth"), + cl::init(2), cl::Hidden); + +/// \brief Create MDNode for input string. +static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) { + LLVMContext &Context = TheLoop->getHeader()->getContext(); + Metadata *MDs[] = { + MDString::get(Context, Name), + ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))}; + return MDNode::get(Context, MDs); +} + +/// \brief Set input string into loop metadata by keeping other values intact. +void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *MDString, + unsigned V) { + SmallVector<Metadata *, 4> MDs(1); + // If the loop already has metadata, retain it. + MDNode *LoopID = TheLoop->getLoopID(); + if (LoopID) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); + MDs.push_back(Node); + } + } + // Add new metadata. + MDs.push_back(createStringMetadata(TheLoop, MDString, V)); + // Replace current metadata node with new one. + LLVMContext &Context = TheLoop->getHeader()->getContext(); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + TheLoop->setLoopID(NewLoopID); +} + +namespace { +struct LoopVersioningLICM : public LoopPass { + static char ID; + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequiredID(LCSSAID); + AU.addRequired<LoopAccessLegacyAnalysis>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + + LoopVersioningLICM() + : LoopPass(ID), AA(nullptr), SE(nullptr), LI(nullptr), DT(nullptr), + TLI(nullptr), LAA(nullptr), LAI(nullptr), Changed(false), + Preheader(nullptr), CurLoop(nullptr), CurAST(nullptr), + LoopDepthThreshold(LVLoopDepthThreshold), + InvariantThreshold(LVInvarThreshold), LoadAndStoreCounter(0), + InvariantCounter(0), IsReadOnlyLoop(true) { + initializeLoopVersioningLICMPass(*PassRegistry::getPassRegistry()); + } + + AliasAnalysis *AA; // Current AliasAnalysis information + ScalarEvolution *SE; // Current ScalarEvolution + LoopInfo *LI; // Current LoopInfo + DominatorTree *DT; // Dominator Tree for the current Loop. + TargetLibraryInfo *TLI; // TargetLibraryInfo for constant folding. + LoopAccessLegacyAnalysis *LAA; // Current LoopAccessAnalysis + const LoopAccessInfo *LAI; // Current Loop's LoopAccessInfo + + bool Changed; // Set to true when we change anything. + BasicBlock *Preheader; // The preheader block of the current loop. + Loop *CurLoop; // The current loop we are working on. + AliasSetTracker *CurAST; // AliasSet information for the current loop. + ValueToValueMap Strides; + + unsigned LoopDepthThreshold; // Maximum loop nest threshold + float InvariantThreshold; // Minimum invariant threshold + unsigned LoadAndStoreCounter; // Counter to track num of load & store + unsigned InvariantCounter; // Counter to track num of invariant + bool IsReadOnlyLoop; // Read only loop marker. + + bool isLegalForVersioning(); + bool legalLoopStructure(); + bool legalLoopInstructions(); + bool legalLoopMemoryAccesses(); + bool isLoopAlreadyVisited(); + void setNoAliasToLoop(Loop *); + bool instructionSafeForVersioning(Instruction *); + const char *getPassName() const override { return "Loop Versioning"; } +}; +} + +/// \brief Check loop structure and confirms it's good for LoopVersioningLICM. +bool LoopVersioningLICM::legalLoopStructure() { + // Loop must have a preheader, if not return false. + if (!CurLoop->getLoopPreheader()) { + DEBUG(dbgs() << " loop preheader is missing\n"); + return false; + } + // Loop should be innermost loop, if not return false. + if (CurLoop->getSubLoops().size()) { + DEBUG(dbgs() << " loop is not innermost\n"); + return false; + } + // Loop should have a single backedge, if not return false. + if (CurLoop->getNumBackEdges() != 1) { + DEBUG(dbgs() << " loop has multiple backedges\n"); + return false; + } + // Loop must have a single exiting block, if not return false. + if (!CurLoop->getExitingBlock()) { + DEBUG(dbgs() << " loop has multiple exiting block\n"); + return false; + } + // We only handle bottom-tested loop, i.e. loop in which the condition is + // checked at the end of each iteration. With that we can assume that all + // instructions in the loop are executed the same number of times. + if (CurLoop->getExitingBlock() != CurLoop->getLoopLatch()) { + DEBUG(dbgs() << " loop is not bottom tested\n"); + return false; + } + // Parallel loops must not have aliasing loop-invariant memory accesses. + // Hence we don't need to version anything in this case. + if (CurLoop->isAnnotatedParallel()) { + DEBUG(dbgs() << " Parallel loop is not worth versioning\n"); + return false; + } + // Loop depth more then LoopDepthThreshold are not allowed + if (CurLoop->getLoopDepth() > LoopDepthThreshold) { + DEBUG(dbgs() << " loop depth is more then threshold\n"); + return false; + } + // Loop should have a dedicated exit block, if not return false. + if (!CurLoop->hasDedicatedExits()) { + DEBUG(dbgs() << " loop does not has dedicated exit blocks\n"); + return false; + } + // We need to be able to compute the loop trip count in order + // to generate the bound checks. + const SCEV *ExitCount = SE->getBackedgeTakenCount(CurLoop); + if (ExitCount == SE->getCouldNotCompute()) { + DEBUG(dbgs() << " loop does not has trip count\n"); + return false; + } + return true; +} + +/// \brief Check memory accesses in loop and confirms it's good for +/// LoopVersioningLICM. +bool LoopVersioningLICM::legalLoopMemoryAccesses() { + bool HasMayAlias = false; + bool TypeSafety = false; + bool HasMod = false; + // Memory check: + // Transform phase will generate a versioned loop and also a runtime check to + // ensure the pointers are independent and they don’t alias. + // In version variant of loop, alias meta data asserts that all access are + // mutually independent. + // + // Pointers aliasing in alias domain are avoided because with multiple + // aliasing domains we may not be able to hoist potential loop invariant + // access out of the loop. + // + // Iterate over alias tracker sets, and confirm AliasSets doesn't have any + // must alias set. + for (const auto &I : *CurAST) { + const AliasSet &AS = I; + // Skip Forward Alias Sets, as this should be ignored as part of + // the AliasSetTracker object. + if (AS.isForwardingAliasSet()) + continue; + // With MustAlias its not worth adding runtime bound check. + if (AS.isMustAlias()) + return false; + Value *SomePtr = AS.begin()->getValue(); + bool TypeCheck = true; + // Check for Mod & MayAlias + HasMayAlias |= AS.isMayAlias(); + HasMod |= AS.isMod(); + for (const auto &A : AS) { + Value *Ptr = A.getValue(); + // Alias tracker should have pointers of same data type. + TypeCheck = (TypeCheck && (SomePtr->getType() == Ptr->getType())); + } + // At least one alias tracker should have pointers of same data type. + TypeSafety |= TypeCheck; + } + // Ensure types should be of same type. + if (!TypeSafety) { + DEBUG(dbgs() << " Alias tracker type safety failed!\n"); + return false; + } + // Ensure loop body shouldn't be read only. + if (!HasMod) { + DEBUG(dbgs() << " No memory modified in loop body\n"); + return false; + } + // Make sure alias set has may alias case. + // If there no alias memory ambiguity, return false. + if (!HasMayAlias) { + DEBUG(dbgs() << " No ambiguity in memory access.\n"); + return false; + } + return true; +} + +/// \brief Check loop instructions safe for Loop versioning. +/// It returns true if it's safe else returns false. +/// Consider following: +/// 1) Check all load store in loop body are non atomic & non volatile. +/// 2) Check function call safety, by ensuring its not accessing memory. +/// 3) Loop body shouldn't have any may throw instruction. +bool LoopVersioningLICM::instructionSafeForVersioning(Instruction *I) { + assert(I != nullptr && "Null instruction found!"); + // Check function call safety + if (isa<CallInst>(I) && !AA->doesNotAccessMemory(CallSite(I))) { + DEBUG(dbgs() << " Unsafe call site found.\n"); + return false; + } + // Avoid loops with possiblity of throw + if (I->mayThrow()) { + DEBUG(dbgs() << " May throw instruction found in loop body\n"); + return false; + } + // If current instruction is load instructions + // make sure it's a simple load (non atomic & non volatile) + if (I->mayReadFromMemory()) { + LoadInst *Ld = dyn_cast<LoadInst>(I); + if (!Ld || !Ld->isSimple()) { + DEBUG(dbgs() << " Found a non-simple load.\n"); + return false; + } + LoadAndStoreCounter++; + Value *Ptr = Ld->getPointerOperand(); + // Check loop invariant. + if (SE->isLoopInvariant(SE->getSCEV(Ptr), CurLoop)) + InvariantCounter++; + } + // If current instruction is store instruction + // make sure it's a simple store (non atomic & non volatile) + else if (I->mayWriteToMemory()) { + StoreInst *St = dyn_cast<StoreInst>(I); + if (!St || !St->isSimple()) { + DEBUG(dbgs() << " Found a non-simple store.\n"); + return false; + } + LoadAndStoreCounter++; + Value *Ptr = St->getPointerOperand(); + // Check loop invariant. + if (SE->isLoopInvariant(SE->getSCEV(Ptr), CurLoop)) + InvariantCounter++; + + IsReadOnlyLoop = false; + } + return true; +} + +/// \brief Check loop instructions and confirms it's good for +/// LoopVersioningLICM. +bool LoopVersioningLICM::legalLoopInstructions() { + // Resetting counters. + LoadAndStoreCounter = 0; + InvariantCounter = 0; + IsReadOnlyLoop = true; + // Iterate over loop blocks and instructions of each block and check + // instruction safety. + for (auto *Block : CurLoop->getBlocks()) + for (auto &Inst : *Block) { + // If instruction is unsafe just return false. + if (!instructionSafeForVersioning(&Inst)) + return false; + } + // Get LoopAccessInfo from current loop. + LAI = &LAA->getInfo(CurLoop); + // Check LoopAccessInfo for need of runtime check. + if (LAI->getRuntimePointerChecking()->getChecks().empty()) { + DEBUG(dbgs() << " LAA: Runtime check not found !!\n"); + return false; + } + // Number of runtime-checks should be less then RuntimeMemoryCheckThreshold + if (LAI->getNumRuntimePointerChecks() > + VectorizerParams::RuntimeMemoryCheckThreshold) { + DEBUG(dbgs() << " LAA: Runtime checks are more than threshold !!\n"); + return false; + } + // Loop should have at least one invariant load or store instruction. + if (!InvariantCounter) { + DEBUG(dbgs() << " Invariant not found !!\n"); + return false; + } + // Read only loop not allowed. + if (IsReadOnlyLoop) { + DEBUG(dbgs() << " Found a read-only loop!\n"); + return false; + } + // Profitablity check: + // Check invariant threshold, should be in limit. + if (InvariantCounter * 100 < InvariantThreshold * LoadAndStoreCounter) { + DEBUG(dbgs() + << " Invariant load & store are less then defined threshold\n"); + DEBUG(dbgs() << " Invariant loads & stores: " + << ((InvariantCounter * 100) / LoadAndStoreCounter) << "%\n"); + DEBUG(dbgs() << " Invariant loads & store threshold: " + << InvariantThreshold << "%\n"); + return false; + } + return true; +} + +/// \brief It checks loop is already visited or not. +/// check loop meta data, if loop revisited return true +/// else false. +bool LoopVersioningLICM::isLoopAlreadyVisited() { + // Check LoopVersioningLICM metadata into loop + if (findStringMetadataForLoop(CurLoop, LICMVersioningMetaData)) { + return true; + } + return false; +} + +/// \brief Checks legality for LoopVersioningLICM by considering following: +/// a) loop structure legality b) loop instruction legality +/// c) loop memory access legality. +/// Return true if legal else returns false. +bool LoopVersioningLICM::isLegalForVersioning() { + DEBUG(dbgs() << "Loop: " << *CurLoop); + // Make sure not re-visiting same loop again. + if (isLoopAlreadyVisited()) { + DEBUG( + dbgs() << " Revisiting loop in LoopVersioningLICM not allowed.\n\n"); + return false; + } + // Check loop structure leagality. + if (!legalLoopStructure()) { + DEBUG( + dbgs() << " Loop structure not suitable for LoopVersioningLICM\n\n"); + return false; + } + // Check loop instruction leagality. + if (!legalLoopInstructions()) { + DEBUG(dbgs() + << " Loop instructions not suitable for LoopVersioningLICM\n\n"); + return false; + } + // Check loop memory access leagality. + if (!legalLoopMemoryAccesses()) { + DEBUG(dbgs() + << " Loop memory access not suitable for LoopVersioningLICM\n\n"); + return false; + } + // Loop versioning is feasible, return true. + DEBUG(dbgs() << " Loop Versioning found to be beneficial\n\n"); + return true; +} + +/// \brief Update loop with aggressive aliasing assumptions. +/// It marks no-alias to any pairs of memory operations by assuming +/// loop should not have any must-alias memory accesses pairs. +/// During LoopVersioningLICM legality we ignore loops having must +/// aliasing memory accesses. +void LoopVersioningLICM::setNoAliasToLoop(Loop *VerLoop) { + // Get latch terminator instruction. + Instruction *I = VerLoop->getLoopLatch()->getTerminator(); + // Create alias scope domain. + MDBuilder MDB(I->getContext()); + MDNode *NewDomain = MDB.createAnonymousAliasScopeDomain("LVDomain"); + StringRef Name = "LVAliasScope"; + SmallVector<Metadata *, 4> Scopes, NoAliases; + MDNode *NewScope = MDB.createAnonymousAliasScope(NewDomain, Name); + // Iterate over each instruction of loop. + // set no-alias for all load & store instructions. + for (auto *Block : CurLoop->getBlocks()) { + for (auto &Inst : *Block) { + // Only interested in instruction that may modify or read memory. + if (!Inst.mayReadFromMemory() && !Inst.mayWriteToMemory()) + continue; + Scopes.push_back(NewScope); + NoAliases.push_back(NewScope); + // Set no-alias for current instruction. + Inst.setMetadata( + LLVMContext::MD_noalias, + MDNode::concatenate(Inst.getMetadata(LLVMContext::MD_noalias), + MDNode::get(Inst.getContext(), NoAliases))); + // set alias-scope for current instruction. + Inst.setMetadata( + LLVMContext::MD_alias_scope, + MDNode::concatenate(Inst.getMetadata(LLVMContext::MD_alias_scope), + MDNode::get(Inst.getContext(), Scopes))); + } + } +} + +bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + Changed = false; + // Get Analysis information. + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); + LAI = nullptr; + // Set Current Loop + CurLoop = L; + // Get the preheader block. + Preheader = L->getLoopPreheader(); + // Initial allocation + CurAST = new AliasSetTracker(*AA); + + // Loop over the body of this loop, construct AST. + for (auto *Block : L->getBlocks()) { + if (LI->getLoopFor(Block) == L) // Ignore blocks in subloop. + CurAST->add(*Block); // Incorporate the specified basic block + } + // Check feasiblity of LoopVersioningLICM. + // If versioning found to be feasible and beneficial then proceed + // else simply return, by cleaning up memory. + if (isLegalForVersioning()) { + // Do loop versioning. + // Create memcheck for memory accessed inside loop. + // Clone original loop, and set blocks properly. + LoopVersioning LVer(*LAI, CurLoop, LI, DT, SE, true); + LVer.versionLoop(); + // Set Loop Versioning metaData for original loop. + addStringMetadataToLoop(LVer.getNonVersionedLoop(), LICMVersioningMetaData); + // Set Loop Versioning metaData for version loop. + addStringMetadataToLoop(LVer.getVersionedLoop(), LICMVersioningMetaData); + // Set "llvm.mem.parallel_loop_access" metaData to versioned loop. + addStringMetadataToLoop(LVer.getVersionedLoop(), + "llvm.mem.parallel_loop_access"); + // Update version loop with aggressive aliasing assumption. + setNoAliasToLoop(LVer.getVersionedLoop()); + Changed = true; + } + // Delete allocated memory. + delete CurAST; + return Changed; +} + +char LoopVersioningLICM::ID = 0; +INITIALIZE_PASS_BEGIN(LoopVersioningLICM, "loop-versioning-licm", + "Loop Versioning For LICM", false, false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LoopVersioningLICM, "loop-versioning-licm", + "Loop Versioning For LICM", false, false) + +Pass *llvm::createLoopVersioningLICMPass() { return new LoopVersioningLICM(); } diff --git a/lib/Transforms/Scalar/LowerAtomic.cpp b/lib/Transforms/Scalar/LowerAtomic.cpp index 41511bcb7b04..08e60b16bedf 100644 --- a/lib/Transforms/Scalar/LowerAtomic.cpp +++ b/lib/Transforms/Scalar/LowerAtomic.cpp @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LowerAtomic.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" using namespace llvm; #define DEBUG_TYPE "loweratomic" @@ -100,49 +101,74 @@ static bool LowerFenceInst(FenceInst *FI) { } static bool LowerLoadInst(LoadInst *LI) { - LI->setAtomic(NotAtomic); + LI->setAtomic(AtomicOrdering::NotAtomic); return true; } static bool LowerStoreInst(StoreInst *SI) { - SI->setAtomic(NotAtomic); + SI->setAtomic(AtomicOrdering::NotAtomic); return true; } -namespace { - struct LowerAtomic : public BasicBlockPass { - static char ID; - LowerAtomic() : BasicBlockPass(ID) { - initializeLowerAtomicPass(*PassRegistry::getPassRegistry()); - } - bool runOnBasicBlock(BasicBlock &BB) override { - if (skipOptnoneFunction(BB)) - return false; - bool Changed = false; - for (BasicBlock::iterator DI = BB.begin(), DE = BB.end(); DI != DE; ) { - Instruction *Inst = &*DI++; - if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) - Changed |= LowerFenceInst(FI); - else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(Inst)) - Changed |= LowerAtomicCmpXchgInst(CXI); - else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(Inst)) - Changed |= LowerAtomicRMWInst(RMWI); - else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { - if (LI->isAtomic()) - LowerLoadInst(LI); - } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - if (SI->isAtomic()) - LowerStoreInst(SI); - } - } - return Changed; +static bool runOnBasicBlock(BasicBlock &BB) { + bool Changed = false; + for (BasicBlock::iterator DI = BB.begin(), DE = BB.end(); DI != DE;) { + Instruction *Inst = &*DI++; + if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) + Changed |= LowerFenceInst(FI); + else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(Inst)) + Changed |= LowerAtomicCmpXchgInst(CXI); + else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(Inst)) + Changed |= LowerAtomicRMWInst(RMWI); + else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + if (LI->isAtomic()) + LowerLoadInst(LI); + } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + if (SI->isAtomic()) + LowerStoreInst(SI); } + } + return Changed; +} + +static bool lowerAtomics(Function &F) { + bool Changed = false; + for (BasicBlock &BB : F) { + Changed |= runOnBasicBlock(BB); + } + return Changed; +} + +PreservedAnalyses LowerAtomicPass::run(Function &F, FunctionAnalysisManager &) { + if (lowerAtomics(F)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +namespace { +class LowerAtomicLegacyPass : public FunctionPass { +public: + static char ID; + + LowerAtomicLegacyPass() : FunctionPass(ID) { + initializeLowerAtomicLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + FunctionAnalysisManager DummyFAM; + auto PA = Impl.run(F, DummyFAM); + return !PA.areAllPreserved(); + } + +private: + LowerAtomicPass Impl; }; } -char LowerAtomic::ID = 0; -INITIALIZE_PASS(LowerAtomic, "loweratomic", - "Lower atomic intrinsics to non-atomic form", - false, false) +char LowerAtomicLegacyPass::ID = 0; +INITIALIZE_PASS(LowerAtomicLegacyPass, "loweratomic", + "Lower atomic intrinsics to non-atomic form", false, false) -Pass *llvm::createLowerAtomicPass() { return new LowerAtomic(); } +Pass *llvm::createLowerAtomicPass() { return new LowerAtomicLegacyPass(); } diff --git a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 2ace902a7a1b..79f0db1163a4 100644 --- a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -34,12 +34,24 @@ using namespace llvm; STATISTIC(ExpectIntrinsicsHandled, "Number of 'expect' intrinsic instructions handled"); -static cl::opt<uint32_t> -LikelyBranchWeight("likely-branch-weight", cl::Hidden, cl::init(64), - cl::desc("Weight of the branch likely to be taken (default = 64)")); -static cl::opt<uint32_t> -UnlikelyBranchWeight("unlikely-branch-weight", cl::Hidden, cl::init(4), - cl::desc("Weight of the branch unlikely to be taken (default = 4)")); +// These default values are chosen to represent an extremely skewed outcome for +// a condition, but they leave some room for interpretation by later passes. +// +// If the documentation for __builtin_expect() was made explicit that it should +// only be used in extreme cases, we could make this ratio higher. As it stands, +// programmers may be using __builtin_expect() / llvm.expect to annotate that a +// branch is likely or unlikely to be taken. +// +// There is a known dependency on this ratio in CodeGenPrepare when transforming +// 'select' instructions. It may be worthwhile to hoist these values to some +// shared space, so they can be used directly by other passes. + +static cl::opt<uint32_t> LikelyBranchWeight( + "likely-branch-weight", cl::Hidden, cl::init(2000), + cl::desc("Weight of the branch likely to be taken (default = 2000)")); +static cl::opt<uint32_t> UnlikelyBranchWeight( + "unlikely-branch-weight", cl::Hidden, cl::init(1), + cl::desc("Weight of the branch unlikely to be taken (default = 1)")); static bool handleSwitchExpect(SwitchInst &SI) { CallInst *CI = dyn_cast<CallInst>(SI.getCondition()); @@ -158,7 +170,8 @@ static bool lowerExpectIntrinsic(Function &F) { return Changed; } -PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F) { +PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F, + FunctionAnalysisManager &) { if (lowerExpectIntrinsic(F)) return PreservedAnalyses::none(); diff --git a/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp new file mode 100644 index 000000000000..57491007d014 --- /dev/null +++ b/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp @@ -0,0 +1,123 @@ +//===- LowerGuardIntrinsic.cpp - Lower the guard intrinsic ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass lowers the llvm.experimental.guard intrinsic to a conditional call +// to @llvm.experimental.deoptimize. Once this happens, the guard can no longer +// be widened. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +static cl::opt<uint32_t> PredicatePassBranchWeight( + "guards-predicate-pass-branch-weight", cl::Hidden, cl::init(1 << 20), + cl::desc("The probability of a guard failing is assumed to be the " + "reciprocal of this value (default = 1 << 20)")); + +namespace { +struct LowerGuardIntrinsic : public FunctionPass { + static char ID; + LowerGuardIntrinsic() : FunctionPass(ID) { + initializeLowerGuardIntrinsicPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; +}; +} + +static void MakeGuardControlFlowExplicit(Function *DeoptIntrinsic, + CallInst *CI) { + OperandBundleDef DeoptOB(*CI->getOperandBundle(LLVMContext::OB_deopt)); + SmallVector<Value *, 4> Args(std::next(CI->arg_begin()), CI->arg_end()); + + auto *CheckBB = CI->getParent(); + auto *DeoptBlockTerm = + SplitBlockAndInsertIfThen(CI->getArgOperand(0), CI, true); + + auto *CheckBI = cast<BranchInst>(CheckBB->getTerminator()); + + // SplitBlockAndInsertIfThen inserts control flow that branches to + // DeoptBlockTerm if the condition is true. We want the opposite. + CheckBI->swapSuccessors(); + + CheckBI->getSuccessor(0)->setName("guarded"); + CheckBI->getSuccessor(1)->setName("deopt"); + + if (auto *MD = CI->getMetadata(LLVMContext::MD_make_implicit)) + CheckBI->setMetadata(LLVMContext::MD_make_implicit, MD); + + MDBuilder MDB(CI->getContext()); + CheckBI->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(PredicatePassBranchWeight, 1)); + + IRBuilder<> B(DeoptBlockTerm); + auto *DeoptCall = B.CreateCall(DeoptIntrinsic, Args, {DeoptOB}, ""); + + if (DeoptIntrinsic->getReturnType()->isVoidTy()) { + B.CreateRetVoid(); + } else { + DeoptCall->setName("deoptcall"); + B.CreateRet(DeoptCall); + } + + DeoptCall->setCallingConv(CI->getCallingConv()); + DeoptBlockTerm->eraseFromParent(); +} + +bool LowerGuardIntrinsic::runOnFunction(Function &F) { + // Check if we can cheaply rule out the possibility of not having any work to + // do. + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + if (!GuardDecl || GuardDecl->use_empty()) + return false; + + SmallVector<CallInst *, 8> ToLower; + for (auto &I : instructions(F)) + if (auto *CI = dyn_cast<CallInst>(&I)) + if (auto *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::experimental_guard) + ToLower.push_back(CI); + + if (ToLower.empty()) + return false; + + auto *DeoptIntrinsic = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::experimental_deoptimize, {F.getReturnType()}); + DeoptIntrinsic->setCallingConv(GuardDecl->getCallingConv()); + + for (auto *CI : ToLower) { + MakeGuardControlFlowExplicit(DeoptIntrinsic, CI); + CI->eraseFromParent(); + } + + return true; +} + +char LowerGuardIntrinsic::ID = 0; +INITIALIZE_PASS(LowerGuardIntrinsic, "lower-guard-intrinsic", + "Lower the guard intrinsic to normal control flow", false, + false) + +Pass *llvm::createLowerGuardIntrinsicPass() { + return new LowerGuardIntrinsic(); +} diff --git a/lib/Transforms/Scalar/Makefile b/lib/Transforms/Scalar/Makefile deleted file mode 100644 index cc42fd00ac7d..000000000000 --- a/lib/Transforms/Scalar/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Transforms/Scalar/Makefile ----------------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMScalarOpts -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 6b43b0f7a2ad..d64c658f8436 100644 --- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -12,22 +12,16 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/MemCpyOptimizer.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/MemoryDependenceAnalysis.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" @@ -184,7 +178,7 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { // size. If so, check to see whether we will end up actually reducing the // number of stores used. unsigned Bytes = unsigned(End-Start); - unsigned MaxIntSize = DL.getLargestLegalIntTypeSize(); + unsigned MaxIntSize = DL.getLargestLegalIntTypeSizeInBits() / 8; if (MaxIntSize == 0) MaxIntSize = 1; unsigned NumPointerStores = Bytes / MaxIntSize; @@ -301,19 +295,16 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr, } //===----------------------------------------------------------------------===// -// MemCpyOpt Pass +// MemCpyOptLegacyPass Pass //===----------------------------------------------------------------------===// namespace { - class MemCpyOpt : public FunctionPass { - MemoryDependenceAnalysis *MD; - TargetLibraryInfo *TLI; + class MemCpyOptLegacyPass : public FunctionPass { + MemCpyOptPass Impl; public: static char ID; // Pass identification, replacement for typeid - MemCpyOpt() : FunctionPass(ID) { - initializeMemCpyOptPass(*PassRegistry::getPassRegistry()); - MD = nullptr; - TLI = nullptr; + MemCpyOptLegacyPass() : FunctionPass(ID) { + initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; @@ -324,11 +315,11 @@ namespace { AU.setPreservesCFG(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<MemoryDependenceAnalysis>(); + AU.addRequired<MemoryDependenceWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<MemoryDependenceAnalysis>(); + AU.addPreserved<MemoryDependenceWrapperPass>(); } // Helper functions @@ -348,29 +339,30 @@ namespace { bool iterateOnFunction(Function &F); }; - char MemCpyOpt::ID = 0; + char MemCpyOptLegacyPass::ID = 0; } /// The public interface to this file... -FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOpt(); } +FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOptLegacyPass(); } -INITIALIZE_PASS_BEGIN(MemCpyOpt, "memcpyopt", "MemCpy Optimization", +INITIALIZE_PASS_BEGIN(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(MemCpyOpt, "memcpyopt", "MemCpy Optimization", +INITIALIZE_PASS_END(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization", false, false) /// When scanning forward over instructions, we look for some other patterns to /// fold away. In particular, this looks for stores to neighboring locations of /// memory. If it sees enough consecutive ones, it attempts to merge them /// together into a memcpy/memset. -Instruction *MemCpyOpt::tryMergingIntoMemset(Instruction *StartInst, - Value *StartPtr, Value *ByteVal) { +Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, + Value *StartPtr, + Value *ByteVal) { const DataLayout &DL = StartInst->getModule()->getDataLayout(); // Okay, so we now have a single store that can be splatable. Scan to find @@ -493,7 +485,93 @@ static unsigned findCommonAlignment(const DataLayout &DL, const StoreInst *SI, return std::min(StoreAlign, LoadAlign); } -bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { +// This method try to lift a store instruction before position P. +// It will lift the store and its argument + that anything that +// may alias with these. +// The method returns true if it was successful. +static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P) { + // If the store alias this position, early bail out. + MemoryLocation StoreLoc = MemoryLocation::get(SI); + if (AA.getModRefInfo(P, StoreLoc) != MRI_NoModRef) + return false; + + // Keep track of the arguments of all instruction we plan to lift + // so we can make sure to lift them as well if apropriate. + DenseSet<Instruction*> Args; + if (auto *Ptr = dyn_cast<Instruction>(SI->getPointerOperand())) + if (Ptr->getParent() == SI->getParent()) + Args.insert(Ptr); + + // Instruction to lift before P. + SmallVector<Instruction*, 8> ToLift; + + // Memory locations of lifted instructions. + SmallVector<MemoryLocation, 8> MemLocs; + MemLocs.push_back(StoreLoc); + + // Lifted callsites. + SmallVector<ImmutableCallSite, 8> CallSites; + + for (auto I = --SI->getIterator(), E = P->getIterator(); I != E; --I) { + auto *C = &*I; + + bool MayAlias = AA.getModRefInfo(C) != MRI_NoModRef; + + bool NeedLift = false; + if (Args.erase(C)) + NeedLift = true; + else if (MayAlias) { + NeedLift = std::any_of(MemLocs.begin(), MemLocs.end(), + [C, &AA](const MemoryLocation &ML) { + return AA.getModRefInfo(C, ML); + }); + + if (!NeedLift) + NeedLift = std::any_of(CallSites.begin(), CallSites.end(), + [C, &AA](const ImmutableCallSite &CS) { + return AA.getModRefInfo(C, CS); + }); + } + + if (!NeedLift) + continue; + + if (MayAlias) { + if (auto CS = ImmutableCallSite(C)) { + // If we can't lift this before P, it's game over. + if (AA.getModRefInfo(P, CS) != MRI_NoModRef) + return false; + + CallSites.push_back(CS); + } else if (isa<LoadInst>(C) || isa<StoreInst>(C) || isa<VAArgInst>(C)) { + // If we can't lift this before P, it's game over. + auto ML = MemoryLocation::get(C); + if (AA.getModRefInfo(P, ML) != MRI_NoModRef) + return false; + + MemLocs.push_back(ML); + } else + // We don't know how to lift this instruction. + return false; + } + + ToLift.push_back(C); + for (unsigned k = 0, e = C->getNumOperands(); k != e; ++k) + if (auto *A = dyn_cast<Instruction>(C->getOperand(k))) + if (A->getParent() == SI->getParent()) + Args.insert(A); + } + + // We made it, we need to lift + for (auto *I : reverse(ToLift)) { + DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n"); + I->moveBefore(P); + } + + return true; +} + +bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (!SI->isSimple()) return false; // Avoid merging nontemporal stores since the resulting @@ -514,7 +592,7 @@ bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { auto *T = LI->getType(); if (T->isAggregateType()) { - AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + AliasAnalysis &AA = LookupAliasAnalysis(); MemoryLocation LoadLoc = MemoryLocation::get(LI); // We use alias analysis to check if an instruction may store to @@ -522,26 +600,20 @@ bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // such an instruction is found, we try to promote there instead // of at the store position. Instruction *P = SI; - for (BasicBlock::iterator I = ++LI->getIterator(), E = SI->getIterator(); - I != E; ++I) { - if (!(AA.getModRefInfo(&*I, LoadLoc) & MRI_Mod)) - continue; - - // We found an instruction that may write to the loaded memory. - // We can try to promote at this position instead of the store - // position if nothing alias the store memory after this and the store - // destination is not in the range. - P = &*I; - for (; I != E; ++I) { - MemoryLocation StoreLoc = MemoryLocation::get(SI); - if (&*I == SI->getOperand(1) || - AA.getModRefInfo(&*I, StoreLoc) != MRI_NoModRef) { - P = nullptr; - break; - } + for (auto &I : make_range(++LI->getIterator(), SI->getIterator())) { + if (AA.getModRefInfo(&I, LoadLoc) & MRI_Mod) { + P = &I; + break; } + } - break; + // We found an instruction that may write to the loaded memory. + // We can try to promote at this position instead of the store + // position if nothing alias the store memory after this and the store + // destination is not in the range. + if (P && P != SI) { + if (!moveUp(AA, SI, P)) + P = nullptr; } // If a valid insertion position is found, then we can promote @@ -594,7 +666,9 @@ bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (C) { // Check that nothing touches the dest of the "copy" between // the call and the store. - AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + Value *CpyDest = SI->getPointerOperand()->stripPointerCasts(); + bool CpyDestIsLocal = isa<AllocaInst>(CpyDest); + AliasAnalysis &AA = LookupAliasAnalysis(); MemoryLocation StoreLoc = MemoryLocation::get(SI); for (BasicBlock::iterator I = --SI->getIterator(), E = C->getIterator(); I != E; --I) { @@ -602,6 +676,12 @@ bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { C = nullptr; break; } + // The store to dest may never happen if an exception can be thrown + // between the load and the store. + if (I->mayThrow() && !CpyDestIsLocal) { + C = nullptr; + break; + } } } @@ -665,7 +745,7 @@ bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { return false; } -bool MemCpyOpt::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { +bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { // See if there is another memset or store neighboring this memset which // allows us to widen out the memset to do a single larger store. if (isa<ConstantInt>(MSI->getLength()) && !MSI->isVolatile()) @@ -681,10 +761,9 @@ bool MemCpyOpt::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { /// Takes a memcpy and a call that it depends on, /// and checks for the possibility of a call slot optimization by having /// the call write its result directly into the destination of the memcpy. -bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy, - Value *cpyDest, Value *cpySrc, - uint64_t cpyLen, unsigned cpyAlign, - CallInst *C) { +bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, + Value *cpySrc, uint64_t cpyLen, + unsigned cpyAlign, CallInst *C) { // The general transformation to keep in mind is // // call @func(..., src, ...) @@ -699,6 +778,11 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy, // src only holds uninitialized values at the moment of the call, meaning that // the memcpy can be discarded rather than moved. + // Lifetime marks shouldn't be operated on. + if (Function *F = C->getCalledFunction()) + if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start) + return false; + // Deliberately get the source and destination with bitcasts stripped away, // because we'll need to do type comparisons based on the underlying type. CallSite CS(C); @@ -734,6 +818,10 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy, if (destSize < srcSize) return false; } else if (Argument *A = dyn_cast<Argument>(cpyDest)) { + // The store to dest may never happen if the call can throw. + if (C->mayThrow()) + return false; + if (A->getDereferenceableBytes() < srcSize) { // If the destination is an sret parameter then only accesses that are // outside of the returned struct type can trap. @@ -805,7 +893,7 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy, // Since we're changing the parameter to the callsite, we need to make sure // that what would be the new parameter dominates the callsite. - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + DominatorTree &DT = LookupDomTree(); if (Instruction *cpyDestInst = dyn_cast<Instruction>(cpyDest)) if (!DT.dominates(cpyDestInst, C)) return false; @@ -814,7 +902,7 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy, // unexpected manner, for example via a global, which we deduce from // the use analysis, we also need to know that it does not sneakily // access dest. We rely on AA to figure this out for us. - AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + AliasAnalysis &AA = LookupAliasAnalysis(); ModRefInfo MR = AA.getModRefInfo(C, cpyDest, srcSize); // If necessary, perform additional analysis. if (MR != MRI_NoModRef) @@ -867,7 +955,8 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy, /// We've found that the (upward scanning) memory dependence of memcpy 'M' is /// the memcpy 'MDep'. Try to simplify M to copy from MDep's input if we can. -bool MemCpyOpt::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep) { +bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, + MemCpyInst *MDep) { // We can only transforms memcpy's where the dest of one is the source of the // other. if (M->getSource() != MDep->getDest() || MDep->isVolatile()) @@ -888,7 +977,7 @@ bool MemCpyOpt::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep) { if (!MDepLen || !MLen || MDepLen->getZExtValue() < MLen->getZExtValue()) return false; - AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + AliasAnalysis &AA = LookupAliasAnalysis(); // Verify that the copied-from memory doesn't change in between the two // transfers. For example, in: @@ -954,8 +1043,8 @@ bool MemCpyOpt::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep) { /// memcpy(dst, src, src_size); /// memset(dst + src_size, c, dst_size <= src_size ? 0 : dst_size - src_size); /// \endcode -bool MemCpyOpt::processMemSetMemCpyDependence(MemCpyInst *MemCpy, - MemSetInst *MemSet) { +bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, + MemSetInst *MemSet) { // We can only transform memset/memcpy with the same destination. if (MemSet->getDest() != MemCpy->getDest()) return false; @@ -1019,8 +1108,8 @@ bool MemCpyOpt::processMemSetMemCpyDependence(MemCpyInst *MemCpy, /// When dst2_size <= dst1_size. /// /// The \p MemCpy must have a Constant length. -bool MemCpyOpt::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, - MemSetInst *MemSet) { +bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, + MemSetInst *MemSet) { // This only makes sense on memcpy(..., memset(...), ...). if (MemSet->getRawDest() != MemCpy->getRawSource()) return false; @@ -1043,7 +1132,7 @@ bool MemCpyOpt::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, /// B to be a memcpy from X to Z (or potentially a memmove, depending on /// circumstances). This allows later passes to remove the first memcpy /// altogether. -bool MemCpyOpt::processMemCpy(MemCpyInst *M) { +bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { // We can only optimize non-volatile memcpy's. if (M->isVolatile()) return false; @@ -1141,8 +1230,8 @@ bool MemCpyOpt::processMemCpy(MemCpyInst *M) { /// Transforms memmove calls to memcpy calls when the src/dst are guaranteed /// not to alias. -bool MemCpyOpt::processMemMove(MemMoveInst *M) { - AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); +bool MemCpyOptPass::processMemMove(MemMoveInst *M) { + AliasAnalysis &AA = LookupAliasAnalysis(); if (!TLI->has(LibFunc::memmove)) return false; @@ -1152,7 +1241,8 @@ bool MemCpyOpt::processMemMove(MemMoveInst *M) { MemoryLocation::getForSource(M))) return false; - DEBUG(dbgs() << "MemCpyOpt: Optimizing memmove -> memcpy: " << *M << "\n"); + DEBUG(dbgs() << "MemCpyOptPass: Optimizing memmove -> memcpy: " << *M + << "\n"); // If not, then we know we can transform this. Type *ArgTys[3] = { M->getRawDest()->getType(), @@ -1170,7 +1260,7 @@ bool MemCpyOpt::processMemMove(MemMoveInst *M) { } /// This is called on every byval argument in call sites. -bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) { +bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { const DataLayout &DL = CS.getCaller()->getParent()->getDataLayout(); // Find out what feeds this byval argument. Value *ByValArg = CS.getArgument(ArgNo); @@ -1202,10 +1292,8 @@ bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) { // If it is greater than the memcpy, then we check to see if we can force the // source of the memcpy to the alignment we need. If we fail, we bail out. - AssumptionCache &AC = - getAnalysis<AssumptionCacheTracker>().getAssumptionCache( - *CS->getParent()->getParent()); - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionCache &AC = LookupAssumptionCache(); + DominatorTree &DT = LookupDomTree(); if (MDep->getAlignment() < ByValAlign && getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, CS.getInstruction(), &AC, &DT) < ByValAlign) @@ -1231,7 +1319,7 @@ bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) { TmpCast = new BitCastInst(MDep->getSource(), ByValArg->getType(), "tmpcast", CS.getInstruction()); - DEBUG(dbgs() << "MemCpyOpt: Forwarding memcpy to byval:\n" + DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n" << " " << *MDep << "\n" << " " << *CS.getInstruction() << "\n"); @@ -1241,13 +1329,13 @@ bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) { return true; } -/// Executes one iteration of MemCpyOpt. -bool MemCpyOpt::iterateOnFunction(Function &F) { +/// Executes one iteration of MemCpyOptPass. +bool MemCpyOptPass::iterateOnFunction(Function &F) { bool MadeChange = false; // Walk all instruction in the function. - for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) { - for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { + for (BasicBlock &BB : F) { + for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) { // Avoid invalidating the iterator. Instruction *I = &*BI++; @@ -1269,7 +1357,8 @@ bool MemCpyOpt::iterateOnFunction(Function &F) { // Reprocess the instruction if desired. if (RepeatInstruction) { - if (BI != BB->begin()) --BI; + if (BI != BB.begin()) + --BI; MadeChange = true; } } @@ -1278,14 +1367,42 @@ bool MemCpyOpt::iterateOnFunction(Function &F) { return MadeChange; } -/// This is the main transformation entry point for a function. -bool MemCpyOpt::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) - return false; +PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) { + + auto &MD = AM.getResult<MemoryDependenceAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + + auto LookupAliasAnalysis = [&]() -> AliasAnalysis & { + return AM.getResult<AAManager>(F); + }; + auto LookupAssumptionCache = [&]() -> AssumptionCache & { + return AM.getResult<AssumptionAnalysis>(F); + }; + auto LookupDomTree = [&]() -> DominatorTree & { + return AM.getResult<DominatorTreeAnalysis>(F); + }; + + bool MadeChange = runImpl(F, &MD, &TLI, LookupAliasAnalysis, + LookupAssumptionCache, LookupDomTree); + if (!MadeChange) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + PA.preserve<MemoryDependenceAnalysis>(); + return PA; +} +bool MemCpyOptPass::runImpl( + Function &F, MemoryDependenceResults *MD_, TargetLibraryInfo *TLI_, + std::function<AliasAnalysis &()> LookupAliasAnalysis_, + std::function<AssumptionCache &()> LookupAssumptionCache_, + std::function<DominatorTree &()> LookupDomTree_) { bool MadeChange = false; - MD = &getAnalysis<MemoryDependenceAnalysis>(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + MD = MD_; + TLI = TLI_; + LookupAliasAnalysis = std::move(LookupAliasAnalysis_); + LookupAssumptionCache = std::move(LookupAssumptionCache_); + LookupDomTree = std::move(LookupDomTree_); // If we don't have at least memset and memcpy, there is little point of doing // anything here. These are required by a freestanding implementation, so if @@ -1302,3 +1419,25 @@ bool MemCpyOpt::runOnFunction(Function &F) { MD = nullptr; return MadeChange; } + +/// This is the main transformation entry point for a function. +bool MemCpyOptLegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + auto *MD = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + auto LookupAliasAnalysis = [this]() -> AliasAnalysis & { + return getAnalysis<AAResultsWrapperPass>().getAAResults(); + }; + auto LookupAssumptionCache = [this, &F]() -> AssumptionCache & { + return getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + }; + auto LookupDomTree = [this]() -> DominatorTree & { + return getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + }; + + return Impl.runImpl(F, MD, TLI, LookupAliasAnalysis, LookupAssumptionCache, + LookupDomTree); +} diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index c812d618c16a..30261b755001 100644 --- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -72,9 +72,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CFG.h" @@ -82,51 +80,37 @@ #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" -#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Support/Allocator.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" -#include <vector> using namespace llvm; #define DEBUG_TYPE "mldst-motion" +namespace { //===----------------------------------------------------------------------===// // MergedLoadStoreMotion Pass //===----------------------------------------------------------------------===// +class MergedLoadStoreMotion { + MemoryDependenceResults *MD = nullptr; + AliasAnalysis *AA = nullptr; -namespace { -class MergedLoadStoreMotion : public FunctionPass { - AliasAnalysis *AA; - MemoryDependenceAnalysis *MD; + // The mergeLoad/Store algorithms could have Size0 * Size1 complexity, + // where Size0 and Size1 are the #instructions on the two sides of + // the diamond. The constant chosen here is arbitrary. Compiler Time + // Control is enforced by the check Size0 * Size1 < MagicCompileTimeControl. + const int MagicCompileTimeControl = 250; public: - static char ID; // Pass identification, replacement for typeid - MergedLoadStoreMotion() - : FunctionPass(ID), MD(nullptr), MagicCompileTimeControl(250) { - initializeMergedLoadStoreMotionPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; + bool run(Function &F, MemoryDependenceResults *MD, AliasAnalysis &AA); private: - // This transformation requires dominator postdominator info - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<MemoryDependenceAnalysis>(); - } - - // Helper routines - /// /// \brief Remove instruction from parent and update memory dependence /// analysis. @@ -135,9 +119,9 @@ private: BasicBlock *getDiamondTail(BasicBlock *BB); bool isDiamondHead(BasicBlock *BB); // Routines for hoisting loads - bool isLoadHoistBarrierInRange(const Instruction& Start, - const Instruction& End, - LoadInst* LI); + bool isLoadHoistBarrierInRange(const Instruction &Start, + const Instruction &End, LoadInst *LI, + bool SafeToLoadUnconditionally); LoadInst *canHoistFromBlock(BasicBlock *BB, LoadInst *LI); void hoistInstruction(BasicBlock *BB, Instruction *HoistCand, Instruction *ElseInst); @@ -151,31 +135,8 @@ private: const Instruction &End, MemoryLocation Loc); bool sinkStore(BasicBlock *BB, StoreInst *SinkCand, StoreInst *ElseInst); bool mergeStores(BasicBlock *BB); - // The mergeLoad/Store algorithms could have Size0 * Size1 complexity, - // where Size0 and Size1 are the #instructions on the two sides of - // the diamond. The constant chosen here is arbitrary. Compiler Time - // Control is enforced by the check Size0 * Size1 < MagicCompileTimeControl. - const int MagicCompileTimeControl; }; - -char MergedLoadStoreMotion::ID = 0; -} // anonymous namespace - -/// -/// \brief createMergedLoadStoreMotionPass - The public interface to this file. -/// -FunctionPass *llvm::createMergedLoadStoreMotionPass() { - return new MergedLoadStoreMotion(); -} - -INITIALIZE_PASS_BEGIN(MergedLoadStoreMotion, "mldst-motion", - "MergedLoadStoreMotion", false, false) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_END(MergedLoadStoreMotion, "mldst-motion", - "MergedLoadStoreMotion", false, false) +} // end anonymous namespace /// /// \brief Remove instruction from parent and update memory dependence analysis. @@ -184,9 +145,9 @@ void MergedLoadStoreMotion::removeInstruction(Instruction *Inst) { // Notify the memory dependence analysis. if (MD) { MD->removeInstruction(Inst); - if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) + if (auto *LI = dyn_cast<LoadInst>(Inst)) MD->invalidateCachedPointerInfo(LI->getPointerOperand()); - if (Inst->getType()->getScalarType()->isPointerTy()) { + if (Inst->getType()->isPtrOrPtrVectorTy()) { MD->invalidateCachedPointerInfo(Inst); } } @@ -198,10 +159,7 @@ void MergedLoadStoreMotion::removeInstruction(Instruction *Inst) { /// BasicBlock *MergedLoadStoreMotion::getDiamondTail(BasicBlock *BB) { assert(isDiamondHead(BB) && "Basic block is not head of a diamond"); - BranchInst *BI = (BranchInst *)(BB->getTerminator()); - BasicBlock *Succ0 = BI->getSuccessor(0); - BasicBlock *Tail = Succ0->getTerminator()->getSuccessor(0); - return Tail; + return BB->getTerminator()->getSuccessor(0)->getSingleSuccessor(); } /// @@ -210,25 +168,22 @@ BasicBlock *MergedLoadStoreMotion::getDiamondTail(BasicBlock *BB) { bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) { if (!BB) return false; - if (!isa<BranchInst>(BB->getTerminator())) - return false; - if (BB->getTerminator()->getNumSuccessors() != 2) + auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isConditional()) return false; - BranchInst *BI = (BranchInst *)(BB->getTerminator()); BasicBlock *Succ0 = BI->getSuccessor(0); BasicBlock *Succ1 = BI->getSuccessor(1); - if (!Succ0->getSinglePredecessor() || - Succ0->getTerminator()->getNumSuccessors() != 1) + if (!Succ0->getSinglePredecessor()) return false; - if (!Succ1->getSinglePredecessor() || - Succ1->getTerminator()->getNumSuccessors() != 1) + if (!Succ1->getSinglePredecessor()) return false; - BasicBlock *Tail = Succ0->getTerminator()->getSuccessor(0); + BasicBlock *Succ0Succ = Succ0->getSingleSuccessor(); + BasicBlock *Succ1Succ = Succ1->getSingleSuccessor(); // Ignore triangles. - if (Succ1->getTerminator()->getSuccessor(0) != Tail) + if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ) return false; return true; } @@ -240,9 +195,14 @@ bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) { /// being loaded or protect against the load from happening /// it is considered a hoist barrier. /// -bool MergedLoadStoreMotion::isLoadHoistBarrierInRange(const Instruction& Start, - const Instruction& End, - LoadInst* LI) { +bool MergedLoadStoreMotion::isLoadHoistBarrierInRange( + const Instruction &Start, const Instruction &End, LoadInst *LI, + bool SafeToLoadUnconditionally) { + if (!SafeToLoadUnconditionally) + for (const Instruction &Inst : + make_range(Start.getIterator(), End.getIterator())) + if (!isGuaranteedToTransferExecutionToSuccessor(&Inst)) + return true; MemoryLocation Loc = MemoryLocation::get(LI); return AA->canInstructionRangeModRef(Start, End, Loc, MRI_Mod); } @@ -256,23 +216,28 @@ bool MergedLoadStoreMotion::isLoadHoistBarrierInRange(const Instruction& Start, /// LoadInst *MergedLoadStoreMotion::canHoistFromBlock(BasicBlock *BB1, LoadInst *Load0) { - + BasicBlock *BB0 = Load0->getParent(); + BasicBlock *Head = BB0->getSinglePredecessor(); + bool SafeToLoadUnconditionally = isSafeToLoadUnconditionally( + Load0->getPointerOperand(), Load0->getAlignment(), + Load0->getModule()->getDataLayout(), + /*ScanFrom=*/Head->getTerminator()); for (BasicBlock::iterator BBI = BB1->begin(), BBE = BB1->end(); BBI != BBE; ++BBI) { Instruction *Inst = &*BBI; // Only merge and hoist loads when their result in used only in BB - if (!isa<LoadInst>(Inst) || Inst->isUsedOutsideOfBlock(BB1)) + auto *Load1 = dyn_cast<LoadInst>(Inst); + if (!Load1 || Inst->isUsedOutsideOfBlock(BB1)) continue; - LoadInst *Load1 = dyn_cast<LoadInst>(Inst); - BasicBlock *BB0 = Load0->getParent(); - MemoryLocation Loc0 = MemoryLocation::get(Load0); MemoryLocation Loc1 = MemoryLocation::get(Load1); - if (AA->isMustAlias(Loc0, Loc1) && Load0->isSameOperationAs(Load1) && - !isLoadHoistBarrierInRange(BB1->front(), *Load1, Load1) && - !isLoadHoistBarrierInRange(BB0->front(), *Load0, Load0)) { + if (Load0->isSameOperationAs(Load1) && AA->isMustAlias(Loc0, Loc1) && + !isLoadHoistBarrierInRange(BB1->front(), *Load1, Load1, + SafeToLoadUnconditionally) && + !isLoadHoistBarrierInRange(BB0->front(), *Load0, Load0, + SafeToLoadUnconditionally)) { return Load1; } } @@ -319,11 +284,10 @@ void MergedLoadStoreMotion::hoistInstruction(BasicBlock *BB, /// bool MergedLoadStoreMotion::isSafeToHoist(Instruction *I) const { BasicBlock *Parent = I->getParent(); - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - Instruction *Instr = dyn_cast<Instruction>(I->getOperand(i)); - if (Instr && Instr->getParent() == Parent) - return false; - } + for (Use &U : I->operands()) + if (auto *Instr = dyn_cast<Instruction>(&U)) + if (Instr->getParent() == Parent) + return false; return true; } @@ -333,8 +297,8 @@ bool MergedLoadStoreMotion::isSafeToHoist(Instruction *I) const { bool MergedLoadStoreMotion::hoistLoad(BasicBlock *BB, LoadInst *L0, LoadInst *L1) { // Only one definition? - Instruction *A0 = dyn_cast<Instruction>(L0->getPointerOperand()); - Instruction *A1 = dyn_cast<Instruction>(L1->getPointerOperand()); + auto *A0 = dyn_cast<Instruction>(L0->getPointerOperand()); + auto *A1 = dyn_cast<Instruction>(L1->getPointerOperand()); if (A0 && A1 && A0->isIdenticalTo(A1) && isSafeToHoist(A0) && A0->hasOneUse() && (A0->getParent() == L0->getParent()) && A1->hasOneUse() && (A1->getParent() == L1->getParent()) && @@ -345,8 +309,8 @@ bool MergedLoadStoreMotion::hoistLoad(BasicBlock *BB, LoadInst *L0, hoistInstruction(BB, A0, A1); hoistInstruction(BB, L0, L1); return true; - } else - return false; + } + return false; } /// @@ -358,7 +322,7 @@ bool MergedLoadStoreMotion::hoistLoad(BasicBlock *BB, LoadInst *L0, bool MergedLoadStoreMotion::mergeLoads(BasicBlock *BB) { bool MergedLoads = false; assert(isDiamondHead(BB)); - BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + BranchInst *BI = cast<BranchInst>(BB->getTerminator()); BasicBlock *Succ0 = BI->getSuccessor(0); BasicBlock *Succ1 = BI->getSuccessor(1); // #Instructions in Succ1 for Compile Time Control @@ -369,8 +333,8 @@ bool MergedLoadStoreMotion::mergeLoads(BasicBlock *BB) { Instruction *I = &*BBI; ++BBI; - // Only move non-simple (atomic, volatile) loads. - LoadInst *L0 = dyn_cast<LoadInst>(I); + // Don't move non-simple (atomic, volatile) loads. + auto *L0 = dyn_cast<LoadInst>(I); if (!L0 || !L0->isSimple() || L0->isUsedOutsideOfBlock(Succ0)) continue; @@ -399,6 +363,10 @@ bool MergedLoadStoreMotion::mergeLoads(BasicBlock *BB) { bool MergedLoadStoreMotion::isStoreSinkBarrierInRange(const Instruction &Start, const Instruction &End, MemoryLocation Loc) { + for (const Instruction &Inst : + make_range(Start.getIterator(), End.getIterator())) + if (Inst.mayThrow()) + return true; return AA->canInstructionRangeModRef(Start, End, Loc, MRI_ModRef); } @@ -411,22 +379,16 @@ StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1, StoreInst *Store0) { DEBUG(dbgs() << "can Sink? : "; Store0->dump(); dbgs() << "\n"); BasicBlock *BB0 = Store0->getParent(); - for (BasicBlock::reverse_iterator RBI = BB1->rbegin(), RBE = BB1->rend(); - RBI != RBE; ++RBI) { - Instruction *Inst = &*RBI; - - if (!isa<StoreInst>(Inst)) - continue; - - StoreInst *Store1 = cast<StoreInst>(Inst); + for (Instruction &Inst : reverse(*BB1)) { + auto *Store1 = dyn_cast<StoreInst>(&Inst); + if (!Store1) + continue; MemoryLocation Loc0 = MemoryLocation::get(Store0); MemoryLocation Loc1 = MemoryLocation::get(Store1); if (AA->isMustAlias(Loc0, Loc1) && Store0->isSameOperationAs(Store1) && - !isStoreSinkBarrierInRange(*(std::next(BasicBlock::iterator(Store1))), - BB1->back(), Loc1) && - !isStoreSinkBarrierInRange(*(std::next(BasicBlock::iterator(Store0))), - BB0->back(), Loc0)) { + !isStoreSinkBarrierInRange(*Store1->getNextNode(), BB1->back(), Loc1) && + !isStoreSinkBarrierInRange(*Store0->getNextNode(), BB0->back(), Loc0)) { return Store1; } } @@ -439,17 +401,17 @@ StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1, PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1) { // Create a phi if the values mismatch. - PHINode *NewPN = nullptr; Value *Opd1 = S0->getValueOperand(); Value *Opd2 = S1->getValueOperand(); - if (Opd1 != Opd2) { - NewPN = PHINode::Create(Opd1->getType(), 2, Opd2->getName() + ".sink", - &BB->front()); - NewPN->addIncoming(Opd1, S0->getParent()); - NewPN->addIncoming(Opd2, S1->getParent()); - if (MD && NewPN->getType()->getScalarType()->isPointerTy()) - MD->invalidateCachedPointerInfo(NewPN); - } + if (Opd1 == Opd2) + return nullptr; + + auto *NewPN = PHINode::Create(Opd1->getType(), 2, Opd2->getName() + ".sink", + &BB->front()); + NewPN->addIncoming(Opd1, S0->getParent()); + NewPN->addIncoming(Opd2, S1->getParent()); + if (MD && NewPN->getType()->getScalarType()->isPointerTy()) + MD->invalidateCachedPointerInfo(NewPN); return NewPN; } @@ -461,8 +423,8 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0, StoreInst *S1) { // Only one definition? - Instruction *A0 = dyn_cast<Instruction>(S0->getPointerOperand()); - Instruction *A1 = dyn_cast<Instruction>(S1->getPointerOperand()); + auto *A0 = dyn_cast<Instruction>(S0->getPointerOperand()); + auto *A1 = dyn_cast<Instruction>(S1->getPointerOperand()); if (A0 && A1 && A0->isIdenticalTo(A1) && A0->hasOneUse() && (A0->getParent() == S0->getParent()) && A1->hasOneUse() && (A1->getParent() == S1->getParent()) && isa<GetElementPtrInst>(A0)) { @@ -476,7 +438,7 @@ bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0, S0->dropUnknownNonDebugMetadata(); // Create the new store to be inserted at the join point. - StoreInst *SNew = (StoreInst *)(S0->clone()); + StoreInst *SNew = cast<StoreInst>(S0->clone()); Instruction *ANew = A0->clone(); SNew->insertBefore(&*InsertPt); ANew->insertBefore(SNew); @@ -484,9 +446,8 @@ bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0, assert(S0->getParent() == A0->getParent()); assert(S1->getParent() == A1->getParent()); - PHINode *NewPN = getPHIOperand(BB, S0, S1); // New PHI operand? Use it. - if (NewPN) + if (PHINode *NewPN = getPHIOperand(BB, S0, S1)) SNew->setOperand(0, NewPN); removeInstruction(S0); removeInstruction(S1); @@ -532,11 +493,9 @@ bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) { Instruction *I = &*RBI; ++RBI; - // Sink move non-simple (atomic, volatile) stores - if (!isa<StoreInst>(I)) - continue; - StoreInst *S0 = (StoreInst *)I; - if (!S0->isSimple()) + // Don't sink non-simple (atomic, volatile) stores. + auto *S0 = dyn_cast<StoreInst>(I); + if (!S0 || !S0->isSimple()) continue; ++NStores; @@ -551,22 +510,18 @@ bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) { // is likely stale at this point. if (!Res) break; - else { - RBI = Pred0->rbegin(); - RBE = Pred0->rend(); - DEBUG(dbgs() << "Search again\n"; Instruction *I = &*RBI; I->dump()); - } + RBI = Pred0->rbegin(); + RBE = Pred0->rend(); + DEBUG(dbgs() << "Search again\n"; Instruction *I = &*RBI; I->dump()); } } return MergedStores; } -/// -/// \brief Run the transformation for each function -/// -bool MergedLoadStoreMotion::runOnFunction(Function &F) { - MD = getAnalysisIfAvailable<MemoryDependenceAnalysis>(); - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); +bool MergedLoadStoreMotion::run(Function &F, MemoryDependenceResults *MD, + AliasAnalysis &AA) { + this->MD = MD; + this->AA = &AA; bool Changed = false; DEBUG(dbgs() << "Instruction Merger\n"); @@ -585,3 +540,66 @@ bool MergedLoadStoreMotion::runOnFunction(Function &F) { } return Changed; } + +namespace { +class MergedLoadStoreMotionLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + MergedLoadStoreMotionLegacyPass() : FunctionPass(ID) { + initializeMergedLoadStoreMotionLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + /// + /// \brief Run the transformation for each function + /// + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + MergedLoadStoreMotion Impl; + auto *MDWP = getAnalysisIfAvailable<MemoryDependenceWrapperPass>(); + return Impl.run(F, MDWP ? &MDWP->getMemDep() : nullptr, + getAnalysis<AAResultsWrapperPass>().getAAResults()); + } + +private: + // This transformation requires dominator postdominator info + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<MemoryDependenceWrapperPass>(); + } +}; + +char MergedLoadStoreMotionLegacyPass::ID = 0; +} // anonymous namespace + +/// +/// \brief createMergedLoadStoreMotionPass - The public interface to this file. +/// +FunctionPass *llvm::createMergedLoadStoreMotionPass() { + return new MergedLoadStoreMotionLegacyPass(); +} + +INITIALIZE_PASS_BEGIN(MergedLoadStoreMotionLegacyPass, "mldst-motion", + "MergedLoadStoreMotion", false, false) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion", + "MergedLoadStoreMotion", false, false) + +PreservedAnalyses +MergedLoadStoreMotionPass::run(Function &F, AnalysisManager<Function> &AM) { + MergedLoadStoreMotion Impl; + auto *MD = AM.getCachedResult<MemoryDependenceAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + if (!Impl.run(F, MD, AA)) + return PreservedAnalyses::all(); + + // FIXME: This should also 'preserve the CFG'. + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + PA.preserve<MemoryDependenceAnalysis>(); + return PA; +} diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp index c8f885e7eec5..ed754fa71025 100644 --- a/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/lib/Transforms/Scalar/NaryReassociate.cpp @@ -208,7 +208,7 @@ FunctionPass *llvm::createNaryReassociatePass() { } bool NaryReassociate::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); @@ -299,49 +299,18 @@ Instruction *NaryReassociate::tryReassociate(Instruction *I) { } } -// FIXME: extract this method into TTI->getGEPCost. static bool isGEPFoldable(GetElementPtrInst *GEP, - const TargetTransformInfo *TTI, - const DataLayout *DL) { - GlobalVariable *BaseGV = nullptr; - int64_t BaseOffset = 0; - bool HasBaseReg = false; - int64_t Scale = 0; - - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getPointerOperand())) - BaseGV = GV; - else - HasBaseReg = true; - - gep_type_iterator GTI = gep_type_begin(GEP); - for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I, ++GTI) { - if (isa<SequentialType>(*GTI)) { - int64_t ElementSize = DL->getTypeAllocSize(GTI.getIndexedType()); - if (ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I)) { - BaseOffset += ConstIdx->getSExtValue() * ElementSize; - } else { - // Needs scale register. - if (Scale != 0) { - // No addressing mode takes two scale registers. - return false; - } - Scale = ElementSize; - } - } else { - StructType *STy = cast<StructType>(*GTI); - uint64_t Field = cast<ConstantInt>(*I)->getZExtValue(); - BaseOffset += DL->getStructLayout(STy)->getElementOffset(Field); - } - } - - unsigned AddrSpace = GEP->getPointerAddressSpace(); - return TTI->isLegalAddressingMode(GEP->getType()->getElementType(), BaseGV, - BaseOffset, HasBaseReg, Scale, AddrSpace); + const TargetTransformInfo *TTI) { + SmallVector<const Value*, 4> Indices; + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) + Indices.push_back(*I); + return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(), + Indices) == TargetTransformInfo::TCC_Free; } Instruction *NaryReassociate::tryReassociateGEP(GetElementPtrInst *GEP) { // Not worth reassociating GEP if it is foldable. - if (isGEPFoldable(GEP, TTI, DL)) + if (isGEPFoldable(GEP, TTI)) return nullptr; gep_type_iterator GTI = gep_type_begin(*GEP); @@ -434,7 +403,7 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType) uint64_t IndexedSize = DL->getTypeAllocSize(IndexedType); - Type *ElementType = GEP->getType()->getElementType(); + Type *ElementType = GEP->getResultElementType(); uint64_t ElementSize = DL->getTypeAllocSize(ElementType); // Another less rare case: because I is not necessarily the last index of the // GEP, the size of the type at the I-th index (IndexedSize) is not diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 9f26f78892c6..c4b3e3464f40 100644 --- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -13,12 +13,10 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/PartiallyInlineLibCalls.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/Pass.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -26,85 +24,9 @@ using namespace llvm; #define DEBUG_TYPE "partially-inline-libcalls" -namespace { - class PartiallyInlineLibCalls : public FunctionPass { - public: - static char ID; - - PartiallyInlineLibCalls() : - FunctionPass(ID) { - initializePartiallyInlineLibCallsPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; - - private: - /// Optimize calls to sqrt. - bool optimizeSQRT(CallInst *Call, Function *CalledFunc, - BasicBlock &CurrBB, Function::iterator &BB); - }; - - char PartiallyInlineLibCalls::ID = 0; -} - -INITIALIZE_PASS(PartiallyInlineLibCalls, "partially-inline-libcalls", - "Partially inline calls to library functions", false, false) - -void PartiallyInlineLibCalls::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - FunctionPass::getAnalysisUsage(AU); -} - -bool PartiallyInlineLibCalls::runOnFunction(Function &F) { - bool Changed = false; - Function::iterator CurrBB; - TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - const TargetTransformInfo *TTI = - &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - for (Function::iterator BB = F.begin(), BE = F.end(); BB != BE;) { - CurrBB = BB++; - - for (BasicBlock::iterator II = CurrBB->begin(), IE = CurrBB->end(); - II != IE; ++II) { - CallInst *Call = dyn_cast<CallInst>(&*II); - Function *CalledFunc; - - if (!Call || !(CalledFunc = Call->getCalledFunction())) - continue; - - // Skip if function either has local linkage or is not a known library - // function. - LibFunc::Func LibFunc; - if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() || - !TLI->getLibFunc(CalledFunc->getName(), LibFunc)) - continue; - - switch (LibFunc) { - case LibFunc::sqrtf: - case LibFunc::sqrt: - if (TTI->haveFastSqrt(Call->getType()) && - optimizeSQRT(Call, CalledFunc, *CurrBB, BB)) - break; - continue; - default: - continue; - } - Changed = true; - break; - } - } - - return Changed; -} - -bool PartiallyInlineLibCalls::optimizeSQRT(CallInst *Call, - Function *CalledFunc, - BasicBlock &CurrBB, - Function::iterator &BB) { +static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, + BasicBlock &CurrBB, Function::iterator &BB) { // There is no need to change the IR, since backend will emit sqrt // instruction if the call has already been marked read-only. if (Call->onlyReadsMemory()) @@ -158,6 +80,97 @@ bool PartiallyInlineLibCalls::optimizeSQRT(CallInst *Call, return true; } +static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI) { + bool Changed = false; + + Function::iterator CurrBB; + for (Function::iterator BB = F.begin(), BE = F.end(); BB != BE;) { + CurrBB = BB++; + + for (BasicBlock::iterator II = CurrBB->begin(), IE = CurrBB->end(); + II != IE; ++II) { + CallInst *Call = dyn_cast<CallInst>(&*II); + Function *CalledFunc; + + if (!Call || !(CalledFunc = Call->getCalledFunction())) + continue; + + // Skip if function either has local linkage or is not a known library + // function. + LibFunc::Func LibFunc; + if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() || + !TLI->getLibFunc(CalledFunc->getName(), LibFunc)) + continue; + + switch (LibFunc) { + case LibFunc::sqrtf: + case LibFunc::sqrt: + if (TTI->haveFastSqrt(Call->getType()) && + optimizeSQRT(Call, CalledFunc, *CurrBB, BB)) + break; + continue; + default: + continue; + } + + Changed = true; + break; + } + } + + return Changed; +} + +PreservedAnalyses +PartiallyInlineLibCallsPass::run(Function &F, AnalysisManager<Function> &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + if (!runPartiallyInlineLibCalls(F, &TLI, &TTI)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +namespace { +class PartiallyInlineLibCallsLegacyPass : public FunctionPass { +public: + static char ID; + + PartiallyInlineLibCallsLegacyPass() : FunctionPass(ID) { + initializePartiallyInlineLibCallsLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + FunctionPass::getAnalysisUsage(AU); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + const TargetTransformInfo *TTI = + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + return runPartiallyInlineLibCalls(F, TLI, TTI); + } +}; +} + +char PartiallyInlineLibCallsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PartiallyInlineLibCallsLegacyPass, + "partially-inline-libcalls", + "Partially inline calls to library functions", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(PartiallyInlineLibCallsLegacyPass, + "partially-inline-libcalls", + "Partially inline calls to library functions", false, false) + FunctionPass *llvm::createPartiallyInlineLibCallsPass() { - return new PartiallyInlineLibCalls(); + return new PartiallyInlineLibCallsLegacyPass(); } diff --git a/lib/Transforms/Scalar/PlaceSafepoints.cpp b/lib/Transforms/Scalar/PlaceSafepoints.cpp index b56b35599120..e47b636348e3 100644 --- a/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -49,45 +49,32 @@ //===----------------------------------------------------------------------===// #include "llvm/Pass.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/ADT/SetOperations.h" + #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/IR/BasicBlock.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" +#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Statepoint.h" -#include "llvm/IR/Value.h" -#include "llvm/IR/Verifier.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #define DEBUG_TYPE "safepoint-placement" + STATISTIC(NumEntrySafepoints, "Number of entry safepoints inserted"); -STATISTIC(NumCallSafepoints, "Number of call safepoints inserted"); STATISTIC(NumBackedgeSafepoints, "Number of backedge safepoints inserted"); -STATISTIC(CallInLoop, "Number of loops w/o safepoints due to calls in loop"); -STATISTIC(FiniteExecution, "Number of loops w/o safepoints finite execution"); +STATISTIC(CallInLoop, + "Number of loops without safepoints due to calls in loop"); +STATISTIC(FiniteExecution, + "Number of loops without safepoints finite execution"); using namespace llvm; @@ -108,9 +95,6 @@ static cl::opt<int> CountedLoopTripWidth("spp-counted-loop-trip-width", static cl::opt<bool> SplitBackedge("spp-split-backedge", cl::Hidden, cl::init(false)); -// Print tracing output -static cl::opt<bool> TraceLSP("spp-trace", cl::Hidden, cl::init(false)); - namespace { /// An analysis pass whose purpose is to identify each of the backedges in @@ -138,8 +122,8 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { bool runOnLoop(Loop *); void runOnLoopAndSubLoops(Loop *L) { // Visit all the subloops - for (auto I = L->begin(), E = L->end(); I != E; I++) - runOnLoopAndSubLoops(*I); + for (Loop *I : *L) + runOnLoopAndSubLoops(I); runOnLoop(L); } @@ -147,8 +131,8 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - for (auto I = LI->begin(), E = LI->end(); I != E; I++) { - runOnLoopAndSubLoops(*I); + for (Loop *I : *LI) { + runOnLoopAndSubLoops(I); } return false; } @@ -200,13 +184,9 @@ static bool needsStatepoint(const CallSite &CS) { if (call->isInlineAsm()) return false; } - if (isStatepoint(CS) || isGCRelocate(CS) || isGCResult(CS)) { - return false; - } - return true; -} -static Value *ReplaceWithStatepoint(const CallSite &CS); + return !(isStatepoint(CS) || isGCRelocate(CS) || isGCResult(CS)); +} /// Returns true if this loop is known to contain a call safepoint which /// must unconditionally execute on any iteration of the loop which returns @@ -278,43 +258,44 @@ static bool mustBeFiniteCountedLoop(Loop *L, ScalarEvolution *SE, return /* not finite */ false; } -static void scanOneBB(Instruction *start, Instruction *end, - std::vector<CallInst *> &calls, - std::set<BasicBlock *> &seen, - std::vector<BasicBlock *> &worklist) { - for (BasicBlock::iterator itr(start); - itr != start->getParent()->end() && itr != BasicBlock::iterator(end); - itr++) { - if (CallInst *CI = dyn_cast<CallInst>(&*itr)) { - calls.push_back(CI); - } +static void scanOneBB(Instruction *Start, Instruction *End, + std::vector<CallInst *> &Calls, + DenseSet<BasicBlock *> &Seen, + std::vector<BasicBlock *> &Worklist) { + for (BasicBlock::iterator BBI(Start), BBE0 = Start->getParent()->end(), + BBE1 = BasicBlock::iterator(End); + BBI != BBE0 && BBI != BBE1; BBI++) { + if (CallInst *CI = dyn_cast<CallInst>(&*BBI)) + Calls.push_back(CI); + // FIXME: This code does not handle invokes - assert(!dyn_cast<InvokeInst>(&*itr) && + assert(!isa<InvokeInst>(&*BBI) && "support for invokes in poll code needed"); + // Only add the successor blocks if we reach the terminator instruction // without encountering end first - if (itr->isTerminator()) { - BasicBlock *BB = itr->getParent(); + if (BBI->isTerminator()) { + BasicBlock *BB = BBI->getParent(); for (BasicBlock *Succ : successors(BB)) { - if (seen.count(Succ) == 0) { - worklist.push_back(Succ); - seen.insert(Succ); + if (Seen.insert(Succ).second) { + Worklist.push_back(Succ); } } } } } -static void scanInlinedCode(Instruction *start, Instruction *end, - std::vector<CallInst *> &calls, - std::set<BasicBlock *> &seen) { - calls.clear(); - std::vector<BasicBlock *> worklist; - seen.insert(start->getParent()); - scanOneBB(start, end, calls, seen, worklist); - while (!worklist.empty()) { - BasicBlock *BB = worklist.back(); - worklist.pop_back(); - scanOneBB(&*BB->begin(), end, calls, seen, worklist); + +static void scanInlinedCode(Instruction *Start, Instruction *End, + std::vector<CallInst *> &Calls, + DenseSet<BasicBlock *> &Seen) { + Calls.clear(); + std::vector<BasicBlock *> Worklist; + Seen.insert(Start->getParent()); + scanOneBB(Start, End, Calls, Seen, Worklist); + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.back(); + Worklist.pop_back(); + scanOneBB(&*BB->begin(), End, Calls, Seen, Worklist); } } @@ -324,29 +305,27 @@ bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { // Note: In common usage, there will be only one edge due to LoopSimplify // having run sometime earlier in the pipeline, but this code must be correct // w.r.t. loops with multiple backedges. - BasicBlock *header = L->getHeader(); + BasicBlock *Header = L->getHeader(); SmallVector<BasicBlock*, 16> LoopLatches; L->getLoopLatches(LoopLatches); - for (BasicBlock *pred : LoopLatches) { - assert(L->contains(pred)); + for (BasicBlock *Pred : LoopLatches) { + assert(L->contains(Pred)); // Make a policy decision about whether this loop needs a safepoint or // not. Note that this is about unburdening the optimizer in loops, not // avoiding the runtime cost of the actual safepoint. if (!AllBackedges) { - if (mustBeFiniteCountedLoop(L, SE, pred)) { - if (TraceLSP) - errs() << "skipping safepoint placement in finite loop\n"; + if (mustBeFiniteCountedLoop(L, SE, Pred)) { + DEBUG(dbgs() << "skipping safepoint placement in finite loop\n"); FiniteExecution++; continue; } if (CallSafepointsEnabled && - containsUnconditionalCallSafepoint(L, header, pred, *DT)) { + containsUnconditionalCallSafepoint(L, Header, Pred, *DT)) { // Note: This is only semantically legal since we won't do any further // IPO or inlining before the actual call insertion.. If we hadn't, we // might latter loose this call safepoint. - if (TraceLSP) - errs() << "skipping safepoint placement due to unconditional call\n"; + DEBUG(dbgs() << "skipping safepoint placement due to unconditional call\n"); CallInLoop++; continue; } @@ -360,14 +339,11 @@ bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { // Safepoint insertion would involve creating a new basic block (as the // target of the current backedge) which does the safepoint (of all live // variables) and branches to the true header - TerminatorInst *term = pred->getTerminator(); + TerminatorInst *Term = Pred->getTerminator(); - if (TraceLSP) { - errs() << "[LSP] terminator instruction: "; - term->dump(); - } + DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term); - PollLocations.push_back(term); + PollLocations.push_back(Term); } return false; @@ -411,27 +387,26 @@ static Instruction *findLocationForEntrySafepoint(Function &F, // hasNextInstruction and nextInstruction are used to iterate // through a "straight line" execution sequence. - auto hasNextInstruction = [](Instruction *I) { - if (!I->isTerminator()) { + auto HasNextInstruction = [](Instruction *I) { + if (!I->isTerminator()) return true; - } + BasicBlock *nextBB = I->getParent()->getUniqueSuccessor(); return nextBB && (nextBB->getUniquePredecessor() != nullptr); }; - auto nextInstruction = [&hasNextInstruction](Instruction *I) { - assert(hasNextInstruction(I) && + auto NextInstruction = [&](Instruction *I) { + assert(HasNextInstruction(I) && "first check if there is a next instruction!"); - if (I->isTerminator()) { + + if (I->isTerminator()) return &I->getParent()->getUniqueSuccessor()->front(); - } else { - return &*++I->getIterator(); - } + return &*++I->getIterator(); }; - Instruction *cursor = nullptr; - for (cursor = &F.getEntryBlock().front(); hasNextInstruction(cursor); - cursor = nextInstruction(cursor)) { + Instruction *Cursor = nullptr; + for (Cursor = &F.getEntryBlock().front(); HasNextInstruction(Cursor); + Cursor = NextInstruction(Cursor)) { // We need to ensure a safepoint poll occurs before any 'real' call. The // easiest way to ensure finite execution between safepoints in the face of @@ -440,51 +415,17 @@ static Instruction *findLocationForEntrySafepoint(Function &F, // which can grow the stack by an unbounded amount. This isn't required // for GC semantics per se, but is a common requirement for languages // which detect stack overflow via guard pages and then throw exceptions. - if (auto CS = CallSite(cursor)) { + if (auto CS = CallSite(Cursor)) { if (doesNotRequireEntrySafepointBefore(CS)) continue; break; } } - assert((hasNextInstruction(cursor) || cursor->isTerminator()) && + assert((HasNextInstruction(Cursor) || Cursor->isTerminator()) && "either we stopped because of a call, or because of terminator"); - return cursor; -} - -/// Identify the list of call sites which need to be have parseable state -static void findCallSafepoints(Function &F, - std::vector<CallSite> &Found /*rval*/) { - assert(Found.empty() && "must be empty!"); - for (Instruction &I : instructions(F)) { - Instruction *inst = &I; - if (isa<CallInst>(inst) || isa<InvokeInst>(inst)) { - CallSite CS(inst); - - // No safepoint needed or wanted - if (!needsStatepoint(CS)) { - continue; - } - - Found.push_back(CS); - } - } -} - -/// Implement a unique function which doesn't require we sort the input -/// vector. Doing so has the effect of changing the output of a couple of -/// tests in ways which make them less useful in testing fused safepoints. -template <typename T> static void unique_unsorted(std::vector<T> &vec) { - std::set<T> seen; - std::vector<T> tmp; - vec.reserve(vec.size()); - std::swap(tmp, vec); - for (auto V : tmp) { - if (seen.insert(V).second) { - vec.push_back(V); - } - } + return Cursor; } static const char *const GCSafepointPollName = "gc.safepoint_poll"; @@ -514,24 +455,6 @@ static bool enableEntrySafepoints(Function &F) { return !NoEntry; } static bool enableBackedgeSafepoints(Function &F) { return !NoBackedge; } static bool enableCallSafepoints(Function &F) { return !NoCall; } -// Normalize basic block to make it ready to be target of invoke statepoint. -// Ensure that 'BB' does not have phi nodes. It may require spliting it. -static BasicBlock *normalizeForInvokeSafepoint(BasicBlock *BB, - BasicBlock *InvokeParent) { - BasicBlock *ret = BB; - - if (!BB->getUniquePredecessor()) { - ret = SplitBlockPredecessors(BB, InvokeParent, ""); - } - - // Now that 'ret' has unique predecessor we can safely remove all phi nodes - // from it - FoldSingleEntryPHINodes(ret); - assert(!isa<PHINode>(ret->begin())); - - return ret; -} - bool PlaceSafepoints::runOnFunction(Function &F) { if (F.isDeclaration() || F.empty()) { // This is a declaration, nothing to do. Must exit early to avoid crash in @@ -549,13 +472,13 @@ bool PlaceSafepoints::runOnFunction(Function &F) { if (!shouldRewriteFunction(F)) return false; - bool modified = false; + bool Modified = false; // In various bits below, we rely on the fact that uses are reachable from // defs. When there are basic blocks unreachable from the entry, dominance // and reachablity queries return non-sensical results. Thus, we preprocess // the function to ensure these properties hold. - modified |= removeUnreachableBlocks(F); + Modified |= removeUnreachableBlocks(F); // STEP 1 - Insert the safepoint polling locations. We do not need to // actually insert parse points yet. That will be done for all polls and @@ -574,8 +497,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { // with for the moment. legacy::FunctionPassManager FPM(F.getParent()); bool CanAssumeCallSafepoints = enableCallSafepoints(F); - PlaceBackedgeSafepointsImpl *PBS = - new PlaceBackedgeSafepointsImpl(CanAssumeCallSafepoints); + auto *PBS = new PlaceBackedgeSafepointsImpl(CanAssumeCallSafepoints); FPM.add(PBS); FPM.run(F); @@ -603,7 +525,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { // The poll location must be the terminator of a loop latch block. for (TerminatorInst *Term : PollLocations) { // We are inserting a poll, the function is modified - modified = true; + Modified = true; if (SplitBackedge) { // Split the backedge of the loop and insert the poll within that new @@ -643,14 +565,13 @@ bool PlaceSafepoints::runOnFunction(Function &F) { } if (enableEntrySafepoints(F)) { - Instruction *Location = findLocationForEntrySafepoint(F, DT); - if (!Location) { - // policy choice not to insert? - } else { + if (Instruction *Location = findLocationForEntrySafepoint(F, DT)) { PollsNeeded.push_back(Location); - modified = true; + Modified = true; NumEntrySafepoints++; } + // TODO: else we should assert that there was, in fact, a policy choice to + // not insert a entry safepoint poll. } // Now that we've identified all the needed safepoint poll locations, insert @@ -661,71 +582,8 @@ bool PlaceSafepoints::runOnFunction(Function &F) { ParsePointNeeded.insert(ParsePointNeeded.end(), RuntimeCalls.begin(), RuntimeCalls.end()); } - PollsNeeded.clear(); // make sure we don't accidentally use - // The dominator tree has been invalidated by the inlining performed in the - // above loop. TODO: Teach the inliner how to update the dom tree? - DT.recalculate(F); - - if (enableCallSafepoints(F)) { - std::vector<CallSite> Calls; - findCallSafepoints(F, Calls); - NumCallSafepoints += Calls.size(); - ParsePointNeeded.insert(ParsePointNeeded.end(), Calls.begin(), Calls.end()); - } - - // Unique the vectors since we can end up with duplicates if we scan the call - // site for call safepoints after we add it for entry or backedge. The - // only reason we need tracking at all is that some functions might have - // polls but not call safepoints and thus we might miss marking the runtime - // calls for the polls. (This is useful in test cases!) - unique_unsorted(ParsePointNeeded); - - // Any parse point (no matter what source) will be handled here - - // We're about to start modifying the function - if (!ParsePointNeeded.empty()) - modified = true; - - // Now run through and insert the safepoints, but do _NOT_ update or remove - // any existing uses. We have references to live variables that need to - // survive to the last iteration of this loop. - std::vector<Value *> Results; - Results.reserve(ParsePointNeeded.size()); - for (size_t i = 0; i < ParsePointNeeded.size(); i++) { - CallSite &CS = ParsePointNeeded[i]; - - // For invoke statepoints we need to remove all phi nodes at the normal - // destination block. - // Reason for this is that we can place gc_result only after last phi node - // in basic block. We will get malformed code after RAUW for the - // gc_result if one of this phi nodes uses result from the invoke. - if (InvokeInst *Invoke = dyn_cast<InvokeInst>(CS.getInstruction())) { - normalizeForInvokeSafepoint(Invoke->getNormalDest(), - Invoke->getParent()); - } - - Value *GCResult = ReplaceWithStatepoint(CS); - Results.push_back(GCResult); - } - assert(Results.size() == ParsePointNeeded.size()); - - // Adjust all users of the old call sites to use the new ones instead - for (size_t i = 0; i < ParsePointNeeded.size(); i++) { - CallSite &CS = ParsePointNeeded[i]; - Value *GCResult = Results[i]; - if (GCResult) { - // Can not RAUW for the invoke gc result in case of phi nodes preset. - assert(CS.isCall() || !isa<PHINode>(cast<Instruction>(GCResult)->getParent()->begin())); - - // Replace all uses with the new call - CS.getInstruction()->replaceAllUsesWith(GCResult); - } - // Now that we've handled all uses, remove the original call itself - // Note: The insert point can't be the deleted instruction! - CS.getInstruction()->eraseFromParent(); - } - return modified; + return Modified; } char PlaceBackedgeSafepointsImpl::ID = 0; @@ -763,191 +621,60 @@ InsertSafepointPoll(Instruction *InsertBefore, auto *F = M->getFunction(GCSafepointPollName); assert(F && "gc.safepoint_poll function is missing"); - assert(F->getType()->getElementType() == + assert(F->getValueType() == FunctionType::get(Type::getVoidTy(M->getContext()), false) && "gc.safepoint_poll declared with wrong type"); assert(!F->empty() && "gc.safepoint_poll must be a non-empty function"); CallInst *PollCall = CallInst::Create(F, "", InsertBefore); // Record some information about the call site we're replacing - BasicBlock::iterator before(PollCall), after(PollCall); - bool isBegin(false); - if (before == OrigBB->begin()) { - isBegin = true; - } else { - before--; - } - after++; - assert(after != OrigBB->end() && "must have successor"); + BasicBlock::iterator Before(PollCall), After(PollCall); + bool IsBegin = false; + if (Before == OrigBB->begin()) + IsBegin = true; + else + Before--; - // do the actual inlining + After++; + assert(After != OrigBB->end() && "must have successor"); + + // Do the actual inlining InlineFunctionInfo IFI; bool InlineStatus = InlineFunction(PollCall, IFI); assert(InlineStatus && "inline must succeed"); (void)InlineStatus; // suppress warning in release-asserts - // Check post conditions + // Check post-conditions assert(IFI.StaticAllocas.empty() && "can't have allocs"); - std::vector<CallInst *> calls; // new calls - std::set<BasicBlock *> BBs; // new BBs + insertee + std::vector<CallInst *> Calls; // new calls + DenseSet<BasicBlock *> BBs; // new BBs + insertee + // Include only the newly inserted instructions, Note: begin may not be valid // if we inserted to the beginning of the basic block - BasicBlock::iterator start; - if (isBegin) { - start = OrigBB->begin(); - } else { - start = before; - start++; - } + BasicBlock::iterator Start = IsBegin ? OrigBB->begin() : std::next(Before); // If your poll function includes an unreachable at the end, that's not // valid. Bugpoint likes to create this, so check for it. - assert(isPotentiallyReachable(&*start, &*after, nullptr, nullptr) && + assert(isPotentiallyReachable(&*Start, &*After) && "malformed poll function"); - scanInlinedCode(&*(start), &*(after), calls, BBs); - assert(!calls.empty() && "slow path not found for safepoint poll"); + scanInlinedCode(&*Start, &*After, Calls, BBs); + assert(!Calls.empty() && "slow path not found for safepoint poll"); // Record the fact we need a parsable state at the runtime call contained in // the poll function. This is required so that the runtime knows how to // parse the last frame when we actually take the safepoint (i.e. execute // the slow path) assert(ParsePointsNeeded.empty()); - for (size_t i = 0; i < calls.size(); i++) { - + for (auto *CI : Calls) { // No safepoint needed or wanted - if (!needsStatepoint(calls[i])) { + if (!needsStatepoint(CI)) continue; - } // These are likely runtime calls. Should we assert that via calling // convention or something? - ParsePointsNeeded.push_back(CallSite(calls[i])); - } - assert(ParsePointsNeeded.size() <= calls.size()); -} - -/// Replaces the given call site (Call or Invoke) with a gc.statepoint -/// intrinsic with an empty deoptimization arguments list. This does -/// NOT do explicit relocation for GC support. -static Value *ReplaceWithStatepoint(const CallSite &CS /* to replace */) { - assert(CS.getInstruction()->getModule() && "must be set"); - - // TODO: technically, a pass is not allowed to get functions from within a - // function pass since it might trigger a new function addition. Refactor - // this logic out to the initialization of the pass. Doesn't appear to - // matter in practice. - - // Then go ahead and use the builder do actually do the inserts. We insert - // immediately before the previous instruction under the assumption that all - // arguments will be available here. We can't insert afterwards since we may - // be replacing a terminator. - IRBuilder<> Builder(CS.getInstruction()); - - // Note: The gc args are not filled in at this time, that's handled by - // RewriteStatepointsForGC (which is currently under review). - - // Create the statepoint given all the arguments - Instruction *Token = nullptr; - - uint64_t ID; - uint32_t NumPatchBytes; - - AttributeSet OriginalAttrs = CS.getAttributes(); - Attribute AttrID = - OriginalAttrs.getAttribute(AttributeSet::FunctionIndex, "statepoint-id"); - Attribute AttrNumPatchBytes = OriginalAttrs.getAttribute( - AttributeSet::FunctionIndex, "statepoint-num-patch-bytes"); - - AttrBuilder AttrsToRemove; - bool HasID = AttrID.isStringAttribute() && - !AttrID.getValueAsString().getAsInteger(10, ID); - - if (HasID) - AttrsToRemove.addAttribute("statepoint-id"); - else - ID = 0xABCDEF00; - - bool HasNumPatchBytes = - AttrNumPatchBytes.isStringAttribute() && - !AttrNumPatchBytes.getValueAsString().getAsInteger(10, NumPatchBytes); - - if (HasNumPatchBytes) - AttrsToRemove.addAttribute("statepoint-num-patch-bytes"); - else - NumPatchBytes = 0; - - OriginalAttrs = OriginalAttrs.removeAttributes( - CS.getInstruction()->getContext(), AttributeSet::FunctionIndex, - AttrsToRemove); - - if (CS.isCall()) { - CallInst *ToReplace = cast<CallInst>(CS.getInstruction()); - CallInst *Call = Builder.CreateGCStatepointCall( - ID, NumPatchBytes, CS.getCalledValue(), - makeArrayRef(CS.arg_begin(), CS.arg_end()), None, None, - "safepoint_token"); - Call->setTailCall(ToReplace->isTailCall()); - Call->setCallingConv(ToReplace->getCallingConv()); - - // In case if we can handle this set of attributes - set up function - // attributes directly on statepoint and return attributes later for - // gc_result intrinsic. - Call->setAttributes(OriginalAttrs.getFnAttributes()); - - Token = Call; - - // Put the following gc_result and gc_relocate calls immediately after - // the old call (which we're about to delete). - assert(ToReplace->getNextNode() && "not a terminator, must have next"); - Builder.SetInsertPoint(ToReplace->getNextNode()); - Builder.SetCurrentDebugLocation(ToReplace->getNextNode()->getDebugLoc()); - } else if (CS.isInvoke()) { - InvokeInst *ToReplace = cast<InvokeInst>(CS.getInstruction()); - - // Insert the new invoke into the old block. We'll remove the old one in a - // moment at which point this will become the new terminator for the - // original block. - Builder.SetInsertPoint(ToReplace->getParent()); - InvokeInst *Invoke = Builder.CreateGCStatepointInvoke( - ID, NumPatchBytes, CS.getCalledValue(), ToReplace->getNormalDest(), - ToReplace->getUnwindDest(), makeArrayRef(CS.arg_begin(), CS.arg_end()), - None, None, "safepoint_token"); - - Invoke->setCallingConv(ToReplace->getCallingConv()); - - // In case if we can handle this set of attributes - set up function - // attributes directly on statepoint and return attributes later for - // gc_result intrinsic. - Invoke->setAttributes(OriginalAttrs.getFnAttributes()); - - Token = Invoke; - - // We'll insert the gc.result into the normal block - BasicBlock *NormalDest = ToReplace->getNormalDest(); - // Can not insert gc.result in case of phi nodes preset. - // Should have removed this cases prior to running this function - assert(!isa<PHINode>(NormalDest->begin())); - Instruction *IP = &*(NormalDest->getFirstInsertionPt()); - Builder.SetInsertPoint(IP); - } else { - llvm_unreachable("unexpect type of CallSite"); - } - assert(Token); - - // Handle the return value of the original call - update all uses to use a - // gc_result hanging off the statepoint node we just inserted - - // Only add the gc_result iff there is actually a used result - if (!CS.getType()->isVoidTy() && !CS.getInstruction()->use_empty()) { - std::string TakenName = - CS.getInstruction()->hasName() ? CS.getInstruction()->getName() : ""; - CallInst *GCResult = Builder.CreateGCResult(Token, CS.getType(), TakenName); - GCResult->setAttributes(OriginalAttrs.getRetAttributes()); - return GCResult; - } else { - // No return value for the call. - return nullptr; + ParsePointsNeeded.push_back(CallSite(CI)); } + assert(ParsePointsNeeded.size() <= Calls.size()); } diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index bcadd4e2bee6..b930a8fb7e99 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -20,7 +20,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/Reassociate.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" @@ -39,9 +39,11 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> using namespace llvm; +using namespace reassociate; #define DEBUG_TYPE "reassociate" @@ -49,17 +51,6 @@ STATISTIC(NumChanged, "Number of insts reassociated"); STATISTIC(NumAnnihil, "Number of expr tree annihilated"); STATISTIC(NumFactor , "Number of multiplies factored"); -namespace { - struct ValueEntry { - unsigned Rank; - Value *Op; - ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {} - }; - inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { - return LHS.Rank > RHS.Rank; // Sort so that highest rank goes to start. - } -} - #ifndef NDEBUG /// Print out the expression identified in the Ops list. /// @@ -75,120 +66,35 @@ static void PrintOps(Instruction *I, const SmallVectorImpl<ValueEntry> &Ops) { } #endif -namespace { - /// \brief Utility class representing a base and exponent pair which form one - /// factor of some product. - struct Factor { - Value *Base; - unsigned Power; - - Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {} - - /// \brief Sort factors in descending order by their power. - struct PowerDescendingSorter { - bool operator()(const Factor &LHS, const Factor &RHS) { - return LHS.Power > RHS.Power; - } - }; - - /// \brief Compare factors for equal powers. - struct PowerEqual { - bool operator()(const Factor &LHS, const Factor &RHS) { - return LHS.Power == RHS.Power; - } - }; - }; - - /// Utility class representing a non-constant Xor-operand. We classify - /// non-constant Xor-Operands into two categories: - /// C1) The operand is in the form "X & C", where C is a constant and C != ~0 - /// C2) - /// C2.1) The operand is in the form of "X | C", where C is a non-zero - /// constant. - /// C2.2) Any operand E which doesn't fall into C1 and C2.1, we view this - /// operand as "E | 0" - class XorOpnd { - public: - XorOpnd(Value *V); - - bool isInvalid() const { return SymbolicPart == nullptr; } - bool isOrExpr() const { return isOr; } - Value *getValue() const { return OrigVal; } - Value *getSymbolicPart() const { return SymbolicPart; } - unsigned getSymbolicRank() const { return SymbolicRank; } - const APInt &getConstPart() const { return ConstPart; } - - void Invalidate() { SymbolicPart = OrigVal = nullptr; } - void setSymbolicRank(unsigned R) { SymbolicRank = R; } - - // Sort the XorOpnd-Pointer in ascending order of symbolic-value-rank. - // The purpose is twofold: - // 1) Cluster together the operands sharing the same symbolic-value. - // 2) Operand having smaller symbolic-value-rank is permuted earlier, which - // could potentially shorten crital path, and expose more loop-invariants. - // Note that values' rank are basically defined in RPO order (FIXME). - // So, if Rank(X) < Rank(Y) < Rank(Z), it means X is defined earlier - // than Y which is defined earlier than Z. Permute "x | 1", "Y & 2", - // "z" in the order of X-Y-Z is better than any other orders. - struct PtrSortFunctor { - bool operator()(XorOpnd * const &LHS, XorOpnd * const &RHS) { - return LHS->getSymbolicRank() < RHS->getSymbolicRank(); - } - }; - private: - Value *OrigVal; - Value *SymbolicPart; - APInt ConstPart; - unsigned SymbolicRank; - bool isOr; - }; -} - -namespace { - class Reassociate : public FunctionPass { - DenseMap<BasicBlock*, unsigned> RankMap; - DenseMap<AssertingVH<Value>, unsigned> ValueRankMap; - SetVector<AssertingVH<Instruction> > RedoInsts; - bool MadeChange; - public: - static char ID; // Pass identification, replacement for typeid - Reassociate() : FunctionPass(ID) { - initializeReassociatePass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } - private: - void BuildRankMap(Function &F); - unsigned getRank(Value *V); - void canonicalizeOperands(Instruction *I); - void ReassociateExpression(BinaryOperator *I); - void RewriteExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops); - Value *OptimizeExpression(BinaryOperator *I, - SmallVectorImpl<ValueEntry> &Ops); - Value *OptimizeAdd(Instruction *I, SmallVectorImpl<ValueEntry> &Ops); - Value *OptimizeXor(Instruction *I, SmallVectorImpl<ValueEntry> &Ops); - bool CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt &ConstOpnd, - Value *&Res); - bool CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, XorOpnd *Opnd2, - APInt &ConstOpnd, Value *&Res); - bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, - SmallVectorImpl<Factor> &Factors); - Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder, - SmallVectorImpl<Factor> &Factors); - Value *OptimizeMul(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops); - Value *RemoveFactorFromExpression(Value *V, Value *Factor); - void EraseInst(Instruction *I); - void RecursivelyEraseDeadInsts(Instruction *I, - SetVector<AssertingVH<Instruction>> &Insts); - void OptimizeInst(Instruction *I); - Instruction *canonicalizeNegConstExpr(Instruction *I); - }; -} +/// Utility class representing a non-constant Xor-operand. We classify +/// non-constant Xor-Operands into two categories: +/// C1) The operand is in the form "X & C", where C is a constant and C != ~0 +/// C2) +/// C2.1) The operand is in the form of "X | C", where C is a non-zero +/// constant. +/// C2.2) Any operand E which doesn't fall into C1 and C2.1, we view this +/// operand as "E | 0" +class llvm::reassociate::XorOpnd { +public: + XorOpnd(Value *V); + + bool isInvalid() const { return SymbolicPart == nullptr; } + bool isOrExpr() const { return isOr; } + Value *getValue() const { return OrigVal; } + Value *getSymbolicPart() const { return SymbolicPart; } + unsigned getSymbolicRank() const { return SymbolicRank; } + const APInt &getConstPart() const { return ConstPart; } + + void Invalidate() { SymbolicPart = OrigVal = nullptr; } + void setSymbolicRank(unsigned R) { SymbolicRank = R; } + +private: + Value *OrigVal; + Value *SymbolicPart; + APInt ConstPart; + unsigned SymbolicRank; + bool isOr; +}; XorOpnd::XorOpnd(Value *V) { assert(!isa<ConstantInt>(V) && "No ConstantInt"); @@ -217,13 +123,6 @@ XorOpnd::XorOpnd(Value *V) { isOr = true; } -char Reassociate::ID = 0; -INITIALIZE_PASS(Reassociate, "reassociate", - "Reassociate expressions", false, false) - -// Public interface to the Reassociate pass -FunctionPass *llvm::createReassociatePass() { return new Reassociate(); } - /// Return true if V is an instruction of the specified opcode and if it /// only has one use. static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) { @@ -246,7 +145,8 @@ static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode1, return nullptr; } -void Reassociate::BuildRankMap(Function &F) { +void ReassociatePass::BuildRankMap( + Function &F, ReversePostOrderTraversal<Function *> &RPOT) { unsigned i = 2; // Assign distinct ranks to function arguments. @@ -255,22 +155,19 @@ void Reassociate::BuildRankMap(Function &F) { DEBUG(dbgs() << "Calculated Rank[" << I->getName() << "] = " << i << "\n"); } - ReversePostOrderTraversal<Function*> RPOT(&F); - for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(), - E = RPOT.end(); I != E; ++I) { - BasicBlock *BB = *I; + for (BasicBlock *BB : RPOT) { unsigned BBRank = RankMap[BB] = ++i << 16; // Walk the basic block, adding precomputed ranks for any instructions that // we cannot move. This ensures that the ranks for these instructions are // all different in the block. - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) - if (mayBeMemoryDependent(*I)) - ValueRankMap[&*I] = ++BBRank; + for (Instruction &I : *BB) + if (mayBeMemoryDependent(I)) + ValueRankMap[&I] = ++BBRank; } } -unsigned Reassociate::getRank(Value *V) { +unsigned ReassociatePass::getRank(Value *V) { Instruction *I = dyn_cast<Instruction>(V); if (!I) { if (isa<Argument>(V)) return ValueRankMap[V]; // Function argument. @@ -301,7 +198,7 @@ unsigned Reassociate::getRank(Value *V) { } // Canonicalize constants to RHS. Otherwise, sort the operands by rank. -void Reassociate::canonicalizeOperands(Instruction *I) { +void ReassociatePass::canonicalizeOperands(Instruction *I) { assert(isa<BinaryOperator>(I) && "Expected binary operator."); assert(I->isCommutative() && "Expected commutative operator."); @@ -711,8 +608,8 @@ static bool LinearizeExprTree(BinaryOperator *I, /// Now that the operands for this expression tree are /// linearized and optimized, emit them in-order. -void Reassociate::RewriteExprTree(BinaryOperator *I, - SmallVectorImpl<ValueEntry> &Ops) { +void ReassociatePass::RewriteExprTree(BinaryOperator *I, + SmallVectorImpl<ValueEntry> &Ops) { assert(Ops.size() > 1 && "Single values should be used directly!"); // Since our optimizations should never increase the number of operations, the @@ -1095,7 +992,7 @@ static Value *EmitAddTreeOfValues(Instruction *I, /// If V is an expression tree that is a multiplication sequence, /// and if this sequence contains a multiply by Factor, /// remove Factor from the tree and return the new tree. -Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { +Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul); if (!BO) return nullptr; @@ -1129,7 +1026,7 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { } } else if (ConstantFP *FC1 = dyn_cast<ConstantFP>(Factor)) { if (ConstantFP *FC2 = dyn_cast<ConstantFP>(Factors[i].Op)) { - APFloat F1(FC1->getValueAPF()); + const APFloat &F1 = FC1->getValueAPF(); APFloat F2(FC2->getValueAPF()); F2.changeSign(); if (F1.compare(F2) == APFloat::cmpEqual) { @@ -1258,9 +1155,9 @@ static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, // If it was successful, true is returned, and the "R" and "C" is returned // via "Res" and "ConstOpnd", respectively; otherwise, false is returned, // and both "Res" and "ConstOpnd" remain unchanged. -// -bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, - APInt &ConstOpnd, Value *&Res) { +// +bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, + APInt &ConstOpnd, Value *&Res) { // Xor-Rule 1: (x | c1) ^ c2 = (x | c1) ^ (c1 ^ c1) ^ c2 // = ((x | c1) ^ c1) ^ (c1 ^ c2) // = (x & ~c1) ^ (c1 ^ c2) @@ -1294,8 +1191,9 @@ bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, // via "Res" and "ConstOpnd", respectively (If the entire expression is // evaluated to a constant, the Res is set to NULL); otherwise, false is // returned, and both "Res" and "ConstOpnd" remain unchanged. -bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, XorOpnd *Opnd2, - APInt &ConstOpnd, Value *&Res) { +bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, + XorOpnd *Opnd2, APInt &ConstOpnd, + Value *&Res) { Value *X = Opnd1->getSymbolicPart(); if (X != Opnd2->getSymbolicPart()) return false; @@ -1369,8 +1267,8 @@ bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, XorOpnd *Opnd2, /// Optimize a series of operands to an 'xor' instruction. If it can be reduced /// to a single Value, it is returned, otherwise the Ops list is mutated as /// necessary. -Value *Reassociate::OptimizeXor(Instruction *I, - SmallVectorImpl<ValueEntry> &Ops) { +Value *ReassociatePass::OptimizeXor(Instruction *I, + SmallVectorImpl<ValueEntry> &Ops) { if (Value *V = OptimizeAndOrXor(Instruction::Xor, Ops)) return V; @@ -1405,7 +1303,19 @@ Value *Reassociate::OptimizeXor(Instruction *I, // the same symbolic value cluster together. For instance, the input operand // sequence ("x | 123", "y & 456", "x & 789") will be sorted into: // ("x | 123", "x & 789", "y & 456"). - std::stable_sort(OpndPtrs.begin(), OpndPtrs.end(), XorOpnd::PtrSortFunctor()); + // + // The purpose is twofold: + // 1) Cluster together the operands sharing the same symbolic-value. + // 2) Operand having smaller symbolic-value-rank is permuted earlier, which + // could potentially shorten crital path, and expose more loop-invariants. + // Note that values' rank are basically defined in RPO order (FIXME). + // So, if Rank(X) < Rank(Y) < Rank(Z), it means X is defined earlier + // than Y which is defined earlier than Z. Permute "x | 1", "Y & 2", + // "z" in the order of X-Y-Z is better than any other orders. + std::stable_sort(OpndPtrs.begin(), OpndPtrs.end(), + [](XorOpnd *LHS, XorOpnd *RHS) { + return LHS->getSymbolicRank() < RHS->getSymbolicRank(); + }); // Step 3: Combine adjacent operands XorOpnd *PrevOpnd = nullptr; @@ -1478,8 +1388,8 @@ Value *Reassociate::OptimizeXor(Instruction *I, /// Optimize a series of operands to an 'add' instruction. This /// optimizes based on identities. If it can be reduced to a single Value, it /// is returned, otherwise the Ops list is mutated as necessary. -Value *Reassociate::OptimizeAdd(Instruction *I, - SmallVectorImpl<ValueEntry> &Ops) { +Value *ReassociatePass::OptimizeAdd(Instruction *I, + SmallVectorImpl<ValueEntry> &Ops) { // Scan the operand lists looking for X and -X pairs. If we find any, we // can simplify expressions like X+-X == 0 and X+~X ==-1. While we're at it, // scan for any @@ -1716,8 +1626,8 @@ Value *Reassociate::OptimizeAdd(Instruction *I, /// ((((x*y)*x)*y)*x) -> [(x, 3), (y, 2)] /// /// \returns Whether any factors have a power greater than one. -bool Reassociate::collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, - SmallVectorImpl<Factor> &Factors) { +bool ReassociatePass::collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, + SmallVectorImpl<Factor> &Factors) { // FIXME: Have Ops be (ValueEntry, Multiplicity) pairs, simplifying this. // Compute the sum of powers of simplifiable factors. unsigned FactorPowerSum = 0; @@ -1763,7 +1673,10 @@ bool Reassociate::collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, // below our mininum of '4'. assert(FactorPowerSum >= 4); - std::stable_sort(Factors.begin(), Factors.end(), Factor::PowerDescendingSorter()); + std::stable_sort(Factors.begin(), Factors.end(), + [](const Factor &LHS, const Factor &RHS) { + return LHS.Power > RHS.Power; + }); return true; } @@ -1790,8 +1703,9 @@ static Value *buildMultiplyTree(IRBuilder<> &Builder, /// equal and the powers are sorted in decreasing order, compute the minimal /// DAG of multiplies to compute the final product, and return that product /// value. -Value *Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder, - SmallVectorImpl<Factor> &Factors) { +Value * +ReassociatePass::buildMinimalMultiplyDAG(IRBuilder<> &Builder, + SmallVectorImpl<Factor> &Factors) { assert(Factors[0].Power); SmallVector<Value *, 4> OuterProduct; for (unsigned LastIdx = 0, Idx = 1, Size = Factors.size(); @@ -1822,7 +1736,9 @@ Value *Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder, // Unique factors with equal powers -- we've folded them into the first one's // base. Factors.erase(std::unique(Factors.begin(), Factors.end(), - Factor::PowerEqual()), + [](const Factor &LHS, const Factor &RHS) { + return LHS.Power == RHS.Power; + }), Factors.end()); // Iteratively collect the base of each factor with an add power into the @@ -1845,8 +1761,8 @@ Value *Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder, return V; } -Value *Reassociate::OptimizeMul(BinaryOperator *I, - SmallVectorImpl<ValueEntry> &Ops) { +Value *ReassociatePass::OptimizeMul(BinaryOperator *I, + SmallVectorImpl<ValueEntry> &Ops) { // We can only optimize the multiplies when there is a chain of more than // three, such that a balanced tree might require fewer total multiplies. if (Ops.size() < 4) @@ -1869,8 +1785,8 @@ Value *Reassociate::OptimizeMul(BinaryOperator *I, return nullptr; } -Value *Reassociate::OptimizeExpression(BinaryOperator *I, - SmallVectorImpl<ValueEntry> &Ops) { +Value *ReassociatePass::OptimizeExpression(BinaryOperator *I, + SmallVectorImpl<ValueEntry> &Ops) { // Now that we have the linearized expression tree, try to optimize it. // Start by folding any constants that we found. Constant *Cst = nullptr; @@ -1930,7 +1846,7 @@ Value *Reassociate::OptimizeExpression(BinaryOperator *I, // Remove dead instructions and if any operands are trivially dead add them to // Insts so they will be removed as well. -void Reassociate::RecursivelyEraseDeadInsts( +void ReassociatePass::RecursivelyEraseDeadInsts( Instruction *I, SetVector<AssertingVH<Instruction>> &Insts) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); SmallVector<Value *, 4> Ops(I->op_begin(), I->op_end()); @@ -1945,7 +1861,7 @@ void Reassociate::RecursivelyEraseDeadInsts( } /// Zap the given instruction, adding interesting operands to the work list. -void Reassociate::EraseInst(Instruction *I) { +void ReassociatePass::EraseInst(Instruction *I) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); // Erase the dead instruction. @@ -1969,7 +1885,7 @@ void Reassociate::EraseInst(Instruction *I) { // Canonicalize expressions of the following form: // x + (-Constant * y) -> x - (Constant * y) // x - (-Constant * y) -> x + (Constant * y) -Instruction *Reassociate::canonicalizeNegConstExpr(Instruction *I) { +Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { if (!I->hasOneUse() || I->getType()->isVectorTy()) return nullptr; @@ -2046,7 +1962,7 @@ Instruction *Reassociate::canonicalizeNegConstExpr(Instruction *I) { /// Inspect and optimize the given instruction. Note that erasing /// instructions is not allowed. -void Reassociate::OptimizeInst(Instruction *I) { +void ReassociatePass::OptimizeInst(Instruction *I) { // Only consider operations that we understand. if (!isa<BinaryOperator>(I)) return; @@ -2173,7 +2089,7 @@ void Reassociate::OptimizeInst(Instruction *I) { ReassociateExpression(BO); } -void Reassociate::ReassociateExpression(BinaryOperator *I) { +void ReassociatePass::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector<RepeatedValue, 8> Tree; @@ -2255,46 +2171,53 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) { RewriteExprTree(I, Ops); } -bool Reassociate::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) - return false; - - // Calculate the rank map for F - BuildRankMap(F); +PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { + // Reassociate needs for each instruction to have its operands already + // processed, so we first perform a RPOT of the basic blocks so that + // when we process a basic block, all its dominators have been processed + // before. + ReversePostOrderTraversal<Function *> RPOT(&F); + BuildRankMap(F, RPOT); MadeChange = false; - for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { + for (BasicBlock *BI : RPOT) { + // Use a worklist to keep track of which instructions have been processed + // (and which insts won't be optimized again) so when redoing insts, + // optimize insts rightaway which won't be processed later. + SmallSet<Instruction *, 8> Worklist; + + // Insert all instructions in the BB + for (Instruction &I : *BI) + Worklist.insert(&I); + // Optimize every instruction in the basic block. - for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE; ) + for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE;) { + // This instruction has been processed. + Worklist.erase(&*II); if (isInstructionTriviallyDead(&*II)) { EraseInst(&*II++); } else { OptimizeInst(&*II); - assert(II->getParent() == BI && "Moved to a different block!"); + assert(II->getParent() == &*BI && "Moved to a different block!"); ++II; } - // Make a copy of all the instructions to be redone so we can remove dead - // instructions. - SetVector<AssertingVH<Instruction>> ToRedo(RedoInsts); - // Iterate over all instructions to be reevaluated and remove trivially dead - // instructions. If any operand of the trivially dead instruction becomes - // dead mark it for deletion as well. Continue this process until all - // trivially dead instructions have been removed. - while (!ToRedo.empty()) { - Instruction *I = ToRedo.pop_back_val(); - if (isInstructionTriviallyDead(I)) - RecursivelyEraseDeadInsts(I, ToRedo); - } - - // Now that we have removed dead instructions, we can reoptimize the - // remaining instructions. - while (!RedoInsts.empty()) { - Instruction *I = RedoInsts.pop_back_val(); - if (isInstructionTriviallyDead(I)) - EraseInst(I); - else - OptimizeInst(I); + // If the above optimizations produced new instructions to optimize or + // made modifications which need to be redone, do them now if they won't + // be handled later. + while (!RedoInsts.empty()) { + Instruction *I = RedoInsts.pop_back_val(); + // Process instructions that won't be processed later, either + // inside the block itself or in another basic block (based on rank), + // since these will be processed later. + if ((I->getParent() != BI || !Worklist.count(I)) && + RankMap[I->getParent()] <= RankMap[BI]) { + if (isInstructionTriviallyDead(I)) + EraseInst(I); + else + OptimizeInst(I); + } + } } } @@ -2302,5 +2225,46 @@ bool Reassociate::runOnFunction(Function &F) { RankMap.clear(); ValueRankMap.clear(); - return MadeChange; + if (MadeChange) { + // FIXME: This should also 'preserve the CFG'. + auto PA = PreservedAnalyses(); + PA.preserve<GlobalsAA>(); + return PA; + } + + return PreservedAnalyses::all(); +} + +namespace { + class ReassociateLegacyPass : public FunctionPass { + ReassociatePass Impl; + public: + static char ID; // Pass identification, replacement for typeid + ReassociateLegacyPass() : FunctionPass(ID) { + initializeReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + FunctionAnalysisManager DummyFAM; + auto PA = Impl.run(F, DummyFAM); + return !PA.areAllPreserved(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + }; +} + +char ReassociateLegacyPass::ID = 0; +INITIALIZE_PASS(ReassociateLegacyPass, "reassociate", + "Reassociate expressions", false, false) + +// Public interface to the Reassociate pass +FunctionPass *llvm::createReassociatePass() { + return new ReassociateLegacyPass(); } diff --git a/lib/Transforms/Scalar/Reg2Mem.cpp b/lib/Transforms/Scalar/Reg2Mem.cpp index 915f89780c08..615029dd161b 100644 --- a/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/lib/Transforms/Scalar/Reg2Mem.cpp @@ -68,7 +68,7 @@ INITIALIZE_PASS_END(RegToMem, "reg2mem", "Demote all values to stack slots", false, false) bool RegToMem::runOnFunction(Function &F) { - if (F.isDeclaration()) + if (F.isDeclaration() || skipFunction(F)) return false; // Insert all new allocas into entry block. @@ -89,10 +89,9 @@ bool RegToMem::runOnFunction(Function &F) { // Find the escaped instructions. But don't create stack slots for // allocas in entry block. std::list<Instruction*> WorkList; - for (Function::iterator ibb = F.begin(), ibe = F.end(); - ibb != ibe; ++ibb) - for (BasicBlock::iterator iib = ibb->begin(), iie = ibb->end(); - iib != iie; ++iib) { + for (BasicBlock &ibb : F) + for (BasicBlock::iterator iib = ibb.begin(), iie = ibb.end(); iib != iie; + ++iib) { if (!(isa<AllocaInst>(iib) && iib->getParent() == BBEntry) && valueEscapes(&*iib)) { WorkList.push_front(&*iib); @@ -101,25 +100,22 @@ bool RegToMem::runOnFunction(Function &F) { // Demote escaped instructions NumRegsDemoted += WorkList.size(); - for (std::list<Instruction*>::iterator ilb = WorkList.begin(), - ile = WorkList.end(); ilb != ile; ++ilb) - DemoteRegToStack(**ilb, false, AllocaInsertionPoint); + for (Instruction *ilb : WorkList) + DemoteRegToStack(*ilb, false, AllocaInsertionPoint); WorkList.clear(); // Find all phi's - for (Function::iterator ibb = F.begin(), ibe = F.end(); - ibb != ibe; ++ibb) - for (BasicBlock::iterator iib = ibb->begin(), iie = ibb->end(); - iib != iie; ++iib) + for (BasicBlock &ibb : F) + for (BasicBlock::iterator iib = ibb.begin(), iie = ibb.end(); iib != iie; + ++iib) if (isa<PHINode>(iib)) WorkList.push_front(&*iib); // Demote phi nodes NumPhisDemoted += WorkList.size(); - for (std::list<Instruction*>::iterator ilb = WorkList.begin(), - ile = WorkList.end(); ilb != ile; ++ilb) - DemotePHIToStack(cast<PHINode>(*ilb), AllocaInsertionPoint); + for (Instruction *ilb : WorkList) + DemotePHIToStack(cast<PHINode>(ilb), AllocaInsertionPoint); return true; } diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index d77d5745e60c..bab39a32677f 100644 --- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -14,7 +14,6 @@ #include "llvm/Pass.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/Statistic.h" @@ -63,7 +62,7 @@ static cl::opt<unsigned> RematerializationThreshold("spp-rematerialization-threshold", cl::Hidden, cl::init(6)); -#ifdef XDEBUG +#ifdef EXPENSIVE_CHECKS static bool ClobberNonLive = true; #else static bool ClobberNonLive = false; @@ -72,19 +71,10 @@ static cl::opt<bool, true> ClobberNonLiveOverride("rs4gc-clobber-non-live", cl::location(ClobberNonLive), cl::Hidden); -static cl::opt<bool> UseDeoptBundles("rs4gc-use-deopt-bundles", cl::Hidden, - cl::init(false)); static cl::opt<bool> AllowStatepointWithNoDeoptInfo("rs4gc-allow-statepoint-with-no-deopt-info", cl::Hidden, cl::init(true)); -/// Should we split vectors of pointers into their individual elements? This -/// is known to be buggy, but the alternate implementation isn't yet ready. -/// This is purely to provide a debugging and dianostic hook until the vector -/// split is replaced with vector relocations. -static cl::opt<bool> UseVectorSplit("rs4gc-split-vector-values", cl::Hidden, - cl::init(true)); - namespace { struct RewriteStatepointsForGC : public ModulePass { static char ID; // Pass identification, replacement for typeid @@ -141,24 +131,25 @@ ModulePass *llvm::createRewriteStatepointsForGCPass() { INITIALIZE_PASS_BEGIN(RewriteStatepointsForGC, "rewrite-statepoints-for-gc", "Make relocations explicit at statepoints", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(RewriteStatepointsForGC, "rewrite-statepoints-for-gc", "Make relocations explicit at statepoints", false, false) namespace { struct GCPtrLivenessData { /// Values defined in this block. - DenseMap<BasicBlock *, DenseSet<Value *>> KillSet; + MapVector<BasicBlock *, SetVector<Value *>> KillSet; /// Values used in this block (and thus live); does not included values /// killed within this block. - DenseMap<BasicBlock *, DenseSet<Value *>> LiveSet; + MapVector<BasicBlock *, SetVector<Value *>> LiveSet; /// Values live into this basic block (i.e. used by any /// instruction in this basic block or ones reachable from here) - DenseMap<BasicBlock *, DenseSet<Value *>> LiveIn; + MapVector<BasicBlock *, SetVector<Value *>> LiveIn; /// Values live out of this basic block (i.e. live into /// any successor block) - DenseMap<BasicBlock *, DenseSet<Value *>> LiveOut; + MapVector<BasicBlock *, SetVector<Value *>> LiveOut; }; // The type of the internal cache used inside the findBasePointers family @@ -171,9 +162,9 @@ struct GCPtrLivenessData { // Generally, after the execution of a full findBasePointer call, only the // base relation will remain. Internally, we add a mixture of the two // types, then update all the second type to the first type -typedef DenseMap<Value *, Value *> DefiningValueMapTy; -typedef DenseSet<Value *> StatepointLiveSetTy; -typedef DenseMap<AssertingVH<Instruction>, AssertingVH<Value>> +typedef MapVector<Value *, Value *> DefiningValueMapTy; +typedef SetVector<Value *> StatepointLiveSetTy; +typedef MapVector<AssertingVH<Instruction>, AssertingVH<Value>> RematerializedValueMapTy; struct PartiallyConstructedSafepointRecord { @@ -181,7 +172,7 @@ struct PartiallyConstructedSafepointRecord { StatepointLiveSetTy LiveSet; /// Mapping from live pointers to a base-defining-value - DenseMap<Value *, Value *> PointerToBase; + MapVector<Value *, Value *> PointerToBase; /// The *new* gc.statepoint instruction itself. This produces the token /// that normal path gc.relocates and the gc.result are tied to. @@ -199,9 +190,8 @@ struct PartiallyConstructedSafepointRecord { } static ArrayRef<Use> GetDeoptBundleOperands(ImmutableCallSite CS) { - assert(UseDeoptBundles && "Should not be called otherwise!"); - - Optional<OperandBundleUse> DeoptBundle = CS.getOperandBundle("deopt"); + Optional<OperandBundleUse> DeoptBundle = + CS.getOperandBundle(LLVMContext::OB_deopt); if (!DeoptBundle.hasValue()) { assert(AllowStatepointWithNoDeoptInfo && @@ -229,7 +219,7 @@ static bool isGCPointerType(Type *T) { // For the sake of this example GC, we arbitrarily pick addrspace(1) as our // GC managed heap. We know that a pointer into this heap needs to be // updated and that no other pointer does. - return (1 == PT->getAddressSpace()); + return PT->getAddressSpace() == 1; return false; } @@ -260,8 +250,7 @@ static bool containsGCPtrType(Type *Ty) { if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) return containsGCPtrType(AT->getElementType()); if (StructType *ST = dyn_cast<StructType>(Ty)) - return std::any_of(ST->subtypes().begin(), ST->subtypes().end(), - containsGCPtrType); + return any_of(ST->subtypes(), containsGCPtrType); return false; } @@ -273,19 +262,6 @@ static bool isUnhandledGCPointerType(Type *Ty) { } #endif -static bool order_by_name(Value *a, Value *b) { - if (a->hasName() && b->hasName()) { - return -1 == a->getName().compare(b->getName()); - } else if (a->hasName() && !b->hasName()) { - return true; - } else if (!a->hasName() && b->hasName()) { - return false; - } else { - // Better than nothing, but not stable - return a < b; - } -} - // Return the name of the value suffixed with the provided value, or if the // value didn't have a name, the default value specified. static std::string suffixed_name_or(Value *V, StringRef Suffix, @@ -297,30 +273,25 @@ static std::string suffixed_name_or(Value *V, StringRef Suffix, // given instruction. The analysis is performed immediately before the // given instruction. Values defined by that instruction are not considered // live. Values used by that instruction are considered live. -static void analyzeParsePointLiveness( - DominatorTree &DT, GCPtrLivenessData &OriginalLivenessData, - const CallSite &CS, PartiallyConstructedSafepointRecord &result) { - Instruction *inst = CS.getInstruction(); +static void +analyzeParsePointLiveness(DominatorTree &DT, + GCPtrLivenessData &OriginalLivenessData, CallSite CS, + PartiallyConstructedSafepointRecord &Result) { + Instruction *Inst = CS.getInstruction(); StatepointLiveSetTy LiveSet; - findLiveSetAtInst(inst, OriginalLivenessData, LiveSet); + findLiveSetAtInst(Inst, OriginalLivenessData, LiveSet); if (PrintLiveSet) { - // Note: This output is used by several of the test cases - // The order of elements in a set is not stable, put them in a vec and sort - // by name - SmallVector<Value *, 64> Temp; - Temp.insert(Temp.end(), LiveSet.begin(), LiveSet.end()); - std::sort(Temp.begin(), Temp.end(), order_by_name); - errs() << "Live Variables:\n"; - for (Value *V : Temp) + dbgs() << "Live Variables:\n"; + for (Value *V : LiveSet) dbgs() << " " << V->getName() << " " << *V << "\n"; } if (PrintLiveSetSize) { - errs() << "Safepoint For: " << CS.getCalledValue()->getName() << "\n"; - errs() << "Number live values: " << LiveSet.size() << "\n"; + dbgs() << "Safepoint For: " << CS.getCalledValue()->getName() << "\n"; + dbgs() << "Number live values: " << LiveSet.size() << "\n"; } - result.LiveSet = LiveSet; + Result.LiveSet = LiveSet; } static bool isKnownBaseResult(Value *V); @@ -372,8 +343,10 @@ findBaseDefiningValueOfVector(Value *I) { return BaseDefiningValueResult(I, true); if (isa<Constant>(I)) - // Constant vectors consist only of constant pointers. - return BaseDefiningValueResult(I, true); + // Base of constant vector consists only of constant null pointers. + // For reasoning see similar case inside 'findBaseDefiningValue' function. + return BaseDefiningValueResult(ConstantAggregateZero::get(I->getType()), + true); if (isa<LoadInst>(I)) return BaseDefiningValueResult(I, true); @@ -415,14 +388,20 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { // We should have never reached here if this argument isn't an gc value return BaseDefiningValueResult(I, true); - if (isa<Constant>(I)) + if (isa<Constant>(I)) { // We assume that objects with a constant base (e.g. a global) can't move // and don't need to be reported to the collector because they are always - // live. All constants have constant bases. Besides global references, all - // kinds of constants (e.g. undef, constant expressions, null pointers) can - // be introduced by the inliner or the optimizer, especially on dynamically - // dead paths. See e.g. test4 in constants.ll. - return BaseDefiningValueResult(I, true); + // live. Besides global references, all kinds of constants (e.g. undef, + // constant expressions, null pointers) can be introduced by the inliner or + // the optimizer, especially on dynamically dead paths. + // Here we treat all of them as having single null base. By doing this we + // trying to avoid problems reporting various conflicts in a form of + // "phi (const1, const2)" or "phi (const, regular gc ptr)". + // See constant.ll file for relevant test cases. + + return BaseDefiningValueResult( + ConstantPointerNull::get(cast<PointerType>(I->getType())), true); + } if (CastInst *CI = dyn_cast<CastInst>(I)) { Value *Def = CI->stripPointerCasts(); @@ -570,30 +549,36 @@ class BDVState { public: enum Status { Unknown, Base, Conflict }; - BDVState(Status s, Value *b = nullptr) : status(s), base(b) { - assert(status != Base || b); + BDVState() : Status(Unknown), BaseValue(nullptr) {} + + explicit BDVState(Status Status, Value *BaseValue = nullptr) + : Status(Status), BaseValue(BaseValue) { + assert(Status != Base || BaseValue); } - explicit BDVState(Value *b) : status(Base), base(b) {} - BDVState() : status(Unknown), base(nullptr) {} - Status getStatus() const { return status; } - Value *getBase() const { return base; } + explicit BDVState(Value *BaseValue) : Status(Base), BaseValue(BaseValue) {} + + Status getStatus() const { return Status; } + Value *getBaseValue() const { return BaseValue; } bool isBase() const { return getStatus() == Base; } bool isUnknown() const { return getStatus() == Unknown; } bool isConflict() const { return getStatus() == Conflict; } - bool operator==(const BDVState &other) const { - return base == other.base && status == other.status; + bool operator==(const BDVState &Other) const { + return BaseValue == Other.BaseValue && Status == Other.Status; } bool operator!=(const BDVState &other) const { return !(*this == other); } LLVM_DUMP_METHOD - void dump() const { print(dbgs()); dbgs() << '\n'; } - + void dump() const { + print(dbgs()); + dbgs() << '\n'; + } + void print(raw_ostream &OS) const { - switch (status) { + switch (getStatus()) { case Unknown: OS << "U"; break; @@ -604,13 +589,13 @@ public: OS << "C"; break; }; - OS << " (" << base << " - " - << (base ? base->getName() : "nullptr") << "): "; + OS << " (" << getBaseValue() << " - " + << (getBaseValue() ? getBaseValue()->getName() : "nullptr") << "): "; } private: - Status status; - AssertingVH<Value> base; // non null only if status == base + Status Status; + AssertingVH<Value> BaseValue; // Non-null only if Status == Base. }; } @@ -621,75 +606,50 @@ static raw_ostream &operator<<(raw_ostream &OS, const BDVState &State) { } #endif -namespace { -// Values of type BDVState form a lattice, and this is a helper -// class that implementes the meet operation. The meat of the meet -// operation is implemented in MeetBDVStates::pureMeet -class MeetBDVStates { -public: - /// Initializes the currentResult to the TOP state so that if can be met with - /// any other state to produce that state. - MeetBDVStates() {} - - // Destructively meet the current result with the given BDVState - void meetWith(BDVState otherState) { - currentResult = meet(otherState, currentResult); - } +static BDVState meetBDVStateImpl(const BDVState &LHS, const BDVState &RHS) { + switch (LHS.getStatus()) { + case BDVState::Unknown: + return RHS; - BDVState getResult() const { return currentResult; } + case BDVState::Base: + assert(LHS.getBaseValue() && "can't be null"); + if (RHS.isUnknown()) + return LHS; -private: - BDVState currentResult; - - /// Perform a meet operation on two elements of the BDVState lattice. - static BDVState meet(BDVState LHS, BDVState RHS) { - assert((pureMeet(LHS, RHS) == pureMeet(RHS, LHS)) && - "math is wrong: meet does not commute!"); - BDVState Result = pureMeet(LHS, RHS); - DEBUG(dbgs() << "meet of " << LHS << " with " << RHS - << " produced " << Result << "\n"); - return Result; - } - - static BDVState pureMeet(const BDVState &stateA, const BDVState &stateB) { - switch (stateA.getStatus()) { - case BDVState::Unknown: - return stateB; - - case BDVState::Base: - assert(stateA.getBase() && "can't be null"); - if (stateB.isUnknown()) - return stateA; - - if (stateB.isBase()) { - if (stateA.getBase() == stateB.getBase()) { - assert(stateA == stateB && "equality broken!"); - return stateA; - } - return BDVState(BDVState::Conflict); + if (RHS.isBase()) { + if (LHS.getBaseValue() == RHS.getBaseValue()) { + assert(LHS == RHS && "equality broken!"); + return LHS; } - assert(stateB.isConflict() && "only three states!"); return BDVState(BDVState::Conflict); - - case BDVState::Conflict: - return stateA; } - llvm_unreachable("only three states!"); + assert(RHS.isConflict() && "only three states!"); + return BDVState(BDVState::Conflict); + + case BDVState::Conflict: + return LHS; } -}; + llvm_unreachable("only three states!"); } +// Values of type BDVState form a lattice, and this function implements the meet +// operation. +static BDVState meetBDVState(BDVState LHS, BDVState RHS) { + BDVState Result = meetBDVStateImpl(LHS, RHS); + assert(Result == meetBDVStateImpl(RHS, LHS) && + "Math is wrong: meet does not commute!"); + return Result; +} -/// For a given value or instruction, figure out what base ptr it's derived -/// from. For gc objects, this is simply itself. On success, returns a value -/// which is the base pointer. (This is reliable and can be used for -/// relocation.) On failure, returns nullptr. -static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { - Value *def = findBaseOrBDV(I, cache); +/// For a given value or instruction, figure out what base ptr its derived from. +/// For gc objects, this is simply itself. On success, returns a value which is +/// the base pointer. (This is reliable and can be used for relocation.) On +/// failure, returns nullptr. +static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { + Value *Def = findBaseOrBDV(I, Cache); - if (isKnownBaseResult(def)) { - return def; - } + if (isKnownBaseResult(Def)) + return Def; // Here's the rough algorithm: // - For every SSA value, construct a mapping to either an actual base @@ -731,14 +691,14 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // one for which we don't already know a definite base value for /* scope */ { SmallVector<Value*, 16> Worklist; - Worklist.push_back(def); - States.insert(std::make_pair(def, BDVState())); + Worklist.push_back(Def); + States.insert({Def, BDVState()}); while (!Worklist.empty()) { Value *Current = Worklist.pop_back_val(); assert(!isKnownBaseResult(Current) && "why did it get added?"); auto visitIncomingValue = [&](Value *InVal) { - Value *Base = findBaseOrBDV(InVal, cache); + Value *Base = findBaseOrBDV(InVal, Cache); if (isKnownBaseResult(Base)) // Known bases won't need new instructions introduced and can be // ignored safely @@ -748,12 +708,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { if (States.insert(std::make_pair(Base, BDVState())).second) Worklist.push_back(Base); }; - if (PHINode *Phi = dyn_cast<PHINode>(Current)) { - for (Value *InVal : Phi->incoming_values()) + if (PHINode *PN = dyn_cast<PHINode>(Current)) { + for (Value *InVal : PN->incoming_values()) visitIncomingValue(InVal); - } else if (SelectInst *Sel = dyn_cast<SelectInst>(Current)) { - visitIncomingValue(Sel->getTrueValue()); - visitIncomingValue(Sel->getFalseValue()); + } else if (SelectInst *SI = dyn_cast<SelectInst>(Current)) { + visitIncomingValue(SI->getTrueValue()); + visitIncomingValue(SI->getFalseValue()); } else if (auto *EE = dyn_cast<ExtractElementInst>(Current)) { visitIncomingValue(EE->getVectorOperand()); } else if (auto *IE = dyn_cast<InsertElementInst>(Current)) { @@ -762,7 +722,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { } else { // There is one known class of instructions we know we don't handle. assert(isa<ShuffleVectorInst>(Current)); - llvm_unreachable("unimplemented instruction case"); + llvm_unreachable("Unimplemented instruction case"); } } } @@ -784,12 +744,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { return I->second; }; - bool progress = true; - while (progress) { + bool Progress = true; + while (Progress) { #ifndef NDEBUG - const size_t oldSize = States.size(); + const size_t OldSize = States.size(); #endif - progress = false; + Progress = false; // We're only changing values in this loop, thus safe to keep iterators. // Since this is computing a fixed point, the order of visit does not // effect the result. TODO: We could use a worklist here and make this run @@ -801,38 +761,39 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // Given an input value for the current instruction, return a BDVState // instance which represents the BDV of that value. auto getStateForInput = [&](Value *V) mutable { - Value *BDV = findBaseOrBDV(V, cache); + Value *BDV = findBaseOrBDV(V, Cache); return getStateForBDV(BDV); }; - MeetBDVStates calculateMeet; - if (SelectInst *select = dyn_cast<SelectInst>(BDV)) { - calculateMeet.meetWith(getStateForInput(select->getTrueValue())); - calculateMeet.meetWith(getStateForInput(select->getFalseValue())); - } else if (PHINode *Phi = dyn_cast<PHINode>(BDV)) { - for (Value *Val : Phi->incoming_values()) - calculateMeet.meetWith(getStateForInput(Val)); + BDVState NewState; + if (SelectInst *SI = dyn_cast<SelectInst>(BDV)) { + NewState = meetBDVState(NewState, getStateForInput(SI->getTrueValue())); + NewState = + meetBDVState(NewState, getStateForInput(SI->getFalseValue())); + } else if (PHINode *PN = dyn_cast<PHINode>(BDV)) { + for (Value *Val : PN->incoming_values()) + NewState = meetBDVState(NewState, getStateForInput(Val)); } else if (auto *EE = dyn_cast<ExtractElementInst>(BDV)) { // The 'meet' for an extractelement is slightly trivial, but it's still // useful in that it drives us to conflict if our input is. - calculateMeet.meetWith(getStateForInput(EE->getVectorOperand())); + NewState = + meetBDVState(NewState, getStateForInput(EE->getVectorOperand())); } else { // Given there's a inherent type mismatch between the operands, will // *always* produce Conflict. auto *IE = cast<InsertElementInst>(BDV); - calculateMeet.meetWith(getStateForInput(IE->getOperand(0))); - calculateMeet.meetWith(getStateForInput(IE->getOperand(1))); + NewState = meetBDVState(NewState, getStateForInput(IE->getOperand(0))); + NewState = meetBDVState(NewState, getStateForInput(IE->getOperand(1))); } - BDVState oldState = States[BDV]; - BDVState newState = calculateMeet.getResult(); - if (oldState != newState) { - progress = true; - States[BDV] = newState; + BDVState OldState = States[BDV]; + if (OldState != NewState) { + Progress = true; + States[BDV] = NewState; } } - assert(oldSize == States.size() && + assert(OldSize == States.size() && "fixed point shouldn't be adding any new nodes to state"); } @@ -842,7 +803,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); } #endif - + // Insert Phis for all conflicts // TODO: adjust naming patterns to avoid this order of iteration dependency for (auto Pair : States) { @@ -856,14 +817,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // The problem is that we need to convert from a vector base to a scalar // base for the particular indice we're interested in. if (State.isBase() && isa<ExtractElementInst>(I) && - isa<VectorType>(State.getBase()->getType())) { + isa<VectorType>(State.getBaseValue()->getType())) { auto *EE = cast<ExtractElementInst>(I); // TODO: In many cases, the new instruction is just EE itself. We should // exploit this, but can't do it here since it would break the invariant // about the BDV not being known to be a base. - auto *BaseInst = ExtractElementInst::Create(State.getBase(), - EE->getIndexOperand(), - "base_ee", EE); + auto *BaseInst = ExtractElementInst::Create( + State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); States[I] = BDVState(BDVState::Base, BaseInst); } @@ -871,10 +831,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // Since we're joining a vector and scalar base, they can never be the // same. As a result, we should always see insert element having reached // the conflict state. - if (isa<InsertElementInst>(I)) { - assert(State.isConflict()); - } - + assert(!isa<InsertElementInst>(I) || State.isConflict()); + if (!State.isConflict()) continue; @@ -887,12 +845,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { assert(NumPreds > 0 && "how did we reach here"); std::string Name = suffixed_name_or(I, ".base", "base_phi"); return PHINode::Create(I->getType(), NumPreds, Name, I); - } else if (SelectInst *Sel = dyn_cast<SelectInst>(I)) { + } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { // The undef will be replaced later - UndefValue *Undef = UndefValue::get(Sel->getType()); + UndefValue *Undef = UndefValue::get(SI->getType()); std::string Name = suffixed_name_or(I, ".base", "base_select"); - return SelectInst::Create(Sel->getCondition(), Undef, - Undef, Name, Sel); + return SelectInst::Create(SI->getCondition(), Undef, Undef, Name, SI); } else if (auto *EE = dyn_cast<ExtractElementInst>(I)) { UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType()); std::string Name = suffixed_name_or(I, ".base", "base_ee"); @@ -906,7 +863,6 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { return InsertElementInst::Create(VecUndef, ScalarUndef, IE->getOperand(2), Name, IE); } - }; Instruction *BaseInst = MakeBaseInstPlaceholder(I); // Add metadata marking this as a base value @@ -921,24 +877,21 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // instruction to propagate the base of it's BDV and have entered that newly // introduced instruction into the state table. In either case, we are // assured to be able to determine an instruction which produces it's base - // pointer. + // pointer. auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) { - Value *BDV = findBaseOrBDV(Input, cache); + Value *BDV = findBaseOrBDV(Input, Cache); Value *Base = nullptr; if (isKnownBaseResult(BDV)) { Base = BDV; } else { // Either conflict or base. assert(States.count(BDV)); - Base = States[BDV].getBase(); + Base = States[BDV].getBaseValue(); } - assert(Base && "can't be null"); + assert(Base && "Can't be null"); // The cast is needed since base traversal may strip away bitcasts - if (Base->getType() != Input->getType() && - InsertPt) { - Base = new BitCastInst(Base, Input->getType(), "cast", - InsertPt); - } + if (Base->getType() != Input->getType() && InsertPt) + Base = new BitCastInst(Base, Input->getType(), "cast", InsertPt); return Base; }; @@ -954,12 +907,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { if (!State.isConflict()) continue; - if (PHINode *basephi = dyn_cast<PHINode>(State.getBase())) { - PHINode *phi = cast<PHINode>(BDV); - unsigned NumPHIValues = phi->getNumIncomingValues(); + if (PHINode *BasePHI = dyn_cast<PHINode>(State.getBaseValue())) { + PHINode *PN = cast<PHINode>(BDV); + unsigned NumPHIValues = PN->getNumIncomingValues(); for (unsigned i = 0; i < NumPHIValues; i++) { - Value *InVal = phi->getIncomingValue(i); - BasicBlock *InBB = phi->getIncomingBlock(i); + Value *InVal = PN->getIncomingValue(i); + BasicBlock *InBB = PN->getIncomingBlock(i); // If we've already seen InBB, add the same incoming value // we added for it earlier. The IR verifier requires phi @@ -970,22 +923,21 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // bitcasts (and hence two distinct values) as incoming // values for the same basic block. - int blockIndex = basephi->getBasicBlockIndex(InBB); - if (blockIndex != -1) { - Value *oldBase = basephi->getIncomingValue(blockIndex); - basephi->addIncoming(oldBase, InBB); - + int BlockIndex = BasePHI->getBasicBlockIndex(InBB); + if (BlockIndex != -1) { + Value *OldBase = BasePHI->getIncomingValue(BlockIndex); + BasePHI->addIncoming(OldBase, InBB); + #ifndef NDEBUG Value *Base = getBaseForInput(InVal, nullptr); - // In essence this assert states: the only way two - // values incoming from the same basic block may be - // different is by being different bitcasts of the same - // value. A cleanup that remains TODO is changing - // findBaseOrBDV to return an llvm::Value of the correct - // type (and still remain pure). This will remove the - // need to add bitcasts. - assert(Base->stripPointerCasts() == oldBase->stripPointerCasts() && - "sanity -- findBaseOrBDV should be pure!"); + // In essence this assert states: the only way two values + // incoming from the same basic block may be different is by + // being different bitcasts of the same value. A cleanup + // that remains TODO is changing findBaseOrBDV to return an + // llvm::Value of the correct type (and still remain pure). + // This will remove the need to add bitcasts. + assert(Base->stripPointerCasts() == OldBase->stripPointerCasts() && + "Sanity -- findBaseOrBDV should be pure!"); #endif continue; } @@ -994,28 +946,25 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // need to insert a bitcast in the incoming block. // TODO: Need to split critical edges if insertion is needed Value *Base = getBaseForInput(InVal, InBB->getTerminator()); - basephi->addIncoming(Base, InBB); + BasePHI->addIncoming(Base, InBB); } - assert(basephi->getNumIncomingValues() == NumPHIValues); - } else if (SelectInst *BaseSel = dyn_cast<SelectInst>(State.getBase())) { - SelectInst *Sel = cast<SelectInst>(BDV); - // Operand 1 & 2 are true, false path respectively. TODO: refactor to - // something more safe and less hacky. - for (int i = 1; i <= 2; i++) { - Value *InVal = Sel->getOperand(i); - // Find the instruction which produces the base for each input. We may - // need to insert a bitcast. - Value *Base = getBaseForInput(InVal, BaseSel); - BaseSel->setOperand(i, Base); - } - } else if (auto *BaseEE = dyn_cast<ExtractElementInst>(State.getBase())) { + assert(BasePHI->getNumIncomingValues() == NumPHIValues); + } else if (SelectInst *BaseSI = + dyn_cast<SelectInst>(State.getBaseValue())) { + SelectInst *SI = cast<SelectInst>(BDV); + + // Find the instruction which produces the base for each input. + // We may need to insert a bitcast. + BaseSI->setTrueValue(getBaseForInput(SI->getTrueValue(), BaseSI)); + BaseSI->setFalseValue(getBaseForInput(SI->getFalseValue(), BaseSI)); + } else if (auto *BaseEE = + dyn_cast<ExtractElementInst>(State.getBaseValue())) { Value *InVal = cast<ExtractElementInst>(BDV)->getVectorOperand(); // Find the instruction which produces the base for each input. We may // need to insert a bitcast. - Value *Base = getBaseForInput(InVal, BaseEE); - BaseEE->setOperand(0, Base); + BaseEE->setOperand(0, getBaseForInput(InVal, BaseEE)); } else { - auto *BaseIE = cast<InsertElementInst>(State.getBase()); + auto *BaseIE = cast<InsertElementInst>(State.getBaseValue()); auto *BdvIE = cast<InsertElementInst>(BDV); auto UpdateOperand = [&](int OperandIdx) { Value *InVal = BdvIE->getOperand(OperandIdx); @@ -1025,69 +974,6 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { UpdateOperand(0); // vector operand UpdateOperand(1); // scalar operand } - - } - - // Now that we're done with the algorithm, see if we can optimize the - // results slightly by reducing the number of new instructions needed. - // Arguably, this should be integrated into the algorithm above, but - // doing as a post process step is easier to reason about for the moment. - DenseMap<Value *, Value *> ReverseMap; - SmallPtrSet<Instruction *, 16> NewInsts; - SmallSetVector<AssertingVH<Instruction>, 16> Worklist; - // Note: We need to visit the states in a deterministic order. We uses the - // Keys we sorted above for this purpose. Note that we are papering over a - // bigger problem with the algorithm above - it's visit order is not - // deterministic. A larger change is needed to fix this. - for (auto Pair : States) { - auto *BDV = Pair.first; - auto State = Pair.second; - Value *Base = State.getBase(); - assert(BDV && Base); - assert(!isKnownBaseResult(BDV) && "why did it get added?"); - assert(isKnownBaseResult(Base) && - "must be something we 'know' is a base pointer"); - if (!State.isConflict()) - continue; - - ReverseMap[Base] = BDV; - if (auto *BaseI = dyn_cast<Instruction>(Base)) { - NewInsts.insert(BaseI); - Worklist.insert(BaseI); - } - } - auto ReplaceBaseInstWith = [&](Value *BDV, Instruction *BaseI, - Value *Replacement) { - // Add users which are new instructions (excluding self references) - for (User *U : BaseI->users()) - if (auto *UI = dyn_cast<Instruction>(U)) - if (NewInsts.count(UI) && UI != BaseI) - Worklist.insert(UI); - // Then do the actual replacement - NewInsts.erase(BaseI); - ReverseMap.erase(BaseI); - BaseI->replaceAllUsesWith(Replacement); - assert(States.count(BDV)); - assert(States[BDV].isConflict() && States[BDV].getBase() == BaseI); - States[BDV] = BDVState(BDVState::Conflict, Replacement); - BaseI->eraseFromParent(); - }; - const DataLayout &DL = cast<Instruction>(def)->getModule()->getDataLayout(); - while (!Worklist.empty()) { - Instruction *BaseI = Worklist.pop_back_val(); - assert(NewInsts.count(BaseI)); - Value *Bdv = ReverseMap[BaseI]; - if (auto *BdvI = dyn_cast<Instruction>(Bdv)) - if (BaseI->isIdenticalTo(BdvI)) { - DEBUG(dbgs() << "Identical Base: " << *BaseI << "\n"); - ReplaceBaseInstWith(Bdv, BaseI, Bdv); - continue; - } - if (Value *V = SimplifyInstruction(BaseI, DL)) { - DEBUG(dbgs() << "Base " << *BaseI << " simplified to " << *V << "\n"); - ReplaceBaseInstWith(Bdv, BaseI, V); - continue; - } } // Cache all of our results so we can cheaply reuse them @@ -1095,25 +981,27 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // relation and one of the base pointer relation! FIXME for (auto Pair : States) { auto *BDV = Pair.first; - Value *base = Pair.second.getBase(); - assert(BDV && base); + Value *Base = Pair.second.getBaseValue(); + assert(BDV && Base); + assert(!isKnownBaseResult(BDV) && "why did it get added?"); - std::string fromstr = cache.count(BDV) ? cache[BDV]->getName() : "none"; DEBUG(dbgs() << "Updating base value cache" - << " for: " << BDV->getName() - << " from: " << fromstr - << " to: " << base->getName() << "\n"); - - if (cache.count(BDV)) { - // Once we transition from the BDV relation being store in the cache to + << " for: " << BDV->getName() << " from: " + << (Cache.count(BDV) ? Cache[BDV]->getName().str() : "none") + << " to: " << Base->getName() << "\n"); + + if (Cache.count(BDV)) { + assert(isKnownBaseResult(Base) && + "must be something we 'know' is a base pointer"); + // Once we transition from the BDV relation being store in the Cache to // the base relation being stored, it must be stable - assert((!isKnownBaseResult(cache[BDV]) || cache[BDV] == base) && + assert((!isKnownBaseResult(Cache[BDV]) || Cache[BDV] == Base) && "base relation should be stable"); } - cache[BDV] = base; + Cache[BDV] = Base; } - assert(cache.count(def)); - return cache[def]; + assert(Cache.count(Def)); + return Cache[Def]; } // For a set of live pointers (base and/or derived), identify the base @@ -1133,15 +1021,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) { // pointer was a base pointer. static void findBasePointers(const StatepointLiveSetTy &live, - DenseMap<Value *, Value *> &PointerToBase, + MapVector<Value *, Value *> &PointerToBase, DominatorTree *DT, DefiningValueMapTy &DVCache) { - // For the naming of values inserted to be deterministic - which makes for - // much cleaner and more stable tests - we need to assign an order to the - // live values. DenseSets do not provide a deterministic order across runs. - SmallVector<Value *, 64> Temp; - Temp.insert(Temp.end(), live.begin(), live.end()); - std::sort(Temp.begin(), Temp.end(), order_by_name); - for (Value *ptr : Temp) { + for (Value *ptr : live) { Value *base = findBasePointer(ptr, DVCache); assert(base && "failed to find base pointer"); PointerToBase[ptr] = base; @@ -1149,41 +1031,24 @@ findBasePointers(const StatepointLiveSetTy &live, DT->dominates(cast<Instruction>(base)->getParent(), cast<Instruction>(ptr)->getParent())) && "The base we found better dominate the derived pointer"); - - // If you see this trip and like to live really dangerously, the code should - // be correct, just with idioms the verifier can't handle. You can try - // disabling the verifier at your own substantial risk. - assert(!isa<ConstantPointerNull>(base) && - "the relocation code needs adjustment to handle the relocation of " - "a null pointer constant without causing false positives in the " - "safepoint ir verifier."); } } /// Find the required based pointers (and adjust the live set) for the given /// parse point. static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, - const CallSite &CS, + CallSite CS, PartiallyConstructedSafepointRecord &result) { - DenseMap<Value *, Value *> PointerToBase; + MapVector<Value *, Value *> PointerToBase; findBasePointers(result.LiveSet, PointerToBase, &DT, DVCache); if (PrintBasePointers) { - // Note: Need to print these in a stable order since this is checked in - // some tests. errs() << "Base Pairs (w/o Relocation):\n"; - SmallVector<Value *, 64> Temp; - Temp.reserve(PointerToBase.size()); - for (auto Pair : PointerToBase) { - Temp.push_back(Pair.first); - } - std::sort(Temp.begin(), Temp.end(), order_by_name); - for (Value *Ptr : Temp) { - Value *Base = PointerToBase[Ptr]; + for (auto &Pair : PointerToBase) { errs() << " derived "; - Ptr->printAsOperand(errs(), false); + Pair.first->printAsOperand(errs(), false); errs() << " base "; - Base->printAsOperand(errs(), false); + Pair.second->printAsOperand(errs(), false); errs() << "\n";; } } @@ -1194,7 +1059,7 @@ static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, /// Given an updated version of the dataflow liveness results, update the /// liveset and base pointer maps for the call site CS. static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, - const CallSite &CS, + CallSite CS, PartiallyConstructedSafepointRecord &result); static void recomputeLiveInValues( @@ -1206,8 +1071,7 @@ static void recomputeLiveInValues( computeLiveInValues(DT, F, RevisedLivenessData); for (size_t i = 0; i < records.size(); i++) { struct PartiallyConstructedSafepointRecord &info = records[i]; - const CallSite &CS = toUpdate[i]; - recomputeLiveInValues(RevisedLivenessData, CS, info); + recomputeLiveInValues(RevisedLivenessData, toUpdate[i], info); } } @@ -1257,8 +1121,7 @@ static AttributeSet legalizeCallAttributes(AttributeSet AS) { // These attributes control the generation of the gc.statepoint call / // invoke itself; and once the gc.statepoint is in place, they're of no // use. - if (Attr.hasAttribute("statepoint-num-patch-bytes") || - Attr.hasAttribute("statepoint-id")) + if (isStatepointDirectiveAttr(Attr)) continue; Ret = Ret.addAttributes( @@ -1349,11 +1212,37 @@ namespace { class DeferredReplacement { AssertingVH<Instruction> Old; AssertingVH<Instruction> New; + bool IsDeoptimize = false; + + DeferredReplacement() {} public: - explicit DeferredReplacement(Instruction *Old, Instruction *New) : - Old(Old), New(New) { - assert(Old != New && "Not allowed!"); + static DeferredReplacement createRAUW(Instruction *Old, Instruction *New) { + assert(Old != New && Old && New && + "Cannot RAUW equal values or to / from null!"); + + DeferredReplacement D; + D.Old = Old; + D.New = New; + return D; + } + + static DeferredReplacement createDelete(Instruction *ToErase) { + DeferredReplacement D; + D.Old = ToErase; + return D; + } + + static DeferredReplacement createDeoptimizeReplacement(Instruction *Old) { +#ifndef NDEBUG + auto *F = cast<CallInst>(Old)->getCalledFunction(); + assert(F && F->getIntrinsicID() == Intrinsic::experimental_deoptimize && + "Only way to construct a deoptimize deferred replacement"); +#endif + DeferredReplacement D; + D.Old = Old; + D.IsDeoptimize = true; + return D; } /// Does the task represented by this instance. @@ -1362,12 +1251,23 @@ public: Instruction *NewI = New; assert(OldI != NewI && "Disallowed at construction?!"); + assert((!IsDeoptimize || !New) && + "Deoptimize instrinsics are not replaced!"); Old = nullptr; New = nullptr; if (NewI) OldI->replaceAllUsesWith(NewI); + + if (IsDeoptimize) { + // Note: we've inserted instructions, so the call to llvm.deoptimize may + // not necessarilly be followed by the matching return. + auto *RI = cast<ReturnInst>(OldI->getParent()->getTerminator()); + new UnreachableInst(RI->getContext(), RI); + RI->eraseFromParent(); + } + OldI->eraseFromParent(); } }; @@ -1380,8 +1280,6 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ PartiallyConstructedSafepointRecord &Result, std::vector<DeferredReplacement> &Replacements) { assert(BasePtrs.size() == LiveVariables.size()); - assert((UseDeoptBundles || isStatepoint(CS)) && - "This method expects to be rewriting a statepoint"); // Then go ahead and use the builder do actually do the inserts. We insert // immediately before the previous instruction under the assumption that all @@ -1391,47 +1289,53 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ IRBuilder<> Builder(InsertBefore); ArrayRef<Value *> GCArgs(LiveVariables); - uint64_t StatepointID = 0xABCDEF00; + uint64_t StatepointID = StatepointDirectives::DefaultStatepointID; uint32_t NumPatchBytes = 0; uint32_t Flags = uint32_t(StatepointFlags::None); - ArrayRef<Use> CallArgs; - ArrayRef<Use> DeoptArgs; + ArrayRef<Use> CallArgs(CS.arg_begin(), CS.arg_end()); + ArrayRef<Use> DeoptArgs = GetDeoptBundleOperands(CS); ArrayRef<Use> TransitionArgs; - - Value *CallTarget = nullptr; - - if (UseDeoptBundles) { - CallArgs = {CS.arg_begin(), CS.arg_end()}; - DeoptArgs = GetDeoptBundleOperands(CS); - // TODO: we don't fill in TransitionArgs or Flags in this branch, but we - // could have an operand bundle for that too. - AttributeSet OriginalAttrs = CS.getAttributes(); - - Attribute AttrID = OriginalAttrs.getAttribute(AttributeSet::FunctionIndex, - "statepoint-id"); - if (AttrID.isStringAttribute()) - AttrID.getValueAsString().getAsInteger(10, StatepointID); - - Attribute AttrNumPatchBytes = OriginalAttrs.getAttribute( - AttributeSet::FunctionIndex, "statepoint-num-patch-bytes"); - if (AttrNumPatchBytes.isStringAttribute()) - AttrNumPatchBytes.getValueAsString().getAsInteger(10, NumPatchBytes); - - CallTarget = CS.getCalledValue(); - } else { - // This branch will be gone soon, and we will soon only support the - // UseDeoptBundles == true configuration. - Statepoint OldSP(CS); - StatepointID = OldSP.getID(); - NumPatchBytes = OldSP.getNumPatchBytes(); - Flags = OldSP.getFlags(); - - CallArgs = {OldSP.arg_begin(), OldSP.arg_end()}; - DeoptArgs = {OldSP.vm_state_begin(), OldSP.vm_state_end()}; - TransitionArgs = {OldSP.gc_transition_args_begin(), - OldSP.gc_transition_args_end()}; - CallTarget = OldSP.getCalledValue(); + if (auto TransitionBundle = + CS.getOperandBundle(LLVMContext::OB_gc_transition)) { + Flags |= uint32_t(StatepointFlags::GCTransition); + TransitionArgs = TransitionBundle->Inputs; + } + + // Instead of lowering calls to @llvm.experimental.deoptimize as normal calls + // with a return value, we lower then as never returning calls to + // __llvm_deoptimize that are followed by unreachable to get better codegen. + bool IsDeoptimize = false; + + StatepointDirectives SD = + parseStatepointDirectivesFromAttrs(CS.getAttributes()); + if (SD.NumPatchBytes) + NumPatchBytes = *SD.NumPatchBytes; + if (SD.StatepointID) + StatepointID = *SD.StatepointID; + + Value *CallTarget = CS.getCalledValue(); + if (Function *F = dyn_cast<Function>(CallTarget)) { + if (F->getIntrinsicID() == Intrinsic::experimental_deoptimize) { + // Calls to llvm.experimental.deoptimize are lowered to calls to the + // __llvm_deoptimize symbol. We want to resolve this now, since the + // verifier does not allow taking the address of an intrinsic function. + + SmallVector<Type *, 8> DomainTy; + for (Value *Arg : CallArgs) + DomainTy.push_back(Arg->getType()); + auto *FTy = FunctionType::get(Type::getVoidTy(F->getContext()), DomainTy, + /* isVarArg = */ false); + + // Note: CallTarget can be a bitcast instruction of a symbol if there are + // calls to @llvm.experimental.deoptimize with different argument types in + // the same module. This is fine -- we assume the frontend knew what it + // was doing when generating this kind of IR. + CallTarget = + F->getParent()->getOrInsertFunction("__llvm_deoptimize", FTy); + + IsDeoptimize = true; + } } // Create the statepoint given all the arguments @@ -1514,7 +1418,13 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ } assert(Token && "Should be set in one of the above branches!"); - if (UseDeoptBundles) { + if (IsDeoptimize) { + // If we're wrapping an @llvm.experimental.deoptimize in a statepoint, we + // transform the tail-call like structure to a call to a void function + // followed by unreachable to get better codegen. + Replacements.push_back( + DeferredReplacement::createDeoptimizeReplacement(CS.getInstruction())); + } else { Token->setName("statepoint_token"); if (!CS.getType()->isVoidTy() && !CS.getInstruction()->use_empty()) { StringRef Name = @@ -1528,24 +1438,12 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ // llvm::Instruction. Instead, we defer the replacement and deletion to // after the live sets have been made explicit in the IR, and we no longer // have raw pointers to worry about. - Replacements.emplace_back(CS.getInstruction(), GCResult); + Replacements.emplace_back( + DeferredReplacement::createRAUW(CS.getInstruction(), GCResult)); } else { - Replacements.emplace_back(CS.getInstruction(), nullptr); + Replacements.emplace_back( + DeferredReplacement::createDelete(CS.getInstruction())); } - } else { - assert(!CS.getInstruction()->hasNUsesOrMore(2) && - "only valid use before rewrite is gc.result"); - assert(!CS.getInstruction()->hasOneUse() || - isGCResult(cast<Instruction>(*CS.getInstruction()->user_begin()))); - - // Take the name of the original statepoint token if there was one. - Token->takeName(CS.getInstruction()); - - // Update the gc.result of the original statepoint (if any) to use the newly - // inserted statepoint. This is safe to do here since the token can't be - // considered a live reference. - CS.getInstruction()->replaceAllUsesWith(Token); - CS.getInstruction()->eraseFromParent(); } Result.StatepointToken = Token; @@ -1555,43 +1453,13 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, Token, Builder); } -namespace { -struct NameOrdering { - Value *Base; - Value *Derived; - - bool operator()(NameOrdering const &a, NameOrdering const &b) { - return -1 == a.Derived->getName().compare(b.Derived->getName()); - } -}; -} - -static void StabilizeOrder(SmallVectorImpl<Value *> &BaseVec, - SmallVectorImpl<Value *> &LiveVec) { - assert(BaseVec.size() == LiveVec.size()); - - SmallVector<NameOrdering, 64> Temp; - for (size_t i = 0; i < BaseVec.size(); i++) { - NameOrdering v; - v.Base = BaseVec[i]; - v.Derived = LiveVec[i]; - Temp.push_back(v); - } - - std::sort(Temp.begin(), Temp.end(), NameOrdering()); - for (size_t i = 0; i < BaseVec.size(); i++) { - BaseVec[i] = Temp[i].Base; - LiveVec[i] = Temp[i].Derived; - } -} - // Replace an existing gc.statepoint with a new one and a set of gc.relocates // which make the relocations happening at this safepoint explicit. // // WARNING: Does not do any fixup to adjust users of the original live // values. That's the callers responsibility. static void -makeStatepointExplicit(DominatorTree &DT, const CallSite &CS, +makeStatepointExplicit(DominatorTree &DT, CallSite CS, PartiallyConstructedSafepointRecord &Result, std::vector<DeferredReplacement> &Replacements) { const auto &LiveSet = Result.LiveSet; @@ -1609,11 +1477,6 @@ makeStatepointExplicit(DominatorTree &DT, const CallSite &CS, } assert(LiveVec.size() == BaseVec.size()); - // To make the output IR slightly more stable (for use in diffs), ensure a - // fixed order of the values in the safepoint (by sorting the value name). - // The order is otherwise meaningless. - StabilizeOrder(BaseVec, LiveVec); - // Do the actual rewriting and delete the old statepoint makeStatepointExplicitImpl(CS, BaseVec, LiveVec, Result, Replacements); } @@ -1634,7 +1497,7 @@ insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, if (!Relocate) continue; - Value *OriginalValue = const_cast<Value *>(Relocate->getDerivedPtr()); + Value *OriginalValue = Relocate->getDerivedPtr(); assert(AllocaMap.count(OriginalValue)); Value *Alloca = AllocaMap[OriginalValue]; @@ -1660,11 +1523,10 @@ insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, // Helper function for the "relocationViaAlloca". Similar to the // "insertRelocationStores" but works for rematerialized values. -static void -insertRematerializationStores( - RematerializedValueMapTy RematerializedValues, - DenseMap<Value *, Value *> &AllocaMap, - DenseSet<Value *> &VisitedLiveValues) { +static void insertRematerializationStores( + const RematerializedValueMapTy &RematerializedValues, + DenseMap<Value *, Value *> &AllocaMap, + DenseSet<Value *> &VisitedLiveValues) { for (auto RematerializedValuePair: RematerializedValues) { Instruction *RematerializedValue = RematerializedValuePair.first; @@ -1691,9 +1553,8 @@ static void relocationViaAlloca( // record initial number of (static) allocas; we'll check we have the same // number when we get done. int InitialAllocaNum = 0; - for (auto I = F.getEntryBlock().begin(), E = F.getEntryBlock().end(); I != E; - I++) - if (isa<AllocaInst>(*I)) + for (Instruction &I : F.getEntryBlock()) + if (isa<AllocaInst>(I)) InitialAllocaNum++; #endif @@ -1777,8 +1638,7 @@ static void relocationViaAlloca( auto InsertClobbersAt = [&](Instruction *IP) { for (auto *AI : ToClobber) { - auto AIType = cast<PointerType>(AI->getType()); - auto PT = cast<PointerType>(AIType->getElementType()); + auto PT = cast<PointerType>(AI->getAllocatedType()); Constant *CPN = ConstantPointerNull::get(PT); StoreInst *Store = new StoreInst(CPN, AI); Store->insertBefore(IP); @@ -1919,141 +1779,7 @@ static void findLiveReferences( computeLiveInValues(DT, F, OriginalLivenessData); for (size_t i = 0; i < records.size(); i++) { struct PartiallyConstructedSafepointRecord &info = records[i]; - const CallSite &CS = toUpdate[i]; - analyzeParsePointLiveness(DT, OriginalLivenessData, CS, info); - } -} - -/// Remove any vector of pointers from the live set by scalarizing them over the -/// statepoint instruction. Adds the scalarized pieces to the live set. It -/// would be preferable to include the vector in the statepoint itself, but -/// the lowering code currently does not handle that. Extending it would be -/// slightly non-trivial since it requires a format change. Given how rare -/// such cases are (for the moment?) scalarizing is an acceptable compromise. -static void splitVectorValues(Instruction *StatepointInst, - StatepointLiveSetTy &LiveSet, - DenseMap<Value *, Value *>& PointerToBase, - DominatorTree &DT) { - SmallVector<Value *, 16> ToSplit; - for (Value *V : LiveSet) - if (isa<VectorType>(V->getType())) - ToSplit.push_back(V); - - if (ToSplit.empty()) - return; - - DenseMap<Value *, SmallVector<Value *, 16>> ElementMapping; - - Function &F = *(StatepointInst->getParent()->getParent()); - - DenseMap<Value *, AllocaInst *> AllocaMap; - // First is normal return, second is exceptional return (invoke only) - DenseMap<Value *, std::pair<Value *, Value *>> Replacements; - for (Value *V : ToSplit) { - AllocaInst *Alloca = - new AllocaInst(V->getType(), "", F.getEntryBlock().getFirstNonPHI()); - AllocaMap[V] = Alloca; - - VectorType *VT = cast<VectorType>(V->getType()); - IRBuilder<> Builder(StatepointInst); - SmallVector<Value *, 16> Elements; - for (unsigned i = 0; i < VT->getNumElements(); i++) - Elements.push_back(Builder.CreateExtractElement(V, Builder.getInt32(i))); - ElementMapping[V] = Elements; - - auto InsertVectorReform = [&](Instruction *IP) { - Builder.SetInsertPoint(IP); - Builder.SetCurrentDebugLocation(IP->getDebugLoc()); - Value *ResultVec = UndefValue::get(VT); - for (unsigned i = 0; i < VT->getNumElements(); i++) - ResultVec = Builder.CreateInsertElement(ResultVec, Elements[i], - Builder.getInt32(i)); - return ResultVec; - }; - - if (isa<CallInst>(StatepointInst)) { - BasicBlock::iterator Next(StatepointInst); - Next++; - Instruction *IP = &*(Next); - Replacements[V].first = InsertVectorReform(IP); - Replacements[V].second = nullptr; - } else { - InvokeInst *Invoke = cast<InvokeInst>(StatepointInst); - // We've already normalized - check that we don't have shared destination - // blocks - BasicBlock *NormalDest = Invoke->getNormalDest(); - assert(!isa<PHINode>(NormalDest->begin())); - BasicBlock *UnwindDest = Invoke->getUnwindDest(); - assert(!isa<PHINode>(UnwindDest->begin())); - // Insert insert element sequences in both successors - Instruction *IP = &*(NormalDest->getFirstInsertionPt()); - Replacements[V].first = InsertVectorReform(IP); - IP = &*(UnwindDest->getFirstInsertionPt()); - Replacements[V].second = InsertVectorReform(IP); - } - } - - for (Value *V : ToSplit) { - AllocaInst *Alloca = AllocaMap[V]; - - // Capture all users before we start mutating use lists - SmallVector<Instruction *, 16> Users; - for (User *U : V->users()) - Users.push_back(cast<Instruction>(U)); - - for (Instruction *I : Users) { - if (auto Phi = dyn_cast<PHINode>(I)) { - for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++) - if (V == Phi->getIncomingValue(i)) { - LoadInst *Load = new LoadInst( - Alloca, "", Phi->getIncomingBlock(i)->getTerminator()); - Phi->setIncomingValue(i, Load); - } - } else { - LoadInst *Load = new LoadInst(Alloca, "", I); - I->replaceUsesOfWith(V, Load); - } - } - - // Store the original value and the replacement value into the alloca - StoreInst *Store = new StoreInst(V, Alloca); - if (auto I = dyn_cast<Instruction>(V)) - Store->insertAfter(I); - else - Store->insertAfter(Alloca); - - // Normal return for invoke, or call return - Instruction *Replacement = cast<Instruction>(Replacements[V].first); - (new StoreInst(Replacement, Alloca))->insertAfter(Replacement); - // Unwind return for invoke only - Replacement = cast_or_null<Instruction>(Replacements[V].second); - if (Replacement) - (new StoreInst(Replacement, Alloca))->insertAfter(Replacement); - } - - // apply mem2reg to promote alloca to SSA - SmallVector<AllocaInst *, 16> Allocas; - for (Value *V : ToSplit) - Allocas.push_back(AllocaMap[V]); - PromoteMemToReg(Allocas, DT); - - // Update our tracking of live pointers and base mappings to account for the - // changes we just made. - for (Value *V : ToSplit) { - auto &Elements = ElementMapping[V]; - - LiveSet.erase(V); - LiveSet.insert(Elements.begin(), Elements.end()); - // We need to update the base mapping as well. - assert(PointerToBase.count(V)); - Value *OldBase = PointerToBase[V]; - auto &BaseElements = ElementMapping[OldBase]; - PointerToBase.erase(V); - assert(Elements.size() == BaseElements.size()); - for (unsigned i = 0; i < Elements.size(); i++) { - Value *Elem = Elements[i]; - PointerToBase[Elem] = BaseElements[i]; - } + analyzeParsePointLiveness(DT, OriginalLivenessData, toUpdate[i], info); } } @@ -2109,7 +1835,7 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain, } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) { // Cost of the address calculation - Type *ValTy = GEP->getPointerOperandType()->getPointerElementType(); + Type *ValTy = GEP->getSourceElementType(); Cost += TTI.getAddressComputationCost(ValTy); // And cost of the GEP itself @@ -2244,7 +1970,7 @@ static void rematerializeLiveValues(CallSite CS, // Remove rematerializaed values from the live set for (auto LiveValue: LiveValuesToBeDeleted) { - Info.LiveSet.erase(LiveValue); + Info.LiveSet.remove(LiveValue); } } @@ -2257,11 +1983,8 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Uniqued.insert(ToUpdate.begin(), ToUpdate.end()); assert(Uniqued.size() == ToUpdate.size() && "no duplicates please!"); - for (CallSite CS : ToUpdate) { - assert(CS.getInstruction()->getParent()->getParent() == &F); - assert((UseDeoptBundles || isStatepoint(CS)) && - "expected to already be a deopt statepoint"); - } + for (CallSite CS : ToUpdate) + assert(CS.getInstruction()->getFunction() == &F); #endif // When inserting gc.relocates for invokes, we need to be able to insert at @@ -2287,12 +2010,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, for (CallSite CS : ToUpdate) { SmallVector<Value *, 64> DeoptValues; - iterator_range<const Use *> DeoptStateRange = - UseDeoptBundles - ? iterator_range<const Use *>(GetDeoptBundleOperands(CS)) - : iterator_range<const Use *>(Statepoint(CS).vm_state_args()); - - for (Value *Arg : DeoptStateRange) { + for (Value *Arg : GetDeoptBundleOperands(CS)) { assert(!isUnhandledGCPointerType(Arg->getType()) && "support for FCA unimplemented"); if (isHandledGCPointerType(Arg->getType())) @@ -2374,29 +2092,13 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, for (auto &Info : Records) for (auto &BasePair : Info.PointerToBase) if (isa<Constant>(BasePair.second)) - Info.LiveSet.erase(BasePair.first); + Info.LiveSet.remove(BasePair.first); for (CallInst *CI : Holders) CI->eraseFromParent(); Holders.clear(); - // Do a limited scalarization of any live at safepoint vector values which - // contain pointers. This enables this pass to run after vectorization at - // the cost of some possible performance loss. Note: This is known to not - // handle updating of the side tables correctly which can lead to relocation - // bugs when the same vector is live at multiple statepoints. We're in the - // process of implementing the alternate lowering - relocating the - // vector-of-pointers as first class item and updating the backend to - // understand that - but that's not yet complete. - if (UseVectorSplit) - for (size_t i = 0; i < Records.size(); i++) { - PartiallyConstructedSafepointRecord &Info = Records[i]; - Instruction *Statepoint = ToUpdate[i].getInstruction(); - splitVectorValues(cast<Instruction>(Statepoint), Info.LiveSet, - Info.PointerToBase, DT); - } - // In order to reduce live set of statepoint we might choose to rematerialize // some values instead of relocating them. This is purely an optimization and // does not influence correctness. @@ -2592,13 +2294,9 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F) { getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto NeedsRewrite = [](Instruction &I) { - if (UseDeoptBundles) { - if (ImmutableCallSite CS = ImmutableCallSite(&I)) - return !callsGCLeafFunction(CS); - return false; - } - - return isStatepoint(I); + if (ImmutableCallSite CS = ImmutableCallSite(&I)) + return !callsGCLeafFunction(CS) && !isStatepoint(CS); + return false; }; // Gather all the statepoints which need rewritten. Be careful to only @@ -2682,15 +2380,12 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F) { /// Compute the live-in set for the location rbegin starting from /// the live-out set of the basic block -static void computeLiveInValues(BasicBlock::reverse_iterator rbegin, - BasicBlock::reverse_iterator rend, - DenseSet<Value *> &LiveTmp) { - - for (BasicBlock::reverse_iterator ritr = rbegin; ritr != rend; ritr++) { - Instruction *I = &*ritr; - +static void computeLiveInValues(BasicBlock::reverse_iterator Begin, + BasicBlock::reverse_iterator End, + SetVector<Value *> &LiveTmp) { + for (auto &I : make_range(Begin, End)) { // KILL/Def - Remove this definition from LiveIn - LiveTmp.erase(I); + LiveTmp.remove(&I); // Don't consider *uses* in PHI nodes, we handle their contribution to // predecessor blocks when we seed the LiveOut sets @@ -2698,7 +2393,7 @@ static void computeLiveInValues(BasicBlock::reverse_iterator rbegin, continue; // USE - Add to the LiveIn set for this instruction - for (Value *V : I->operands()) { + for (Value *V : I.operands()) { assert(!isUnhandledGCPointerType(V->getType()) && "support for FCA unimplemented"); if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) { @@ -2718,24 +2413,24 @@ static void computeLiveInValues(BasicBlock::reverse_iterator rbegin, } } -static void computeLiveOutSeed(BasicBlock *BB, DenseSet<Value *> &LiveTmp) { - +static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp) { for (BasicBlock *Succ : successors(BB)) { - const BasicBlock::iterator E(Succ->getFirstNonPHI()); - for (BasicBlock::iterator I = Succ->begin(); I != E; I++) { - PHINode *Phi = cast<PHINode>(&*I); - Value *V = Phi->getIncomingValueForBlock(BB); + for (auto &I : *Succ) { + PHINode *PN = dyn_cast<PHINode>(&I); + if (!PN) + break; + + Value *V = PN->getIncomingValueForBlock(BB); assert(!isUnhandledGCPointerType(V->getType()) && "support for FCA unimplemented"); - if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) { + if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) LiveTmp.insert(V); - } } } } -static DenseSet<Value *> computeKillSet(BasicBlock *BB) { - DenseSet<Value *> KillSet; +static SetVector<Value *> computeKillSet(BasicBlock *BB) { + SetVector<Value *> KillSet; for (Instruction &I : *BB) if (isHandledGCPointerType(I.getType())) KillSet.insert(&I); @@ -2745,7 +2440,7 @@ static DenseSet<Value *> computeKillSet(BasicBlock *BB) { #ifndef NDEBUG /// Check that the items in 'Live' dominate 'TI'. This is used as a basic /// sanity check for the liveness computation. -static void checkBasicSSA(DominatorTree &DT, DenseSet<Value *> &Live, +static void checkBasicSSA(DominatorTree &DT, SetVector<Value *> &Live, TerminatorInst *TI, bool TermOkay = false) { for (Value *V : Live) { if (auto *I = dyn_cast<Instruction>(V)) { @@ -2773,17 +2468,7 @@ static void checkBasicSSA(DominatorTree &DT, GCPtrLivenessData &Data, static void computeLiveInValues(DominatorTree &DT, Function &F, GCPtrLivenessData &Data) { - - SmallSetVector<BasicBlock *, 200> Worklist; - auto AddPredsToWorklist = [&](BasicBlock *BB) { - // We use a SetVector so that we don't have duplicates in the worklist. - Worklist.insert(pred_begin(BB), pred_end(BB)); - }; - auto NextItem = [&]() { - BasicBlock *BB = Worklist.back(); - Worklist.pop_back(); - return BB; - }; + SmallSetVector<BasicBlock *, 32> Worklist; // Seed the liveness for each individual block for (BasicBlock &BB : F) { @@ -2796,56 +2481,55 @@ static void computeLiveInValues(DominatorTree &DT, Function &F, assert(!Data.LiveSet[&BB].count(Kill) && "live set contains kill"); #endif - Data.LiveOut[&BB] = DenseSet<Value *>(); + Data.LiveOut[&BB] = SetVector<Value *>(); computeLiveOutSeed(&BB, Data.LiveOut[&BB]); Data.LiveIn[&BB] = Data.LiveSet[&BB]; - set_union(Data.LiveIn[&BB], Data.LiveOut[&BB]); - set_subtract(Data.LiveIn[&BB], Data.KillSet[&BB]); + Data.LiveIn[&BB].set_union(Data.LiveOut[&BB]); + Data.LiveIn[&BB].set_subtract(Data.KillSet[&BB]); if (!Data.LiveIn[&BB].empty()) - AddPredsToWorklist(&BB); + Worklist.insert(pred_begin(&BB), pred_end(&BB)); } // Propagate that liveness until stable while (!Worklist.empty()) { - BasicBlock *BB = NextItem(); + BasicBlock *BB = Worklist.pop_back_val(); - // Compute our new liveout set, then exit early if it hasn't changed - // despite the contribution of our successor. - DenseSet<Value *> LiveOut = Data.LiveOut[BB]; + // Compute our new liveout set, then exit early if it hasn't changed despite + // the contribution of our successor. + SetVector<Value *> LiveOut = Data.LiveOut[BB]; const auto OldLiveOutSize = LiveOut.size(); for (BasicBlock *Succ : successors(BB)) { assert(Data.LiveIn.count(Succ)); - set_union(LiveOut, Data.LiveIn[Succ]); + LiveOut.set_union(Data.LiveIn[Succ]); } // assert OutLiveOut is a subset of LiveOut if (OldLiveOutSize == LiveOut.size()) { // If the sets are the same size, then we didn't actually add anything - // when unioning our successors LiveIn Thus, the LiveIn of this block + // when unioning our successors LiveIn. Thus, the LiveIn of this block // hasn't changed. continue; } Data.LiveOut[BB] = LiveOut; // Apply the effects of this basic block - DenseSet<Value *> LiveTmp = LiveOut; - set_union(LiveTmp, Data.LiveSet[BB]); - set_subtract(LiveTmp, Data.KillSet[BB]); + SetVector<Value *> LiveTmp = LiveOut; + LiveTmp.set_union(Data.LiveSet[BB]); + LiveTmp.set_subtract(Data.KillSet[BB]); assert(Data.LiveIn.count(BB)); - const DenseSet<Value *> &OldLiveIn = Data.LiveIn[BB]; + const SetVector<Value *> &OldLiveIn = Data.LiveIn[BB]; // assert: OldLiveIn is a subset of LiveTmp if (OldLiveIn.size() != LiveTmp.size()) { Data.LiveIn[BB] = LiveTmp; - AddPredsToWorklist(BB); + Worklist.insert(pred_begin(BB), pred_end(BB)); } - } // while( !worklist.empty() ) + } // while (!Worklist.empty()) #ifndef NDEBUG // Sanity check our output against SSA properties. This helps catch any // missing kills during the above iteration. - for (BasicBlock &BB : F) { + for (BasicBlock &BB : F) checkBasicSSA(DT, Data, BB); - } #endif } @@ -2856,7 +2540,7 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, // Note: The copy is intentional and required assert(Data.LiveOut.count(BB)); - DenseSet<Value *> LiveOut = Data.LiveOut[BB]; + SetVector<Value *> LiveOut = Data.LiveOut[BB]; // We want to handle the statepoint itself oddly. It's // call result is not live (normal), nor are it's arguments @@ -2864,12 +2548,12 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, // specifically what we need to relocate BasicBlock::reverse_iterator rend(Inst->getIterator()); computeLiveInValues(BB->rbegin(), rend, LiveOut); - LiveOut.erase(Inst); + LiveOut.remove(Inst); Out.insert(LiveOut.begin(), LiveOut.end()); } static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, - const CallSite &CS, + CallSite CS, PartiallyConstructedSafepointRecord &Info) { Instruction *Inst = CS.getInstruction(); StatepointLiveSetTy Updated; @@ -2877,33 +2561,32 @@ static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, #ifndef NDEBUG DenseSet<Value *> Bases; - for (auto KVPair : Info.PointerToBase) { + for (auto KVPair : Info.PointerToBase) Bases.insert(KVPair.second); - } #endif + // We may have base pointers which are now live that weren't before. We need // to update the PointerToBase structure to reflect this. for (auto V : Updated) - if (!Info.PointerToBase.count(V)) { - assert(Bases.count(V) && "can't find base for unexpected live value"); - Info.PointerToBase[V] = V; + if (Info.PointerToBase.insert({V, V}).second) { + assert(Bases.count(V) && "Can't find base for unexpected live value!"); continue; } #ifndef NDEBUG - for (auto V : Updated) { + for (auto V : Updated) assert(Info.PointerToBase.count(V) && - "must be able to find base for live value"); - } + "Must be able to find base for live value!"); #endif // Remove any stale base mappings - this can happen since our liveness is - // more precise then the one inherent in the base pointer analysis + // more precise then the one inherent in the base pointer analysis. DenseSet<Value *> ToErase; for (auto KVPair : Info.PointerToBase) if (!Updated.count(KVPair.first)) ToErase.insert(KVPair.first); - for (auto V : ToErase) + + for (auto *V : ToErase) Info.PointerToBase.erase(V); #ifndef NDEBUG diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp index 8569e080873c..da700f18cdaf 100644 --- a/lib/Transforms/Scalar/SCCP.cpp +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -17,15 +17,15 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/IPO/SCCP.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" @@ -38,6 +38,8 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> using namespace llvm; @@ -57,8 +59,8 @@ namespace { /// class LatticeVal { enum LatticeValueTy { - /// undefined - This LLVM Value has no known value yet. - undefined, + /// unknown - This LLVM Value has no known value yet. + unknown, /// constant - This LLVM Value has a specific constant value. constant, @@ -83,9 +85,9 @@ class LatticeVal { } public: - LatticeVal() : Val(nullptr, undefined) {} + LatticeVal() : Val(nullptr, unknown) {} - bool isUndefined() const { return getLatticeValue() == undefined; } + bool isUnknown() const { return getLatticeValue() == unknown; } bool isConstant() const { return getLatticeValue() == constant || getLatticeValue() == forcedconstant; } @@ -112,7 +114,7 @@ public: return false; } - if (isUndefined()) { + if (isUnknown()) { Val.setInt(constant); assert(V && "Marking constant with NULL"); Val.setPointer(V); @@ -139,7 +141,7 @@ public: } void markForcedConstant(Constant *V) { - assert(isUndefined() && "Can't force a defined value!"); + assert(isUnknown() && "Can't force a defined value!"); Val.setInt(forcedconstant); Val.setPointer(V); } @@ -228,7 +230,7 @@ public: /// performing Interprocedural SCCP. void TrackValueOfGlobalVariable(GlobalVariable *GV) { // We only track the contents of scalar globals. - if (GV->getType()->getElementType()->isSingleValueType()) { + if (GV->getValueType()->isSingleValueType()) { LatticeVal &IV = TrackedGlobals[GV]; if (!isa<UndefValue>(GV->getInitializer())) IV.markConstant(GV->getInitializer()); @@ -268,6 +270,18 @@ public: return BBExecutable.count(BB); } + std::vector<LatticeVal> getStructLatticeValueFor(Value *V) const { + std::vector<LatticeVal> StructValues; + StructType *STy = dyn_cast<StructType>(V->getType()); + assert(STy && "getStructLatticeValueFor() can be called only on structs"); + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + auto I = StructValueState.find(std::make_pair(V, i)); + assert(I != StructValueState.end() && "Value not in valuemap!"); + StructValues.push_back(I->second); + } + return StructValues; + } + LatticeVal getLatticeValueFor(Value *V) const { DenseMap<Value*, LatticeVal>::const_iterator I = ValueState.find(V); assert(I != ValueState.end() && "V is not in valuemap!"); @@ -302,6 +316,13 @@ public: } private: + // pushToWorkList - Helper for markConstant/markForcedConstant + void pushToWorkList(LatticeVal &IV, Value *V) { + if (IV.isOverdefined()) + return OverdefinedInstWorkList.push_back(V); + InstWorkList.push_back(V); + } + // markConstant - Make a value be marked as "constant". If the value // is not already a constant, add it to the instruction work list so that // the users of the instruction are updated later. @@ -309,10 +330,7 @@ private: void markConstant(LatticeVal &IV, Value *V, Constant *C) { if (!IV.markConstant(C)) return; DEBUG(dbgs() << "markConstant: " << *C << ": " << *V << '\n'); - if (IV.isOverdefined()) - OverdefinedInstWorkList.push_back(V); - else - InstWorkList.push_back(V); + pushToWorkList(IV, V); } void markConstant(Value *V, Constant *C) { @@ -325,10 +343,7 @@ private: LatticeVal &IV = ValueState[V]; IV.markForcedConstant(C); DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n'); - if (IV.isOverdefined()) - OverdefinedInstWorkList.push_back(V); - else - InstWorkList.push_back(V); + pushToWorkList(IV, V); } @@ -348,14 +363,14 @@ private: } void mergeInValue(LatticeVal &IV, Value *V, LatticeVal MergeWithV) { - if (IV.isOverdefined() || MergeWithV.isUndefined()) + if (IV.isOverdefined() || MergeWithV.isUnknown()) return; // Noop. if (MergeWithV.isOverdefined()) - markOverdefined(IV, V); - else if (IV.isUndefined()) - markConstant(IV, V, MergeWithV.getConstant()); - else if (IV.getConstant() != MergeWithV.getConstant()) - markOverdefined(IV, V); + return markOverdefined(IV, V); + if (IV.isUnknown()) + return markConstant(IV, V, MergeWithV.getConstant()); + if (IV.getConstant() != MergeWithV.getConstant()) + return markOverdefined(IV, V); } void mergeInValue(Value *V, LatticeVal MergeWithV) { @@ -378,7 +393,7 @@ private: return LV; // Common case, already in the map. if (Constant *C = dyn_cast<Constant>(V)) { - // Undef values remain undefined. + // Undef values remain unknown. if (!isa<UndefValue>(V)) LV.markConstant(C); // Constants are constant } @@ -409,7 +424,7 @@ private: if (!Elt) LV.markOverdefined(); // Unknown sort of constant. else if (isa<UndefValue>(Elt)) - ; // Undef values remain undefined. + ; // Undef values remain unknown. else LV.markConstant(Elt); // Constants are constant. } @@ -537,7 +552,7 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, if (!CI) { // Overdefined condition variables, and branches on unfoldable constant // conditions, mean the branch could go either way. - if (!BCValue.isUndefined()) + if (!BCValue.isUnknown()) Succs[0] = Succs[1] = true; return; } @@ -561,9 +576,9 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, LatticeVal SCValue = getValueState(SI->getCondition()); ConstantInt *CI = SCValue.getConstantInt(); - if (!CI) { // Overdefined or undefined condition? + if (!CI) { // Overdefined or unknown condition? // All destinations are executable! - if (!SCValue.isUndefined()) + if (!SCValue.isUnknown()) Succs.assign(TI.getNumSuccessors(), true); return; } @@ -607,7 +622,7 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { // undef conditions mean that neither edge is feasible yet. ConstantInt *CI = BCValue.getConstantInt(); if (!CI) - return !BCValue.isUndefined(); + return !BCValue.isUnknown(); // Constant condition variables mean the branch can only go a single way. return BI->getSuccessor(CI->isZero()) == To; @@ -625,7 +640,7 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { ConstantInt *CI = SCValue.getConstantInt(); if (!CI) - return !SCValue.isUndefined(); + return !SCValue.isUnknown(); return SI->findCaseValue(CI).getCaseSuccessor() == To; } @@ -677,12 +692,12 @@ void SCCPSolver::visitPHINode(PHINode &PN) { // are overdefined, the PHI becomes overdefined as well. If they are all // constant, and they agree with each other, the PHI becomes the identical // constant. If they are constant and don't agree, the PHI is overdefined. - // If there are no executable operands, the PHI remains undefined. + // If there are no executable operands, the PHI remains unknown. // Constant *OperandVal = nullptr; for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { LatticeVal IV = getValueState(PN.getIncomingValue(i)); - if (IV.isUndefined()) continue; // Doesn't influence PHI node. + if (IV.isUnknown()) continue; // Doesn't influence PHI node. if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent())) continue; @@ -708,7 +723,7 @@ void SCCPSolver::visitPHINode(PHINode &PN) { // If we exited the loop, this means that the PHI node only has constant // arguments that agree with each other(and OperandVal is the constant) or // OperandVal is null because there are no defined incoming arguments. If - // this is the case, the PHI remains undefined. + // this is the case, the PHI remains unknown. // if (OperandVal) markConstant(&PN, OperandVal); // Acquire operand value @@ -758,8 +773,9 @@ void SCCPSolver::visitCastInst(CastInst &I) { if (OpSt.isOverdefined()) // Inherit overdefinedness of operand markOverdefined(&I); else if (OpSt.isConstant()) { - Constant *C = - ConstantExpr::getCast(I.getOpcode(), OpSt.getConstant(), I.getType()); + // Fold the constant as we build. + Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpSt.getConstant(), + I.getType(), DL); if (isa<UndefValue>(C)) return; // Propagate constant value @@ -829,7 +845,7 @@ void SCCPSolver::visitSelectInst(SelectInst &I) { return markAnythingOverdefined(&I); LatticeVal CondValue = getValueState(I.getCondition()); - if (CondValue.isUndefined()) + if (CondValue.isUnknown()) return; if (ConstantInt *CondCB = CondValue.getConstantInt()) { @@ -849,9 +865,9 @@ void SCCPSolver::visitSelectInst(SelectInst &I) { TVal.getConstant() == FVal.getConstant()) return markConstant(&I, FVal.getConstant()); - if (TVal.isUndefined()) // select ?, undef, X -> X. + if (TVal.isUnknown()) // select ?, undef, X -> X. return mergeInValue(&I, FVal); - if (FVal.isUndefined()) // select ?, X, undef -> X. + if (FVal.isUnknown()) // select ?, X, undef -> X. return mergeInValue(&I, TVal); markOverdefined(&I); } @@ -890,7 +906,7 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { NonOverdefVal = &V2State; if (NonOverdefVal) { - if (NonOverdefVal->isUndefined()) { + if (NonOverdefVal->isUnknown()) { // Could annihilate value. if (I.getOpcode() == Instruction::And) markConstant(IV, &I, Constant::getNullValue(I.getType())); @@ -934,7 +950,7 @@ void SCCPSolver::visitCmpInst(CmpInst &I) { return markConstant(IV, &I, C); } - // If operands are still undefined, wait for it to resolve. + // If operands are still unknown, wait for it to resolve. if (!V1State.isOverdefined() && !V2State.isOverdefined()) return; @@ -944,69 +960,16 @@ void SCCPSolver::visitCmpInst(CmpInst &I) { void SCCPSolver::visitExtractElementInst(ExtractElementInst &I) { // TODO : SCCP does not handle vectors properly. return markOverdefined(&I); - -#if 0 - LatticeVal &ValState = getValueState(I.getOperand(0)); - LatticeVal &IdxState = getValueState(I.getOperand(1)); - - if (ValState.isOverdefined() || IdxState.isOverdefined()) - markOverdefined(&I); - else if(ValState.isConstant() && IdxState.isConstant()) - markConstant(&I, ConstantExpr::getExtractElement(ValState.getConstant(), - IdxState.getConstant())); -#endif } void SCCPSolver::visitInsertElementInst(InsertElementInst &I) { // TODO : SCCP does not handle vectors properly. return markOverdefined(&I); -#if 0 - LatticeVal &ValState = getValueState(I.getOperand(0)); - LatticeVal &EltState = getValueState(I.getOperand(1)); - LatticeVal &IdxState = getValueState(I.getOperand(2)); - - if (ValState.isOverdefined() || EltState.isOverdefined() || - IdxState.isOverdefined()) - markOverdefined(&I); - else if(ValState.isConstant() && EltState.isConstant() && - IdxState.isConstant()) - markConstant(&I, ConstantExpr::getInsertElement(ValState.getConstant(), - EltState.getConstant(), - IdxState.getConstant())); - else if (ValState.isUndefined() && EltState.isConstant() && - IdxState.isConstant()) - markConstant(&I,ConstantExpr::getInsertElement(UndefValue::get(I.getType()), - EltState.getConstant(), - IdxState.getConstant())); -#endif } void SCCPSolver::visitShuffleVectorInst(ShuffleVectorInst &I) { // TODO : SCCP does not handle vectors properly. return markOverdefined(&I); -#if 0 - LatticeVal &V1State = getValueState(I.getOperand(0)); - LatticeVal &V2State = getValueState(I.getOperand(1)); - LatticeVal &MaskState = getValueState(I.getOperand(2)); - - if (MaskState.isUndefined() || - (V1State.isUndefined() && V2State.isUndefined())) - return; // Undefined output if mask or both inputs undefined. - - if (V1State.isOverdefined() || V2State.isOverdefined() || - MaskState.isOverdefined()) { - markOverdefined(&I); - } else { - // A mix of constant/undef inputs. - Constant *V1 = V1State.isConstant() ? - V1State.getConstant() : UndefValue::get(I.getType()); - Constant *V2 = V2State.isConstant() ? - V2State.getConstant() : UndefValue::get(I.getType()); - Constant *Mask = MaskState.isConstant() ? - MaskState.getConstant() : UndefValue::get(I.getOperand(2)->getType()); - markConstant(&I, ConstantExpr::getShuffleVector(V1, V2, Mask)); - } -#endif } // Handle getelementptr instructions. If all operands are constants then we @@ -1020,7 +983,7 @@ void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { LatticeVal State = getValueState(I.getOperand(i)); - if (State.isUndefined()) + if (State.isUnknown()) return; // Operands are not resolved yet. if (State.isOverdefined()) @@ -1066,7 +1029,7 @@ void SCCPSolver::visitLoadInst(LoadInst &I) { return markAnythingOverdefined(&I); LatticeVal PtrVal = getValueState(I.getOperand(0)); - if (PtrVal.isUndefined()) return; // The pointer is not resolved yet! + if (PtrVal.isUnknown()) return; // The pointer is not resolved yet! LatticeVal &IV = ValueState[&I]; if (IV.isOverdefined()) return; @@ -1094,7 +1057,7 @@ void SCCPSolver::visitLoadInst(LoadInst &I) { } // Transform load from a constant into a constant if possible. - if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, DL)) { + if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) { if (isa<UndefValue>(C)) return; return markConstant(IV, &I, C); @@ -1127,7 +1090,7 @@ CallOverdefined: AI != E; ++AI) { LatticeVal State = getValueState(*AI); - if (State.isUndefined()) + if (State.isUnknown()) return; // Operands are not resolved yet. if (State.isOverdefined()) return markOverdefined(I); @@ -1275,11 +1238,11 @@ void SCCPSolver::Solve() { /// conservatively, as "(zext i8 X -> i32) & 0xFF00" must always return zero, /// even if X isn't defined. bool SCCPSolver::ResolvedUndefsIn(Function &F) { - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - if (!BBExecutable.count(&*BB)) + for (BasicBlock &BB : F) { + if (!BBExecutable.count(&BB)) continue; - for (Instruction &I : *BB) { + for (Instruction &I : BB) { // Look for instructions which produce undef values. if (I.getType()->isVoidTy()) continue; @@ -1301,14 +1264,14 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // more precise than this but it isn't worth bothering. for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { LatticeVal &LV = getStructValueState(&I, i); - if (LV.isUndefined()) + if (LV.isUnknown()) markOverdefined(LV, &I); } continue; } LatticeVal &LV = getValueState(&I); - if (!LV.isUndefined()) continue; + if (!LV.isUnknown()) continue; // extractvalue is safe; check here because the argument is a struct. if (isa<ExtractValueInst>(I)) @@ -1347,7 +1310,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { case Instruction::FDiv: case Instruction::FRem: // Floating-point binary operation: be conservative. - if (Op0LV.isUndefined() && Op1LV.isUndefined()) + if (Op0LV.isUnknown() && Op1LV.isUnknown()) markForcedConstant(&I, Constant::getNullValue(ITy)); else markOverdefined(&I); @@ -1367,7 +1330,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { case Instruction::Mul: case Instruction::And: // Both operands undef -> undef - if (Op0LV.isUndefined() && Op1LV.isUndefined()) + if (Op0LV.isUnknown() && Op1LV.isUnknown()) break; // undef * X -> 0. X could be zero. // undef & X -> 0. X could be zero. @@ -1376,7 +1339,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { case Instruction::Or: // Both operands undef -> undef - if (Op0LV.isUndefined() && Op1LV.isUndefined()) + if (Op0LV.isUnknown() && Op1LV.isUnknown()) break; // undef | X -> -1. X could be -1. markForcedConstant(&I, Constant::getAllOnesValue(ITy)); @@ -1386,7 +1349,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // undef ^ undef -> 0; strictly speaking, this is not strictly // necessary, but we try to be nice to people who expect this // behavior in simple cases - if (Op0LV.isUndefined() && Op1LV.isUndefined()) { + if (Op0LV.isUnknown() && Op1LV.isUnknown()) { markForcedConstant(&I, Constant::getNullValue(ITy)); return true; } @@ -1399,7 +1362,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { case Instruction::URem: // X / undef -> undef. No change. // X % undef -> undef. No change. - if (Op1LV.isUndefined()) break; + if (Op1LV.isUnknown()) break; // X / 0 -> undef. No change. // X % 0 -> undef. No change. @@ -1413,7 +1376,15 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { case Instruction::AShr: // X >>a undef -> undef. - if (Op1LV.isUndefined()) break; + if (Op1LV.isUnknown()) break; + + // Shifting by the bitwidth or more is undefined. + if (Op1LV.isConstant()) { + if (auto *ShiftAmt = Op1LV.getConstantInt()) + if (ShiftAmt->getLimitedValue() >= + ShiftAmt->getType()->getScalarSizeInBits()) + break; + } // undef >>a X -> all ones markForcedConstant(&I, Constant::getAllOnesValue(ITy)); @@ -1422,7 +1393,15 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { case Instruction::Shl: // X << undef -> undef. // X >> undef -> undef. - if (Op1LV.isUndefined()) break; + if (Op1LV.isUnknown()) break; + + // Shifting by the bitwidth or more is undefined. + if (Op1LV.isConstant()) { + if (auto *ShiftAmt = Op1LV.getConstantInt()) + if (ShiftAmt->getLimitedValue() >= + ShiftAmt->getType()->getScalarSizeInBits()) + break; + } // undef << X -> 0 // undef >> X -> 0 @@ -1431,13 +1410,13 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { case Instruction::Select: Op1LV = getValueState(I.getOperand(1)); // undef ? X : Y -> X or Y. There could be commonality between X/Y. - if (Op0LV.isUndefined()) { + if (Op0LV.isUnknown()) { if (!Op1LV.isConstant()) // Pick the constant one if there is any. Op1LV = getValueState(I.getOperand(2)); - } else if (Op1LV.isUndefined()) { + } else if (Op1LV.isUnknown()) { // c ? undef : undef -> undef. No change. Op1LV = getValueState(I.getOperand(2)); - if (Op1LV.isUndefined()) + if (Op1LV.isUnknown()) break; // Otherwise, c ? undef : x -> x. } else { @@ -1487,17 +1466,17 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // Check to see if we have a branch or switch on an undefined value. If so // we force the branch to go one way or the other to make the successor // values live. It doesn't really matter which way we force it. - TerminatorInst *TI = BB->getTerminator(); + TerminatorInst *TI = BB.getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { if (!BI->isConditional()) continue; - if (!getValueState(BI->getCondition()).isUndefined()) + if (!getValueState(BI->getCondition()).isUnknown()) continue; // If the input to SCCP is actually branch on undef, fix the undef to // false. if (isa<UndefValue>(BI->getCondition())) { BI->setCondition(ConstantInt::getFalse(BI->getContext())); - markEdgeExecutable(&*BB, TI->getSuccessor(1)); + markEdgeExecutable(&BB, TI->getSuccessor(1)); return true; } @@ -1510,16 +1489,14 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { } if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - if (!SI->getNumCases()) - continue; - if (!getValueState(SI->getCondition()).isUndefined()) + if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown()) continue; // If the input to SCCP is actually switch on undef, fix the undef to // the first constant. if (isa<UndefValue>(SI->getCondition())) { SI->setCondition(SI->case_begin().getCaseValue()); - markEdgeExecutable(&*BB, SI->case_begin().getCaseSuccessor()); + markEdgeExecutable(&BB, SI->case_begin().getCaseSuccessor()); return true; } @@ -1531,75 +1508,53 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { return false; } - -namespace { - //===--------------------------------------------------------------------===// - // - /// SCCP Class - This class uses the SCCPSolver to implement a per-function - /// Sparse Conditional Constant Propagator. - /// - struct SCCP : public FunctionPass { - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - } - static char ID; // Pass identification, replacement for typeid - SCCP() : FunctionPass(ID) { - initializeSCCPPass(*PassRegistry::getPassRegistry()); +static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { + Constant *Const = nullptr; + if (V->getType()->isStructTy()) { + std::vector<LatticeVal> IVs = Solver.getStructLatticeValueFor(V); + if (std::any_of(IVs.begin(), IVs.end(), + [](LatticeVal &LV) { return LV.isOverdefined(); })) + return false; + std::vector<Constant *> ConstVals; + StructType *ST = dyn_cast<StructType>(V->getType()); + for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { + LatticeVal V = IVs[i]; + ConstVals.push_back(V.isConstant() + ? V.getConstant() + : UndefValue::get(ST->getElementType(i))); } + Const = ConstantStruct::get(ST, ConstVals); + } else { + LatticeVal IV = Solver.getLatticeValueFor(V); + if (IV.isOverdefined()) + return false; + Const = IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); + } + assert(Const && "Constant is nullptr here!"); + DEBUG(dbgs() << " Constant: " << *Const << " = " << *V << '\n'); - // runOnFunction - Run the Sparse Conditional Constant Propagation - // algorithm, and return true if the function was modified. - // - bool runOnFunction(Function &F) override; - }; -} // end anonymous namespace - -char SCCP::ID = 0; -INITIALIZE_PASS(SCCP, "sccp", - "Sparse Conditional Constant Propagation", false, false) - -// createSCCPPass - This is the public interface to this file. -FunctionPass *llvm::createSCCPPass() { - return new SCCP(); + // Replaces all of the uses of a variable with uses of the constant. + V->replaceAllUsesWith(Const); + return true; } -static void DeleteInstructionInBlock(BasicBlock *BB) { - DEBUG(dbgs() << " BasicBlock Dead:" << *BB); - ++NumDeadBlocks; - - // Check to see if there are non-terminating instructions to delete. - if (isa<TerminatorInst>(BB->begin())) - return; +static bool tryToReplaceInstWithConstant(SCCPSolver &Solver, Instruction *Inst, + bool shouldEraseFromParent) { + if (!tryToReplaceWithConstant(Solver, Inst)) + return false; - // Delete the instructions backwards, as it has a reduced likelihood of having - // to update as many def-use and use-def chains. - Instruction *EndInst = BB->getTerminator(); // Last not to be deleted. - while (EndInst != BB->begin()) { - // Delete the next to last instruction. - Instruction *Inst = &*--EndInst->getIterator(); - if (!Inst->use_empty()) - Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); - if (Inst->isEHPad()) { - EndInst = Inst; - continue; - } - BB->getInstList().erase(Inst); - ++NumInstRemoved; - } + // Delete the instruction. + if (shouldEraseFromParent) + Inst->eraseFromParent(); + return true; } -// runOnFunction() - Run the Sparse Conditional Constant Propagation algorithm, +// runSCCP() - Run the Sparse Conditional Constant Propagation algorithm, // and return true if the function was modified. // -bool SCCP::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) - return false; - +static bool runSCCP(Function &F, const DataLayout &DL, + const TargetLibraryInfo *TLI) { DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n"); - const DataLayout &DL = F.getParent()->getDataLayout(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); SCCPSolver Solver(DL, TLI); // Mark the first block of the function as being executable. @@ -1623,9 +1578,13 @@ bool SCCP::runOnFunction(Function &F) { // delete their contents now. Note that we cannot actually delete the blocks, // as we cannot modify the CFG of the function. - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - if (!Solver.isBlockExecutable(&*BB)) { - DeleteInstructionInBlock(&*BB); + for (BasicBlock &BB : F) { + if (!Solver.isBlockExecutable(&BB)) { + DEBUG(dbgs() << " BasicBlock Dead:" << BB); + + ++NumDeadBlocks; + NumInstRemoved += removeAllNonTerminatorAndEHPadInstructions(&BB); + MadeChanges = true; continue; } @@ -1633,70 +1592,74 @@ bool SCCP::runOnFunction(Function &F) { // Iterate over all of the instructions in a function, replacing them with // constants if we have found them to be of constant values. // - for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { + for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { Instruction *Inst = &*BI++; if (Inst->getType()->isVoidTy() || isa<TerminatorInst>(Inst)) continue; - // TODO: Reconstruct structs from their elements. - if (Inst->getType()->isStructTy()) - continue; - - LatticeVal IV = Solver.getLatticeValueFor(Inst); - if (IV.isOverdefined()) - continue; - - Constant *Const = IV.isConstant() - ? IV.getConstant() : UndefValue::get(Inst->getType()); - DEBUG(dbgs() << " Constant: " << *Const << " = " << *Inst << '\n'); - - // Replaces all of the uses of a variable with uses of the constant. - Inst->replaceAllUsesWith(Const); - - // Delete the instruction. - Inst->eraseFromParent(); - - // Hey, we just changed something! - MadeChanges = true; - ++NumInstRemoved; + if (tryToReplaceInstWithConstant(Solver, Inst, + true /* shouldEraseFromParent */)) { + // Hey, we just changed something! + MadeChanges = true; + ++NumInstRemoved; + } } } return MadeChanges; } +PreservedAnalyses SCCPPass::run(Function &F, AnalysisManager<Function> &AM) { + const DataLayout &DL = F.getParent()->getDataLayout(); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + if (!runSCCP(F, DL, &TLI)) + return PreservedAnalyses::all(); + + auto PA = PreservedAnalyses(); + PA.preserve<GlobalsAA>(); + return PA; +} + namespace { - //===--------------------------------------------------------------------===// +//===--------------------------------------------------------------------===// +// +/// SCCP Class - This class uses the SCCPSolver to implement a per-function +/// Sparse Conditional Constant Propagator. +/// +class SCCPLegacyPass : public FunctionPass { +public: + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + static char ID; // Pass identification, replacement for typeid + SCCPLegacyPass() : FunctionPass(ID) { + initializeSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + // runOnFunction - Run the Sparse Conditional Constant Propagation + // algorithm, and return true if the function was modified. // - /// IPSCCP Class - This class implements interprocedural Sparse Conditional - /// Constant Propagation. - /// - struct IPSCCP : public ModulePass { - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } - static char ID; - IPSCCP() : ModulePass(ID) { - initializeIPSCCPPass(*PassRegistry::getPassRegistry()); - } - bool runOnModule(Module &M) override; - }; + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + const DataLayout &DL = F.getParent()->getDataLayout(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return runSCCP(F, DL, TLI); + } +}; } // end anonymous namespace -char IPSCCP::ID = 0; -INITIALIZE_PASS_BEGIN(IPSCCP, "ipsccp", - "Interprocedural Sparse Conditional Constant Propagation", - false, false) +char SCCPLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(SCCPLegacyPass, "sccp", + "Sparse Conditional Constant Propagation", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(IPSCCP, "ipsccp", - "Interprocedural Sparse Conditional Constant Propagation", - false, false) - -// createIPSCCPPass - This is the public interface to this file. -ModulePass *llvm::createIPSCCPPass() { - return new IPSCCP(); -} +INITIALIZE_PASS_END(SCCPLegacyPass, "sccp", + "Sparse Conditional Constant Propagation", false, false) +// createSCCPPass - This is the public interface to this file. +FunctionPass *llvm::createSCCPPass() { return new SCCPLegacyPass(); } static bool AddressIsTaken(const GlobalValue *GV) { // Delete any dead constantexpr klingons. @@ -1725,10 +1688,8 @@ static bool AddressIsTaken(const GlobalValue *GV) { return false; } -bool IPSCCP::runOnModule(Module &M) { - const DataLayout &DL = M.getDataLayout(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); +static bool runIPSCCP(Module &M, const DataLayout &DL, + const TargetLibraryInfo *TLI) { SCCPSolver Solver(DL, TLI); // AddressTakenFunctions - This set keeps track of the address-taken functions @@ -1741,32 +1702,32 @@ bool IPSCCP::runOnModule(Module &M) { // Loop over all functions, marking arguments to those with their addresses // taken or that are external as overdefined. // - for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { - if (F->isDeclaration()) + for (Function &F : M) { + if (F.isDeclaration()) continue; - // If this is a strong or ODR definition of this function, then we can - // propagate information about its result into callsites of it. - if (!F->mayBeOverridden()) - Solver.AddTrackedFunction(&*F); + // If this is an exact definition of this function, then we can propagate + // information about its result into callsites of it. + if (F.hasExactDefinition()) + Solver.AddTrackedFunction(&F); // If this function only has direct calls that we can see, we can track its // arguments and return value aggressively, and can assume it is not called // unless we see evidence to the contrary. - if (F->hasLocalLinkage()) { - if (AddressIsTaken(&*F)) - AddressTakenFunctions.insert(&*F); + if (F.hasLocalLinkage()) { + if (AddressIsTaken(&F)) + AddressTakenFunctions.insert(&F); else { - Solver.AddArgumentTrackedFunction(&*F); + Solver.AddArgumentTrackedFunction(&F); continue; } } // Assume the function is called. - Solver.MarkBlockExecutable(&F->front()); + Solver.MarkBlockExecutable(&F.front()); // Assume nothing about the incoming arguments. - for (Argument &AI : F->args()) + for (Argument &AI : F.args()) Solver.markAnythingOverdefined(&AI); } @@ -1784,8 +1745,8 @@ bool IPSCCP::runOnModule(Module &M) { DEBUG(dbgs() << "RESOLVING UNDEFS\n"); ResolvedUndefs = false; - for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) - ResolvedUndefs |= Solver.ResolvedUndefsIn(*F); + for (Function &F : M) + ResolvedUndefs |= Solver.ResolvedUndefsIn(F); } bool MadeChanges = false; @@ -1795,79 +1756,47 @@ bool IPSCCP::runOnModule(Module &M) { // SmallVector<BasicBlock*, 512> BlocksToErase; - for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { - if (F->isDeclaration()) + for (Function &F : M) { + if (F.isDeclaration()) continue; - if (Solver.isBlockExecutable(&F->front())) { - for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); - AI != E; ++AI) { - if (AI->use_empty() || AI->getType()->isStructTy()) continue; - - // TODO: Could use getStructLatticeValueFor to find out if the entire - // result is a constant and replace it entirely if so. - - LatticeVal IV = Solver.getLatticeValueFor(&*AI); - if (IV.isOverdefined()) continue; - - Constant *CST = IV.isConstant() ? - IV.getConstant() : UndefValue::get(AI->getType()); - DEBUG(dbgs() << "*** Arg " << *AI << " = " << *CST <<"\n"); - - // Replaces all of the uses of a variable with uses of the - // constant. - AI->replaceAllUsesWith(CST); - ++IPNumArgsElimed; + if (Solver.isBlockExecutable(&F.front())) { + for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E; + ++AI) { + if (AI->use_empty()) + continue; + if (tryToReplaceWithConstant(Solver, &*AI)) + ++IPNumArgsElimed; } } - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { if (!Solver.isBlockExecutable(&*BB)) { - DeleteInstructionInBlock(&*BB); - MadeChanges = true; + DEBUG(dbgs() << " BasicBlock Dead:" << *BB); - TerminatorInst *TI = BB->getTerminator(); - for (BasicBlock *Succ : TI->successors()) { - if (!Succ->empty() && isa<PHINode>(Succ->begin())) - Succ->removePredecessor(&*BB); - } - if (!TI->use_empty()) - TI->replaceAllUsesWith(UndefValue::get(TI->getType())); - TI->eraseFromParent(); - new UnreachableInst(M.getContext(), &*BB); + ++NumDeadBlocks; + NumInstRemoved += + changeToUnreachable(BB->getFirstNonPHI(), /*UseLLVMTrap=*/false); + + MadeChanges = true; - if (&*BB != &F->front()) + if (&*BB != &F.front()) BlocksToErase.push_back(&*BB); continue; } for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { Instruction *Inst = &*BI++; - if (Inst->getType()->isVoidTy() || Inst->getType()->isStructTy()) + if (Inst->getType()->isVoidTy()) continue; - - // TODO: Could use getStructLatticeValueFor to find out if the entire - // result is a constant and replace it entirely if so. - - LatticeVal IV = Solver.getLatticeValueFor(Inst); - if (IV.isOverdefined()) - continue; - - Constant *Const = IV.isConstant() - ? IV.getConstant() : UndefValue::get(Inst->getType()); - DEBUG(dbgs() << " Constant: " << *Const << " = " << *Inst << '\n'); - - // Replaces all of the uses of a variable with uses of the - // constant. - Inst->replaceAllUsesWith(Const); - - // Delete the instruction. - if (!isa<CallInst>(Inst) && !isa<TerminatorInst>(Inst)) - Inst->eraseFromParent(); - - // Hey, we just changed something! - MadeChanges = true; - ++IPNumInstRemoved; + if (tryToReplaceInstWithConstant( + Solver, Inst, + !isa<CallInst>(Inst) && + !isa<TerminatorInst>(Inst) /* shouldEraseFromParent */)) { + // Hey, we just changed something! + MadeChanges = true; + ++IPNumInstRemoved; + } } } @@ -1918,7 +1847,7 @@ bool IPSCCP::runOnModule(Module &M) { } // Finally, delete the basic block. - F->getBasicBlockList().erase(DeadBB); + F.getBasicBlockList().erase(DeadBB); } BlocksToErase.clear(); } @@ -1937,18 +1866,17 @@ bool IPSCCP::runOnModule(Module &M) { // TODO: Process multiple value ret instructions also. const DenseMap<Function*, LatticeVal> &RV = Solver.getTrackedRetVals(); - for (DenseMap<Function*, LatticeVal>::const_iterator I = RV.begin(), - E = RV.end(); I != E; ++I) { - Function *F = I->first; - if (I->second.isOverdefined() || F->getReturnType()->isVoidTy()) + for (const auto &I : RV) { + Function *F = I.first; + if (I.second.isOverdefined() || F->getReturnType()->isVoidTy()) continue; // We can only do this if we know that nothing else can call the function. if (!F->hasLocalLinkage() || AddressTakenFunctions.count(F)) continue; - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) - if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) + for (BasicBlock &BB : *F) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) if (!isa<UndefValue>(RI->getOperand(0))) ReturnsToZap.push_back(RI); } @@ -1978,3 +1906,52 @@ bool IPSCCP::runOnModule(Module &M) { return MadeChanges; } + +PreservedAnalyses IPSCCPPass::run(Module &M, AnalysisManager<Module> &AM) { + const DataLayout &DL = M.getDataLayout(); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); + if (!runIPSCCP(M, DL, &TLI)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +namespace { +//===--------------------------------------------------------------------===// +// +/// IPSCCP Class - This class implements interprocedural Sparse Conditional +/// Constant Propagation. +/// +class IPSCCPLegacyPass : public ModulePass { +public: + static char ID; + + IPSCCPLegacyPass() : ModulePass(ID) { + initializeIPSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + const DataLayout &DL = M.getDataLayout(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return runIPSCCP(M, DL, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } +}; +} // end anonymous namespace + +char IPSCCPLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", + "Interprocedural Sparse Conditional Constant Propagation", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", + "Interprocedural Sparse Conditional Constant Propagation", + false, false) + +// createIPSCCPPass - This is the public interface to this file. +ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); } diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index a7361b5fe083..7d33259c030b 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -55,8 +55,8 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" -#if __cplusplus >= 201103L && !defined(NDEBUG) -// We only use this for a debug check in C++11 +#ifndef NDEBUG +// We only use this for a debug check. #include <random> #endif @@ -87,12 +87,13 @@ static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), cl::Hidden); namespace { -/// \brief A custom IRBuilder inserter which prefixes all names if they are -/// preserved. -template <bool preserveNames = true> -class IRBuilderPrefixedInserter - : public IRBuilderDefaultInserter<preserveNames> { +/// \brief A custom IRBuilder inserter which prefixes all names, but only in +/// Assert builds. +class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter { std::string Prefix; + const Twine getNameWithPrefix(const Twine &Name) const { + return Name.isTriviallyEmpty() ? Name : Prefix + Name; + } public: void SetNamePrefix(const Twine &P) { Prefix = P.str(); } @@ -100,27 +101,13 @@ public: protected: void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, BasicBlock::iterator InsertPt) const { - IRBuilderDefaultInserter<preserveNames>::InsertHelper( - I, Name.isTriviallyEmpty() ? Name : Prefix + Name, BB, InsertPt); + IRBuilderDefaultInserter::InsertHelper(I, getNameWithPrefix(Name), BB, + InsertPt); } }; -// Specialization for not preserving the name is trivial. -template <> -class IRBuilderPrefixedInserter<false> - : public IRBuilderDefaultInserter<false> { -public: - void SetNamePrefix(const Twine &P) {} -}; - /// \brief Provide a typedef for IRBuilder that drops names in release builds. -#ifndef NDEBUG -typedef llvm::IRBuilder<true, ConstantFolder, IRBuilderPrefixedInserter<true>> - IRBuilderTy; -#else -typedef llvm::IRBuilder<false, ConstantFolder, IRBuilderPrefixedInserter<false>> - IRBuilderTy; -#endif +using IRBuilderTy = llvm::IRBuilder<ConstantFolder, IRBuilderPrefixedInserter>; } namespace { @@ -694,7 +681,7 @@ private: // langref in a very strict sense. If we ever want to enable // SROAStrictInbounds, this code should be factored cleanly into // PtrUseVisitor, but it is easier to experiment with SROAStrictInbounds - // by writing out the code here where we have tho underlying allocation + // by writing out the code here where we have the underlying allocation // size readily available. APInt GEPOffset = Offset; const DataLayout &DL = GEPI.getModule()->getDataLayout(); @@ -1015,7 +1002,7 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) }), Slices.end()); -#if __cplusplus >= 201103L && !defined(NDEBUG) +#ifndef NDEBUG if (SROARandomShuffleSlices) { std::mt19937 MT(static_cast<unsigned>(sys::TimeValue::now().msec())); std::shuffle(Slices.begin(), Slices.end(), MT); @@ -1192,8 +1179,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) { // If this pointer is always safe to load, or if we can prove that there // is already a load in the block, then we can move the load to the pred // block. - if (isDereferenceablePointer(InVal, DL) || - isSafeToLoadUnconditionally(InVal, TI, MaxAlign)) + if (isSafeToLoadUnconditionally(InVal, MaxAlign, DL, TI)) continue; return false; @@ -1262,8 +1248,6 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) { Value *TValue = SI.getTrueValue(); Value *FValue = SI.getFalseValue(); const DataLayout &DL = SI.getModule()->getDataLayout(); - bool TDerefable = isDereferenceablePointer(TValue, DL); - bool FDerefable = isDereferenceablePointer(FValue, DL); for (User *U : SI.users()) { LoadInst *LI = dyn_cast<LoadInst>(U); @@ -1273,11 +1257,9 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) { // Both operands to the select need to be dereferencable, either // absolutely (e.g. allocas) or at this point because we can see other // accesses to it. - if (!TDerefable && - !isSafeToLoadUnconditionally(TValue, LI, LI->getAlignment())) + if (!isSafeToLoadUnconditionally(TValue, LI->getAlignment(), DL, LI)) return false; - if (!FDerefable && - !isSafeToLoadUnconditionally(FValue, LI, LI->getAlignment())) + if (!isSafeToLoadUnconditionally(FValue, LI->getAlignment(), DL, LI)) return false; } @@ -1570,7 +1552,7 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, if (Operator::getOpcode(Ptr) == Instruction::BitCast) { Ptr = cast<Operator>(Ptr)->getOperand(0); } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(Ptr)) { - if (GA->mayBeOverridden()) + if (GA->isInterposable()) break; Ptr = GA->getAliasee(); } else { @@ -1653,8 +1635,10 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) { OldTy = OldTy->getScalarType(); NewTy = NewTy->getScalarType(); if (NewTy->isPointerTy() || OldTy->isPointerTy()) { - if (NewTy->isPointerTy() && OldTy->isPointerTy()) - return true; + if (NewTy->isPointerTy() && OldTy->isPointerTy()) { + return cast<PointerType>(NewTy)->getPointerAddressSpace() == + cast<PointerType>(OldTy)->getPointerAddressSpace(); + } if (NewTy->isIntegerTy() || OldTy->isIntegerTy()) return true; return false; @@ -3123,9 +3107,14 @@ private: void emitFunc(Type *Ty, Value *&Agg, const Twine &Name) { assert(Ty->isSingleValueType()); // Extract the single value and store it using the indices. - Value *Store = IRB.CreateStore( - IRB.CreateExtractValue(Agg, Indices, Name + ".extract"), - IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep")); + // + // The gep and extractvalue values are factored out of the CreateStore + // call to make the output independent of the argument evaluation order. + Value *ExtractValue = + IRB.CreateExtractValue(Agg, Indices, Name + ".extract"); + Value *InBoundsGEP = + IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); + Value *Store = IRB.CreateStore(ExtractValue, InBoundsGEP); (void)Store; DEBUG(dbgs() << " to: " << *Store << "\n"); } @@ -3380,11 +3369,15 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { for (auto &P : AS.partitions()) { for (Slice &S : P) { Instruction *I = cast<Instruction>(S.getUse()->getUser()); - if (!S.isSplittable() ||S.endOffset() <= P.endOffset()) { - // If this was a load we have to track that it can't participate in any - // pre-splitting! + if (!S.isSplittable() || S.endOffset() <= P.endOffset()) { + // If this is a load we have to track that it can't participate in any + // pre-splitting. If this is a store of a load we have to track that + // that load also can't participate in any pre-splitting. if (auto *LI = dyn_cast<LoadInst>(I)) UnsplittableLoads.insert(LI); + else if (auto *SI = dyn_cast<StoreInst>(I)) + if (auto *LI = dyn_cast<LoadInst>(SI->getValueOperand())) + UnsplittableLoads.insert(LI); continue; } assert(P.endOffset() > S.beginOffset() && @@ -3411,9 +3404,9 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { } Loads.push_back(LI); - } else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser())) { - if (!SI || - S.getUse() != &SI->getOperandUse(SI->getPointerOperandIndex())) + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + if (S.getUse() != &SI->getOperandUse(SI->getPointerOperandIndex())) + // Skip stores *of* pointers. FIXME: This shouldn't even be possible! continue; auto *StoredLoad = dyn_cast<LoadInst>(SI->getValueOperand()); if (!StoredLoad || !StoredLoad->isSimple()) @@ -3937,15 +3930,19 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, Worklist.insert(NewAI); } } else { - // If we can't promote the alloca, iterate on it to check for new - // refinements exposed by splitting the current alloca. Don't iterate on an - // alloca which didn't actually change and didn't get promoted. - if (NewAI != &AI) - Worklist.insert(NewAI); - // Drop any post-promotion work items if promotion didn't happen. while (PostPromotionWorklist.size() > PPWOldSize) PostPromotionWorklist.pop_back(); + + // We couldn't promote and we didn't create a new partition, nothing + // happened. + if (NewAI == &AI) + return nullptr; + + // If we can't promote the alloca, iterate on it to check for new + // refinements exposed by splitting the current alloca. Don't iterate on an + // alloca which didn't actually change and didn't get promoted. + Worklist.insert(NewAI); } return NewAI; @@ -4024,12 +4021,12 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { auto *Var = DbgDecl->getVariable(); auto *Expr = DbgDecl->getExpression(); DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); - bool IsSplit = Pieces.size() > 1; + uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType()); for (auto Piece : Pieces) { // Create a piece expression describing the new partition or reuse AI's // expression if there is only one partition. auto *PieceExpr = Expr; - if (IsSplit || Expr->isBitPiece()) { + if (Piece.Size < AllocaSize || Expr->isBitPiece()) { // If this alloca is already a scalar replacement of a larger aggregate, // Piece.Offset describes the offset inside the scalar. uint64_t Offset = Expr->isBitPiece() ? Expr->getBitPieceOffset() : 0; @@ -4043,6 +4040,9 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { Size = std::min(Size, AbsEnd - Start); } PieceExpr = DIB.createBitPieceExpression(Start, Size); + } else { + assert(Pieces.size() == 1 && + "partition is as large as original alloca"); } // Remove any existing dbg.declare intrinsic describing the same alloca. @@ -4237,14 +4237,19 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, PostPromotionWorklist.clear(); } while (!Worklist.empty()); + if (!Changed) + return PreservedAnalyses::all(); + // FIXME: Even when promoting allocas we should preserve some abstract set of // CFG-specific analyses. - return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; } -PreservedAnalyses SROA::run(Function &F, AnalysisManager<Function> *AM) { - return runImpl(F, AM->getResult<DominatorTreeAnalysis>(F), - AM->getResult<AssumptionAnalysis>(F)); +PreservedAnalyses SROA::run(Function &F, AnalysisManager<Function> &AM) { + return runImpl(F, AM.getResult<DominatorTreeAnalysis>(F), + AM.getResult<AssumptionAnalysis>(F)); } /// A legacy pass for the legacy pass manager that wraps the \c SROA pass. @@ -4260,7 +4265,7 @@ public: initializeSROALegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; auto PA = Impl.runImpl( diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index 52d477cc9573..f235b12e49cc 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/Analysis/TypeBasedAliasAnalysis.h" +#include "llvm/Transforms/Scalar/GVN.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" @@ -31,49 +32,52 @@ using namespace llvm; /// ScalarOpts library. void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeADCELegacyPassPass(Registry); - initializeBDCEPass(Registry); + initializeBDCELegacyPassPass(Registry); initializeAlignmentFromAssumptionsPass(Registry); - initializeConstantHoistingPass(Registry); + initializeConstantHoistingLegacyPassPass(Registry); initializeConstantPropagationPass(Registry); initializeCorrelatedValuePropagationPass(Registry); - initializeDCEPass(Registry); + initializeDCELegacyPassPass(Registry); initializeDeadInstEliminationPass(Registry); initializeScalarizerPass(Registry); - initializeDSEPass(Registry); - initializeGVNPass(Registry); + initializeDSELegacyPassPass(Registry); + initializeGuardWideningLegacyPassPass(Registry); + initializeGVNLegacyPassPass(Registry); initializeEarlyCSELegacyPassPass(Registry); + initializeGVNHoistLegacyPassPass(Registry); initializeFlattenCFGPassPass(Registry); initializeInductiveRangeCheckEliminationPass(Registry); - initializeIndVarSimplifyPass(Registry); + initializeIndVarSimplifyLegacyPassPass(Registry); initializeJumpThreadingPass(Registry); - initializeLICMPass(Registry); - initializeLoopDeletionPass(Registry); - initializeLoopAccessAnalysisPass(Registry); - initializeLoopInstSimplifyPass(Registry); + initializeLegacyLICMPassPass(Registry); + initializeLoopDataPrefetchPass(Registry); + initializeLoopDeletionLegacyPassPass(Registry); + initializeLoopAccessLegacyAnalysisPass(Registry); + initializeLoopInstSimplifyLegacyPassPass(Registry); initializeLoopInterchangePass(Registry); - initializeLoopRotatePass(Registry); + initializeLoopRotateLegacyPassPass(Registry); initializeLoopStrengthReducePass(Registry); initializeLoopRerollPass(Registry); initializeLoopUnrollPass(Registry); initializeLoopUnswitchPass(Registry); - initializeLoopIdiomRecognizePass(Registry); - initializeLowerAtomicPass(Registry); + initializeLoopVersioningLICMPass(Registry); + initializeLoopIdiomRecognizeLegacyPassPass(Registry); + initializeLowerAtomicLegacyPassPass(Registry); initializeLowerExpectIntrinsicPass(Registry); - initializeMemCpyOptPass(Registry); - initializeMergedLoadStoreMotionPass(Registry); + initializeLowerGuardIntrinsicPass(Registry); + initializeMemCpyOptLegacyPassPass(Registry); + initializeMergedLoadStoreMotionLegacyPassPass(Registry); initializeNaryReassociatePass(Registry); - initializePartiallyInlineLibCallsPass(Registry); - initializeReassociatePass(Registry); + initializePartiallyInlineLibCallsLegacyPassPass(Registry); + initializeReassociateLegacyPassPass(Registry); initializeRegToMemPass(Registry); initializeRewriteStatepointsForGCPass(Registry); - initializeSCCPPass(Registry); - initializeIPSCCPPass(Registry); + initializeSCCPLegacyPassPass(Registry); + initializeIPSCCPLegacyPassPass(Registry); initializeSROALegacyPassPass(Registry); - initializeSROA_DTPass(Registry); - initializeSROA_SSAUpPass(Registry); initializeCFGSimplifyPassPass(Registry); initializeStructurizeCFGPass(Registry); - initializeSinkingPass(Registry); + initializeSinkingLegacyPassPass(Registry); initializeTailCallElimPass(Registry); initializeSeparateConstOffsetFromGEPPass(Registry); initializeSpeculativeExecutionPass(Registry); @@ -81,9 +85,11 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLoadCombinePass(Registry); initializePlaceBackedgeSafepointsImplPass(Registry); initializePlaceSafepointsPass(Registry); - initializeFloat2IntPass(Registry); - initializeLoopDistributePass(Registry); + initializeFloat2IntLegacyPassPass(Registry); + initializeLoopDistributeLegacyPass(Registry); initializeLoopLoadEliminationPass(Registry); + initializeLoopSimplifyCFGLegacyPassPass(Registry); + initializeLoopVersioningPassPass(Registry); } void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { @@ -154,6 +160,10 @@ void LLVMAddLoopRerollPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopRerollPass()); } +void LLVMAddLoopSimplifyCFGPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopSimplifyCFGPass()); +} + void LLVMAddLoopUnrollPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopUnrollPass()); } @@ -187,16 +197,16 @@ void LLVMAddSCCPPass(LLVMPassManagerRef PM) { } void LLVMAddScalarReplAggregatesPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createScalarReplAggregatesPass()); + unwrap(PM)->add(createSROAPass()); } void LLVMAddScalarReplAggregatesPassSSA(LLVMPassManagerRef PM) { - unwrap(PM)->add(createScalarReplAggregatesPass(-1, false)); + unwrap(PM)->add(createSROAPass()); } void LLVMAddScalarReplAggregatesPassWithThreshold(LLVMPassManagerRef PM, int Threshold) { - unwrap(PM)->add(createScalarReplAggregatesPass(Threshold)); + unwrap(PM)->add(createSROAPass()); } void LLVMAddSimplifyLibCallsPass(LLVMPassManagerRef PM) { @@ -227,6 +237,10 @@ void LLVMAddEarlyCSEPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createEarlyCSEPass()); } +void LLVMAddGVNHoistLegacyPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createGVNHoistPass()); +} + void LLVMAddTypeBasedAliasAnalysisPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createTypeBasedAAWrapperPass()); } diff --git a/lib/Transforms/Scalar/ScalarReplAggregates.cpp b/lib/Transforms/Scalar/ScalarReplAggregates.cpp deleted file mode 100644 index 114d22ddf2e4..000000000000 --- a/lib/Transforms/Scalar/ScalarReplAggregates.cpp +++ /dev/null @@ -1,2630 +0,0 @@ -//===- ScalarReplAggregates.cpp - Scalar Replacement of Aggregates --------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This transformation implements the well known scalar replacement of -// aggregates transformation. This xform breaks up alloca instructions of -// aggregate type (structure or array) into individual alloca instructions for -// each member (if possible). Then, if possible, it transforms the individual -// alloca instructions into nice clean scalar SSA form. -// -// This combines a simple SRoA algorithm with the Mem2Reg algorithm because they -// often interact, especially for C++ programs. As such, iterating between -// SRoA, then Mem2Reg until we run out of things to promote works well. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/Scalar.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/Loads.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/CallSite.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DIBuilder.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" -#include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" -#include "llvm/Pass.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/MathExtras.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/PromoteMemToReg.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" -using namespace llvm; - -#define DEBUG_TYPE "scalarrepl" - -STATISTIC(NumReplaced, "Number of allocas broken up"); -STATISTIC(NumPromoted, "Number of allocas promoted"); -STATISTIC(NumAdjusted, "Number of scalar allocas adjusted to allow promotion"); -STATISTIC(NumConverted, "Number of aggregates converted to scalar"); - -namespace { -#define SROA SROA_ - struct SROA : public FunctionPass { - SROA(int T, bool hasDT, char &ID, int ST, int AT, int SLT) - : FunctionPass(ID), HasDomTree(hasDT) { - if (T == -1) - SRThreshold = 128; - else - SRThreshold = T; - if (ST == -1) - StructMemberThreshold = 32; - else - StructMemberThreshold = ST; - if (AT == -1) - ArrayElementThreshold = 8; - else - ArrayElementThreshold = AT; - if (SLT == -1) - // Do not limit the scalar integer load size if no threshold is given. - ScalarLoadThreshold = -1; - else - ScalarLoadThreshold = SLT; - } - - bool runOnFunction(Function &F) override; - - bool performScalarRepl(Function &F); - bool performPromotion(Function &F); - - private: - bool HasDomTree; - - /// DeadInsts - Keep track of instructions we have made dead, so that - /// we can remove them after we are done working. - SmallVector<Value*, 32> DeadInsts; - - /// AllocaInfo - When analyzing uses of an alloca instruction, this captures - /// information about the uses. All these fields are initialized to false - /// and set to true when something is learned. - struct AllocaInfo { - /// The alloca to promote. - AllocaInst *AI; - - /// CheckedPHIs - This is a set of verified PHI nodes, to prevent infinite - /// looping and avoid redundant work. - SmallPtrSet<PHINode*, 8> CheckedPHIs; - - /// isUnsafe - This is set to true if the alloca cannot be SROA'd. - bool isUnsafe : 1; - - /// isMemCpySrc - This is true if this aggregate is memcpy'd from. - bool isMemCpySrc : 1; - - /// isMemCpyDst - This is true if this aggregate is memcpy'd into. - bool isMemCpyDst : 1; - - /// hasSubelementAccess - This is true if a subelement of the alloca is - /// ever accessed, or false if the alloca is only accessed with mem - /// intrinsics or load/store that only access the entire alloca at once. - bool hasSubelementAccess : 1; - - /// hasALoadOrStore - This is true if there are any loads or stores to it. - /// The alloca may just be accessed with memcpy, for example, which would - /// not set this. - bool hasALoadOrStore : 1; - - explicit AllocaInfo(AllocaInst *ai) - : AI(ai), isUnsafe(false), isMemCpySrc(false), isMemCpyDst(false), - hasSubelementAccess(false), hasALoadOrStore(false) {} - }; - - /// SRThreshold - The maximum alloca size to considered for SROA. - unsigned SRThreshold; - - /// StructMemberThreshold - The maximum number of members a struct can - /// contain to be considered for SROA. - unsigned StructMemberThreshold; - - /// ArrayElementThreshold - The maximum number of elements an array can - /// have to be considered for SROA. - unsigned ArrayElementThreshold; - - /// ScalarLoadThreshold - The maximum size in bits of scalars to load when - /// converting to scalar - unsigned ScalarLoadThreshold; - - void MarkUnsafe(AllocaInfo &I, Instruction *User) { - I.isUnsafe = true; - DEBUG(dbgs() << " Transformation preventing inst: " << *User << '\n'); - } - - bool isSafeAllocaToScalarRepl(AllocaInst *AI); - - void isSafeForScalarRepl(Instruction *I, uint64_t Offset, AllocaInfo &Info); - void isSafePHISelectUseForScalarRepl(Instruction *User, uint64_t Offset, - AllocaInfo &Info); - void isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset, AllocaInfo &Info); - void isSafeMemAccess(uint64_t Offset, uint64_t MemSize, - Type *MemOpType, bool isStore, AllocaInfo &Info, - Instruction *TheAccess, bool AllowWholeAccess); - bool TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size, - const DataLayout &DL); - uint64_t FindElementAndOffset(Type *&T, uint64_t &Offset, Type *&IdxTy, - const DataLayout &DL); - - void DoScalarReplacement(AllocaInst *AI, - std::vector<AllocaInst*> &WorkList); - void DeleteDeadInstructions(); - - void RewriteForScalarRepl(Instruction *I, AllocaInst *AI, uint64_t Offset, - SmallVectorImpl<AllocaInst *> &NewElts); - void RewriteBitCast(BitCastInst *BC, AllocaInst *AI, uint64_t Offset, - SmallVectorImpl<AllocaInst *> &NewElts); - void RewriteGEP(GetElementPtrInst *GEPI, AllocaInst *AI, uint64_t Offset, - SmallVectorImpl<AllocaInst *> &NewElts); - void RewriteLifetimeIntrinsic(IntrinsicInst *II, AllocaInst *AI, - uint64_t Offset, - SmallVectorImpl<AllocaInst *> &NewElts); - void RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst, - AllocaInst *AI, - SmallVectorImpl<AllocaInst *> &NewElts); - void RewriteStoreUserOfWholeAlloca(StoreInst *SI, AllocaInst *AI, - SmallVectorImpl<AllocaInst *> &NewElts); - void RewriteLoadUserOfWholeAlloca(LoadInst *LI, AllocaInst *AI, - SmallVectorImpl<AllocaInst *> &NewElts); - bool ShouldAttemptScalarRepl(AllocaInst *AI); - }; - - // SROA_DT - SROA that uses DominatorTree. - struct SROA_DT : public SROA { - static char ID; - public: - SROA_DT(int T = -1, int ST = -1, int AT = -1, int SLT = -1) : - SROA(T, true, ID, ST, AT, SLT) { - initializeSROA_DTPass(*PassRegistry::getPassRegistry()); - } - - // getAnalysisUsage - This pass does not require any passes, but we know it - // will not alter the CFG, so say so. - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.setPreservesCFG(); - } - }; - - // SROA_SSAUp - SROA that uses SSAUpdater. - struct SROA_SSAUp : public SROA { - static char ID; - public: - SROA_SSAUp(int T = -1, int ST = -1, int AT = -1, int SLT = -1) : - SROA(T, false, ID, ST, AT, SLT) { - initializeSROA_SSAUpPass(*PassRegistry::getPassRegistry()); - } - - // getAnalysisUsage - This pass does not require any passes, but we know it - // will not alter the CFG, so say so. - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.setPreservesCFG(); - } - }; - -} - -char SROA_DT::ID = 0; -char SROA_SSAUp::ID = 0; - -INITIALIZE_PASS_BEGIN(SROA_DT, "scalarrepl", - "Scalar Replacement of Aggregates (DT)", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(SROA_DT, "scalarrepl", - "Scalar Replacement of Aggregates (DT)", false, false) - -INITIALIZE_PASS_BEGIN(SROA_SSAUp, "scalarrepl-ssa", - "Scalar Replacement of Aggregates (SSAUp)", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(SROA_SSAUp, "scalarrepl-ssa", - "Scalar Replacement of Aggregates (SSAUp)", false, false) - -// Public interface to the ScalarReplAggregates pass -FunctionPass *llvm::createScalarReplAggregatesPass(int Threshold, - bool UseDomTree, - int StructMemberThreshold, - int ArrayElementThreshold, - int ScalarLoadThreshold) { - if (UseDomTree) - return new SROA_DT(Threshold, StructMemberThreshold, ArrayElementThreshold, - ScalarLoadThreshold); - return new SROA_SSAUp(Threshold, StructMemberThreshold, - ArrayElementThreshold, ScalarLoadThreshold); -} - - -//===----------------------------------------------------------------------===// -// Convert To Scalar Optimization. -//===----------------------------------------------------------------------===// - -namespace { -/// ConvertToScalarInfo - This class implements the "Convert To Scalar" -/// optimization, which scans the uses of an alloca and determines if it can -/// rewrite it in terms of a single new alloca that can be mem2reg'd. -class ConvertToScalarInfo { - /// AllocaSize - The size of the alloca being considered in bytes. - unsigned AllocaSize; - const DataLayout &DL; - unsigned ScalarLoadThreshold; - - /// IsNotTrivial - This is set to true if there is some access to the object - /// which means that mem2reg can't promote it. - bool IsNotTrivial; - - /// ScalarKind - Tracks the kind of alloca being considered for promotion, - /// computed based on the uses of the alloca rather than the LLVM type system. - enum { - Unknown, - - // Accesses via GEPs that are consistent with element access of a vector - // type. This will not be converted into a vector unless there is a later - // access using an actual vector type. - ImplicitVector, - - // Accesses via vector operations and GEPs that are consistent with the - // layout of a vector type. - Vector, - - // An integer bag-of-bits with bitwise operations for insertion and - // extraction. Any combination of types can be converted into this kind - // of scalar. - Integer - } ScalarKind; - - /// VectorTy - This tracks the type that we should promote the vector to if - /// it is possible to turn it into a vector. This starts out null, and if it - /// isn't possible to turn into a vector type, it gets set to VoidTy. - VectorType *VectorTy; - - /// HadNonMemTransferAccess - True if there is at least one access to the - /// alloca that is not a MemTransferInst. We don't want to turn structs into - /// large integers unless there is some potential for optimization. - bool HadNonMemTransferAccess; - - /// HadDynamicAccess - True if some element of this alloca was dynamic. - /// We don't yet have support for turning a dynamic access into a large - /// integer. - bool HadDynamicAccess; - -public: - explicit ConvertToScalarInfo(unsigned Size, const DataLayout &DL, - unsigned SLT) - : AllocaSize(Size), DL(DL), ScalarLoadThreshold(SLT), IsNotTrivial(false), - ScalarKind(Unknown), VectorTy(nullptr), HadNonMemTransferAccess(false), - HadDynamicAccess(false) { } - - AllocaInst *TryConvert(AllocaInst *AI); - -private: - bool CanConvertToScalar(Value *V, uint64_t Offset, Value* NonConstantIdx); - void MergeInTypeForLoadOrStore(Type *In, uint64_t Offset); - bool MergeInVectorType(VectorType *VInTy, uint64_t Offset); - void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset, - Value *NonConstantIdx); - - Value *ConvertScalar_ExtractValue(Value *NV, Type *ToType, - uint64_t Offset, Value* NonConstantIdx, - IRBuilder<> &Builder); - Value *ConvertScalar_InsertValue(Value *StoredVal, Value *ExistingVal, - uint64_t Offset, Value* NonConstantIdx, - IRBuilder<> &Builder); -}; -} // end anonymous namespace. - - -/// TryConvert - Analyze the specified alloca, and if it is safe to do so, -/// rewrite it to be a new alloca which is mem2reg'able. This returns the new -/// alloca if possible or null if not. -AllocaInst *ConvertToScalarInfo::TryConvert(AllocaInst *AI) { - // If we can't convert this scalar, or if mem2reg can trivially do it, bail - // out. - if (!CanConvertToScalar(AI, 0, nullptr) || !IsNotTrivial) - return nullptr; - - // If an alloca has only memset / memcpy uses, it may still have an Unknown - // ScalarKind. Treat it as an Integer below. - if (ScalarKind == Unknown) - ScalarKind = Integer; - - if (ScalarKind == Vector && VectorTy->getBitWidth() != AllocaSize * 8) - ScalarKind = Integer; - - // If we were able to find a vector type that can handle this with - // insert/extract elements, and if there was at least one use that had - // a vector type, promote this to a vector. We don't want to promote - // random stuff that doesn't use vectors (e.g. <9 x double>) because then - // we just get a lot of insert/extracts. If at least one vector is - // involved, then we probably really do have a union of vector/array. - Type *NewTy; - if (ScalarKind == Vector) { - assert(VectorTy && "Missing type for vector scalar."); - DEBUG(dbgs() << "CONVERT TO VECTOR: " << *AI << "\n TYPE = " - << *VectorTy << '\n'); - NewTy = VectorTy; // Use the vector type. - } else { - unsigned BitWidth = AllocaSize * 8; - - // Do not convert to scalar integer if the alloca size exceeds the - // scalar load threshold. - if (BitWidth > ScalarLoadThreshold) - return nullptr; - - if ((ScalarKind == ImplicitVector || ScalarKind == Integer) && - !HadNonMemTransferAccess && !DL.fitsInLegalInteger(BitWidth)) - return nullptr; - // Dynamic accesses on integers aren't yet supported. They need us to shift - // by a dynamic amount which could be difficult to work out as we might not - // know whether to use a left or right shift. - if (ScalarKind == Integer && HadDynamicAccess) - return nullptr; - - DEBUG(dbgs() << "CONVERT TO SCALAR INTEGER: " << *AI << "\n"); - // Create and insert the integer alloca. - NewTy = IntegerType::get(AI->getContext(), BitWidth); - } - AllocaInst *NewAI = - new AllocaInst(NewTy, nullptr, "", &AI->getParent()->front()); - ConvertUsesToScalar(AI, NewAI, 0, nullptr); - return NewAI; -} - -/// MergeInTypeForLoadOrStore - Add the 'In' type to the accumulated vector type -/// (VectorTy) so far at the offset specified by Offset (which is specified in -/// bytes). -/// -/// There are two cases we handle here: -/// 1) A union of vector types of the same size and potentially its elements. -/// Here we turn element accesses into insert/extract element operations. -/// This promotes a <4 x float> with a store of float to the third element -/// into a <4 x float> that uses insert element. -/// 2) A fully general blob of memory, which we turn into some (potentially -/// large) integer type with extract and insert operations where the loads -/// and stores would mutate the memory. We mark this by setting VectorTy -/// to VoidTy. -void ConvertToScalarInfo::MergeInTypeForLoadOrStore(Type *In, - uint64_t Offset) { - // If we already decided to turn this into a blob of integer memory, there is - // nothing to be done. - if (ScalarKind == Integer) - return; - - // If this could be contributing to a vector, analyze it. - - // If the In type is a vector that is the same size as the alloca, see if it - // matches the existing VecTy. - if (VectorType *VInTy = dyn_cast<VectorType>(In)) { - if (MergeInVectorType(VInTy, Offset)) - return; - } else if (In->isFloatTy() || In->isDoubleTy() || - (In->isIntegerTy() && In->getPrimitiveSizeInBits() >= 8 && - isPowerOf2_32(In->getPrimitiveSizeInBits()))) { - // Full width accesses can be ignored, because they can always be turned - // into bitcasts. - unsigned EltSize = In->getPrimitiveSizeInBits()/8; - if (EltSize == AllocaSize) - return; - - // If we're accessing something that could be an element of a vector, see - // if the implied vector agrees with what we already have and if Offset is - // compatible with it. - if (Offset % EltSize == 0 && AllocaSize % EltSize == 0 && - (!VectorTy || EltSize == VectorTy->getElementType() - ->getPrimitiveSizeInBits()/8)) { - if (!VectorTy) { - ScalarKind = ImplicitVector; - VectorTy = VectorType::get(In, AllocaSize/EltSize); - } - return; - } - } - - // Otherwise, we have a case that we can't handle with an optimized vector - // form. We can still turn this into a large integer. - ScalarKind = Integer; -} - -/// MergeInVectorType - Handles the vector case of MergeInTypeForLoadOrStore, -/// returning true if the type was successfully merged and false otherwise. -bool ConvertToScalarInfo::MergeInVectorType(VectorType *VInTy, - uint64_t Offset) { - if (VInTy->getBitWidth()/8 == AllocaSize && Offset == 0) { - // If we're storing/loading a vector of the right size, allow it as a - // vector. If this the first vector we see, remember the type so that - // we know the element size. If this is a subsequent access, ignore it - // even if it is a differing type but the same size. Worst case we can - // bitcast the resultant vectors. - if (!VectorTy) - VectorTy = VInTy; - ScalarKind = Vector; - return true; - } - - return false; -} - -/// CanConvertToScalar - V is a pointer. If we can convert the pointee and all -/// its accesses to a single vector type, return true and set VecTy to -/// the new type. If we could convert the alloca into a single promotable -/// integer, return true but set VecTy to VoidTy. Further, if the use is not a -/// completely trivial use that mem2reg could promote, set IsNotTrivial. Offset -/// is the current offset from the base of the alloca being analyzed. -/// -/// If we see at least one access to the value that is as a vector type, set the -/// SawVec flag. -bool ConvertToScalarInfo::CanConvertToScalar(Value *V, uint64_t Offset, - Value* NonConstantIdx) { - for (User *U : V->users()) { - Instruction *UI = cast<Instruction>(U); - - if (LoadInst *LI = dyn_cast<LoadInst>(UI)) { - // Don't break volatile loads. - if (!LI->isSimple()) - return false; - // Don't touch MMX operations. - if (LI->getType()->isX86_MMXTy()) - return false; - HadNonMemTransferAccess = true; - MergeInTypeForLoadOrStore(LI->getType(), Offset); - continue; - } - - if (StoreInst *SI = dyn_cast<StoreInst>(UI)) { - // Storing the pointer, not into the value? - if (SI->getOperand(0) == V || !SI->isSimple()) return false; - // Don't touch MMX operations. - if (SI->getOperand(0)->getType()->isX86_MMXTy()) - return false; - HadNonMemTransferAccess = true; - MergeInTypeForLoadOrStore(SI->getOperand(0)->getType(), Offset); - continue; - } - - if (BitCastInst *BCI = dyn_cast<BitCastInst>(UI)) { - if (!onlyUsedByLifetimeMarkers(BCI)) - IsNotTrivial = true; // Can't be mem2reg'd. - if (!CanConvertToScalar(BCI, Offset, NonConstantIdx)) - return false; - continue; - } - - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UI)) { - // If this is a GEP with a variable indices, we can't handle it. - PointerType* PtrTy = dyn_cast<PointerType>(GEP->getPointerOperandType()); - if (!PtrTy) - return false; - - // Compute the offset that this GEP adds to the pointer. - SmallVector<Value*, 8> Indices(GEP->op_begin()+1, GEP->op_end()); - Value *GEPNonConstantIdx = nullptr; - if (!GEP->hasAllConstantIndices()) { - if (!isa<VectorType>(PtrTy->getElementType())) - return false; - if (NonConstantIdx) - return false; - GEPNonConstantIdx = Indices.pop_back_val(); - if (!GEPNonConstantIdx->getType()->isIntegerTy(32)) - return false; - HadDynamicAccess = true; - } else - GEPNonConstantIdx = NonConstantIdx; - uint64_t GEPOffset = DL.getIndexedOffset(PtrTy, - Indices); - // See if all uses can be converted. - if (!CanConvertToScalar(GEP, Offset+GEPOffset, GEPNonConstantIdx)) - return false; - IsNotTrivial = true; // Can't be mem2reg'd. - HadNonMemTransferAccess = true; - continue; - } - - // If this is a constant sized memset of a constant value (e.g. 0) we can - // handle it. - if (MemSetInst *MSI = dyn_cast<MemSetInst>(UI)) { - // Store to dynamic index. - if (NonConstantIdx) - return false; - // Store of constant value. - if (!isa<ConstantInt>(MSI->getValue())) - return false; - - // Store of constant size. - ConstantInt *Len = dyn_cast<ConstantInt>(MSI->getLength()); - if (!Len) - return false; - - // If the size differs from the alloca, we can only convert the alloca to - // an integer bag-of-bits. - // FIXME: This should handle all of the cases that are currently accepted - // as vector element insertions. - if (Len->getZExtValue() != AllocaSize || Offset != 0) - ScalarKind = Integer; - - IsNotTrivial = true; // Can't be mem2reg'd. - HadNonMemTransferAccess = true; - continue; - } - - // If this is a memcpy or memmove into or out of the whole allocation, we - // can handle it like a load or store of the scalar type. - if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(UI)) { - // Store to dynamic index. - if (NonConstantIdx) - return false; - ConstantInt *Len = dyn_cast<ConstantInt>(MTI->getLength()); - if (!Len || Len->getZExtValue() != AllocaSize || Offset != 0) - return false; - - IsNotTrivial = true; // Can't be mem2reg'd. - continue; - } - - // If this is a lifetime intrinsic, we can handle it. - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(UI)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) { - continue; - } - } - - // Otherwise, we cannot handle this! - return false; - } - - return true; -} - -/// ConvertUsesToScalar - Convert all of the users of Ptr to use the new alloca -/// directly. This happens when we are converting an "integer union" to a -/// single integer scalar, or when we are converting a "vector union" to a -/// vector with insert/extractelement instructions. -/// -/// Offset is an offset from the original alloca, in bits that need to be -/// shifted to the right. By the end of this, there should be no uses of Ptr. -void ConvertToScalarInfo::ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, - uint64_t Offset, - Value* NonConstantIdx) { - while (!Ptr->use_empty()) { - Instruction *User = cast<Instruction>(Ptr->user_back()); - - if (BitCastInst *CI = dyn_cast<BitCastInst>(User)) { - ConvertUsesToScalar(CI, NewAI, Offset, NonConstantIdx); - CI->eraseFromParent(); - continue; - } - - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) { - // Compute the offset that this GEP adds to the pointer. - SmallVector<Value*, 8> Indices(GEP->op_begin()+1, GEP->op_end()); - Value* GEPNonConstantIdx = nullptr; - if (!GEP->hasAllConstantIndices()) { - assert(!NonConstantIdx && - "Dynamic GEP reading from dynamic GEP unsupported"); - GEPNonConstantIdx = Indices.pop_back_val(); - } else - GEPNonConstantIdx = NonConstantIdx; - uint64_t GEPOffset = DL.getIndexedOffset(GEP->getPointerOperandType(), - Indices); - ConvertUsesToScalar(GEP, NewAI, Offset+GEPOffset*8, GEPNonConstantIdx); - GEP->eraseFromParent(); - continue; - } - - IRBuilder<> Builder(User); - - if (LoadInst *LI = dyn_cast<LoadInst>(User)) { - // The load is a bit extract from NewAI shifted right by Offset bits. - Value *LoadedVal = Builder.CreateLoad(NewAI); - Value *NewLoadVal - = ConvertScalar_ExtractValue(LoadedVal, LI->getType(), Offset, - NonConstantIdx, Builder); - LI->replaceAllUsesWith(NewLoadVal); - LI->eraseFromParent(); - continue; - } - - if (StoreInst *SI = dyn_cast<StoreInst>(User)) { - assert(SI->getOperand(0) != Ptr && "Consistency error!"); - Instruction *Old = Builder.CreateLoad(NewAI, NewAI->getName()+".in"); - Value *New = ConvertScalar_InsertValue(SI->getOperand(0), Old, Offset, - NonConstantIdx, Builder); - Builder.CreateStore(New, NewAI); - SI->eraseFromParent(); - - // If the load we just inserted is now dead, then the inserted store - // overwrote the entire thing. - if (Old->use_empty()) - Old->eraseFromParent(); - continue; - } - - // If this is a constant sized memset of a constant value (e.g. 0) we can - // transform it into a store of the expanded constant value. - if (MemSetInst *MSI = dyn_cast<MemSetInst>(User)) { - assert(MSI->getRawDest() == Ptr && "Consistency error!"); - assert(!NonConstantIdx && "Cannot replace dynamic memset with insert"); - int64_t SNumBytes = cast<ConstantInt>(MSI->getLength())->getSExtValue(); - if (SNumBytes > 0 && (SNumBytes >> 32) == 0) { - unsigned NumBytes = static_cast<unsigned>(SNumBytes); - unsigned Val = cast<ConstantInt>(MSI->getValue())->getZExtValue(); - - // Compute the value replicated the right number of times. - APInt APVal(NumBytes*8, Val); - - // Splat the value if non-zero. - if (Val) - for (unsigned i = 1; i != NumBytes; ++i) - APVal |= APVal << 8; - - Instruction *Old = Builder.CreateLoad(NewAI, NewAI->getName()+".in"); - Value *New = ConvertScalar_InsertValue( - ConstantInt::get(User->getContext(), APVal), - Old, Offset, nullptr, Builder); - Builder.CreateStore(New, NewAI); - - // If the load we just inserted is now dead, then the memset overwrote - // the entire thing. - if (Old->use_empty()) - Old->eraseFromParent(); - } - MSI->eraseFromParent(); - continue; - } - - // If this is a memcpy or memmove into or out of the whole allocation, we - // can handle it like a load or store of the scalar type. - if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(User)) { - assert(Offset == 0 && "must be store to start of alloca"); - assert(!NonConstantIdx && "Cannot replace dynamic transfer with insert"); - - // If the source and destination are both to the same alloca, then this is - // a noop copy-to-self, just delete it. Otherwise, emit a load and store - // as appropriate. - AllocaInst *OrigAI = cast<AllocaInst>(GetUnderlyingObject(Ptr, DL, 0)); - - if (GetUnderlyingObject(MTI->getSource(), DL, 0) != OrigAI) { - // Dest must be OrigAI, change this to be a load from the original - // pointer (bitcasted), then a store to our new alloca. - assert(MTI->getRawDest() == Ptr && "Neither use is of pointer?"); - Value *SrcPtr = MTI->getSource(); - PointerType* SPTy = cast<PointerType>(SrcPtr->getType()); - PointerType* AIPTy = cast<PointerType>(NewAI->getType()); - if (SPTy->getAddressSpace() != AIPTy->getAddressSpace()) { - AIPTy = PointerType::get(AIPTy->getElementType(), - SPTy->getAddressSpace()); - } - SrcPtr = Builder.CreateBitCast(SrcPtr, AIPTy); - - LoadInst *SrcVal = Builder.CreateLoad(SrcPtr, "srcval"); - SrcVal->setAlignment(MTI->getAlignment()); - Builder.CreateStore(SrcVal, NewAI); - } else if (GetUnderlyingObject(MTI->getDest(), DL, 0) != OrigAI) { - // Src must be OrigAI, change this to be a load from NewAI then a store - // through the original dest pointer (bitcasted). - assert(MTI->getRawSource() == Ptr && "Neither use is of pointer?"); - LoadInst *SrcVal = Builder.CreateLoad(NewAI, "srcval"); - - PointerType* DPTy = cast<PointerType>(MTI->getDest()->getType()); - PointerType* AIPTy = cast<PointerType>(NewAI->getType()); - if (DPTy->getAddressSpace() != AIPTy->getAddressSpace()) { - AIPTy = PointerType::get(AIPTy->getElementType(), - DPTy->getAddressSpace()); - } - Value *DstPtr = Builder.CreateBitCast(MTI->getDest(), AIPTy); - - StoreInst *NewStore = Builder.CreateStore(SrcVal, DstPtr); - NewStore->setAlignment(MTI->getAlignment()); - } else { - // Noop transfer. Src == Dst - } - - MTI->eraseFromParent(); - continue; - } - - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(User)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) { - // There's no need to preserve these, as the resulting alloca will be - // converted to a register anyways. - II->eraseFromParent(); - continue; - } - } - - llvm_unreachable("Unsupported operation!"); - } -} - -/// ConvertScalar_ExtractValue - Extract a value of type ToType from an integer -/// or vector value FromVal, extracting the bits from the offset specified by -/// Offset. This returns the value, which is of type ToType. -/// -/// This happens when we are converting an "integer union" to a single -/// integer scalar, or when we are converting a "vector union" to a vector with -/// insert/extractelement instructions. -/// -/// Offset is an offset from the original alloca, in bits that need to be -/// shifted to the right. -Value *ConvertToScalarInfo:: -ConvertScalar_ExtractValue(Value *FromVal, Type *ToType, - uint64_t Offset, Value* NonConstantIdx, - IRBuilder<> &Builder) { - // If the load is of the whole new alloca, no conversion is needed. - Type *FromType = FromVal->getType(); - if (FromType == ToType && Offset == 0) - return FromVal; - - // If the result alloca is a vector type, this is either an element - // access or a bitcast to another vector type of the same size. - if (VectorType *VTy = dyn_cast<VectorType>(FromType)) { - unsigned FromTypeSize = DL.getTypeAllocSize(FromType); - unsigned ToTypeSize = DL.getTypeAllocSize(ToType); - if (FromTypeSize == ToTypeSize) - return Builder.CreateBitCast(FromVal, ToType); - - // Otherwise it must be an element access. - unsigned Elt = 0; - if (Offset) { - unsigned EltSize = DL.getTypeAllocSizeInBits(VTy->getElementType()); - Elt = Offset/EltSize; - assert(EltSize*Elt == Offset && "Invalid modulus in validity checking"); - } - // Return the element extracted out of it. - Value *Idx; - if (NonConstantIdx) { - if (Elt) - Idx = Builder.CreateAdd(NonConstantIdx, - Builder.getInt32(Elt), - "dyn.offset"); - else - Idx = NonConstantIdx; - } else - Idx = Builder.getInt32(Elt); - Value *V = Builder.CreateExtractElement(FromVal, Idx); - if (V->getType() != ToType) - V = Builder.CreateBitCast(V, ToType); - return V; - } - - // If ToType is a first class aggregate, extract out each of the pieces and - // use insertvalue's to form the FCA. - if (StructType *ST = dyn_cast<StructType>(ToType)) { - assert(!NonConstantIdx && - "Dynamic indexing into struct types not supported"); - const StructLayout &Layout = *DL.getStructLayout(ST); - Value *Res = UndefValue::get(ST); - for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { - Value *Elt = ConvertScalar_ExtractValue(FromVal, ST->getElementType(i), - Offset+Layout.getElementOffsetInBits(i), - nullptr, Builder); - Res = Builder.CreateInsertValue(Res, Elt, i); - } - return Res; - } - - if (ArrayType *AT = dyn_cast<ArrayType>(ToType)) { - assert(!NonConstantIdx && - "Dynamic indexing into array types not supported"); - uint64_t EltSize = DL.getTypeAllocSizeInBits(AT->getElementType()); - Value *Res = UndefValue::get(AT); - for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) { - Value *Elt = ConvertScalar_ExtractValue(FromVal, AT->getElementType(), - Offset+i*EltSize, nullptr, - Builder); - Res = Builder.CreateInsertValue(Res, Elt, i); - } - return Res; - } - - // Otherwise, this must be a union that was converted to an integer value. - IntegerType *NTy = cast<IntegerType>(FromVal->getType()); - - // If this is a big-endian system and the load is narrower than the - // full alloca type, we need to do a shift to get the right bits. - int ShAmt = 0; - if (DL.isBigEndian()) { - // On big-endian machines, the lowest bit is stored at the bit offset - // from the pointer given by getTypeStoreSizeInBits. This matters for - // integers with a bitwidth that is not a multiple of 8. - ShAmt = DL.getTypeStoreSizeInBits(NTy) - - DL.getTypeStoreSizeInBits(ToType) - Offset; - } else { - ShAmt = Offset; - } - - // Note: we support negative bitwidths (with shl) which are not defined. - // We do this to support (f.e.) loads off the end of a structure where - // only some bits are used. - if (ShAmt > 0 && (unsigned)ShAmt < NTy->getBitWidth()) - FromVal = Builder.CreateLShr(FromVal, - ConstantInt::get(FromVal->getType(), ShAmt)); - else if (ShAmt < 0 && (unsigned)-ShAmt < NTy->getBitWidth()) - FromVal = Builder.CreateShl(FromVal, - ConstantInt::get(FromVal->getType(), -ShAmt)); - - // Finally, unconditionally truncate the integer to the right width. - unsigned LIBitWidth = DL.getTypeSizeInBits(ToType); - if (LIBitWidth < NTy->getBitWidth()) - FromVal = - Builder.CreateTrunc(FromVal, IntegerType::get(FromVal->getContext(), - LIBitWidth)); - else if (LIBitWidth > NTy->getBitWidth()) - FromVal = - Builder.CreateZExt(FromVal, IntegerType::get(FromVal->getContext(), - LIBitWidth)); - - // If the result is an integer, this is a trunc or bitcast. - if (ToType->isIntegerTy()) { - // Should be done. - } else if (ToType->isFloatingPointTy() || ToType->isVectorTy()) { - // Just do a bitcast, we know the sizes match up. - FromVal = Builder.CreateBitCast(FromVal, ToType); - } else { - // Otherwise must be a pointer. - FromVal = Builder.CreateIntToPtr(FromVal, ToType); - } - assert(FromVal->getType() == ToType && "Didn't convert right?"); - return FromVal; -} - -/// ConvertScalar_InsertValue - Insert the value "SV" into the existing integer -/// or vector value "Old" at the offset specified by Offset. -/// -/// This happens when we are converting an "integer union" to a -/// single integer scalar, or when we are converting a "vector union" to a -/// vector with insert/extractelement instructions. -/// -/// Offset is an offset from the original alloca, in bits that need to be -/// shifted to the right. -/// -/// NonConstantIdx is an index value if there was a GEP with a non-constant -/// index value. If this is 0 then all GEPs used to find this insert address -/// are constant. -Value *ConvertToScalarInfo:: -ConvertScalar_InsertValue(Value *SV, Value *Old, - uint64_t Offset, Value* NonConstantIdx, - IRBuilder<> &Builder) { - // Convert the stored type to the actual type, shift it left to insert - // then 'or' into place. - Type *AllocaType = Old->getType(); - LLVMContext &Context = Old->getContext(); - - if (VectorType *VTy = dyn_cast<VectorType>(AllocaType)) { - uint64_t VecSize = DL.getTypeAllocSizeInBits(VTy); - uint64_t ValSize = DL.getTypeAllocSizeInBits(SV->getType()); - - // Changing the whole vector with memset or with an access of a different - // vector type? - if (ValSize == VecSize) - return Builder.CreateBitCast(SV, AllocaType); - - // Must be an element insertion. - Type *EltTy = VTy->getElementType(); - if (SV->getType() != EltTy) - SV = Builder.CreateBitCast(SV, EltTy); - uint64_t EltSize = DL.getTypeAllocSizeInBits(EltTy); - unsigned Elt = Offset/EltSize; - Value *Idx; - if (NonConstantIdx) { - if (Elt) - Idx = Builder.CreateAdd(NonConstantIdx, - Builder.getInt32(Elt), - "dyn.offset"); - else - Idx = NonConstantIdx; - } else - Idx = Builder.getInt32(Elt); - return Builder.CreateInsertElement(Old, SV, Idx); - } - - // If SV is a first-class aggregate value, insert each value recursively. - if (StructType *ST = dyn_cast<StructType>(SV->getType())) { - assert(!NonConstantIdx && - "Dynamic indexing into struct types not supported"); - const StructLayout &Layout = *DL.getStructLayout(ST); - for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { - Value *Elt = Builder.CreateExtractValue(SV, i); - Old = ConvertScalar_InsertValue(Elt, Old, - Offset+Layout.getElementOffsetInBits(i), - nullptr, Builder); - } - return Old; - } - - if (ArrayType *AT = dyn_cast<ArrayType>(SV->getType())) { - assert(!NonConstantIdx && - "Dynamic indexing into array types not supported"); - uint64_t EltSize = DL.getTypeAllocSizeInBits(AT->getElementType()); - for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) { - Value *Elt = Builder.CreateExtractValue(SV, i); - Old = ConvertScalar_InsertValue(Elt, Old, Offset+i*EltSize, nullptr, - Builder); - } - return Old; - } - - // If SV is a float, convert it to the appropriate integer type. - // If it is a pointer, do the same. - unsigned SrcWidth = DL.getTypeSizeInBits(SV->getType()); - unsigned DestWidth = DL.getTypeSizeInBits(AllocaType); - unsigned SrcStoreWidth = DL.getTypeStoreSizeInBits(SV->getType()); - unsigned DestStoreWidth = DL.getTypeStoreSizeInBits(AllocaType); - if (SV->getType()->isFloatingPointTy() || SV->getType()->isVectorTy()) - SV = Builder.CreateBitCast(SV, IntegerType::get(SV->getContext(),SrcWidth)); - else if (SV->getType()->isPointerTy()) - SV = Builder.CreatePtrToInt(SV, DL.getIntPtrType(SV->getType())); - - // Zero extend or truncate the value if needed. - if (SV->getType() != AllocaType) { - if (SV->getType()->getPrimitiveSizeInBits() < - AllocaType->getPrimitiveSizeInBits()) - SV = Builder.CreateZExt(SV, AllocaType); - else { - // Truncation may be needed if storing more than the alloca can hold - // (undefined behavior). - SV = Builder.CreateTrunc(SV, AllocaType); - SrcWidth = DestWidth; - SrcStoreWidth = DestStoreWidth; - } - } - - // If this is a big-endian system and the store is narrower than the - // full alloca type, we need to do a shift to get the right bits. - int ShAmt = 0; - if (DL.isBigEndian()) { - // On big-endian machines, the lowest bit is stored at the bit offset - // from the pointer given by getTypeStoreSizeInBits. This matters for - // integers with a bitwidth that is not a multiple of 8. - ShAmt = DestStoreWidth - SrcStoreWidth - Offset; - } else { - ShAmt = Offset; - } - - // Note: we support negative bitwidths (with shr) which are not defined. - // We do this to support (f.e.) stores off the end of a structure where - // only some bits in the structure are set. - APInt Mask(APInt::getLowBitsSet(DestWidth, SrcWidth)); - if (ShAmt > 0 && (unsigned)ShAmt < DestWidth) { - SV = Builder.CreateShl(SV, ConstantInt::get(SV->getType(), ShAmt)); - Mask <<= ShAmt; - } else if (ShAmt < 0 && (unsigned)-ShAmt < DestWidth) { - SV = Builder.CreateLShr(SV, ConstantInt::get(SV->getType(), -ShAmt)); - Mask = Mask.lshr(-ShAmt); - } - - // Mask out the bits we are about to insert from the old value, and or - // in the new bits. - if (SrcWidth != DestWidth) { - assert(DestWidth > SrcWidth); - Old = Builder.CreateAnd(Old, ConstantInt::get(Context, ~Mask), "mask"); - SV = Builder.CreateOr(Old, SV, "ins"); - } - return SV; -} - - -//===----------------------------------------------------------------------===// -// SRoA Driver -//===----------------------------------------------------------------------===// - - -bool SROA::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) - return false; - - bool Changed = performPromotion(F); - - while (1) { - bool LocalChange = performScalarRepl(F); - if (!LocalChange) break; // No need to repromote if no scalarrepl - Changed = true; - LocalChange = performPromotion(F); - if (!LocalChange) break; // No need to re-scalarrepl if no promotion - } - - return Changed; -} - -namespace { -class AllocaPromoter : public LoadAndStorePromoter { - AllocaInst *AI; - DIBuilder *DIB; - SmallVector<DbgDeclareInst *, 4> DDIs; - SmallVector<DbgValueInst *, 4> DVIs; -public: - AllocaPromoter(ArrayRef<Instruction*> Insts, SSAUpdater &S, - DIBuilder *DB) - : LoadAndStorePromoter(Insts, S), AI(nullptr), DIB(DB) {} - - void run(AllocaInst *AI, const SmallVectorImpl<Instruction*> &Insts) { - // Remember which alloca we're promoting (for isInstInList). - this->AI = AI; - if (auto *L = LocalAsMetadata::getIfExists(AI)) { - if (auto *DINode = MetadataAsValue::getIfExists(AI->getContext(), L)) { - for (User *U : DINode->users()) - if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) - DDIs.push_back(DDI); - else if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(U)) - DVIs.push_back(DVI); - } - } - - LoadAndStorePromoter::run(Insts); - AI->eraseFromParent(); - for (SmallVectorImpl<DbgDeclareInst *>::iterator I = DDIs.begin(), - E = DDIs.end(); I != E; ++I) { - DbgDeclareInst *DDI = *I; - DDI->eraseFromParent(); - } - for (SmallVectorImpl<DbgValueInst *>::iterator I = DVIs.begin(), - E = DVIs.end(); I != E; ++I) { - DbgValueInst *DVI = *I; - DVI->eraseFromParent(); - } - } - - bool isInstInList(Instruction *I, - const SmallVectorImpl<Instruction*> &Insts) const override { - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - return LI->getOperand(0) == AI; - return cast<StoreInst>(I)->getPointerOperand() == AI; - } - - void updateDebugInfo(Instruction *Inst) const override { - for (SmallVectorImpl<DbgDeclareInst *>::const_iterator I = DDIs.begin(), - E = DDIs.end(); I != E; ++I) { - DbgDeclareInst *DDI = *I; - if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) - ConvertDebugDeclareToDebugValue(DDI, SI, *DIB); - else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) - ConvertDebugDeclareToDebugValue(DDI, LI, *DIB); - } - for (SmallVectorImpl<DbgValueInst *>::const_iterator I = DVIs.begin(), - E = DVIs.end(); I != E; ++I) { - DbgValueInst *DVI = *I; - Value *Arg = nullptr; - if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - // If an argument is zero extended then use argument directly. The ZExt - // may be zapped by an optimization pass in future. - if (ZExtInst *ZExt = dyn_cast<ZExtInst>(SI->getOperand(0))) - Arg = dyn_cast<Argument>(ZExt->getOperand(0)); - if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0))) - Arg = dyn_cast<Argument>(SExt->getOperand(0)); - if (!Arg) - Arg = SI->getOperand(0); - } else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { - Arg = LI->getOperand(0); - } else { - continue; - } - DIB->insertDbgValueIntrinsic(Arg, 0, DVI->getVariable(), - DVI->getExpression(), DVI->getDebugLoc(), - Inst); - } - } -}; -} // end anon namespace - -/// isSafeSelectToSpeculate - Select instructions that use an alloca and are -/// subsequently loaded can be rewritten to load both input pointers and then -/// select between the result, allowing the load of the alloca to be promoted. -/// From this: -/// %P2 = select i1 %cond, i32* %Alloca, i32* %Other -/// %V = load i32* %P2 -/// to: -/// %V1 = load i32* %Alloca -> will be mem2reg'd -/// %V2 = load i32* %Other -/// %V = select i1 %cond, i32 %V1, i32 %V2 -/// -/// We can do this to a select if its only uses are loads and if the operand to -/// the select can be loaded unconditionally. -static bool isSafeSelectToSpeculate(SelectInst *SI) { - const DataLayout &DL = SI->getModule()->getDataLayout(); - bool TDerefable = isDereferenceablePointer(SI->getTrueValue(), DL); - bool FDerefable = isDereferenceablePointer(SI->getFalseValue(), DL); - - for (User *U : SI->users()) { - LoadInst *LI = dyn_cast<LoadInst>(U); - if (!LI || !LI->isSimple()) return false; - - // Both operands to the select need to be dereferencable, either absolutely - // (e.g. allocas) or at this point because we can see other accesses to it. - if (!TDerefable && - !isSafeToLoadUnconditionally(SI->getTrueValue(), LI, - LI->getAlignment())) - return false; - if (!FDerefable && - !isSafeToLoadUnconditionally(SI->getFalseValue(), LI, - LI->getAlignment())) - return false; - } - - return true; -} - -/// isSafePHIToSpeculate - PHI instructions that use an alloca and are -/// subsequently loaded can be rewritten to load both input pointers in the pred -/// blocks and then PHI the results, allowing the load of the alloca to be -/// promoted. -/// From this: -/// %P2 = phi [i32* %Alloca, i32* %Other] -/// %V = load i32* %P2 -/// to: -/// %V1 = load i32* %Alloca -> will be mem2reg'd -/// ... -/// %V2 = load i32* %Other -/// ... -/// %V = phi [i32 %V1, i32 %V2] -/// -/// We can do this to a select if its only uses are loads and if the operand to -/// the select can be loaded unconditionally. -static bool isSafePHIToSpeculate(PHINode *PN) { - // For now, we can only do this promotion if the load is in the same block as - // the PHI, and if there are no stores between the phi and load. - // TODO: Allow recursive phi users. - // TODO: Allow stores. - BasicBlock *BB = PN->getParent(); - unsigned MaxAlign = 0; - for (User *U : PN->users()) { - LoadInst *LI = dyn_cast<LoadInst>(U); - if (!LI || !LI->isSimple()) return false; - - // For now we only allow loads in the same block as the PHI. This is a - // common case that happens when instcombine merges two loads through a PHI. - if (LI->getParent() != BB) return false; - - // Ensure that there are no instructions between the PHI and the load that - // could store. - for (BasicBlock::iterator BBI(PN); &*BBI != LI; ++BBI) - if (BBI->mayWriteToMemory()) - return false; - - MaxAlign = std::max(MaxAlign, LI->getAlignment()); - } - - const DataLayout &DL = PN->getModule()->getDataLayout(); - - // Okay, we know that we have one or more loads in the same block as the PHI. - // We can transform this if it is safe to push the loads into the predecessor - // blocks. The only thing to watch out for is that we can't put a possibly - // trapping load in the predecessor if it is a critical edge. - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { - BasicBlock *Pred = PN->getIncomingBlock(i); - Value *InVal = PN->getIncomingValue(i); - - // If the terminator of the predecessor has side-effects (an invoke), - // there is no safe place to put a load in the predecessor. - if (Pred->getTerminator()->mayHaveSideEffects()) - return false; - - // If the value is produced by the terminator of the predecessor - // (an invoke), there is no valid place to put a load in the predecessor. - if (Pred->getTerminator() == InVal) - return false; - - // If the predecessor has a single successor, then the edge isn't critical. - if (Pred->getTerminator()->getNumSuccessors() == 1) - continue; - - // If this pointer is always safe to load, or if we can prove that there is - // already a load in the block, then we can move the load to the pred block. - if (isDereferenceablePointer(InVal, DL) || - isSafeToLoadUnconditionally(InVal, Pred->getTerminator(), MaxAlign)) - continue; - - return false; - } - - return true; -} - - -/// tryToMakeAllocaBePromotable - This returns true if the alloca only has -/// direct (non-volatile) loads and stores to it. If the alloca is close but -/// not quite there, this will transform the code to allow promotion. As such, -/// it is a non-pure predicate. -static bool tryToMakeAllocaBePromotable(AllocaInst *AI, const DataLayout &DL) { - SetVector<Instruction*, SmallVector<Instruction*, 4>, - SmallPtrSet<Instruction*, 4> > InstsToRewrite; - for (User *U : AI->users()) { - if (LoadInst *LI = dyn_cast<LoadInst>(U)) { - if (!LI->isSimple()) - return false; - continue; - } - - if (StoreInst *SI = dyn_cast<StoreInst>(U)) { - if (SI->getOperand(0) == AI || !SI->isSimple()) - return false; // Don't allow a store OF the AI, only INTO the AI. - continue; - } - - if (SelectInst *SI = dyn_cast<SelectInst>(U)) { - // If the condition being selected on is a constant, fold the select, yes - // this does (rarely) happen early on. - if (ConstantInt *CI = dyn_cast<ConstantInt>(SI->getCondition())) { - Value *Result = SI->getOperand(1+CI->isZero()); - SI->replaceAllUsesWith(Result); - SI->eraseFromParent(); - - // This is very rare and we just scrambled the use list of AI, start - // over completely. - return tryToMakeAllocaBePromotable(AI, DL); - } - - // If it is safe to turn "load (select c, AI, ptr)" into a select of two - // loads, then we can transform this by rewriting the select. - if (!isSafeSelectToSpeculate(SI)) - return false; - - InstsToRewrite.insert(SI); - continue; - } - - if (PHINode *PN = dyn_cast<PHINode>(U)) { - if (PN->use_empty()) { // Dead PHIs can be stripped. - InstsToRewrite.insert(PN); - continue; - } - - // If it is safe to turn "load (phi [AI, ptr, ...])" into a PHI of loads - // in the pred blocks, then we can transform this by rewriting the PHI. - if (!isSafePHIToSpeculate(PN)) - return false; - - InstsToRewrite.insert(PN); - continue; - } - - if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) { - if (onlyUsedByLifetimeMarkers(BCI)) { - InstsToRewrite.insert(BCI); - continue; - } - } - - return false; - } - - // If there are no instructions to rewrite, then all uses are load/stores and - // we're done! - if (InstsToRewrite.empty()) - return true; - - // If we have instructions that need to be rewritten for this to be promotable - // take care of it now. - for (unsigned i = 0, e = InstsToRewrite.size(); i != e; ++i) { - if (BitCastInst *BCI = dyn_cast<BitCastInst>(InstsToRewrite[i])) { - // This could only be a bitcast used by nothing but lifetime intrinsics. - for (BitCastInst::user_iterator I = BCI->user_begin(), E = BCI->user_end(); - I != E;) - cast<Instruction>(*I++)->eraseFromParent(); - BCI->eraseFromParent(); - continue; - } - - if (SelectInst *SI = dyn_cast<SelectInst>(InstsToRewrite[i])) { - // Selects in InstsToRewrite only have load uses. Rewrite each as two - // loads with a new select. - while (!SI->use_empty()) { - LoadInst *LI = cast<LoadInst>(SI->user_back()); - - IRBuilder<> Builder(LI); - LoadInst *TrueLoad = - Builder.CreateLoad(SI->getTrueValue(), LI->getName()+".t"); - LoadInst *FalseLoad = - Builder.CreateLoad(SI->getFalseValue(), LI->getName()+".f"); - - // Transfer alignment and AA info if present. - TrueLoad->setAlignment(LI->getAlignment()); - FalseLoad->setAlignment(LI->getAlignment()); - - AAMDNodes Tags; - LI->getAAMetadata(Tags); - if (Tags) { - TrueLoad->setAAMetadata(Tags); - FalseLoad->setAAMetadata(Tags); - } - - Value *V = Builder.CreateSelect(SI->getCondition(), TrueLoad, FalseLoad); - V->takeName(LI); - LI->replaceAllUsesWith(V); - LI->eraseFromParent(); - } - - // Now that all the loads are gone, the select is gone too. - SI->eraseFromParent(); - continue; - } - - // Otherwise, we have a PHI node which allows us to push the loads into the - // predecessors. - PHINode *PN = cast<PHINode>(InstsToRewrite[i]); - if (PN->use_empty()) { - PN->eraseFromParent(); - continue; - } - - Type *LoadTy = cast<PointerType>(PN->getType())->getElementType(); - PHINode *NewPN = PHINode::Create(LoadTy, PN->getNumIncomingValues(), - PN->getName()+".ld", PN); - - // Get the AA tags and alignment to use from one of the loads. It doesn't - // matter which one we get and if any differ, it doesn't matter. - LoadInst *SomeLoad = cast<LoadInst>(PN->user_back()); - - AAMDNodes AATags; - SomeLoad->getAAMetadata(AATags); - unsigned Align = SomeLoad->getAlignment(); - - // Rewrite all loads of the PN to use the new PHI. - while (!PN->use_empty()) { - LoadInst *LI = cast<LoadInst>(PN->user_back()); - LI->replaceAllUsesWith(NewPN); - LI->eraseFromParent(); - } - - // Inject loads into all of the pred blocks. Keep track of which blocks we - // insert them into in case we have multiple edges from the same block. - DenseMap<BasicBlock*, LoadInst*> InsertedLoads; - - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { - BasicBlock *Pred = PN->getIncomingBlock(i); - LoadInst *&Load = InsertedLoads[Pred]; - if (!Load) { - Load = new LoadInst(PN->getIncomingValue(i), - PN->getName() + "." + Pred->getName(), - Pred->getTerminator()); - Load->setAlignment(Align); - if (AATags) Load->setAAMetadata(AATags); - } - - NewPN->addIncoming(Load, Pred); - } - - PN->eraseFromParent(); - } - - ++NumAdjusted; - return true; -} - -bool SROA::performPromotion(Function &F) { - std::vector<AllocaInst*> Allocas; - const DataLayout &DL = F.getParent()->getDataLayout(); - DominatorTree *DT = nullptr; - if (HasDomTree) - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - AssumptionCache &AC = - getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - - BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function - DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false); - bool Changed = false; - SmallVector<Instruction*, 64> Insts; - while (1) { - Allocas.clear(); - - // Find allocas that are safe to promote, by looking at all instructions in - // the entry node - for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I) - if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) // Is it an alloca? - if (tryToMakeAllocaBePromotable(AI, DL)) - Allocas.push_back(AI); - - if (Allocas.empty()) break; - - if (HasDomTree) - PromoteMemToReg(Allocas, *DT, nullptr, &AC); - else { - SSAUpdater SSA; - for (unsigned i = 0, e = Allocas.size(); i != e; ++i) { - AllocaInst *AI = Allocas[i]; - - // Build list of instructions to promote. - for (User *U : AI->users()) - Insts.push_back(cast<Instruction>(U)); - AllocaPromoter(Insts, SSA, &DIB).run(AI, Insts); - Insts.clear(); - } - } - NumPromoted += Allocas.size(); - Changed = true; - } - - return Changed; -} - - -/// ShouldAttemptScalarRepl - Decide if an alloca is a good candidate for -/// SROA. It must be a struct or array type with a small number of elements. -bool SROA::ShouldAttemptScalarRepl(AllocaInst *AI) { - Type *T = AI->getAllocatedType(); - // Do not promote any struct that has too many members. - if (StructType *ST = dyn_cast<StructType>(T)) - return ST->getNumElements() <= StructMemberThreshold; - // Do not promote any array that has too many elements. - if (ArrayType *AT = dyn_cast<ArrayType>(T)) - return AT->getNumElements() <= ArrayElementThreshold; - return false; -} - -// performScalarRepl - This algorithm is a simple worklist driven algorithm, -// which runs on all of the alloca instructions in the entry block, removing -// them if they are only used by getelementptr instructions. -// -bool SROA::performScalarRepl(Function &F) { - std::vector<AllocaInst*> WorkList; - const DataLayout &DL = F.getParent()->getDataLayout(); - - // Scan the entry basic block, adding allocas to the worklist. - BasicBlock &BB = F.getEntryBlock(); - for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) - if (AllocaInst *A = dyn_cast<AllocaInst>(I)) - WorkList.push_back(A); - - // Process the worklist - bool Changed = false; - while (!WorkList.empty()) { - AllocaInst *AI = WorkList.back(); - WorkList.pop_back(); - - // Handle dead allocas trivially. These can be formed by SROA'ing arrays - // with unused elements. - if (AI->use_empty()) { - AI->eraseFromParent(); - Changed = true; - continue; - } - - // If this alloca is impossible for us to promote, reject it early. - if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized()) - continue; - - // Check to see if we can perform the core SROA transformation. We cannot - // transform the allocation instruction if it is an array allocation - // (allocations OF arrays are ok though), and an allocation of a scalar - // value cannot be decomposed at all. - uint64_t AllocaSize = DL.getTypeAllocSize(AI->getAllocatedType()); - - // Do not promote [0 x %struct]. - if (AllocaSize == 0) continue; - - // Do not promote any struct whose size is too big. - if (AllocaSize > SRThreshold) continue; - - // If the alloca looks like a good candidate for scalar replacement, and if - // all its users can be transformed, then split up the aggregate into its - // separate elements. - if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) { - DoScalarReplacement(AI, WorkList); - Changed = true; - continue; - } - - // If we can turn this aggregate value (potentially with casts) into a - // simple scalar value that can be mem2reg'd into a register value. - // IsNotTrivial tracks whether this is something that mem2reg could have - // promoted itself. If so, we don't want to transform it needlessly. Note - // that we can't just check based on the type: the alloca may be of an i32 - // but that has pointer arithmetic to set byte 3 of it or something. - if (AllocaInst *NewAI = - ConvertToScalarInfo((unsigned)AllocaSize, DL, ScalarLoadThreshold) - .TryConvert(AI)) { - NewAI->takeName(AI); - AI->eraseFromParent(); - ++NumConverted; - Changed = true; - continue; - } - - // Otherwise, couldn't process this alloca. - } - - return Changed; -} - -/// DoScalarReplacement - This alloca satisfied the isSafeAllocaToScalarRepl -/// predicate, do SROA now. -void SROA::DoScalarReplacement(AllocaInst *AI, - std::vector<AllocaInst*> &WorkList) { - DEBUG(dbgs() << "Found inst to SROA: " << *AI << '\n'); - SmallVector<AllocaInst*, 32> ElementAllocas; - if (StructType *ST = dyn_cast<StructType>(AI->getAllocatedType())) { - ElementAllocas.reserve(ST->getNumContainedTypes()); - for (unsigned i = 0, e = ST->getNumContainedTypes(); i != e; ++i) { - AllocaInst *NA = new AllocaInst(ST->getContainedType(i), nullptr, - AI->getAlignment(), - AI->getName() + "." + Twine(i), AI); - ElementAllocas.push_back(NA); - WorkList.push_back(NA); // Add to worklist for recursive processing - } - } else { - ArrayType *AT = cast<ArrayType>(AI->getAllocatedType()); - ElementAllocas.reserve(AT->getNumElements()); - Type *ElTy = AT->getElementType(); - for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) { - AllocaInst *NA = new AllocaInst(ElTy, nullptr, AI->getAlignment(), - AI->getName() + "." + Twine(i), AI); - ElementAllocas.push_back(NA); - WorkList.push_back(NA); // Add to worklist for recursive processing - } - } - - // Now that we have created the new alloca instructions, rewrite all the - // uses of the old alloca. - RewriteForScalarRepl(AI, AI, 0, ElementAllocas); - - // Now erase any instructions that were made dead while rewriting the alloca. - DeleteDeadInstructions(); - AI->eraseFromParent(); - - ++NumReplaced; -} - -/// DeleteDeadInstructions - Erase instructions on the DeadInstrs list, -/// recursively including all their operands that become trivially dead. -void SROA::DeleteDeadInstructions() { - while (!DeadInsts.empty()) { - Instruction *I = cast<Instruction>(DeadInsts.pop_back_val()); - - for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI) - if (Instruction *U = dyn_cast<Instruction>(*OI)) { - // Zero out the operand and see if it becomes trivially dead. - // (But, don't add allocas to the dead instruction list -- they are - // already on the worklist and will be deleted separately.) - *OI = nullptr; - if (isInstructionTriviallyDead(U) && !isa<AllocaInst>(U)) - DeadInsts.push_back(U); - } - - I->eraseFromParent(); - } -} - -/// isSafeForScalarRepl - Check if instruction I is a safe use with regard to -/// performing scalar replacement of alloca AI. The results are flagged in -/// the Info parameter. Offset indicates the position within AI that is -/// referenced by this instruction. -void SROA::isSafeForScalarRepl(Instruction *I, uint64_t Offset, - AllocaInfo &Info) { - const DataLayout &DL = I->getModule()->getDataLayout(); - for (Use &U : I->uses()) { - Instruction *User = cast<Instruction>(U.getUser()); - - if (BitCastInst *BC = dyn_cast<BitCastInst>(User)) { - isSafeForScalarRepl(BC, Offset, Info); - } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(User)) { - uint64_t GEPOffset = Offset; - isSafeGEP(GEPI, GEPOffset, Info); - if (!Info.isUnsafe) - isSafeForScalarRepl(GEPI, GEPOffset, Info); - } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) { - ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength()); - if (!Length || Length->isNegative()) - return MarkUnsafe(Info, User); - - isSafeMemAccess(Offset, Length->getZExtValue(), nullptr, - U.getOperandNo() == 0, Info, MI, - true /*AllowWholeAccess*/); - } else if (LoadInst *LI = dyn_cast<LoadInst>(User)) { - if (!LI->isSimple()) - return MarkUnsafe(Info, User); - Type *LIType = LI->getType(); - isSafeMemAccess(Offset, DL.getTypeAllocSize(LIType), LIType, false, Info, - LI, true /*AllowWholeAccess*/); - Info.hasALoadOrStore = true; - - } else if (StoreInst *SI = dyn_cast<StoreInst>(User)) { - // Store is ok if storing INTO the pointer, not storing the pointer - if (!SI->isSimple() || SI->getOperand(0) == I) - return MarkUnsafe(Info, User); - - Type *SIType = SI->getOperand(0)->getType(); - isSafeMemAccess(Offset, DL.getTypeAllocSize(SIType), SIType, true, Info, - SI, true /*AllowWholeAccess*/); - Info.hasALoadOrStore = true; - } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(User)) { - if (II->getIntrinsicID() != Intrinsic::lifetime_start && - II->getIntrinsicID() != Intrinsic::lifetime_end) - return MarkUnsafe(Info, User); - } else if (isa<PHINode>(User) || isa<SelectInst>(User)) { - isSafePHISelectUseForScalarRepl(User, Offset, Info); - } else { - return MarkUnsafe(Info, User); - } - if (Info.isUnsafe) return; - } -} - - -/// isSafePHIUseForScalarRepl - If we see a PHI node or select using a pointer -/// derived from the alloca, we can often still split the alloca into elements. -/// This is useful if we have a large alloca where one element is phi'd -/// together somewhere: we can SRoA and promote all the other elements even if -/// we end up not being able to promote this one. -/// -/// All we require is that the uses of the PHI do not index into other parts of -/// the alloca. The most important use case for this is single load and stores -/// that are PHI'd together, which can happen due to code sinking. -void SROA::isSafePHISelectUseForScalarRepl(Instruction *I, uint64_t Offset, - AllocaInfo &Info) { - // If we've already checked this PHI, don't do it again. - if (PHINode *PN = dyn_cast<PHINode>(I)) - if (!Info.CheckedPHIs.insert(PN).second) - return; - - const DataLayout &DL = I->getModule()->getDataLayout(); - for (User *U : I->users()) { - Instruction *UI = cast<Instruction>(U); - - if (BitCastInst *BC = dyn_cast<BitCastInst>(UI)) { - isSafePHISelectUseForScalarRepl(BC, Offset, Info); - } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(UI)) { - // Only allow "bitcast" GEPs for simplicity. We could generalize this, - // but would have to prove that we're staying inside of an element being - // promoted. - if (!GEPI->hasAllZeroIndices()) - return MarkUnsafe(Info, UI); - isSafePHISelectUseForScalarRepl(GEPI, Offset, Info); - } else if (LoadInst *LI = dyn_cast<LoadInst>(UI)) { - if (!LI->isSimple()) - return MarkUnsafe(Info, UI); - Type *LIType = LI->getType(); - isSafeMemAccess(Offset, DL.getTypeAllocSize(LIType), LIType, false, Info, - LI, false /*AllowWholeAccess*/); - Info.hasALoadOrStore = true; - - } else if (StoreInst *SI = dyn_cast<StoreInst>(UI)) { - // Store is ok if storing INTO the pointer, not storing the pointer - if (!SI->isSimple() || SI->getOperand(0) == I) - return MarkUnsafe(Info, UI); - - Type *SIType = SI->getOperand(0)->getType(); - isSafeMemAccess(Offset, DL.getTypeAllocSize(SIType), SIType, true, Info, - SI, false /*AllowWholeAccess*/); - Info.hasALoadOrStore = true; - } else if (isa<PHINode>(UI) || isa<SelectInst>(UI)) { - isSafePHISelectUseForScalarRepl(UI, Offset, Info); - } else { - return MarkUnsafe(Info, UI); - } - if (Info.isUnsafe) return; - } -} - -/// isSafeGEP - Check if a GEP instruction can be handled for scalar -/// replacement. It is safe when all the indices are constant, in-bounds -/// references, and when the resulting offset corresponds to an element within -/// the alloca type. The results are flagged in the Info parameter. Upon -/// return, Offset is adjusted as specified by the GEP indices. -void SROA::isSafeGEP(GetElementPtrInst *GEPI, - uint64_t &Offset, AllocaInfo &Info) { - gep_type_iterator GEPIt = gep_type_begin(GEPI), E = gep_type_end(GEPI); - if (GEPIt == E) - return; - bool NonConstant = false; - unsigned NonConstantIdxSize = 0; - - // Walk through the GEP type indices, checking the types that this indexes - // into. - for (; GEPIt != E; ++GEPIt) { - // Ignore struct elements, no extra checking needed for these. - if ((*GEPIt)->isStructTy()) - continue; - - ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand()); - if (!IdxVal) - return MarkUnsafe(Info, GEPI); - } - - // Compute the offset due to this GEP and check if the alloca has a - // component element at that offset. - SmallVector<Value*, 8> Indices(GEPI->op_begin() + 1, GEPI->op_end()); - // If this GEP is non-constant then the last operand must have been a - // dynamic index into a vector. Pop this now as it has no impact on the - // constant part of the offset. - if (NonConstant) - Indices.pop_back(); - - const DataLayout &DL = GEPI->getModule()->getDataLayout(); - Offset += DL.getIndexedOffset(GEPI->getPointerOperandType(), Indices); - if (!TypeHasComponent(Info.AI->getAllocatedType(), Offset, NonConstantIdxSize, - DL)) - MarkUnsafe(Info, GEPI); -} - -/// isHomogeneousAggregate - Check if type T is a struct or array containing -/// elements of the same type (which is always true for arrays). If so, -/// return true with NumElts and EltTy set to the number of elements and the -/// element type, respectively. -static bool isHomogeneousAggregate(Type *T, unsigned &NumElts, - Type *&EltTy) { - if (ArrayType *AT = dyn_cast<ArrayType>(T)) { - NumElts = AT->getNumElements(); - EltTy = (NumElts == 0 ? nullptr : AT->getElementType()); - return true; - } - if (StructType *ST = dyn_cast<StructType>(T)) { - NumElts = ST->getNumContainedTypes(); - EltTy = (NumElts == 0 ? nullptr : ST->getContainedType(0)); - for (unsigned n = 1; n < NumElts; ++n) { - if (ST->getContainedType(n) != EltTy) - return false; - } - return true; - } - return false; -} - -/// isCompatibleAggregate - Check if T1 and T2 are either the same type or are -/// "homogeneous" aggregates with the same element type and number of elements. -static bool isCompatibleAggregate(Type *T1, Type *T2) { - if (T1 == T2) - return true; - - unsigned NumElts1, NumElts2; - Type *EltTy1, *EltTy2; - if (isHomogeneousAggregate(T1, NumElts1, EltTy1) && - isHomogeneousAggregate(T2, NumElts2, EltTy2) && - NumElts1 == NumElts2 && - EltTy1 == EltTy2) - return true; - - return false; -} - -/// isSafeMemAccess - Check if a load/store/memcpy operates on the entire AI -/// alloca or has an offset and size that corresponds to a component element -/// within it. The offset checked here may have been formed from a GEP with a -/// pointer bitcasted to a different type. -/// -/// If AllowWholeAccess is true, then this allows uses of the entire alloca as a -/// unit. If false, it only allows accesses known to be in a single element. -void SROA::isSafeMemAccess(uint64_t Offset, uint64_t MemSize, - Type *MemOpType, bool isStore, - AllocaInfo &Info, Instruction *TheAccess, - bool AllowWholeAccess) { - const DataLayout &DL = TheAccess->getModule()->getDataLayout(); - // Check if this is a load/store of the entire alloca. - if (Offset == 0 && AllowWholeAccess && - MemSize == DL.getTypeAllocSize(Info.AI->getAllocatedType())) { - // This can be safe for MemIntrinsics (where MemOpType is 0) and integer - // loads/stores (which are essentially the same as the MemIntrinsics with - // regard to copying padding between elements). But, if an alloca is - // flagged as both a source and destination of such operations, we'll need - // to check later for padding between elements. - if (!MemOpType || MemOpType->isIntegerTy()) { - if (isStore) - Info.isMemCpyDst = true; - else - Info.isMemCpySrc = true; - return; - } - // This is also safe for references using a type that is compatible with - // the type of the alloca, so that loads/stores can be rewritten using - // insertvalue/extractvalue. - if (isCompatibleAggregate(MemOpType, Info.AI->getAllocatedType())) { - Info.hasSubelementAccess = true; - return; - } - } - // Check if the offset/size correspond to a component within the alloca type. - Type *T = Info.AI->getAllocatedType(); - if (TypeHasComponent(T, Offset, MemSize, DL)) { - Info.hasSubelementAccess = true; - return; - } - - return MarkUnsafe(Info, TheAccess); -} - -/// TypeHasComponent - Return true if T has a component type with the -/// specified offset and size. If Size is zero, do not check the size. -bool SROA::TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size, - const DataLayout &DL) { - Type *EltTy; - uint64_t EltSize; - if (StructType *ST = dyn_cast<StructType>(T)) { - const StructLayout *Layout = DL.getStructLayout(ST); - unsigned EltIdx = Layout->getElementContainingOffset(Offset); - EltTy = ST->getContainedType(EltIdx); - EltSize = DL.getTypeAllocSize(EltTy); - Offset -= Layout->getElementOffset(EltIdx); - } else if (ArrayType *AT = dyn_cast<ArrayType>(T)) { - EltTy = AT->getElementType(); - EltSize = DL.getTypeAllocSize(EltTy); - if (Offset >= AT->getNumElements() * EltSize) - return false; - Offset %= EltSize; - } else if (VectorType *VT = dyn_cast<VectorType>(T)) { - EltTy = VT->getElementType(); - EltSize = DL.getTypeAllocSize(EltTy); - if (Offset >= VT->getNumElements() * EltSize) - return false; - Offset %= EltSize; - } else { - return false; - } - if (Offset == 0 && (Size == 0 || EltSize == Size)) - return true; - // Check if the component spans multiple elements. - if (Offset + Size > EltSize) - return false; - return TypeHasComponent(EltTy, Offset, Size, DL); -} - -/// RewriteForScalarRepl - Alloca AI is being split into NewElts, so rewrite -/// the instruction I, which references it, to use the separate elements. -/// Offset indicates the position within AI that is referenced by this -/// instruction. -void SROA::RewriteForScalarRepl(Instruction *I, AllocaInst *AI, uint64_t Offset, - SmallVectorImpl<AllocaInst *> &NewElts) { - const DataLayout &DL = I->getModule()->getDataLayout(); - for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI!=E;) { - Use &TheUse = *UI++; - Instruction *User = cast<Instruction>(TheUse.getUser()); - - if (BitCastInst *BC = dyn_cast<BitCastInst>(User)) { - RewriteBitCast(BC, AI, Offset, NewElts); - continue; - } - - if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(User)) { - RewriteGEP(GEPI, AI, Offset, NewElts); - continue; - } - - if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) { - ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength()); - uint64_t MemSize = Length->getZExtValue(); - if (Offset == 0 && MemSize == DL.getTypeAllocSize(AI->getAllocatedType())) - RewriteMemIntrinUserOfAlloca(MI, I, AI, NewElts); - // Otherwise the intrinsic can only touch a single element and the - // address operand will be updated, so nothing else needs to be done. - continue; - } - - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(User)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) { - RewriteLifetimeIntrinsic(II, AI, Offset, NewElts); - } - continue; - } - - if (LoadInst *LI = dyn_cast<LoadInst>(User)) { - Type *LIType = LI->getType(); - - if (isCompatibleAggregate(LIType, AI->getAllocatedType())) { - // Replace: - // %res = load { i32, i32 }* %alloc - // with: - // %load.0 = load i32* %alloc.0 - // %insert.0 insertvalue { i32, i32 } zeroinitializer, i32 %load.0, 0 - // %load.1 = load i32* %alloc.1 - // %insert = insertvalue { i32, i32 } %insert.0, i32 %load.1, 1 - // (Also works for arrays instead of structs) - Value *Insert = UndefValue::get(LIType); - IRBuilder<> Builder(LI); - for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { - Value *Load = Builder.CreateLoad(NewElts[i], "load"); - Insert = Builder.CreateInsertValue(Insert, Load, i, "insert"); - } - LI->replaceAllUsesWith(Insert); - DeadInsts.push_back(LI); - } else if (LIType->isIntegerTy() && - DL.getTypeAllocSize(LIType) == - DL.getTypeAllocSize(AI->getAllocatedType())) { - // If this is a load of the entire alloca to an integer, rewrite it. - RewriteLoadUserOfWholeAlloca(LI, AI, NewElts); - } - continue; - } - - if (StoreInst *SI = dyn_cast<StoreInst>(User)) { - Value *Val = SI->getOperand(0); - Type *SIType = Val->getType(); - if (isCompatibleAggregate(SIType, AI->getAllocatedType())) { - // Replace: - // store { i32, i32 } %val, { i32, i32 }* %alloc - // with: - // %val.0 = extractvalue { i32, i32 } %val, 0 - // store i32 %val.0, i32* %alloc.0 - // %val.1 = extractvalue { i32, i32 } %val, 1 - // store i32 %val.1, i32* %alloc.1 - // (Also works for arrays instead of structs) - IRBuilder<> Builder(SI); - for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { - Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName()); - Builder.CreateStore(Extract, NewElts[i]); - } - DeadInsts.push_back(SI); - } else if (SIType->isIntegerTy() && - DL.getTypeAllocSize(SIType) == - DL.getTypeAllocSize(AI->getAllocatedType())) { - // If this is a store of the entire alloca from an integer, rewrite it. - RewriteStoreUserOfWholeAlloca(SI, AI, NewElts); - } - continue; - } - - if (isa<SelectInst>(User) || isa<PHINode>(User)) { - // If we have a PHI user of the alloca itself (as opposed to a GEP or - // bitcast) we have to rewrite it. GEP and bitcast uses will be RAUW'd to - // the new pointer. - if (!isa<AllocaInst>(I)) continue; - - assert(Offset == 0 && NewElts[0] && - "Direct alloca use should have a zero offset"); - - // If we have a use of the alloca, we know the derived uses will be - // utilizing just the first element of the scalarized result. Insert a - // bitcast of the first alloca before the user as required. - AllocaInst *NewAI = NewElts[0]; - BitCastInst *BCI = new BitCastInst(NewAI, AI->getType(), "", NewAI); - NewAI->moveBefore(BCI); - TheUse = BCI; - continue; - } - } -} - -/// RewriteBitCast - Update a bitcast reference to the alloca being replaced -/// and recursively continue updating all of its uses. -void SROA::RewriteBitCast(BitCastInst *BC, AllocaInst *AI, uint64_t Offset, - SmallVectorImpl<AllocaInst *> &NewElts) { - RewriteForScalarRepl(BC, AI, Offset, NewElts); - if (BC->getOperand(0) != AI) - return; - - // The bitcast references the original alloca. Replace its uses with - // references to the alloca containing offset zero (which is normally at - // index zero, but might not be in cases involving structs with elements - // of size zero). - Type *T = AI->getAllocatedType(); - uint64_t EltOffset = 0; - Type *IdxTy; - uint64_t Idx = FindElementAndOffset(T, EltOffset, IdxTy, - BC->getModule()->getDataLayout()); - Instruction *Val = NewElts[Idx]; - if (Val->getType() != BC->getDestTy()) { - Val = new BitCastInst(Val, BC->getDestTy(), "", BC); - Val->takeName(BC); - } - BC->replaceAllUsesWith(Val); - DeadInsts.push_back(BC); -} - -/// FindElementAndOffset - Return the index of the element containing Offset -/// within the specified type, which must be either a struct or an array. -/// Sets T to the type of the element and Offset to the offset within that -/// element. IdxTy is set to the type of the index result to be used in a -/// GEP instruction. -uint64_t SROA::FindElementAndOffset(Type *&T, uint64_t &Offset, Type *&IdxTy, - const DataLayout &DL) { - uint64_t Idx = 0; - - if (StructType *ST = dyn_cast<StructType>(T)) { - const StructLayout *Layout = DL.getStructLayout(ST); - Idx = Layout->getElementContainingOffset(Offset); - T = ST->getContainedType(Idx); - Offset -= Layout->getElementOffset(Idx); - IdxTy = Type::getInt32Ty(T->getContext()); - return Idx; - } else if (ArrayType *AT = dyn_cast<ArrayType>(T)) { - T = AT->getElementType(); - uint64_t EltSize = DL.getTypeAllocSize(T); - Idx = Offset / EltSize; - Offset -= Idx * EltSize; - IdxTy = Type::getInt64Ty(T->getContext()); - return Idx; - } - VectorType *VT = cast<VectorType>(T); - T = VT->getElementType(); - uint64_t EltSize = DL.getTypeAllocSize(T); - Idx = Offset / EltSize; - Offset -= Idx * EltSize; - IdxTy = Type::getInt64Ty(T->getContext()); - return Idx; -} - -/// RewriteGEP - Check if this GEP instruction moves the pointer across -/// elements of the alloca that are being split apart, and if so, rewrite -/// the GEP to be relative to the new element. -void SROA::RewriteGEP(GetElementPtrInst *GEPI, AllocaInst *AI, uint64_t Offset, - SmallVectorImpl<AllocaInst *> &NewElts) { - uint64_t OldOffset = Offset; - const DataLayout &DL = GEPI->getModule()->getDataLayout(); - SmallVector<Value*, 8> Indices(GEPI->op_begin() + 1, GEPI->op_end()); - // If the GEP was dynamic then it must have been a dynamic vector lookup. - // In this case, it must be the last GEP operand which is dynamic so keep that - // aside until we've found the constant GEP offset then add it back in at the - // end. - Value* NonConstantIdx = nullptr; - if (!GEPI->hasAllConstantIndices()) - NonConstantIdx = Indices.pop_back_val(); - Offset += DL.getIndexedOffset(GEPI->getPointerOperandType(), Indices); - - RewriteForScalarRepl(GEPI, AI, Offset, NewElts); - - Type *T = AI->getAllocatedType(); - Type *IdxTy; - uint64_t OldIdx = FindElementAndOffset(T, OldOffset, IdxTy, DL); - if (GEPI->getOperand(0) == AI) - OldIdx = ~0ULL; // Force the GEP to be rewritten. - - T = AI->getAllocatedType(); - uint64_t EltOffset = Offset; - uint64_t Idx = FindElementAndOffset(T, EltOffset, IdxTy, DL); - - // If this GEP does not move the pointer across elements of the alloca - // being split, then it does not needs to be rewritten. - if (Idx == OldIdx) - return; - - Type *i32Ty = Type::getInt32Ty(AI->getContext()); - SmallVector<Value*, 8> NewArgs; - NewArgs.push_back(Constant::getNullValue(i32Ty)); - while (EltOffset != 0) { - uint64_t EltIdx = FindElementAndOffset(T, EltOffset, IdxTy, DL); - NewArgs.push_back(ConstantInt::get(IdxTy, EltIdx)); - } - if (NonConstantIdx) { - Type* GepTy = T; - // This GEP has a dynamic index. We need to add "i32 0" to index through - // any structs or arrays in the original type until we get to the vector - // to index. - while (!isa<VectorType>(GepTy)) { - NewArgs.push_back(Constant::getNullValue(i32Ty)); - GepTy = cast<CompositeType>(GepTy)->getTypeAtIndex(0U); - } - NewArgs.push_back(NonConstantIdx); - } - Instruction *Val = NewElts[Idx]; - if (NewArgs.size() > 1) { - Val = GetElementPtrInst::CreateInBounds(Val, NewArgs, "", GEPI); - Val->takeName(GEPI); - } - if (Val->getType() != GEPI->getType()) - Val = new BitCastInst(Val, GEPI->getType(), Val->getName(), GEPI); - GEPI->replaceAllUsesWith(Val); - DeadInsts.push_back(GEPI); -} - -/// RewriteLifetimeIntrinsic - II is a lifetime.start/lifetime.end. Rewrite it -/// to mark the lifetime of the scalarized memory. -void SROA::RewriteLifetimeIntrinsic(IntrinsicInst *II, AllocaInst *AI, - uint64_t Offset, - SmallVectorImpl<AllocaInst *> &NewElts) { - ConstantInt *OldSize = cast<ConstantInt>(II->getArgOperand(0)); - // Put matching lifetime markers on everything from Offset up to - // Offset+OldSize. - Type *AIType = AI->getAllocatedType(); - const DataLayout &DL = II->getModule()->getDataLayout(); - uint64_t NewOffset = Offset; - Type *IdxTy; - uint64_t Idx = FindElementAndOffset(AIType, NewOffset, IdxTy, DL); - - IRBuilder<> Builder(II); - uint64_t Size = OldSize->getLimitedValue(); - - if (NewOffset) { - // Splice the first element and index 'NewOffset' bytes in. SROA will - // split the alloca again later. - unsigned AS = AI->getType()->getAddressSpace(); - Value *V = Builder.CreateBitCast(NewElts[Idx], Builder.getInt8PtrTy(AS)); - V = Builder.CreateGEP(Builder.getInt8Ty(), V, Builder.getInt64(NewOffset)); - - IdxTy = NewElts[Idx]->getAllocatedType(); - uint64_t EltSize = DL.getTypeAllocSize(IdxTy) - NewOffset; - if (EltSize > Size) { - EltSize = Size; - Size = 0; - } else { - Size -= EltSize; - } - if (II->getIntrinsicID() == Intrinsic::lifetime_start) - Builder.CreateLifetimeStart(V, Builder.getInt64(EltSize)); - else - Builder.CreateLifetimeEnd(V, Builder.getInt64(EltSize)); - ++Idx; - } - - for (; Idx != NewElts.size() && Size; ++Idx) { - IdxTy = NewElts[Idx]->getAllocatedType(); - uint64_t EltSize = DL.getTypeAllocSize(IdxTy); - if (EltSize > Size) { - EltSize = Size; - Size = 0; - } else { - Size -= EltSize; - } - if (II->getIntrinsicID() == Intrinsic::lifetime_start) - Builder.CreateLifetimeStart(NewElts[Idx], - Builder.getInt64(EltSize)); - else - Builder.CreateLifetimeEnd(NewElts[Idx], - Builder.getInt64(EltSize)); - } - DeadInsts.push_back(II); -} - -/// RewriteMemIntrinUserOfAlloca - MI is a memcpy/memset/memmove from or to AI. -/// Rewrite it to copy or set the elements of the scalarized memory. -void -SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst, - AllocaInst *AI, - SmallVectorImpl<AllocaInst *> &NewElts) { - // If this is a memcpy/memmove, construct the other pointer as the - // appropriate type. The "Other" pointer is the pointer that goes to memory - // that doesn't have anything to do with the alloca that we are promoting. For - // memset, this Value* stays null. - Value *OtherPtr = nullptr; - unsigned MemAlignment = MI->getAlignment(); - if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { // memmove/memcopy - if (Inst == MTI->getRawDest()) - OtherPtr = MTI->getRawSource(); - else { - assert(Inst == MTI->getRawSource()); - OtherPtr = MTI->getRawDest(); - } - } - - // If there is an other pointer, we want to convert it to the same pointer - // type as AI has, so we can GEP through it safely. - if (OtherPtr) { - unsigned AddrSpace = - cast<PointerType>(OtherPtr->getType())->getAddressSpace(); - - // Remove bitcasts and all-zero GEPs from OtherPtr. This is an - // optimization, but it's also required to detect the corner case where - // both pointer operands are referencing the same memory, and where - // OtherPtr may be a bitcast or GEP that currently being rewritten. (This - // function is only called for mem intrinsics that access the whole - // aggregate, so non-zero GEPs are not an issue here.) - OtherPtr = OtherPtr->stripPointerCasts(); - - // Copying the alloca to itself is a no-op: just delete it. - if (OtherPtr == AI || OtherPtr == NewElts[0]) { - // This code will run twice for a no-op memcpy -- once for each operand. - // Put only one reference to MI on the DeadInsts list. - for (SmallVectorImpl<Value *>::const_iterator I = DeadInsts.begin(), - E = DeadInsts.end(); I != E; ++I) - if (*I == MI) return; - DeadInsts.push_back(MI); - return; - } - - // If the pointer is not the right type, insert a bitcast to the right - // type. - Type *NewTy = - PointerType::get(AI->getType()->getElementType(), AddrSpace); - - if (OtherPtr->getType() != NewTy) - OtherPtr = new BitCastInst(OtherPtr, NewTy, OtherPtr->getName(), MI); - } - - // Process each element of the aggregate. - bool SROADest = MI->getRawDest() == Inst; - - Constant *Zero = Constant::getNullValue(Type::getInt32Ty(MI->getContext())); - const DataLayout &DL = MI->getModule()->getDataLayout(); - - for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { - // If this is a memcpy/memmove, emit a GEP of the other element address. - Value *OtherElt = nullptr; - unsigned OtherEltAlign = MemAlignment; - - if (OtherPtr) { - Value *Idx[2] = { Zero, - ConstantInt::get(Type::getInt32Ty(MI->getContext()), i) }; - OtherElt = GetElementPtrInst::CreateInBounds(OtherPtr, Idx, - OtherPtr->getName()+"."+Twine(i), - MI); - uint64_t EltOffset; - PointerType *OtherPtrTy = cast<PointerType>(OtherPtr->getType()); - Type *OtherTy = OtherPtrTy->getElementType(); - if (StructType *ST = dyn_cast<StructType>(OtherTy)) { - EltOffset = DL.getStructLayout(ST)->getElementOffset(i); - } else { - Type *EltTy = cast<SequentialType>(OtherTy)->getElementType(); - EltOffset = DL.getTypeAllocSize(EltTy) * i; - } - - // The alignment of the other pointer is the guaranteed alignment of the - // element, which is affected by both the known alignment of the whole - // mem intrinsic and the alignment of the element. If the alignment of - // the memcpy (f.e.) is 32 but the element is at a 4-byte offset, then the - // known alignment is just 4 bytes. - OtherEltAlign = (unsigned)MinAlign(OtherEltAlign, EltOffset); - } - - Value *EltPtr = NewElts[i]; - Type *EltTy = cast<PointerType>(EltPtr->getType())->getElementType(); - - // If we got down to a scalar, insert a load or store as appropriate. - if (EltTy->isSingleValueType()) { - if (isa<MemTransferInst>(MI)) { - if (SROADest) { - // From Other to Alloca. - Value *Elt = new LoadInst(OtherElt, "tmp", false, OtherEltAlign, MI); - new StoreInst(Elt, EltPtr, MI); - } else { - // From Alloca to Other. - Value *Elt = new LoadInst(EltPtr, "tmp", MI); - new StoreInst(Elt, OtherElt, false, OtherEltAlign, MI); - } - continue; - } - assert(isa<MemSetInst>(MI)); - - // If the stored element is zero (common case), just store a null - // constant. - Constant *StoreVal; - if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getArgOperand(1))) { - if (CI->isZero()) { - StoreVal = Constant::getNullValue(EltTy); // 0.0, null, 0, <0,0> - } else { - // If EltTy is a vector type, get the element type. - Type *ValTy = EltTy->getScalarType(); - - // Construct an integer with the right value. - unsigned EltSize = DL.getTypeSizeInBits(ValTy); - APInt OneVal(EltSize, CI->getZExtValue()); - APInt TotalVal(OneVal); - // Set each byte. - for (unsigned i = 0; 8*i < EltSize; ++i) { - TotalVal = TotalVal.shl(8); - TotalVal |= OneVal; - } - - // Convert the integer value to the appropriate type. - StoreVal = ConstantInt::get(CI->getContext(), TotalVal); - if (ValTy->isPointerTy()) - StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy); - else if (ValTy->isFloatingPointTy()) - StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy); - assert(StoreVal->getType() == ValTy && "Type mismatch!"); - - // If the requested value was a vector constant, create it. - if (EltTy->isVectorTy()) { - unsigned NumElts = cast<VectorType>(EltTy)->getNumElements(); - StoreVal = ConstantVector::getSplat(NumElts, StoreVal); - } - } - new StoreInst(StoreVal, EltPtr, MI); - continue; - } - // Otherwise, if we're storing a byte variable, use a memset call for - // this element. - } - - unsigned EltSize = DL.getTypeAllocSize(EltTy); - if (!EltSize) - continue; - - IRBuilder<> Builder(MI); - - // Finally, insert the meminst for this element. - if (isa<MemSetInst>(MI)) { - Builder.CreateMemSet(EltPtr, MI->getArgOperand(1), EltSize, - MI->isVolatile()); - } else { - assert(isa<MemTransferInst>(MI)); - Value *Dst = SROADest ? EltPtr : OtherElt; // Dest ptr - Value *Src = SROADest ? OtherElt : EltPtr; // Src ptr - - if (isa<MemCpyInst>(MI)) - Builder.CreateMemCpy(Dst, Src, EltSize, OtherEltAlign,MI->isVolatile()); - else - Builder.CreateMemMove(Dst, Src, EltSize,OtherEltAlign,MI->isVolatile()); - } - } - DeadInsts.push_back(MI); -} - -/// RewriteStoreUserOfWholeAlloca - We found a store of an integer that -/// overwrites the entire allocation. Extract out the pieces of the stored -/// integer and store them individually. -void -SROA::RewriteStoreUserOfWholeAlloca(StoreInst *SI, AllocaInst *AI, - SmallVectorImpl<AllocaInst *> &NewElts) { - // Extract each element out of the integer according to its structure offset - // and store the element value to the individual alloca. - Value *SrcVal = SI->getOperand(0); - Type *AllocaEltTy = AI->getAllocatedType(); - const DataLayout &DL = SI->getModule()->getDataLayout(); - uint64_t AllocaSizeBits = DL.getTypeAllocSizeInBits(AllocaEltTy); - - IRBuilder<> Builder(SI); - - // Handle tail padding by extending the operand - if (DL.getTypeSizeInBits(SrcVal->getType()) != AllocaSizeBits) - SrcVal = Builder.CreateZExt(SrcVal, - IntegerType::get(SI->getContext(), AllocaSizeBits)); - - DEBUG(dbgs() << "PROMOTING STORE TO WHOLE ALLOCA: " << *AI << '\n' << *SI - << '\n'); - - // There are two forms here: AI could be an array or struct. Both cases - // have different ways to compute the element offset. - if (StructType *EltSTy = dyn_cast<StructType>(AllocaEltTy)) { - const StructLayout *Layout = DL.getStructLayout(EltSTy); - - for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { - // Get the number of bits to shift SrcVal to get the value. - Type *FieldTy = EltSTy->getElementType(i); - uint64_t Shift = Layout->getElementOffsetInBits(i); - - if (DL.isBigEndian()) - Shift = AllocaSizeBits - Shift - DL.getTypeAllocSizeInBits(FieldTy); - - Value *EltVal = SrcVal; - if (Shift) { - Value *ShiftVal = ConstantInt::get(EltVal->getType(), Shift); - EltVal = Builder.CreateLShr(EltVal, ShiftVal, "sroa.store.elt"); - } - - // Truncate down to an integer of the right size. - uint64_t FieldSizeBits = DL.getTypeSizeInBits(FieldTy); - - // Ignore zero sized fields like {}, they obviously contain no data. - if (FieldSizeBits == 0) continue; - - if (FieldSizeBits != AllocaSizeBits) - EltVal = Builder.CreateTrunc(EltVal, - IntegerType::get(SI->getContext(), FieldSizeBits)); - Value *DestField = NewElts[i]; - if (EltVal->getType() == FieldTy) { - // Storing to an integer field of this size, just do it. - } else if (FieldTy->isFloatingPointTy() || FieldTy->isVectorTy()) { - // Bitcast to the right element type (for fp/vector values). - EltVal = Builder.CreateBitCast(EltVal, FieldTy); - } else { - // Otherwise, bitcast the dest pointer (for aggregates). - DestField = Builder.CreateBitCast(DestField, - PointerType::getUnqual(EltVal->getType())); - } - new StoreInst(EltVal, DestField, SI); - } - - } else { - ArrayType *ATy = cast<ArrayType>(AllocaEltTy); - Type *ArrayEltTy = ATy->getElementType(); - uint64_t ElementOffset = DL.getTypeAllocSizeInBits(ArrayEltTy); - uint64_t ElementSizeBits = DL.getTypeSizeInBits(ArrayEltTy); - - uint64_t Shift; - - if (DL.isBigEndian()) - Shift = AllocaSizeBits-ElementOffset; - else - Shift = 0; - - for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { - // Ignore zero sized fields like {}, they obviously contain no data. - if (ElementSizeBits == 0) continue; - - Value *EltVal = SrcVal; - if (Shift) { - Value *ShiftVal = ConstantInt::get(EltVal->getType(), Shift); - EltVal = Builder.CreateLShr(EltVal, ShiftVal, "sroa.store.elt"); - } - - // Truncate down to an integer of the right size. - if (ElementSizeBits != AllocaSizeBits) - EltVal = Builder.CreateTrunc(EltVal, - IntegerType::get(SI->getContext(), - ElementSizeBits)); - Value *DestField = NewElts[i]; - if (EltVal->getType() == ArrayEltTy) { - // Storing to an integer field of this size, just do it. - } else if (ArrayEltTy->isFloatingPointTy() || - ArrayEltTy->isVectorTy()) { - // Bitcast to the right element type (for fp/vector values). - EltVal = Builder.CreateBitCast(EltVal, ArrayEltTy); - } else { - // Otherwise, bitcast the dest pointer (for aggregates). - DestField = Builder.CreateBitCast(DestField, - PointerType::getUnqual(EltVal->getType())); - } - new StoreInst(EltVal, DestField, SI); - - if (DL.isBigEndian()) - Shift -= ElementOffset; - else - Shift += ElementOffset; - } - } - - DeadInsts.push_back(SI); -} - -/// RewriteLoadUserOfWholeAlloca - We found a load of the entire allocation to -/// an integer. Load the individual pieces to form the aggregate value. -void -SROA::RewriteLoadUserOfWholeAlloca(LoadInst *LI, AllocaInst *AI, - SmallVectorImpl<AllocaInst *> &NewElts) { - // Extract each element out of the NewElts according to its structure offset - // and form the result value. - Type *AllocaEltTy = AI->getAllocatedType(); - const DataLayout &DL = LI->getModule()->getDataLayout(); - uint64_t AllocaSizeBits = DL.getTypeAllocSizeInBits(AllocaEltTy); - - DEBUG(dbgs() << "PROMOTING LOAD OF WHOLE ALLOCA: " << *AI << '\n' << *LI - << '\n'); - - // There are two forms here: AI could be an array or struct. Both cases - // have different ways to compute the element offset. - const StructLayout *Layout = nullptr; - uint64_t ArrayEltBitOffset = 0; - if (StructType *EltSTy = dyn_cast<StructType>(AllocaEltTy)) { - Layout = DL.getStructLayout(EltSTy); - } else { - Type *ArrayEltTy = cast<ArrayType>(AllocaEltTy)->getElementType(); - ArrayEltBitOffset = DL.getTypeAllocSizeInBits(ArrayEltTy); - } - - Value *ResultVal = - Constant::getNullValue(IntegerType::get(LI->getContext(), AllocaSizeBits)); - - for (unsigned i = 0, e = NewElts.size(); i != e; ++i) { - // Load the value from the alloca. If the NewElt is an aggregate, cast - // the pointer to an integer of the same size before doing the load. - Value *SrcField = NewElts[i]; - Type *FieldTy = - cast<PointerType>(SrcField->getType())->getElementType(); - uint64_t FieldSizeBits = DL.getTypeSizeInBits(FieldTy); - - // Ignore zero sized fields like {}, they obviously contain no data. - if (FieldSizeBits == 0) continue; - - IntegerType *FieldIntTy = IntegerType::get(LI->getContext(), - FieldSizeBits); - if (!FieldTy->isIntegerTy() && !FieldTy->isFloatingPointTy() && - !FieldTy->isVectorTy()) - SrcField = new BitCastInst(SrcField, - PointerType::getUnqual(FieldIntTy), - "", LI); - SrcField = new LoadInst(SrcField, "sroa.load.elt", LI); - - // If SrcField is a fp or vector of the right size but that isn't an - // integer type, bitcast to an integer so we can shift it. - if (SrcField->getType() != FieldIntTy) - SrcField = new BitCastInst(SrcField, FieldIntTy, "", LI); - - // Zero extend the field to be the same size as the final alloca so that - // we can shift and insert it. - if (SrcField->getType() != ResultVal->getType()) - SrcField = new ZExtInst(SrcField, ResultVal->getType(), "", LI); - - // Determine the number of bits to shift SrcField. - uint64_t Shift; - if (Layout) // Struct case. - Shift = Layout->getElementOffsetInBits(i); - else // Array case. - Shift = i*ArrayEltBitOffset; - - if (DL.isBigEndian()) - Shift = AllocaSizeBits-Shift-FieldIntTy->getBitWidth(); - - if (Shift) { - Value *ShiftVal = ConstantInt::get(SrcField->getType(), Shift); - SrcField = BinaryOperator::CreateShl(SrcField, ShiftVal, "", LI); - } - - // Don't create an 'or x, 0' on the first iteration. - if (!isa<Constant>(ResultVal) || - !cast<Constant>(ResultVal)->isNullValue()) - ResultVal = BinaryOperator::CreateOr(SrcField, ResultVal, "", LI); - else - ResultVal = SrcField; - } - - // Handle tail padding by truncating the result - if (DL.getTypeSizeInBits(LI->getType()) != AllocaSizeBits) - ResultVal = new TruncInst(ResultVal, LI->getType(), "", LI); - - LI->replaceAllUsesWith(ResultVal); - DeadInsts.push_back(LI); -} - -/// HasPadding - Return true if the specified type has any structure or -/// alignment padding in between the elements that would be split apart -/// by SROA; return false otherwise. -static bool HasPadding(Type *Ty, const DataLayout &DL) { - if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) { - Ty = ATy->getElementType(); - return DL.getTypeSizeInBits(Ty) != DL.getTypeAllocSizeInBits(Ty); - } - - // SROA currently handles only Arrays and Structs. - StructType *STy = cast<StructType>(Ty); - const StructLayout *SL = DL.getStructLayout(STy); - unsigned PrevFieldBitOffset = 0; - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - unsigned FieldBitOffset = SL->getElementOffsetInBits(i); - - // Check to see if there is any padding between this element and the - // previous one. - if (i) { - unsigned PrevFieldEnd = - PrevFieldBitOffset+DL.getTypeSizeInBits(STy->getElementType(i-1)); - if (PrevFieldEnd < FieldBitOffset) - return true; - } - PrevFieldBitOffset = FieldBitOffset; - } - // Check for tail padding. - if (unsigned EltCount = STy->getNumElements()) { - unsigned PrevFieldEnd = PrevFieldBitOffset + - DL.getTypeSizeInBits(STy->getElementType(EltCount-1)); - if (PrevFieldEnd < SL->getSizeInBits()) - return true; - } - return false; -} - -/// isSafeStructAllocaToScalarRepl - Check to see if the specified allocation of -/// an aggregate can be broken down into elements. Return 0 if not, 3 if safe, -/// or 1 if safe after canonicalization has been performed. -bool SROA::isSafeAllocaToScalarRepl(AllocaInst *AI) { - // Loop over the use list of the alloca. We can only transform it if all of - // the users are safe to transform. - AllocaInfo Info(AI); - - isSafeForScalarRepl(AI, 0, Info); - if (Info.isUnsafe) { - DEBUG(dbgs() << "Cannot transform: " << *AI << '\n'); - return false; - } - - const DataLayout &DL = AI->getModule()->getDataLayout(); - - // Okay, we know all the users are promotable. If the aggregate is a memcpy - // source and destination, we have to be careful. In particular, the memcpy - // could be moving around elements that live in structure padding of the LLVM - // types, but may actually be used. In these cases, we refuse to promote the - // struct. - if (Info.isMemCpySrc && Info.isMemCpyDst && - HasPadding(AI->getAllocatedType(), DL)) - return false; - - // If the alloca never has an access to just *part* of it, but is accessed - // via loads and stores, then we should use ConvertToScalarInfo to promote - // the alloca instead of promoting each piece at a time and inserting fission - // and fusion code. - if (!Info.hasSubelementAccess && Info.hasALoadOrStore) { - // If the struct/array just has one element, use basic SRoA. - if (StructType *ST = dyn_cast<StructType>(AI->getAllocatedType())) { - if (ST->getNumElements() > 1) return false; - } else { - if (cast<ArrayType>(AI->getAllocatedType())->getNumElements() > 1) - return false; - } - } - - return true; -} diff --git a/lib/Transforms/Scalar/Scalarizer.cpp b/lib/Transforms/Scalar/Scalarizer.cpp index 054bacdc706b..aed4a4ad4d26 100644 --- a/lib/Transforms/Scalar/Scalarizer.cpp +++ b/lib/Transforms/Scalar/Scalarizer.cpp @@ -14,12 +14,11 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/Pass.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -253,6 +252,8 @@ bool Scalarizer::doInitialization(Module &M) { } bool Scalarizer::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; assert(Gathered.empty() && Scattered.empty()); for (BasicBlock &BB : F) { for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) { @@ -305,7 +306,11 @@ void Scalarizer::gather(Instruction *Op, const ValueVector &CV) { ValueVector &SV = Scattered[Op]; if (!SV.empty()) { for (unsigned I = 0, E = SV.size(); I != E; ++I) { - Instruction *Old = cast<Instruction>(SV[I]); + Value *V = SV[I]; + if (V == nullptr) + continue; + + Instruction *Old = cast<Instruction>(V); CV[I]->takeName(Old); Old->replaceAllUsesWith(CV[I]); Old->eraseFromParent(); @@ -334,13 +339,11 @@ void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) { Op->getAllMetadataOtherThanDebugLoc(MDs); for (unsigned I = 0, E = CV.size(); I != E; ++I) { if (Instruction *New = dyn_cast<Instruction>(CV[I])) { - for (SmallVectorImpl<std::pair<unsigned, MDNode *>>::iterator - MI = MDs.begin(), - ME = MDs.end(); - MI != ME; ++MI) - if (canTransferMetadata(MI->first)) - New->setMetadata(MI->first, MI->second); - New->setDebugLoc(Op->getDebugLoc()); + for (const auto &MD : MDs) + if (canTransferMetadata(MD.first)) + New->setMetadata(MD.first, MD.second); + if (Op->getDebugLoc() && !New->getDebugLoc()) + New->setDebugLoc(Op->getDebugLoc()); } } } @@ -646,10 +649,9 @@ bool Scalarizer::finish() { // made to the Function. if (Gathered.empty() && Scattered.empty()) return false; - for (GatherList::iterator GMI = Gathered.begin(), GME = Gathered.end(); - GMI != GME; ++GMI) { - Instruction *Op = GMI->first; - ValueVector &CV = *GMI->second; + for (const auto &GMI : Gathered) { + Instruction *Op = GMI.first; + ValueVector &CV = *GMI.second; if (!Op->use_empty()) { // The value is still needed, so recreate it using a series of // InsertElements. diff --git a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 86a10d2a1612..d6ae186698c7 100644 --- a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -590,9 +590,9 @@ Value *ConstantOffsetExtractor::rebuildWithoutConstOffset() { distributeExtsAndCloneChain(UserChain.size() - 1); // Remove all nullptrs (used to be s/zext) from UserChain. unsigned NewSize = 0; - for (auto I = UserChain.begin(), E = UserChain.end(); I != E; ++I) { - if (*I != nullptr) { - UserChain[NewSize] = *I; + for (User *I : UserChain) { + if (I != nullptr) { + UserChain[NewSize] = I; NewSize++; } } @@ -824,8 +824,8 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( // If we created a GEP with constant index, and the base is loop invariant, // then we swap the first one with it, so LICM can move constant GEP out // later. - GetElementPtrInst *FirstGEP = dyn_cast<GetElementPtrInst>(FirstResult); - GetElementPtrInst *SecondGEP = dyn_cast<GetElementPtrInst>(ResultPtr); + GetElementPtrInst *FirstGEP = dyn_cast_or_null<GetElementPtrInst>(FirstResult); + GetElementPtrInst *SecondGEP = dyn_cast_or_null<GetElementPtrInst>(ResultPtr); if (isSwapCandidate && isLegalToSwapOperand(FirstGEP, SecondGEP, L)) swapGEPOperand(FirstGEP, SecondGEP); @@ -911,7 +911,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { getAnalysis<TargetTransformInfoWrapperPass>().getTTI( *GEP->getParent()->getParent()); unsigned AddrSpace = GEP->getPointerAddressSpace(); - if (!TTI.isLegalAddressingMode(GEP->getType()->getElementType(), + if (!TTI.isLegalAddressingMode(GEP->getResultElementType(), /*BaseGV=*/nullptr, AccumulativeByteOffset, /*HasBaseReg=*/true, /*Scale=*/0, AddrSpace)) { @@ -1018,7 +1018,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // unsigned.. Therefore, we cast ElementTypeSizeOfGEP to signed because it is // used with unsigned integers later. int64_t ElementTypeSizeOfGEP = static_cast<int64_t>( - DL->getTypeAllocSize(GEP->getType()->getElementType())); + DL->getTypeAllocSize(GEP->getResultElementType())); Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); if (AccumulativeByteOffset % ElementTypeSizeOfGEP == 0) { // Very likely. As long as %gep is natually aligned, the byte offset we @@ -1064,7 +1064,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { } bool SeparateConstOffsetFromGEP::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; if (DisableSeparateConstOffsetFromGEP) @@ -1075,8 +1075,8 @@ bool SeparateConstOffsetFromGEP::runOnFunction(Function &F) { LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); bool Changed = false; - for (Function::iterator B = F.begin(), BE = F.end(); B != BE; ++B) { - for (BasicBlock::iterator I = B->begin(), IE = B->end(); I != IE;) + for (BasicBlock &B : F) { + for (BasicBlock::iterator I = B.begin(), IE = B.end(); I != IE;) if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I++)) Changed |= splitGEP(GEP); // No need to split GEP ConstantExprs because all its indices are constant @@ -1162,8 +1162,8 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Function &F) { } void SeparateConstOffsetFromGEP::verifyNoDeadCode(Function &F) { - for (auto &B : F) { - for (auto &I : B) { + for (BasicBlock &B : F) { + for (Instruction &I : B) { if (isInstructionTriviallyDead(&I)) { std::string ErrMessage; raw_string_ostream RSO(ErrMessage); diff --git a/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 63c8836bf381..2d0a21d2c518 100644 --- a/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -21,12 +21,12 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar/SimplifyCFG.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CFG.h" @@ -37,8 +37,10 @@ #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/SimplifyCFG.h" +#include "llvm/Transforms/Utils/Local.h" +#include <utility> using namespace llvm; #define DEBUG_TYPE "simplifycfg" @@ -131,12 +133,19 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, unsigned BonusInstThreshold) { bool Changed = false; bool LocalChange = true; + + SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 32> Edges; + FindFunctionBackedges(F, Edges); + SmallPtrSet<BasicBlock *, 16> LoopHeaders; + for (unsigned i = 0, e = Edges.size(); i != e; ++i) + LoopHeaders.insert(const_cast<BasicBlock *>(Edges[i].second)); + while (LocalChange) { LocalChange = false; // Loop over all of the basic blocks and remove them if they are unneeded. for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) { - if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC)) { + if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC, &LoopHeaders)) { LocalChange = true; ++NumSimpl; } @@ -178,14 +187,15 @@ SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold) : BonusInstThreshold(BonusInstThreshold) {} PreservedAnalyses SimplifyCFGPass::run(Function &F, - AnalysisManager<Function> *AM) { - auto &TTI = AM->getResult<TargetIRAnalysis>(F); - auto &AC = AM->getResult<AssumptionAnalysis>(F); + AnalysisManager<Function> &AM) { + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); if (!simplifyFunctionCFG(F, TTI, &AC, BonusInstThreshold)) - return PreservedAnalyses::none(); - - return PreservedAnalyses::all(); + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; } namespace { @@ -196,15 +206,12 @@ struct CFGSimplifyPass : public FunctionPass { CFGSimplifyPass(int T = -1, std::function<bool(const Function &)> Ftor = nullptr) - : FunctionPass(ID), PredicateFtor(Ftor) { + : FunctionPass(ID), PredicateFtor(std::move(Ftor)) { BonusInstThreshold = (T == -1) ? UserBonusInstThreshold : unsigned(T); initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { - if (PredicateFtor && !PredicateFtor(F)) - return false; - - if (skipOptnoneFunction(F)) + if (skipFunction(F) || (PredicateFtor && !PredicateFtor(F))) return false; AssumptionCache *AC = @@ -234,6 +241,5 @@ INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, FunctionPass * llvm::createCFGSimplificationPass(int Threshold, std::function<bool(const Function &)> Ftor) { - return new CFGSimplifyPass(Threshold, Ftor); + return new CFGSimplifyPass(Threshold, std::move(Ftor)); } - diff --git a/lib/Transforms/Scalar/Sink.cpp b/lib/Transforms/Scalar/Sink.cpp index 64109b2df117..d9a296c63122 100644 --- a/lib/Transforms/Scalar/Sink.cpp +++ b/lib/Transforms/Scalar/Sink.cpp @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/Sink.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/LoopInfo.h" @@ -24,6 +24,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" using namespace llvm; #define DEBUG_TYPE "sink" @@ -31,50 +32,10 @@ using namespace llvm; STATISTIC(NumSunk, "Number of instructions sunk"); STATISTIC(NumSinkIter, "Number of sinking iterations"); -namespace { - class Sinking : public FunctionPass { - DominatorTree *DT; - LoopInfo *LI; - AliasAnalysis *AA; - - public: - static char ID; // Pass identification - Sinking() : FunctionPass(ID) { - initializeSinkingPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - FunctionPass::getAnalysisUsage(AU); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - } - private: - bool ProcessBlock(BasicBlock &BB); - bool SinkInstruction(Instruction *I, SmallPtrSetImpl<Instruction*> &Stores); - bool AllUsesDominatedByBlock(Instruction *Inst, BasicBlock *BB) const; - bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo) const; - }; -} // end anonymous namespace - -char Sinking::ID = 0; -INITIALIZE_PASS_BEGIN(Sinking, "sink", "Code sinking", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(Sinking, "sink", "Code sinking", false, false) - -FunctionPass *llvm::createSinkingPass() { return new Sinking(); } - /// AllUsesDominatedByBlock - Return true if all uses of the specified value /// occur in blocks dominated by the specified block. -bool Sinking::AllUsesDominatedByBlock(Instruction *Inst, - BasicBlock *BB) const { +static bool AllUsesDominatedByBlock(Instruction *Inst, BasicBlock *BB, + DominatorTree &DT) { // Ignoring debug uses is necessary so debug info doesn't affect the code. // This may leave a referencing dbg_value in the original block, before // the definition of the vreg. Dwarf generator handles this although the @@ -90,71 +51,13 @@ bool Sinking::AllUsesDominatedByBlock(Instruction *Inst, UseBlock = PN->getIncomingBlock(Num); } // Check that it dominates. - if (!DT->dominates(BB, UseBlock)) + if (!DT.dominates(BB, UseBlock)) return false; } return true; } -bool Sinking::runOnFunction(Function &F) { - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - - bool MadeChange, EverMadeChange = false; - - do { - MadeChange = false; - DEBUG(dbgs() << "Sinking iteration " << NumSinkIter << "\n"); - // Process all basic blocks. - for (Function::iterator I = F.begin(), E = F.end(); - I != E; ++I) - MadeChange |= ProcessBlock(*I); - EverMadeChange |= MadeChange; - NumSinkIter++; - } while (MadeChange); - - return EverMadeChange; -} - -bool Sinking::ProcessBlock(BasicBlock &BB) { - // Can't sink anything out of a block that has less than two successors. - if (BB.getTerminator()->getNumSuccessors() <= 1) return false; - - // Don't bother sinking code out of unreachable blocks. In addition to being - // unprofitable, it can also lead to infinite looping, because in an - // unreachable loop there may be nowhere to stop. - if (!DT->isReachableFromEntry(&BB)) return false; - - bool MadeChange = false; - - // Walk the basic block bottom-up. Remember if we saw a store. - BasicBlock::iterator I = BB.end(); - --I; - bool ProcessedBegin = false; - SmallPtrSet<Instruction *, 8> Stores; - do { - Instruction *Inst = &*I; // The instruction to sink. - - // Predecrement I (if it's not begin) so that it isn't invalidated by - // sinking. - ProcessedBegin = I == BB.begin(); - if (!ProcessedBegin) - --I; - - if (isa<DbgInfoIntrinsic>(Inst)) - continue; - - if (SinkInstruction(Inst, Stores)) - ++NumSunk, MadeChange = true; - - // If we just processed the first instruction in the block, we're done. - } while (!ProcessedBegin); - - return MadeChange; -} - -static bool isSafeToMove(Instruction *Inst, AliasAnalysis *AA, +static bool isSafeToMove(Instruction *Inst, AliasAnalysis &AA, SmallPtrSetImpl<Instruction *> &Stores) { if (Inst->mayWriteToMemory()) { @@ -165,7 +68,7 @@ static bool isSafeToMove(Instruction *Inst, AliasAnalysis *AA, if (LoadInst *L = dyn_cast<LoadInst>(Inst)) { MemoryLocation Loc = MemoryLocation::get(L); for (Instruction *S : Stores) - if (AA->getModRefInfo(S, Loc) & MRI_Mod) + if (AA.getModRefInfo(S, Loc) & MRI_Mod) return false; } @@ -173,11 +76,15 @@ static bool isSafeToMove(Instruction *Inst, AliasAnalysis *AA, Inst->mayThrow()) return false; - // Convergent operations cannot be made control-dependent on additional - // values. if (auto CS = CallSite(Inst)) { + // Convergent operations cannot be made control-dependent on additional + // values. if (CS.hasFnAttr(Attribute::Convergent)) return false; + + for (Instruction *S : Stores) + if (AA.getModRefInfo(S, CS) & MRI_Mod) + return false; } return true; @@ -185,8 +92,8 @@ static bool isSafeToMove(Instruction *Inst, AliasAnalysis *AA, /// IsAcceptableTarget - Return true if it is possible to sink the instruction /// in the specified basic block. -bool Sinking::IsAcceptableTarget(Instruction *Inst, - BasicBlock *SuccToSinkTo) const { +static bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo, + DominatorTree &DT, LoopInfo &LI) { assert(Inst && "Instruction to be sunk is null"); assert(SuccToSinkTo && "Candidate sink target is null"); @@ -212,25 +119,26 @@ bool Sinking::IsAcceptableTarget(Instruction *Inst, // We don't want to sink across a critical edge if we don't dominate the // successor. We could be introducing calculations to new code paths. - if (!DT->dominates(Inst->getParent(), SuccToSinkTo)) + if (!DT.dominates(Inst->getParent(), SuccToSinkTo)) return false; // Don't sink instructions into a loop. - Loop *succ = LI->getLoopFor(SuccToSinkTo); - Loop *cur = LI->getLoopFor(Inst->getParent()); + Loop *succ = LI.getLoopFor(SuccToSinkTo); + Loop *cur = LI.getLoopFor(Inst->getParent()); if (succ != nullptr && succ != cur) return false; } // Finally, check that all the uses of the instruction are actually // dominated by the candidate - return AllUsesDominatedByBlock(Inst, SuccToSinkTo); + return AllUsesDominatedByBlock(Inst, SuccToSinkTo, DT); } /// SinkInstruction - Determine whether it is safe to sink the specified machine /// instruction out of its current block into a successor. -bool Sinking::SinkInstruction(Instruction *Inst, - SmallPtrSetImpl<Instruction *> &Stores) { +static bool SinkInstruction(Instruction *Inst, + SmallPtrSetImpl<Instruction *> &Stores, + DominatorTree &DT, LoopInfo &LI, AAResults &AA) { // Don't sink static alloca instructions. CodeGen assumes allocas outside the // entry block are dynamically sized stack objects. @@ -257,12 +165,12 @@ bool Sinking::SinkInstruction(Instruction *Inst, // Instructions can only be sunk if all their uses are in blocks // dominated by one of the successors. // Look at all the postdominators and see if we can sink it in one. - DomTreeNode *DTN = DT->getNode(Inst->getParent()); + DomTreeNode *DTN = DT.getNode(Inst->getParent()); for (DomTreeNode::iterator I = DTN->begin(), E = DTN->end(); I != E && SuccToSinkTo == nullptr; ++I) { BasicBlock *Candidate = (*I)->getBlock(); if ((*I)->getIDom()->getBlock() == Inst->getParent() && - IsAcceptableTarget(Inst, Candidate)) + IsAcceptableTarget(Inst, Candidate, DT, LI)) SuccToSinkTo = Candidate; } @@ -270,7 +178,7 @@ bool Sinking::SinkInstruction(Instruction *Inst, // decide which one we should sink to, if any. for (succ_iterator I = succ_begin(Inst->getParent()), E = succ_end(Inst->getParent()); I != E && !SuccToSinkTo; ++I) { - if (IsAcceptableTarget(Inst, *I)) + if (IsAcceptableTarget(Inst, *I, DT, LI)) SuccToSinkTo = *I; } @@ -288,3 +196,111 @@ bool Sinking::SinkInstruction(Instruction *Inst, Inst->moveBefore(&*SuccToSinkTo->getFirstInsertionPt()); return true; } + +static bool ProcessBlock(BasicBlock &BB, DominatorTree &DT, LoopInfo &LI, + AAResults &AA) { + // Can't sink anything out of a block that has less than two successors. + if (BB.getTerminator()->getNumSuccessors() <= 1) return false; + + // Don't bother sinking code out of unreachable blocks. In addition to being + // unprofitable, it can also lead to infinite looping, because in an + // unreachable loop there may be nowhere to stop. + if (!DT.isReachableFromEntry(&BB)) return false; + + bool MadeChange = false; + + // Walk the basic block bottom-up. Remember if we saw a store. + BasicBlock::iterator I = BB.end(); + --I; + bool ProcessedBegin = false; + SmallPtrSet<Instruction *, 8> Stores; + do { + Instruction *Inst = &*I; // The instruction to sink. + + // Predecrement I (if it's not begin) so that it isn't invalidated by + // sinking. + ProcessedBegin = I == BB.begin(); + if (!ProcessedBegin) + --I; + + if (isa<DbgInfoIntrinsic>(Inst)) + continue; + + if (SinkInstruction(Inst, Stores, DT, LI, AA)) { + ++NumSunk; + MadeChange = true; + } + + // If we just processed the first instruction in the block, we're done. + } while (!ProcessedBegin); + + return MadeChange; +} + +static bool iterativelySinkInstructions(Function &F, DominatorTree &DT, + LoopInfo &LI, AAResults &AA) { + bool MadeChange, EverMadeChange = false; + + do { + MadeChange = false; + DEBUG(dbgs() << "Sinking iteration " << NumSinkIter << "\n"); + // Process all basic blocks. + for (BasicBlock &I : F) + MadeChange |= ProcessBlock(I, DT, LI, AA); + EverMadeChange |= MadeChange; + NumSinkIter++; + } while (MadeChange); + + return EverMadeChange; +} + +PreservedAnalyses SinkingPass::run(Function &F, AnalysisManager<Function> &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + + if (!iterativelySinkInstructions(F, DT, LI, AA)) + return PreservedAnalyses::all(); + + auto PA = PreservedAnalyses(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + return PA; +} + +namespace { + class SinkingLegacyPass : public FunctionPass { + public: + static char ID; // Pass identification + SinkingLegacyPass() : FunctionPass(ID) { + initializeSinkingLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + + return iterativelySinkInstructions(F, DT, LI, AA); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + FunctionPass::getAnalysisUsage(AU); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + } + }; +} // end anonymous namespace + +char SinkingLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(SinkingLegacyPass, "sink", "Code sinking", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(SinkingLegacyPass, "sink", "Code sinking", false, false) + +FunctionPass *llvm::createSinkingPass() { return new SinkingLegacyPass(); } diff --git a/lib/Transforms/Scalar/SpeculativeExecution.cpp b/lib/Transforms/Scalar/SpeculativeExecution.cpp index 147d615488ff..9bf2d6206819 100644 --- a/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -50,9 +50,19 @@ // aggressive speculation while counting on later passes to either capitalize on // that or clean it up. // +// If the pass was created by calling +// createSpeculativeExecutionIfHasBranchDivergencePass or the +// -spec-exec-only-if-divergent-target option is present, this pass only has an +// effect on targets where TargetTransformInfo::hasBranchDivergence() is true; +// on other targets, it is a nop. +// +// This lets you include this pass unconditionally in the IR pass pipeline, but +// only enable it for relevant targets. +// //===----------------------------------------------------------------------===// #include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Instructions.h" @@ -83,19 +93,39 @@ static cl::opt<unsigned> SpecExecMaxNotHoisted( "number of instructions that would not be speculatively executed " "exceeds this limit.")); +static cl::opt<bool> SpecExecOnlyIfDivergentTarget( + "spec-exec-only-if-divergent-target", cl::init(false), cl::Hidden, + cl::desc("Speculative execution is applied only to targets with divergent " + "branches, even if the pass was configured to apply only to all " + "targets.")); + namespace { + class SpeculativeExecution : public FunctionPass { public: - static char ID; - SpeculativeExecution(): FunctionPass(ID) {} + static char ID; + explicit SpeculativeExecution(bool OnlyIfDivergentTarget = false) + : FunctionPass(ID), + OnlyIfDivergentTarget(OnlyIfDivergentTarget || + SpecExecOnlyIfDivergentTarget) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; + const char *getPassName() const override { + if (OnlyIfDivergentTarget) + return "Speculatively execute instructions if target has divergent " + "branches"; + return "Speculatively execute instructions"; + } private: bool runOnBasicBlock(BasicBlock &B); bool considerHoistingFromTo(BasicBlock &FromBlock, BasicBlock &ToBlock); + // If true, this pass is a nop unless the target architecture has branch + // divergence. + const bool OnlyIfDivergentTarget; const TargetTransformInfo *TTI = nullptr; }; } // namespace @@ -105,17 +135,23 @@ INITIALIZE_PASS_BEGIN(SpeculativeExecution, "speculative-execution", "Speculatively execute instructions", false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(SpeculativeExecution, "speculative-execution", - "Speculatively execute instructions", false, false) + "Speculatively execute instructions", false, false) void SpeculativeExecution::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); } bool SpeculativeExecution::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence()) { + DEBUG(dbgs() << "Not running SpeculativeExecution because " + "TTI->hasBranchDivergence() is false.\n"); + return false; + } bool Changed = false; for (auto& B : F) { @@ -240,4 +276,8 @@ FunctionPass *createSpeculativeExecutionPass() { return new SpeculativeExecution(); } +FunctionPass *createSpeculativeExecutionIfHasBranchDivergencePass() { + return new SpeculativeExecution(/* OnlyIfDivergentTarget = */ true); +} + } // namespace llvm diff --git a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 1faa65eb3417..292d0400a516 100644 --- a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -57,8 +57,6 @@ // SLSR. #include <vector> -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/FoldingSet.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -76,6 +74,8 @@ using namespace PatternMatch; namespace { +static const unsigned UnknownAddressSpace = ~0u; + class StraightLineStrengthReduce : public FunctionPass { public: // SLSR candidate. Such a candidate must be in one of the forms described in @@ -234,51 +234,22 @@ bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis, Basis.CandidateKind == C.CandidateKind); } -// TODO: use TTI->getGEPCost. static bool isGEPFoldable(GetElementPtrInst *GEP, - const TargetTransformInfo *TTI, - const DataLayout *DL) { - GlobalVariable *BaseGV = nullptr; - int64_t BaseOffset = 0; - bool HasBaseReg = false; - int64_t Scale = 0; - - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getPointerOperand())) - BaseGV = GV; - else - HasBaseReg = true; - - gep_type_iterator GTI = gep_type_begin(GEP); - for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I, ++GTI) { - if (isa<SequentialType>(*GTI)) { - int64_t ElementSize = DL->getTypeAllocSize(GTI.getIndexedType()); - if (ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I)) { - BaseOffset += ConstIdx->getSExtValue() * ElementSize; - } else { - // Needs scale register. - if (Scale != 0) { - // No addressing mode takes two scale registers. - return false; - } - Scale = ElementSize; - } - } else { - StructType *STy = cast<StructType>(*GTI); - uint64_t Field = cast<ConstantInt>(*I)->getZExtValue(); - BaseOffset += DL->getStructLayout(STy)->getElementOffset(Field); - } - } - - unsigned AddrSpace = GEP->getPointerAddressSpace(); - return TTI->isLegalAddressingMode(GEP->getType()->getElementType(), BaseGV, - BaseOffset, HasBaseReg, Scale, AddrSpace); + const TargetTransformInfo *TTI) { + SmallVector<const Value*, 4> Indices; + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) + Indices.push_back(*I); + return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(), + Indices) == TargetTransformInfo::TCC_Free; } // Returns whether (Base + Index * Stride) can be folded to an addressing mode. static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI) { - return TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true, - Index->getSExtValue()); + // Index->getSExtValue() may crash if Index is wider than 64-bit. + return Index->getBitWidth() <= 64 && + TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true, + Index->getSExtValue(), UnknownAddressSpace); } bool StraightLineStrengthReduce::isFoldable(const Candidate &C, @@ -287,7 +258,7 @@ bool StraightLineStrengthReduce::isFoldable(const Candidate &C, if (C.CandidateKind == Candidate::Add) return isAddFoldable(C.Base, C.Index, C.Stride, TTI); if (C.CandidateKind == Candidate::GEP) - return isGEPFoldable(cast<GetElementPtrInst>(C.Ins), TTI, DL); + return isGEPFoldable(cast<GetElementPtrInst>(C.Ins), TTI); return false; } @@ -533,13 +504,23 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( IndexExprs, GEP->isInBounds()); Value *ArrayIdx = GEP->getOperand(I); uint64_t ElementSize = DL->getTypeAllocSize(*GTI); - factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP); + if (ArrayIdx->getType()->getIntegerBitWidth() <= + DL->getPointerSizeInBits(GEP->getAddressSpace())) { + // Skip factoring if ArrayIdx is wider than the pointer size, because + // ArrayIdx is implicitly truncated to the pointer size. + factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP); + } // When ArrayIdx is the sext of a value, we try to factor that value as // well. Handling this case is important because array indices are // typically sign-extended to the pointer size. Value *TruncatedArrayIdx = nullptr; - if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx)))) + if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))) && + TruncatedArrayIdx->getType()->getIntegerBitWidth() <= + DL->getPointerSizeInBits(GEP->getAddressSpace())) { + // Skip factoring if TruncatedArrayIdx is wider than the pointer size, + // because TruncatedArrayIdx is implicitly truncated to the pointer size. factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP); + } IndexExprs[I - 1] = OrigIndexExpr; } @@ -567,10 +548,10 @@ Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis, APInt ElementSize( IndexOffset.getBitWidth(), DL->getTypeAllocSize( - cast<GetElementPtrInst>(Basis.Ins)->getType()->getElementType())); + cast<GetElementPtrInst>(Basis.Ins)->getResultElementType())); APInt Q, R; APInt::sdivrem(IndexOffset, ElementSize, Q, R); - if (R.getSExtValue() == 0) + if (R == 0) IndexOffset = Q; else BumpWithUglyGEP = true; @@ -578,10 +559,10 @@ Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis, // Compute Bump = C - Basis = (i' - i) * S. // Common case 1: if (i' - i) is 1, Bump = S. - if (IndexOffset.getSExtValue() == 1) + if (IndexOffset == 1) return C.Stride; // Common case 2: if (i' - i) is -1, Bump = -S. - if (IndexOffset.getSExtValue() == -1) + if (IndexOffset.isAllOnesValue()) return Builder.CreateNeg(C.Stride); // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may @@ -685,7 +666,7 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis( } bool StraightLineStrengthReduce::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) + if (skipFunction(F)) return false; TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); diff --git a/lib/Transforms/Scalar/StructurizeCFG.cpp b/lib/Transforms/Scalar/StructurizeCFG.cpp index 662513c7d8ae..e9ac39beae5a 100644 --- a/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -11,6 +11,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" @@ -161,6 +162,9 @@ public: /// consist of a network of PHI nodes where the true incoming values expresses /// breaks and the false values expresses continue states. class StructurizeCFG : public RegionPass { + bool SkipUniformRegions; + DivergenceAnalysis *DA; + Type *Boolean; ConstantInt *BoolTrue; ConstantInt *BoolFalse; @@ -232,11 +236,18 @@ class StructurizeCFG : public RegionPass { void rebuildSSA(); + bool hasOnlyUniformBranches(const Region *R); + public: static char ID; StructurizeCFG() : - RegionPass(ID) { + RegionPass(ID), SkipUniformRegions(false) { + initializeStructurizeCFGPass(*PassRegistry::getPassRegistry()); + } + + StructurizeCFG(bool SkipUniformRegions) : + RegionPass(ID), SkipUniformRegions(SkipUniformRegions) { initializeStructurizeCFGPass(*PassRegistry::getPassRegistry()); } @@ -250,6 +261,8 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { + if (SkipUniformRegions) + AU.addRequired<DivergenceAnalysis>(); AU.addRequiredID(LowerSwitchID); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); @@ -264,6 +277,7 @@ char StructurizeCFG::ID = 0; INITIALIZE_PASS_BEGIN(StructurizeCFG, "structurizecfg", "Structurize the CFG", false, false) +INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) INITIALIZE_PASS_DEPENDENCY(LowerSwitch) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(RegionInfoPass) @@ -297,11 +311,7 @@ void StructurizeCFG::orderNodes() { for (RegionNode *RN : TempOrder) { BasicBlock *BB = RN->getEntry(); Loop *Loop = LI->getLoopFor(BB); - if (!LoopBlocks.count(Loop)) { - LoopBlocks[Loop] = 1; - continue; - } - LoopBlocks[Loop]++; + ++LoopBlocks[Loop]; } unsigned CurrentLoopDepth = 0; @@ -319,11 +329,11 @@ void StructurizeCFG::orderNodes() { // the outer loop. RNVector::iterator LoopI = I; - while(LoopBlocks[CurrentLoop]) { + while (unsigned &BlockCount = LoopBlocks[CurrentLoop]) { LoopI++; BasicBlock *LoopBB = (*LoopI)->getEntry(); if (LI->getLoopFor(LoopBB) == CurrentLoop) { - LoopBlocks[CurrentLoop]--; + --BlockCount; Order.push_back(*LoopI); } } @@ -367,14 +377,8 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) { /// \brief Invert the given condition Value *StructurizeCFG::invert(Value *Condition) { // First: Check if it's a constant - if (Condition == BoolTrue) - return BoolFalse; - - if (Condition == BoolFalse) - return BoolTrue; - - if (Condition == BoolUndef) - return BoolUndef; + if (Constant *C = dyn_cast<Constant>(Condition)) + return ConstantExpr::getNot(C); // Second: If the condition is already inverted, return the original value if (match(Condition, m_Not(m_Value(Condition)))) @@ -491,21 +495,21 @@ void StructurizeCFG::collectInfos() { // Reset the visited nodes Visited.clear(); - for (RNVector::reverse_iterator OI = Order.rbegin(), OE = Order.rend(); - OI != OE; ++OI) { + for (RegionNode *RN : reverse(Order)) { - DEBUG(dbgs() << "Visiting: " << - ((*OI)->isSubRegion() ? "SubRegion with entry: " : "") << - (*OI)->getEntry()->getName() << " Loop Depth: " << LI->getLoopDepth((*OI)->getEntry()) << "\n"); + DEBUG(dbgs() << "Visiting: " + << (RN->isSubRegion() ? "SubRegion with entry: " : "") + << RN->getEntry()->getName() << " Loop Depth: " + << LI->getLoopDepth(RN->getEntry()) << "\n"); // Analyze all the conditions leading to a node - gatherPredicates(*OI); + gatherPredicates(RN); // Remember that we've seen this node - Visited.insert((*OI)->getEntry()); + Visited.insert(RN->getEntry()); // Find the last back edges - analyzeLoops(*OI); + analyzeLoops(RN); } } @@ -584,20 +588,18 @@ void StructurizeCFG::addPhiValues(BasicBlock *From, BasicBlock *To) { /// \brief Add the real PHI value as soon as everything is set up void StructurizeCFG::setPhiValues() { SSAUpdater Updater; - for (BB2BBVecMap::iterator AI = AddedPhis.begin(), AE = AddedPhis.end(); - AI != AE; ++AI) { + for (const auto &AddedPhi : AddedPhis) { - BasicBlock *To = AI->first; - BBVector &From = AI->second; + BasicBlock *To = AddedPhi.first; + const BBVector &From = AddedPhi.second; if (!DeletedPhis.count(To)) continue; PhiMap &Map = DeletedPhis[To]; - for (PhiMap::iterator PI = Map.begin(), PE = Map.end(); - PI != PE; ++PI) { + for (const auto &PI : Map) { - PHINode *Phi = PI->first; + PHINode *Phi = PI.first; Value *Undef = UndefValue::get(Phi->getType()); Updater.Initialize(Phi->getType(), ""); Updater.AddAvailableValue(&Func->getEntryBlock(), Undef); @@ -605,22 +607,20 @@ void StructurizeCFG::setPhiValues() { NearestCommonDominator Dominator(DT); Dominator.addBlock(To, false); - for (BBValueVector::iterator VI = PI->second.begin(), - VE = PI->second.end(); VI != VE; ++VI) { + for (const auto &VI : PI.second) { - Updater.AddAvailableValue(VI->first, VI->second); - Dominator.addBlock(VI->first); + Updater.AddAvailableValue(VI.first, VI.second); + Dominator.addBlock(VI.first); } if (!Dominator.wasResultExplicitMentioned()) Updater.AddAvailableValue(Dominator.getResult(), Undef); - for (BBVector::iterator FI = From.begin(), FE = From.end(); - FI != FE; ++FI) { + for (BasicBlock *FI : From) { - int Idx = Phi->getBasicBlockIndex(*FI); + int Idx = Phi->getBasicBlockIndex(FI); assert(Idx != -1); - Phi->setIncomingValue(Idx, Updater.GetValueAtEndOfBlock(*FI)); + Phi->setIncomingValue(Idx, Updater.GetValueAtEndOfBlock(FI)); } } @@ -914,11 +914,48 @@ void StructurizeCFG::rebuildSSA() { } } +bool StructurizeCFG::hasOnlyUniformBranches(const Region *R) { + for (const BasicBlock *BB : R->blocks()) { + const BranchInst *Br = dyn_cast<BranchInst>(BB->getTerminator()); + if (!Br || !Br->isConditional()) + continue; + + if (!DA->isUniform(Br->getCondition())) + return false; + DEBUG(dbgs() << "BB: " << BB->getName() << " has uniform terminator\n"); + } + return true; +} + /// \brief Run the transformation for each region found bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { if (R->isTopLevelRegion()) return false; + if (SkipUniformRegions) { + DA = &getAnalysis<DivergenceAnalysis>(); + // TODO: We could probably be smarter here with how we handle sub-regions. + if (hasOnlyUniformBranches(R)) { + DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R << '\n'); + + // Mark all direct child block terminators as having been treated as + // uniform. To account for a possible future in which non-uniform + // sub-regions are treated more cleverly, indirect children are not + // marked as uniform. + MDNode *MD = MDNode::get(R->getEntry()->getParent()->getContext(), {}); + Region::element_iterator E = R->element_end(); + for (Region::element_iterator I = R->element_begin(); I != E; ++I) { + if (I->isSubRegion()) + continue; + + if (Instruction *Term = I->getEntry()->getTerminator()) + Term->setMetadata("structurizecfg.uniform", MD); + } + + return false; + } + } + Func = R->getEntry()->getParent(); ParentRegion = R; @@ -947,7 +984,6 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { return true; } -/// \brief Create the pass -Pass *llvm::createStructurizeCFGPass() { - return new StructurizeCFG(); +Pass *llvm::createStructurizeCFGPass(bool SkipUniformRegions) { + return new StructurizeCFG(SkipUniformRegions); } diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp index 4e84d72ae7bd..d5ff99750370 100644 --- a/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -50,6 +50,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/TailRecursionElimination.h" #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -85,64 +86,9 @@ STATISTIC(NumEliminated, "Number of tail calls removed"); STATISTIC(NumRetDuped, "Number of return duplicated"); STATISTIC(NumAccumAdded, "Number of accumulators introduced"); -namespace { - struct TailCallElim : public FunctionPass { - const TargetTransformInfo *TTI; - - static char ID; // Pass identification, replacement for typeid - TailCallElim() : FunctionPass(ID) { - initializeTailCallElimPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override; - - bool runOnFunction(Function &F) override; - - private: - bool runTRE(Function &F); - bool markTails(Function &F, bool &AllCallsAreTailCalls); - - CallInst *FindTRECandidate(Instruction *I, - bool CannotTailCallElimCallsMarkedTail); - bool EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, - BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail); - bool FoldReturnAndProcessPred(BasicBlock *BB, - ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail); - bool ProcessReturningBlock(ReturnInst *RI, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail); - bool CanMoveAboveCall(Instruction *I, CallInst *CI); - Value *CanTransformAccumulatorRecursion(Instruction *I, CallInst *CI); - }; -} - -char TailCallElim::ID = 0; -INITIALIZE_PASS_BEGIN(TailCallElim, "tailcallelim", - "Tail Call Elimination", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(TailCallElim, "tailcallelim", - "Tail Call Elimination", false, false) - -// Public interface to the TailCallElimination pass -FunctionPass *llvm::createTailCallEliminationPass() { - return new TailCallElim(); -} - -void TailCallElim::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); -} - /// \brief Scan the specified function for alloca instructions. /// If it contains any dynamic allocas, returns false. -static bool CanTRE(Function &F) { +static bool canTRE(Function &F) { // Because of PR962, we don't TRE dynamic allocas. for (auto &BB : F) { for (auto &I : BB) { @@ -156,20 +102,6 @@ static bool CanTRE(Function &F) { return true; } -bool TailCallElim::runOnFunction(Function &F) { - if (skipOptnoneFunction(F)) - return false; - - if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") - return false; - - bool AllCallsAreTailCalls = false; - bool Modified = markTails(F, AllCallsAreTailCalls); - if (AllCallsAreTailCalls) - Modified |= runTRE(F); - return Modified; -} - namespace { struct AllocaDerivedValueTracker { // Start at a root value and walk its use-def chain to mark calls that use the @@ -250,7 +182,7 @@ struct AllocaDerivedValueTracker { }; } -bool TailCallElim::markTails(Function &F, bool &AllCallsAreTailCalls) { +static bool markTails(Function &F, bool &AllCallsAreTailCalls) { if (F.callsFunctionThatReturnsTwice()) return false; AllCallsAreTailCalls = true; @@ -385,63 +317,11 @@ bool TailCallElim::markTails(Function &F, bool &AllCallsAreTailCalls) { return Modified; } -bool TailCallElim::runTRE(Function &F) { - // If this function is a varargs function, we won't be able to PHI the args - // right, so don't even try to convert it... - if (F.getFunctionType()->isVarArg()) return false; - - TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - BasicBlock *OldEntry = nullptr; - bool TailCallsAreMarkedTail = false; - SmallVector<PHINode*, 8> ArgumentPHIs; - bool MadeChange = false; - - // If false, we cannot perform TRE on tail calls marked with the 'tail' - // attribute, because doing so would cause the stack size to increase (real - // TRE would deallocate variable sized allocas, TRE doesn't). - bool CanTRETailMarkedCall = CanTRE(F); - - // Change any tail recursive calls to loops. - // - // FIXME: The code generator produces really bad code when an 'escaping - // alloca' is changed from being a static alloca to being a dynamic alloca. - // Until this is resolved, disable this transformation if that would ever - // happen. This bug is PR962. - for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) { - BasicBlock *BB = &*BBI++; // FoldReturnAndProcessPred may delete BB. - if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { - bool Change = ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, !CanTRETailMarkedCall); - if (!Change && BB->getFirstNonPHIOrDbg() == Ret) - Change = FoldReturnAndProcessPred(BB, Ret, OldEntry, - TailCallsAreMarkedTail, ArgumentPHIs, - !CanTRETailMarkedCall); - MadeChange |= Change; - } - } - - // If we eliminated any tail recursions, it's possible that we inserted some - // silly PHI nodes which just merge an initial value (the incoming operand) - // with themselves. Check to see if we did and clean up our mess if so. This - // occurs when a function passes an argument straight through to its tail - // call. - for (PHINode *PN : ArgumentPHIs) { - // If the PHI Node is a dynamic constant, replace it with the value it is. - if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { - PN->replaceAllUsesWith(PNV); - PN->eraseFromParent(); - } - } - - return MadeChange; -} - - /// Return true if it is safe to move the specified /// instruction from after the call to before the call, assuming that all /// instructions between the call and this instruction are movable. /// -bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) { +static bool canMoveAboveCall(Instruction *I, CallInst *CI) { // FIXME: We can move load/store/call/free instructions above the call if the // call does not mod/ref the memory location being processed. if (I->mayHaveSideEffects()) // This also handles volatile loads. @@ -454,9 +334,10 @@ bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) { // does not write to memory and the load provably won't trap. // FIXME: Writes to memory only matter if they may alias the pointer // being loaded from. + const DataLayout &DL = L->getModule()->getDataLayout(); if (CI->mayWriteToMemory() || - !isSafeToLoadUnconditionally(L->getPointerOperand(), L, - L->getAlignment())) + !isSafeToLoadUnconditionally(L->getPointerOperand(), + L->getAlignment(), DL, L)) return false; } } @@ -512,8 +393,8 @@ static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) { Function *F = CI->getParent()->getParent(); Value *ReturnedValue = nullptr; - for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) { - ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator()); + for (BasicBlock &BBI : *F) { + ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator()); if (RI == nullptr || RI == IgnoreRI) continue; // We can only perform this transformation if the value returned is @@ -534,8 +415,7 @@ static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) { /// If the specified instruction can be transformed using accumulator recursion /// elimination, return the constant which is the start of the accumulator /// value. Otherwise return null. -Value *TailCallElim::CanTransformAccumulatorRecursion(Instruction *I, - CallInst *CI) { +static Value *canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { if (!I->isAssociative() || !I->isCommutative()) return nullptr; assert(I->getNumOperands() == 2 && "Associative/commutative operations should have 2 args!"); @@ -555,15 +435,15 @@ Value *TailCallElim::CanTransformAccumulatorRecursion(Instruction *I, return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI); } -static Instruction *FirstNonDbg(BasicBlock::iterator I) { +static Instruction *firstNonDbg(BasicBlock::iterator I) { while (isa<DbgInfoIntrinsic>(I)) ++I; return &*I; } -CallInst* -TailCallElim::FindTRECandidate(Instruction *TI, - bool CannotTailCallElimCallsMarkedTail) { +static CallInst *findTRECandidate(Instruction *TI, + bool CannotTailCallElimCallsMarkedTail, + const TargetTransformInfo *TTI) { BasicBlock *BB = TI->getParent(); Function *F = BB->getParent(); @@ -594,8 +474,8 @@ TailCallElim::FindTRECandidate(Instruction *TI, // and disable this xform in this case, because the code generator will // lower the call to fabs into inline code. if (BB == &F->getEntryBlock() && - FirstNonDbg(BB->front().getIterator()) == CI && - FirstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() && + firstNonDbg(BB->front().getIterator()) == CI && + firstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() && !TTI->isLoweredToCall(CI->getCalledFunction())) { // A single-block function with just a call and a return. Check that // the arguments match. @@ -612,7 +492,7 @@ TailCallElim::FindTRECandidate(Instruction *TI, return CI; } -bool TailCallElim::EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, +static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, @@ -636,14 +516,14 @@ bool TailCallElim::EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, // Check that this is the case now. BasicBlock::iterator BBI(CI); for (++BBI; &*BBI != Ret; ++BBI) { - if (CanMoveAboveCall(&*BBI, CI)) continue; + if (canMoveAboveCall(&*BBI, CI)) continue; // If we can't move the instruction above the call, it might be because it // is an associative and commutative operation that could be transformed // using accumulator recursion elimination. Check to see if this is the // case, and if so, remember the initial accumulator value for later. if ((AccumulatorRecursionEliminationInitVal = - CanTransformAccumulatorRecursion(&*BBI, CI))) { + canTransformAccumulatorRecursion(&*BBI, CI))) { // Yes, this is accumulator recursion. Remember which instruction // accumulates. AccumulatorRecursionInstr = &*BBI; @@ -773,8 +653,8 @@ bool TailCallElim::EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, // Finally, rewrite any return instructions in the program to return the PHI // node instead of the "initval" that they do currently. This loop will // actually rewrite the return value we are destroying, but that's ok. - for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) - if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator())) + for (BasicBlock &BBI : *F) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator())) RI->setOperand(0, AccPN); ++NumAccumAdded; } @@ -790,11 +670,12 @@ bool TailCallElim::EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, return true; } -bool TailCallElim::FoldReturnAndProcessPred(BasicBlock *BB, - ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail) { +static bool foldReturnAndProcessPred(BasicBlock *BB, ReturnInst *Ret, + BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, + SmallVectorImpl<PHINode *> &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail, + const TargetTransformInfo *TTI) { bool Change = false; // If the return block contains nothing but the return and PHI's, @@ -813,7 +694,7 @@ bool TailCallElim::FoldReturnAndProcessPred(BasicBlock *BB, while (!UncondBranchPreds.empty()) { BranchInst *BI = UncondBranchPreds.pop_back_val(); BasicBlock *Pred = BI->getParent(); - if (CallInst *CI = FindTRECandidate(BI, CannotTailCallElimCallsMarkedTail)){ + if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){ DEBUG(dbgs() << "FOLDING: " << *BB << "INTO UNCOND BRANCH PRED: " << *Pred); ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred); @@ -821,11 +702,11 @@ bool TailCallElim::FoldReturnAndProcessPred(BasicBlock *BB, // Cleanup: if all predecessors of BB have been eliminated by // FoldReturnIntoUncondBranch, delete it. It is important to empty it, // because the ret instruction in there is still using a value which - // EliminateRecursiveTailCall will attempt to remove. + // eliminateRecursiveTailCall will attempt to remove. if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) BB->eraseFromParent(); - EliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, + eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, CannotTailCallElimCallsMarkedTail); ++NumRetDuped; @@ -836,16 +717,124 @@ bool TailCallElim::FoldReturnAndProcessPred(BasicBlock *BB, return Change; } -bool -TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail) { - CallInst *CI = FindTRECandidate(Ret, CannotTailCallElimCallsMarkedTail); +static bool processReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, + SmallVectorImpl<PHINode *> &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail, + const TargetTransformInfo *TTI) { + CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); if (!CI) return false; - return EliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, + return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, CannotTailCallElimCallsMarkedTail); } + +static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI) { + if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") + return false; + + bool MadeChange = false; + bool AllCallsAreTailCalls = false; + MadeChange |= markTails(F, AllCallsAreTailCalls); + if (!AllCallsAreTailCalls) + return MadeChange; + + // If this function is a varargs function, we won't be able to PHI the args + // right, so don't even try to convert it... + if (F.getFunctionType()->isVarArg()) + return false; + + BasicBlock *OldEntry = nullptr; + bool TailCallsAreMarkedTail = false; + SmallVector<PHINode*, 8> ArgumentPHIs; + + // If false, we cannot perform TRE on tail calls marked with the 'tail' + // attribute, because doing so would cause the stack size to increase (real + // TRE would deallocate variable sized allocas, TRE doesn't). + bool CanTRETailMarkedCall = canTRE(F); + + // Change any tail recursive calls to loops. + // + // FIXME: The code generator produces really bad code when an 'escaping + // alloca' is changed from being a static alloca to being a dynamic alloca. + // Until this is resolved, disable this transformation if that would ever + // happen. This bug is PR962. + for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) { + BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB. + if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { + bool Change = + processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs, !CanTRETailMarkedCall, TTI); + if (!Change && BB->getFirstNonPHIOrDbg() == Ret) + Change = + foldReturnAndProcessPred(BB, Ret, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs, !CanTRETailMarkedCall, TTI); + MadeChange |= Change; + } + } + + // If we eliminated any tail recursions, it's possible that we inserted some + // silly PHI nodes which just merge an initial value (the incoming operand) + // with themselves. Check to see if we did and clean up our mess if so. This + // occurs when a function passes an argument straight through to its tail + // call. + for (PHINode *PN : ArgumentPHIs) { + // If the PHI Node is a dynamic constant, replace it with the value it is. + if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { + PN->replaceAllUsesWith(PNV); + PN->eraseFromParent(); + } + } + + return MadeChange; +} + +namespace { +struct TailCallElim : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + TailCallElim() : FunctionPass(ID) { + initializeTailCallElimPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + return eliminateTailRecursion( + F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F)); + } +}; +} + +char TailCallElim::ID = 0; +INITIALIZE_PASS_BEGIN(TailCallElim, "tailcallelim", "Tail Call Elimination", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(TailCallElim, "tailcallelim", "Tail Call Elimination", + false, false) + +// Public interface to the TailCallElimination pass +FunctionPass *llvm::createTailCallEliminationPass() { + return new TailCallElim(); +} + +PreservedAnalyses TailCallElimPass::run(Function &F, + FunctionAnalysisManager &AM) { + + TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); + + bool Changed = eliminateTailRecursion(F, &TTI); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/lib/Transforms/Utils/ASanStackFrameLayout.cpp b/lib/Transforms/Utils/ASanStackFrameLayout.cpp index 409326eba401..7e50d4bb447e 100644 --- a/lib/Transforms/Utils/ASanStackFrameLayout.cpp +++ b/lib/Transforms/Utils/ASanStackFrameLayout.cpp @@ -44,7 +44,7 @@ static size_t VarAndRedzoneSize(size_t Size, size_t Alignment) { else if (Size <= 512) Res = Size + 64; else if (Size <= 4096) Res = Size + 128; else Res = Size + 256; - return RoundUpToAlignment(Res, Alignment); + return alignTo(Res, Alignment); } void diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index 0262358fa3d5..d034905b6572 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -52,7 +52,9 @@ // http://wiki.dwarfstd.org/index.php?title=Path_Discriminators //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/AddDiscriminators.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" @@ -72,20 +74,22 @@ using namespace llvm; #define DEBUG_TYPE "add-discriminators" namespace { -struct AddDiscriminators : public FunctionPass { +// The legacy pass of AddDiscriminators. +struct AddDiscriminatorsLegacyPass : public FunctionPass { static char ID; // Pass identification, replacement for typeid - AddDiscriminators() : FunctionPass(ID) { - initializeAddDiscriminatorsPass(*PassRegistry::getPassRegistry()); + AddDiscriminatorsLegacyPass() : FunctionPass(ID) { + initializeAddDiscriminatorsLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; }; -} -char AddDiscriminators::ID = 0; -INITIALIZE_PASS_BEGIN(AddDiscriminators, "add-discriminators", +} // end anonymous namespace + +char AddDiscriminatorsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(AddDiscriminatorsLegacyPass, "add-discriminators", "Add DWARF path discriminators", false, false) -INITIALIZE_PASS_END(AddDiscriminators, "add-discriminators", +INITIALIZE_PASS_END(AddDiscriminatorsLegacyPass, "add-discriminators", "Add DWARF path discriminators", false, false) // Command line option to disable discriminator generation even in the @@ -95,13 +99,9 @@ static cl::opt<bool> NoDiscriminators( "no-discriminators", cl::init(false), cl::desc("Disable generation of discriminator information.")); +// Create the legacy AddDiscriminatorsPass. FunctionPass *llvm::createAddDiscriminatorsPass() { - return new AddDiscriminators(); -} - -static bool hasDebugInfo(const Function &F) { - DISubprogram *S = getDISubprogram(&F); - return S != nullptr; + return new AddDiscriminatorsLegacyPass(); } /// \brief Assign DWARF discriminators. @@ -155,13 +155,13 @@ static bool hasDebugInfo(const Function &F) { /// lexical block for I2 and all the instruction in B2 that share the same /// file and line location as I2. This new lexical block will have a /// different discriminator number than I1. -bool AddDiscriminators::runOnFunction(Function &F) { +static bool addDiscriminators(Function &F) { // If the function has debug information, but the user has disabled // discriminators, do nothing. // Simlarly, if the function has no debug info, do nothing. // Finally, if this module is built with dwarf versions earlier than 4, // do nothing (discriminator support is a DWARF 4 feature). - if (NoDiscriminators || !hasDebugInfo(F) || + if (NoDiscriminators || !F.getSubprogram() || F.getParent()->getDwarfVersion() < 4) return false; @@ -173,8 +173,11 @@ bool AddDiscriminators::runOnFunction(Function &F) { typedef std::pair<StringRef, unsigned> Location; typedef DenseMap<const BasicBlock *, Metadata *> BBScopeMap; typedef DenseMap<Location, BBScopeMap> LocationBBMap; + typedef DenseMap<Location, unsigned> LocationDiscriminatorMap; + typedef DenseSet<Location> LocationSet; LocationBBMap LBM; + LocationDiscriminatorMap LDM; // Traverse all instructions in the function. If the source line location // of the instruction appears in other basic block, assign a new @@ -199,8 +202,7 @@ bool AddDiscriminators::runOnFunction(Function &F) { auto *Scope = DIL->getScope(); auto *File = Builder.createFile(DIL->getFilename(), Scope->getDirectory()); - NewScope = Builder.createLexicalBlockFile( - Scope, File, DIL->computeNewDiscriminator()); + NewScope = Builder.createLexicalBlockFile(Scope, File, ++LDM[L]); } I.setDebugLoc(DILocation::get(Ctx, DIL->getLine(), DIL->getColumn(), NewScope, DIL->getInlinedAt())); @@ -217,32 +219,40 @@ bool AddDiscriminators::runOnFunction(Function &F) { // Sample base profile needs to distinguish different function calls within // a same source line for correct profile annotation. for (BasicBlock &B : F) { - const DILocation *FirstDIL = NULL; + LocationSet CallLocations; for (auto &I : B.getInstList()) { CallInst *Current = dyn_cast<CallInst>(&I); if (!Current || isa<DbgInfoIntrinsic>(&I)) continue; DILocation *CurrentDIL = Current->getDebugLoc(); - if (FirstDIL) { - if (CurrentDIL && CurrentDIL->getLine() == FirstDIL->getLine() && - CurrentDIL->getFilename() == FirstDIL->getFilename()) { - auto *Scope = FirstDIL->getScope(); - auto *File = Builder.createFile(FirstDIL->getFilename(), - Scope->getDirectory()); - auto *NewScope = Builder.createLexicalBlockFile( - Scope, File, FirstDIL->computeNewDiscriminator()); - Current->setDebugLoc(DILocation::get( - Ctx, CurrentDIL->getLine(), CurrentDIL->getColumn(), NewScope, - CurrentDIL->getInlinedAt())); - Changed = true; - } else { - FirstDIL = CurrentDIL; - } - } else { - FirstDIL = CurrentDIL; + if (!CurrentDIL) + continue; + Location L = + std::make_pair(CurrentDIL->getFilename(), CurrentDIL->getLine()); + if (!CallLocations.insert(L).second) { + auto *Scope = CurrentDIL->getScope(); + auto *File = Builder.createFile(CurrentDIL->getFilename(), + Scope->getDirectory()); + auto *NewScope = Builder.createLexicalBlockFile(Scope, File, ++LDM[L]); + Current->setDebugLoc(DILocation::get(Ctx, CurrentDIL->getLine(), + CurrentDIL->getColumn(), NewScope, + CurrentDIL->getInlinedAt())); + Changed = true; } } } return Changed; } + +bool AddDiscriminatorsLegacyPass::runOnFunction(Function &F) { + return addDiscriminators(F); +} +PreservedAnalyses AddDiscriminatorsPass::run(Function &F, + AnalysisManager<Function> &AM) { + if (!addDiscriminators(F)) + return PreservedAnalyses::all(); + + // FIXME: should be all() + return PreservedAnalyses::none(); +} diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 72db980cf572..b90349d3cdad 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -31,8 +31,6 @@ #include <algorithm> using namespace llvm; -/// DeleteDeadBlock - Delete the specified block, which must have no -/// predecessors. void llvm::DeleteDeadBlock(BasicBlock *BB) { assert((pred_begin(BB) == pred_end(BB) || // Can delete self loop. @@ -61,12 +59,8 @@ void llvm::DeleteDeadBlock(BasicBlock *BB) { BB->eraseFromParent(); } -/// FoldSingleEntryPHINodes - We know that BB has one predecessor. If there are -/// any single-entry PHI nodes in it, fold them away. This handles the case -/// when all entries to the PHI nodes in a block are guaranteed equal, such as -/// when the block has exactly one predecessor. void llvm::FoldSingleEntryPHINodes(BasicBlock *BB, - MemoryDependenceAnalysis *MemDep) { + MemoryDependenceResults *MemDep) { if (!isa<PHINode>(BB->begin())) return; while (PHINode *PN = dyn_cast<PHINode>(BB->begin())) { @@ -82,11 +76,6 @@ void llvm::FoldSingleEntryPHINodes(BasicBlock *BB, } } - -/// DeleteDeadPHIs - Examine each PHI in the given block and delete it if it -/// is dead. Also recursively delete any operands that become dead as -/// a result. This includes tracing the def-use list from the PHI to see if -/// it is ultimately unused or if it reaches an unused cycle. bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI) { // Recursively deleting a PHI may cause multiple PHIs to be deleted // or RAUW'd undef, so use an array of WeakVH for the PHIs to delete. @@ -103,11 +92,9 @@ bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI) { return Changed; } -/// MergeBlockIntoPredecessor - Attempts to merge a block into its predecessor, -/// if possible. The return value indicates success or failure. bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, LoopInfo *LI, - MemoryDependenceAnalysis *MemDep) { + MemoryDependenceResults *MemDep) { // Don't merge away blocks who have their address taken. if (BB->hasAddressTaken()) return false; @@ -165,10 +152,8 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, if (DomTreeNode *DTN = DT->getNode(BB)) { DomTreeNode *PredDTN = DT->getNode(PredBB); SmallVector<DomTreeNode *, 8> Children(DTN->begin(), DTN->end()); - for (SmallVectorImpl<DomTreeNode *>::iterator DI = Children.begin(), - DE = Children.end(); - DI != DE; ++DI) - DT->changeImmediateDominator(*DI, PredDTN); + for (DomTreeNode *DI : Children) + DT->changeImmediateDominator(DI, PredDTN); DT->eraseNode(BB); } @@ -183,9 +168,6 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, return true; } -/// ReplaceInstWithValue - Replace all uses of an instruction (specified by BI) -/// with a value, then remove and delete the original instruction. -/// void llvm::ReplaceInstWithValue(BasicBlock::InstListType &BIL, BasicBlock::iterator &BI, Value *V) { Instruction &I = *BI; @@ -200,11 +182,6 @@ void llvm::ReplaceInstWithValue(BasicBlock::InstListType &BIL, BI = BIL.erase(BI); } - -/// ReplaceInstWithInst - Replace the instruction specified by BI with the -/// instruction specified by I. The original instruction is deleted and BI is -/// updated to point to the new instruction. -/// void llvm::ReplaceInstWithInst(BasicBlock::InstListType &BIL, BasicBlock::iterator &BI, Instruction *I) { assert(I->getParent() == nullptr && @@ -225,16 +202,11 @@ void llvm::ReplaceInstWithInst(BasicBlock::InstListType &BIL, BI = New; } -/// ReplaceInstWithInst - Replace the instruction specified by From with the -/// instruction specified by To. -/// void llvm::ReplaceInstWithInst(Instruction *From, Instruction *To) { BasicBlock::iterator BI(From); ReplaceInstWithInst(From->getParent()->getInstList(), BI, To); } -/// SplitEdge - Split the edge connecting specified block. Pass P must -/// not be NULL. BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, DominatorTree *DT, LoopInfo *LI) { unsigned SuccNum = GetSuccessorNumber(BB, Succ); @@ -266,8 +238,8 @@ unsigned llvm::SplitAllCriticalEdges(Function &F, const CriticalEdgeSplittingOptions &Options) { unsigned NumBroken = 0; - for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { - TerminatorInst *TI = I->getTerminator(); + for (BasicBlock &BB : F) { + TerminatorInst *TI = BB.getTerminator(); if (TI->getNumSuccessors() > 1 && !isa<IndirectBrInst>(TI)) for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) if (SplitCriticalEdge(TI, i, Options)) @@ -276,11 +248,6 @@ llvm::SplitAllCriticalEdges(Function &F, return NumBroken; } -/// SplitBlock - Split the specified block at the specified instruction - every -/// thing before SplitPt stays in Old and everything starting with SplitPt moves -/// to a new block. The two blocks are joined by an unconditional branch and -/// the loop info is updated. -/// BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, DominatorTree *DT, LoopInfo *LI) { BasicBlock::iterator SplitIt = SplitPt->getIterator(); @@ -297,22 +264,17 @@ BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, if (DT) // Old dominates New. New node dominates all other nodes dominated by Old. if (DomTreeNode *OldNode = DT->getNode(Old)) { - std::vector<DomTreeNode *> Children; - for (DomTreeNode::iterator I = OldNode->begin(), E = OldNode->end(); - I != E; ++I) - Children.push_back(*I); + std::vector<DomTreeNode *> Children(OldNode->begin(), OldNode->end()); DomTreeNode *NewNode = DT->addNewBlock(New, Old); - for (std::vector<DomTreeNode *>::iterator I = Children.begin(), - E = Children.end(); I != E; ++I) - DT->changeImmediateDominator(*I, NewNode); + for (DomTreeNode *I : Children) + DT->changeImmediateDominator(I, NewNode); } return New; } -/// UpdateAnalysisInformation - Update DominatorTree, LoopInfo, and LCCSA -/// analysis information. +/// Update DominatorTree, LoopInfo, and LCCSA analysis information. static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, ArrayRef<BasicBlock *> Preds, DominatorTree *DT, LoopInfo *LI, @@ -331,10 +293,7 @@ static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, // this split will affect loops. bool IsLoopEntry = !!L; bool SplitMakesNewLoopHeader = false; - for (ArrayRef<BasicBlock *>::iterator i = Preds.begin(), e = Preds.end(); - i != e; ++i) { - BasicBlock *Pred = *i; - + for (BasicBlock *Pred : Preds) { // If we need to preserve LCSSA, determine if any of the preds is a loop // exit. if (PreserveLCSSA) @@ -362,9 +321,7 @@ static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, // loops enclose them, and select the most-nested loop which contains the // loop containing the block being split. Loop *InnermostPredLoop = nullptr; - for (ArrayRef<BasicBlock*>::iterator - i = Preds.begin(), e = Preds.end(); i != e; ++i) { - BasicBlock *Pred = *i; + for (BasicBlock *Pred : Preds) { if (Loop *PredLoop = LI->getLoopFor(Pred)) { // Seek a loop which actually contains the block being split (to avoid // adjacent loops). @@ -388,8 +345,8 @@ static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, } } -/// UpdatePHINodes - Update the PHI nodes in OrigBB to include the values coming -/// from NewBB. This also updates AliasAnalysis, if available. +/// Update the PHI nodes in OrigBB to include the values coming from NewBB. +/// This also updates AliasAnalysis, if available. static void UpdatePHINodes(BasicBlock *OrigBB, BasicBlock *NewBB, ArrayRef<BasicBlock *> Preds, BranchInst *BI, bool HasLoopExit) { @@ -456,21 +413,6 @@ static void UpdatePHINodes(BasicBlock *OrigBB, BasicBlock *NewBB, } } -/// SplitBlockPredecessors - This method introduces at least one new basic block -/// into the function and moves some of the predecessors of BB to be -/// predecessors of the new block. The new predecessors are indicated by the -/// Preds array. The new block is given a suffix of 'Suffix'. Returns new basic -/// block to which predecessors from Preds are now pointing. -/// -/// If BB is a landingpad block then additional basicblock might be introduced. -/// It will have suffix of 'Suffix'+".split_lp". -/// See SplitLandingPadPredecessors for more details on this case. -/// -/// This currently updates the LLVM IR, AliasAnalysis, DominatorTree, -/// LoopInfo, and LCCSA but no other analyses. In particular, it does not -/// preserve LoopSimplify (because it's complicated to handle the case where one -/// of the edges being split is an exit of a loop with other exits). -/// BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, ArrayRef<BasicBlock *> Preds, const char *Suffix, DominatorTree *DT, @@ -529,19 +471,6 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, return NewBB; } -/// SplitLandingPadPredecessors - This method transforms the landing pad, -/// OrigBB, by introducing two new basic blocks into the function. One of those -/// new basic blocks gets the predecessors listed in Preds. The other basic -/// block gets the remaining predecessors of OrigBB. The landingpad instruction -/// OrigBB is clone into both of the new basic blocks. The new blocks are given -/// the suffixes 'Suffix1' and 'Suffix2', and are returned in the NewBBs vector. -/// -/// This currently updates the LLVM IR, AliasAnalysis, DominatorTree, -/// DominanceFrontier, LoopInfo, and LCCSA but no other analyses. In particular, -/// it does not preserve LoopSimplify (because it's complicated to handle the -/// case where one of the edges being split is an exit of a loop with other -/// exits). -/// void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, ArrayRef<BasicBlock *> Preds, const char *Suffix1, const char *Suffix2, @@ -603,9 +532,8 @@ void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, BI2->setDebugLoc(OrigBB->getFirstNonPHI()->getDebugLoc()); // Move the remaining edges from OrigBB to point to NewBB2. - for (SmallVectorImpl<BasicBlock*>::iterator - i = NewBB2Preds.begin(), e = NewBB2Preds.end(); i != e; ++i) - (*i)->getTerminator()->replaceUsesOfWith(OrigBB, NewBB2); + for (BasicBlock *NewBB2Pred : NewBB2Preds) + NewBB2Pred->getTerminator()->replaceUsesOfWith(OrigBB, NewBB2); // Update DominatorTree, LoopInfo, and LCCSA analysis information. HasLoopExit = false; @@ -646,11 +574,6 @@ void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, } } -/// FoldReturnIntoUncondBranch - This method duplicates the specified return -/// instruction into a predecessor which ends in an unconditional branch. If -/// the return instruction returns a value defined by a PHI, propagate the -/// right value into the return. It returns the new return instruction in the -/// predecessor. ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, BasicBlock *Pred) { Instruction *UncondBranch = Pred->getTerminator(); @@ -689,31 +612,10 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, return cast<ReturnInst>(NewRet); } -/// SplitBlockAndInsertIfThen - Split the containing block at the -/// specified instruction - everything before and including SplitBefore stays -/// in the old basic block, and everything after SplitBefore is moved to a -/// new block. The two blocks are connected by a conditional branch -/// (with value of Cmp being the condition). -/// Before: -/// Head -/// SplitBefore -/// Tail -/// After: -/// Head -/// if (Cond) -/// ThenBlock -/// SplitBefore -/// Tail -/// -/// If Unreachable is true, then ThenBlock ends with -/// UnreachableInst, otherwise it branches to Tail. -/// Returns the NewBasicBlock's terminator. - -TerminatorInst *llvm::SplitBlockAndInsertIfThen(Value *Cond, - Instruction *SplitBefore, - bool Unreachable, - MDNode *BranchWeights, - DominatorTree *DT) { +TerminatorInst * +llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore, + bool Unreachable, MDNode *BranchWeights, + DominatorTree *DT, LoopInfo *LI) { BasicBlock *Head = SplitBefore->getParent(); BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); TerminatorInst *HeadOldTerm = Head->getTerminator(); @@ -735,7 +637,7 @@ TerminatorInst *llvm::SplitBlockAndInsertIfThen(Value *Cond, std::vector<DomTreeNode *> Children(OldNode->begin(), OldNode->end()); DomTreeNode *NewNode = DT->addNewBlock(Tail, Head); - for (auto Child : Children) + for (DomTreeNode *Child : Children) DT->changeImmediateDominator(Child, NewNode); // Head dominates ThenBlock. @@ -743,23 +645,15 @@ TerminatorInst *llvm::SplitBlockAndInsertIfThen(Value *Cond, } } + if (LI) { + Loop *L = LI->getLoopFor(Head); + L->addBasicBlockToLoop(ThenBlock, *LI); + L->addBasicBlockToLoop(Tail, *LI); + } + return CheckTerm; } -/// SplitBlockAndInsertIfThenElse is similar to SplitBlockAndInsertIfThen, -/// but also creates the ElseBlock. -/// Before: -/// Head -/// SplitBefore -/// Tail -/// After: -/// Head -/// if (Cond) -/// ThenBlock -/// else -/// ElseBlock -/// SplitBefore -/// Tail void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, TerminatorInst **ThenTerm, TerminatorInst **ElseTerm, @@ -781,15 +675,6 @@ void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, } -/// GetIfCondition - Given a basic block (BB) with two predecessors, -/// check to see if the merge at this block is due -/// to an "if condition". If so, return the boolean condition that determines -/// which entry into BB will be taken. Also, return by references the block -/// that will be entered from if the condition is true, and the block that will -/// be entered if the condition is false. -/// -/// This does no checking to see if the true/false blocks have large or unsavory -/// instructions in them. Value *llvm::GetIfCondition(BasicBlock *BB, BasicBlock *&IfTrue, BasicBlock *&IfFalse) { PHINode *SomePHI = dyn_cast<PHINode>(BB->begin()); diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp index 95825991cee9..49b646a041f5 100644 --- a/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -76,11 +76,10 @@ FunctionPass *llvm::createBreakCriticalEdgesPass() { // Implementation of the external critical edge manipulation functions //===----------------------------------------------------------------------===// -/// createPHIsForSplitLoopExit - When a loop exit edge is split, LCSSA form -/// may require new PHIs in the new exit block. This function inserts the -/// new PHIs, as needed. Preds is a list of preds inside the loop, SplitBB -/// is the new loop exit block, and DestBB is the old loop exit, now the -/// successor of SplitBB. +/// When a loop exit edge is split, LCSSA form may require new PHIs in the new +/// exit block. This function inserts the new PHIs, as needed. Preds is a list +/// of preds inside the loop, SplitBB is the new loop exit block, and DestBB is +/// the old loop exit, now the successor of SplitBB. static void createPHIsForSplitLoopExit(ArrayRef<BasicBlock *> Preds, BasicBlock *SplitBB, BasicBlock *DestBB) { @@ -112,25 +111,9 @@ static void createPHIsForSplitLoopExit(ArrayRef<BasicBlock *> Preds, } } -/// SplitCriticalEdge - If this edge is a critical edge, insert a new node to -/// split the critical edge. This will update DominatorTree information if it -/// is available, thus calling this pass will not invalidate either of them. -/// This returns the new block if the edge was split, null otherwise. -/// -/// If MergeIdenticalEdges is true (not the default), *all* edges from TI to the -/// specified successor will be merged into the same critical edge block. -/// This is most commonly interesting with switch instructions, which may -/// have many edges to any one destination. This ensures that all edges to that -/// dest go to one block instead of each going to a different block, but isn't -/// the standard definition of a "critical edge". -/// -/// It is invalid to call this function on a critical edge that starts at an -/// IndirectBrInst. Splitting these edges will almost always create an invalid -/// program because the address of the new block won't be the one that is jumped -/// to. -/// -BasicBlock *llvm::SplitCriticalEdge(TerminatorInst *TI, unsigned SuccNum, - const CriticalEdgeSplittingOptions &Options) { +BasicBlock * +llvm::SplitCriticalEdge(TerminatorInst *TI, unsigned SuccNum, + const CriticalEdgeSplittingOptions &Options) { if (!isCriticalEdge(TI, SuccNum, Options.MergeIdenticalEdges)) return nullptr; diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index 64b44a6b7919..f4260a9ff980 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -13,6 +13,7 @@ #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -25,81 +26,742 @@ using namespace llvm; -/// CastToCStr - Return V if it is an i8*, otherwise cast it to i8*. -Value *llvm::CastToCStr(Value *V, IRBuilder<> &B) { +#define DEBUG_TYPE "build-libcalls" + +//- Infer Attributes ---------------------------------------------------------// + +STATISTIC(NumReadNone, "Number of functions inferred as readnone"); +STATISTIC(NumReadOnly, "Number of functions inferred as readonly"); +STATISTIC(NumArgMemOnly, "Number of functions inferred as argmemonly"); +STATISTIC(NumNoUnwind, "Number of functions inferred as nounwind"); +STATISTIC(NumNoCapture, "Number of arguments inferred as nocapture"); +STATISTIC(NumReadOnlyArg, "Number of arguments inferred as readonly"); +STATISTIC(NumNoAlias, "Number of function returns inferred as noalias"); +STATISTIC(NumNonNull, "Number of function returns inferred as nonnull returns"); + +static bool setDoesNotAccessMemory(Function &F) { + if (F.doesNotAccessMemory()) + return false; + F.setDoesNotAccessMemory(); + ++NumReadNone; + return true; +} + +static bool setOnlyReadsMemory(Function &F) { + if (F.onlyReadsMemory()) + return false; + F.setOnlyReadsMemory(); + ++NumReadOnly; + return true; +} + +static bool setOnlyAccessesArgMemory(Function &F) { + if (F.onlyAccessesArgMemory()) + return false; + F.setOnlyAccessesArgMemory (); + ++NumArgMemOnly; + return true; +} + +static bool setDoesNotThrow(Function &F) { + if (F.doesNotThrow()) + return false; + F.setDoesNotThrow(); + ++NumNoUnwind; + return true; +} + +static bool setDoesNotCapture(Function &F, unsigned n) { + if (F.doesNotCapture(n)) + return false; + F.setDoesNotCapture(n); + ++NumNoCapture; + return true; +} + +static bool setOnlyReadsMemory(Function &F, unsigned n) { + if (F.onlyReadsMemory(n)) + return false; + F.setOnlyReadsMemory(n); + ++NumReadOnlyArg; + return true; +} + +static bool setDoesNotAlias(Function &F, unsigned n) { + if (F.doesNotAlias(n)) + return false; + F.setDoesNotAlias(n); + ++NumNoAlias; + return true; +} + +static bool setNonNull(Function &F, unsigned n) { + assert((n != AttributeSet::ReturnIndex || + F.getReturnType()->isPointerTy()) && + "nonnull applies only to pointers"); + if (F.getAttributes().hasAttribute(n, Attribute::NonNull)) + return false; + F.addAttribute(n, Attribute::NonNull); + ++NumNonNull; + return true; +} + +bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { + LibFunc::Func TheLibFunc; + if (!(TLI.getLibFunc(F, TheLibFunc) && TLI.has(TheLibFunc))) + return false; + + bool Changed = false; + switch (TheLibFunc) { + case LibFunc::strlen: + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::strchr: + case LibFunc::strrchr: + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotThrow(F); + return Changed; + case LibFunc::strtol: + case LibFunc::strtod: + case LibFunc::strtof: + case LibFunc::strtoul: + case LibFunc::strtoll: + case LibFunc::strtold: + case LibFunc::strtoull: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::strcpy: + case LibFunc::stpcpy: + case LibFunc::strcat: + case LibFunc::strncat: + case LibFunc::strncpy: + case LibFunc::stpncpy: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::strxfrm: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::strcmp: // 0,1 + case LibFunc::strspn: // 0,1 + case LibFunc::strncmp: // 0,1 + case LibFunc::strcspn: // 0,1 + case LibFunc::strcoll: // 0,1 + case LibFunc::strcasecmp: // 0,1 + case LibFunc::strncasecmp: // + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::strstr: + case LibFunc::strpbrk: + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::strtok: + case LibFunc::strtok_r: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::scanf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::setbuf: + case LibFunc::setvbuf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::strdup: + case LibFunc::strndup: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::stat: + case LibFunc::statvfs: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::sscanf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::sprintf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::snprintf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 3); + Changed |= setOnlyReadsMemory(F, 3); + return Changed; + case LibFunc::setitimer: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + Changed |= setDoesNotCapture(F, 3); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::system: + // May throw; "system" is a valid pthread cancellation point. + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::malloc: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + return Changed; + case LibFunc::memcmp: + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::memchr: + case LibFunc::memrchr: + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotThrow(F); + return Changed; + case LibFunc::modf: + case LibFunc::modff: + case LibFunc::modfl: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::memcpy: + case LibFunc::memccpy: + case LibFunc::memmove: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::memcpy_chk: + Changed |= setDoesNotThrow(F); + return Changed; + case LibFunc::memalign: + Changed |= setDoesNotAlias(F, 0); + return Changed; + case LibFunc::mkdir: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::mktime: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::realloc: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::read: + // May throw; "read" is a valid pthread cancellation point. + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::rewind: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::rmdir: + case LibFunc::remove: + case LibFunc::realpath: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::rename: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::readlink: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::write: + // May throw; "write" is a valid pthread cancellation point. + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::bcopy: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::bcmp: + Changed |= setDoesNotThrow(F); + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::bzero: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::calloc: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + return Changed; + case LibFunc::chmod: + case LibFunc::chown: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::ctermid: + case LibFunc::clearerr: + case LibFunc::closedir: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::atoi: + case LibFunc::atol: + case LibFunc::atof: + case LibFunc::atoll: + Changed |= setDoesNotThrow(F); + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::access: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::fopen: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::fdopen: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::feof: + case LibFunc::free: + case LibFunc::fseek: + case LibFunc::ftell: + case LibFunc::fgetc: + case LibFunc::fseeko: + case LibFunc::ftello: + case LibFunc::fileno: + case LibFunc::fflush: + case LibFunc::fclose: + case LibFunc::fsetpos: + case LibFunc::flockfile: + case LibFunc::funlockfile: + case LibFunc::ftrylockfile: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::ferror: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F); + return Changed; + case LibFunc::fputc: + case LibFunc::fstat: + case LibFunc::frexp: + case LibFunc::frexpf: + case LibFunc::frexpl: + case LibFunc::fstatvfs: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::fgets: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 3); + return Changed; + case LibFunc::fread: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 4); + return Changed; + case LibFunc::fwrite: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 4); + // FIXME: readonly #1? + return Changed; + case LibFunc::fputs: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::fscanf: + case LibFunc::fprintf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::fgetpos: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::getc: + case LibFunc::getlogin_r: + case LibFunc::getc_unlocked: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::getenv: + Changed |= setDoesNotThrow(F); + Changed |= setOnlyReadsMemory(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::gets: + case LibFunc::getchar: + Changed |= setDoesNotThrow(F); + return Changed; + case LibFunc::getitimer: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::getpwnam: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::ungetc: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::uname: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::unlink: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::unsetenv: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::utime: + case LibFunc::utimes: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::putc: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::puts: + case LibFunc::printf: + case LibFunc::perror: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::pread: + // May throw; "pread" is a valid pthread cancellation point. + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::pwrite: + // May throw; "pwrite" is a valid pthread cancellation point. + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::putchar: + Changed |= setDoesNotThrow(F); + return Changed; + case LibFunc::popen: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::pclose: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::vscanf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::vsscanf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::vfscanf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::valloc: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + return Changed; + case LibFunc::vprintf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::vfprintf: + case LibFunc::vsprintf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::vsnprintf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 3); + Changed |= setOnlyReadsMemory(F, 3); + return Changed; + case LibFunc::open: + // May throw; "open" is a valid pthread cancellation point. + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::opendir: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::tmpfile: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + return Changed; + case LibFunc::times: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::htonl: + case LibFunc::htons: + case LibFunc::ntohl: + case LibFunc::ntohs: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAccessMemory(F); + return Changed; + case LibFunc::lstat: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::lchown: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::qsort: + // May throw; places call through function pointer. + Changed |= setDoesNotCapture(F, 4); + return Changed; + case LibFunc::dunder_strdup: + case LibFunc::dunder_strndup: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::dunder_strtok_r: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::under_IO_getc: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::under_IO_putc: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::dunder_isoc99_scanf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::stat64: + case LibFunc::lstat64: + case LibFunc::statvfs64: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::dunder_isoc99_sscanf: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::fopen64: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 1); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + case LibFunc::fseeko64: + case LibFunc::ftello64: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + return Changed; + case LibFunc::tmpfile64: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotAlias(F, 0); + return Changed; + case LibFunc::fstat64: + case LibFunc::fstatvfs64: + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::open64: + // May throw; "open" is a valid pthread cancellation point. + Changed |= setDoesNotCapture(F, 1); + Changed |= setOnlyReadsMemory(F, 1); + return Changed; + case LibFunc::gettimeofday: + // Currently some platforms have the restrict keyword on the arguments to + // gettimeofday. To be conservative, do not add noalias to gettimeofday's + // arguments. + Changed |= setDoesNotThrow(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + return Changed; + case LibFunc::Znwj: // new(unsigned int) + case LibFunc::Znwm: // new(unsigned long) + case LibFunc::Znaj: // new[](unsigned int) + case LibFunc::Znam: // new[](unsigned long) + case LibFunc::msvc_new_int: // new(unsigned int) + case LibFunc::msvc_new_longlong: // new(unsigned long long) + case LibFunc::msvc_new_array_int: // new[](unsigned int) + case LibFunc::msvc_new_array_longlong: // new[](unsigned long long) + // Operator new always returns a nonnull noalias pointer + Changed |= setNonNull(F, AttributeSet::ReturnIndex); + Changed |= setDoesNotAlias(F, AttributeSet::ReturnIndex); + return Changed; + //TODO: add LibFunc entries for: + //case LibFunc::memset_pattern4: + //case LibFunc::memset_pattern8: + case LibFunc::memset_pattern16: + Changed |= setOnlyAccessesArgMemory(F); + Changed |= setDoesNotCapture(F, 1); + Changed |= setDoesNotCapture(F, 2); + Changed |= setOnlyReadsMemory(F, 2); + return Changed; + // int __nvvm_reflect(const char *) + case LibFunc::nvvm_reflect: + Changed |= setDoesNotAccessMemory(F); + Changed |= setDoesNotThrow(F); + return Changed; + + default: + // FIXME: It'd be really nice to cover all the library functions we're + // aware of here. + return false; + } +} + +//- Emit LibCalls ------------------------------------------------------------// + +Value *llvm::castToCStr(Value *V, IRBuilder<> &B) { unsigned AS = V->getType()->getPointerAddressSpace(); return B.CreateBitCast(V, B.getInt8PtrTy(AS), "cstr"); } -/// EmitStrLen - Emit a call to the strlen function to the builder, for the -/// specified pointer. This always returns an integer value of size intptr_t. -Value *llvm::EmitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL, +Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::strlen)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[2]; - AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); - Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); - + Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Constant *StrLen = M->getOrInsertFunction( - "strlen", AttributeSet::get(M->getContext(), AS), - DL.getIntPtrType(Context), B.getInt8PtrTy(), nullptr); - CallInst *CI = B.CreateCall(StrLen, CastToCStr(Ptr, B), "strlen"); + Constant *StrLen = M->getOrInsertFunction("strlen", DL.getIntPtrType(Context), + B.getInt8PtrTy(), nullptr); + inferLibFuncAttributes(*M->getFunction("strlen"), *TLI); + CallInst *CI = B.CreateCall(StrLen, castToCStr(Ptr, B), "strlen"); if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } -/// EmitStrChr - Emit a call to the strchr function to the builder, for the -/// specified pointer and character. Ptr is required to be some pointer type, -/// and the return value has 'i8*' type. -Value *llvm::EmitStrChr(Value *Ptr, char C, IRBuilder<> &B, +Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::strchr)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AttributeSet AS = - AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); - + Module *M = B.GetInsertBlock()->getModule(); Type *I8Ptr = B.getInt8PtrTy(); Type *I32Ty = B.getInt32Ty(); - Constant *StrChr = M->getOrInsertFunction("strchr", - AttributeSet::get(M->getContext(), - AS), - I8Ptr, I8Ptr, I32Ty, nullptr); + Constant *StrChr = + M->getOrInsertFunction("strchr", I8Ptr, I8Ptr, I32Ty, nullptr); + inferLibFuncAttributes(*M->getFunction("strchr"), *TLI); CallInst *CI = B.CreateCall( - StrChr, {CastToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, "strchr"); + StrChr, {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, "strchr"); if (const Function *F = dyn_cast<Function>(StrChr->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } -/// EmitStrNCmp - Emit a call to the strncmp function to the builder. -Value *llvm::EmitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, +Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::strncmp)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[3]; - AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); - AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); - Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); - + Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *StrNCmp = M->getOrInsertFunction( - "strncmp", AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), - B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context), nullptr); + Value *StrNCmp = M->getOrInsertFunction("strncmp", B.getInt32Ty(), + B.getInt8PtrTy(), B.getInt8PtrTy(), + DL.getIntPtrType(Context), nullptr); + inferLibFuncAttributes(*M->getFunction("strncmp"), *TLI); CallInst *CI = B.CreateCall( - StrNCmp, {CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), Len}, "strncmp"); + StrNCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "strncmp"); if (const Function *F = dyn_cast<Function>(StrNCmp->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -107,64 +769,46 @@ Value *llvm::EmitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, return CI; } -/// EmitStrCpy - Emit a call to the strcpy function to the builder, for the -/// specified pointer arguments. -Value *llvm::EmitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, +Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, const TargetLibraryInfo *TLI, StringRef Name) { if (!TLI->has(LibFunc::strcpy)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[2]; - AS[0] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - Attribute::NoUnwind); + Module *M = B.GetInsertBlock()->getModule(); Type *I8Ptr = B.getInt8PtrTy(); - Value *StrCpy = M->getOrInsertFunction(Name, - AttributeSet::get(M->getContext(), AS), - I8Ptr, I8Ptr, I8Ptr, nullptr); + Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, nullptr); + inferLibFuncAttributes(*M->getFunction(Name), *TLI); CallInst *CI = - B.CreateCall(StrCpy, {CastToCStr(Dst, B), CastToCStr(Src, B)}, Name); + B.CreateCall(StrCpy, {castToCStr(Dst, B), castToCStr(Src, B)}, Name); if (const Function *F = dyn_cast<Function>(StrCpy->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } -/// EmitStrNCpy - Emit a call to the strncpy function to the builder, for the -/// specified pointer arguments. -Value *llvm::EmitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, +Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, const TargetLibraryInfo *TLI, StringRef Name) { if (!TLI->has(LibFunc::strncpy)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[2]; - AS[0] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - Attribute::NoUnwind); + Module *M = B.GetInsertBlock()->getModule(); Type *I8Ptr = B.getInt8PtrTy(); - Value *StrNCpy = M->getOrInsertFunction(Name, - AttributeSet::get(M->getContext(), - AS), - I8Ptr, I8Ptr, I8Ptr, + Value *StrNCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, Len->getType(), nullptr); + inferLibFuncAttributes(*M->getFunction(Name), *TLI); CallInst *CI = B.CreateCall( - StrNCpy, {CastToCStr(Dst, B), CastToCStr(Src, B), Len}, "strncpy"); + StrNCpy, {castToCStr(Dst, B), castToCStr(Src, B), Len}, "strncpy"); if (const Function *F = dyn_cast<Function>(StrNCpy->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } -/// EmitMemCpyChk - Emit a call to the __memcpy_chk function to the builder. -/// This expects that the Len and ObjSize have type 'intptr_t' and Dst/Src -/// are pointers. -Value *llvm::EmitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, +Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::memcpy_chk)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); + Module *M = B.GetInsertBlock()->getModule(); AttributeSet AS; AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, Attribute::NoUnwind); @@ -173,30 +817,26 @@ Value *llvm::EmitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, "__memcpy_chk", AttributeSet::get(M->getContext(), AS), B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context), DL.getIntPtrType(Context), nullptr); - Dst = CastToCStr(Dst, B); - Src = CastToCStr(Src, B); + Dst = castToCStr(Dst, B); + Src = castToCStr(Src, B); CallInst *CI = B.CreateCall(MemCpy, {Dst, Src, Len, ObjSize}); if (const Function *F = dyn_cast<Function>(MemCpy->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } -/// EmitMemChr - Emit a call to the memchr function. This assumes that Ptr is -/// a pointer, Val is an i32 value, and Len is an 'intptr_t' value. -Value *llvm::EmitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B, +Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::memchr)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS; - Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); + Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *MemChr = M->getOrInsertFunction( - "memchr", AttributeSet::get(M->getContext(), AS), B.getInt8PtrTy(), - B.getInt8PtrTy(), B.getInt32Ty(), DL.getIntPtrType(Context), nullptr); - CallInst *CI = B.CreateCall(MemChr, {CastToCStr(Ptr, B), Val, Len}, "memchr"); + Value *MemChr = M->getOrInsertFunction("memchr", B.getInt8PtrTy(), + B.getInt8PtrTy(), B.getInt32Ty(), + DL.getIntPtrType(Context), nullptr); + inferLibFuncAttributes(*M->getFunction("memchr"), *TLI); + CallInst *CI = B.CreateCall(MemChr, {castToCStr(Ptr, B), Val, Len}, "memchr"); if (const Function *F = dyn_cast<Function>(MemChr->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -204,25 +844,19 @@ Value *llvm::EmitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B, return CI; } -/// EmitMemCmp - Emit a call to the memcmp function. -Value *llvm::EmitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, +Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::memcmp)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[3]; - AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); - AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); - Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); - + Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *MemCmp = M->getOrInsertFunction( - "memcmp", AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), - B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context), nullptr); + Value *MemCmp = M->getOrInsertFunction("memcmp", B.getInt32Ty(), + B.getInt8PtrTy(), B.getInt8PtrTy(), + DL.getIntPtrType(Context), nullptr); + inferLibFuncAttributes(*M->getFunction("memcmp"), *TLI); CallInst *CI = B.CreateCall( - MemCmp, {CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), Len}, "memcmp"); + MemCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "memcmp"); if (const Function *F = dyn_cast<Function>(MemCmp->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -231,7 +865,8 @@ Value *llvm::EmitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, } /// Append a suffix to the function name according to the type of 'Op'. -static void AppendTypeSuffix(Value *Op, StringRef &Name, SmallString<20> &NameBuffer) { +static void appendTypeSuffix(Value *Op, StringRef &Name, + SmallString<20> &NameBuffer) { if (!Op->getType()->isDoubleTy()) { NameBuffer += Name; @@ -242,19 +877,14 @@ static void AppendTypeSuffix(Value *Op, StringRef &Name, SmallString<20> &NameBu Name = NameBuffer; } - return; } -/// EmitUnaryFloatFnCall - Emit a call to the unary function named 'Name' (e.g. -/// 'floor'). This function is known to take a single of type matching 'Op' and -/// returns one value with the same type. If 'Op' is a long double, 'l' is -/// added as the suffix of name, if 'Op' is a float, we add a 'f' suffix. -Value *llvm::EmitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, +Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, const AttributeSet &Attrs) { SmallString<20> NameBuffer; - AppendTypeSuffix(Op, Name, NameBuffer); + appendTypeSuffix(Op, Name, NameBuffer); - Module *M = B.GetInsertBlock()->getParent()->getParent(); + Module *M = B.GetInsertBlock()->getModule(); Value *Callee = M->getOrInsertFunction(Name, Op->getType(), Op->getType(), nullptr); CallInst *CI = B.CreateCall(Callee, Op, Name); @@ -265,19 +895,14 @@ Value *llvm::EmitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, return CI; } -/// EmitBinaryFloatFnCall - Emit a call to the binary function named 'Name' -/// (e.g. 'fmin'). This function is known to take type matching 'Op1' and 'Op2' -/// and return one value with the same type. If 'Op1/Op2' are long double, 'l' -/// is added as the suffix of name, if 'Op1/Op2' is a float, we add a 'f' -/// suffix. -Value *llvm::EmitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, +Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, IRBuilder<> &B, const AttributeSet &Attrs) { SmallString<20> NameBuffer; - AppendTypeSuffix(Op1, Name, NameBuffer); + appendTypeSuffix(Op1, Name, NameBuffer); - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), - Op1->getType(), Op2->getType(), nullptr); + Module *M = B.GetInsertBlock()->getModule(); + Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), Op1->getType(), + Op2->getType(), nullptr); CallInst *CI = B.CreateCall(Callee, {Op1, Op2}, Name); CI->setAttributes(Attrs); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -286,14 +911,12 @@ Value *llvm::EmitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, return CI; } -/// EmitPutChar - Emit a call to the putchar function. This assumes that Char -/// is an integer. -Value *llvm::EmitPutChar(Value *Char, IRBuilder<> &B, +Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::putchar)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); + Module *M = B.GetInsertBlock()->getModule(); Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), B.getInt32Ty(), nullptr); CallInst *CI = B.CreateCall(PutChar, @@ -308,54 +931,31 @@ Value *llvm::EmitPutChar(Value *Char, IRBuilder<> &B, return CI; } -/// EmitPutS - Emit a call to the puts function. This assumes that Str is -/// some pointer. -Value *llvm::EmitPutS(Value *Str, IRBuilder<> &B, +Value *llvm::emitPutS(Value *Str, IRBuilder<> &B, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::puts)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[2]; - AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - Attribute::NoUnwind); - - Value *PutS = M->getOrInsertFunction("puts", - AttributeSet::get(M->getContext(), AS), - B.getInt32Ty(), - B.getInt8PtrTy(), - nullptr); - CallInst *CI = B.CreateCall(PutS, CastToCStr(Str, B), "puts"); + Module *M = B.GetInsertBlock()->getModule(); + Value *PutS = + M->getOrInsertFunction("puts", B.getInt32Ty(), B.getInt8PtrTy(), nullptr); + inferLibFuncAttributes(*M->getFunction("puts"), *TLI); + CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), "puts"); if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } -/// EmitFPutC - Emit a call to the fputc function. This assumes that Char is -/// an integer and File is a pointer to FILE. -Value *llvm::EmitFPutC(Value *Char, Value *File, IRBuilder<> &B, +Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::fputc)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[2]; - AS[0] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - Attribute::NoUnwind); - Constant *F; + Module *M = B.GetInsertBlock()->getModule(); + Constant *F = M->getOrInsertFunction("fputc", B.getInt32Ty(), B.getInt32Ty(), + File->getType(), nullptr); if (File->getType()->isPointerTy()) - F = M->getOrInsertFunction("fputc", - AttributeSet::get(M->getContext(), AS), - B.getInt32Ty(), - B.getInt32Ty(), File->getType(), - nullptr); - else - F = M->getOrInsertFunction("fputc", - B.getInt32Ty(), - B.getInt32Ty(), - File->getType(), nullptr); + inferLibFuncAttributes(*M->getFunction("fputc"), *TLI); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, "chari"); CallInst *CI = B.CreateCall(F, {Char, File}, "fputc"); @@ -365,66 +965,40 @@ Value *llvm::EmitFPutC(Value *Char, Value *File, IRBuilder<> &B, return CI; } -/// EmitFPutS - Emit a call to the puts function. Str is required to be a -/// pointer and File is a pointer to FILE. -Value *llvm::EmitFPutS(Value *Str, Value *File, IRBuilder<> &B, +Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::fputs)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[3]; - AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); - AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - Attribute::NoUnwind); + Module *M = B.GetInsertBlock()->getModule(); StringRef FPutsName = TLI->getName(LibFunc::fputs); - Constant *F; + Constant *F = M->getOrInsertFunction( + FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType(), nullptr); if (File->getType()->isPointerTy()) - F = M->getOrInsertFunction(FPutsName, - AttributeSet::get(M->getContext(), AS), - B.getInt32Ty(), - B.getInt8PtrTy(), - File->getType(), nullptr); - else - F = M->getOrInsertFunction(FPutsName, B.getInt32Ty(), - B.getInt8PtrTy(), - File->getType(), nullptr); - CallInst *CI = B.CreateCall(F, {CastToCStr(Str, B), File}, "fputs"); + inferLibFuncAttributes(*M->getFunction(FPutsName), *TLI); + CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, "fputs"); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } -/// EmitFWrite - Emit a call to the fwrite function. This assumes that Ptr is -/// a pointer, Size is an 'intptr_t', and File is a pointer to FILE. -Value *llvm::EmitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B, +Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc::fwrite)) return nullptr; - Module *M = B.GetInsertBlock()->getParent()->getParent(); - AttributeSet AS[3]; - AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); - AS[1] = AttributeSet::get(M->getContext(), 4, Attribute::NoCapture); - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - Attribute::NoUnwind); + Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); StringRef FWriteName = TLI->getName(LibFunc::fwrite); - Constant *F; + Constant *F = M->getOrInsertFunction( + FWriteName, DL.getIntPtrType(Context), B.getInt8PtrTy(), + DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType(), + nullptr); if (File->getType()->isPointerTy()) - F = M->getOrInsertFunction( - FWriteName, AttributeSet::get(M->getContext(), AS), - DL.getIntPtrType(Context), B.getInt8PtrTy(), DL.getIntPtrType(Context), - DL.getIntPtrType(Context), File->getType(), nullptr); - else - F = M->getOrInsertFunction(FWriteName, DL.getIntPtrType(Context), - B.getInt8PtrTy(), DL.getIntPtrType(Context), - DL.getIntPtrType(Context), File->getType(), - nullptr); + inferLibFuncAttributes(*M->getFunction(FWriteName), *TLI); CallInst *CI = - B.CreateCall(F, {CastToCStr(Ptr, B), Size, + B.CreateCall(F, {castToCStr(Ptr, B), Size, ConstantInt::get(DL.getIntPtrType(Context), 1), File}); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index 8308a9b69149..5aec0dce34db 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -11,7 +11,9 @@ add_llvm_library(LLVMTransformUtils CodeExtractor.cpp CtorUtils.cpp DemoteRegToStack.cpp + Evaluator.cpp FlattenCFG.cpp + FunctionImportUtils.cpp GlobalStatus.cpp InlineFunction.cpp InstructionNamer.cpp @@ -26,10 +28,13 @@ add_llvm_library(LLVMTransformUtils LowerInvoke.cpp LowerSwitch.cpp Mem2Reg.cpp + MemorySSA.cpp MetaRenamer.cpp ModuleUtils.cpp + NameAnonFunctions.cpp PromoteMemoryToRegister.cpp SSAUpdater.cpp + SanitizerStats.cpp SimplifyCFG.cpp SimplifyIndVar.cpp SimplifyInstructions.cpp diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 6454afb8bc42..c5ca56360fc8 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -119,6 +119,15 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, .addAttributes(NewFunc->getContext(), AttributeSet::FunctionIndex, OldAttrs.getFnAttributes())); + SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; + OldFunc->getAllMetadata(MDs); + for (auto MD : MDs) + NewFunc->addMetadata( + MD.first, + *MapMetadata(MD.second, VMap, + ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges, + TypeMapper, Materializer)); + // Loop over all of the basic blocks in the function, cloning them as // appropriate. Note that we save BE this way in order to handle cloning of // recursive functions into themselves. @@ -163,65 +172,14 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, TypeMapper, Materializer); } -// Find the MDNode which corresponds to the subprogram data that described F. -static DISubprogram *FindSubprogram(const Function *F, - DebugInfoFinder &Finder) { - for (DISubprogram *Subprogram : Finder.subprograms()) { - if (Subprogram->describes(F)) - return Subprogram; - } - return nullptr; -} - -// Add an operand to an existing MDNode. The new operand will be added at the -// back of the operand list. -static void AddOperand(DICompileUnit *CU, DISubprogramArray SPs, - Metadata *NewSP) { - SmallVector<Metadata *, 16> NewSPs; - NewSPs.reserve(SPs.size() + 1); - for (auto *SP : SPs) - NewSPs.push_back(SP); - NewSPs.push_back(NewSP); - CU->replaceSubprograms(MDTuple::get(CU->getContext(), NewSPs)); -} - -// Clone the module-level debug info associated with OldFunc. The cloned data -// will point to NewFunc instead. -static void CloneDebugInfoMetadata(Function *NewFunc, const Function *OldFunc, - ValueToValueMapTy &VMap) { - DebugInfoFinder Finder; - Finder.processModule(*OldFunc->getParent()); - - const DISubprogram *OldSubprogramMDNode = FindSubprogram(OldFunc, Finder); - if (!OldSubprogramMDNode) return; - - auto *NewSubprogram = - cast<DISubprogram>(MapMetadata(OldSubprogramMDNode, VMap)); - NewFunc->setSubprogram(NewSubprogram); - - for (auto *CU : Finder.compile_units()) { - auto Subprograms = CU->getSubprograms(); - // If the compile unit's function list contains the old function, it should - // also contain the new one. - for (auto *SP : Subprograms) { - if (SP == OldSubprogramMDNode) { - AddOperand(CU, Subprograms, NewSubprogram); - break; - } - } - } -} - -/// Return a copy of the specified function, but without -/// embedding the function into another module. Also, any references specified -/// in the VMap are changed to refer to their mapped value instead of the -/// original one. If any of the arguments to the function are in the VMap, -/// the arguments are deleted from the resultant function. The VMap is -/// updated to include mappings from all of the instructions and basicblocks in -/// the function from their old to new values. +/// Return a copy of the specified function and add it to that function's +/// module. Also, any references specified in the VMap are changed to refer to +/// their mapped value instead of the original one. If any of the arguments to +/// the function are in the VMap, the arguments are deleted from the resultant +/// function. The VMap is updated to include mappings from all of the +/// instructions and basicblocks in the function from their old to new values. /// -Function *llvm::CloneFunction(const Function *F, ValueToValueMapTy &VMap, - bool ModuleLevelChanges, +Function *llvm::CloneFunction(Function *F, ValueToValueMapTy &VMap, ClonedCodeInfo *CodeInfo) { std::vector<Type*> ArgTypes; @@ -237,7 +195,8 @@ Function *llvm::CloneFunction(const Function *F, ValueToValueMapTy &VMap, ArgTypes, F->getFunctionType()->isVarArg()); // Create the new function... - Function *NewF = Function::Create(FTy, F->getLinkage(), F->getName()); + Function *NewF = + Function::Create(FTy, F->getLinkage(), F->getName(), F->getParent()); // Loop over the arguments, copying the names of the mapped arguments over... Function::arg_iterator DestI = NewF->arg_begin(); @@ -247,11 +206,10 @@ Function *llvm::CloneFunction(const Function *F, ValueToValueMapTy &VMap, VMap[&I] = &*DestI++; // Add mapping to VMap } - if (ModuleLevelChanges) - CloneDebugInfoMetadata(NewF, F, VMap); - SmallVector<ReturnInst*, 8> Returns; // Ignore returns cloned. - CloneFunctionInto(NewF, F, VMap, ModuleLevelChanges, Returns, "", CodeInfo); + CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns, "", + CodeInfo); + return NewF; } @@ -338,9 +296,11 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, if (Value *MappedV = VMap.lookup(V)) V = MappedV; - VMap[&*II] = V; - delete NewInst; - continue; + if (!NewInst->mayHaveSideEffects()) { + VMap[&*II] = V; + delete NewInst; + continue; + } } } @@ -372,7 +332,7 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, ConstantInt *Cond = dyn_cast<ConstantInt>(BI->getCondition()); // Or is a known constant in the caller... if (!Cond) { - Value *V = VMap[BI->getCondition()]; + Value *V = VMap.lookup(BI->getCondition()); Cond = dyn_cast_or_null<ConstantInt>(V); } @@ -388,7 +348,7 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, // If switching on a value known constant in the caller. ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition()); if (!Cond) { // Or known constant after constant prop in the callee... - Value *V = VMap[SI->getCondition()]; + Value *V = VMap.lookup(SI->getCondition()); Cond = dyn_cast_or_null<ConstantInt>(V); } if (Cond) { // Constant fold to uncond branch! @@ -475,7 +435,7 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, // Defer PHI resolution until rest of function is resolved. SmallVector<const PHINode*, 16> PHIToResolve; for (const BasicBlock &BI : *OldFunc) { - Value *V = VMap[&BI]; + Value *V = VMap.lookup(&BI); BasicBlock *NewBB = cast_or_null<BasicBlock>(V); if (!NewBB) continue; // Dead block. @@ -519,7 +479,7 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, OPN = PHIToResolve[phino]; PHINode *PN = cast<PHINode>(VMap[OPN]); for (unsigned pred = 0, e = NumPreds; pred != e; ++pred) { - Value *V = VMap[PN->getIncomingBlock(pred)]; + Value *V = VMap.lookup(PN->getIncomingBlock(pred)); if (BasicBlock *MappedBlock = cast_or_null<BasicBlock>(V)) { Value *InVal = MapValue(PN->getIncomingValue(pred), VMap, @@ -529,7 +489,8 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, PN->setIncomingBlock(pred, MappedBlock); } else { PN->removeIncomingValue(pred, false); - --pred, --e; // Revisit the next entry. + --pred; // Revisit the next entry. + --e; } } } @@ -558,10 +519,9 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, // entries. BasicBlock::iterator I = NewBB->begin(); for (; (PN = dyn_cast<PHINode>(I)); ++I) { - for (std::map<BasicBlock*, unsigned>::iterator PCI =PredCount.begin(), - E = PredCount.end(); PCI != E; ++PCI) { - BasicBlock *Pred = PCI->first; - for (unsigned NumToRemove = PCI->second; NumToRemove; --NumToRemove) + for (const auto &PCI : PredCount) { + BasicBlock *Pred = PCI.first; + for (unsigned NumToRemove = PCI.second; NumToRemove; --NumToRemove) PN->removeIncomingValue(Pred, false); } } @@ -684,7 +644,7 @@ void llvm::remapInstructionsInBlocks( for (auto *BB : Blocks) for (auto &Inst : *BB) RemapInstruction(&Inst, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); } /// \brief Clones a loop \p OrigLoop. Returns the loop and the blocks in \p @@ -697,6 +657,8 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, const Twine &NameSuffix, LoopInfo *LI, DominatorTree *DT, SmallVectorImpl<BasicBlock *> &Blocks) { + assert(OrigLoop->getSubLoops().empty() && + "Loop to be cloned cannot have inner loop"); Function *F = OrigLoop->getHeader()->getParent(); Loop *ParentLoop = OrigLoop->getParentLoop(); @@ -727,13 +689,19 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, // Update LoopInfo. NewLoop->addBasicBlockToLoop(NewBB, *LI); - // Update DominatorTree. - BasicBlock *IDomBB = DT->getNode(BB)->getIDom()->getBlock(); - DT->addNewBlock(NewBB, cast<BasicBlock>(VMap[IDomBB])); + // Add DominatorTree node. After seeing all blocks, update to correct IDom. + DT->addNewBlock(NewBB, NewPH); Blocks.push_back(NewBB); } + for (BasicBlock *BB : OrigLoop->getBlocks()) { + // Update DominatorTree. + BasicBlock *IDomBB = DT->getNode(BB)->getIDom()->getBlock(); + DT->changeImmediateDominator(cast<BasicBlock>(VMap[BB]), + cast<BasicBlock>(VMap[IDomBB])); + } + // Move them physically from the end of the block list. F->getBasicBlockList().splice(Before->getIterator(), F->getBasicBlockList(), NewPH); diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index ab083353ece6..17e34c4ffa0f 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -38,7 +38,7 @@ std::unique_ptr<Module> llvm::CloneModule(const Module *M, std::unique_ptr<Module> llvm::CloneModule( const Module *M, ValueToValueMapTy &VMap, - std::function<bool(const GlobalValue *)> ShouldCloneDefinition) { + function_ref<bool(const GlobalValue *)> ShouldCloneDefinition) { // First off, we need to create the new module. std::unique_ptr<Module> New = llvm::make_unique<Module>(M->getModuleIdentifier(), M->getContext()); @@ -53,7 +53,7 @@ std::unique_ptr<Module> llvm::CloneModule( for (Module::const_global_iterator I = M->global_begin(), E = M->global_end(); I != E; ++I) { GlobalVariable *GV = new GlobalVariable(*New, - I->getType()->getElementType(), + I->getValueType(), I->isConstant(), I->getLinkage(), (Constant*) nullptr, I->getName(), (GlobalVariable*) nullptr, @@ -64,12 +64,11 @@ std::unique_ptr<Module> llvm::CloneModule( } // Loop over the functions in the module, making external functions as before - for (Module::const_iterator I = M->begin(), E = M->end(); I != E; ++I) { - Function *NF = - Function::Create(cast<FunctionType>(I->getType()->getElementType()), - I->getLinkage(), I->getName(), New.get()); - NF->copyAttributesFrom(&*I); - VMap[&*I] = NF; + for (const Function &I : *M) { + Function *NF = Function::Create(cast<FunctionType>(I.getValueType()), + I.getLinkage(), I.getName(), New.get()); + NF->copyAttributesFrom(&I); + VMap[&I] = NF; } // Loop over the aliases in the module @@ -109,6 +108,9 @@ std::unique_ptr<Module> llvm::CloneModule( // for (Module::const_global_iterator I = M->global_begin(), E = M->global_end(); I != E; ++I) { + if (I->isDeclaration()) + continue; + GlobalVariable *GV = cast<GlobalVariable>(VMap[&*I]); if (!ShouldCloneDefinition(&*I)) { // Skip after setting the correct linkage for an external reference. @@ -121,27 +123,31 @@ std::unique_ptr<Module> llvm::CloneModule( // Similarly, copy over function bodies now... // - for (Module::const_iterator I = M->begin(), E = M->end(); I != E; ++I) { - Function *F = cast<Function>(VMap[&*I]); - if (!ShouldCloneDefinition(&*I)) { + for (const Function &I : *M) { + if (I.isDeclaration()) + continue; + + Function *F = cast<Function>(VMap[&I]); + if (!ShouldCloneDefinition(&I)) { // Skip after setting the correct linkage for an external reference. F->setLinkage(GlobalValue::ExternalLinkage); + // Personality function is not valid on a declaration. + F->setPersonalityFn(nullptr); continue; } - if (!I->isDeclaration()) { - Function::arg_iterator DestI = F->arg_begin(); - for (Function::const_arg_iterator J = I->arg_begin(); J != I->arg_end(); - ++J) { - DestI->setName(J->getName()); - VMap[&*J] = &*DestI++; - } - - SmallVector<ReturnInst*, 8> Returns; // Ignore returns cloned. - CloneFunctionInto(F, &*I, VMap, /*ModuleLevelChanges=*/true, Returns); + + Function::arg_iterator DestI = F->arg_begin(); + for (Function::const_arg_iterator J = I.arg_begin(); J != I.arg_end(); + ++J) { + DestI->setName(J->getName()); + VMap[&*J] = &*DestI++; } - if (I->hasPersonalityFn()) - F->setPersonalityFn(MapValue(I->getPersonalityFn(), VMap)); + SmallVector<ReturnInst *, 8> Returns; // Ignore returns cloned. + CloneFunctionInto(F, &I, VMap, /*ModuleLevelChanges=*/true, Returns); + + if (I.hasPersonalityFn()) + F->setPersonalityFn(MapValue(I.getPersonalityFn(), VMap)); } // And aliases diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index 823696d88e65..9f2181f87cee 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -77,15 +77,15 @@ static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin, // Loop over the blocks, adding them to our set-vector, and aborting with an // empty set if we encounter invalid blocks. - for (IteratorT I = BBBegin, E = BBEnd; I != E; ++I) { - if (!Result.insert(*I)) + do { + if (!Result.insert(*BBBegin)) llvm_unreachable("Repeated basic blocks in extraction input"); - if (!isBlockValidForExtraction(**I)) { + if (!isBlockValidForExtraction(**BBBegin)) { Result.clear(); return Result; } - } + } while (++BBBegin != BBEnd); #ifndef NDEBUG for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()), @@ -159,23 +159,18 @@ static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) { void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs) const { - for (SetVector<BasicBlock *>::const_iterator I = Blocks.begin(), - E = Blocks.end(); - I != E; ++I) { - BasicBlock *BB = *I; - + for (BasicBlock *BB : Blocks) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. - for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); - II != IE; ++II) { - for (User::op_iterator OI = II->op_begin(), OE = II->op_end(); - OI != OE; ++OI) + for (Instruction &II : *BB) { + for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE; + ++OI) if (definedInCaller(Blocks, *OI)) Inputs.insert(*OI); - for (User *U : II->users()) + for (User *U : II.users()) if (!definedInRegion(Blocks, U)) { - Outputs.insert(&*II); + Outputs.insert(&II); break; } } @@ -263,25 +258,21 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { } void CodeExtractor::splitReturnBlocks() { - for (SetVector<BasicBlock *>::iterator I = Blocks.begin(), E = Blocks.end(); - I != E; ++I) - if (ReturnInst *RI = dyn_cast<ReturnInst>((*I)->getTerminator())) { + for (BasicBlock *Block : Blocks) + if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) { BasicBlock *New = - (*I)->splitBasicBlock(RI->getIterator(), (*I)->getName() + ".ret"); + Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret"); if (DT) { // Old dominates New. New node dominates all other nodes dominated // by Old. - DomTreeNode *OldNode = DT->getNode(*I); - SmallVector<DomTreeNode*, 8> Children; - for (DomTreeNode::iterator DI = OldNode->begin(), DE = OldNode->end(); - DI != DE; ++DI) - Children.push_back(*DI); + DomTreeNode *OldNode = DT->getNode(Block); + SmallVector<DomTreeNode *, 8> Children(OldNode->begin(), + OldNode->end()); - DomTreeNode *NewNode = DT->addNewBlock(New, *I); + DomTreeNode *NewNode = DT->addNewBlock(New, Block); - for (SmallVectorImpl<DomTreeNode *>::iterator I = Children.begin(), - E = Children.end(); I != E; ++I) - DT->changeImmediateDominator(*I, NewNode); + for (DomTreeNode *I : Children) + DT->changeImmediateDominator(I, NewNode); } } } @@ -310,28 +301,26 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, std::vector<Type*> paramTy; // Add the types of the input values to the function's argument list - for (ValueSet::const_iterator i = inputs.begin(), e = inputs.end(); - i != e; ++i) { - const Value *value = *i; + for (Value *value : inputs) { DEBUG(dbgs() << "value used in func: " << *value << "\n"); paramTy.push_back(value->getType()); } // Add the types of the output values to the function's argument list. - for (ValueSet::const_iterator I = outputs.begin(), E = outputs.end(); - I != E; ++I) { - DEBUG(dbgs() << "instr used in func: " << **I << "\n"); + for (Value *output : outputs) { + DEBUG(dbgs() << "instr used in func: " << *output << "\n"); if (AggregateArgs) - paramTy.push_back((*I)->getType()); + paramTy.push_back(output->getType()); else - paramTy.push_back(PointerType::getUnqual((*I)->getType())); + paramTy.push_back(PointerType::getUnqual(output->getType())); } - DEBUG(dbgs() << "Function type: " << *RetTy << " f("); - for (std::vector<Type*>::iterator i = paramTy.begin(), - e = paramTy.end(); i != e; ++i) - DEBUG(dbgs() << **i << ", "); - DEBUG(dbgs() << ")\n"); + DEBUG({ + dbgs() << "Function type: " << *RetTy << " f("; + for (Type *i : paramTy) + dbgs() << *i << ", "; + dbgs() << ")\n"; + }); StructType *StructTy; if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { @@ -372,9 +361,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, RewriteVal = &*AI++; std::vector<User*> Users(inputs[i]->user_begin(), inputs[i]->user_end()); - for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end(); - use != useE; ++use) - if (Instruction* inst = dyn_cast<Instruction>(*use)) + for (User *use : Users) + if (Instruction *inst = dyn_cast<Instruction>(use)) if (Blocks.count(inst->getParent())) inst->replaceUsesOfWith(inputs[i], RewriteVal); } @@ -429,19 +417,19 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, LLVMContext &Context = newFunction->getContext(); // Add inputs as params, or to be filled into the struct - for (ValueSet::iterator i = inputs.begin(), e = inputs.end(); i != e; ++i) + for (Value *input : inputs) if (AggregateArgs) - StructValues.push_back(*i); + StructValues.push_back(input); else - params.push_back(*i); + params.push_back(input); // Create allocas for the outputs - for (ValueSet::iterator i = outputs.begin(), e = outputs.end(); i != e; ++i) { + for (Value *output : outputs) { if (AggregateArgs) { - StructValues.push_back(*i); + StructValues.push_back(output); } else { AllocaInst *alloca = - new AllocaInst((*i)->getType(), nullptr, (*i)->getName() + ".loc", + new AllocaInst(output->getType(), nullptr, output->getName() + ".loc", &codeReplacer->getParent()->front().front()); ReloadOutputs.push_back(alloca); params.push_back(alloca); @@ -522,9 +510,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, std::map<BasicBlock*, BasicBlock*> ExitBlockMap; unsigned switchVal = 0; - for (SetVector<BasicBlock*>::const_iterator i = Blocks.begin(), - e = Blocks.end(); i != e; ++i) { - TerminatorInst *TI = (*i)->getTerminator(); + for (BasicBlock *Block : Blocks) { + TerminatorInst *TI = Block->getTerminator(); for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) if (!Blocks.count(TI->getSuccessor(i))) { BasicBlock *OldTarget = TI->getSuccessor(i); @@ -576,10 +563,9 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, // Make sure we are looking at the original successor block, not // at a newly inserted exit block, which won't be in the dominator // info. - for (std::map<BasicBlock*, BasicBlock*>::iterator I = - ExitBlockMap.begin(), E = ExitBlockMap.end(); I != E; ++I) - if (DefBlock == I->second) { - DefBlock = I->first; + for (const auto &I : ExitBlockMap) + if (DefBlock == I.second) { + DefBlock = I.first; break; } @@ -677,13 +663,12 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) { Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); - for (SetVector<BasicBlock*>::const_iterator i = Blocks.begin(), - e = Blocks.end(); i != e; ++i) { + for (BasicBlock *Block : Blocks) { // Delete the basic block from the old function, and the list of blocks - oldBlocks.remove(*i); + oldBlocks.remove(Block); // Insert this basic block into the new function - newBlocks.push_back(*i); + newBlocks.push_back(Block); } } @@ -721,9 +706,9 @@ Function *CodeExtractor::extractCodeRegion() { findInputsOutputs(inputs, outputs); SmallPtrSet<BasicBlock *, 1> ExitBlocks; - for (SetVector<BasicBlock *>::iterator I = Blocks.begin(), E = Blocks.end(); - I != E; ++I) - for (succ_iterator SI = succ_begin(*I), SE = succ_end(*I); SI != SE; ++SI) + for (BasicBlock *Block : Blocks) + for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; + ++SI) if (!Blocks.count(*SI)) ExitBlocks.insert(*SI); NumExitBlocks = ExitBlocks.size(); diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp new file mode 100644 index 000000000000..cd130abf4519 --- /dev/null +++ b/lib/Transforms/Utils/Evaluator.cpp @@ -0,0 +1,596 @@ +//===- Evaluator.cpp - LLVM IR evaluator ----------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Function evaluator for LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/Evaluator.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "evaluator" + +using namespace llvm; + +static inline bool +isSimpleEnoughValueToCommit(Constant *C, + SmallPtrSetImpl<Constant *> &SimpleConstants, + const DataLayout &DL); + +/// Return true if the specified constant can be handled by the code generator. +/// We don't want to generate something like: +/// void *X = &X/42; +/// because the code generator doesn't have a relocation that can handle that. +/// +/// This function should be called if C was not found (but just got inserted) +/// in SimpleConstants to avoid having to rescan the same constants all the +/// time. +static bool +isSimpleEnoughValueToCommitHelper(Constant *C, + SmallPtrSetImpl<Constant *> &SimpleConstants, + const DataLayout &DL) { + // Simple global addresses are supported, do not allow dllimport or + // thread-local globals. + if (auto *GV = dyn_cast<GlobalValue>(C)) + return !GV->hasDLLImportStorageClass() && !GV->isThreadLocal(); + + // Simple integer, undef, constant aggregate zero, etc are all supported. + if (C->getNumOperands() == 0 || isa<BlockAddress>(C)) + return true; + + // Aggregate values are safe if all their elements are. + if (isa<ConstantAggregate>(C)) { + for (Value *Op : C->operands()) + if (!isSimpleEnoughValueToCommit(cast<Constant>(Op), SimpleConstants, DL)) + return false; + return true; + } + + // We don't know exactly what relocations are allowed in constant expressions, + // so we allow &global+constantoffset, which is safe and uniformly supported + // across targets. + ConstantExpr *CE = cast<ConstantExpr>(C); + switch (CE->getOpcode()) { + case Instruction::BitCast: + // Bitcast is fine if the casted value is fine. + return isSimpleEnoughValueToCommit(CE->getOperand(0), SimpleConstants, DL); + + case Instruction::IntToPtr: + case Instruction::PtrToInt: + // int <=> ptr is fine if the int type is the same size as the + // pointer type. + if (DL.getTypeSizeInBits(CE->getType()) != + DL.getTypeSizeInBits(CE->getOperand(0)->getType())) + return false; + return isSimpleEnoughValueToCommit(CE->getOperand(0), SimpleConstants, DL); + + // GEP is fine if it is simple + constant offset. + case Instruction::GetElementPtr: + for (unsigned i = 1, e = CE->getNumOperands(); i != e; ++i) + if (!isa<ConstantInt>(CE->getOperand(i))) + return false; + return isSimpleEnoughValueToCommit(CE->getOperand(0), SimpleConstants, DL); + + case Instruction::Add: + // We allow simple+cst. + if (!isa<ConstantInt>(CE->getOperand(1))) + return false; + return isSimpleEnoughValueToCommit(CE->getOperand(0), SimpleConstants, DL); + } + return false; +} + +static inline bool +isSimpleEnoughValueToCommit(Constant *C, + SmallPtrSetImpl<Constant *> &SimpleConstants, + const DataLayout &DL) { + // If we already checked this constant, we win. + if (!SimpleConstants.insert(C).second) + return true; + // Check the constant. + return isSimpleEnoughValueToCommitHelper(C, SimpleConstants, DL); +} + +/// Return true if this constant is simple enough for us to understand. In +/// particular, if it is a cast to anything other than from one pointer type to +/// another pointer type, we punt. We basically just support direct accesses to +/// globals and GEP's of globals. This should be kept up to date with +/// CommitValueTo. +static bool isSimpleEnoughPointerToCommit(Constant *C) { + // Conservatively, avoid aggregate types. This is because we don't + // want to worry about them partially overlapping other stores. + if (!cast<PointerType>(C->getType())->getElementType()->isSingleValueType()) + return false; + + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) + // Do not allow weak/*_odr/linkonce linkage or external globals. + return GV->hasUniqueInitializer(); + + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { + // Handle a constantexpr gep. + if (CE->getOpcode() == Instruction::GetElementPtr && + isa<GlobalVariable>(CE->getOperand(0)) && + cast<GEPOperator>(CE)->isInBounds()) { + GlobalVariable *GV = cast<GlobalVariable>(CE->getOperand(0)); + // Do not allow weak/*_odr/linkonce/dllimport/dllexport linkage or + // external globals. + if (!GV->hasUniqueInitializer()) + return false; + + // The first index must be zero. + ConstantInt *CI = dyn_cast<ConstantInt>(*std::next(CE->op_begin())); + if (!CI || !CI->isZero()) return false; + + // The remaining indices must be compile-time known integers within the + // notional bounds of the corresponding static array types. + if (!CE->isGEPWithNoNotionalOverIndexing()) + return false; + + return ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE); + + // A constantexpr bitcast from a pointer to another pointer is a no-op, + // and we know how to evaluate it by moving the bitcast from the pointer + // operand to the value operand. + } else if (CE->getOpcode() == Instruction::BitCast && + isa<GlobalVariable>(CE->getOperand(0))) { + // Do not allow weak/*_odr/linkonce/dllimport/dllexport linkage or + // external globals. + return cast<GlobalVariable>(CE->getOperand(0))->hasUniqueInitializer(); + } + } + + return false; +} + +/// Return the value that would be computed by a load from P after the stores +/// reflected by 'memory' have been performed. If we can't decide, return null. +Constant *Evaluator::ComputeLoadResult(Constant *P) { + // If this memory location has been recently stored, use the stored value: it + // is the most up-to-date. + DenseMap<Constant*, Constant*>::const_iterator I = MutatedMemory.find(P); + if (I != MutatedMemory.end()) return I->second; + + // Access it. + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(P)) { + if (GV->hasDefinitiveInitializer()) + return GV->getInitializer(); + return nullptr; + } + + // Handle a constantexpr getelementptr. + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(P)) + if (CE->getOpcode() == Instruction::GetElementPtr && + isa<GlobalVariable>(CE->getOperand(0))) { + GlobalVariable *GV = cast<GlobalVariable>(CE->getOperand(0)); + if (GV->hasDefinitiveInitializer()) + return ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE); + } + + return nullptr; // don't know how to evaluate. +} + +/// Evaluate all instructions in block BB, returning true if successful, false +/// if we can't evaluate it. NewBB returns the next BB that control flows into, +/// or null upon return. +bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, + BasicBlock *&NextBB) { + // This is the main evaluation loop. + while (1) { + Constant *InstResult = nullptr; + + DEBUG(dbgs() << "Evaluating Instruction: " << *CurInst << "\n"); + + if (StoreInst *SI = dyn_cast<StoreInst>(CurInst)) { + if (!SI->isSimple()) { + DEBUG(dbgs() << "Store is not simple! Can not evaluate.\n"); + return false; // no volatile/atomic accesses. + } + Constant *Ptr = getVal(SI->getOperand(1)); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { + DEBUG(dbgs() << "Folding constant ptr expression: " << *Ptr); + Ptr = ConstantFoldConstantExpression(CE, DL, TLI); + DEBUG(dbgs() << "; To: " << *Ptr << "\n"); + } + if (!isSimpleEnoughPointerToCommit(Ptr)) { + // If this is too complex for us to commit, reject it. + DEBUG(dbgs() << "Pointer is too complex for us to evaluate store."); + return false; + } + + Constant *Val = getVal(SI->getOperand(0)); + + // If this might be too difficult for the backend to handle (e.g. the addr + // of one global variable divided by another) then we can't commit it. + if (!isSimpleEnoughValueToCommit(Val, SimpleConstants, DL)) { + DEBUG(dbgs() << "Store value is too complex to evaluate store. " << *Val + << "\n"); + return false; + } + + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { + if (CE->getOpcode() == Instruction::BitCast) { + DEBUG(dbgs() << "Attempting to resolve bitcast on constant ptr.\n"); + // If we're evaluating a store through a bitcast, then we need + // to pull the bitcast off the pointer type and push it onto the + // stored value. + Ptr = CE->getOperand(0); + + Type *NewTy = cast<PointerType>(Ptr->getType())->getElementType(); + + // In order to push the bitcast onto the stored value, a bitcast + // from NewTy to Val's type must be legal. If it's not, we can try + // introspecting NewTy to find a legal conversion. + while (!Val->getType()->canLosslesslyBitCastTo(NewTy)) { + // If NewTy is a struct, we can convert the pointer to the struct + // into a pointer to its first member. + // FIXME: This could be extended to support arrays as well. + if (StructType *STy = dyn_cast<StructType>(NewTy)) { + NewTy = STy->getTypeAtIndex(0U); + + IntegerType *IdxTy = IntegerType::get(NewTy->getContext(), 32); + Constant *IdxZero = ConstantInt::get(IdxTy, 0, false); + Constant * const IdxList[] = {IdxZero, IdxZero}; + + Ptr = ConstantExpr::getGetElementPtr(nullptr, Ptr, IdxList); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) + Ptr = ConstantFoldConstantExpression(CE, DL, TLI); + + // If we can't improve the situation by introspecting NewTy, + // we have to give up. + } else { + DEBUG(dbgs() << "Failed to bitcast constant ptr, can not " + "evaluate.\n"); + return false; + } + } + + // If we found compatible types, go ahead and push the bitcast + // onto the stored value. + Val = ConstantExpr::getBitCast(Val, NewTy); + + DEBUG(dbgs() << "Evaluated bitcast: " << *Val << "\n"); + } + } + + MutatedMemory[Ptr] = Val; + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CurInst)) { + InstResult = ConstantExpr::get(BO->getOpcode(), + getVal(BO->getOperand(0)), + getVal(BO->getOperand(1))); + DEBUG(dbgs() << "Found a BinaryOperator! Simplifying: " << *InstResult + << "\n"); + } else if (CmpInst *CI = dyn_cast<CmpInst>(CurInst)) { + InstResult = ConstantExpr::getCompare(CI->getPredicate(), + getVal(CI->getOperand(0)), + getVal(CI->getOperand(1))); + DEBUG(dbgs() << "Found a CmpInst! Simplifying: " << *InstResult + << "\n"); + } else if (CastInst *CI = dyn_cast<CastInst>(CurInst)) { + InstResult = ConstantExpr::getCast(CI->getOpcode(), + getVal(CI->getOperand(0)), + CI->getType()); + DEBUG(dbgs() << "Found a Cast! Simplifying: " << *InstResult + << "\n"); + } else if (SelectInst *SI = dyn_cast<SelectInst>(CurInst)) { + InstResult = ConstantExpr::getSelect(getVal(SI->getOperand(0)), + getVal(SI->getOperand(1)), + getVal(SI->getOperand(2))); + DEBUG(dbgs() << "Found a Select! Simplifying: " << *InstResult + << "\n"); + } else if (auto *EVI = dyn_cast<ExtractValueInst>(CurInst)) { + InstResult = ConstantExpr::getExtractValue( + getVal(EVI->getAggregateOperand()), EVI->getIndices()); + DEBUG(dbgs() << "Found an ExtractValueInst! Simplifying: " << *InstResult + << "\n"); + } else if (auto *IVI = dyn_cast<InsertValueInst>(CurInst)) { + InstResult = ConstantExpr::getInsertValue( + getVal(IVI->getAggregateOperand()), + getVal(IVI->getInsertedValueOperand()), IVI->getIndices()); + DEBUG(dbgs() << "Found an InsertValueInst! Simplifying: " << *InstResult + << "\n"); + } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(CurInst)) { + Constant *P = getVal(GEP->getOperand(0)); + SmallVector<Constant*, 8> GEPOps; + for (User::op_iterator i = GEP->op_begin() + 1, e = GEP->op_end(); + i != e; ++i) + GEPOps.push_back(getVal(*i)); + InstResult = + ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), P, GEPOps, + cast<GEPOperator>(GEP)->isInBounds()); + DEBUG(dbgs() << "Found a GEP! Simplifying: " << *InstResult + << "\n"); + } else if (LoadInst *LI = dyn_cast<LoadInst>(CurInst)) { + + if (!LI->isSimple()) { + DEBUG(dbgs() << "Found a Load! Not a simple load, can not evaluate.\n"); + return false; // no volatile/atomic accesses. + } + + Constant *Ptr = getVal(LI->getOperand(0)); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { + Ptr = ConstantFoldConstantExpression(CE, DL, TLI); + DEBUG(dbgs() << "Found a constant pointer expression, constant " + "folding: " << *Ptr << "\n"); + } + InstResult = ComputeLoadResult(Ptr); + if (!InstResult) { + DEBUG(dbgs() << "Failed to compute load result. Can not evaluate load." + "\n"); + return false; // Could not evaluate load. + } + + DEBUG(dbgs() << "Evaluated load: " << *InstResult << "\n"); + } else if (AllocaInst *AI = dyn_cast<AllocaInst>(CurInst)) { + if (AI->isArrayAllocation()) { + DEBUG(dbgs() << "Found an array alloca. Can not evaluate.\n"); + return false; // Cannot handle array allocs. + } + Type *Ty = AI->getAllocatedType(); + AllocaTmps.push_back( + make_unique<GlobalVariable>(Ty, false, GlobalValue::InternalLinkage, + UndefValue::get(Ty), AI->getName())); + InstResult = AllocaTmps.back().get(); + DEBUG(dbgs() << "Found an alloca. Result: " << *InstResult << "\n"); + } else if (isa<CallInst>(CurInst) || isa<InvokeInst>(CurInst)) { + CallSite CS(&*CurInst); + + // Debug info can safely be ignored here. + if (isa<DbgInfoIntrinsic>(CS.getInstruction())) { + DEBUG(dbgs() << "Ignoring debug info.\n"); + ++CurInst; + continue; + } + + // Cannot handle inline asm. + if (isa<InlineAsm>(CS.getCalledValue())) { + DEBUG(dbgs() << "Found inline asm, can not evaluate.\n"); + return false; + } + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { + if (MemSetInst *MSI = dyn_cast<MemSetInst>(II)) { + if (MSI->isVolatile()) { + DEBUG(dbgs() << "Can not optimize a volatile memset " << + "intrinsic.\n"); + return false; + } + Constant *Ptr = getVal(MSI->getDest()); + Constant *Val = getVal(MSI->getValue()); + Constant *DestVal = ComputeLoadResult(getVal(Ptr)); + if (Val->isNullValue() && DestVal && DestVal->isNullValue()) { + // This memset is a no-op. + DEBUG(dbgs() << "Ignoring no-op memset.\n"); + ++CurInst; + continue; + } + } + + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) { + DEBUG(dbgs() << "Ignoring lifetime intrinsic.\n"); + ++CurInst; + continue; + } + + if (II->getIntrinsicID() == Intrinsic::invariant_start) { + // We don't insert an entry into Values, as it doesn't have a + // meaningful return value. + if (!II->use_empty()) { + DEBUG(dbgs() << "Found unused invariant_start. Can't evaluate.\n"); + return false; + } + ConstantInt *Size = cast<ConstantInt>(II->getArgOperand(0)); + Value *PtrArg = getVal(II->getArgOperand(1)); + Value *Ptr = PtrArg->stripPointerCasts(); + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) { + Type *ElemTy = GV->getValueType(); + if (!Size->isAllOnesValue() && + Size->getValue().getLimitedValue() >= + DL.getTypeStoreSize(ElemTy)) { + Invariants.insert(GV); + DEBUG(dbgs() << "Found a global var that is an invariant: " << *GV + << "\n"); + } else { + DEBUG(dbgs() << "Found a global var, but can not treat it as an " + "invariant.\n"); + } + } + // Continue even if we do nothing. + ++CurInst; + continue; + } else if (II->getIntrinsicID() == Intrinsic::assume) { + DEBUG(dbgs() << "Skipping assume intrinsic.\n"); + ++CurInst; + continue; + } + + DEBUG(dbgs() << "Unknown intrinsic. Can not evaluate.\n"); + return false; + } + + // Resolve function pointers. + Function *Callee = dyn_cast<Function>(getVal(CS.getCalledValue())); + if (!Callee || Callee->isInterposable()) { + DEBUG(dbgs() << "Can not resolve function pointer.\n"); + return false; // Cannot resolve. + } + + SmallVector<Constant*, 8> Formals; + for (User::op_iterator i = CS.arg_begin(), e = CS.arg_end(); i != e; ++i) + Formals.push_back(getVal(*i)); + + if (Callee->isDeclaration()) { + // If this is a function we can constant fold, do it. + if (Constant *C = ConstantFoldCall(Callee, Formals, TLI)) { + InstResult = C; + DEBUG(dbgs() << "Constant folded function call. Result: " << + *InstResult << "\n"); + } else { + DEBUG(dbgs() << "Can not constant fold function call.\n"); + return false; + } + } else { + if (Callee->getFunctionType()->isVarArg()) { + DEBUG(dbgs() << "Can not constant fold vararg function call.\n"); + return false; + } + + Constant *RetVal = nullptr; + // Execute the call, if successful, use the return value. + ValueStack.emplace_back(); + if (!EvaluateFunction(Callee, RetVal, Formals)) { + DEBUG(dbgs() << "Failed to evaluate function.\n"); + return false; + } + ValueStack.pop_back(); + InstResult = RetVal; + + if (InstResult) { + DEBUG(dbgs() << "Successfully evaluated function. Result: " + << *InstResult << "\n\n"); + } else { + DEBUG(dbgs() << "Successfully evaluated function. Result: 0\n\n"); + } + } + } else if (isa<TerminatorInst>(CurInst)) { + DEBUG(dbgs() << "Found a terminator instruction.\n"); + + if (BranchInst *BI = dyn_cast<BranchInst>(CurInst)) { + if (BI->isUnconditional()) { + NextBB = BI->getSuccessor(0); + } else { + ConstantInt *Cond = + dyn_cast<ConstantInt>(getVal(BI->getCondition())); + if (!Cond) return false; // Cannot determine. + + NextBB = BI->getSuccessor(!Cond->getZExtValue()); + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurInst)) { + ConstantInt *Val = + dyn_cast<ConstantInt>(getVal(SI->getCondition())); + if (!Val) return false; // Cannot determine. + NextBB = SI->findCaseValue(Val).getCaseSuccessor(); + } else if (IndirectBrInst *IBI = dyn_cast<IndirectBrInst>(CurInst)) { + Value *Val = getVal(IBI->getAddress())->stripPointerCasts(); + if (BlockAddress *BA = dyn_cast<BlockAddress>(Val)) + NextBB = BA->getBasicBlock(); + else + return false; // Cannot determine. + } else if (isa<ReturnInst>(CurInst)) { + NextBB = nullptr; + } else { + // invoke, unwind, resume, unreachable. + DEBUG(dbgs() << "Can not handle terminator."); + return false; // Cannot handle this terminator. + } + + // We succeeded at evaluating this block! + DEBUG(dbgs() << "Successfully evaluated block.\n"); + return true; + } else { + // Did not know how to evaluate this! + DEBUG(dbgs() << "Failed to evaluate block due to unhandled instruction." + "\n"); + return false; + } + + if (!CurInst->use_empty()) { + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(InstResult)) + InstResult = ConstantFoldConstantExpression(CE, DL, TLI); + + setVal(&*CurInst, InstResult); + } + + // If we just processed an invoke, we finished evaluating the block. + if (InvokeInst *II = dyn_cast<InvokeInst>(CurInst)) { + NextBB = II->getNormalDest(); + DEBUG(dbgs() << "Found an invoke instruction. Finished Block.\n\n"); + return true; + } + + // Advance program counter. + ++CurInst; + } +} + +/// Evaluate a call to function F, returning true if successful, false if we +/// can't evaluate it. ActualArgs contains the formal arguments for the +/// function. +bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, + const SmallVectorImpl<Constant*> &ActualArgs) { + // Check to see if this function is already executing (recursion). If so, + // bail out. TODO: we might want to accept limited recursion. + if (std::find(CallStack.begin(), CallStack.end(), F) != CallStack.end()) + return false; + + CallStack.push_back(F); + + // Initialize arguments to the incoming values specified. + unsigned ArgNo = 0; + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); AI != E; + ++AI, ++ArgNo) + setVal(&*AI, ActualArgs[ArgNo]); + + // ExecutedBlocks - We only handle non-looping, non-recursive code. As such, + // we can only evaluate any one basic block at most once. This set keeps + // track of what we have executed so we can detect recursive cases etc. + SmallPtrSet<BasicBlock*, 32> ExecutedBlocks; + + // CurBB - The current basic block we're evaluating. + BasicBlock *CurBB = &F->front(); + + BasicBlock::iterator CurInst = CurBB->begin(); + + while (1) { + BasicBlock *NextBB = nullptr; // Initialized to avoid compiler warnings. + DEBUG(dbgs() << "Trying to evaluate BB: " << *CurBB << "\n"); + + if (!EvaluateBlock(CurInst, NextBB)) + return false; + + if (!NextBB) { + // Successfully running until there's no next block means that we found + // the return. Fill it the return value and pop the call stack. + ReturnInst *RI = cast<ReturnInst>(CurBB->getTerminator()); + if (RI->getNumOperands()) + RetVal = getVal(RI->getOperand(0)); + CallStack.pop_back(); + return true; + } + + // Okay, we succeeded in evaluating this control flow. See if we have + // executed the new block before. If so, we have a looping function, + // which we cannot evaluate in reasonable time. + if (!ExecutedBlocks.insert(NextBB).second) + return false; // looped! + + // Okay, we have never been in this block before. Check to see if there + // are any PHI nodes. If so, evaluate them with information about where + // we came from. + PHINode *PN = nullptr; + for (CurInst = NextBB->begin(); + (PN = dyn_cast<PHINode>(CurInst)); ++CurInst) + setVal(PN, getVal(PN->getIncomingValueForBlock(CurBB))); + + // Advance to the next block. + CurBB = NextBB; + } +} + diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp new file mode 100644 index 000000000000..fcb25baf3216 --- /dev/null +++ b/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -0,0 +1,243 @@ +//===- lib/Transforms/Utils/FunctionImportUtils.cpp - Importing utilities -===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the FunctionImportGlobalProcessing class, used +// to perform the necessary global value handling for function importing. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ModuleSummaryAnalysis.h" +#include "llvm/Transforms/Utils/FunctionImportUtils.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +using namespace llvm; + +/// Checks if we should import SGV as a definition, otherwise import as a +/// declaration. +bool FunctionImportGlobalProcessing::doImportAsDefinition( + const GlobalValue *SGV, DenseSet<const GlobalValue *> *GlobalsToImport) { + + // For alias, we tie the definition to the base object. Extract it and recurse + if (auto *GA = dyn_cast<GlobalAlias>(SGV)) { + if (GA->hasWeakAnyLinkage()) + return false; + const GlobalObject *GO = GA->getBaseObject(); + if (!GO->hasLinkOnceODRLinkage()) + return false; + return FunctionImportGlobalProcessing::doImportAsDefinition( + GO, GlobalsToImport); + } + // Only import the globals requested for importing. + if (GlobalsToImport->count(SGV)) + return true; + // Otherwise no. + return false; +} + +bool FunctionImportGlobalProcessing::doImportAsDefinition( + const GlobalValue *SGV) { + if (!isPerformingImport()) + return false; + return FunctionImportGlobalProcessing::doImportAsDefinition(SGV, + GlobalsToImport); +} + +bool FunctionImportGlobalProcessing::doPromoteLocalToGlobal( + const GlobalValue *SGV) { + assert(SGV->hasLocalLinkage()); + // Both the imported references and the original local variable must + // be promoted. + if (!isPerformingImport() && !isModuleExporting()) + return false; + + // Local const variables never need to be promoted unless they are address + // taken. The imported uses can simply use the clone created in this module. + // For now we are conservative in determining which variables are not + // address taken by checking the unnamed addr flag. To be more aggressive, + // the address taken information must be checked earlier during parsing + // of the module and recorded in the summary index for use when importing + // from that module. + auto *GVar = dyn_cast<GlobalVariable>(SGV); + if (GVar && GVar->isConstant() && GVar->hasGlobalUnnamedAddr()) + return false; + + if (GVar && GVar->hasSection()) + // Some sections like "__DATA,__cfstring" are "magic" and promotion is not + // allowed. Just disable promotion on any GVar with sections right now. + return false; + + // Eventually we only need to promote functions in the exporting module that + // are referenced by a potentially exported function (i.e. one that is in the + // summary index). + return true; +} + +std::string FunctionImportGlobalProcessing::getName(const GlobalValue *SGV) { + // For locals that must be promoted to global scope, ensure that + // the promoted name uniquely identifies the copy in the original module, + // using the ID assigned during combined index creation. When importing, + // we rename all locals (not just those that are promoted) in order to + // avoid naming conflicts between locals imported from different modules. + if (SGV->hasLocalLinkage() && + (doPromoteLocalToGlobal(SGV) || isPerformingImport())) + return ModuleSummaryIndex::getGlobalNameForLocal( + SGV->getName(), + ImportIndex.getModuleHash(SGV->getParent()->getModuleIdentifier())); + return SGV->getName(); +} + +GlobalValue::LinkageTypes +FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV) { + // Any local variable that is referenced by an exported function needs + // to be promoted to global scope. Since we don't currently know which + // functions reference which local variables/functions, we must treat + // all as potentially exported if this module is exporting anything. + if (isModuleExporting()) { + if (SGV->hasLocalLinkage() && doPromoteLocalToGlobal(SGV)) + return GlobalValue::ExternalLinkage; + return SGV->getLinkage(); + } + + // Otherwise, if we aren't importing, no linkage change is needed. + if (!isPerformingImport()) + return SGV->getLinkage(); + + switch (SGV->getLinkage()) { + case GlobalValue::ExternalLinkage: + // External defnitions are converted to available_externally + // definitions upon import, so that they are available for inlining + // and/or optimization, but are turned into declarations later + // during the EliminateAvailableExternally pass. + if (doImportAsDefinition(SGV) && !dyn_cast<GlobalAlias>(SGV)) + return GlobalValue::AvailableExternallyLinkage; + // An imported external declaration stays external. + return SGV->getLinkage(); + + case GlobalValue::AvailableExternallyLinkage: + // An imported available_externally definition converts + // to external if imported as a declaration. + if (!doImportAsDefinition(SGV)) + return GlobalValue::ExternalLinkage; + // An imported available_externally declaration stays that way. + return SGV->getLinkage(); + + case GlobalValue::LinkOnceAnyLinkage: + case GlobalValue::LinkOnceODRLinkage: + // These both stay the same when importing the definition. + // The ThinLTO pass will eventually force-import their definitions. + return SGV->getLinkage(); + + case GlobalValue::WeakAnyLinkage: + // Can't import weak_any definitions correctly, or we might change the + // program semantics, since the linker will pick the first weak_any + // definition and importing would change the order they are seen by the + // linker. The module linking caller needs to enforce this. + assert(!doImportAsDefinition(SGV)); + // If imported as a declaration, it becomes external_weak. + return SGV->getLinkage(); + + case GlobalValue::WeakODRLinkage: + // For weak_odr linkage, there is a guarantee that all copies will be + // equivalent, so the issue described above for weak_any does not exist, + // and the definition can be imported. It can be treated similarly + // to an imported externally visible global value. + if (doImportAsDefinition(SGV) && !dyn_cast<GlobalAlias>(SGV)) + return GlobalValue::AvailableExternallyLinkage; + else + return GlobalValue::ExternalLinkage; + + case GlobalValue::AppendingLinkage: + // It would be incorrect to import an appending linkage variable, + // since it would cause global constructors/destructors to be + // executed multiple times. This should have already been handled + // by linkIfNeeded, and we will assert in shouldLinkFromSource + // if we try to import, so we simply return AppendingLinkage. + return GlobalValue::AppendingLinkage; + + case GlobalValue::InternalLinkage: + case GlobalValue::PrivateLinkage: + // If we are promoting the local to global scope, it is handled + // similarly to a normal externally visible global. + if (doPromoteLocalToGlobal(SGV)) { + if (doImportAsDefinition(SGV) && !dyn_cast<GlobalAlias>(SGV)) + return GlobalValue::AvailableExternallyLinkage; + else + return GlobalValue::ExternalLinkage; + } + // A non-promoted imported local definition stays local. + // The ThinLTO pass will eventually force-import their definitions. + return SGV->getLinkage(); + + case GlobalValue::ExternalWeakLinkage: + // External weak doesn't apply to definitions, must be a declaration. + assert(!doImportAsDefinition(SGV)); + // Linkage stays external_weak. + return SGV->getLinkage(); + + case GlobalValue::CommonLinkage: + // Linkage stays common on definitions. + // The ThinLTO pass will eventually force-import their definitions. + return SGV->getLinkage(); + } + + llvm_unreachable("unknown linkage type"); +} + +void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { + if (GV.hasLocalLinkage() && + (doPromoteLocalToGlobal(&GV) || isPerformingImport())) { + GV.setName(getName(&GV)); + GV.setLinkage(getLinkage(&GV)); + if (!GV.hasLocalLinkage()) + GV.setVisibility(GlobalValue::HiddenVisibility); + } else + GV.setLinkage(getLinkage(&GV)); + + // Remove functions imported as available externally defs from comdats, + // as this is a declaration for the linker, and will be dropped eventually. + // It is illegal for comdats to contain declarations. + auto *GO = dyn_cast_or_null<GlobalObject>(&GV); + if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) { + // The IRMover should not have placed any imported declarations in + // a comdat, so the only declaration that should be in a comdat + // at this point would be a definition imported as available_externally. + assert(GO->hasAvailableExternallyLinkage() && + "Expected comdat on definition (possibly available external)"); + GO->setComdat(nullptr); + } +} + +void FunctionImportGlobalProcessing::processGlobalsForThinLTO() { + if (!moduleCanBeRenamedForThinLTO(M)) { + // We would have blocked importing from this module by suppressing index + // generation. We still may be able to import into this module though. + assert(!isPerformingImport() && + "Should have blocked importing from module with local used in ASM"); + return; + } + + for (GlobalVariable &GV : M.globals()) + processGlobalForThinLTO(GV); + for (Function &SF : M) + processGlobalForThinLTO(SF); + for (GlobalAlias &GA : M.aliases()) + processGlobalForThinLTO(GA); +} + +bool FunctionImportGlobalProcessing::run() { + processGlobalsForThinLTO(); + return false; +} + +bool llvm::renameModuleForThinLTO( + Module &M, const ModuleSummaryIndex &Index, + DenseSet<const GlobalValue *> *GlobalsToImport) { + FunctionImportGlobalProcessing ThinLTOProcessing(M, Index, GlobalsToImport); + return ThinLTOProcessing.run(); +} diff --git a/lib/Transforms/Utils/GlobalStatus.cpp b/lib/Transforms/Utils/GlobalStatus.cpp index 3893a752503b..266be41fbead 100644 --- a/lib/Transforms/Utils/GlobalStatus.cpp +++ b/lib/Transforms/Utils/GlobalStatus.cpp @@ -20,11 +20,11 @@ using namespace llvm; /// and release, then return AcquireRelease. /// static AtomicOrdering strongerOrdering(AtomicOrdering X, AtomicOrdering Y) { - if (X == Acquire && Y == Release) - return AcquireRelease; - if (Y == Acquire && X == Release) - return AcquireRelease; - return (AtomicOrdering)std::max(X, Y); + if (X == AtomicOrdering::Acquire && Y == AtomicOrdering::Release) + return AtomicOrdering::AcquireRelease; + if (Y == AtomicOrdering::Acquire && X == AtomicOrdering::Release) + return AtomicOrdering::AcquireRelease; + return (AtomicOrdering)std::max((unsigned)X, (unsigned)Y); } /// It is safe to destroy a constant iff it is only used by constants itself. @@ -105,7 +105,7 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, } } - if (StoredVal == GV->getInitializer()) { + if (GV->hasInitializer() && StoredVal == GV->getInitializer()) { if (GS.StoredType < GlobalStatus::InitializerStored) GS.StoredType = GlobalStatus::InitializerStored; } else if (isa<LoadInst>(StoredVal) && @@ -185,4 +185,4 @@ GlobalStatus::GlobalStatus() : IsCompared(false), IsLoaded(false), StoredType(NotStored), StoredOnceValue(nullptr), AccessingFunction(nullptr), HasMultipleAccessingFunctions(false), HasNonInstructionUser(false), - Ordering(NotAtomic) {} + Ordering(AtomicOrdering::NotAtomic) {} diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index 79282a2a703b..1fbb19d2b8ad 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -427,6 +427,17 @@ static BasicBlock *HandleCallsInBlockInlinedThroughInvoke( if (!CI || CI->doesNotThrow() || isa<InlineAsm>(CI->getCalledValue())) continue; + // We do not need to (and in fact, cannot) convert possibly throwing calls + // to @llvm.experimental_deoptimize (resp. @llvm.experimental.guard) into + // invokes. The caller's "segment" of the deoptimization continuation + // attached to the newly inlined @llvm.experimental_deoptimize + // (resp. @llvm.experimental.guard) call should contain the exception + // handling logic, if any. + if (auto *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::experimental_deoptimize || + F->getIntrinsicID() == Intrinsic::experimental_guard) + continue; + if (auto FuncletBundle = CI->getOperandBundle(LLVMContext::OB_funclet)) { // This call is nested inside a funclet. If that funclet has an unwind // destination within the inlinee, then unwinding out of this call would @@ -677,6 +688,34 @@ static void HandleInlinedEHPad(InvokeInst *II, BasicBlock *FirstNewBlock, UnwindDest->removePredecessor(InvokeBB); } +/// When inlining a call site that has !llvm.mem.parallel_loop_access metadata, +/// that metadata should be propagated to all memory-accessing cloned +/// instructions. +static void PropagateParallelLoopAccessMetadata(CallSite CS, + ValueToValueMapTy &VMap) { + MDNode *M = + CS.getInstruction()->getMetadata(LLVMContext::MD_mem_parallel_loop_access); + if (!M) + return; + + for (ValueToValueMapTy::iterator VMI = VMap.begin(), VMIE = VMap.end(); + VMI != VMIE; ++VMI) { + if (!VMI->second) + continue; + + Instruction *NI = dyn_cast<Instruction>(VMI->second); + if (!NI) + continue; + + if (MDNode *PM = NI->getMetadata(LLVMContext::MD_mem_parallel_loop_access)) { + M = MDNode::concatenate(PM, M); + NI->setMetadata(LLVMContext::MD_mem_parallel_loop_access, M); + } else if (NI->mayReadOrWriteMemory()) { + NI->setMetadata(LLVMContext::MD_mem_parallel_loop_access, M); + } + } +} + /// When inlining a function that contains noalias scope metadata, /// this metadata needs to be cloned so that the inlined blocks /// have different "unqiue scopes" at every call site. Were this not done, then @@ -693,12 +732,11 @@ static void CloneAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap) { // inter-procedural alias analysis passes. We can revisit this if it becomes // an efficiency or overhead problem. - for (Function::const_iterator I = CalledFunc->begin(), IE = CalledFunc->end(); - I != IE; ++I) - for (BasicBlock::const_iterator J = I->begin(), JE = I->end(); J != JE; ++J) { - if (const MDNode *M = J->getMetadata(LLVMContext::MD_alias_scope)) + for (const BasicBlock &I : *CalledFunc) + for (const Instruction &J : I) { + if (const MDNode *M = J.getMetadata(LLVMContext::MD_alias_scope)) MD.insert(M); - if (const MDNode *M = J->getMetadata(LLVMContext::MD_noalias)) + if (const MDNode *M = J.getMetadata(LLVMContext::MD_noalias)) MD.insert(M); } @@ -720,20 +758,18 @@ static void CloneAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap) { // the noalias scopes and the lists of those scopes. SmallVector<TempMDTuple, 16> DummyNodes; DenseMap<const MDNode *, TrackingMDNodeRef> MDMap; - for (SetVector<const MDNode *>::iterator I = MD.begin(), IE = MD.end(); - I != IE; ++I) { + for (const MDNode *I : MD) { DummyNodes.push_back(MDTuple::getTemporary(CalledFunc->getContext(), None)); - MDMap[*I].reset(DummyNodes.back().get()); + MDMap[I].reset(DummyNodes.back().get()); } // Create new metadata nodes to replace the dummy nodes, replacing old // metadata references with either a dummy node or an already-created new // node. - for (SetVector<const MDNode *>::iterator I = MD.begin(), IE = MD.end(); - I != IE; ++I) { + for (const MDNode *I : MD) { SmallVector<Metadata *, 4> NewOps; - for (unsigned i = 0, ie = (*I)->getNumOperands(); i != ie; ++i) { - const Metadata *V = (*I)->getOperand(i); + for (unsigned i = 0, ie = I->getNumOperands(); i != ie; ++i) { + const Metadata *V = I->getOperand(i); if (const MDNode *M = dyn_cast<MDNode>(V)) NewOps.push_back(MDMap[M]); else @@ -741,7 +777,7 @@ static void CloneAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap) { } MDNode *NewM = MDNode::get(CalledFunc->getContext(), NewOps); - MDTuple *TempM = cast<MDTuple>(MDMap[*I]); + MDTuple *TempM = cast<MDTuple>(MDMap[I]); assert(TempM->isTemporary() && "Expected temporary node"); TempM->replaceAllUsesWith(NewM); @@ -801,10 +837,9 @@ static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, const Function *CalledFunc = CS.getCalledFunction(); SmallVector<const Argument *, 4> NoAliasArgs; - for (const Argument &I : CalledFunc->args()) { - if (I.hasNoAliasAttr() && !I.hasNUses(0)) - NoAliasArgs.push_back(&I); - } + for (const Argument &Arg : CalledFunc->args()) + if (Arg.hasNoAliasAttr() && !Arg.use_empty()) + NoAliasArgs.push_back(&Arg); if (NoAliasArgs.empty()) return; @@ -885,17 +920,16 @@ static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, IsArgMemOnlyCall = true; } - for (ImmutableCallSite::arg_iterator AI = ICS.arg_begin(), - AE = ICS.arg_end(); AI != AE; ++AI) { + for (Value *Arg : ICS.args()) { // We need to check the underlying objects of all arguments, not just // the pointer arguments, because we might be passing pointers as // integers, etc. // However, if we know that the call only accesses pointer arguments, // then we only need to check the pointer arguments. - if (IsArgMemOnlyCall && !(*AI)->getType()->isPointerTy()) + if (IsArgMemOnlyCall && !Arg->getType()->isPointerTy()) continue; - PtrArgs.push_back(*AI); + PtrArgs.push_back(Arg); } } @@ -913,9 +947,9 @@ static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, SmallVector<Metadata *, 4> Scopes, NoAliases; SmallSetVector<const Argument *, 4> NAPtrArgs; - for (unsigned i = 0, ie = PtrArgs.size(); i != ie; ++i) { + for (const Value *V : PtrArgs) { SmallVector<Value *, 4> Objects; - GetUnderlyingObjects(const_cast<Value*>(PtrArgs[i]), + GetUnderlyingObjects(const_cast<Value*>(V), Objects, DL, /* LI = */ nullptr); for (Value *O : Objects) @@ -1228,7 +1262,8 @@ static bool hasLifetimeMarkers(AllocaInst *AI) { /// Rebuild the entire inlined-at chain for this instruction so that the top of /// the chain now is inlined-at the new call site. static DebugLoc -updateInlinedAtInfo(DebugLoc DL, DILocation *InlinedAtNode, LLVMContext &Ctx, +updateInlinedAtInfo(const DebugLoc &DL, DILocation *InlinedAtNode, + LLVMContext &Ctx, DenseMap<const DILocation *, DILocation *> &IANodes) { SmallVector<DILocation *, 3> InlinedAtLocations; DILocation *Last = InlinedAtNode; @@ -1249,8 +1284,7 @@ updateInlinedAtInfo(DebugLoc DL, DILocation *InlinedAtNode, LLVMContext &Ctx, // Starting from the top, rebuild the nodes to point to the new inlined-at // location (then rebuilding the rest of the chain behind it) and update the // map of already-constructed inlined-at nodes. - for (const DILocation *MD : make_range(InlinedAtLocations.rbegin(), - InlinedAtLocations.rend())) { + for (const DILocation *MD : reverse(InlinedAtLocations)) { Last = IANodes[MD] = DILocation::getDistinct( Ctx, MD->getLine(), MD->getColumn(), MD->getScope(), Last); } @@ -1264,7 +1298,7 @@ updateInlinedAtInfo(DebugLoc DL, DILocation *InlinedAtNode, LLVMContext &Ctx, /// to encode location where these instructions are inlined. static void fixupLineNumbers(Function *Fn, Function::iterator FI, Instruction *TheCall) { - DebugLoc TheCallDL = TheCall->getDebugLoc(); + const DebugLoc &TheCallDL = TheCall->getDebugLoc(); if (!TheCallDL) return; @@ -1422,6 +1456,19 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, } } + // Determine if we are dealing with a call in an EHPad which does not unwind + // to caller. + bool EHPadForCallUnwindsLocally = false; + if (CallSiteEHPad && CS.isCall()) { + UnwindDestMemoTy FuncletUnwindMap; + Value *CallSiteUnwindDestToken = + getUnwindDestToken(CallSiteEHPad, FuncletUnwindMap); + + EHPadForCallUnwindsLocally = + CallSiteUnwindDestToken && + !isa<ConstantTokenNone>(CallSiteUnwindDestToken); + } + // Get an iterator to the last basic block in the function, which will have // the new function inlined after it. Function::iterator LastBlock = --Caller->end(); @@ -1552,6 +1599,9 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Add noalias metadata if necessary. AddAliasScopeMetadata(CS, VMap, DL, CalleeAAR); + // Propagate llvm.mem.parallel_loop_access if necessary. + PropagateParallelLoopAccessMetadata(CS, VMap); + // FIXME: We could register any cloned assumptions instead of clearing the // whole function's cache. if (IFI.ACT) @@ -1602,7 +1652,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, replaceDbgDeclareForAlloca(AI, AI, DIB, /*Deref=*/false); } - bool InlinedMustTailCalls = false; + bool InlinedMustTailCalls = false, InlinedDeoptimizeCalls = false; if (InlinedFunctionInfo.ContainsCalls) { CallInst::TailCallKind CallSiteTailKind = CallInst::TCK_None; if (CallInst *CI = dyn_cast<CallInst>(TheCall)) @@ -1615,6 +1665,10 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, if (!CI) continue; + if (Function *F = CI->getCalledFunction()) + InlinedDeoptimizeCalls |= + F->getIntrinsicID() == Intrinsic::experimental_deoptimize; + // We need to reduce the strength of any inlined tail calls. For // musttail, we have to avoid introducing potential unbounded stack // growth. For example, if functions 'f' and 'g' are mutually recursive @@ -1677,11 +1731,14 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, builder.CreateLifetimeStart(AI, AllocaSize); for (ReturnInst *RI : Returns) { - // Don't insert llvm.lifetime.end calls between a musttail call and a - // return. The return kills all local allocas. + // Don't insert llvm.lifetime.end calls between a musttail or deoptimize + // call and a return. The return kills all local allocas. if (InlinedMustTailCalls && RI->getParent()->getTerminatingMustTailCall()) continue; + if (InlinedDeoptimizeCalls && + RI->getParent()->getTerminatingDeoptimizeCall()) + continue; IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize); } } @@ -1702,10 +1759,12 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Insert a call to llvm.stackrestore before any return instructions in the // inlined function. for (ReturnInst *RI : Returns) { - // Don't insert llvm.stackrestore calls between a musttail call and a - // return. The return will restore the stack pointer. + // Don't insert llvm.stackrestore calls between a musttail or deoptimize + // call and a return. The return will restore the stack pointer. if (InlinedMustTailCalls && RI->getParent()->getTerminatingMustTailCall()) continue; + if (InlinedDeoptimizeCalls && RI->getParent()->getTerminatingDeoptimizeCall()) + continue; IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); } } @@ -1758,7 +1817,6 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, NewInst = CallInst::Create(cast<CallInst>(I), OpBundles, I); else NewInst = InvokeInst::Create(cast<InvokeInst>(I), OpBundles, I); - NewInst->setDebugLoc(I->getDebugLoc()); NewInst->takeName(I); I->replaceAllUsesWith(NewInst); I->eraseFromParent(); @@ -1766,6 +1824,14 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, OpBundles.clear(); } + // It is problematic if the inlinee has a cleanupret which unwinds to + // caller and we inline it into a call site which doesn't unwind but into + // an EH pad that does. Such an edge must be dynamically unreachable. + // As such, we replace the cleanupret with unreachable. + if (auto *CleanupRet = dyn_cast<CleanupReturnInst>(BB->getTerminator())) + if (CleanupRet->unwindsToCaller() && EHPadForCallUnwindsLocally) + changeToUnreachable(CleanupRet, /*UseLLVMTrap=*/false); + Instruction *I = BB->getFirstNonPHI(); if (!I->isEHPad()) continue; @@ -1781,6 +1847,64 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, } } + if (InlinedDeoptimizeCalls) { + // We need to at least remove the deoptimizing returns from the Return set, + // so that the control flow from those returns does not get merged into the + // caller (but terminate it instead). If the caller's return type does not + // match the callee's return type, we also need to change the return type of + // the intrinsic. + if (Caller->getReturnType() == TheCall->getType()) { + auto NewEnd = remove_if(Returns, [](ReturnInst *RI) { + return RI->getParent()->getTerminatingDeoptimizeCall() != nullptr; + }); + Returns.erase(NewEnd, Returns.end()); + } else { + SmallVector<ReturnInst *, 8> NormalReturns; + Function *NewDeoptIntrinsic = Intrinsic::getDeclaration( + Caller->getParent(), Intrinsic::experimental_deoptimize, + {Caller->getReturnType()}); + + for (ReturnInst *RI : Returns) { + CallInst *DeoptCall = RI->getParent()->getTerminatingDeoptimizeCall(); + if (!DeoptCall) { + NormalReturns.push_back(RI); + continue; + } + + // The calling convention on the deoptimize call itself may be bogus, + // since the code we're inlining may have undefined behavior (and may + // never actually execute at runtime); but all + // @llvm.experimental.deoptimize declarations have to have the same + // calling convention in a well-formed module. + auto CallingConv = DeoptCall->getCalledFunction()->getCallingConv(); + NewDeoptIntrinsic->setCallingConv(CallingConv); + auto *CurBB = RI->getParent(); + RI->eraseFromParent(); + + SmallVector<Value *, 4> CallArgs(DeoptCall->arg_begin(), + DeoptCall->arg_end()); + + SmallVector<OperandBundleDef, 1> OpBundles; + DeoptCall->getOperandBundlesAsDefs(OpBundles); + DeoptCall->eraseFromParent(); + assert(!OpBundles.empty() && + "Expected at least the deopt operand bundle"); + + IRBuilder<> Builder(CurBB); + CallInst *NewDeoptCall = + Builder.CreateCall(NewDeoptIntrinsic, CallArgs, OpBundles); + NewDeoptCall->setCallingConv(CallingConv); + if (NewDeoptCall->getType()->isVoidTy()) + Builder.CreateRetVoid(); + else + Builder.CreateRet(NewDeoptCall); + } + + // Leave behind the normal returns so we can merge control flow. + std::swap(Returns, NormalReturns); + } + } + // Handle any inlined musttail call sites. In order for a new call site to be // musttail, the source of the clone and the inlined call site must have been // musttail. Therefore it's safe to return without merging control into the diff --git a/lib/Transforms/Utils/InstructionNamer.cpp b/lib/Transforms/Utils/InstructionNamer.cpp index da890a297005..8a1973d1db05 100644 --- a/lib/Transforms/Utils/InstructionNamer.cpp +++ b/lib/Transforms/Utils/InstructionNamer.cpp @@ -37,13 +37,13 @@ namespace { if (!AI->hasName() && !AI->getType()->isVoidTy()) AI->setName("arg"); - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - if (!BB->hasName()) - BB->setName("bb"); - - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) - if (!I->hasName() && !I->getType()->isVoidTy()) - I->setName("tmp"); + for (BasicBlock &BB : F) { + if (!BB.hasName()) + BB.setName("bb"); + + for (Instruction &I : BB) + if (!I.hasName() && !I.getType()->isVoidTy()) + I.setName("tmp"); } return true; } diff --git a/lib/Transforms/Utils/IntegerDivision.cpp b/lib/Transforms/Utils/IntegerDivision.cpp index 5687afa61e2a..5a90dcb033b2 100644 --- a/lib/Transforms/Utils/IntegerDivision.cpp +++ b/lib/Transforms/Utils/IntegerDivision.cpp @@ -390,6 +390,8 @@ bool llvm::expandRemainder(BinaryOperator *Rem) { Value *Remainder = generateSignedRemainderCode(Rem->getOperand(0), Rem->getOperand(1), Builder); + // Check whether this is the insert point while Rem is still valid. + bool IsInsertPoint = Rem->getIterator() == Builder.GetInsertPoint(); Rem->replaceAllUsesWith(Remainder); Rem->dropAllReferences(); Rem->eraseFromParent(); @@ -397,7 +399,7 @@ bool llvm::expandRemainder(BinaryOperator *Rem) { // If we didn't actually generate an urem instruction, we're done // This happens for example if the input were constant. In this case the // Builder insertion point was unchanged - if (Rem == Builder.GetInsertPoint().getNodePtrUnchecked()) + if (IsInsertPoint) return true; BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); @@ -446,6 +448,9 @@ bool llvm::expandDivision(BinaryOperator *Div) { // Lower the code to unsigned division, and reset Div to point to the udiv. Value *Quotient = generateSignedDivisionCode(Div->getOperand(0), Div->getOperand(1), Builder); + + // Check whether this is the insert point while Div is still valid. + bool IsInsertPoint = Div->getIterator() == Builder.GetInsertPoint(); Div->replaceAllUsesWith(Quotient); Div->dropAllReferences(); Div->eraseFromParent(); @@ -453,7 +458,7 @@ bool llvm::expandDivision(BinaryOperator *Div) { // If we didn't actually generate an udiv instruction, we're done // This happens for example if the input were constant. In this case the // Builder insertion point was unchanged - if (Div == Builder.GetInsertPoint().getNodePtrUnchecked()) + if (IsInsertPoint) return true; BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp index b4b2e148dfbb..9658966779b9 100644 --- a/lib/Transforms/Utils/LCSSA.cpp +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -27,10 +27,11 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LCSSA.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -41,6 +42,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/PredIteratorCache.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" using namespace llvm; @@ -52,154 +54,156 @@ STATISTIC(NumLCSSA, "Number of live out of a loop variables"); /// Return true if the specified block is in the list. static bool isExitBlock(BasicBlock *BB, const SmallVectorImpl<BasicBlock *> &ExitBlocks) { - for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) - if (ExitBlocks[i] == BB) - return true; - return false; + return find(ExitBlocks, BB) != ExitBlocks.end(); } -/// Given an instruction in the loop, check to see if it has any uses that are -/// outside the current loop. If so, insert LCSSA PHI nodes and rewrite the -/// uses. -static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, - const SmallVectorImpl<BasicBlock *> &ExitBlocks, - PredIteratorCache &PredCache, LoopInfo *LI) { +/// For every instruction from the worklist, check to see if it has any uses +/// that are outside the current loop. If so, insert LCSSA PHI nodes and +/// rewrite the uses. +bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, + DominatorTree &DT, LoopInfo &LI) { SmallVector<Use *, 16> UsesToRewrite; + SmallVector<BasicBlock *, 8> ExitBlocks; + PredIteratorCache PredCache; + bool Changed = false; - // Tokens cannot be used in PHI nodes, so we skip over them. - // We can run into tokens which are live out of a loop with catchswitch - // instructions in Windows EH if the catchswitch has one catchpad which - // is inside the loop and another which is not. - if (Inst.getType()->isTokenTy()) - return false; + while (!Worklist.empty()) { + UsesToRewrite.clear(); + ExitBlocks.clear(); - BasicBlock *InstBB = Inst.getParent(); + Instruction *I = Worklist.pop_back_val(); + BasicBlock *InstBB = I->getParent(); + Loop *L = LI.getLoopFor(InstBB); + L->getExitBlocks(ExitBlocks); - for (Use &U : Inst.uses()) { - Instruction *User = cast<Instruction>(U.getUser()); - BasicBlock *UserBB = User->getParent(); - if (PHINode *PN = dyn_cast<PHINode>(User)) - UserBB = PN->getIncomingBlock(U); + if (ExitBlocks.empty()) + continue; - if (InstBB != UserBB && !L.contains(UserBB)) - UsesToRewrite.push_back(&U); - } + // Tokens cannot be used in PHI nodes, so we skip over them. + // We can run into tokens which are live out of a loop with catchswitch + // instructions in Windows EH if the catchswitch has one catchpad which + // is inside the loop and another which is not. + if (I->getType()->isTokenTy()) + continue; - // If there are no uses outside the loop, exit with no change. - if (UsesToRewrite.empty()) - return false; + for (Use &U : I->uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + BasicBlock *UserBB = User->getParent(); + if (PHINode *PN = dyn_cast<PHINode>(User)) + UserBB = PN->getIncomingBlock(U); - ++NumLCSSA; // We are applying the transformation + if (InstBB != UserBB && !L->contains(UserBB)) + UsesToRewrite.push_back(&U); + } - // Invoke instructions are special in that their result value is not available - // along their unwind edge. The code below tests to see whether DomBB - // dominates the value, so adjust DomBB to the normal destination block, - // which is effectively where the value is first usable. - BasicBlock *DomBB = Inst.getParent(); - if (InvokeInst *Inv = dyn_cast<InvokeInst>(&Inst)) - DomBB = Inv->getNormalDest(); + // If there are no uses outside the loop, exit with no change. + if (UsesToRewrite.empty()) + continue; - DomTreeNode *DomNode = DT.getNode(DomBB); + ++NumLCSSA; // We are applying the transformation - SmallVector<PHINode *, 16> AddedPHIs; - SmallVector<PHINode *, 8> PostProcessPHIs; + // Invoke instructions are special in that their result value is not + // available along their unwind edge. The code below tests to see whether + // DomBB dominates the value, so adjust DomBB to the normal destination + // block, which is effectively where the value is first usable. + BasicBlock *DomBB = InstBB; + if (InvokeInst *Inv = dyn_cast<InvokeInst>(I)) + DomBB = Inv->getNormalDest(); - SSAUpdater SSAUpdate; - SSAUpdate.Initialize(Inst.getType(), Inst.getName()); + DomTreeNode *DomNode = DT.getNode(DomBB); - // Insert the LCSSA phi's into all of the exit blocks dominated by the - // value, and add them to the Phi's map. - for (BasicBlock *ExitBB : ExitBlocks) { - if (!DT.dominates(DomNode, DT.getNode(ExitBB))) - continue; + SmallVector<PHINode *, 16> AddedPHIs; + SmallVector<PHINode *, 8> PostProcessPHIs; - // If we already inserted something for this BB, don't reprocess it. - if (SSAUpdate.HasValueForBlock(ExitBB)) - continue; + SSAUpdater SSAUpdate; + SSAUpdate.Initialize(I->getType(), I->getName()); - PHINode *PN = PHINode::Create(Inst.getType(), PredCache.size(ExitBB), - Inst.getName() + ".lcssa", &ExitBB->front()); + // Insert the LCSSA phi's into all of the exit blocks dominated by the + // value, and add them to the Phi's map. + for (BasicBlock *ExitBB : ExitBlocks) { + if (!DT.dominates(DomNode, DT.getNode(ExitBB))) + continue; - // Add inputs from inside the loop for this PHI. - for (BasicBlock *Pred : PredCache.get(ExitBB)) { - PN->addIncoming(&Inst, Pred); + // If we already inserted something for this BB, don't reprocess it. + if (SSAUpdate.HasValueForBlock(ExitBB)) + continue; - // If the exit block has a predecessor not within the loop, arrange for - // the incoming value use corresponding to that predecessor to be - // rewritten in terms of a different LCSSA PHI. - if (!L.contains(Pred)) - UsesToRewrite.push_back( - &PN->getOperandUse(PN->getOperandNumForIncomingValue( - PN->getNumIncomingValues() - 1))); + PHINode *PN = PHINode::Create(I->getType(), PredCache.size(ExitBB), + I->getName() + ".lcssa", &ExitBB->front()); + + // Add inputs from inside the loop for this PHI. + for (BasicBlock *Pred : PredCache.get(ExitBB)) { + PN->addIncoming(I, Pred); + + // If the exit block has a predecessor not within the loop, arrange for + // the incoming value use corresponding to that predecessor to be + // rewritten in terms of a different LCSSA PHI. + if (!L->contains(Pred)) + UsesToRewrite.push_back( + &PN->getOperandUse(PN->getOperandNumForIncomingValue( + PN->getNumIncomingValues() - 1))); + } + + AddedPHIs.push_back(PN); + + // Remember that this phi makes the value alive in this block. + SSAUpdate.AddAvailableValue(ExitBB, PN); + + // LoopSimplify might fail to simplify some loops (e.g. when indirect + // branches are involved). In such situations, it might happen that an + // exit for Loop L1 is the header of a disjoint Loop L2. Thus, when we + // create PHIs in such an exit block, we are also inserting PHIs into L2's + // header. This could break LCSSA form for L2 because these inserted PHIs + // can also have uses outside of L2. Remember all PHIs in such situation + // as to revisit than later on. FIXME: Remove this if indirectbr support + // into LoopSimplify gets improved. + if (auto *OtherLoop = LI.getLoopFor(ExitBB)) + if (!L->contains(OtherLoop)) + PostProcessPHIs.push_back(PN); } - AddedPHIs.push_back(PN); - - // Remember that this phi makes the value alive in this block. - SSAUpdate.AddAvailableValue(ExitBB, PN); - - // LoopSimplify might fail to simplify some loops (e.g. when indirect - // branches are involved). In such situations, it might happen that an exit - // for Loop L1 is the header of a disjoint Loop L2. Thus, when we create - // PHIs in such an exit block, we are also inserting PHIs into L2's header. - // This could break LCSSA form for L2 because these inserted PHIs can also - // have uses outside of L2. Remember all PHIs in such situation as to - // revisit than later on. FIXME: Remove this if indirectbr support into - // LoopSimplify gets improved. - if (auto *OtherLoop = LI->getLoopFor(ExitBB)) - if (!L.contains(OtherLoop)) - PostProcessPHIs.push_back(PN); - } + // Rewrite all uses outside the loop in terms of the new PHIs we just + // inserted. + for (Use *UseToRewrite : UsesToRewrite) { + // If this use is in an exit block, rewrite to use the newly inserted PHI. + // This is required for correctness because SSAUpdate doesn't handle uses + // in the same block. It assumes the PHI we inserted is at the end of the + // block. + Instruction *User = cast<Instruction>(UseToRewrite->getUser()); + BasicBlock *UserBB = User->getParent(); + if (PHINode *PN = dyn_cast<PHINode>(User)) + UserBB = PN->getIncomingBlock(*UseToRewrite); + + if (isa<PHINode>(UserBB->begin()) && isExitBlock(UserBB, ExitBlocks)) { + // Tell the VHs that the uses changed. This updates SCEV's caches. + if (UseToRewrite->get()->hasValueHandle()) + ValueHandleBase::ValueIsRAUWd(*UseToRewrite, &UserBB->front()); + UseToRewrite->set(&UserBB->front()); + continue; + } - // Rewrite all uses outside the loop in terms of the new PHIs we just - // inserted. - for (Use *UseToRewrite : UsesToRewrite) { - // If this use is in an exit block, rewrite to use the newly inserted PHI. - // This is required for correctness because SSAUpdate doesn't handle uses in - // the same block. It assumes the PHI we inserted is at the end of the - // block. - Instruction *User = cast<Instruction>(UseToRewrite->getUser()); - BasicBlock *UserBB = User->getParent(); - if (PHINode *PN = dyn_cast<PHINode>(User)) - UserBB = PN->getIncomingBlock(*UseToRewrite); - - if (isa<PHINode>(UserBB->begin()) && isExitBlock(UserBB, ExitBlocks)) { - // Tell the VHs that the uses changed. This updates SCEV's caches. - if (UseToRewrite->get()->hasValueHandle()) - ValueHandleBase::ValueIsRAUWd(*UseToRewrite, &UserBB->front()); - UseToRewrite->set(&UserBB->front()); - continue; + // Otherwise, do full PHI insertion. + SSAUpdate.RewriteUse(*UseToRewrite); } - // Otherwise, do full PHI insertion. - SSAUpdate.RewriteUse(*UseToRewrite); - } + // Post process PHI instructions that were inserted into another disjoint + // loop and update their exits properly. + for (auto *PostProcessPN : PostProcessPHIs) { + if (PostProcessPN->use_empty()) + continue; - // Post process PHI instructions that were inserted into another disjoint loop - // and update their exits properly. - for (auto *I : PostProcessPHIs) { - if (I->use_empty()) - continue; + // Reprocess each PHI instruction. + Worklist.push_back(PostProcessPN); + } - BasicBlock *PHIBB = I->getParent(); - Loop *OtherLoop = LI->getLoopFor(PHIBB); - SmallVector<BasicBlock *, 8> EBs; - OtherLoop->getExitBlocks(EBs); - if (EBs.empty()) - continue; + // Remove PHI nodes that did not have any uses rewritten. + for (PHINode *PN : AddedPHIs) + if (PN->use_empty()) + PN->eraseFromParent(); - // Recurse and re-process each PHI instruction. FIXME: we should really - // convert this entire thing to a worklist approach where we process a - // vector of instructions... - processInstruction(*OtherLoop, *I, DT, EBs, PredCache, LI); + Changed = true; } - - // Remove PHI nodes that did not have any uses rewritten. - for (PHINode *PN : AddedPHIs) - if (PN->use_empty()) - PN->eraseFromParent(); - - return true; + return Changed; } /// Return true if the specified block dominates at least @@ -209,11 +213,9 @@ blockDominatesAnExit(BasicBlock *BB, DominatorTree &DT, const SmallVectorImpl<BasicBlock *> &ExitBlocks) { DomTreeNode *DomNode = DT.getNode(BB); - for (BasicBlock *ExitBB : ExitBlocks) - if (DT.dominates(DomNode, DT.getNode(ExitBB))) - return true; - - return false; + return llvm::any_of(ExitBlocks, [&](BasicBlock * EB) { + return DT.dominates(DomNode, DT.getNode(EB)); + }); } bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, @@ -227,10 +229,10 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, if (ExitBlocks.empty()) return false; - PredIteratorCache PredCache; + SmallVector<Instruction *, 8> Worklist; // Look at all the instructions in the loop, checking to see if they have uses - // outside the loop. If so, rewrite those uses. + // outside the loop. If so, put them into the worklist to rewrite those uses. for (BasicBlock *BB : L.blocks()) { // For large loops, avoid use-scanning by using dominance information: In // particular, if a block does not dominate any of the loop exits, then none @@ -246,9 +248,10 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, !isa<PHINode>(I.user_back()))) continue; - Changed |= processInstruction(L, I, DT, ExitBlocks, PredCache, LI); + Worklist.push_back(&I); } } + Changed = formLCSSAForInstructions(Worklist, DT, *LI); // If we modified the code, remove any caches about the loop from SCEV to // avoid dangling entries. @@ -274,11 +277,20 @@ bool llvm::formLCSSARecursively(Loop &L, DominatorTree &DT, LoopInfo *LI, return Changed; } +/// Process all loops in the function, inner-most out. +static bool formLCSSAOnAllLoops(LoopInfo *LI, DominatorTree &DT, + ScalarEvolution *SE) { + bool Changed = false; + for (auto &L : *LI) + Changed |= formLCSSARecursively(*L, DT, LI, SE); + return Changed; +} + namespace { -struct LCSSA : public FunctionPass { +struct LCSSAWrapperPass : public FunctionPass { static char ID; // Pass identification, replacement for typeid - LCSSA() : FunctionPass(ID) { - initializeLCSSAPass(*PassRegistry::getPassRegistry()); + LCSSAWrapperPass() : FunctionPass(ID) { + initializeLCSSAWrapperPassPass(*PassRegistry::getPassRegistry()); } // Cached analysis information for the current function. @@ -298,6 +310,7 @@ struct LCSSA : public FunctionPass { AU.addRequired<LoopInfoWrapperPass>(); AU.addPreservedID(LoopSimplifyID); AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<BasicAAWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<ScalarEvolutionWrapperPass>(); AU.addPreserved<SCEVAAWrapperPass>(); @@ -305,30 +318,39 @@ struct LCSSA : public FunctionPass { }; } -char LCSSA::ID = 0; -INITIALIZE_PASS_BEGIN(LCSSA, "lcssa", "Loop-Closed SSA Form Pass", false, false) +char LCSSAWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(LCSSAWrapperPass, "lcssa", "Loop-Closed SSA Form Pass", + false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) -INITIALIZE_PASS_END(LCSSA, "lcssa", "Loop-Closed SSA Form Pass", false, false) - -Pass *llvm::createLCSSAPass() { return new LCSSA(); } -char &llvm::LCSSAID = LCSSA::ID; +INITIALIZE_PASS_END(LCSSAWrapperPass, "lcssa", "Loop-Closed SSA Form Pass", + false, false) +Pass *llvm::createLCSSAPass() { return new LCSSAWrapperPass(); } +char &llvm::LCSSAID = LCSSAWrapperPass::ID; -/// Process all loops in the function, inner-most out. -bool LCSSA::runOnFunction(Function &F) { - bool Changed = false; +/// Transform \p F into loop-closed SSA form. +bool LCSSAWrapperPass::runOnFunction(Function &F) { LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); SE = SEWP ? &SEWP->getSE() : nullptr; - // Simplify each loop nest in the function. - for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= formLCSSARecursively(**I, *DT, LI, SE); - - return Changed; + return formLCSSAOnAllLoops(LI, *DT, SE); } +PreservedAnalyses LCSSAPass::run(Function &F, AnalysisManager<Function> &AM) { + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto *SE = AM.getCachedResult<ScalarEvolutionAnalysis>(F); + if (!formLCSSAOnAllLoops(&LI, DT, SE)) + return PreservedAnalyses::all(); + + // FIXME: This should also 'preserve the CFG'. + PreservedAnalyses PA; + PA.preserve<BasicAA>(); + PA.preserve<GlobalsAA>(); + PA.preserve<SCEVAA>(); + PA.preserve<ScalarEvolutionAnalysis>(); + return PA; +} diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index abc9b65f7a39..f1838d891466 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -42,11 +42,13 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "local" @@ -148,9 +150,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, SmallVector<uint32_t, 8> Weights; for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; ++MD_i) { - ConstantInt *CI = - mdconst::dyn_extract<ConstantInt>(MD->getOperand(MD_i)); - assert(CI); + auto *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i)); Weights.push_back(CI->getValue().getZExtValue()); } // Merge weight of this case to the default weight. @@ -321,8 +321,12 @@ bool llvm::isInstructionTriviallyDead(Instruction *I, II->getIntrinsicID() == Intrinsic::lifetime_end) return isa<UndefValue>(II->getArgOperand(1)); - // Assumptions are dead if their condition is trivially true. - if (II->getIntrinsicID() == Intrinsic::assume) { + // Assumptions are dead if their condition is trivially true. Guards on + // true are operationally no-ops. In the future we can consider more + // sophisticated tradeoffs for guards considering potential for check + // widening, but for now we keep things simple. + if (II->getIntrinsicID() == Intrinsic::assume || + II->getIntrinsicID() == Intrinsic::experimental_guard) { if (ConstantInt *Cond = dyn_cast<ConstantInt>(II->getArgOperand(0))) return !Cond->isZero(); @@ -452,14 +456,23 @@ simplifyAndDCEInstruction(Instruction *I, if (Value *SimpleV = SimplifyInstruction(I, DL)) { // Add the users to the worklist. CAREFUL: an instruction can use itself, // in the case of a phi node. - for (User *U : I->users()) - if (U != I) + for (User *U : I->users()) { + if (U != I) { WorkList.insert(cast<Instruction>(U)); + } + } // Replace the instruction with its simplified value. - I->replaceAllUsesWith(SimpleV); - I->eraseFromParent(); - return true; + bool Changed = false; + if (!I->use_empty()) { + I->replaceAllUsesWith(SimpleV); + Changed = true; + } + if (isInstructionTriviallyDead(I, TLI)) { + I->eraseFromParent(); + Changed = true; + } + return Changed; } return false; } @@ -486,7 +499,8 @@ bool llvm::SimplifyInstructionsInBlock(BasicBlock *BB, // Iterate over the original function, only adding insts to the worklist // if they actually need to be revisited. This avoids having to pre-init // the worklist with the entire function's worth of instructions. - for (BasicBlock::iterator BI = BB->begin(), E = std::prev(BB->end()); BI != E;) { + for (BasicBlock::iterator BI = BB->begin(), E = std::prev(BB->end()); + BI != E;) { assert(!BI->isTerminator()); Instruction *I = &*BI; ++BI; @@ -1025,7 +1039,8 @@ unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign, /// /// See if there is a dbg.value intrinsic for DIVar before I. -static bool LdStHasDebugValue(const DILocalVariable *DIVar, Instruction *I) { +static bool LdStHasDebugValue(DILocalVariable *DIVar, DIExpression *DIExpr, + Instruction *I) { // Since we can't guarantee that the original dbg.declare instrinsic // is removed by LowerDbgDeclare(), we need to make sure that we are // not inserting the same dbg.value intrinsic over and over. @@ -1035,7 +1050,8 @@ static bool LdStHasDebugValue(const DILocalVariable *DIVar, Instruction *I) { if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(PrevI)) if (DVI->getValue() == I->getOperand(0) && DVI->getOffset() == 0 && - DVI->getVariable() == DIVar) + DVI->getVariable() == DIVar && + DVI->getExpression() == DIExpr) return true; } return false; @@ -1049,9 +1065,6 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, auto *DIExpr = DDI->getExpression(); assert(DIVar && "Missing variable"); - if (LdStHasDebugValue(DIVar, SI)) - return true; - // If an argument is zero extended then use argument directly. The ZExt // may be zapped by an optimization pass in future. Argument *ExtendedArg = nullptr; @@ -1066,25 +1079,25 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, // to the alloca described by DDI, if it's first operand is an extend, // we're guaranteed that before extension, the value was narrower than // the size of the alloca, hence the size of the described variable. - SmallVector<uint64_t, 3> NewDIExpr; + SmallVector<uint64_t, 3> Ops; unsigned PieceOffset = 0; // If this already is a bit piece, we drop the bit piece from the expression // and record the offset. if (DIExpr->isBitPiece()) { - NewDIExpr.append(DIExpr->elements_begin(), DIExpr->elements_end()-3); + Ops.append(DIExpr->elements_begin(), DIExpr->elements_end()-3); PieceOffset = DIExpr->getBitPieceOffset(); } else { - NewDIExpr.append(DIExpr->elements_begin(), DIExpr->elements_end()); + Ops.append(DIExpr->elements_begin(), DIExpr->elements_end()); } - NewDIExpr.push_back(dwarf::DW_OP_bit_piece); - NewDIExpr.push_back(PieceOffset); //Offset + Ops.push_back(dwarf::DW_OP_bit_piece); + Ops.push_back(PieceOffset); // Offset const DataLayout &DL = DDI->getModule()->getDataLayout(); - NewDIExpr.push_back(DL.getTypeSizeInBits(ExtendedArg->getType())); // Size - Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, - Builder.createExpression(NewDIExpr), - DDI->getDebugLoc(), SI); - } - else + Ops.push_back(DL.getTypeSizeInBits(ExtendedArg->getType())); // Size + auto NewDIExpr = Builder.createExpression(Ops); + if (!LdStHasDebugValue(DIVar, NewDIExpr, SI)) + Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, NewDIExpr, + DDI->getDebugLoc(), SI); + } else if (!LdStHasDebugValue(DIVar, DIExpr, SI)) Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, DIExpr, DDI->getDebugLoc(), SI); return true; @@ -1098,7 +1111,7 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, auto *DIExpr = DDI->getExpression(); assert(DIVar && "Missing variable"); - if (LdStHasDebugValue(DIVar, LI)) + if (LdStHasDebugValue(DIVar, DIExpr, LI)) return true; // We are now tracking the loaded value instead of the address. In the @@ -1140,12 +1153,14 @@ bool llvm::LowerDbgDeclare(Function &F) { // the stack slot (and at a lexical-scope granularity). Later // passes will attempt to elide the stack slot. if (AI && !isArray(AI)) { - for (User *U : AI->users()) - if (StoreInst *SI = dyn_cast<StoreInst>(U)) - ConvertDebugDeclareToDebugValue(DDI, SI, DIB); - else if (LoadInst *LI = dyn_cast<LoadInst>(U)) + for (auto &AIUse : AI->uses()) { + User *U = AIUse.getUser(); + if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + if (AIUse.getOperandNo() == 1) + ConvertDebugDeclareToDebugValue(DDI, SI, DIB); + } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) { ConvertDebugDeclareToDebugValue(DDI, LI, DIB); - else if (CallInst *CI = dyn_cast<CallInst>(U)) { + } else if (CallInst *CI = dyn_cast<CallInst>(U)) { // This is a call by-value or some other instruction that // takes a pointer to the variable. Insert a *value* // intrinsic that describes the alloca. @@ -1157,6 +1172,7 @@ bool llvm::LowerDbgDeclare(Function &F) { DIB.createExpression(NewDIExpr), DDI->getDebugLoc(), CI); } + } DDI->eraseFromParent(); } } @@ -1175,6 +1191,38 @@ DbgDeclareInst *llvm::FindAllocaDbgDeclare(Value *V) { return nullptr; } +static void DIExprAddDeref(SmallVectorImpl<uint64_t> &Expr) { + Expr.push_back(dwarf::DW_OP_deref); +} + +static void DIExprAddOffset(SmallVectorImpl<uint64_t> &Expr, int Offset) { + if (Offset > 0) { + Expr.push_back(dwarf::DW_OP_plus); + Expr.push_back(Offset); + } else if (Offset < 0) { + Expr.push_back(dwarf::DW_OP_minus); + Expr.push_back(-Offset); + } +} + +static DIExpression *BuildReplacementDIExpr(DIBuilder &Builder, + DIExpression *DIExpr, bool Deref, + int Offset) { + if (!Deref && !Offset) + return DIExpr; + // Create a copy of the original DIDescriptor for user variable, prepending + // "deref" operation to a list of address elements, as new llvm.dbg.declare + // will take a value storing address of the memory for variable, not + // alloca itself. + SmallVector<uint64_t, 4> NewDIExpr; + if (Deref) + DIExprAddDeref(NewDIExpr); + DIExprAddOffset(NewDIExpr, Offset); + if (DIExpr) + NewDIExpr.append(DIExpr->elements_begin(), DIExpr->elements_end()); + return Builder.createExpression(NewDIExpr); +} + bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, Instruction *InsertBefore, DIBuilder &Builder, bool Deref, int Offset) { @@ -1186,25 +1234,7 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, auto *DIExpr = DDI->getExpression(); assert(DIVar && "Missing variable"); - if (Deref || Offset) { - // Create a copy of the original DIDescriptor for user variable, prepending - // "deref" operation to a list of address elements, as new llvm.dbg.declare - // will take a value storing address of the memory for variable, not - // alloca itself. - SmallVector<uint64_t, 4> NewDIExpr; - if (Deref) - NewDIExpr.push_back(dwarf::DW_OP_deref); - if (Offset > 0) { - NewDIExpr.push_back(dwarf::DW_OP_plus); - NewDIExpr.push_back(Offset); - } else if (Offset < 0) { - NewDIExpr.push_back(dwarf::DW_OP_minus); - NewDIExpr.push_back(-Offset); - } - if (DIExpr) - NewDIExpr.append(DIExpr->elements_begin(), DIExpr->elements_end()); - DIExpr = Builder.createExpression(NewDIExpr); - } + DIExpr = BuildReplacementDIExpr(Builder, DIExpr, Deref, Offset); // Insert llvm.dbg.declare immediately after the original alloca, and remove // old llvm.dbg.declare. @@ -1219,12 +1249,73 @@ bool llvm::replaceDbgDeclareForAlloca(AllocaInst *AI, Value *NewAllocaAddress, Deref, Offset); } -void llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap) { +static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress, + DIBuilder &Builder, int Offset) { + DebugLoc Loc = DVI->getDebugLoc(); + auto *DIVar = DVI->getVariable(); + auto *DIExpr = DVI->getExpression(); + assert(DIVar && "Missing variable"); + + // This is an alloca-based llvm.dbg.value. The first thing it should do with + // the alloca pointer is dereference it. Otherwise we don't know how to handle + // it and give up. + if (!DIExpr || DIExpr->getNumElements() < 1 || + DIExpr->getElement(0) != dwarf::DW_OP_deref) + return; + + // Insert the offset immediately after the first deref. + // We could just change the offset argument of dbg.value, but it's unsigned... + if (Offset) { + SmallVector<uint64_t, 4> NewDIExpr; + DIExprAddDeref(NewDIExpr); + DIExprAddOffset(NewDIExpr, Offset); + NewDIExpr.append(DIExpr->elements_begin() + 1, DIExpr->elements_end()); + DIExpr = Builder.createExpression(NewDIExpr); + } + + Builder.insertDbgValueIntrinsic(NewAddress, DVI->getOffset(), DIVar, DIExpr, + Loc, DVI); + DVI->eraseFromParent(); +} + +void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress, + DIBuilder &Builder, int Offset) { + if (auto *L = LocalAsMetadata::getIfExists(AI)) + if (auto *MDV = MetadataAsValue::getIfExists(AI->getContext(), L)) + for (auto UI = MDV->use_begin(), UE = MDV->use_end(); UI != UE;) { + Use &U = *UI++; + if (auto *DVI = dyn_cast<DbgValueInst>(U.getUser())) + replaceOneDbgValueForAlloca(DVI, NewAllocaAddress, Builder, Offset); + } +} + +unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { + unsigned NumDeadInst = 0; + // Delete the instructions backwards, as it has a reduced likelihood of + // having to update as many def-use and use-def chains. + Instruction *EndInst = BB->getTerminator(); // Last not to be deleted. + while (EndInst != &BB->front()) { + // Delete the next to last instruction. + Instruction *Inst = &*--EndInst->getIterator(); + if (!Inst->use_empty() && !Inst->getType()->isTokenTy()) + Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); + if (Inst->isEHPad() || Inst->getType()->isTokenTy()) { + EndInst = Inst; + continue; + } + if (!isa<DbgInfoIntrinsic>(Inst)) + ++NumDeadInst; + Inst->eraseFromParent(); + } + return NumDeadInst; +} + +unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap) { BasicBlock *BB = I->getParent(); // Loop over all of the successors, removing BB's entry from any PHI // nodes. - for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) - (*SI)->removePredecessor(BB); + for (BasicBlock *Successor : successors(BB)) + Successor->removePredecessor(BB); // Insert a call to llvm.trap right before this. This turns the undefined // behavior into a hard fail instead of falling through into random code. @@ -1237,12 +1328,15 @@ void llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap) { new UnreachableInst(I->getContext(), I); // All instructions after this are dead. + unsigned NumInstrsRemoved = 0; BasicBlock::iterator BBI = I->getIterator(), BBE = BB->end(); while (BBI != BBE) { if (!BBI->use_empty()) BBI->replaceAllUsesWith(UndefValue::get(BBI->getType())); BB->getInstList().erase(BBI++); + ++NumInstrsRemoved; } + return NumInstrsRemoved; } /// changeToCall - Convert the specified invoke into a normal call. @@ -1280,36 +1374,52 @@ static bool markAliveBlocks(Function &F, // Do a quick scan of the basic block, turning any obviously unreachable // instructions into LLVM unreachable insts. The instruction combining pass // canonicalizes unreachable insts into stores to null or undef. - for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E;++BBI){ + for (Instruction &I : *BB) { // Assumptions that are known to be false are equivalent to unreachable. // Also, if the condition is undefined, then we make the choice most // beneficial to the optimizer, and choose that to also be unreachable. - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(BBI)) + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { if (II->getIntrinsicID() == Intrinsic::assume) { - bool MakeUnreachable = false; - if (isa<UndefValue>(II->getArgOperand(0))) - MakeUnreachable = true; - else if (ConstantInt *Cond = - dyn_cast<ConstantInt>(II->getArgOperand(0))) - MakeUnreachable = Cond->isZero(); - - if (MakeUnreachable) { + if (match(II->getArgOperand(0), m_CombineOr(m_Zero(), m_Undef()))) { // Don't insert a call to llvm.trap right before the unreachable. - changeToUnreachable(&*BBI, false); + changeToUnreachable(II, false); Changed = true; break; } } - if (CallInst *CI = dyn_cast<CallInst>(BBI)) { + if (II->getIntrinsicID() == Intrinsic::experimental_guard) { + // A call to the guard intrinsic bails out of the current compilation + // unit if the predicate passed to it is false. If the predicate is a + // constant false, then we know the guard will bail out of the current + // compile unconditionally, so all code following it is dead. + // + // Note: unlike in llvm.assume, it is not "obviously profitable" for + // guards to treat `undef` as `false` since a guard on `undef` can + // still be useful for widening. + if (match(II->getArgOperand(0), m_Zero())) + if (!isa<UnreachableInst>(II->getNextNode())) { + changeToUnreachable(II->getNextNode(), /*UseLLVMTrap=*/ false); + Changed = true; + break; + } + } + } + + if (auto *CI = dyn_cast<CallInst>(&I)) { + Value *Callee = CI->getCalledValue(); + if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { + changeToUnreachable(CI, /*UseLLVMTrap=*/false); + Changed = true; + break; + } if (CI->doesNotReturn()) { // If we found a call to a no-return function, insert an unreachable // instruction after it. Make sure there isn't *already* one there // though. - ++BBI; - if (!isa<UnreachableInst>(BBI)) { + if (!isa<UnreachableInst>(CI->getNextNode())) { // Don't insert a call to llvm.trap right before the unreachable. - changeToUnreachable(&*BBI, false); + changeToUnreachable(CI->getNextNode(), false); Changed = true; } break; @@ -1319,7 +1429,7 @@ static bool markAliveBlocks(Function &F, // Store to undef and store to null are undefined and used to signal that // they should be changed to unreachable by passes that can't modify the // CFG. - if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) { + if (auto *SI = dyn_cast<StoreInst>(&I)) { // Don't touch volatile stores. if (SI->isVolatile()) continue; @@ -1393,9 +1503,9 @@ static bool markAliveBlocks(Function &F, } Changed |= ConstantFoldTerminator(BB, true); - for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) - if (Reachable.insert(*SI).second) - Worklist.push_back(*SI); + for (BasicBlock *Successor : successors(BB)) + if (Reachable.insert(Successor).second) + Worklist.push_back(Successor); } while (!Worklist.empty()); return Changed; } @@ -1438,7 +1548,7 @@ void llvm::removeUnwindEdge(BasicBlock *BB) { /// if they are in a dead cycle. Return true if a change was made, false /// otherwise. bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI) { - SmallPtrSet<BasicBlock*, 128> Reachable; + SmallPtrSet<BasicBlock*, 16> Reachable; bool Changed = markAliveBlocks(F, Reachable); // If there are unreachable blocks in the CFG... @@ -1454,10 +1564,9 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI) { if (Reachable.count(&*BB)) continue; - for (succ_iterator SI = succ_begin(&*BB), SE = succ_end(&*BB); SI != SE; - ++SI) - if (Reachable.count(*SI)) - (*SI)->removePredecessor(&*BB); + for (BasicBlock *Successor : successors(&*BB)) + if (Reachable.count(Successor)) + Successor->removePredecessor(&*BB); if (LVI) LVI->eraseBlock(&*BB); BB->dropAllReferences(); @@ -1495,6 +1604,7 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, K->setMetadata(Kind, MDNode::getMostGenericAliasScope(JMD, KMD)); break; case LLVMContext::MD_noalias: + case LLVMContext::MD_mem_parallel_loop_access: K->setMetadata(Kind, MDNode::intersect(JMD, KMD)); break; case LLVMContext::MD_range: @@ -1566,7 +1676,7 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, UI != UE;) { Use &U = *UI++; auto *I = cast<Instruction>(U.getUser()); - if (DT.dominates(BB, I->getParent())) { + if (DT.properlyDominates(BB, I->getParent())) { U.set(To); DEBUG(dbgs() << "Replace dominated use of '" << From->getName() << "' as " << *To << " in " << *U << "\n"); @@ -1577,18 +1687,18 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, } bool llvm::callsGCLeafFunction(ImmutableCallSite CS) { - if (isa<IntrinsicInst>(CS.getInstruction())) - // Most LLVM intrinsics are things which can never take a safepoint. - // As a result, we don't need to have the stack parsable at the - // callsite. This is a highly useful optimization since intrinsic - // calls are fairly prevalent, particularly in debug builds. - return true; - // Check if the function is specifically marked as a gc leaf function. if (CS.hasFnAttr("gc-leaf-function")) return true; - if (const Function *F = CS.getCalledFunction()) - return F->hasFnAttribute("gc-leaf-function"); + if (const Function *F = CS.getCalledFunction()) { + if (F->hasFnAttribute("gc-leaf-function")) + return true; + + if (auto IID = F->getIntrinsicID()) + // Most LLVM intrinsics do not take safepoints. + return IID != Intrinsic::experimental_gc_statepoint && + IID != Intrinsic::experimental_deoptimize; + } return false; } @@ -1723,7 +1833,23 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // If the AndMask is zero for this bit, clear the bit. if ((AndMask & Bit) == 0) Result->Provenance[i] = BitPart::Unset; + return Result; + } + // If this is a zext instruction zero extend the result. + if (I->getOpcode() == Instruction::ZExt) { + auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, + MatchBitReversals, BPS); + if (!Res) + return Result; + + Result = BitPart(Res->Provider, BitWidth); + auto NarrowBitWidth = + cast<IntegerType>(cast<ZExtInst>(I)->getSrcTy())->getBitWidth(); + for (unsigned i = 0; i < NarrowBitWidth; ++i) + Result->Provenance[i] = Res->Provenance[i]; + for (unsigned i = NarrowBitWidth; i < BitWidth; ++i) + Result->Provenance[i] = BitPart::Unset; return Result; } } @@ -1754,7 +1880,7 @@ static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To, /// Given an OR instruction, check to see if this is a bitreverse /// idiom. If so, insert the new intrinsic and return true. -bool llvm::recognizeBitReverseOrBSwapIdiom( +bool llvm::recognizeBSwapOrBitReverseIdiom( Instruction *I, bool MatchBSwaps, bool MatchBitReversals, SmallVectorImpl<Instruction *> &InsertedInsts) { if (Operator::getOpcode(I) != Instruction::Or) @@ -1766,6 +1892,15 @@ bool llvm::recognizeBitReverseOrBSwapIdiom( return false; // Can't do vectors or integers > 128 bits. unsigned BW = ITy->getBitWidth(); + unsigned DemandedBW = BW; + IntegerType *DemandedTy = ITy; + if (I->hasOneUse()) { + if (TruncInst *Trunc = dyn_cast<TruncInst>(I->user_back())) { + DemandedTy = cast<IntegerType>(Trunc->getType()); + DemandedBW = DemandedTy->getBitWidth(); + } + } + // Try to find all the pieces corresponding to the bswap. std::map<Value *, Optional<BitPart>> BPS; auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS); @@ -1775,11 +1910,12 @@ bool llvm::recognizeBitReverseOrBSwapIdiom( // Now, is the bit permutation correct for a bswap or a bitreverse? We can // only byteswap values with an even number of bytes. - bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true; - for (unsigned i = 0; i < BW; ++i) { - OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW); + bool OKForBSwap = DemandedBW % 16 == 0, OKForBitReverse = true; + for (unsigned i = 0; i < DemandedBW; ++i) { + OKForBSwap &= + bitTransformIsCorrectForBSwap(BitProvenance[i], i, DemandedBW); OKForBitReverse &= - bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW); + bitTransformIsCorrectForBitReverse(BitProvenance[i], i, DemandedBW); } Intrinsic::ID Intrin; @@ -1790,7 +1926,51 @@ bool llvm::recognizeBitReverseOrBSwapIdiom( else return false; + if (ITy != DemandedTy) { + Function *F = Intrinsic::getDeclaration(I->getModule(), Intrin, DemandedTy); + Value *Provider = Res->Provider; + IntegerType *ProviderTy = cast<IntegerType>(Provider->getType()); + // We may need to truncate the provider. + if (DemandedTy != ProviderTy) { + auto *Trunc = CastInst::Create(Instruction::Trunc, Provider, DemandedTy, + "trunc", I); + InsertedInsts.push_back(Trunc); + Provider = Trunc; + } + auto *CI = CallInst::Create(F, Provider, "rev", I); + InsertedInsts.push_back(CI); + auto *ExtInst = CastInst::Create(Instruction::ZExt, CI, ITy, "zext", I); + InsertedInsts.push_back(ExtInst); + return true; + } + Function *F = Intrinsic::getDeclaration(I->getModule(), Intrin, ITy); InsertedInsts.push_back(CallInst::Create(F, Res->Provider, "rev", I)); return true; } + +// CodeGen has special handling for some string functions that may replace +// them with target-specific intrinsics. Since that'd skip our interceptors +// in ASan/MSan/TSan/DFSan, and thus make us miss some memory accesses, +// we mark affected calls as NoBuiltin, which will disable optimization +// in CodeGen. +void llvm::maybeMarkSanitizerLibraryCallNoBuiltin(CallInst *CI, + const TargetLibraryInfo *TLI) { + Function *F = CI->getCalledFunction(); + LibFunc::Func Func; + if (!F || F->hasLocalLinkage() || !F->hasName() || + !TLI->getLibFunc(F->getName(), Func)) + return; + switch (Func) { + default: break; + case LibFunc::memcmp: + case LibFunc::memchr: + case LibFunc::strcpy: + case LibFunc::stpcpy: + case LibFunc::strcmp: + case LibFunc::strlen: + case LibFunc::strnlen: + CI->addAttribute(AttributeSet::FunctionIndex, Attribute::NoBuiltin); + break; + } +} diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index 1fa469595d16..b3a928bf7753 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -37,6 +37,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SetOperations.h" @@ -489,14 +490,9 @@ ReprocessLoop: DEBUG(dbgs() << "LoopSimplify: Deleting edge from dead predecessor " << P->getName() << "\n"); - // Inform each successor of each dead pred. - for (succ_iterator SI = succ_begin(P), SE = succ_end(P); SI != SE; ++SI) - (*SI)->removePredecessor(P); // Zap the dead pred's terminator and replace it with unreachable. TerminatorInst *TI = P->getTerminator(); - TI->replaceAllUsesWith(UndefValue::get(TI->getType())); - P->getTerminator()->eraseFromParent(); - new UnreachableInst(P->getContext(), P); + changeToUnreachable(TI, /*UseLLVMTrap=*/false); Changed = true; } } @@ -506,14 +502,13 @@ ReprocessLoop: // trip count computations. SmallVector<BasicBlock*, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - for (SmallVectorImpl<BasicBlock *>::iterator I = ExitingBlocks.begin(), - E = ExitingBlocks.end(); I != E; ++I) - if (BranchInst *BI = dyn_cast<BranchInst>((*I)->getTerminator())) + for (BasicBlock *ExitingBlock : ExitingBlocks) + if (BranchInst *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator())) if (BI->isConditional()) { if (UndefValue *Cond = dyn_cast<UndefValue>(BI->getCondition())) { DEBUG(dbgs() << "LoopSimplify: Resolving \"br i1 undef\" to exit in " - << (*I)->getName() << "\n"); + << ExitingBlock->getName() << "\n"); BI->setCondition(ConstantInt::get(Cond->getType(), !L->contains(BI->getSuccessor(0)))); @@ -545,9 +540,7 @@ ReprocessLoop: SmallSetVector<BasicBlock *, 8> ExitBlockSet(ExitBlocks.begin(), ExitBlocks.end()); - for (SmallSetVector<BasicBlock *, 8>::iterator I = ExitBlockSet.begin(), - E = ExitBlockSet.end(); I != E; ++I) { - BasicBlock *ExitBlock = *I; + for (BasicBlock *ExitBlock : ExitBlockSet) { for (pred_iterator PI = pred_begin(ExitBlock), PE = pred_end(ExitBlock); PI != PE; ++PI) // Must be exactly this loop: no subloops, parent loops, or non-loop preds @@ -691,8 +684,10 @@ ReprocessLoop: } DT->eraseNode(ExitingBlock); - BI->getSuccessor(0)->removePredecessor(ExitingBlock); - BI->getSuccessor(1)->removePredecessor(ExitingBlock); + BI->getSuccessor(0)->removePredecessor( + ExitingBlock, /* DontDeleteUselessPHIs */ PreserveLCSSA); + BI->getSuccessor(1)->removePredecessor( + ExitingBlock, /* DontDeleteUselessPHIs */ PreserveLCSSA); ExitingBlock->eraseFromParent(); } } @@ -731,11 +726,6 @@ namespace { initializeLoopSimplifyPass(*PassRegistry::getPassRegistry()); } - DominatorTree *DT; - LoopInfo *LI; - ScalarEvolution *SE; - AssumptionCache *AC; - bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -753,7 +743,8 @@ namespace { AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<ScalarEvolutionWrapperPass>(); AU.addPreserved<SCEVAAWrapperPass>(); - AU.addPreserved<DependenceAnalysis>(); + AU.addPreservedID(LCSSAID); + AU.addPreserved<DependenceAnalysisWrapperPass>(); AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added. } @@ -768,9 +759,6 @@ INITIALIZE_PASS_BEGIN(LoopSimplify, "loop-simplify", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", "Canonicalize natural loops", false, false) @@ -783,20 +771,64 @@ Pass *llvm::createLoopSimplifyPass() { return new LoopSimplify(); } /// bool LoopSimplify::runOnFunction(Function &F) { bool Changed = false; - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); - SE = SEWP ? &SEWP->getSE() : nullptr; - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + ScalarEvolution *SE = SEWP ? &SEWP->getSE() : nullptr; + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); +#ifndef NDEBUG + if (PreserveLCSSA) { + assert(DT && "DT not available."); + assert(LI && "LI not available."); + bool InLCSSA = + all_of(*LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT); }); + assert(InLCSSA && "Requested to preserve LCSSA, but it's already broken."); + } +#endif // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) Changed |= simplifyLoop(*I, DT, LI, SE, AC, PreserveLCSSA); +#ifndef NDEBUG + if (PreserveLCSSA) { + bool InLCSSA = + all_of(*LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT); }); + assert(InLCSSA && "LCSSA is broken after loop-simplify."); + } +#endif return Changed; } +PreservedAnalyses LoopSimplifyPass::run(Function &F, + AnalysisManager<Function> &AM) { + bool Changed = false; + LoopInfo *LI = &AM.getResult<LoopAnalysis>(F); + DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); + ScalarEvolution *SE = AM.getCachedResult<ScalarEvolutionAnalysis>(F); + AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); + + // FIXME: This pass should verify that the loops on which it's operating + // are in canonical SSA form, and that the pass itself preserves this form. + for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) + Changed |= simplifyLoop(*I, DT, LI, SE, AC, true /* PreserveLCSSA */); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + PA.preserve<BasicAA>(); + PA.preserve<GlobalsAA>(); + PA.preserve<SCEVAA>(); + PA.preserve<ScalarEvolutionAnalysis>(); + PA.preserve<DependenceAnalysis>(); + return PA; +} + // FIXME: Restore this code when we re-enable verification in verifyAnalysis // below. #if 0 diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index eea9237ba80c..7f1f78fa8b41 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -34,6 +34,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" using namespace llvm; @@ -44,9 +45,14 @@ using namespace llvm; STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled"); STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)"); -/// RemapInstruction - Convert the instruction operands from referencing the -/// current values into those specified by VMap. -static inline void RemapInstruction(Instruction *I, +static cl::opt<bool> +UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(true), cl::Hidden, + cl::desc("Allow runtime unrolled loops to be unrolled " + "with epilog instead of prolog.")); + +/// Convert the instruction operands from referencing the current values into +/// those specified by VMap. +static inline void remapInstruction(Instruction *I, ValueToValueMapTy &VMap) { for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) { Value *Op = I->getOperand(op); @@ -64,8 +70,8 @@ static inline void RemapInstruction(Instruction *I, } } -/// FoldBlockIntoPredecessor - Folds a basic block into its predecessor if it -/// only has one predecessor, and that predecessor only has one successor. +/// Folds a basic block into its predecessor if it only has one predecessor, and +/// that predecessor only has one successor. /// The LoopInfo Analysis that is passed will be kept consistent. If folding is /// successful references to the containing loop must be removed from /// ScalarEvolution by calling ScalarEvolution::forgetLoop because SE may have @@ -73,8 +79,9 @@ static inline void RemapInstruction(Instruction *I, /// of loops that have already been forgotten to prevent redundant, expensive /// calls to ScalarEvolution::forgetLoop. Returns the new combined block. static BasicBlock * -FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, ScalarEvolution *SE, - SmallPtrSetImpl<Loop *> &ForgottenLoops) { +foldBlockIntoPredecessor(BasicBlock *BB, LoopInfo *LI, ScalarEvolution *SE, + SmallPtrSetImpl<Loop *> &ForgottenLoops, + DominatorTree *DT) { // Merge basic blocks into their predecessor if there is only one distinct // pred, and if there is only one distinct successor of the predecessor, and // if there are no PHI nodes. @@ -106,7 +113,16 @@ FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, ScalarEvolution *SE, // OldName will be valid until erased. StringRef OldName = BB->getName(); - // Erase basic block from the function... + // Erase the old block and update dominator info. + if (DT) + if (DomTreeNode *DTN = DT->getNode(BB)) { + DomTreeNode *PredDTN = DT->getNode(OnlyPred); + SmallVector<DomTreeNode *, 8> Children(DTN->begin(), DTN->end()); + for (auto *DI : Children) + DT->changeImmediateDominator(DI, PredDTN); + + DT->eraseNode(BB); + } // ScalarEvolution holds references to loop exit blocks. if (SE) { @@ -126,6 +142,35 @@ FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, ScalarEvolution *SE, return OnlyPred; } +/// Check if unrolling created a situation where we need to insert phi nodes to +/// preserve LCSSA form. +/// \param Blocks is a vector of basic blocks representing unrolled loop. +/// \param L is the outer loop. +/// It's possible that some of the blocks are in L, and some are not. In this +/// case, if there is a use is outside L, and definition is inside L, we need to +/// insert a phi-node, otherwise LCSSA will be broken. +/// The function is just a helper function for llvm::UnrollLoop that returns +/// true if this situation occurs, indicating that LCSSA needs to be fixed. +static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks, + LoopInfo *LI) { + for (BasicBlock *BB : Blocks) { + if (LI->getLoopFor(BB) == L) + continue; + for (Instruction &I : *BB) { + for (Use &U : I.operands()) { + if (auto Def = dyn_cast<Instruction>(U)) { + Loop *DefLoop = LI->getLoopFor(Def->getParent()); + if (!DefLoop) + continue; + if (DefLoop->contains(L)) + return true; + } + } + } + } + return false; +} + /// Unroll the given loop by Count. The loop must be in LCSSA form. Returns true /// if unrolling was successful, or false if the loop was unmodified. Unrolling /// can only fail when the loop's latch block is not terminated by a conditional @@ -155,7 +200,7 @@ FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, ScalarEvolution *SE, /// /// This utility preserves LoopInfo. It will also preserve ScalarEvolution and /// DominatorTree if they are non-null. -bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, +bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, bool AllowRuntime, bool AllowExpensiveTripCount, unsigned TripMultiple, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, @@ -218,20 +263,48 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool CompletelyUnroll = Count == TripCount; SmallVector<BasicBlock *, 4> ExitBlocks; L->getExitBlocks(ExitBlocks); - Loop *ParentL = L->getParentLoop(); - bool AllExitsAreInsideParentLoop = !ParentL || - std::all_of(ExitBlocks.begin(), ExitBlocks.end(), - [&](BasicBlock *BB) { return ParentL->contains(BB); }); + std::vector<BasicBlock*> OriginalLoopBlocks = L->getBlocks(); + + // Go through all exits of L and see if there are any phi-nodes there. We just + // conservatively assume that they're inserted to preserve LCSSA form, which + // means that complete unrolling might break this form. We need to either fix + // it in-place after the transformation, or entirely rebuild LCSSA. TODO: For + // now we just recompute LCSSA for the outer loop, but it should be possible + // to fix it in-place. + bool NeedToFixLCSSA = PreserveLCSSA && CompletelyUnroll && + std::any_of(ExitBlocks.begin(), ExitBlocks.end(), + [&](BasicBlock *BB) { return isa<PHINode>(BB->begin()); }); // We assume a run-time trip count if the compiler cannot // figure out the loop trip count and the unroll-runtime // flag is specified. bool RuntimeTripCount = (TripCount == 0 && Count > 0 && AllowRuntime); - if (RuntimeTripCount && - !UnrollRuntimeLoopProlog(L, Count, AllowExpensiveTripCount, LI, SE, DT, - PreserveLCSSA)) - return false; + // Loops containing convergent instructions must have a count that divides + // their TripMultiple. + DEBUG( + { + bool HasConvergent = false; + for (auto &BB : L->blocks()) + for (auto &I : *BB) + if (auto CS = CallSite(&I)) + HasConvergent |= CS.isConvergent(); + assert((!HasConvergent || TripMultiple % Count == 0) && + "Unroll count must divide trip multiple if loop contains a " + "convergent operation."); + }); + // Don't output the runtime loop remainder if Count is a multiple of + // TripMultiple. Such a remainder is never needed, and is unsafe if the loop + // contains a convergent instruction. + if (RuntimeTripCount && TripMultiple % Count != 0 && + !UnrollRuntimeLoopRemainder(L, Count, AllowExpensiveTripCount, + UnrollRuntimeEpilog, LI, SE, DT, + PreserveLCSSA)) { + if (Force) + RuntimeTripCount = false; + else + return false; + } // Notify ScalarEvolution that the loop will be substantially changed, // if not outright eliminated. @@ -308,6 +381,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO(); LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO(); + std::vector<BasicBlock*> UnrolledLoopBlocks = L->getBlocks(); for (unsigned It = 1; It != Count; ++It) { std::vector<BasicBlock*> NewBlocks; SmallDenseMap<const Loop *, Loop *, 4> NewLoops; @@ -349,13 +423,13 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, if (*BB == Header) // Loop over all of the PHI nodes in the block, changing them to use // the incoming values from the previous block. - for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) { - PHINode *NewPHI = cast<PHINode>(VMap[OrigPHINode[i]]); + for (PHINode *OrigPHI : OrigPHINode) { + PHINode *NewPHI = cast<PHINode>(VMap[OrigPHI]); Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock); if (Instruction *InValI = dyn_cast<Instruction>(InVal)) if (It > 1 && L->contains(InValI)) InVal = LastValueMap[InValI]; - VMap[OrigPHINode[i]] = InVal; + VMap[OrigPHI] = InVal; New->getInstList().erase(NewPHI); } @@ -366,11 +440,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, LastValueMap[VI->first] = VI->second; // Add phi entries for newly created values to all exit blocks. - for (succ_iterator SI = succ_begin(*BB), SE = succ_end(*BB); - SI != SE; ++SI) { - if (L->contains(*SI)) + for (BasicBlock *Succ : successors(*BB)) { + if (L->contains(Succ)) continue; - for (BasicBlock::iterator BBI = (*SI)->begin(); + for (BasicBlock::iterator BBI = Succ->begin(); PHINode *phi = dyn_cast<PHINode>(BBI); ++BBI) { Value *Incoming = phi->getIncomingValueForBlock(*BB); ValueToValueMapTy::iterator It = LastValueMap.find(Incoming); @@ -387,18 +460,33 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, Latches.push_back(New); NewBlocks.push_back(New); + UnrolledLoopBlocks.push_back(New); + + // Update DomTree: since we just copy the loop body, and each copy has a + // dedicated entry block (copy of the header block), this header's copy + // dominates all copied blocks. That means, dominance relations in the + // copied body are the same as in the original body. + if (DT) { + if (*BB == Header) + DT->addNewBlock(New, Latches[It - 1]); + else { + auto BBDomNode = DT->getNode(*BB); + auto BBIDom = BBDomNode->getIDom(); + BasicBlock *OriginalBBIDom = BBIDom->getBlock(); + DT->addNewBlock( + New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)])); + } + } } // Remap all instructions in the most recent iteration - for (unsigned i = 0; i < NewBlocks.size(); ++i) - for (BasicBlock::iterator I = NewBlocks[i]->begin(), - E = NewBlocks[i]->end(); I != E; ++I) - ::RemapInstruction(&*I, LastValueMap); + for (BasicBlock *NewBlock : NewBlocks) + for (Instruction &I : *NewBlock) + ::remapInstruction(&I, LastValueMap); } // Loop over the PHI nodes in the original block, setting incoming values. - for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) { - PHINode *PN = OrigPHINode[i]; + for (PHINode *PN : OrigPHINode) { if (CompletelyUnroll) { PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader)); Header->getInstList().erase(PN); @@ -453,11 +541,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, // Remove phi operands at this loop exit if (Dest != LoopExit) { BasicBlock *BB = Latches[i]; - for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); - SI != SE; ++SI) { - if (*SI == Headers[i]) + for (BasicBlock *Succ: successors(BB)) { + if (Succ == Headers[i]) continue; - for (BasicBlock::iterator BBI = (*SI)->begin(); + for (BasicBlock::iterator BBI = Succ->begin(); PHINode *Phi = dyn_cast<PHINode>(BBI); ++BBI) { Phi->removeIncomingValue(BB, false); } @@ -468,16 +555,43 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, Term->eraseFromParent(); } } + // Update dominators of blocks we might reach through exits. + // Immediate dominator of such block might change, because we add more + // routes which can lead to the exit: we can now reach it from the copied + // iterations too. Thus, the new idom of the block will be the nearest + // common dominator of the previous idom and common dominator of all copies of + // the previous idom. This is equivalent to the nearest common dominator of + // the previous idom and the first latch, which dominates all copies of the + // previous idom. + if (DT && Count > 1) { + for (auto *BB : OriginalLoopBlocks) { + auto *BBDomNode = DT->getNode(BB); + SmallVector<BasicBlock *, 16> ChildrenToUpdate; + for (auto *ChildDomNode : BBDomNode->getChildren()) { + auto *ChildBB = ChildDomNode->getBlock(); + if (!L->contains(ChildBB)) + ChildrenToUpdate.push_back(ChildBB); + } + BasicBlock *NewIDom = DT->findNearestCommonDominator(BB, Latches[0]); + for (auto *ChildBB : ChildrenToUpdate) + DT->changeImmediateDominator(ChildBB, NewIDom); + } + } // Merge adjacent basic blocks, if possible. SmallPtrSet<Loop *, 4> ForgottenLoops; - for (unsigned i = 0, e = Latches.size(); i != e; ++i) { - BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator()); + for (BasicBlock *Latch : Latches) { + BranchInst *Term = cast<BranchInst>(Latch->getTerminator()); if (Term->isUnconditional()) { BasicBlock *Dest = Term->getSuccessor(0); - if (BasicBlock *Fold = FoldBlockIntoPredecessor(Dest, LI, SE, - ForgottenLoops)) + if (BasicBlock *Fold = + foldBlockIntoPredecessor(Dest, LI, SE, ForgottenLoops, DT)) { + // Dest has been folded into Fold. Update our worklists accordingly. std::replace(Latches.begin(), Latches.end(), Dest, Fold); + UnrolledLoopBlocks.erase(std::remove(UnrolledLoopBlocks.begin(), + UnrolledLoopBlocks.end(), Dest), + UnrolledLoopBlocks.end()); + } } } @@ -485,10 +599,12 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, // whole function's cache. AC->clear(); - // FIXME: Reconstruct dom info, because it is not preserved properly. - // Incrementally updating domtree after loop unrolling would be easy. - if (DT) + // FIXME: We only preserve DT info for complete unrolling now. Incrementally + // updating domtree after partial loop unrolling should also be easy. + if (DT && !CompletelyUnroll) DT->recalculate(*L->getHeader()->getParent()); + else if (DT) + DEBUG(DT->verifyDomTree()); // Simplify any new induction variables in the partially unrolled loop. if (SE && !CompletelyUnroll) { @@ -508,19 +624,17 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, // go. const DataLayout &DL = Header->getModule()->getDataLayout(); const std::vector<BasicBlock*> &NewLoopBlocks = L->getBlocks(); - for (std::vector<BasicBlock*>::const_iterator BB = NewLoopBlocks.begin(), - BBE = NewLoopBlocks.end(); BB != BBE; ++BB) - for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); I != E; ) { + for (BasicBlock *BB : NewLoopBlocks) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) { Instruction *Inst = &*I++; - if (isInstructionTriviallyDead(Inst)) - (*BB)->getInstList().erase(Inst); - else if (Value *V = SimplifyInstruction(Inst, DL)) - if (LI->replacementPreservesLCSSAForm(Inst, V)) { + if (Value *V = SimplifyInstruction(Inst, DL)) + if (LI->replacementPreservesLCSSAForm(Inst, V)) Inst->replaceAllUsesWith(V); - (*BB)->getInstList().erase(Inst); - } + if (isInstructionTriviallyDead(Inst)) + BB->getInstList().erase(Inst); } + } NumCompletelyUnrolled += CompletelyUnroll; ++NumUnrolled; @@ -530,6 +644,17 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, if (CompletelyUnroll) LI->markAsRemoved(L); + // After complete unrolling most of the blocks should be contained in OuterL. + // However, some of them might happen to be out of OuterL (e.g. if they + // precede a loop exit). In this case we might need to insert PHI nodes in + // order to preserve LCSSA form. + // We don't need to check this if we already know that we need to fix LCSSA + // form. + // TODO: For now we just recompute LCSSA for the outer loop in this case, but + // it should be possible to fix it in-place. + if (PreserveLCSSA && OuterL && CompletelyUnroll && !NeedToFixLCSSA) + NeedToFixLCSSA |= ::needToInsertPhisForLCSSA(OuterL, UnrolledLoopBlocks, LI); + // If we have a pass and a DominatorTree we should re-simplify impacted loops // to ensure subsequent analyses can rely on this form. We want to simplify // at least one layer outside of the loop that was unrolled so that any @@ -538,7 +663,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, if (!OuterL && !CompletelyUnroll) OuterL = L; if (OuterL) { - bool Simplified = simplifyLoop(OuterL, DT, LI, SE, AC, PreserveLCSSA); + simplifyLoop(OuterL, DT, LI, SE, AC, PreserveLCSSA); // LCSSA must be performed on the outermost affected loop. The unrolled // loop's last loop latch is guaranteed to be in the outermost loop after @@ -548,7 +673,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, while (OuterL->getParentLoop() != LatchLoop) OuterL = OuterL->getParentLoop(); - if (CompletelyUnroll && (!AllExitsAreInsideParentLoop || Simplified)) + if (NeedToFixLCSSA) formLCSSARecursively(*OuterL, *DT, LI, SE); else assert(OuterL->isLCSSAForm(*DT) && diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 0d68f18ad0e5..861a50cf354d 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -16,8 +16,8 @@ // case, we need to generate code to execute these 'left over' iterations. // // The current strategy generates an if-then-else sequence prior to the -// unrolled loop to execute the 'left over' iterations. Other strategies -// include generate a loop before or after the unrolled loop. +// unrolled loop to execute the 'left over' iterations before or after the +// unrolled loop. // //===----------------------------------------------------------------------===// @@ -60,91 +60,220 @@ STATISTIC(NumRuntimeUnrolled, /// than the unroll factor. /// static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, - BasicBlock *LastPrologBB, BasicBlock *PrologEnd, - BasicBlock *OrigPH, BasicBlock *NewPH, - ValueToValueMapTy &VMap, DominatorTree *DT, - LoopInfo *LI, bool PreserveLCSSA) { + BasicBlock *PrologExit, BasicBlock *PreHeader, + BasicBlock *NewPreHeader, ValueToValueMapTy &VMap, + DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA) { BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); + BasicBlock *PrologLatch = cast<BasicBlock>(VMap[Latch]); // Create a PHI node for each outgoing value from the original loop // (which means it is an outgoing value from the prolog code too). // The new PHI node is inserted in the prolog end basic block. - // The new PHI name is added as an operand of a PHI node in either + // The new PHI node value is added as an operand of a PHI node in either // the loop header or the loop exit block. - for (succ_iterator SBI = succ_begin(Latch), SBE = succ_end(Latch); - SBI != SBE; ++SBI) { - for (BasicBlock::iterator BBI = (*SBI)->begin(); - PHINode *PN = dyn_cast<PHINode>(BBI); ++BBI) { - + for (BasicBlock *Succ : successors(Latch)) { + for (Instruction &BBI : *Succ) { + PHINode *PN = dyn_cast<PHINode>(&BBI); + // Exit when we passed all PHI nodes. + if (!PN) + break; // Add a new PHI node to the prolog end block and add the // appropriate incoming values. - PHINode *NewPN = PHINode::Create(PN->getType(), 2, PN->getName()+".unr", - PrologEnd->getTerminator()); + PHINode *NewPN = PHINode::Create(PN->getType(), 2, PN->getName() + ".unr", + PrologExit->getFirstNonPHI()); // Adding a value to the new PHI node from the original loop preheader. // This is the value that skips all the prolog code. if (L->contains(PN)) { - NewPN->addIncoming(PN->getIncomingValueForBlock(NewPH), OrigPH); + NewPN->addIncoming(PN->getIncomingValueForBlock(NewPreHeader), + PreHeader); } else { - NewPN->addIncoming(UndefValue::get(PN->getType()), OrigPH); + NewPN->addIncoming(UndefValue::get(PN->getType()), PreHeader); } Value *V = PN->getIncomingValueForBlock(Latch); if (Instruction *I = dyn_cast<Instruction>(V)) { if (L->contains(I)) { - V = VMap[I]; + V = VMap.lookup(I); } } // Adding a value to the new PHI node from the last prolog block // that was created. - NewPN->addIncoming(V, LastPrologBB); + NewPN->addIncoming(V, PrologLatch); // Update the existing PHI node operand with the value from the // new PHI node. How this is done depends on if the existing // PHI node is in the original loop block, or the exit block. if (L->contains(PN)) { - PN->setIncomingValue(PN->getBasicBlockIndex(NewPH), NewPN); + PN->setIncomingValue(PN->getBasicBlockIndex(NewPreHeader), NewPN); } else { - PN->addIncoming(NewPN, PrologEnd); + PN->addIncoming(NewPN, PrologExit); } } } - // Create a branch around the orignal loop, which is taken if there are no + // Create a branch around the original loop, which is taken if there are no // iterations remaining to be executed after running the prologue. - Instruction *InsertPt = PrologEnd->getTerminator(); + Instruction *InsertPt = PrologExit->getTerminator(); IRBuilder<> B(InsertPt); assert(Count != 0 && "nonsensical Count!"); - // If BECount <u (Count - 1) then (BECount + 1) & (Count - 1) == (BECount + 1) - // (since Count is a power of 2). This means %xtraiter is (BECount + 1) and - // and all of the iterations of this loop were executed by the prologue. Note - // that if BECount <u (Count - 1) then (BECount + 1) cannot unsigned-overflow. + // If BECount <u (Count - 1) then (BECount + 1) % Count == (BECount + 1) + // This means %xtraiter is (BECount + 1) and all of the iterations of this + // loop were executed by the prologue. Note that if BECount <u (Count - 1) + // then (BECount + 1) cannot unsigned-overflow. Value *BrLoopExit = B.CreateICmpULT(BECount, ConstantInt::get(BECount->getType(), Count - 1)); BasicBlock *Exit = L->getUniqueExitBlock(); assert(Exit && "Loop must have a single exit block only"); // Split the exit to maintain loop canonicalization guarantees - SmallVector<BasicBlock*, 4> Preds(pred_begin(Exit), pred_end(Exit)); + SmallVector<BasicBlock*, 4> Preds(predecessors(Exit)); SplitBlockPredecessors(Exit, Preds, ".unr-lcssa", DT, LI, PreserveLCSSA); // Add the branch to the exit block (around the unrolled loop) - B.CreateCondBr(BrLoopExit, Exit, NewPH); + B.CreateCondBr(BrLoopExit, Exit, NewPreHeader); + InsertPt->eraseFromParent(); +} + +/// Connect the unrolling epilog code to the original loop. +/// The unrolling epilog code contains code to execute the +/// 'extra' iterations if the run-time trip count modulo the +/// unroll count is non-zero. +/// +/// This function performs the following: +/// - Update PHI nodes at the unrolling loop exit and epilog loop exit +/// - Create PHI nodes at the unrolling loop exit to combine +/// values that exit the unrolling loop code and jump around it. +/// - Update PHI operands in the epilog loop by the new PHI nodes +/// - Branch around the epilog loop if extra iters (ModVal) is zero. +/// +static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, + BasicBlock *Exit, BasicBlock *PreHeader, + BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader, + ValueToValueMapTy &VMap, DominatorTree *DT, + LoopInfo *LI, bool PreserveLCSSA) { + BasicBlock *Latch = L->getLoopLatch(); + assert(Latch && "Loop must have a latch"); + BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]); + + // Loop structure should be the following: + // + // PreHeader + // NewPreHeader + // Header + // ... + // Latch + // NewExit (PN) + // EpilogPreHeader + // EpilogHeader + // ... + // EpilogLatch + // Exit (EpilogPN) + + // Update PHI nodes at NewExit and Exit. + for (Instruction &BBI : *NewExit) { + PHINode *PN = dyn_cast<PHINode>(&BBI); + // Exit when we passed all PHI nodes. + if (!PN) + break; + // PN should be used in another PHI located in Exit block as + // Exit was split by SplitBlockPredecessors into Exit and NewExit + // Basicaly it should look like: + // NewExit: + // PN = PHI [I, Latch] + // ... + // Exit: + // EpilogPN = PHI [PN, EpilogPreHeader] + // + // There is EpilogPreHeader incoming block instead of NewExit as + // NewExit was spilt 1 more time to get EpilogPreHeader. + assert(PN->hasOneUse() && "The phi should have 1 use"); + PHINode *EpilogPN = cast<PHINode> (PN->use_begin()->getUser()); + assert(EpilogPN->getParent() == Exit && "EpilogPN should be in Exit block"); + + // Add incoming PreHeader from branch around the Loop + PN->addIncoming(UndefValue::get(PN->getType()), PreHeader); + + Value *V = PN->getIncomingValueForBlock(Latch); + Instruction *I = dyn_cast<Instruction>(V); + if (I && L->contains(I)) + // If value comes from an instruction in the loop add VMap value. + V = VMap.lookup(I); + // For the instruction out of the loop, constant or undefined value + // insert value itself. + EpilogPN->addIncoming(V, EpilogLatch); + + assert(EpilogPN->getBasicBlockIndex(EpilogPreHeader) >= 0 && + "EpilogPN should have EpilogPreHeader incoming block"); + // Change EpilogPreHeader incoming block to NewExit. + EpilogPN->setIncomingBlock(EpilogPN->getBasicBlockIndex(EpilogPreHeader), + NewExit); + // Now PHIs should look like: + // NewExit: + // PN = PHI [I, Latch], [undef, PreHeader] + // ... + // Exit: + // EpilogPN = PHI [PN, NewExit], [VMap[I], EpilogLatch] + } + + // Create PHI nodes at NewExit (from the unrolling loop Latch and PreHeader). + // Update corresponding PHI nodes in epilog loop. + for (BasicBlock *Succ : successors(Latch)) { + // Skip this as we already updated phis in exit blocks. + if (!L->contains(Succ)) + continue; + for (Instruction &BBI : *Succ) { + PHINode *PN = dyn_cast<PHINode>(&BBI); + // Exit when we passed all PHI nodes. + if (!PN) + break; + // Add new PHI nodes to the loop exit block and update epilog + // PHIs with the new PHI values. + PHINode *NewPN = PHINode::Create(PN->getType(), 2, PN->getName() + ".unr", + NewExit->getFirstNonPHI()); + // Adding a value to the new PHI node from the unrolling loop preheader. + NewPN->addIncoming(PN->getIncomingValueForBlock(NewPreHeader), PreHeader); + // Adding a value to the new PHI node from the unrolling loop latch. + NewPN->addIncoming(PN->getIncomingValueForBlock(Latch), Latch); + + // Update the existing PHI node operand with the value from the new PHI + // node. Corresponding instruction in epilog loop should be PHI. + PHINode *VPN = cast<PHINode>(VMap[&BBI]); + VPN->setIncomingValue(VPN->getBasicBlockIndex(EpilogPreHeader), NewPN); + } + } + + Instruction *InsertPt = NewExit->getTerminator(); + IRBuilder<> B(InsertPt); + Value *BrLoopExit = B.CreateIsNotNull(ModVal, "lcmp.mod"); + assert(Exit && "Loop must have a single exit block only"); + // Split the exit to maintain loop canonicalization guarantees + SmallVector<BasicBlock*, 4> Preds(predecessors(Exit)); + SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, + PreserveLCSSA); + // Add the branch to the exit block (around the unrolling loop) + B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit); InsertPt->eraseFromParent(); } /// Create a clone of the blocks in a loop and connect them together. -/// If UnrollProlog is true, loop structure will not be cloned, otherwise a new -/// loop will be created including all cloned blocks, and the iterator of it -/// switches to count NewIter down to 0. +/// If CreateRemainderLoop is false, loop structure will not be cloned, +/// otherwise a new loop will be created including all cloned blocks, and the +/// iterator of it switches to count NewIter down to 0. +/// The cloned blocks should be inserted between InsertTop and InsertBot. +/// If loop structure is cloned InsertTop should be new preheader, InsertBot +/// new loop exit. /// -static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, +static void CloneLoopBlocks(Loop *L, Value *NewIter, + const bool CreateRemainderLoop, + const bool UseEpilogRemainder, BasicBlock *InsertTop, BasicBlock *InsertBot, + BasicBlock *Preheader, std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, LoopInfo *LI) { - BasicBlock *Preheader = L->getLoopPreheader(); + StringRef suffix = UseEpilogRemainder ? "epil" : "prol"; BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); Function *F = Header->getParent(); @@ -152,7 +281,7 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, LoopBlocksDFS::RPOIterator BlockEnd = LoopBlocks.endRPO(); Loop *NewLoop = nullptr; Loop *ParentLoop = L->getParentLoop(); - if (!UnrollProlog) { + if (CreateRemainderLoop) { NewLoop = new Loop(); if (ParentLoop) ParentLoop->addChildLoop(NewLoop); @@ -163,7 +292,7 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, // For each block in the original loop, create a new copy, // and update the value map with the newly created values. for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { - BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, ".prol", F); + BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, "." + suffix, F); NewBlocks.push_back(NewBB); if (NewLoop) @@ -176,19 +305,20 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, // For the first block, add a CFG connection to this newly // created block. InsertTop->getTerminator()->setSuccessor(0, NewBB); - } + if (Latch == *BB) { - // For the last block, if UnrollProlog is true, create a direct jump to - // InsertBot. If not, create a loop back to cloned head. + // For the last block, if CreateRemainderLoop is false, create a direct + // jump to InsertBot. If not, create a loop back to cloned head. VMap.erase((*BB)->getTerminator()); BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]); BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator()); IRBuilder<> Builder(LatchBR); - if (UnrollProlog) { + if (!CreateRemainderLoop) { Builder.CreateBr(InsertBot); } else { - PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, "prol.iter", + PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, + suffix + ".iter", FirstLoopBB->getFirstNonPHI()); Value *IdxSub = Builder.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), @@ -207,9 +337,15 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, // cloned loop. for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { PHINode *NewPHI = cast<PHINode>(VMap[&*I]); - if (UnrollProlog) { - VMap[&*I] = NewPHI->getIncomingValueForBlock(Preheader); - cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); + if (!CreateRemainderLoop) { + if (UseEpilogRemainder) { + unsigned idx = NewPHI->getBasicBlockIndex(Preheader); + NewPHI->setIncomingBlock(idx, InsertTop); + NewPHI->removeIncomingValue(Latch, false); + } else { + VMap[&*I] = NewPHI->getIncomingValueForBlock(Preheader); + cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); + } } else { unsigned idx = NewPHI->getBasicBlockIndex(Preheader); NewPHI->setIncomingBlock(idx, InsertTop); @@ -217,8 +353,8 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, idx = NewPHI->getBasicBlockIndex(Latch); Value *InVal = NewPHI->getIncomingValue(idx); NewPHI->setIncomingBlock(idx, NewLatch); - if (VMap[InVal]) - NewPHI->setIncomingValue(idx, VMap[InVal]); + if (Value *V = VMap.lookup(InVal)) + NewPHI->setIncomingValue(idx, V); } } if (NewLoop) { @@ -254,11 +390,11 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, } } -/// Insert code in the prolog code when unrolling a loop with a +/// Insert code in the prolog/epilog code when unrolling a loop with a /// run-time trip-count. /// /// This method assumes that the loop unroll factor is total number -/// of loop bodes in the loop after unrolling. (Some folks refer +/// of loop bodies in the loop after unrolling. (Some folks refer /// to the unroll factor as the number of *extra* copies added). /// We assume also that the loop unroll factor is a power-of-two. So, after /// unrolling the loop, the number of loop bodies executed is 2, @@ -266,37 +402,56 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, /// instruction in SimplifyCFG.cpp. Then, the backend decides how code for /// the switch instruction is generated. /// +/// ***Prolog case*** /// extraiters = tripcount % loopfactor /// if (extraiters == 0) jump Loop: -/// else jump Prol +/// else jump Prol: /// Prol: LoopBody; /// extraiters -= 1 // Omitted if unroll factor is 2. /// if (extraiters != 0) jump Prol: // Omitted if unroll factor is 2. -/// if (tripcount < loopfactor) jump End +/// if (tripcount < loopfactor) jump End: /// Loop: /// ... /// End: /// -bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, - bool AllowExpensiveTripCount, LoopInfo *LI, - ScalarEvolution *SE, DominatorTree *DT, - bool PreserveLCSSA) { +/// ***Epilog case*** +/// extraiters = tripcount % loopfactor +/// if (tripcount < loopfactor) jump LoopExit: +/// unroll_iters = tripcount - extraiters +/// Loop: LoopBody; (executes unroll_iter times); +/// unroll_iter -= 1 +/// if (unroll_iter != 0) jump Loop: +/// LoopExit: +/// if (extraiters == 0) jump EpilExit: +/// Epil: LoopBody; (executes extraiters times) +/// extraiters -= 1 // Omitted if unroll factor is 2. +/// if (extraiters != 0) jump Epil: // Omitted if unroll factor is 2. +/// EpilExit: + +bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, + bool AllowExpensiveTripCount, + bool UseEpilogRemainder, + LoopInfo *LI, ScalarEvolution *SE, + DominatorTree *DT, bool PreserveLCSSA) { // for now, only unroll loops that contain a single exit if (!L->getExitingBlock()) return false; // Make sure the loop is in canonical form, and there is a single // exit block only. - if (!L->isLoopSimplifyForm() || !L->getUniqueExitBlock()) + if (!L->isLoopSimplifyForm()) + return false; + BasicBlock *Exit = L->getUniqueExitBlock(); // successor out of loop + if (!Exit) return false; - // Use Scalar Evolution to compute the trip count. This allows more - // loops to be unrolled than relying on induction var simplification + // Use Scalar Evolution to compute the trip count. This allows more loops to + // be unrolled than relying on induction var simplification. if (!SE) return false; - // Only unroll loops with a computable trip count and the trip count needs - // to be an int value (allowing a pointer type is a TODO item) + // Only unroll loops with a computable trip count, and the trip count needs + // to be an int value (allowing a pointer type is a TODO item). const SCEV *BECountSC = SE->getBackedgeTakenCount(L); if (isa<SCEVCouldNotCompute>(BECountSC) || !BECountSC->getType()->isIntegerTy()) @@ -304,21 +459,19 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth(); - // Add 1 since the backedge count doesn't include the first loop iteration + // Add 1 since the backedge count doesn't include the first loop iteration. const SCEV *TripCountSC = SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1)); if (isa<SCEVCouldNotCompute>(TripCountSC)) return false; BasicBlock *Header = L->getHeader(); + BasicBlock *PreHeader = L->getLoopPreheader(); + BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator()); const DataLayout &DL = Header->getModule()->getDataLayout(); SCEVExpander Expander(*SE, DL, "loop-unroll"); - if (!AllowExpensiveTripCount && Expander.isHighCostExpansion(TripCountSC, L)) - return false; - - // We only handle cases when the unroll factor is a power of 2. - // Count is the loop unroll factor, the number of extra copies added + 1. - if (!isPowerOf2_32(Count)) + if (!AllowExpensiveTripCount && + Expander.isHighCostExpansion(TripCountSC, L, PreHeaderBR)) return false; // This constraint lets us deal with an overflowing trip count easily; see the @@ -326,51 +479,115 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, if (Log2_32(Count) > BEWidth) return false; - // If this loop is nested, then the loop unroller changes the code in - // parent loop, so the Scalar Evolution pass needs to be run again + // If this loop is nested, then the loop unroller changes the code in the + // parent loop, so the Scalar Evolution pass needs to be run again. if (Loop *ParentLoop = L->getParentLoop()) SE->forgetLoop(ParentLoop); - BasicBlock *PH = L->getLoopPreheader(); BasicBlock *Latch = L->getLoopLatch(); - // It helps to splits the original preheader twice, one for the end of the - // prolog code and one for a new loop preheader - BasicBlock *PEnd = SplitEdge(PH, Header, DT, LI); - BasicBlock *NewPH = SplitBlock(PEnd, PEnd->getTerminator(), DT, LI); - BranchInst *PreHeaderBR = cast<BranchInst>(PH->getTerminator()); + // Loop structure is the following: + // + // PreHeader + // Header + // ... + // Latch + // Exit + + BasicBlock *NewPreHeader; + BasicBlock *NewExit = nullptr; + BasicBlock *PrologExit = nullptr; + BasicBlock *EpilogPreHeader = nullptr; + BasicBlock *PrologPreHeader = nullptr; + + if (UseEpilogRemainder) { + // If epilog remainder + // Split PreHeader to insert a branch around loop for unrolling. + NewPreHeader = SplitBlock(PreHeader, PreHeader->getTerminator(), DT, LI); + NewPreHeader->setName(PreHeader->getName() + ".new"); + // Split Exit to create phi nodes from branch above. + SmallVector<BasicBlock*, 4> Preds(predecessors(Exit)); + NewExit = SplitBlockPredecessors(Exit, Preds, ".unr-lcssa", + DT, LI, PreserveLCSSA); + // Split NewExit to insert epilog remainder loop. + EpilogPreHeader = SplitBlock(NewExit, NewExit->getTerminator(), DT, LI); + EpilogPreHeader->setName(Header->getName() + ".epil.preheader"); + } else { + // If prolog remainder + // Split the original preheader twice to insert prolog remainder loop + PrologPreHeader = SplitEdge(PreHeader, Header, DT, LI); + PrologPreHeader->setName(Header->getName() + ".prol.preheader"); + PrologExit = SplitBlock(PrologPreHeader, PrologPreHeader->getTerminator(), + DT, LI); + PrologExit->setName(Header->getName() + ".prol.loopexit"); + // Split PrologExit to get NewPreHeader. + NewPreHeader = SplitBlock(PrologExit, PrologExit->getTerminator(), DT, LI); + NewPreHeader->setName(PreHeader->getName() + ".new"); + } + // Loop structure should be the following: + // Epilog Prolog + // + // PreHeader PreHeader + // *NewPreHeader *PrologPreHeader + // Header *PrologExit + // ... *NewPreHeader + // Latch Header + // *NewExit ... + // *EpilogPreHeader Latch + // Exit Exit + + // Calculate conditions for branch around loop for unrolling + // in epilog case and around prolog remainder loop in prolog case. // Compute the number of extra iterations required, which is: - // extra iterations = run-time trip count % (loop unroll factor + 1) + // extra iterations = run-time trip count % loop unroll factor + PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator()); Value *TripCount = Expander.expandCodeFor(TripCountSC, TripCountSC->getType(), PreHeaderBR); Value *BECount = Expander.expandCodeFor(BECountSC, BECountSC->getType(), PreHeaderBR); - IRBuilder<> B(PreHeaderBR); - Value *ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter"); - - // If ModVal is zero, we know that either - // 1. there are no iteration to be run in the prologue loop - // OR - // 2. the addition computing TripCount overflowed - // - // If (2) is true, we know that TripCount really is (1 << BEWidth) and so the - // number of iterations that remain to be run in the original loop is a - // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (we - // explicitly check this above). - - Value *BranchVal = B.CreateIsNotNull(ModVal, "lcmp.mod"); - - // Branch to either the extra iterations or the cloned/unrolled loop - // We will fix up the true branch label when adding loop body copies - B.CreateCondBr(BranchVal, PEnd, PEnd); - assert(PreHeaderBR->isUnconditional() && - PreHeaderBR->getSuccessor(0) == PEnd && - "CFG edges in Preheader are not correct"); + Value *ModVal; + // Calculate ModVal = (BECount + 1) % Count. + // Note that TripCount is BECount + 1. + if (isPowerOf2_32(Count)) { + // When Count is power of 2 we don't BECount for epilog case, however we'll + // need it for a branch around unrolling loop for prolog case. + ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter"); + // 1. There are no iterations to be run in the prolog/epilog loop. + // OR + // 2. The addition computing TripCount overflowed. + // + // If (2) is true, we know that TripCount really is (1 << BEWidth) and so + // the number of iterations that remain to be run in the original loop is a + // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (we + // explicitly check this above). + } else { + // As (BECount + 1) can potentially unsigned overflow we count + // (BECount % Count) + 1 which is overflow safe as BECount % Count < Count. + Value *ModValTmp = B.CreateURem(BECount, + ConstantInt::get(BECount->getType(), + Count)); + Value *ModValAdd = B.CreateAdd(ModValTmp, + ConstantInt::get(ModValTmp->getType(), 1)); + // At that point (BECount % Count) + 1 could be equal to Count. + // To handle this case we need to take mod by Count one more time. + ModVal = B.CreateURem(ModValAdd, + ConstantInt::get(BECount->getType(), Count), + "xtraiter"); + } + Value *BranchVal = + UseEpilogRemainder ? B.CreateICmpULT(BECount, + ConstantInt::get(BECount->getType(), + Count - 1)) : + B.CreateIsNotNull(ModVal, "lcmp.mod"); + BasicBlock *RemainderLoop = UseEpilogRemainder ? NewExit : PrologPreHeader; + BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit; + // Branch to either remainder (extra iterations) loop or unrolling loop. + B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop); PreHeaderBR->eraseFromParent(); Function *F = Header->getParent(); // Get an ordered list of blocks in the loop to help with the ordering of the - // cloned blocks in the prolog code + // cloned blocks in the prolog/epilog code LoopBlocksDFS LoopBlocks(L); LoopBlocks.perform(LI); @@ -382,34 +599,80 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, std::vector<BasicBlock *> NewBlocks; ValueToValueMapTy VMap; - bool UnrollPrologue = Count == 2; + // For unroll factor 2 remainder loop will have 1 iterations. + // Do not create 1 iteration loop. + bool CreateRemainderLoop = (Count != 2); // Clone all the basic blocks in the loop. If Count is 2, we don't clone // the loop, otherwise we create a cloned loop to execute the extra // iterations. This function adds the appropriate CFG connections. - CloneLoopBlocks(L, ModVal, UnrollPrologue, PH, PEnd, NewBlocks, LoopBlocks, - VMap, LI); - - // Insert the cloned blocks into function just before the original loop - F->getBasicBlockList().splice(PEnd->getIterator(), F->getBasicBlockList(), - NewBlocks[0]->getIterator(), F->end()); - - // Rewrite the cloned instruction operands to use the values - // created when the clone is created. - for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { - for (BasicBlock::iterator I = NewBlocks[i]->begin(), - E = NewBlocks[i]->end(); - I != E; ++I) { - RemapInstruction(&*I, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + BasicBlock *InsertBot = UseEpilogRemainder ? Exit : PrologExit; + BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader; + CloneLoopBlocks(L, ModVal, CreateRemainderLoop, UseEpilogRemainder, InsertTop, + InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, LI); + + // Insert the cloned blocks into the function. + F->getBasicBlockList().splice(InsertBot->getIterator(), + F->getBasicBlockList(), + NewBlocks[0]->getIterator(), + F->end()); + + // Loop structure should be the following: + // Epilog Prolog + // + // PreHeader PreHeader + // NewPreHeader PrologPreHeader + // Header PrologHeader + // ... ... + // Latch PrologLatch + // NewExit PrologExit + // EpilogPreHeader NewPreHeader + // EpilogHeader Header + // ... ... + // EpilogLatch Latch + // Exit Exit + + // Rewrite the cloned instruction operands to use the values created when the + // clone is created. + for (BasicBlock *BB : NewBlocks) { + for (Instruction &I : *BB) { + RemapInstruction(&I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); } } - // Connect the prolog code to the original loop and update the - // PHI functions. - BasicBlock *LastLoopBB = cast<BasicBlock>(VMap[Latch]); - ConnectProlog(L, BECount, Count, LastLoopBB, PEnd, PH, NewPH, VMap, DT, LI, - PreserveLCSSA); + if (UseEpilogRemainder) { + // Connect the epilog code to the original loop and update the + // PHI functions. + ConnectEpilog(L, ModVal, NewExit, Exit, PreHeader, + EpilogPreHeader, NewPreHeader, VMap, DT, LI, + PreserveLCSSA); + + // Update counter in loop for unrolling. + // I should be multiply of Count. + IRBuilder<> B2(NewPreHeader->getTerminator()); + Value *TestVal = B2.CreateSub(TripCount, ModVal, "unroll_iter"); + BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator()); + B2.SetInsertPoint(LatchBR); + PHINode *NewIdx = PHINode::Create(TestVal->getType(), 2, "niter", + Header->getFirstNonPHI()); + Value *IdxSub = + B2.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), + NewIdx->getName() + ".nsub"); + Value *IdxCmp; + if (LatchBR->getSuccessor(0) == Header) + IdxCmp = B2.CreateIsNotNull(IdxSub, NewIdx->getName() + ".ncmp"); + else + IdxCmp = B2.CreateIsNull(IdxSub, NewIdx->getName() + ".ncmp"); + NewIdx->addIncoming(TestVal, NewPreHeader); + NewIdx->addIncoming(IdxSub, Latch); + LatchBR->setCondition(IdxCmp); + } else { + // Connect the prolog code to the original loop and update the + // PHI functions. + ConnectProlog(L, BECount, Count, PrologExit, PreHeader, NewPreHeader, + VMap, DT, LI, PreserveLCSSA); + } NumRuntimeUnrolled++; return true; } diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp index fa958e913b7b..3902c67c6a01 100644 --- a/lib/Transforms/Utils/LoopUtils.cpp +++ b/lib/Transforms/Utils/LoopUtils.cpp @@ -11,13 +11,20 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -423,7 +430,7 @@ RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind, default: return InstDesc(false, I); case Instruction::PHI: - return InstDesc(I, Prev.getMinMaxKind()); + return InstDesc(I, Prev.getMinMaxKind(), Prev.getUnsafeAlgebraInst()); case Instruction::Sub: case Instruction::Add: return InstDesc(Kind == RK_IntegerAdd, I); @@ -466,12 +473,10 @@ bool RecurrenceDescriptor::hasMultipleUsesOf( bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, RecurrenceDescriptor &RedDes) { - bool HasFunNoNaNAttr = false; BasicBlock *Header = TheLoop->getHeader(); Function &F = *Header->getParent(); - if (F.hasFnAttribute("no-nans-fp-math")) - HasFunNoNaNAttr = - F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; + bool HasFunNoNaNAttr = + F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes)) { DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); @@ -514,6 +519,43 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, return false; } +bool RecurrenceDescriptor::isFirstOrderRecurrence(PHINode *Phi, Loop *TheLoop, + DominatorTree *DT) { + + // Ensure the phi node is in the loop header and has two incoming values. + if (Phi->getParent() != TheLoop->getHeader() || + Phi->getNumIncomingValues() != 2) + return false; + + // Ensure the loop has a preheader and a single latch block. The loop + // vectorizer will need the latch to set up the next iteration of the loop. + auto *Preheader = TheLoop->getLoopPreheader(); + auto *Latch = TheLoop->getLoopLatch(); + if (!Preheader || !Latch) + return false; + + // Ensure the phi node's incoming blocks are the loop preheader and latch. + if (Phi->getBasicBlockIndex(Preheader) < 0 || + Phi->getBasicBlockIndex(Latch) < 0) + return false; + + // Get the previous value. The previous value comes from the latch edge while + // the initial value comes form the preheader edge. + auto *Previous = dyn_cast<Instruction>(Phi->getIncomingValueForBlock(Latch)); + if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous)) + return false; + + // Ensure every user of the phi node is dominated by the previous value. The + // dominance requirement ensures the loop vectorizer will not need to + // vectorize the initial value prior to the first iteration of the loop. + for (User *U : Phi->users()) + if (auto *I = dyn_cast<Instruction>(U)) + if (!DT->dominates(Previous, I)) + return false; + + return true; +} + /// This function returns the identity element (or neutral element) for /// the operation K. Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurrenceKind K, @@ -612,61 +654,120 @@ Value *RecurrenceDescriptor::createMinMaxOp(IRBuilder<> &Builder, } InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, - ConstantInt *Step) - : StartValue(Start), IK(K), StepValue(Step) { + const SCEV *Step) + : StartValue(Start), IK(K), Step(Step) { assert(IK != IK_NoInduction && "Not an induction"); + + // Start value type should match the induction kind and the value + // itself should not be null. assert(StartValue && "StartValue is null"); - assert(StepValue && !StepValue->isZero() && "StepValue is zero"); assert((IK != IK_PtrInduction || StartValue->getType()->isPointerTy()) && "StartValue is not a pointer for pointer induction"); assert((IK != IK_IntInduction || StartValue->getType()->isIntegerTy()) && "StartValue is not an integer for integer induction"); - assert(StepValue->getType()->isIntegerTy() && - "StepValue is not an integer"); + + // Check the Step Value. It should be non-zero integer value. + assert((!getConstIntStepValue() || !getConstIntStepValue()->isZero()) && + "Step value is zero"); + + assert((IK != IK_PtrInduction || getConstIntStepValue()) && + "Step value should be constant for pointer induction"); + assert(Step->getType()->isIntegerTy() && "StepValue is not an integer"); } int InductionDescriptor::getConsecutiveDirection() const { - if (StepValue && (StepValue->isOne() || StepValue->isMinusOne())) - return StepValue->getSExtValue(); + ConstantInt *ConstStep = getConstIntStepValue(); + if (ConstStep && (ConstStep->isOne() || ConstStep->isMinusOne())) + return ConstStep->getSExtValue(); return 0; } -Value *InductionDescriptor::transform(IRBuilder<> &B, Value *Index) const { +ConstantInt *InductionDescriptor::getConstIntStepValue() const { + if (isa<SCEVConstant>(Step)) + return dyn_cast<ConstantInt>(cast<SCEVConstant>(Step)->getValue()); + return nullptr; +} + +Value *InductionDescriptor::transform(IRBuilder<> &B, Value *Index, + ScalarEvolution *SE, + const DataLayout& DL) const { + + SCEVExpander Exp(*SE, DL, "induction"); switch (IK) { - case IK_IntInduction: + case IK_IntInduction: { assert(Index->getType() == StartValue->getType() && "Index type does not match StartValue type"); - if (StepValue->isMinusOne()) - return B.CreateSub(StartValue, Index); - if (!StepValue->isOne()) - Index = B.CreateMul(Index, StepValue); - return B.CreateAdd(StartValue, Index); - case IK_PtrInduction: - assert(Index->getType() == StepValue->getType() && + // FIXME: Theoretically, we can call getAddExpr() of ScalarEvolution + // and calculate (Start + Index * Step) for all cases, without + // special handling for "isOne" and "isMinusOne". + // But in the real life the result code getting worse. We mix SCEV + // expressions and ADD/SUB operations and receive redundant + // intermediate values being calculated in different ways and + // Instcombine is unable to reduce them all. + + if (getConstIntStepValue() && + getConstIntStepValue()->isMinusOne()) + return B.CreateSub(StartValue, Index); + if (getConstIntStepValue() && + getConstIntStepValue()->isOne()) + return B.CreateAdd(StartValue, Index); + const SCEV *S = SE->getAddExpr(SE->getSCEV(StartValue), + SE->getMulExpr(Step, SE->getSCEV(Index))); + return Exp.expandCodeFor(S, StartValue->getType(), &*B.GetInsertPoint()); + } + case IK_PtrInduction: { + assert(Index->getType() == Step->getType() && "Index type does not match StepValue type"); - if (StepValue->isMinusOne()) - Index = B.CreateNeg(Index); - else if (!StepValue->isOne()) - Index = B.CreateMul(Index, StepValue); + assert(isa<SCEVConstant>(Step) && + "Expected constant step for pointer induction"); + const SCEV *S = SE->getMulExpr(SE->getSCEV(Index), Step); + Index = Exp.expandCodeFor(S, Index->getType(), &*B.GetInsertPoint()); return B.CreateGEP(nullptr, StartValue, Index); - + } case IK_NoInduction: return nullptr; } llvm_unreachable("invalid enum"); } -bool InductionDescriptor::isInductionPHI(PHINode *Phi, ScalarEvolution *SE, - InductionDescriptor &D) { +bool InductionDescriptor::isInductionPHI(PHINode *Phi, + PredicatedScalarEvolution &PSE, + InductionDescriptor &D, + bool Assume) { + Type *PhiTy = Phi->getType(); + // We only handle integer and pointer inductions variables. + if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy()) + return false; + + const SCEV *PhiScev = PSE.getSCEV(Phi); + const auto *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); + + // We need this expression to be an AddRecExpr. + if (Assume && !AR) + AR = PSE.getAsAddRec(Phi); + + if (!AR) { + DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); + return false; + } + + return isInductionPHI(Phi, PSE.getSE(), D, AR); +} + +bool InductionDescriptor::isInductionPHI(PHINode *Phi, + ScalarEvolution *SE, + InductionDescriptor &D, + const SCEV *Expr) { Type *PhiTy = Phi->getType(); // We only handle integer and pointer inductions variables. if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy()) return false; // Check that the PHI is consecutive. - const SCEV *PhiScev = SE->getSCEV(Phi); + const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi); const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); + if (!AR) { DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); return false; @@ -678,17 +779,22 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, ScalarEvolution *SE, Phi->getIncomingValueForBlock(AR->getLoop()->getLoopPreheader()); const SCEV *Step = AR->getStepRecurrence(*SE); // Calculate the pointer stride and check if it is consecutive. - const SCEVConstant *C = dyn_cast<SCEVConstant>(Step); - if (!C) + // The stride may be a constant or a loop invariant integer value. + const SCEVConstant *ConstStep = dyn_cast<SCEVConstant>(Step); + if (!ConstStep && !SE->isLoopInvariant(Step, AR->getLoop())) return false; - ConstantInt *CV = C->getValue(); if (PhiTy->isIntegerTy()) { - D = InductionDescriptor(StartValue, IK_IntInduction, CV); + D = InductionDescriptor(StartValue, IK_IntInduction, Step); return true; } assert(PhiTy->isPointerTy() && "The PHI must be a pointer"); + // Pointer induction should be a constant. + if (!ConstStep) + return false; + + ConstantInt *CV = ConstStep->getValue(); Type *PointerElementType = PhiTy->getPointerElementType(); // The pointer stride cannot be determined if the pointer element type is not // sized. @@ -703,8 +809,8 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, ScalarEvolution *SE, int64_t CVSize = CV->getSExtValue(); if (CVSize % Size) return false; - auto *StepValue = ConstantInt::getSigned(CV->getType(), CVSize / Size); - + auto *StepValue = SE->getConstant(CV->getType(), CVSize / Size, + true /* signed */); D = InductionDescriptor(StartValue, IK_PtrInduction, StepValue); return true; } @@ -727,3 +833,137 @@ SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) { return UsedOutside; } + +void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) { + // By definition, all loop passes need the LoopInfo analysis and the + // Dominator tree it depends on. Because they all participate in the loop + // pass manager, they must also preserve these. + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + + // We must also preserve LoopSimplify and LCSSA. We locally access their IDs + // here because users shouldn't directly get them from this header. + extern char &LoopSimplifyID; + extern char &LCSSAID; + AU.addRequiredID(LoopSimplifyID); + AU.addPreservedID(LoopSimplifyID); + AU.addRequiredID(LCSSAID); + AU.addPreservedID(LCSSAID); + + // Loop passes are designed to run inside of a loop pass manager which means + // that any function analyses they require must be required by the first loop + // pass in the manager (so that it is computed before the loop pass manager + // runs) and preserved by all loop pasess in the manager. To make this + // reasonably robust, the set needed for most loop passes is maintained here. + // If your loop pass requires an analysis not listed here, you will need to + // carefully audit the loop pass manager nesting structure that results. + AU.addRequired<AAResultsWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<BasicAAWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<SCEVAAWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); +} + +/// Manually defined generic "LoopPass" dependency initialization. This is used +/// to initialize the exact set of passes from above in \c +/// getLoopAnalysisUsage. It can be used within a loop pass's initialization +/// with: +/// +/// INITIALIZE_PASS_DEPENDENCY(LoopPass) +/// +/// As-if "LoopPass" were a pass. +void llvm::initializeLoopPassPass(PassRegistry &Registry) { + INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) + INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) + INITIALIZE_PASS_DEPENDENCY(LoopSimplify) + INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) + INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) + INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) + INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) + INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) + INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +} + +/// \brief Find string metadata for loop +/// +/// If it has a value (e.g. {"llvm.distribute", 1} return the value as an +/// operand or null otherwise. If the string metadata is not found return +/// Optional's not-a-value. +Optional<const MDOperand *> llvm::findStringMetadataForLoop(Loop *TheLoop, + StringRef Name) { + MDNode *LoopID = TheLoop->getLoopID(); + // Return none if LoopID is false. + if (!LoopID) + return None; + + // First operand should refer to the loop id itself. + assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); + assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + + // Iterate over LoopID operands and look for MDString Metadata + for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (!MD) + continue; + MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + if (!S) + continue; + // Return true if MDString holds expected MetaData. + if (Name.equals(S->getString())) + switch (MD->getNumOperands()) { + case 1: + return nullptr; + case 2: + return &MD->getOperand(1); + default: + llvm_unreachable("loop metadata has 0 or 1 operand"); + } + } + return None; +} + +/// Returns true if the instruction in a loop is guaranteed to execute at least +/// once. +bool llvm::isGuaranteedToExecute(const Instruction &Inst, + const DominatorTree *DT, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo) { + // We have to check to make sure that the instruction dominates all + // of the exit blocks. If it doesn't, then there is a path out of the loop + // which does not execute this instruction, so we can't hoist it. + + // If the instruction is in the header block for the loop (which is very + // common), it is always guaranteed to dominate the exit blocks. Since this + // is a common case, and can save some work, check it now. + if (Inst.getParent() == CurLoop->getHeader()) + // If there's a throw in the header block, we can't guarantee we'll reach + // Inst. + return !SafetyInfo->HeaderMayThrow; + + // Somewhere in this loop there is an instruction which may throw and make us + // exit the loop. + if (SafetyInfo->MayThrow) + return false; + + // Get the exit blocks for the current loop. + SmallVector<BasicBlock *, 8> ExitBlocks; + CurLoop->getExitBlocks(ExitBlocks); + + // Verify that the block dominates each of the exit blocks of the loop. + for (BasicBlock *ExitBlock : ExitBlocks) + if (!DT->dominates(Inst.getParent(), ExitBlock)) + return false; + + // As a degenerate case, if the loop is statically infinite then we haven't + // proven anything since there are no exit blocks. + if (ExitBlocks.empty()) + return false; + + // FIXME: In general, we have to prove that the loop isn't an infinite loop. + // See http::llvm.org/PR24078 . (The "ExitBlocks.empty()" check above is + // just a special case of this.) + return true; +} diff --git a/lib/Transforms/Utils/LoopVersioning.cpp b/lib/Transforms/Utils/LoopVersioning.cpp index 9a2a06cf6891..b3c61691da30 100644 --- a/lib/Transforms/Utils/LoopVersioning.cpp +++ b/lib/Transforms/Utils/LoopVersioning.cpp @@ -18,11 +18,18 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" using namespace llvm; +static cl::opt<bool> + AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true), + cl::Hidden, + cl::desc("Add no-alias annotation for instructions that " + "are disambiguated by memchecks")); + LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE, bool UseLAIChecks) @@ -32,12 +39,12 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, assert(L->getLoopPreheader() && "No preheader"); if (UseLAIChecks) { setAliasChecks(LAI.getRuntimePointerChecking()->getChecks()); - setSCEVChecks(LAI.PSE.getUnionPredicate()); + setSCEVChecks(LAI.getPSE().getUnionPredicate()); } } void LoopVersioning::setAliasChecks( - const SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) { + SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) { AliasChecks = std::move(Checks); } @@ -56,9 +63,8 @@ void LoopVersioning::versionLoop( BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader(); std::tie(FirstCheckInst, MemRuntimeCheck) = LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks); - assert(MemRuntimeCheck && "called even though needsAnyChecking = false"); - const SCEVUnionPredicate &Pred = LAI.PSE.getUnionPredicate(); + const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate(); SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), "scev.check"); SCEVRuntimeCheck = @@ -71,7 +77,7 @@ void LoopVersioning::versionLoop( if (MemRuntimeCheck && SCEVRuntimeCheck) { RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck, - SCEVRuntimeCheck, "ldist.safe"); + SCEVRuntimeCheck, "lver.safe"); if (auto *I = dyn_cast<Instruction>(RuntimeCheck)) I->insertBefore(RuntimeCheckBB->getTerminator()); } else @@ -119,16 +125,14 @@ void LoopVersioning::addPHINodes( const SmallVectorImpl<Instruction *> &DefsUsedOutside) { BasicBlock *PHIBlock = VersionedLoop->getExitBlock(); assert(PHIBlock && "No single successor to loop exit block"); + PHINode *PN; + // First add a single-operand PHI for each DefsUsedOutside if one does not + // exists yet. for (auto *Inst : DefsUsedOutside) { - auto *NonVersionedLoopInst = cast<Instruction>(VMap[Inst]); - PHINode *PN; - - // First see if we have a single-operand PHI with the value defined by the + // See if we have a single-operand PHI with the value defined by the // original loop. for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) { - assert(PN->getNumOperands() == 1 && - "Exit block should only have on predecessor"); if (PN->getIncomingValue(0) == Inst) break; } @@ -141,7 +145,179 @@ void LoopVersioning::addPHINodes( User->replaceUsesOfWith(Inst, PN); PN->addIncoming(Inst, VersionedLoop->getExitingBlock()); } - // Add the new incoming value from the non-versioned loop. - PN->addIncoming(NonVersionedLoopInst, NonVersionedLoop->getExitingBlock()); } + + // Then for each PHI add the operand for the edge from the cloned loop. + for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) { + assert(PN->getNumOperands() == 1 && + "Exit block should only have on predecessor"); + + // If the definition was cloned used that otherwise use the same value. + Value *ClonedValue = PN->getIncomingValue(0); + auto Mapped = VMap.find(ClonedValue); + if (Mapped != VMap.end()) + ClonedValue = Mapped->second; + + PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock()); + } +} + +void LoopVersioning::prepareNoAliasMetadata() { + // We need to turn the no-alias relation between pointer checking groups into + // no-aliasing annotations between instructions. + // + // We accomplish this by mapping each pointer checking group (a set of + // pointers memchecked together) to an alias scope and then also mapping each + // group to the list of scopes it can't alias. + + const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking(); + LLVMContext &Context = VersionedLoop->getHeader()->getContext(); + + // First allocate an aliasing scope for each pointer checking group. + // + // While traversing through the checking groups in the loop, also create a + // reverse map from pointers to the pointer checking group they were assigned + // to. + MDBuilder MDB(Context); + MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain"); + + for (const auto &Group : RtPtrChecking->CheckingGroups) { + GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain); + + for (unsigned PtrIdx : Group.Members) + PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group; + } + + // Go through the checks and for each pointer group, collect the scopes for + // each non-aliasing pointer group. + DenseMap<const RuntimePointerChecking::CheckingPtrGroup *, + SmallVector<Metadata *, 4>> + GroupToNonAliasingScopes; + + for (const auto &Check : AliasChecks) + GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]); + + // Finally, transform the above to actually map to scope list which is what + // the metadata uses. + + for (auto Pair : GroupToNonAliasingScopes) + GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second); +} + +void LoopVersioning::annotateLoopWithNoAlias() { + if (!AnnotateNoAlias) + return; + + // First prepare the maps. + prepareNoAliasMetadata(); + + // Add the scope and no-alias metadata to the instructions. + for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) { + annotateInstWithNoAlias(I); + } +} + +void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst, + const Instruction *OrigInst) { + if (!AnnotateNoAlias) + return; + + LLVMContext &Context = VersionedLoop->getHeader()->getContext(); + const Value *Ptr = isa<LoadInst>(OrigInst) + ? cast<LoadInst>(OrigInst)->getPointerOperand() + : cast<StoreInst>(OrigInst)->getPointerOperand(); + + // Find the group for the pointer and then add the scope metadata. + auto Group = PtrToGroup.find(Ptr); + if (Group != PtrToGroup.end()) { + VersionedInst->setMetadata( + LLVMContext::MD_alias_scope, + MDNode::concatenate( + VersionedInst->getMetadata(LLVMContext::MD_alias_scope), + MDNode::get(Context, GroupToScope[Group->second]))); + + // Add the no-alias metadata. + auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second); + if (NonAliasingScopeList != GroupToNonAliasingScopeList.end()) + VersionedInst->setMetadata( + LLVMContext::MD_noalias, + MDNode::concatenate( + VersionedInst->getMetadata(LLVMContext::MD_noalias), + NonAliasingScopeList->second)); + } +} + +namespace { +/// \brief Also expose this is a pass. Currently this is only used for +/// unit-testing. It adds all memchecks necessary to remove all may-aliasing +/// array accesses from the loop. +class LoopVersioningPass : public FunctionPass { +public: + LoopVersioningPass() : FunctionPass(ID) { + initializeLoopVersioningPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + + // Build up a worklist of inner-loops to version. This is necessary as the + // act of versioning a loop creates new loops and can invalidate iterators + // across the loops. + SmallVector<Loop *, 8> Worklist; + + for (Loop *TopLevelLoop : *LI) + for (Loop *L : depth_first(TopLevelLoop)) + // We only handle inner-most loops. + if (L->empty()) + Worklist.push_back(L); + + // Now walk the identified inner loops. + bool Changed = false; + for (Loop *L : Worklist) { + const LoopAccessInfo &LAI = LAA->getInfo(L); + if (LAI.getNumRuntimePointerChecks() || + !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { + LoopVersioning LVer(LAI, L, LI, DT, SE); + LVer.versionLoop(); + LVer.annotateLoopWithNoAlias(); + Changed = true; + } + } + + return Changed; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<LoopAccessLegacyAnalysis>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + } + + static char ID; +}; +} + +#define LVER_OPTION "loop-versioning" +#define DEBUG_TYPE LVER_OPTION + +char LoopVersioningPass::ID; +static const char LVer_name[] = "Loop Versioning"; + +INITIALIZE_PASS_BEGIN(LoopVersioningPass, LVER_OPTION, LVer_name, false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(LoopVersioningPass, LVER_OPTION, LVer_name, false, false) + +namespace llvm { +FunctionPass *createLoopVersioningPass() { + return new LoopVersioningPass(); +} } diff --git a/lib/Transforms/Utils/LowerInvoke.cpp b/lib/Transforms/Utils/LowerInvoke.cpp index b0ad4d5e84a1..1b31c5ae580a 100644 --- a/lib/Transforms/Utils/LowerInvoke.cpp +++ b/lib/Transforms/Utils/LowerInvoke.cpp @@ -14,14 +14,13 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" -#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Scalar.h" using namespace llvm; #define DEBUG_TYPE "lowerinvoke" @@ -53,8 +52,8 @@ FunctionPass *llvm::createLowerInvokePass() { bool LowerInvoke::runOnFunction(Function &F) { bool Changed = false; - for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) - if (InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) { + for (BasicBlock &BB : F) + if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator())) { SmallVector<Value*,16> CallArgs(II->op_begin(), II->op_end() - 3); // Insert a normal call instruction... CallInst *NewCall = CallInst::Create(II->getCalledValue(), @@ -69,10 +68,10 @@ bool LowerInvoke::runOnFunction(Function &F) { BranchInst::Create(II->getNormalDest(), II); // Remove any PHI node entries from the exception destination. - II->getUnwindDest()->removePredecessor(&*BB); + II->getUnwindDest()->removePredecessor(&BB); // Remove the invoke instruction now. - BB->getInstList().erase(II); + BB.getInstList().erase(II); ++NumInvokes; Changed = true; } diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index 52beb1542497..5c07469869ff 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -59,12 +59,6 @@ namespace { bool runOnFunction(Function &F) override; - void getAnalysisUsage(AnalysisUsage &AU) const override { - // This is a cluster of orthogonal Transforms - AU.addPreserved<UnifyFunctionExitNodes>(); - AU.addPreservedID(LowerInvokePassID); - } - struct CaseRange { ConstantInt* Low; ConstantInt* High; @@ -192,8 +186,8 @@ static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, } // Remove incoming values in the reverse order to prevent invalidating // *successive* index. - for (auto III = Indices.rbegin(), IIE = Indices.rend(); III != IIE; ++III) - PN->removeIncomingValue(*III); + for (unsigned III : reverse(Indices)) + PN->removeIncomingValue(III); } } diff --git a/lib/Transforms/Utils/Makefile b/lib/Transforms/Utils/Makefile deleted file mode 100644 index d1e9336d67f0..000000000000 --- a/lib/Transforms/Utils/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Transforms/Utils/Makefile -----------------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMTransformUtils -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp index aa1e35ddba02..1419254bcb4f 100644 --- a/lib/Transforms/Utils/Mem2Reg.cpp +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Mem2Reg.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" using namespace llvm; @@ -26,51 +27,11 @@ using namespace llvm; STATISTIC(NumPromoted, "Number of alloca's promoted"); -namespace { - struct PromotePass : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - PromotePass() : FunctionPass(ID) { - initializePromotePassPass(*PassRegistry::getPassRegistry()); - } - - // runOnFunction - To run this pass, first we calculate the alloca - // instructions that are safe for promotion, then we promote each one. - // - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.setPreservesCFG(); - // This is a cluster of orthogonal Transforms - AU.addPreserved<UnifyFunctionExitNodes>(); - AU.addPreservedID(LowerSwitchID); - AU.addPreservedID(LowerInvokePassID); - } - }; -} // end of anonymous namespace - -char PromotePass::ID = 0; -INITIALIZE_PASS_BEGIN(PromotePass, "mem2reg", "Promote Memory to Register", - false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(PromotePass, "mem2reg", "Promote Memory to Register", - false, false) - -bool PromotePass::runOnFunction(Function &F) { - std::vector<AllocaInst*> Allocas; - - BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function - - if (F.hasFnAttribute(Attribute::OptimizeNone)) - return false; - - bool Changed = false; - - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - AssumptionCache &AC = - getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); +static bool promoteMemoryToRegister(Function &F, DominatorTree &DT, + AssumptionCache &AC) { + std::vector<AllocaInst *> Allocas; + BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function + bool Changed = false; while (1) { Allocas.clear(); @@ -78,22 +39,69 @@ bool PromotePass::runOnFunction(Function &F) { // Find allocas that are safe to promote, by looking at all instructions in // the entry node for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I) - if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) // Is it an alloca? + if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) // Is it an alloca? if (isAllocaPromotable(AI)) Allocas.push_back(AI); - if (Allocas.empty()) break; + if (Allocas.empty()) + break; PromoteMemToReg(Allocas, DT, nullptr, &AC); NumPromoted += Allocas.size(); Changed = true; } - return Changed; } +PreservedAnalyses PromotePass::run(Function &F, AnalysisManager<Function> &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + if (!promoteMemoryToRegister(F, DT, AC)) + return PreservedAnalyses::all(); + + // FIXME: This should also 'preserve the CFG'. + return PreservedAnalyses::none(); +} + +namespace { +struct PromoteLegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + PromoteLegacyPass() : FunctionPass(ID) { + initializePromoteLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + // runOnFunction - To run this pass, first we calculate the alloca + // instructions that are safe for promotion, then we promote each one. + // + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionCache &AC = + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + return promoteMemoryToRegister(F, DT, AC); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.setPreservesCFG(); + } + }; +} // end of anonymous namespace + +char PromoteLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PromoteLegacyPass, "mem2reg", "Promote Memory to " + "Register", + false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(PromoteLegacyPass, "mem2reg", "Promote Memory to Register", + false, false) + // createPromoteMemoryToRegister - Provide an entry point to create this pass. // FunctionPass *llvm::createPromoteMemoryToRegisterPass() { - return new PromotePass(); + return new PromoteLegacyPass(); } diff --git a/lib/Transforms/Utils/MemorySSA.cpp b/lib/Transforms/Utils/MemorySSA.cpp new file mode 100644 index 000000000000..8ba3cae43b18 --- /dev/null +++ b/lib/Transforms/Utils/MemorySSA.cpp @@ -0,0 +1,1361 @@ +//===-- MemorySSA.cpp - Memory SSA Builder---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------===// +// +// This file implements the MemorySSA class. +// +//===----------------------------------------------------------------===// +#include "llvm/Transforms/Utils/MemorySSA.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Transforms/Scalar.h" +#include <algorithm> + +#define DEBUG_TYPE "memoryssa" +using namespace llvm; +STATISTIC(NumClobberCacheLookups, "Number of Memory SSA version cache lookups"); +STATISTIC(NumClobberCacheHits, "Number of Memory SSA version cache hits"); +STATISTIC(NumClobberCacheInserts, "Number of MemorySSA version cache inserts"); + +INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, + true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, + true) + +INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa", + "Memory SSA Printer", false, false) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", + "Memory SSA Printer", false, false) + +static cl::opt<bool> + VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden, + cl::desc("Verify MemorySSA in legacy printer pass.")); + +namespace llvm { +/// \brief An assembly annotator class to print Memory SSA information in +/// comments. +class MemorySSAAnnotatedWriter : public AssemblyAnnotationWriter { + friend class MemorySSA; + const MemorySSA *MSSA; + +public: + MemorySSAAnnotatedWriter(const MemorySSA *M) : MSSA(M) {} + + virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, + formatted_raw_ostream &OS) { + if (MemoryAccess *MA = MSSA->getMemoryAccess(BB)) + OS << "; " << *MA << "\n"; + } + + virtual void emitInstructionAnnot(const Instruction *I, + formatted_raw_ostream &OS) { + if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) + OS << "; " << *MA << "\n"; + } +}; + +/// \brief A MemorySSAWalker that does AA walks and caching of lookups to +/// disambiguate accesses. +/// +/// FIXME: The current implementation of this can take quadratic space in rare +/// cases. This can be fixed, but it is something to note until it is fixed. +/// +/// In order to trigger this behavior, you need to store to N distinct locations +/// (that AA can prove don't alias), perform M stores to other memory +/// locations that AA can prove don't alias any of the initial N locations, and +/// then load from all of the N locations. In this case, we insert M cache +/// entries for each of the N loads. +/// +/// For example: +/// define i32 @foo() { +/// %a = alloca i32, align 4 +/// %b = alloca i32, align 4 +/// store i32 0, i32* %a, align 4 +/// store i32 0, i32* %b, align 4 +/// +/// ; Insert M stores to other memory that doesn't alias %a or %b here +/// +/// %c = load i32, i32* %a, align 4 ; Caches M entries in +/// ; CachedUpwardsClobberingAccess for the +/// ; MemoryLocation %a +/// %d = load i32, i32* %b, align 4 ; Caches M entries in +/// ; CachedUpwardsClobberingAccess for the +/// ; MemoryLocation %b +/// +/// ; For completeness' sake, loading %a or %b again would not cache *another* +/// ; M entries. +/// %r = add i32 %c, %d +/// ret i32 %r +/// } +class MemorySSA::CachingWalker final : public MemorySSAWalker { +public: + CachingWalker(MemorySSA *, AliasAnalysis *, DominatorTree *); + ~CachingWalker() override; + + MemoryAccess *getClobberingMemoryAccess(const Instruction *) override; + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, + MemoryLocation &) override; + void invalidateInfo(MemoryAccess *) override; + +protected: + struct UpwardsMemoryQuery; + MemoryAccess *doCacheLookup(const MemoryAccess *, const UpwardsMemoryQuery &, + const MemoryLocation &); + + void doCacheInsert(const MemoryAccess *, MemoryAccess *, + const UpwardsMemoryQuery &, const MemoryLocation &); + + void doCacheRemove(const MemoryAccess *, const UpwardsMemoryQuery &, + const MemoryLocation &); + +private: + MemoryAccessPair UpwardsDFSWalk(MemoryAccess *, const MemoryLocation &, + UpwardsMemoryQuery &, bool); + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &); + bool instructionClobbersQuery(const MemoryDef *, UpwardsMemoryQuery &, + const MemoryLocation &Loc) const; + void verifyRemoved(MemoryAccess *); + SmallDenseMap<ConstMemoryAccessPair, MemoryAccess *> + CachedUpwardsClobberingAccess; + DenseMap<const MemoryAccess *, MemoryAccess *> CachedUpwardsClobberingCall; + AliasAnalysis *AA; + DominatorTree *DT; +}; +} + +namespace { +struct RenamePassData { + DomTreeNode *DTN; + DomTreeNode::const_iterator ChildIt; + MemoryAccess *IncomingVal; + + RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It, + MemoryAccess *M) + : DTN(D), ChildIt(It), IncomingVal(M) {} + void swap(RenamePassData &RHS) { + std::swap(DTN, RHS.DTN); + std::swap(ChildIt, RHS.ChildIt); + std::swap(IncomingVal, RHS.IncomingVal); + } +}; +} + +namespace llvm { +/// \brief Rename a single basic block into MemorySSA form. +/// Uses the standard SSA renaming algorithm. +/// \returns The new incoming value. +MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB, + MemoryAccess *IncomingVal) { + auto It = PerBlockAccesses.find(BB); + // Skip most processing if the list is empty. + if (It != PerBlockAccesses.end()) { + AccessList *Accesses = It->second.get(); + for (MemoryAccess &L : *Accesses) { + switch (L.getValueID()) { + case Value::MemoryUseVal: + cast<MemoryUse>(&L)->setDefiningAccess(IncomingVal); + break; + case Value::MemoryDefVal: + // We can't legally optimize defs, because we only allow single + // memory phis/uses on operations, and if we optimize these, we can + // end up with multiple reaching defs. Uses do not have this + // problem, since they do not produce a value + cast<MemoryDef>(&L)->setDefiningAccess(IncomingVal); + IncomingVal = &L; + break; + case Value::MemoryPhiVal: + IncomingVal = &L; + break; + } + } + } + + // Pass through values to our successors + for (const BasicBlock *S : successors(BB)) { + auto It = PerBlockAccesses.find(S); + // Rename the phi nodes in our successor block + if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) + continue; + AccessList *Accesses = It->second.get(); + auto *Phi = cast<MemoryPhi>(&Accesses->front()); + Phi->addIncoming(IncomingVal, BB); + } + + return IncomingVal; +} + +/// \brief This is the standard SSA renaming algorithm. +/// +/// We walk the dominator tree in preorder, renaming accesses, and then filling +/// in phi nodes in our successors. +void MemorySSA::renamePass(DomTreeNode *Root, MemoryAccess *IncomingVal, + SmallPtrSet<BasicBlock *, 16> &Visited) { + SmallVector<RenamePassData, 32> WorkStack; + IncomingVal = renameBlock(Root->getBlock(), IncomingVal); + WorkStack.push_back({Root, Root->begin(), IncomingVal}); + Visited.insert(Root->getBlock()); + + while (!WorkStack.empty()) { + DomTreeNode *Node = WorkStack.back().DTN; + DomTreeNode::const_iterator ChildIt = WorkStack.back().ChildIt; + IncomingVal = WorkStack.back().IncomingVal; + + if (ChildIt == Node->end()) { + WorkStack.pop_back(); + } else { + DomTreeNode *Child = *ChildIt; + ++WorkStack.back().ChildIt; + BasicBlock *BB = Child->getBlock(); + Visited.insert(BB); + IncomingVal = renameBlock(BB, IncomingVal); + WorkStack.push_back({Child, Child->begin(), IncomingVal}); + } + } +} + +/// \brief Compute dominator levels, used by the phi insertion algorithm above. +void MemorySSA::computeDomLevels(DenseMap<DomTreeNode *, unsigned> &DomLevels) { + for (auto DFI = df_begin(DT->getRootNode()), DFE = df_end(DT->getRootNode()); + DFI != DFE; ++DFI) + DomLevels[*DFI] = DFI.getPathLength() - 1; +} + +/// \brief This handles unreachable block accesses by deleting phi nodes in +/// unreachable blocks, and marking all other unreachable MemoryAccess's as +/// being uses of the live on entry definition. +void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) { + assert(!DT->isReachableFromEntry(BB) && + "Reachable block found while handling unreachable blocks"); + + // Make sure phi nodes in our reachable successors end up with a + // LiveOnEntryDef for our incoming edge, even though our block is forward + // unreachable. We could just disconnect these blocks from the CFG fully, + // but we do not right now. + for (const BasicBlock *S : successors(BB)) { + if (!DT->isReachableFromEntry(S)) + continue; + auto It = PerBlockAccesses.find(S); + // Rename the phi nodes in our successor block + if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) + continue; + AccessList *Accesses = It->second.get(); + auto *Phi = cast<MemoryPhi>(&Accesses->front()); + Phi->addIncoming(LiveOnEntryDef.get(), BB); + } + + auto It = PerBlockAccesses.find(BB); + if (It == PerBlockAccesses.end()) + return; + + auto &Accesses = It->second; + for (auto AI = Accesses->begin(), AE = Accesses->end(); AI != AE;) { + auto Next = std::next(AI); + // If we have a phi, just remove it. We are going to replace all + // users with live on entry. + if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(AI)) + UseOrDef->setDefiningAccess(LiveOnEntryDef.get()); + else + Accesses->erase(AI); + AI = Next; + } +} + +MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT) + : AA(AA), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr), + NextID(0) { + buildMemorySSA(); +} + +MemorySSA::MemorySSA(MemorySSA &&MSSA) + : AA(MSSA.AA), DT(MSSA.DT), F(MSSA.F), + ValueToMemoryAccess(std::move(MSSA.ValueToMemoryAccess)), + PerBlockAccesses(std::move(MSSA.PerBlockAccesses)), + LiveOnEntryDef(std::move(MSSA.LiveOnEntryDef)), + Walker(std::move(MSSA.Walker)), NextID(MSSA.NextID) { + // Update the Walker MSSA pointer so it doesn't point to the moved-from MSSA + // object any more. + Walker->MSSA = this; +} + +MemorySSA::~MemorySSA() { + // Drop all our references + for (const auto &Pair : PerBlockAccesses) + for (MemoryAccess &MA : *Pair.second) + MA.dropAllReferences(); +} + +MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) { + auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr)); + + if (Res.second) + Res.first->second = make_unique<AccessList>(); + return Res.first->second.get(); +} + +void MemorySSA::buildMemorySSA() { + // We create an access to represent "live on entry", for things like + // arguments or users of globals, where the memory they use is defined before + // the beginning of the function. We do not actually insert it into the IR. + // We do not define a live on exit for the immediate uses, and thus our + // semantics do *not* imply that something with no immediate uses can simply + // be removed. + BasicBlock &StartingPoint = F.getEntryBlock(); + LiveOnEntryDef = make_unique<MemoryDef>(F.getContext(), nullptr, nullptr, + &StartingPoint, NextID++); + + // We maintain lists of memory accesses per-block, trading memory for time. We + // could just look up the memory access for every possible instruction in the + // stream. + SmallPtrSet<BasicBlock *, 32> DefiningBlocks; + SmallPtrSet<BasicBlock *, 32> DefUseBlocks; + // Go through each block, figure out where defs occur, and chain together all + // the accesses. + for (BasicBlock &B : F) { + bool InsertIntoDef = false; + AccessList *Accesses = nullptr; + for (Instruction &I : B) { + MemoryUseOrDef *MUD = createNewAccess(&I); + if (!MUD) + continue; + InsertIntoDef |= isa<MemoryDef>(MUD); + + if (!Accesses) + Accesses = getOrCreateAccessList(&B); + Accesses->push_back(MUD); + } + if (InsertIntoDef) + DefiningBlocks.insert(&B); + if (Accesses) + DefUseBlocks.insert(&B); + } + + // Compute live-in. + // Live in is normally defined as "all the blocks on the path from each def to + // each of it's uses". + // MemoryDef's are implicit uses of previous state, so they are also uses. + // This means we don't really have def-only instructions. The only + // MemoryDef's that are not really uses are those that are of the LiveOnEntry + // variable (because LiveOnEntry can reach anywhere, and every def is a + // must-kill of LiveOnEntry). + // In theory, you could precisely compute live-in by using alias-analysis to + // disambiguate defs and uses to see which really pair up with which. + // In practice, this would be really expensive and difficult. So we simply + // assume all defs are also uses that need to be kept live. + // Because of this, the end result of this live-in computation will be "the + // entire set of basic blocks that reach any use". + + SmallPtrSet<BasicBlock *, 32> LiveInBlocks; + SmallVector<BasicBlock *, 64> LiveInBlockWorklist(DefUseBlocks.begin(), + DefUseBlocks.end()); + // Now that we have a set of blocks where a value is live-in, recursively add + // predecessors until we find the full region the value is live. + while (!LiveInBlockWorklist.empty()) { + BasicBlock *BB = LiveInBlockWorklist.pop_back_val(); + + // The block really is live in here, insert it into the set. If already in + // the set, then it has already been processed. + if (!LiveInBlocks.insert(BB).second) + continue; + + // Since the value is live into BB, it is either defined in a predecessor or + // live into it to. + LiveInBlockWorklist.append(pred_begin(BB), pred_end(BB)); + } + + // Determine where our MemoryPhi's should go + ForwardIDFCalculator IDFs(*DT); + IDFs.setDefiningBlocks(DefiningBlocks); + IDFs.setLiveInBlocks(LiveInBlocks); + SmallVector<BasicBlock *, 32> IDFBlocks; + IDFs.calculate(IDFBlocks); + + // Now place MemoryPhi nodes. + for (auto &BB : IDFBlocks) { + // Insert phi node + AccessList *Accesses = getOrCreateAccessList(BB); + MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); + ValueToMemoryAccess.insert(std::make_pair(BB, Phi)); + // Phi's always are placed at the front of the block. + Accesses->push_front(Phi); + } + + // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get + // filled in with all blocks. + SmallPtrSet<BasicBlock *, 16> Visited; + renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited); + + MemorySSAWalker *Walker = getWalker(); + + // Now optimize the MemoryUse's defining access to point to the nearest + // dominating clobbering def. + // This ensures that MemoryUse's that are killed by the same store are + // immediate users of that store, one of the invariants we guarantee. + for (auto DomNode : depth_first(DT)) { + BasicBlock *BB = DomNode->getBlock(); + auto AI = PerBlockAccesses.find(BB); + if (AI == PerBlockAccesses.end()) + continue; + AccessList *Accesses = AI->second.get(); + for (auto &MA : *Accesses) { + if (auto *MU = dyn_cast<MemoryUse>(&MA)) { + Instruction *Inst = MU->getMemoryInst(); + MU->setDefiningAccess(Walker->getClobberingMemoryAccess(Inst)); + } + } + } + + // Mark the uses in unreachable blocks as live on entry, so that they go + // somewhere. + for (auto &BB : F) + if (!Visited.count(&BB)) + markUnreachableAsLiveOnEntry(&BB); +} + +MemorySSAWalker *MemorySSA::getWalker() { + if (Walker) + return Walker.get(); + + Walker = make_unique<CachingWalker>(this, AA, DT); + return Walker.get(); +} + +MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { + assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); + AccessList *Accesses = getOrCreateAccessList(BB); + MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); + ValueToMemoryAccess.insert(std::make_pair(BB, Phi)); + // Phi's always are placed at the front of the block. + Accesses->push_front(Phi); + return Phi; +} + +MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I, + MemoryAccess *Definition) { + assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI"); + MemoryUseOrDef *NewAccess = createNewAccess(I); + assert( + NewAccess != nullptr && + "Tried to create a memory access for a non-memory touching instruction"); + NewAccess->setDefiningAccess(Definition); + return NewAccess; +} + +MemoryAccess *MemorySSA::createMemoryAccessInBB(Instruction *I, + MemoryAccess *Definition, + const BasicBlock *BB, + InsertionPlace Point) { + MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); + auto *Accesses = getOrCreateAccessList(BB); + if (Point == Beginning) { + // It goes after any phi nodes + auto AI = std::find_if( + Accesses->begin(), Accesses->end(), + [](const MemoryAccess &MA) { return !isa<MemoryPhi>(MA); }); + + Accesses->insert(AI, NewAccess); + } else { + Accesses->push_back(NewAccess); + } + + return NewAccess; +} +MemoryAccess *MemorySSA::createMemoryAccessBefore(Instruction *I, + MemoryAccess *Definition, + MemoryAccess *InsertPt) { + assert(I->getParent() == InsertPt->getBlock() && + "New and old access must be in the same block"); + MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); + auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); + Accesses->insert(AccessList::iterator(InsertPt), NewAccess); + return NewAccess; +} + +MemoryAccess *MemorySSA::createMemoryAccessAfter(Instruction *I, + MemoryAccess *Definition, + MemoryAccess *InsertPt) { + assert(I->getParent() == InsertPt->getBlock() && + "New and old access must be in the same block"); + MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); + auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); + Accesses->insertAfter(AccessList::iterator(InsertPt), NewAccess); + return NewAccess; +} + +/// \brief Helper function to create new memory accesses +MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { + // The assume intrinsic has a control dependency which we model by claiming + // that it writes arbitrarily. Ignore that fake memory dependency here. + // FIXME: Replace this special casing with a more accurate modelling of + // assume's control dependency. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + if (II->getIntrinsicID() == Intrinsic::assume) + return nullptr; + + // Find out what affect this instruction has on memory. + ModRefInfo ModRef = AA->getModRefInfo(I); + bool Def = bool(ModRef & MRI_Mod); + bool Use = bool(ModRef & MRI_Ref); + + // It's possible for an instruction to not modify memory at all. During + // construction, we ignore them. + if (!Def && !Use) + return nullptr; + + assert((Def || Use) && + "Trying to create a memory access with a non-memory instruction"); + + MemoryUseOrDef *MUD; + if (Def) + MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++); + else + MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent()); + ValueToMemoryAccess.insert(std::make_pair(I, MUD)); + return MUD; +} + +MemoryAccess *MemorySSA::findDominatingDef(BasicBlock *UseBlock, + enum InsertionPlace Where) { + // Handle the initial case + if (Where == Beginning) + // The only thing that could define us at the beginning is a phi node + if (MemoryPhi *Phi = getMemoryAccess(UseBlock)) + return Phi; + + DomTreeNode *CurrNode = DT->getNode(UseBlock); + // Need to be defined by our dominator + if (Where == Beginning) + CurrNode = CurrNode->getIDom(); + Where = End; + while (CurrNode) { + auto It = PerBlockAccesses.find(CurrNode->getBlock()); + if (It != PerBlockAccesses.end()) { + auto &Accesses = It->second; + for (MemoryAccess &RA : reverse(*Accesses)) { + if (isa<MemoryDef>(RA) || isa<MemoryPhi>(RA)) + return &RA; + } + } + CurrNode = CurrNode->getIDom(); + } + return LiveOnEntryDef.get(); +} + +/// \brief Returns true if \p Replacer dominates \p Replacee . +bool MemorySSA::dominatesUse(const MemoryAccess *Replacer, + const MemoryAccess *Replacee) const { + if (isa<MemoryUseOrDef>(Replacee)) + return DT->dominates(Replacer->getBlock(), Replacee->getBlock()); + const auto *MP = cast<MemoryPhi>(Replacee); + // For a phi node, the use occurs in the predecessor block of the phi node. + // Since we may occur multiple times in the phi node, we have to check each + // operand to ensure Replacer dominates each operand where Replacee occurs. + for (const Use &Arg : MP->operands()) { + if (Arg.get() != Replacee && + !DT->dominates(Replacer->getBlock(), MP->getIncomingBlock(Arg))) + return false; + } + return true; +} + +/// \brief If all arguments of a MemoryPHI are defined by the same incoming +/// argument, return that argument. +static MemoryAccess *onlySingleValue(MemoryPhi *MP) { + MemoryAccess *MA = nullptr; + + for (auto &Arg : MP->operands()) { + if (!MA) + MA = cast<MemoryAccess>(Arg); + else if (MA != Arg) + return nullptr; + } + return MA; +} + +/// \brief Properly remove \p MA from all of MemorySSA's lookup tables. +/// +/// Because of the way the intrusive list and use lists work, it is important to +/// do removal in the right order. +void MemorySSA::removeFromLookups(MemoryAccess *MA) { + assert(MA->use_empty() && + "Trying to remove memory access that still has uses"); + if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MUD->setDefiningAccess(nullptr); + // Invalidate our walker's cache if necessary + if (!isa<MemoryUse>(MA)) + Walker->invalidateInfo(MA); + // The call below to erase will destroy MA, so we can't change the order we + // are doing things here + Value *MemoryInst; + if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) { + MemoryInst = MUD->getMemoryInst(); + } else { + MemoryInst = MA->getBlock(); + } + ValueToMemoryAccess.erase(MemoryInst); + + auto AccessIt = PerBlockAccesses.find(MA->getBlock()); + std::unique_ptr<AccessList> &Accesses = AccessIt->second; + Accesses->erase(MA); + if (Accesses->empty()) + PerBlockAccesses.erase(AccessIt); +} + +void MemorySSA::removeMemoryAccess(MemoryAccess *MA) { + assert(!isLiveOnEntryDef(MA) && "Trying to remove the live on entry def"); + // We can only delete phi nodes if they have no uses, or we can replace all + // uses with a single definition. + MemoryAccess *NewDefTarget = nullptr; + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(MA)) { + // Note that it is sufficient to know that all edges of the phi node have + // the same argument. If they do, by the definition of dominance frontiers + // (which we used to place this phi), that argument must dominate this phi, + // and thus, must dominate the phi's uses, and so we will not hit the assert + // below. + NewDefTarget = onlySingleValue(MP); + assert((NewDefTarget || MP->use_empty()) && + "We can't delete this memory phi"); + } else { + NewDefTarget = cast<MemoryUseOrDef>(MA)->getDefiningAccess(); + } + + // Re-point the uses at our defining access + if (!MA->use_empty()) + MA->replaceAllUsesWith(NewDefTarget); + + // The call below to erase will destroy MA, so we can't change the order we + // are doing things here + removeFromLookups(MA); +} + +void MemorySSA::print(raw_ostream &OS) const { + MemorySSAAnnotatedWriter Writer(this); + F.print(OS, &Writer); +} + +void MemorySSA::dump() const { + MemorySSAAnnotatedWriter Writer(this); + F.print(dbgs(), &Writer); +} + +void MemorySSA::verifyMemorySSA() const { + verifyDefUses(F); + verifyDomination(F); + verifyOrdering(F); +} + +/// \brief Verify that the order and existence of MemoryAccesses matches the +/// order and existence of memory affecting instructions. +void MemorySSA::verifyOrdering(Function &F) const { + // Walk all the blocks, comparing what the lookups think and what the access + // lists think, as well as the order in the blocks vs the order in the access + // lists. + SmallVector<MemoryAccess *, 32> ActualAccesses; + for (BasicBlock &B : F) { + const AccessList *AL = getBlockAccesses(&B); + MemoryAccess *Phi = getMemoryAccess(&B); + if (Phi) + ActualAccesses.push_back(Phi); + for (Instruction &I : B) { + MemoryAccess *MA = getMemoryAccess(&I); + assert((!MA || AL) && "We have memory affecting instructions " + "in this block but they are not in the " + "access list"); + if (MA) + ActualAccesses.push_back(MA); + } + // Either we hit the assert, really have no accesses, or we have both + // accesses and an access list + if (!AL) + continue; + assert(AL->size() == ActualAccesses.size() && + "We don't have the same number of accesses in the block as on the " + "access list"); + auto ALI = AL->begin(); + auto AAI = ActualAccesses.begin(); + while (ALI != AL->end() && AAI != ActualAccesses.end()) { + assert(&*ALI == *AAI && "Not the same accesses in the same order"); + ++ALI; + ++AAI; + } + ActualAccesses.clear(); + } +} + +/// \brief Verify the domination properties of MemorySSA by checking that each +/// definition dominates all of its uses. +void MemorySSA::verifyDomination(Function &F) const { + for (BasicBlock &B : F) { + // Phi nodes are attached to basic blocks + if (MemoryPhi *MP = getMemoryAccess(&B)) { + for (User *U : MP->users()) { + BasicBlock *UseBlock; + // Phi operands are used on edges, we simulate the right domination by + // acting as if the use occurred at the end of the predecessor block. + if (MemoryPhi *P = dyn_cast<MemoryPhi>(U)) { + for (const auto &Arg : P->operands()) { + if (Arg == MP) { + UseBlock = P->getIncomingBlock(Arg); + break; + } + } + } else { + UseBlock = cast<MemoryAccess>(U)->getBlock(); + } + (void)UseBlock; + assert(DT->dominates(MP->getBlock(), UseBlock) && + "Memory PHI does not dominate it's uses"); + } + } + + for (Instruction &I : B) { + MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I)); + if (!MD) + continue; + + for (User *U : MD->users()) { + BasicBlock *UseBlock; + (void)UseBlock; + // Things are allowed to flow to phi nodes over their predecessor edge. + if (auto *P = dyn_cast<MemoryPhi>(U)) { + for (const auto &Arg : P->operands()) { + if (Arg == MD) { + UseBlock = P->getIncomingBlock(Arg); + break; + } + } + } else { + UseBlock = cast<MemoryAccess>(U)->getBlock(); + } + assert(DT->dominates(MD->getBlock(), UseBlock) && + "Memory Def does not dominate it's uses"); + } + } + } +} + +/// \brief Verify the def-use lists in MemorySSA, by verifying that \p Use +/// appears in the use list of \p Def. +/// +/// llvm_unreachable is used instead of asserts because this may be called in +/// a build without asserts. In that case, we don't want this to turn into a +/// nop. +void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const { + // The live on entry use may cause us to get a NULL def here + if (!Def) { + if (!isLiveOnEntryDef(Use)) + llvm_unreachable("Null def but use not point to live on entry def"); + } else if (std::find(Def->user_begin(), Def->user_end(), Use) == + Def->user_end()) { + llvm_unreachable("Did not find use in def's use list"); + } +} + +/// \brief Verify the immediate use information, by walking all the memory +/// accesses and verifying that, for each use, it appears in the +/// appropriate def's use list +void MemorySSA::verifyDefUses(Function &F) const { + for (BasicBlock &B : F) { + // Phi nodes are attached to basic blocks + if (MemoryPhi *Phi = getMemoryAccess(&B)) { + assert(Phi->getNumOperands() == static_cast<unsigned>(std::distance( + pred_begin(&B), pred_end(&B))) && + "Incomplete MemoryPhi Node"); + for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) + verifyUseInDefs(Phi->getIncomingValue(I), Phi); + } + + for (Instruction &I : B) { + if (MemoryAccess *MA = getMemoryAccess(&I)) { + assert(isa<MemoryUseOrDef>(MA) && + "Found a phi node not attached to a bb"); + verifyUseInDefs(cast<MemoryUseOrDef>(MA)->getDefiningAccess(), MA); + } + } + } +} + +MemoryAccess *MemorySSA::getMemoryAccess(const Value *I) const { + return ValueToMemoryAccess.lookup(I); +} + +MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const { + return cast_or_null<MemoryPhi>(getMemoryAccess((const Value *)BB)); +} + +/// \brief Determine, for two memory accesses in the same block, +/// whether \p Dominator dominates \p Dominatee. +/// \returns True if \p Dominator dominates \p Dominatee. +bool MemorySSA::locallyDominates(const MemoryAccess *Dominator, + const MemoryAccess *Dominatee) const { + + assert((Dominator->getBlock() == Dominatee->getBlock()) && + "Asking for local domination when accesses are in different blocks!"); + + // A node dominates itself. + if (Dominatee == Dominator) + return true; + + // When Dominatee is defined on function entry, it is not dominated by another + // memory access. + if (isLiveOnEntryDef(Dominatee)) + return false; + + // When Dominator is defined on function entry, it dominates the other memory + // access. + if (isLiveOnEntryDef(Dominator)) + return true; + + // Get the access list for the block + const AccessList *AccessList = getBlockAccesses(Dominator->getBlock()); + AccessList::const_reverse_iterator It(Dominator->getIterator()); + + // If we hit the beginning of the access list before we hit dominatee, we must + // dominate it + return std::none_of(It, AccessList->rend(), + [&](const MemoryAccess &MA) { return &MA == Dominatee; }); +} + +const static char LiveOnEntryStr[] = "liveOnEntry"; + +void MemoryDef::print(raw_ostream &OS) const { + MemoryAccess *UO = getDefiningAccess(); + + OS << getID() << " = MemoryDef("; + if (UO && UO->getID()) + OS << UO->getID(); + else + OS << LiveOnEntryStr; + OS << ')'; +} + +void MemoryPhi::print(raw_ostream &OS) const { + bool First = true; + OS << getID() << " = MemoryPhi("; + for (const auto &Op : operands()) { + BasicBlock *BB = getIncomingBlock(Op); + MemoryAccess *MA = cast<MemoryAccess>(Op); + if (!First) + OS << ','; + else + First = false; + + OS << '{'; + if (BB->hasName()) + OS << BB->getName(); + else + BB->printAsOperand(OS, false); + OS << ','; + if (unsigned ID = MA->getID()) + OS << ID; + else + OS << LiveOnEntryStr; + OS << '}'; + } + OS << ')'; +} + +MemoryAccess::~MemoryAccess() {} + +void MemoryUse::print(raw_ostream &OS) const { + MemoryAccess *UO = getDefiningAccess(); + OS << "MemoryUse("; + if (UO && UO->getID()) + OS << UO->getID(); + else + OS << LiveOnEntryStr; + OS << ')'; +} + +void MemoryAccess::dump() const { + print(dbgs()); + dbgs() << "\n"; +} + +char MemorySSAPrinterLegacyPass::ID = 0; + +MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) { + initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); +} + +void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); +} + +bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { + auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSA.print(dbgs()); + if (VerifyMemorySSA) + MSSA.verifyMemorySSA(); + return false; +} + +char MemorySSAAnalysis::PassID; + +MemorySSA MemorySSAAnalysis::run(Function &F, AnalysisManager<Function> &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + return MemorySSA(F, &AA, &DT); +} + +PreservedAnalyses MemorySSAPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "MemorySSA for function: " << F.getName() << "\n"; + AM.getResult<MemorySSAAnalysis>(F).print(OS); + + return PreservedAnalyses::all(); +} + +PreservedAnalyses MemorySSAVerifierPass::run(Function &F, + FunctionAnalysisManager &AM) { + AM.getResult<MemorySSAAnalysis>(F).verifyMemorySSA(); + + return PreservedAnalyses::all(); +} + +char MemorySSAWrapperPass::ID = 0; + +MemorySSAWrapperPass::MemorySSAWrapperPass() : FunctionPass(ID) { + initializeMemorySSAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void MemorySSAWrapperPass::releaseMemory() { MSSA.reset(); } + +void MemorySSAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<DominatorTreeWrapperPass>(); + AU.addRequiredTransitive<AAResultsWrapperPass>(); +} + +bool MemorySSAWrapperPass::runOnFunction(Function &F) { + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + MSSA.reset(new MemorySSA(F, &AA, &DT)); + return false; +} + +void MemorySSAWrapperPass::verifyAnalysis() const { MSSA->verifyMemorySSA(); } + +void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const { + MSSA->print(OS); +} + +MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {} + +MemorySSA::CachingWalker::CachingWalker(MemorySSA *M, AliasAnalysis *A, + DominatorTree *D) + : MemorySSAWalker(M), AA(A), DT(D) {} + +MemorySSA::CachingWalker::~CachingWalker() {} + +struct MemorySSA::CachingWalker::UpwardsMemoryQuery { + // True if we saw a phi whose predecessor was a backedge + bool SawBackedgePhi; + // True if our original query started off as a call + bool IsCall; + // The pointer location we started the query with. This will be empty if + // IsCall is true. + MemoryLocation StartingLoc; + // This is the instruction we were querying about. + const Instruction *Inst; + // Set of visited Instructions for this query. + DenseSet<MemoryAccessPair> Visited; + // Vector of visited call accesses for this query. This is separated out + // because you can always cache and lookup the result of call queries (IE when + // IsCall == true) for every call in the chain. The calls have no AA location + // associated with them with them, and thus, no context dependence. + SmallVector<const MemoryAccess *, 32> VisitedCalls; + // The MemoryAccess we actually got called with, used to test local domination + const MemoryAccess *OriginalAccess; + + UpwardsMemoryQuery() + : SawBackedgePhi(false), IsCall(false), Inst(nullptr), + OriginalAccess(nullptr) {} + + UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) + : SawBackedgePhi(false), IsCall(ImmutableCallSite(Inst)), Inst(Inst), + OriginalAccess(Access) {} +}; + +void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) { + + // TODO: We can do much better cache invalidation with differently stored + // caches. For now, for MemoryUses, we simply remove them + // from the cache, and kill the entire call/non-call cache for everything + // else. The problem is for phis or defs, currently we'd need to follow use + // chains down and invalidate anything below us in the chain that currently + // terminates at this access. + + // See if this is a MemoryUse, if so, just remove the cached info. MemoryUse + // is by definition never a barrier, so nothing in the cache could point to + // this use. In that case, we only need invalidate the info for the use + // itself. + + if (MemoryUse *MU = dyn_cast<MemoryUse>(MA)) { + UpwardsMemoryQuery Q; + Instruction *I = MU->getMemoryInst(); + Q.IsCall = bool(ImmutableCallSite(I)); + Q.Inst = I; + if (!Q.IsCall) + Q.StartingLoc = MemoryLocation::get(I); + doCacheRemove(MA, Q, Q.StartingLoc); + } else { + // If it is not a use, the best we can do right now is destroy the cache. + CachedUpwardsClobberingCall.clear(); + CachedUpwardsClobberingAccess.clear(); + } + +#ifdef EXPENSIVE_CHECKS + // Run this only when expensive checks are enabled. + verifyRemoved(MA); +#endif +} + +void MemorySSA::CachingWalker::doCacheRemove(const MemoryAccess *M, + const UpwardsMemoryQuery &Q, + const MemoryLocation &Loc) { + if (Q.IsCall) + CachedUpwardsClobberingCall.erase(M); + else + CachedUpwardsClobberingAccess.erase({M, Loc}); +} + +void MemorySSA::CachingWalker::doCacheInsert(const MemoryAccess *M, + MemoryAccess *Result, + const UpwardsMemoryQuery &Q, + const MemoryLocation &Loc) { + // This is fine for Phis, since there are times where we can't optimize them. + // Making a def its own clobber is never correct, though. + assert((Result != M || isa<MemoryPhi>(M)) && + "Something can't clobber itself!"); + ++NumClobberCacheInserts; + if (Q.IsCall) + CachedUpwardsClobberingCall[M] = Result; + else + CachedUpwardsClobberingAccess[{M, Loc}] = Result; +} + +MemoryAccess * +MemorySSA::CachingWalker::doCacheLookup(const MemoryAccess *M, + const UpwardsMemoryQuery &Q, + const MemoryLocation &Loc) { + ++NumClobberCacheLookups; + MemoryAccess *Result; + + if (Q.IsCall) + Result = CachedUpwardsClobberingCall.lookup(M); + else + Result = CachedUpwardsClobberingAccess.lookup({M, Loc}); + + if (Result) + ++NumClobberCacheHits; + return Result; +} + +bool MemorySSA::CachingWalker::instructionClobbersQuery( + const MemoryDef *MD, UpwardsMemoryQuery &Q, + const MemoryLocation &Loc) const { + Instruction *DefMemoryInst = MD->getMemoryInst(); + assert(DefMemoryInst && "Defining instruction not actually an instruction"); + + if (!Q.IsCall) + return AA->getModRefInfo(DefMemoryInst, Loc) & MRI_Mod; + + // If this is a call, mark it for caching + if (ImmutableCallSite(DefMemoryInst)) + Q.VisitedCalls.push_back(MD); + ModRefInfo I = AA->getModRefInfo(DefMemoryInst, ImmutableCallSite(Q.Inst)); + return I != MRI_NoModRef; +} + +MemoryAccessPair MemorySSA::CachingWalker::UpwardsDFSWalk( + MemoryAccess *StartingAccess, const MemoryLocation &Loc, + UpwardsMemoryQuery &Q, bool FollowingBackedge) { + MemoryAccess *ModifyingAccess = nullptr; + + auto DFI = df_begin(StartingAccess); + for (auto DFE = df_end(StartingAccess); DFI != DFE;) { + MemoryAccess *CurrAccess = *DFI; + if (MSSA->isLiveOnEntryDef(CurrAccess)) + return {CurrAccess, Loc}; + // If this is a MemoryDef, check whether it clobbers our current query. This + // needs to be done before consulting the cache, because the cache reports + // the clobber for CurrAccess. If CurrAccess is a clobber for this query, + // and we ask the cache for information first, then we might skip this + // clobber, which is bad. + if (auto *MD = dyn_cast<MemoryDef>(CurrAccess)) { + // If we hit the top, stop following this path. + // While we can do lookups, we can't sanely do inserts here unless we were + // to track everything we saw along the way, since we don't know where we + // will stop. + if (instructionClobbersQuery(MD, Q, Loc)) { + ModifyingAccess = CurrAccess; + break; + } + } + if (auto CacheResult = doCacheLookup(CurrAccess, Q, Loc)) + return {CacheResult, Loc}; + + // We need to know whether it is a phi so we can track backedges. + // Otherwise, walk all upward defs. + if (!isa<MemoryPhi>(CurrAccess)) { + ++DFI; + continue; + } + +#ifndef NDEBUG + // The loop below visits the phi's children for us. Because phis are the + // only things with multiple edges, skipping the children should always lead + // us to the end of the loop. + // + // Use a copy of DFI because skipChildren would kill our search stack, which + // would make caching anything on the way back impossible. + auto DFICopy = DFI; + assert(DFICopy.skipChildren() == DFE && + "Skipping phi's children doesn't end the DFS?"); +#endif + + const MemoryAccessPair PHIPair(CurrAccess, Loc); + + // Don't try to optimize this phi again if we've already tried to do so. + if (!Q.Visited.insert(PHIPair).second) { + ModifyingAccess = CurrAccess; + break; + } + + std::size_t InitialVisitedCallSize = Q.VisitedCalls.size(); + + // Recurse on PHI nodes, since we need to change locations. + // TODO: Allow graphtraits on pairs, which would turn this whole function + // into a normal single depth first walk. + MemoryAccess *FirstDef = nullptr; + for (auto MPI = upward_defs_begin(PHIPair), MPE = upward_defs_end(); + MPI != MPE; ++MPI) { + bool Backedge = + !FollowingBackedge && + DT->dominates(CurrAccess->getBlock(), MPI.getPhiArgBlock()); + + MemoryAccessPair CurrentPair = + UpwardsDFSWalk(MPI->first, MPI->second, Q, Backedge); + // All the phi arguments should reach the same point if we can bypass + // this phi. The alternative is that they hit this phi node, which + // means we can skip this argument. + if (FirstDef && CurrentPair.first != PHIPair.first && + CurrentPair.first != FirstDef) { + ModifyingAccess = CurrAccess; + break; + } + + if (!FirstDef) + FirstDef = CurrentPair.first; + } + + // If we exited the loop early, go with the result it gave us. + if (!ModifyingAccess) { + assert(FirstDef && "Found a Phi with no upward defs?"); + ModifyingAccess = FirstDef; + } else { + // If we can't optimize this Phi, then we can't safely cache any of the + // calls we visited when trying to optimize it. Wipe them out now. + Q.VisitedCalls.resize(InitialVisitedCallSize); + } + break; + } + + if (!ModifyingAccess) + return {MSSA->getLiveOnEntryDef(), Q.StartingLoc}; + + const BasicBlock *OriginalBlock = StartingAccess->getBlock(); + assert(DFI.getPathLength() > 0 && "We dropped our path?"); + unsigned N = DFI.getPathLength(); + // If we found a clobbering def, the last element in the path will be our + // clobber, so we don't want to cache that to itself. OTOH, if we optimized a + // phi, we can add the last thing in the path to the cache, since that won't + // be the result. + if (DFI.getPath(N - 1) == ModifyingAccess) + --N; + for (; N > 1; --N) { + MemoryAccess *CacheAccess = DFI.getPath(N - 1); + BasicBlock *CurrBlock = CacheAccess->getBlock(); + if (!FollowingBackedge) + doCacheInsert(CacheAccess, ModifyingAccess, Q, Loc); + if (DT->dominates(CurrBlock, OriginalBlock) && + (CurrBlock != OriginalBlock || !FollowingBackedge || + MSSA->locallyDominates(CacheAccess, StartingAccess))) + break; + } + + // Cache everything else on the way back. The caller should cache + // StartingAccess for us. + for (; N > 1; --N) { + MemoryAccess *CacheAccess = DFI.getPath(N - 1); + doCacheInsert(CacheAccess, ModifyingAccess, Q, Loc); + } + + return {ModifyingAccess, Loc}; +} + +/// \brief Walk the use-def chains starting at \p MA and find +/// the MemoryAccess that actually clobbers Loc. +/// +/// \returns our clobbering memory access +MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( + MemoryAccess *StartingAccess, UpwardsMemoryQuery &Q) { + return UpwardsDFSWalk(StartingAccess, Q.StartingLoc, Q, false).first; +} + +MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( + MemoryAccess *StartingAccess, MemoryLocation &Loc) { + if (isa<MemoryPhi>(StartingAccess)) + return StartingAccess; + + auto *StartingUseOrDef = cast<MemoryUseOrDef>(StartingAccess); + if (MSSA->isLiveOnEntryDef(StartingUseOrDef)) + return StartingUseOrDef; + + Instruction *I = StartingUseOrDef->getMemoryInst(); + + // Conservatively, fences are always clobbers, so don't perform the walk if we + // hit a fence. + if (!ImmutableCallSite(I) && I->isFenceLike()) + return StartingUseOrDef; + + UpwardsMemoryQuery Q; + Q.OriginalAccess = StartingUseOrDef; + Q.StartingLoc = Loc; + Q.Inst = StartingUseOrDef->getMemoryInst(); + Q.IsCall = false; + + if (auto CacheResult = doCacheLookup(StartingUseOrDef, Q, Q.StartingLoc)) + return CacheResult; + + // Unlike the other function, do not walk to the def of a def, because we are + // handed something we already believe is the clobbering access. + MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef) + ? StartingUseOrDef->getDefiningAccess() + : StartingUseOrDef; + + MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q); + // Only cache this if it wouldn't make Clobber point to itself. + if (Clobber != StartingAccess) + doCacheInsert(Q.OriginalAccess, Clobber, Q, Q.StartingLoc); + DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); + DEBUG(dbgs() << *StartingUseOrDef << "\n"); + DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); + DEBUG(dbgs() << *Clobber << "\n"); + return Clobber; +} + +MemoryAccess * +MemorySSA::CachingWalker::getClobberingMemoryAccess(const Instruction *I) { + // There should be no way to lookup an instruction and get a phi as the + // access, since we only map BB's to PHI's. So, this must be a use or def. + auto *StartingAccess = cast<MemoryUseOrDef>(MSSA->getMemoryAccess(I)); + + bool IsCall = bool(ImmutableCallSite(I)); + + // We can't sanely do anything with a fences, they conservatively + // clobber all memory, and have no locations to get pointers from to + // try to disambiguate. + if (!IsCall && I->isFenceLike()) + return StartingAccess; + + UpwardsMemoryQuery Q; + Q.OriginalAccess = StartingAccess; + Q.IsCall = IsCall; + if (!Q.IsCall) + Q.StartingLoc = MemoryLocation::get(I); + Q.Inst = I; + if (auto CacheResult = doCacheLookup(StartingAccess, Q, Q.StartingLoc)) + return CacheResult; + + // Start with the thing we already think clobbers this location + MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); + + // At this point, DefiningAccess may be the live on entry def. + // If it is, we will not get a better result. + if (MSSA->isLiveOnEntryDef(DefiningAccess)) + return DefiningAccess; + + MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q); + // DFS won't cache a result for DefiningAccess. So, if DefiningAccess isn't + // our clobber, be sure that it gets a cache entry, too. + if (Result != DefiningAccess) + doCacheInsert(DefiningAccess, Result, Q, Q.StartingLoc); + doCacheInsert(Q.OriginalAccess, Result, Q, Q.StartingLoc); + // TODO: When this implementation is more mature, we may want to figure out + // what this additional caching buys us. It's most likely A Good Thing. + if (Q.IsCall) + for (const MemoryAccess *MA : Q.VisitedCalls) + if (MA != Result) + doCacheInsert(MA, Result, Q, Q.StartingLoc); + + DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); + DEBUG(dbgs() << *DefiningAccess << "\n"); + DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); + DEBUG(dbgs() << *Result << "\n"); + + return Result; +} + +// Verify that MA doesn't exist in any of the caches. +void MemorySSA::CachingWalker::verifyRemoved(MemoryAccess *MA) { +#ifndef NDEBUG + for (auto &P : CachedUpwardsClobberingAccess) + assert(P.first.first != MA && P.second != MA && + "Found removed MemoryAccess in cache."); + for (auto &P : CachedUpwardsClobberingCall) + assert(P.first != MA && P.second != MA && + "Found removed MemoryAccess in cache."); +#endif // !NDEBUG +} + +MemoryAccess * +DoNothingMemorySSAWalker::getClobberingMemoryAccess(const Instruction *I) { + MemoryAccess *MA = MSSA->getMemoryAccess(I); + if (auto *Use = dyn_cast<MemoryUseOrDef>(MA)) + return Use->getDefiningAccess(); + return MA; +} + +MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess( + MemoryAccess *StartingAccess, MemoryLocation &) { + if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess)) + return Use->getDefiningAccess(); + return StartingAccess; +} +} diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index 9ec28a3f3d47..eb9188518624 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ModuleUtils.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -21,8 +20,8 @@ using namespace llvm; -static void appendToGlobalArray(const char *Array, - Module &M, Function *F, int Priority) { +static void appendToGlobalArray(const char *Array, Module &M, Function *F, + int Priority, Constant *Data) { IRBuilder<> IRB(M.getContext()); FunctionType *FnTy = FunctionType::get(IRB.getVoidTy(), false); @@ -31,15 +30,26 @@ static void appendToGlobalArray(const char *Array, SmallVector<Constant *, 16> CurrentCtors; StructType *EltTy; if (GlobalVariable *GVCtor = M.getNamedGlobal(Array)) { - // If there is a global_ctors array, use the existing struct type, which can - // have 2 or 3 fields. - ArrayType *ATy = cast<ArrayType>(GVCtor->getType()->getElementType()); - EltTy = cast<StructType>(ATy->getElementType()); + ArrayType *ATy = cast<ArrayType>(GVCtor->getValueType()); + StructType *OldEltTy = cast<StructType>(ATy->getElementType()); + // Upgrade a 2-field global array type to the new 3-field format if needed. + if (Data && OldEltTy->getNumElements() < 3) + EltTy = StructType::get(IRB.getInt32Ty(), PointerType::getUnqual(FnTy), + IRB.getInt8PtrTy(), nullptr); + else + EltTy = OldEltTy; if (Constant *Init = GVCtor->getInitializer()) { unsigned n = Init->getNumOperands(); CurrentCtors.reserve(n + 1); - for (unsigned i = 0; i != n; ++i) - CurrentCtors.push_back(cast<Constant>(Init->getOperand(i))); + for (unsigned i = 0; i != n; ++i) { + auto Ctor = cast<Constant>(Init->getOperand(i)); + if (EltTy != OldEltTy) + Ctor = ConstantStruct::get( + EltTy, Ctor->getAggregateElement((unsigned)0), + Ctor->getAggregateElement(1), + Constant::getNullValue(IRB.getInt8PtrTy()), nullptr); + CurrentCtors.push_back(Ctor); + } } GVCtor->eraseFromParent(); } else { @@ -54,7 +64,8 @@ static void appendToGlobalArray(const char *Array, CSVals[1] = F; // FIXME: Drop support for the two element form in LLVM 4.0. if (EltTy->getNumElements() >= 3) - CSVals[2] = llvm::Constant::getNullValue(IRB.getInt8PtrTy()); + CSVals[2] = Data ? ConstantExpr::getPointerCast(Data, IRB.getInt8PtrTy()) + : Constant::getNullValue(IRB.getInt8PtrTy()); Constant *RuntimeCtorInit = ConstantStruct::get(EltTy, makeArrayRef(CSVals, EltTy->getNumElements())); @@ -70,29 +81,12 @@ static void appendToGlobalArray(const char *Array, GlobalValue::AppendingLinkage, NewInit, Array); } -void llvm::appendToGlobalCtors(Module &M, Function *F, int Priority) { - appendToGlobalArray("llvm.global_ctors", M, F, Priority); +void llvm::appendToGlobalCtors(Module &M, Function *F, int Priority, Constant *Data) { + appendToGlobalArray("llvm.global_ctors", M, F, Priority, Data); } -void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority) { - appendToGlobalArray("llvm.global_dtors", M, F, Priority); -} - -GlobalVariable * -llvm::collectUsedGlobalVariables(Module &M, SmallPtrSetImpl<GlobalValue *> &Set, - bool CompilerUsed) { - const char *Name = CompilerUsed ? "llvm.compiler.used" : "llvm.used"; - GlobalVariable *GV = M.getGlobalVariable(Name); - if (!GV || !GV->hasInitializer()) - return GV; - - const ConstantArray *Init = cast<ConstantArray>(GV->getInitializer()); - for (unsigned I = 0, E = Init->getNumOperands(); I != E; ++I) { - Value *Op = Init->getOperand(I); - GlobalValue *G = cast<GlobalValue>(Op->stripPointerCastsNoFollowAliases()); - Set.insert(G); - } - return GV; +void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *Data) { + appendToGlobalArray("llvm.global_dtors", M, F, Priority, Data); } Function *llvm::checkSanitizerInterfaceFunction(Constant *FuncOrBitcast) { @@ -132,4 +126,3 @@ std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions( } return std::make_pair(Ctor, InitFunction); } - diff --git a/lib/Transforms/Utils/NameAnonFunctions.cpp b/lib/Transforms/Utils/NameAnonFunctions.cpp new file mode 100644 index 000000000000..c4f3839d8482 --- /dev/null +++ b/lib/Transforms/Utils/NameAnonFunctions.cpp @@ -0,0 +1,102 @@ +//===- NameAnonFunctions.cpp - ThinLTO Summary-based Function Import ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements naming anonymous function to make sure they can be +// refered to by ThinLTO. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallString.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/MD5.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +// Compute a "unique" hash for the module based on the name of the public +// functions. +class ModuleHasher { + Module &TheModule; + std::string TheHash; + +public: + ModuleHasher(Module &M) : TheModule(M) {} + + /// Return the lazily computed hash. + std::string &get() { + if (!TheHash.empty()) + // Cache hit :) + return TheHash; + + MD5 Hasher; + for (auto &F : TheModule) { + if (F.isDeclaration() || F.hasLocalLinkage() || !F.hasName()) + continue; + auto Name = F.getName(); + Hasher.update(Name); + } + for (auto &GV : TheModule.globals()) { + if (GV.isDeclaration() || GV.hasLocalLinkage() || !GV.hasName()) + continue; + auto Name = GV.getName(); + Hasher.update(Name); + } + + // Now return the result. + MD5::MD5Result Hash; + Hasher.final(Hash); + SmallString<32> Result; + MD5::stringifyResult(Hash, Result); + TheHash = Result.str(); + return TheHash; + } +}; + +// Rename all the anon functions in the module +bool llvm::nameUnamedFunctions(Module &M) { + bool Changed = false; + ModuleHasher ModuleHash(M); + int count = 0; + for (auto &F : M) { + if (F.hasName()) + continue; + F.setName(Twine("anon.") + ModuleHash.get() + "." + Twine(count++)); + Changed = true; + } + return Changed; +} + +namespace { + +// Simple pass that provides a name to every anon function. +class NameAnonFunction : public ModulePass { + +public: + /// Pass identification, replacement for typeid + static char ID; + + /// Specify pass name for debug output + const char *getPassName() const override { return "Name Anon Functions"; } + + explicit NameAnonFunction() : ModulePass(ID) {} + + bool runOnModule(Module &M) override { return nameUnamedFunctions(M); } +}; +char NameAnonFunction::ID = 0; + +} // anonymous namespace + +INITIALIZE_PASS_BEGIN(NameAnonFunction, "name-anon-functions", + "Provide a name to nameless functions", false, false) +INITIALIZE_PASS_END(NameAnonFunction, "name-anon-functions", + "Provide a name to nameless functions", false, false) + +namespace llvm { +ModulePass *createNameAnonFunctionPass() { return new NameAnonFunction(); } +} diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index c4f9b9f61407..cbf385d56339 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -523,7 +523,7 @@ void PromoteMem2Reg::run() { AllocaInfo Info; LargeBlockInfo LBI; - IDFCalculator IDF(DT); + ForwardIDFCalculator IDF(DT); for (unsigned AllocaNum = 0; AllocaNum != Allocas.size(); ++AllocaNum) { AllocaInst *AI = Allocas[AllocaNum]; @@ -802,7 +802,8 @@ void PromoteMem2Reg::ComputeLiveInBlocks( // actually live-in here. LiveInBlockWorklist[i] = LiveInBlockWorklist.back(); LiveInBlockWorklist.pop_back(); - --i, --e; + --i; + --e; break; } diff --git a/lib/Transforms/Utils/SanitizerStats.cpp b/lib/Transforms/Utils/SanitizerStats.cpp new file mode 100644 index 000000000000..9afd175c10ed --- /dev/null +++ b/lib/Transforms/Utils/SanitizerStats.cpp @@ -0,0 +1,108 @@ +//===- SanitizerStats.cpp - Sanitizer statistics gathering ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Implements code generation for sanitizer statistics gathering. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/SanitizerStats.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include "llvm/ADT/Triple.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" + +using namespace llvm; + +SanitizerStatReport::SanitizerStatReport(Module *M) : M(M) { + StatTy = ArrayType::get(Type::getInt8PtrTy(M->getContext()), 2); + EmptyModuleStatsTy = makeModuleStatsTy(); + + ModuleStatsGV = new GlobalVariable(*M, EmptyModuleStatsTy, false, + GlobalValue::InternalLinkage, nullptr); +} + +ArrayType *SanitizerStatReport::makeModuleStatsArrayTy() { + return ArrayType::get(StatTy, Inits.size()); +} + +StructType *SanitizerStatReport::makeModuleStatsTy() { + return StructType::get(M->getContext(), {Type::getInt8PtrTy(M->getContext()), + Type::getInt32Ty(M->getContext()), + makeModuleStatsArrayTy()}); +} + +void SanitizerStatReport::create(IRBuilder<> &B, SanitizerStatKind SK) { + Function *F = B.GetInsertBlock()->getParent(); + Module *M = F->getParent(); + PointerType *Int8PtrTy = B.getInt8PtrTy(); + IntegerType *IntPtrTy = B.getIntPtrTy(M->getDataLayout()); + ArrayType *StatTy = ArrayType::get(Int8PtrTy, 2); + + Inits.push_back(ConstantArray::get( + StatTy, + {Constant::getNullValue(Int8PtrTy), + ConstantExpr::getIntToPtr( + ConstantInt::get(IntPtrTy, uint64_t(SK) << (IntPtrTy->getBitWidth() - + kSanitizerStatKindBits)), + Int8PtrTy)})); + + FunctionType *StatReportTy = + FunctionType::get(B.getVoidTy(), Int8PtrTy, false); + Constant *StatReport = M->getOrInsertFunction( + "__sanitizer_stat_report", StatReportTy); + + auto InitAddr = ConstantExpr::getGetElementPtr( + EmptyModuleStatsTy, ModuleStatsGV, + ArrayRef<Constant *>{ + ConstantInt::get(IntPtrTy, 0), ConstantInt::get(B.getInt32Ty(), 2), + ConstantInt::get(IntPtrTy, Inits.size() - 1), + }); + B.CreateCall(StatReport, ConstantExpr::getBitCast(InitAddr, Int8PtrTy)); +} + +void SanitizerStatReport::finish() { + if (Inits.empty()) { + ModuleStatsGV->eraseFromParent(); + return; + } + + PointerType *Int8PtrTy = Type::getInt8PtrTy(M->getContext()); + IntegerType *Int32Ty = Type::getInt32Ty(M->getContext()); + Type *VoidTy = Type::getVoidTy(M->getContext()); + + // Create a new ModuleStatsGV to replace the old one. We can't just set the + // old one's initializer because its type is different. + auto NewModuleStatsGV = new GlobalVariable( + *M, makeModuleStatsTy(), false, GlobalValue::InternalLinkage, + ConstantStruct::getAnon( + {Constant::getNullValue(Int8PtrTy), + ConstantInt::get(Int32Ty, Inits.size()), + ConstantArray::get(makeModuleStatsArrayTy(), Inits)})); + ModuleStatsGV->replaceAllUsesWith( + ConstantExpr::getBitCast(NewModuleStatsGV, ModuleStatsGV->getType())); + ModuleStatsGV->eraseFromParent(); + + // Create a global constructor to register NewModuleStatsGV. + auto F = Function::Create(FunctionType::get(VoidTy, false), + GlobalValue::InternalLinkage, "", M); + auto BB = BasicBlock::Create(M->getContext(), "", F); + IRBuilder<> B(BB); + + FunctionType *StatInitTy = FunctionType::get(VoidTy, Int8PtrTy, false); + Constant *StatInit = M->getOrInsertFunction( + "__sanitizer_stat_init", StatInitTy); + + B.CreateCall(StatInit, ConstantExpr::getBitCast(NewModuleStatsGV, Int8PtrTy)); + B.CreateRetVoid(); + + appendToGlobalCtors(*M, F, 0); +} diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index e484b690597e..0504646c304e 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -11,7 +11,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/Local.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" @@ -45,6 +44,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <map> @@ -58,17 +58,18 @@ using namespace PatternMatch; // a select, so the "clamp" idiom (of a min followed by a max) will be caught. // To catch this, we need to fold a compare and a select, hence '2' being the // minimum reasonable default. -static cl::opt<unsigned> -PHINodeFoldingThreshold("phi-node-folding-threshold", cl::Hidden, cl::init(2), - cl::desc("Control the amount of phi node folding to perform (default = 2)")); +static cl::opt<unsigned> PHINodeFoldingThreshold( + "phi-node-folding-threshold", cl::Hidden, cl::init(2), + cl::desc( + "Control the amount of phi node folding to perform (default = 2)")); -static cl::opt<bool> -DupRet("simplifycfg-dup-ret", cl::Hidden, cl::init(false), - cl::desc("Duplicate return instructions into unconditional branches")); +static cl::opt<bool> DupRet( + "simplifycfg-dup-ret", cl::Hidden, cl::init(false), + cl::desc("Duplicate return instructions into unconditional branches")); static cl::opt<bool> -SinkCommon("simplifycfg-sink-common", cl::Hidden, cl::init(true), - cl::desc("Sink common instructions down to the end block")); + SinkCommon("simplifycfg-sink-common", cl::Hidden, cl::init(true), + cl::desc("Sink common instructions down to the end block")); static cl::opt<bool> HoistCondStores( "simplifycfg-hoist-cond-stores", cl::Hidden, cl::init(true), @@ -96,48 +97,54 @@ static cl::opt<unsigned> MaxSpeculationDepth( "speculatively executed instructions")); STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); -STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); -STATISTIC(NumLookupTables, "Number of switch instructions turned into lookup tables"); -STATISTIC(NumLookupTablesHoles, "Number of switch instructions turned into lookup tables (holes checked)"); +STATISTIC(NumLinearMaps, + "Number of switch instructions turned into linear mapping"); +STATISTIC(NumLookupTables, + "Number of switch instructions turned into lookup tables"); +STATISTIC( + NumLookupTablesHoles, + "Number of switch instructions turned into lookup tables (holes checked)"); STATISTIC(NumTableCmpReuses, "Number of reused switch table lookup compares"); -STATISTIC(NumSinkCommons, "Number of common instructions sunk down to the end block"); +STATISTIC(NumSinkCommons, + "Number of common instructions sunk down to the end block"); STATISTIC(NumSpeculations, "Number of speculative executed instructions"); namespace { - // The first field contains the value that the switch produces when a certain - // case group is selected, and the second field is a vector containing the - // cases composing the case group. - typedef SmallVector<std::pair<Constant *, SmallVector<ConstantInt *, 4>>, 2> +// The first field contains the value that the switch produces when a certain +// case group is selected, and the second field is a vector containing the +// cases composing the case group. +typedef SmallVector<std::pair<Constant *, SmallVector<ConstantInt *, 4>>, 2> SwitchCaseResultVectorTy; - // The first field contains the phi node that generates a result of the switch - // and the second field contains the value generated for a certain case in the - // switch for that PHI. - typedef SmallVector<std::pair<PHINode *, Constant *>, 4> SwitchCaseResultsTy; +// The first field contains the phi node that generates a result of the switch +// and the second field contains the value generated for a certain case in the +// switch for that PHI. +typedef SmallVector<std::pair<PHINode *, Constant *>, 4> SwitchCaseResultsTy; - /// ValueEqualityComparisonCase - Represents a case of a switch. - struct ValueEqualityComparisonCase { - ConstantInt *Value; - BasicBlock *Dest; +/// ValueEqualityComparisonCase - Represents a case of a switch. +struct ValueEqualityComparisonCase { + ConstantInt *Value; + BasicBlock *Dest; - ValueEqualityComparisonCase(ConstantInt *Value, BasicBlock *Dest) + ValueEqualityComparisonCase(ConstantInt *Value, BasicBlock *Dest) : Value(Value), Dest(Dest) {} - bool operator<(ValueEqualityComparisonCase RHS) const { - // Comparing pointers is ok as we only rely on the order for uniquing. - return Value < RHS.Value; - } + bool operator<(ValueEqualityComparisonCase RHS) const { + // Comparing pointers is ok as we only rely on the order for uniquing. + return Value < RHS.Value; + } - bool operator==(BasicBlock *RHSDest) const { return Dest == RHSDest; } - }; + bool operator==(BasicBlock *RHSDest) const { return Dest == RHSDest; } +}; class SimplifyCFGOpt { const TargetTransformInfo &TTI; const DataLayout &DL; unsigned BonusInstThreshold; AssumptionCache *AC; + SmallPtrSetImpl<BasicBlock *> *LoopHeaders; Value *isValueEqualityComparison(TerminatorInst *TI); - BasicBlock *GetValueEqualityComparisonCases(TerminatorInst *TI, - std::vector<ValueEqualityComparisonCase> &Cases); + BasicBlock *GetValueEqualityComparisonCases( + TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); bool SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, BasicBlock *Pred, IRBuilder<> &Builder); @@ -152,13 +159,15 @@ class SimplifyCFGOpt { bool SimplifyUnreachable(UnreachableInst *UI); bool SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder); bool SimplifyIndirectBr(IndirectBrInst *IBI); - bool SimplifyUncondBranch(BranchInst *BI, IRBuilder <> &Builder); - bool SimplifyCondBranch(BranchInst *BI, IRBuilder <>&Builder); + bool SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder); + bool SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder); public: SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout &DL, - unsigned BonusInstThreshold, AssumptionCache *AC) - : TTI(TTI), DL(DL), BonusInstThreshold(BonusInstThreshold), AC(AC) {} + unsigned BonusInstThreshold, AssumptionCache *AC, + SmallPtrSetImpl<BasicBlock *> *LoopHeaders) + : TTI(TTI), DL(DL), BonusInstThreshold(BonusInstThreshold), AC(AC), + LoopHeaders(LoopHeaders) {} bool run(BasicBlock *BB); }; } @@ -166,19 +175,19 @@ public: /// Return true if it is safe to merge these two /// terminator instructions together. static bool SafeToMergeTerminators(TerminatorInst *SI1, TerminatorInst *SI2) { - if (SI1 == SI2) return false; // Can't merge with self! + if (SI1 == SI2) + return false; // Can't merge with self! // It is not safe to merge these two switch instructions if they have a common // successor, and if that successor has a PHI node, and if *that* PHI node has // conflicting incoming values from the two switch blocks. BasicBlock *SI1BB = SI1->getParent(); BasicBlock *SI2BB = SI2->getParent(); - SmallPtrSet<BasicBlock*, 16> SI1Succs(succ_begin(SI1BB), succ_end(SI1BB)); + SmallPtrSet<BasicBlock *, 16> SI1Succs(succ_begin(SI1BB), succ_end(SI1BB)); - for (succ_iterator I = succ_begin(SI2BB), E = succ_end(SI2BB); I != E; ++I) - if (SI1Succs.count(*I)) - for (BasicBlock::iterator BBI = (*I)->begin(); - isa<PHINode>(BBI); ++BBI) { + for (BasicBlock *Succ : successors(SI2BB)) + if (SI1Succs.count(Succ)) + for (BasicBlock::iterator BBI = Succ->begin(); isa<PHINode>(BBI); ++BBI) { PHINode *PN = cast<PHINode>(BBI); if (PN->getIncomingValueForBlock(SI1BB) != PN->getIncomingValueForBlock(SI2BB)) @@ -191,11 +200,12 @@ static bool SafeToMergeTerminators(TerminatorInst *SI1, TerminatorInst *SI2) { /// Return true if it is safe and profitable to merge these two terminator /// instructions together, where SI1 is an unconditional branch. PhiNodes will /// store all PHI nodes in common successors. -static bool isProfitableToFoldUnconditional(BranchInst *SI1, - BranchInst *SI2, - Instruction *Cond, - SmallVectorImpl<PHINode*> &PhiNodes) { - if (SI1 == SI2) return false; // Can't merge with self! +static bool +isProfitableToFoldUnconditional(BranchInst *SI1, BranchInst *SI2, + Instruction *Cond, + SmallVectorImpl<PHINode *> &PhiNodes) { + if (SI1 == SI2) + return false; // Can't merge with self! assert(SI1->isUnconditional() && SI2->isConditional()); // We fold the unconditional branch if we can easily update all PHI nodes in @@ -204,7 +214,8 @@ static bool isProfitableToFoldUnconditional(BranchInst *SI1, // 2> We have "Cond" as the incoming value for the unconditional branch; // 3> SI2->getCondition() and Cond have same operands. CmpInst *Ci2 = dyn_cast<CmpInst>(SI2->getCondition()); - if (!Ci2) return false; + if (!Ci2) + return false; if (!(Cond->getOperand(0) == Ci2->getOperand(0) && Cond->getOperand(1) == Ci2->getOperand(1)) && !(Cond->getOperand(0) == Ci2->getOperand(1) && @@ -213,11 +224,10 @@ static bool isProfitableToFoldUnconditional(BranchInst *SI1, BasicBlock *SI1BB = SI1->getParent(); BasicBlock *SI2BB = SI2->getParent(); - SmallPtrSet<BasicBlock*, 16> SI1Succs(succ_begin(SI1BB), succ_end(SI1BB)); - for (succ_iterator I = succ_begin(SI2BB), E = succ_end(SI2BB); I != E; ++I) - if (SI1Succs.count(*I)) - for (BasicBlock::iterator BBI = (*I)->begin(); - isa<PHINode>(BBI); ++BBI) { + SmallPtrSet<BasicBlock *, 16> SI1Succs(succ_begin(SI1BB), succ_end(SI1BB)); + for (BasicBlock *Succ : successors(SI2BB)) + if (SI1Succs.count(Succ)) + for (BasicBlock::iterator BBI = Succ->begin(); isa<PHINode>(BBI); ++BBI) { PHINode *PN = cast<PHINode>(BBI); if (PN->getIncomingValueForBlock(SI1BB) != Cond || !isa<ConstantInt>(PN->getIncomingValueForBlock(SI2BB))) @@ -233,11 +243,11 @@ static bool isProfitableToFoldUnconditional(BranchInst *SI1, /// of Succ. static void AddPredecessorToBlock(BasicBlock *Succ, BasicBlock *NewPred, BasicBlock *ExistPred) { - if (!isa<PHINode>(Succ->begin())) return; // Quick exit if nothing to do + if (!isa<PHINode>(Succ->begin())) + return; // Quick exit if nothing to do PHINode *PN; - for (BasicBlock::iterator I = Succ->begin(); - (PN = dyn_cast<PHINode>(I)); ++I) + for (BasicBlock::iterator I = Succ->begin(); (PN = dyn_cast<PHINode>(I)); ++I) PN->addIncoming(PN->getIncomingValueForBlock(ExistPred), NewPred); } @@ -270,7 +280,7 @@ static unsigned ComputeSpeculationCost(const User *I, /// V plus its non-dominating operands. If that cost is greater than /// CostRemaining, false is returned and CostRemaining is undefined. static bool DominatesMergePoint(Value *V, BasicBlock *BB, - SmallPtrSetImpl<Instruction*> *AggressiveInsts, + SmallPtrSetImpl<Instruction *> *AggressiveInsts, unsigned &CostRemaining, const TargetTransformInfo &TTI, unsigned Depth = 0) { @@ -294,7 +304,8 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB, // We don't want to allow weird loops that might have the "if condition" in // the bottom of this block. - if (PBB == BB) return false; + if (PBB == BB) + return false; // If this instruction is defined in a block that contains an unconditional // branch to BB, then it must be in the 'conditional' part of the "if @@ -305,10 +316,12 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB, // If we aren't allowing aggressive promotion anymore, then don't consider // instructions in the 'if region'. - if (!AggressiveInsts) return false; + if (!AggressiveInsts) + return false; // If we have seen this instruction before, don't count it again. - if (AggressiveInsts->count(I)) return true; + if (AggressiveInsts->count(I)) + return true; // Okay, it looks like the instruction IS in the "condition". Check to // see if it's a cheap instruction to unconditionally compute, and if it @@ -366,8 +379,8 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout &DL) { if (CI->getType() == PtrTy) return CI; else - return cast<ConstantInt> - (ConstantExpr::getIntegerCast(CI, PtrTy, /*isSigned=*/false)); + return cast<ConstantInt>( + ConstantExpr::getIntegerCast(CI, PtrTy, /*isSigned=*/false)); } return nullptr; } @@ -403,11 +416,11 @@ struct ConstantComparesGatherer { operator=(const ConstantComparesGatherer &) = delete; private: - /// Try to set the current value used for the comparison, it succeeds only if /// it wasn't set before or if the new value is the same as the old one bool setValueOnce(Value *NewVal) { - if(CompValue && CompValue != NewVal) return false; + if (CompValue && CompValue != NewVal) + return false; CompValue = NewVal; return (CompValue != nullptr); } @@ -424,35 +437,99 @@ private: ICmpInst *ICI; ConstantInt *C; if (!((ICI = dyn_cast<ICmpInst>(I)) && - (C = GetConstantInt(I->getOperand(1), DL)))) { + (C = GetConstantInt(I->getOperand(1), DL)))) { return false; } Value *RHSVal; - ConstantInt *RHSC; + const APInt *RHSC; // Pattern match a special case - // (x & ~2^x) == y --> x == y || x == y|2^x + // (x & ~2^z) == y --> x == y || x == y|2^z // This undoes a transformation done by instcombine to fuse 2 compares. - if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) { + if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE)) { + + // It's a little bit hard to see why the following transformations are + // correct. Here is a CVC3 program to verify them for 64-bit values: + + /* + ONE : BITVECTOR(64) = BVZEROEXTEND(0bin1, 63); + x : BITVECTOR(64); + y : BITVECTOR(64); + z : BITVECTOR(64); + mask : BITVECTOR(64) = BVSHL(ONE, z); + QUERY( (y & ~mask = y) => + ((x & ~mask = y) <=> (x = y OR x = (y | mask))) + ); + QUERY( (y | mask = y) => + ((x | mask = y) <=> (x = y OR x = (y & ~mask))) + ); + */ + + // Please note that each pattern must be a dual implication (<--> or + // iff). One directional implication can create spurious matches. If the + // implication is only one-way, an unsatisfiable condition on the left + // side can imply a satisfiable condition on the right side. Dual + // implication ensures that satisfiable conditions are transformed to + // other satisfiable conditions and unsatisfiable conditions are + // transformed to other unsatisfiable conditions. + + // Here is a concrete example of a unsatisfiable condition on the left + // implying a satisfiable condition on the right: + // + // mask = (1 << z) + // (x & ~mask) == y --> (x == y || x == (y | mask)) + // + // Substituting y = 3, z = 0 yields: + // (x & -2) == 3 --> (x == 3 || x == 2) + + // Pattern match a special case: + /* + QUERY( (y & ~mask = y) => + ((x & ~mask = y) <=> (x = y OR x = (y | mask))) + ); + */ if (match(ICI->getOperand(0), - m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) { - APInt Not = ~RHSC->getValue(); - if (Not.isPowerOf2()) { + m_And(m_Value(RHSVal), m_APInt(RHSC)))) { + APInt Mask = ~*RHSC; + if (Mask.isPowerOf2() && (C->getValue() & ~Mask) == C->getValue()) { // If we already have a value for the switch, it has to match! - if(!setValueOnce(RHSVal)) + if (!setValueOnce(RHSVal)) + return false; + + Vals.push_back(C); + Vals.push_back( + ConstantInt::get(C->getContext(), + C->getValue() | Mask)); + UsedICmps++; + return true; + } + } + + // Pattern match a special case: + /* + QUERY( (y | mask = y) => + ((x | mask = y) <=> (x = y OR x = (y & ~mask))) + ); + */ + if (match(ICI->getOperand(0), + m_Or(m_Value(RHSVal), m_APInt(RHSC)))) { + APInt Mask = *RHSC; + if (Mask.isPowerOf2() && (C->getValue() | Mask) == C->getValue()) { + // If we already have a value for the switch, it has to match! + if (!setValueOnce(RHSVal)) return false; Vals.push_back(C); Vals.push_back(ConstantInt::get(C->getContext(), - C->getValue() | Not)); + C->getValue() & ~Mask)); UsedICmps++; return true; } } // If we already have a value for the switch, it has to match! - if(!setValueOnce(ICI->getOperand(0))) + if (!setValueOnce(ICI->getOperand(0))) return false; UsedICmps++; @@ -467,8 +544,8 @@ private: // Shift the range if the compare is fed by an add. This is the range // compare idiom as emitted by instcombine. Value *CandidateVal = I->getOperand(0); - if(match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC)))) { - Span = Span.subtract(RHSC->getValue()); + if (match(I->getOperand(0), m_Add(m_Value(RHSVal), m_APInt(RHSC)))) { + Span = Span.subtract(*RHSC); CandidateVal = RHSVal; } @@ -484,7 +561,7 @@ private: } // If we already have a value for the switch, it has to match! - if(!setValueOnce(CandidateVal)) + if (!setValueOnce(CandidateVal)) return false; // Add all values from the range to the set @@ -493,7 +570,6 @@ private: UsedICmps++; return true; - } /// Given a potentially 'or'd or 'and'd together collection of icmp @@ -507,18 +583,22 @@ private: // Keep a stack (SmallVector for efficiency) for depth-first traversal SmallVector<Value *, 8> DFT; + SmallPtrSet<Value *, 8> Visited; // Initialize + Visited.insert(V); DFT.push_back(V); - while(!DFT.empty()) { + while (!DFT.empty()) { V = DFT.pop_back_val(); if (Instruction *I = dyn_cast<Instruction>(V)) { // If it is a || (or && depending on isEQ), process the operands. if (I->getOpcode() == (isEQ ? Instruction::Or : Instruction::And)) { - DFT.push_back(I->getOperand(1)); - DFT.push_back(I->getOperand(0)); + if (Visited.insert(I->getOperand(1)).second) + DFT.push_back(I->getOperand(1)); + if (Visited.insert(I->getOperand(0)).second) + DFT.push_back(I->getOperand(0)); continue; } @@ -541,7 +621,6 @@ private: } } }; - } static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) { @@ -556,7 +635,8 @@ static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) { } TI->eraseFromParent(); - if (Cond) RecursivelyDeleteTriviallyDeadInstructions(Cond); + if (Cond) + RecursivelyDeleteTriviallyDeadInstructions(Cond); } /// Return true if the specified terminator checks @@ -566,8 +646,9 @@ Value *SimplifyCFGOpt::isValueEqualityComparison(TerminatorInst *TI) { if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { // Do not permit merging of large switch instructions into their // predecessors unless there is only one predecessor. - if (SI->getNumSuccessors()*std::distance(pred_begin(SI->getParent()), - pred_end(SI->getParent())) <= 128) + if (SI->getNumSuccessors() * std::distance(pred_begin(SI->getParent()), + pred_end(SI->getParent())) <= + 128) CV = SI->getCondition(); } else if (BranchInst *BI = dyn_cast<BranchInst>(TI)) if (BI->isConditional() && BI->getCondition()->hasOneUse()) @@ -589,46 +670,44 @@ Value *SimplifyCFGOpt::isValueEqualityComparison(TerminatorInst *TI) { /// Given a value comparison instruction, /// decode all of the 'cases' that it represents and return the 'default' block. -BasicBlock *SimplifyCFGOpt:: -GetValueEqualityComparisonCases(TerminatorInst *TI, - std::vector<ValueEqualityComparisonCase> - &Cases) { +BasicBlock *SimplifyCFGOpt::GetValueEqualityComparisonCases( + TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases) { if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { Cases.reserve(SI->getNumCases()); - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) - Cases.push_back(ValueEqualityComparisonCase(i.getCaseValue(), - i.getCaseSuccessor())); + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; + ++i) + Cases.push_back( + ValueEqualityComparisonCase(i.getCaseValue(), i.getCaseSuccessor())); return SI->getDefaultDest(); } BranchInst *BI = cast<BranchInst>(TI); ICmpInst *ICI = cast<ICmpInst>(BI->getCondition()); BasicBlock *Succ = BI->getSuccessor(ICI->getPredicate() == ICmpInst::ICMP_NE); - Cases.push_back(ValueEqualityComparisonCase(GetConstantInt(ICI->getOperand(1), - DL), - Succ)); + Cases.push_back(ValueEqualityComparisonCase( + GetConstantInt(ICI->getOperand(1), DL), Succ)); return BI->getSuccessor(ICI->getPredicate() == ICmpInst::ICMP_EQ); } - /// Given a vector of bb/value pairs, remove any entries /// in the list that match the specified block. -static void EliminateBlockCases(BasicBlock *BB, - std::vector<ValueEqualityComparisonCase> &Cases) { +static void +EliminateBlockCases(BasicBlock *BB, + std::vector<ValueEqualityComparisonCase> &Cases) { Cases.erase(std::remove(Cases.begin(), Cases.end(), BB), Cases.end()); } /// Return true if there are any keys in C1 that exist in C2 as well. -static bool -ValuesOverlap(std::vector<ValueEqualityComparisonCase> &C1, - std::vector<ValueEqualityComparisonCase > &C2) { +static bool ValuesOverlap(std::vector<ValueEqualityComparisonCase> &C1, + std::vector<ValueEqualityComparisonCase> &C2) { std::vector<ValueEqualityComparisonCase> *V1 = &C1, *V2 = &C2; // Make V1 be smaller than V2. if (V1->size() > V2->size()) std::swap(V1, V2); - if (V1->size() == 0) return false; + if (V1->size() == 0) + return false; if (V1->size() == 1) { // Just scan V2. ConstantInt *TheVal = (*V1)[0].Value; @@ -657,30 +736,30 @@ ValuesOverlap(std::vector<ValueEqualityComparisonCase> &C1, /// also a value comparison with the same value, and if that comparison /// determines the outcome of this comparison. If so, simplify TI. This does a /// very limited form of jump threading. -bool SimplifyCFGOpt:: -SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, - BasicBlock *Pred, - IRBuilder<> &Builder) { +bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( + TerminatorInst *TI, BasicBlock *Pred, IRBuilder<> &Builder) { Value *PredVal = isValueEqualityComparison(Pred->getTerminator()); - if (!PredVal) return false; // Not a value comparison in predecessor. + if (!PredVal) + return false; // Not a value comparison in predecessor. Value *ThisVal = isValueEqualityComparison(TI); assert(ThisVal && "This isn't a value comparison!!"); - if (ThisVal != PredVal) return false; // Different predicates. + if (ThisVal != PredVal) + return false; // Different predicates. // TODO: Preserve branch weight metadata, similarly to how // FoldValueComparisonIntoPredecessors preserves it. // Find out information about when control will move from Pred to TI's block. std::vector<ValueEqualityComparisonCase> PredCases; - BasicBlock *PredDef = GetValueEqualityComparisonCases(Pred->getTerminator(), - PredCases); - EliminateBlockCases(PredDef, PredCases); // Remove default from cases. + BasicBlock *PredDef = + GetValueEqualityComparisonCases(Pred->getTerminator(), PredCases); + EliminateBlockCases(PredDef, PredCases); // Remove default from cases. // Find information about how control leaves this block. std::vector<ValueEqualityComparisonCase> ThisCases; BasicBlock *ThisDef = GetValueEqualityComparisonCases(TI, ThisCases); - EliminateBlockCases(ThisDef, ThisCases); // Remove default from cases. + EliminateBlockCases(ThisDef, ThisCases); // Remove default from cases. // If TI's block is the default block from Pred's comparison, potentially // simplify TI based on this knowledge. @@ -697,13 +776,14 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, assert(ThisCases.size() == 1 && "Branch can only have one case!"); // Insert the new branch. Instruction *NI = Builder.CreateBr(ThisDef); - (void) NI; + (void)NI; // Remove PHI node entries for the dead edge. ThisCases[0].Dest->removePredecessor(TI->getParent()); DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() - << "Through successor TI: " << *TI << "Leaving: " << *NI << "\n"); + << "Through successor TI: " << *TI << "Leaving: " << *NI + << "\n"); EraseTerminatorInstAndDCECond(TI); return true; @@ -711,7 +791,7 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, SwitchInst *SI = cast<SwitchInst>(TI); // Okay, TI has cases that are statically dead, prune them away. - SmallPtrSet<Constant*, 16> DeadCases; + SmallPtrSet<Constant *, 16> DeadCases; for (unsigned i = 0, e = PredCases.size(); i != e; ++i) DeadCases.insert(PredCases[i].Value); @@ -732,7 +812,7 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, --i; if (DeadCases.count(i.getCaseValue())) { if (HasWeight) { - std::swap(Weights[i.getCaseIndex()+1], Weights.back()); + std::swap(Weights[i.getCaseIndex() + 1], Weights.back()); Weights.pop_back(); } i.getCaseSuccessor()->removePredecessor(TI->getParent()); @@ -741,8 +821,8 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, } if (HasWeight && Weights.size() >= 2) SI->setMetadata(LLVMContext::MD_prof, - MDBuilder(SI->getParent()->getContext()). - createBranchWeights(Weights)); + MDBuilder(SI->getParent()->getContext()) + .createBranchWeights(Weights)); DEBUG(dbgs() << "Leaving: " << *TI << "\n"); return true; @@ -755,7 +835,7 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, for (unsigned i = 0, e = PredCases.size(); i != e; ++i) if (PredCases[i].Dest == TIBB) { if (TIV) - return false; // Cannot handle multiple values coming to this block. + return false; // Cannot handle multiple values coming to this block. TIV = PredCases[i].Value; } assert(TIV && "No edge from pred to succ?"); @@ -770,53 +850,53 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, } // If not handled by any explicit cases, it is handled by the default case. - if (!TheRealDest) TheRealDest = ThisDef; + if (!TheRealDest) + TheRealDest = ThisDef; // Remove PHI node entries for dead edges. BasicBlock *CheckEdge = TheRealDest; - for (succ_iterator SI = succ_begin(TIBB), e = succ_end(TIBB); SI != e; ++SI) - if (*SI != CheckEdge) - (*SI)->removePredecessor(TIBB); + for (BasicBlock *Succ : successors(TIBB)) + if (Succ != CheckEdge) + Succ->removePredecessor(TIBB); else CheckEdge = nullptr; // Insert the new branch. Instruction *NI = Builder.CreateBr(TheRealDest); - (void) NI; + (void)NI; DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() - << "Through successor TI: " << *TI << "Leaving: " << *NI << "\n"); + << "Through successor TI: " << *TI << "Leaving: " << *NI + << "\n"); EraseTerminatorInstAndDCECond(TI); return true; } namespace { - /// This class implements a stable ordering of constant - /// integers that does not depend on their address. This is important for - /// applications that sort ConstantInt's to ensure uniqueness. - struct ConstantIntOrdering { - bool operator()(const ConstantInt *LHS, const ConstantInt *RHS) const { - return LHS->getValue().ult(RHS->getValue()); - } - }; +/// This class implements a stable ordering of constant +/// integers that does not depend on their address. This is important for +/// applications that sort ConstantInt's to ensure uniqueness. +struct ConstantIntOrdering { + bool operator()(const ConstantInt *LHS, const ConstantInt *RHS) const { + return LHS->getValue().ult(RHS->getValue()); + } +}; } static int ConstantIntSortPredicate(ConstantInt *const *P1, ConstantInt *const *P2) { const ConstantInt *LHS = *P1; const ConstantInt *RHS = *P2; - if (LHS->getValue().ult(RHS->getValue())) - return 1; - if (LHS->getValue() == RHS->getValue()) + if (LHS == RHS) return 0; - return -1; + return LHS->getValue().ult(RHS->getValue()) ? 1 : -1; } -static inline bool HasBranchWeights(const Instruction* I) { +static inline bool HasBranchWeights(const Instruction *I) { MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof); if (ProfMD && ProfMD->getOperand(0)) - if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) + if (MDString *MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) return MDS->getString().equals("branch_weights"); return false; @@ -837,7 +917,7 @@ static void GetBranchWeights(TerminatorInst *TI, // If TI is a conditional eq, the default case is the false case, // and the corresponding branch-weight data is at index 2. We swap the // default weight to be the first entry. - if (BranchInst* BI = dyn_cast<BranchInst>(TI)) { + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { assert(Weights.size() == 2); ICmpInst *ICI = cast<ICmpInst>(BI->getCondition()); if (ICI->getPredicate() == ICmpInst::ICMP_EQ) @@ -862,17 +942,17 @@ static void FitWeights(MutableArrayRef<uint64_t> Weights) { bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, IRBuilder<> &Builder) { BasicBlock *BB = TI->getParent(); - Value *CV = isValueEqualityComparison(TI); // CondVal + Value *CV = isValueEqualityComparison(TI); // CondVal assert(CV && "Not a comparison?"); bool Changed = false; - SmallVector<BasicBlock*, 16> Preds(pred_begin(BB), pred_end(BB)); + SmallVector<BasicBlock *, 16> Preds(pred_begin(BB), pred_end(BB)); while (!Preds.empty()) { BasicBlock *Pred = Preds.pop_back_val(); // See if the predecessor is a comparison with the same value. TerminatorInst *PTI = Pred->getTerminator(); - Value *PCV = isValueEqualityComparison(PTI); // PredCondVal + Value *PCV = isValueEqualityComparison(PTI); // PredCondVal if (PCV == CV && SafeToMergeTerminators(TI, PTI)) { // Figure out which 'cases' to copy from SI to PSI. @@ -885,7 +965,7 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, // Based on whether the default edge from PTI goes to BB or not, fill in // PredCases and PredDefault with the new switch cases we would like to // build. - SmallVector<BasicBlock*, 8> NewSuccessors; + SmallVector<BasicBlock *, 8> NewSuccessors; // Update the branch weight metadata along the way SmallVector<uint64_t, 8> Weights; @@ -915,7 +995,7 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, if (PredDefault == BB) { // If this is the default destination from PTI, only the edges in TI // that don't occur in PTI, or that branch to BB will be activated. - std::set<ConstantInt*, ConstantIntOrdering> PTIHandled; + std::set<ConstantInt *, ConstantIntOrdering> PTIHandled; for (unsigned i = 0, e = PredCases.size(); i != e; ++i) if (PredCases[i].Dest != BB) PTIHandled.insert(PredCases[i].Value); @@ -925,13 +1005,14 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, if (PredHasWeights || SuccHasWeights) { // Increase weight for the default case. - Weights[0] += Weights[i+1]; - std::swap(Weights[i+1], Weights.back()); + Weights[0] += Weights[i + 1]; + std::swap(Weights[i + 1], Weights.back()); Weights.pop_back(); } PredCases.pop_back(); - --i; --e; + --i; + --e; } // Reconstruct the new switch statement we will be building. @@ -952,8 +1033,8 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, // The default weight is at index 0, so weight for the ith case // should be at index i+1. Scale the cases from successor by // PredDefaultWeight (Weights[0]). - Weights.push_back(Weights[0] * SuccWeights[i+1]); - ValidTotalSuccWeight += SuccWeights[i+1]; + Weights.push_back(Weights[0] * SuccWeights[i + 1]); + ValidTotalSuccWeight += SuccWeights[i + 1]; } } @@ -969,21 +1050,22 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, // If this is not the default destination from PSI, only the edges // in SI that occur in PSI with a destination of BB will be // activated. - std::set<ConstantInt*, ConstantIntOrdering> PTIHandled; - std::map<ConstantInt*, uint64_t> WeightsForHandled; + std::set<ConstantInt *, ConstantIntOrdering> PTIHandled; + std::map<ConstantInt *, uint64_t> WeightsForHandled; for (unsigned i = 0, e = PredCases.size(); i != e; ++i) if (PredCases[i].Dest == BB) { PTIHandled.insert(PredCases[i].Value); if (PredHasWeights || SuccHasWeights) { - WeightsForHandled[PredCases[i].Value] = Weights[i+1]; - std::swap(Weights[i+1], Weights.back()); + WeightsForHandled[PredCases[i].Value] = Weights[i + 1]; + std::swap(Weights[i + 1], Weights.back()); Weights.pop_back(); } std::swap(PredCases[i], PredCases.back()); PredCases.pop_back(); - --i; --e; + --i; + --e; } // Okay, now we know which constants were sent to BB from the @@ -995,17 +1077,16 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, Weights.push_back(WeightsForHandled[BBCases[i].Value]); PredCases.push_back(BBCases[i]); NewSuccessors.push_back(BBCases[i].Dest); - PTIHandled.erase(BBCases[i].Value);// This constant is taken care of + PTIHandled.erase( + BBCases[i].Value); // This constant is taken care of } // If there are any constants vectored to BB that TI doesn't handle, // they must go to the default destination of TI. - for (std::set<ConstantInt*, ConstantIntOrdering>::iterator I = - PTIHandled.begin(), - E = PTIHandled.end(); I != E; ++I) { + for (ConstantInt *I : PTIHandled) { if (PredHasWeights || SuccHasWeights) - Weights.push_back(WeightsForHandled[*I]); - PredCases.push_back(ValueEqualityComparisonCase(*I, BBDefault)); + Weights.push_back(WeightsForHandled[I]); + PredCases.push_back(ValueEqualityComparisonCase(I, BBDefault)); NewSuccessors.push_back(BBDefault); } } @@ -1024,8 +1105,8 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, } // Now that the successors are updated, create the new Switch instruction. - SwitchInst *NewSI = Builder.CreateSwitch(CV, PredDefault, - PredCases.size()); + SwitchInst *NewSI = + Builder.CreateSwitch(CV, PredDefault, PredCases.size()); NewSI->setDebugLoc(PTI->getDebugLoc()); for (ValueEqualityComparisonCase &V : PredCases) NewSI->addCase(V.Value, V.Dest); @@ -1036,9 +1117,9 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - NewSI->setMetadata(LLVMContext::MD_prof, - MDBuilder(BB->getContext()). - createBranchWeights(MDWeights)); + NewSI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(BB->getContext()).createBranchWeights(MDWeights)); } EraseTerminatorInstAndDCECond(PTI); @@ -1052,8 +1133,8 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, if (!InfLoopBlock) { // Insert it at the end of the function, because it's either code, // or it won't matter if it's hot. :) - InfLoopBlock = BasicBlock::Create(BB->getContext(), - "infloop", BB->getParent()); + InfLoopBlock = BasicBlock::Create(BB->getContext(), "infloop", + BB->getParent()); BranchInst::Create(InfLoopBlock, InfLoopBlock); } NewSI->setSuccessor(i, InfLoopBlock); @@ -1070,13 +1151,13 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, // can't hoist the invoke, as there is nowhere to put the select in this case. static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2, Instruction *I1, Instruction *I2) { - for (succ_iterator SI = succ_begin(BB1), E = succ_end(BB1); SI != E; ++SI) { + for (BasicBlock *Succ : successors(BB1)) { PHINode *PN; - for (BasicBlock::iterator BBI = SI->begin(); + for (BasicBlock::iterator BBI = Succ->begin(); (PN = dyn_cast<PHINode>(BBI)); ++BBI) { Value *BB1V = PN->getIncomingValueForBlock(BB1); Value *BB2V = PN->getIncomingValueForBlock(BB2); - if (BB1V != BB2V && (BB1V==I1 || BB2V==I2)) { + if (BB1V != BB2V && (BB1V == I1 || BB2V == I2)) { return false; } } @@ -1096,8 +1177,8 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, // O(M*N) situations here where M and N are the sizes of BB1 and BB2. As // such, we currently just scan for obviously identical instructions in an // identical order. - BasicBlock *BB1 = BI->getSuccessor(0); // The true destination. - BasicBlock *BB2 = BI->getSuccessor(1); // The false destination + BasicBlock *BB1 = BI->getSuccessor(0); // The true destination. + BasicBlock *BB2 = BI->getSuccessor(1); // The false destination BasicBlock::iterator BB1_Itr = BB1->begin(); BasicBlock::iterator BB2_Itr = BB2->begin(); @@ -1135,12 +1216,16 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, if (!I2->use_empty()) I2->replaceAllUsesWith(I1); I1->intersectOptionalDataWith(I2); - unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, LLVMContext::MD_range, - LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, - LLVMContext::MD_nonnull, LLVMContext::MD_invariant_group, - LLVMContext::MD_align, LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null}; + unsigned KnownIDs[] = {LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_fpmath, + LLVMContext::MD_invariant_load, + LLVMContext::MD_nonnull, + LLVMContext::MD_invariant_group, + LLVMContext::MD_align, + LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_mem_parallel_loop_access}; combineMetadata(I1, I2, KnownIDs); I2->eraseFromParent(); Changed = true; @@ -1165,9 +1250,9 @@ HoistTerminator: if (isa<InvokeInst>(I1) && !isSafeToHoistInvoke(BB1, BB2, I1, I2)) return Changed; - for (succ_iterator SI = succ_begin(BB1), E = succ_end(BB1); SI != E; ++SI) { + for (BasicBlock *Succ : successors(BB1)) { PHINode *PN; - for (BasicBlock::iterator BBI = SI->begin(); + for (BasicBlock::iterator BBI = Succ->begin(); (PN = dyn_cast<PHINode>(BBI)); ++BBI) { Value *BB1V = PN->getIncomingValueForBlock(BB1); Value *BB2V = PN->getIncomingValueForBlock(BB2); @@ -1178,7 +1263,7 @@ HoistTerminator: // eliminate undefined control flow then converting it to a select. if (passingValueIsAlwaysUndefined(BB1V, PN) || passingValueIsAlwaysUndefined(BB2V, PN)) - return Changed; + return Changed; if (isa<ConstantExpr>(BB1V) && !isSafeToSpeculativelyExecute(BB1V)) return Changed; @@ -1196,27 +1281,28 @@ HoistTerminator: NT->takeName(I1); } - IRBuilder<true, NoFolder> Builder(NT); + IRBuilder<NoFolder> Builder(NT); // Hoisting one of the terminators from our successor is a great thing. // Unfortunately, the successors of the if/else blocks may have PHI nodes in // them. If they do, all PHI entries for BB1/BB2 must agree for all PHI // nodes, so we insert select instruction to compute the final result. - std::map<std::pair<Value*,Value*>, SelectInst*> InsertedSelects; - for (succ_iterator SI = succ_begin(BB1), E = succ_end(BB1); SI != E; ++SI) { + std::map<std::pair<Value *, Value *>, SelectInst *> InsertedSelects; + for (BasicBlock *Succ : successors(BB1)) { PHINode *PN; - for (BasicBlock::iterator BBI = SI->begin(); + for (BasicBlock::iterator BBI = Succ->begin(); (PN = dyn_cast<PHINode>(BBI)); ++BBI) { Value *BB1V = PN->getIncomingValueForBlock(BB1); Value *BB2V = PN->getIncomingValueForBlock(BB2); - if (BB1V == BB2V) continue; + if (BB1V == BB2V) + continue; // These values do not agree. Insert a select instruction before NT // that determines the right value. SelectInst *&SI = InsertedSelects[std::make_pair(BB1V, BB2V)]; if (!SI) - SI = cast<SelectInst> - (Builder.CreateSelect(BI->getCondition(), BB1V, BB2V, - BB1V->getName()+"."+BB2V->getName())); + SI = cast<SelectInst>( + Builder.CreateSelect(BI->getCondition(), BB1V, BB2V, + BB1V->getName() + "." + BB2V->getName(), BI)); // Make the PHI node use the select for all incoming values for BB1/BB2 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) @@ -1226,8 +1312,8 @@ HoistTerminator: } // Update any PHI nodes in our new successors. - for (succ_iterator SI = succ_begin(BB1), E = succ_end(BB1); SI != E; ++SI) - AddPredecessorToBlock(*SI, BIParent, BB1); + for (BasicBlock *Succ : successors(BB1)) + AddPredecessorToBlock(Succ, BIParent, BB1); EraseTerminatorInstAndDCECond(BI); return true; @@ -1280,10 +1366,12 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { RI2 = BB2->getInstList().rbegin(), RE2 = BB2->getInstList().rend(); // Skip debug info. - while (RI1 != RE1 && isa<DbgInfoIntrinsic>(&*RI1)) ++RI1; + while (RI1 != RE1 && isa<DbgInfoIntrinsic>(&*RI1)) + ++RI1; if (RI1 == RE1) return false; - while (RI2 != RE2 && isa<DbgInfoIntrinsic>(&*RI2)) ++RI2; + while (RI2 != RE2 && isa<DbgInfoIntrinsic>(&*RI2)) + ++RI2; if (RI2 == RE2) return false; // Skip the unconditional branches. @@ -1293,10 +1381,12 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { bool Changed = false; while (RI1 != RE1 && RI2 != RE2) { // Skip debug info. - while (RI1 != RE1 && isa<DbgInfoIntrinsic>(&*RI1)) ++RI1; + while (RI1 != RE1 && isa<DbgInfoIntrinsic>(&*RI1)) + ++RI1; if (RI1 == RE1) return Changed; - while (RI2 != RE2 && isa<DbgInfoIntrinsic>(&*RI2)) ++RI2; + while (RI2 != RE2 && isa<DbgInfoIntrinsic>(&*RI2)) + ++RI2; if (RI2 == RE2) return Changed; @@ -1305,22 +1395,19 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { // I1 and I2 should have a single use in the same PHI node, and they // perform the same operation. // Cannot move control-flow-involving, volatile loads, vaarg, etc. - if (isa<PHINode>(I1) || isa<PHINode>(I2) || - isa<TerminatorInst>(I1) || isa<TerminatorInst>(I2) || - I1->isEHPad() || I2->isEHPad() || + if (isa<PHINode>(I1) || isa<PHINode>(I2) || isa<TerminatorInst>(I1) || + isa<TerminatorInst>(I2) || I1->isEHPad() || I2->isEHPad() || isa<AllocaInst>(I1) || isa<AllocaInst>(I2) || I1->mayHaveSideEffects() || I2->mayHaveSideEffects() || I1->mayReadOrWriteMemory() || I2->mayReadOrWriteMemory() || - !I1->hasOneUse() || !I2->hasOneUse() || - !JointValueMap.count(InstPair)) + !I1->hasOneUse() || !I2->hasOneUse() || !JointValueMap.count(InstPair)) return Changed; // Check whether we should swap the operands of ICmpInst. // TODO: Add support of communativity. ICmpInst *ICmp1 = dyn_cast<ICmpInst>(I1), *ICmp2 = dyn_cast<ICmpInst>(I2); bool SwapOpnds = false; - if (ICmp1 && ICmp2 && - ICmp1->getOperand(0) != ICmp2->getOperand(0) && + if (ICmp1 && ICmp2 && ICmp1->getOperand(0) != ICmp2->getOperand(0) && ICmp1->getOperand(1) != ICmp2->getOperand(1) && (ICmp1->getOperand(0) == ICmp2->getOperand(1) || ICmp1->getOperand(1) == ICmp2->getOperand(0))) { @@ -1343,8 +1430,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { continue; // Early exit if we have more-than one pair of different operands or if // we need a PHI node to replace a constant. - if (Op1Idx != ~0U || - isa<Constant>(I1->getOperand(I)) || + if (Op1Idx != ~0U || isa<Constant>(I1->getOperand(I)) || isa<Constant>(I2->getOperand(I))) { // If we can't sink the instructions, undo the swapping. if (SwapOpnds) @@ -1379,7 +1465,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { // We need to update RE1 and RE2 if we are going to sink the first // instruction in the basic block down. - bool UpdateRE1 = (I1 == BB1->begin()), UpdateRE2 = (I2 == BB2->begin()); + bool UpdateRE1 = (I1 == &BB1->front()), UpdateRE2 = (I2 == &BB2->front()); // Sink the instruction. BBEnd->getInstList().splice(FirstNonPhiInBBEnd->getIterator(), BB1->getInstList(), I1); @@ -1444,22 +1530,26 @@ static Value *isSafeToSpeculateStore(Instruction *I, BasicBlock *BrBB, Value *StorePtr = StoreToHoist->getPointerOperand(); // Look for a store to the same pointer in BrBB. - unsigned MaxNumInstToLookAt = 10; - for (BasicBlock::reverse_iterator RI = BrBB->rbegin(), - RE = BrBB->rend(); RI != RE && (--MaxNumInstToLookAt); ++RI) { - Instruction *CurI = &*RI; + unsigned MaxNumInstToLookAt = 9; + for (Instruction &CurI : reverse(*BrBB)) { + if (!MaxNumInstToLookAt) + break; + // Skip debug info. + if (isa<DbgInfoIntrinsic>(CurI)) + continue; + --MaxNumInstToLookAt; // Could be calling an instruction that effects memory like free(). - if (CurI->mayHaveSideEffects() && !isa<StoreInst>(CurI)) + if (CurI.mayHaveSideEffects() && !isa<StoreInst>(CurI)) return nullptr; - StoreInst *SI = dyn_cast<StoreInst>(CurI); - // Found the previous store make sure it stores to the same location. - if (SI && SI->getPointerOperand() == StorePtr) - // Found the previous store, return its value operand. - return SI->getValueOperand(); - else if (SI) + if (auto *SI = dyn_cast<StoreInst>(&CurI)) { + // Found the previous store make sure it stores to the same location. + if (SI->getPointerOperand() == StorePtr) + // Found the previous store, return its value operand. + return SI->getValueOperand(); return nullptr; // Unknown store. + } } return nullptr; @@ -1562,11 +1652,9 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Do not hoist the instruction if any of its operands are defined but not // used in BB. The transformation will prevent the operand from // being sunk into the use block. - for (User::op_iterator i = I->op_begin(), e = I->op_end(); - i != e; ++i) { + for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i) { Instruction *OpI = dyn_cast<Instruction>(*i); - if (!OpI || OpI->getParent() != BB || - OpI->mayHaveSideEffects()) + if (!OpI || OpI->getParent() != BB || OpI->mayHaveSideEffects()) continue; // Not a candidate for sinking. ++SinkCandidateUseCounts[OpI]; @@ -1576,8 +1664,9 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Consider any sink candidates which are only used in CondBB as costs for // speculation. Note, while we iterate over a DenseMap here, we are summing // and so iteration order isn't significant. - for (SmallDenseMap<Instruction *, unsigned, 4>::iterator I = - SinkCandidateUseCounts.begin(), E = SinkCandidateUseCounts.end(); + for (SmallDenseMap<Instruction *, unsigned, 4>::iterator + I = SinkCandidateUseCounts.begin(), + E = SinkCandidateUseCounts.end(); I != E; ++I) if (I->first->getNumUses() == I->second) { ++SpeculationCost; @@ -1613,8 +1702,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, return false; unsigned OrigCost = OrigCE ? ComputeSpeculationCost(OrigCE, TTI) : 0; unsigned ThenCost = ThenCE ? ComputeSpeculationCost(ThenCE, TTI) : 0; - unsigned MaxCost = 2 * PHINodeFoldingThreshold * - TargetTransformInfo::TCC_Basic; + unsigned MaxCost = + 2 * PHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic; if (OrigCost + ThenCost > MaxCost) return false; @@ -1637,19 +1726,19 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Insert a select of the value of the speculated store. if (SpeculatedStoreValue) { - IRBuilder<true, NoFolder> Builder(BI); + IRBuilder<NoFolder> Builder(BI); Value *TrueV = SpeculatedStore->getValueOperand(); Value *FalseV = SpeculatedStoreValue; if (Invert) std::swap(TrueV, FalseV); - Value *S = Builder.CreateSelect(BrCond, TrueV, FalseV, TrueV->getName() + - "." + FalseV->getName()); + Value *S = Builder.CreateSelect( + BrCond, TrueV, FalseV, TrueV->getName() + "." + FalseV->getName(), BI); SpeculatedStore->setOperand(0, S); } // Metadata can be dependent on the condition we are hoisting above. // Conservatively strip all metadata on the instruction. - for (auto &I: *ThenBB) + for (auto &I : *ThenBB) I.dropUnknownNonDebugMetadata(); // Hoist the instructions. @@ -1657,7 +1746,7 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, ThenBB->begin(), std::prev(ThenBB->end())); // Insert selects and rewrite the PHI operands. - IRBuilder<true, NoFolder> Builder(BI); + IRBuilder<NoFolder> Builder(BI); for (BasicBlock::iterator I = EndBB->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I) { unsigned OrigI = PN->getBasicBlockIndex(BB); @@ -1675,8 +1764,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, Value *TrueV = ThenV, *FalseV = OrigV; if (Invert) std::swap(TrueV, FalseV); - Value *V = Builder.CreateSelect(BrCond, TrueV, FalseV, - TrueV->getName() + "." + FalseV->getName()); + Value *V = Builder.CreateSelect( + BrCond, TrueV, FalseV, TrueV->getName() + "." + FalseV->getName(), BI); PN->setIncomingValue(OrigI, V); PN->setIncomingValue(ThenI, V); } @@ -1685,19 +1774,6 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, return true; } -/// \returns True if this block contains a CallInst with the NoDuplicate -/// attribute. -static bool HasNoDuplicateCall(const BasicBlock *BB) { - for (BasicBlock::const_iterator I = BB->begin(), E = BB->end(); I != E; ++I) { - const CallInst *CI = dyn_cast<CallInst>(I); - if (!CI) - continue; - if (CI->cannotDuplicate()) - return true; - } - return false; -} - /// Return true if we can thread a branch across this block. static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { BranchInst *BI = cast<BranchInst>(BB->getTerminator()); @@ -1706,14 +1782,16 @@ static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI) { if (isa<DbgInfoIntrinsic>(BBI)) continue; - if (Size > 10) return false; // Don't clone large BB's. + if (Size > 10) + return false; // Don't clone large BB's. ++Size; // We can only support instructions that do not define values that are // live outside of the current basic block. for (User *U : BBI->users()) { Instruction *UI = cast<Instruction>(U); - if (UI->getParent() != BB || isa<PHINode>(UI)) return false; + if (UI->getParent() != BB || isa<PHINode>(UI)) + return false; } // Looks ok, continue checking. @@ -1740,32 +1818,41 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { } // Now we know that this block has multiple preds and two succs. - if (!BlockIsSimpleEnoughToThreadThrough(BB)) return false; + if (!BlockIsSimpleEnoughToThreadThrough(BB)) + return false; - if (HasNoDuplicateCall(BB)) return false; + // Can't fold blocks that contain noduplicate or convergent calls. + if (llvm::any_of(*BB, [](const Instruction &I) { + const CallInst *CI = dyn_cast<CallInst>(&I); + return CI && (CI->cannotDuplicate() || CI->isConvergent()); + })) + return false; // Okay, this is a simple enough basic block. See if any phi values are // constants. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { ConstantInt *CB = dyn_cast<ConstantInt>(PN->getIncomingValue(i)); - if (!CB || !CB->getType()->isIntegerTy(1)) continue; + if (!CB || !CB->getType()->isIntegerTy(1)) + continue; // Okay, we now know that all edges from PredBB should be revectored to // branch to RealDest. BasicBlock *PredBB = PN->getIncomingBlock(i); BasicBlock *RealDest = BI->getSuccessor(!CB->getZExtValue()); - if (RealDest == BB) continue; // Skip self loops. + if (RealDest == BB) + continue; // Skip self loops. // Skip if the predecessor's terminator is an indirect branch. - if (isa<IndirectBrInst>(PredBB->getTerminator())) continue; + if (isa<IndirectBrInst>(PredBB->getTerminator())) + continue; // The dest block might have PHI nodes, other predecessors and other // difficult cases. Instead of being smart about this, just insert a new // block that jumps to the destination block, effectively splitting // the edge we are about to create. - BasicBlock *EdgeBB = BasicBlock::Create(BB->getContext(), - RealDest->getName()+".critedge", - RealDest->getParent(), RealDest); + BasicBlock *EdgeBB = + BasicBlock::Create(BB->getContext(), RealDest->getName() + ".critedge", + RealDest->getParent(), RealDest); BranchInst::Create(RealDest, EdgeBB); // Update PHI nodes. @@ -1775,7 +1862,7 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { // instructions into EdgeBB. We know that there will be no uses of the // cloned instructions outside of EdgeBB. BasicBlock::iterator InsertPt = EdgeBB->begin(); - DenseMap<Value*, Value*> TranslateMap; // Track translated values. + DenseMap<Value *, Value *> TranslateMap; // Track translated values. for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI) { if (PHINode *PN = dyn_cast<PHINode>(BBI)) { TranslateMap[PN] = PN->getIncomingValueForBlock(PredBB); @@ -1783,26 +1870,31 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { } // Clone the instruction. Instruction *N = BBI->clone(); - if (BBI->hasName()) N->setName(BBI->getName()+".c"); + if (BBI->hasName()) + N->setName(BBI->getName() + ".c"); // Update operands due to translation. - for (User::op_iterator i = N->op_begin(), e = N->op_end(); - i != e; ++i) { - DenseMap<Value*, Value*>::iterator PI = TranslateMap.find(*i); + for (User::op_iterator i = N->op_begin(), e = N->op_end(); i != e; ++i) { + DenseMap<Value *, Value *>::iterator PI = TranslateMap.find(*i); if (PI != TranslateMap.end()) *i = PI->second; } // Check for trivial simplification. if (Value *V = SimplifyInstruction(N, DL)) { - TranslateMap[&*BBI] = V; - delete N; // Instruction folded away, don't need actual inst + if (!BBI->use_empty()) + TranslateMap[&*BBI] = V; + if (!N->mayHaveSideEffects()) { + delete N; // Instruction folded away, don't need actual inst + N = nullptr; + } } else { - // Insert the new instruction into its new home. - EdgeBB->getInstList().insert(InsertPt, N); if (!BBI->use_empty()) TranslateMap[&*BBI] = N; } + // Insert the new instruction into its new home. + if (N) + EdgeBB->getInstList().insert(InsertPt, N); } // Loop over all of the edges from PredBB to BB, changing them to branch @@ -1852,7 +1944,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // Loop over the PHI's seeing if we can promote them all to select // instructions. While we are at it, keep track of the instructions // that need to be moved to the dominating block. - SmallPtrSet<Instruction*, 4> AggressiveInsts; + SmallPtrSet<Instruction *, 4> AggressiveInsts; unsigned MaxCostVal0 = PHINodeFoldingThreshold, MaxCostVal1 = PHINodeFoldingThreshold; MaxCostVal0 *= TargetTransformInfo::TCC_Basic; @@ -1876,7 +1968,8 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // If we folded the first phi, PN dangles at this point. Refresh it. If // we ran out of PHIs then we simplified them all. PN = dyn_cast<PHINode>(BB->begin()); - if (!PN) return true; + if (!PN) + return true; // Don't fold i1 branches on PHIs which contain binary operators. These can // often be turned into switches and other things. @@ -1886,10 +1979,10 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, isa<BinaryOperator>(IfCond))) return false; - // If we all PHI nodes are promotable, check to make sure that all - // instructions in the predecessor blocks can be promoted as well. If - // not, we won't be able to get rid of the control flow, so it's not - // worth promoting to select instructions. + // If all PHI nodes are promotable, check to make sure that all instructions + // in the predecessor blocks can be promoted as well. If not, we won't be able + // to get rid of the control flow, so it's not worth promoting to select + // instructions. BasicBlock *DomBlock = nullptr; BasicBlock *IfBlock1 = PN->getIncomingBlock(0); BasicBlock *IfBlock2 = PN->getIncomingBlock(1); @@ -1897,11 +1990,12 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, IfBlock1 = nullptr; } else { DomBlock = *pred_begin(IfBlock1); - for (BasicBlock::iterator I = IfBlock1->begin();!isa<TerminatorInst>(I);++I) + for (BasicBlock::iterator I = IfBlock1->begin(); !isa<TerminatorInst>(I); + ++I) if (!AggressiveInsts.count(&*I) && !isa<DbgInfoIntrinsic>(I)) { // This is not an aggressive instruction that we can promote. - // Because of this, we won't be able to get rid of the control - // flow, so the xform is not worth it. + // Because of this, we won't be able to get rid of the control flow, so + // the xform is not worth it. return false; } } @@ -1910,11 +2004,12 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, IfBlock2 = nullptr; } else { DomBlock = *pred_begin(IfBlock2); - for (BasicBlock::iterator I = IfBlock2->begin();!isa<TerminatorInst>(I);++I) + for (BasicBlock::iterator I = IfBlock2->begin(); !isa<TerminatorInst>(I); + ++I) if (!AggressiveInsts.count(&*I) && !isa<DbgInfoIntrinsic>(I)) { // This is not an aggressive instruction that we can promote. - // Because of this, we won't be able to get rid of the control - // flow, so the xform is not worth it. + // Because of this, we won't be able to get rid of the control flow, so + // the xform is not worth it. return false; } } @@ -1925,7 +2020,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // If we can still promote the PHI nodes after this gauntlet of tests, // do all of the PHI's now. Instruction *InsertPt = DomBlock->getTerminator(); - IRBuilder<true, NoFolder> Builder(InsertPt); + IRBuilder<NoFolder> Builder(InsertPt); // Move all 'aggressive' instructions, which are defined in the // conditional parts of the if's up to the dominating block. @@ -1940,13 +2035,12 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, while (PHINode *PN = dyn_cast<PHINode>(BB->begin())) { // Change the PHI node into a select instruction. - Value *TrueVal = PN->getIncomingValue(PN->getIncomingBlock(0) == IfFalse); + Value *TrueVal = PN->getIncomingValue(PN->getIncomingBlock(0) == IfFalse); Value *FalseVal = PN->getIncomingValue(PN->getIncomingBlock(0) == IfTrue); - SelectInst *NV = - cast<SelectInst>(Builder.CreateSelect(IfCond, TrueVal, FalseVal, "")); - PN->replaceAllUsesWith(NV); - NV->takeName(PN); + Value *Sel = Builder.CreateSelect(IfCond, TrueVal, FalseVal, "", InsertPt); + PN->replaceAllUsesWith(Sel); + Sel->takeName(PN); PN->eraseFromParent(); } @@ -2029,51 +2123,32 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI, } else if (isa<UndefValue>(TrueValue)) { TrueValue = FalseValue; } else { - TrueValue = Builder.CreateSelect(BrCond, TrueValue, - FalseValue, "retval"); + TrueValue = + Builder.CreateSelect(BrCond, TrueValue, FalseValue, "retval", BI); } } - Value *RI = !TrueValue ? - Builder.CreateRetVoid() : Builder.CreateRet(TrueValue); + Value *RI = + !TrueValue ? Builder.CreateRetVoid() : Builder.CreateRet(TrueValue); - (void) RI; + (void)RI; DEBUG(dbgs() << "\nCHANGING BRANCH TO TWO RETURNS INTO SELECT:" << "\n " << *BI << "NewRet = " << *RI - << "TRUEBLOCK: " << *TrueSucc << "FALSEBLOCK: "<< *FalseSucc); + << "TRUEBLOCK: " << *TrueSucc << "FALSEBLOCK: " << *FalseSucc); EraseTerminatorInstAndDCECond(BI); return true; } -/// Given a conditional BranchInstruction, retrieve the probabilities of the -/// branch taking each edge. Fills in the two APInt parameters and returns true, -/// or returns false if no or invalid metadata was found. -static bool ExtractBranchMetadata(BranchInst *BI, - uint64_t &ProbTrue, uint64_t &ProbFalse) { - assert(BI->isConditional() && - "Looking for probabilities on unconditional branch?"); - MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof); - if (!ProfileData || ProfileData->getNumOperands() != 3) return false; - ConstantInt *CITrue = - mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1)); - ConstantInt *CIFalse = - mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)); - if (!CITrue || !CIFalse) return false; - ProbTrue = CITrue->getValue().getZExtValue(); - ProbFalse = CIFalse->getValue().getZExtValue(); - return true; -} - /// Return true if the given instruction is available /// in its predecessor block. If yes, the instruction will be removed. static bool checkCSEInPredecessor(Instruction *Inst, BasicBlock *PB) { if (!isa<BinaryOperator>(Inst) && !isa<CmpInst>(Inst)) return false; - for (BasicBlock::iterator I = PB->begin(), E = PB->end(); I != E; I++) { - Instruction *PBI = &*I; + for (Instruction &I : *PB) { + Instruction *PBI = &I; // Check whether Inst and PBI generate the same value. if (Inst->isIdenticalTo(PBI)) { Inst->replaceAllUsesWith(PBI); @@ -2084,6 +2159,29 @@ static bool checkCSEInPredecessor(Instruction *Inst, BasicBlock *PB) { return false; } +/// Return true if either PBI or BI has branch weight available, and store +/// the weights in {Pred|Succ}{True|False}Weight. If one of PBI and BI does +/// not have branch weight, use 1:1 as its weight. +static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI, + uint64_t &PredTrueWeight, + uint64_t &PredFalseWeight, + uint64_t &SuccTrueWeight, + uint64_t &SuccFalseWeight) { + bool PredHasWeights = + PBI->extractProfMetadata(PredTrueWeight, PredFalseWeight); + bool SuccHasWeights = + BI->extractProfMetadata(SuccTrueWeight, SuccFalseWeight); + if (PredHasWeights || SuccHasWeights) { + if (!PredHasWeights) + PredTrueWeight = PredFalseWeight = 1; + if (!SuccHasWeights) + SuccTrueWeight = SuccFalseWeight = 1; + return true; + } else { + return false; + } +} + /// If this basic block is simple enough, and if a predecessor branches to us /// and one of our successors, fold the block into the predecessor and use /// logical operations to pick the right destination. @@ -2103,8 +2201,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { if (PBI->isConditional() && (BI->getSuccessor(0) == PBI->getSuccessor(0) || BI->getSuccessor(0) == PBI->getSuccessor(1))) { - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); - I != E; ) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { Instruction *Curr = &*I++; if (isa<CmpInst>(Curr)) { Cond = Curr; @@ -2122,13 +2219,14 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { if (!Cond || (!isa<CmpInst>(Cond) && !isa<BinaryOperator>(Cond)) || Cond->getParent() != BB || !Cond->hasOneUse()) - return false; + return false; // Make sure the instruction after the condition is the cond branch. BasicBlock::iterator CondIt = ++Cond->getIterator(); // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(CondIt)) ++CondIt; + while (isa<DbgInfoIntrinsic>(CondIt)) + ++CondIt; if (&*CondIt != BI) return false; @@ -2139,7 +2237,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // as "bonus instructions", and only allow this transformation when the // number of the bonus instructions does not exceed a certain threshold. unsigned NumBonusInsts = 0; - for (auto I = BB->begin(); Cond != I; ++I) { + for (auto I = BB->begin(); Cond != &*I; ++I) { // Ignore dbg intrinsics. if (isa<DbgInfoIntrinsic>(I)) continue; @@ -2168,7 +2266,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { return false; // Finally, don't infinitely unroll conditional loops. - BasicBlock *TrueDest = BI->getSuccessor(0); + BasicBlock *TrueDest = BI->getSuccessor(0); BasicBlock *FalseDest = (BI->isConditional()) ? BI->getSuccessor(1) : nullptr; if (TrueDest == BB || FalseDest == BB) return false; @@ -2180,10 +2278,9 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // Check that we have two conditional branches. If there is a PHI node in // the common successor, verify that the same value flows in from both // blocks. - SmallVector<PHINode*, 4> PHIs; + SmallVector<PHINode *, 4> PHIs; if (!PBI || PBI->isUnconditional() || - (BI->isConditional() && - !SafeToMergeTerminators(BI, PBI)) || + (BI->isConditional() && !SafeToMergeTerminators(BI, PBI)) || (!BI->isConditional() && !isProfitableToFoldUnconditional(BI, PBI, Cond, PHIs))) continue; @@ -2193,16 +2290,19 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { bool InvertPredCond = false; if (BI->isConditional()) { - if (PBI->getSuccessor(0) == TrueDest) + if (PBI->getSuccessor(0) == TrueDest) { Opc = Instruction::Or; - else if (PBI->getSuccessor(1) == FalseDest) + } else if (PBI->getSuccessor(1) == FalseDest) { Opc = Instruction::And; - else if (PBI->getSuccessor(0) == FalseDest) - Opc = Instruction::And, InvertPredCond = true; - else if (PBI->getSuccessor(1) == TrueDest) - Opc = Instruction::Or, InvertPredCond = true; - else + } else if (PBI->getSuccessor(0) == FalseDest) { + Opc = Instruction::And; + InvertPredCond = true; + } else if (PBI->getSuccessor(1) == TrueDest) { + Opc = Instruction::Or; + InvertPredCond = true; + } else { continue; + } } else { if (PBI->getSuccessor(0) != TrueDest && PBI->getSuccessor(1) != TrueDest) continue; @@ -2219,8 +2319,8 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { CmpInst *CI = cast<CmpInst>(NewCond); CI->setPredicate(CI->getInversePredicate()); } else { - NewCond = Builder.CreateNot(NewCond, - PBI->getCondition()->getName()+".not"); + NewCond = + Builder.CreateNot(NewCond, PBI->getCondition()->getName() + ".not"); } PBI->setCondition(NewCond); @@ -2234,12 +2334,12 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // We already make sure Cond is the last instruction before BI. Therefore, // all instructions before Cond other than DbgInfoIntrinsic are bonus // instructions. - for (auto BonusInst = BB->begin(); Cond != BonusInst; ++BonusInst) { + for (auto BonusInst = BB->begin(); Cond != &*BonusInst; ++BonusInst) { if (isa<DbgInfoIntrinsic>(BonusInst)) continue; Instruction *NewBonusInst = BonusInst->clone(); RemapInstruction(NewBonusInst, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); VMap[&*BonusInst] = NewBonusInst; // If we moved a load, we cannot any longer claim any knowledge about @@ -2258,49 +2358,49 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // two conditions together. Instruction *New = Cond->clone(); RemapInstruction(New, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); PredBlock->getInstList().insert(PBI->getIterator(), New); New->takeName(Cond); Cond->setName(New->getName() + ".old"); if (BI->isConditional()) { - Instruction *NewCond = - cast<Instruction>(Builder.CreateBinOp(Opc, PBI->getCondition(), - New, "or.cond")); + Instruction *NewCond = cast<Instruction>( + Builder.CreateBinOp(Opc, PBI->getCondition(), New, "or.cond")); PBI->setCondition(NewCond); uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight; - bool PredHasWeights = ExtractBranchMetadata(PBI, PredTrueWeight, - PredFalseWeight); - bool SuccHasWeights = ExtractBranchMetadata(BI, SuccTrueWeight, - SuccFalseWeight); + bool HasWeights = + extractPredSuccWeights(PBI, BI, PredTrueWeight, PredFalseWeight, + SuccTrueWeight, SuccFalseWeight); SmallVector<uint64_t, 8> NewWeights; if (PBI->getSuccessor(0) == BB) { - if (PredHasWeights && SuccHasWeights) { + if (HasWeights) { // PBI: br i1 %x, BB, FalseDest // BI: br i1 %y, TrueDest, FalseDest - //TrueWeight is TrueWeight for PBI * TrueWeight for BI. + // TrueWeight is TrueWeight for PBI * TrueWeight for BI. NewWeights.push_back(PredTrueWeight * SuccTrueWeight); - //FalseWeight is FalseWeight for PBI * TotalWeight for BI + + // FalseWeight is FalseWeight for PBI * TotalWeight for BI + // TrueWeight for PBI * FalseWeight for BI. // We assume that total weights of a BranchInst can fit into 32 bits. // Therefore, we will not have overflow using 64-bit arithmetic. - NewWeights.push_back(PredFalseWeight * (SuccFalseWeight + - SuccTrueWeight) + PredTrueWeight * SuccFalseWeight); + NewWeights.push_back(PredFalseWeight * + (SuccFalseWeight + SuccTrueWeight) + + PredTrueWeight * SuccFalseWeight); } AddPredecessorToBlock(TrueDest, PredBlock, BB); PBI->setSuccessor(0, TrueDest); } if (PBI->getSuccessor(1) == BB) { - if (PredHasWeights && SuccHasWeights) { + if (HasWeights) { // PBI: br i1 %x, TrueDest, BB // BI: br i1 %y, TrueDest, FalseDest - //TrueWeight is TrueWeight for PBI * TotalWeight for BI + + // TrueWeight is TrueWeight for PBI * TotalWeight for BI + // FalseWeight for PBI * TrueWeight for BI. - NewWeights.push_back(PredTrueWeight * (SuccFalseWeight + - SuccTrueWeight) + PredFalseWeight * SuccTrueWeight); - //FalseWeight is FalseWeight for PBI * FalseWeight for BI. + NewWeights.push_back(PredTrueWeight * + (SuccFalseWeight + SuccTrueWeight) + + PredFalseWeight * SuccTrueWeight); + // FalseWeight is FalseWeight for PBI * FalseWeight for BI. NewWeights.push_back(PredFalseWeight * SuccFalseWeight); } AddPredecessorToBlock(FalseDest, PredBlock, BB); @@ -2310,51 +2410,42 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // Halve the weights if any of them cannot fit in an uint32_t FitWeights(NewWeights); - SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(),NewWeights.end()); - PBI->setMetadata(LLVMContext::MD_prof, - MDBuilder(BI->getContext()). - createBranchWeights(MDWeights)); + SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(), + NewWeights.end()); + PBI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(BI->getContext()).createBranchWeights(MDWeights)); } else PBI->setMetadata(LLVMContext::MD_prof, nullptr); } else { // Update PHI nodes in the common successors. for (unsigned i = 0, e = PHIs.size(); i != e; ++i) { ConstantInt *PBI_C = cast<ConstantInt>( - PHIs[i]->getIncomingValueForBlock(PBI->getParent())); + PHIs[i]->getIncomingValueForBlock(PBI->getParent())); assert(PBI_C->getType()->isIntegerTy(1)); Instruction *MergedCond = nullptr; if (PBI->getSuccessor(0) == TrueDest) { // Create (PBI_Cond and PBI_C) or (!PBI_Cond and BI_Value) // PBI_C is true: PBI_Cond or (!PBI_Cond and BI_Value) // is false: !PBI_Cond and BI_Value - Instruction *NotCond = - cast<Instruction>(Builder.CreateNot(PBI->getCondition(), - "not.cond")); - MergedCond = - cast<Instruction>(Builder.CreateBinOp(Instruction::And, - NotCond, New, - "and.cond")); + Instruction *NotCond = cast<Instruction>( + Builder.CreateNot(PBI->getCondition(), "not.cond")); + MergedCond = cast<Instruction>( + Builder.CreateBinOp(Instruction::And, NotCond, New, "and.cond")); if (PBI_C->isOne()) - MergedCond = - cast<Instruction>(Builder.CreateBinOp(Instruction::Or, - PBI->getCondition(), MergedCond, - "or.cond")); + MergedCond = cast<Instruction>(Builder.CreateBinOp( + Instruction::Or, PBI->getCondition(), MergedCond, "or.cond")); } else { // Create (PBI_Cond and BI_Value) or (!PBI_Cond and PBI_C) // PBI_C is true: (PBI_Cond and BI_Value) or (!PBI_Cond) // is false: PBI_Cond and BI_Value - MergedCond = - cast<Instruction>(Builder.CreateBinOp(Instruction::And, - PBI->getCondition(), New, - "and.cond")); + MergedCond = cast<Instruction>(Builder.CreateBinOp( + Instruction::And, PBI->getCondition(), New, "and.cond")); if (PBI_C->isOne()) { - Instruction *NotCond = - cast<Instruction>(Builder.CreateNot(PBI->getCondition(), - "not.cond")); - MergedCond = - cast<Instruction>(Builder.CreateBinOp(Instruction::Or, - NotCond, MergedCond, - "or.cond")); + Instruction *NotCond = cast<Instruction>( + Builder.CreateNot(PBI->getCondition(), "not.cond")); + MergedCond = cast<Instruction>(Builder.CreateBinOp( + Instruction::Or, NotCond, MergedCond, "or.cond")); } } // Update PHI Node. @@ -2371,9 +2462,9 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // could replace PBI's branch probabilities with BI's. // Copy any debug value intrinsics into the end of PredBlock. - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) - if (isa<DbgInfoIntrinsic>(*I)) - I->clone()->insertBefore(PBI); + for (Instruction &I : *BB) + if (isa<DbgInfoIntrinsic>(I)) + I.clone()->insertBefore(PBI); return true; } @@ -2417,7 +2508,7 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB, // where OtherBB is the single other predecessor of BB's only successor. PHINode *PHI = nullptr; BasicBlock *Succ = BB->getSingleSuccessor(); - + for (auto I = Succ->begin(); isa<PHINode>(I); ++I) if (cast<PHINode>(I)->getIncomingValueForBlock(BB) == V) { PHI = cast<PHINode>(I); @@ -2443,8 +2534,8 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB, PHI->addIncoming(V, BB); for (BasicBlock *PredBB : predecessors(Succ)) if (PredBB != BB) - PHI->addIncoming(AlternativeV ? AlternativeV : UndefValue::get(V->getType()), - PredBB); + PHI->addIncoming( + AlternativeV ? AlternativeV : UndefValue::get(V->getType()), PredBB); return PHI; } @@ -2481,10 +2572,9 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, return N <= PHINodeFoldingThreshold; }; - if (!MergeCondStoresAggressively && (!IsWorthwhile(PTB) || - !IsWorthwhile(PFB) || - !IsWorthwhile(QTB) || - !IsWorthwhile(QFB))) + if (!MergeCondStoresAggressively && + (!IsWorthwhile(PTB) || !IsWorthwhile(PFB) || !IsWorthwhile(QTB) || + !IsWorthwhile(QFB))) return false; // For every pointer, there must be exactly two stores, one coming from @@ -2561,7 +2651,7 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, QStore->eraseFromParent(); PStore->eraseFromParent(); - + return true; } @@ -2593,7 +2683,7 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { // We model triangles as a type of diamond with a nullptr "true" block. // Triangles are canonicalized so that the fallthrough edge is represented by // a true condition, as in the diagram above. - // + // BasicBlock *PTB = PBI->getSuccessor(0); BasicBlock *PFB = PBI->getSuccessor(1); BasicBlock *QTB = QBI->getSuccessor(0); @@ -2622,8 +2712,7 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { // the post-dominating block, and the non-fallthroughs must only have one // predecessor. auto HasOnePredAndOneSucc = [](BasicBlock *BB, BasicBlock *P, BasicBlock *S) { - return BB->getSinglePredecessor() == P && - BB->getSingleSuccessor() == S; + return BB->getSinglePredecessor() == P && BB->getSingleSuccessor() == S; }; if (!PostBB || !HasOnePredAndOneSucc(PFB, PBI->getParent(), QBI->getParent()) || @@ -2637,7 +2726,7 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { // OK, this is a sequence of two diamonds or triangles. // Check if there are stores in PTB or PFB that are repeated in QTB or QFB. - SmallPtrSet<Value *,4> PStoreAddresses, QStoreAddresses; + SmallPtrSet<Value *, 4> PStoreAddresses, QStoreAddresses; for (auto *BB : {PTB, PFB}) { if (!BB) continue; @@ -2652,7 +2741,7 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { if (StoreInst *SI = dyn_cast<StoreInst>(&I)) QStoreAddresses.insert(SI->getPointerOperand()); } - + set_intersect(PStoreAddresses, QStoreAddresses); // set_intersect mutates PStoreAddresses in place. Rename it here to make it // clear what it contains. @@ -2684,9 +2773,9 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, if (BB->getSinglePredecessor()) { // Turn this into a branch on constant. bool CondIsTrue = PBI->getSuccessor(0) == BB; - BI->setCondition(ConstantInt::get(Type::getInt1Ty(BB->getContext()), - CondIsTrue)); - return true; // Nuke the branch on constant. + BI->setCondition( + ConstantInt::get(Type::getInt1Ty(BB->getContext()), CondIsTrue)); + return true; // Nuke the branch on constant. } // Otherwise, if there are multiple predecessors, insert a PHI that merges @@ -2702,13 +2791,13 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // Any predecessor where the condition is not computable we keep symbolic. for (pred_iterator PI = PB; PI != PE; ++PI) { BasicBlock *P = *PI; - if ((PBI = dyn_cast<BranchInst>(P->getTerminator())) && - PBI != BI && PBI->isConditional() && - PBI->getCondition() == BI->getCondition() && + if ((PBI = dyn_cast<BranchInst>(P->getTerminator())) && PBI != BI && + PBI->isConditional() && PBI->getCondition() == BI->getCondition() && PBI->getSuccessor(0) != PBI->getSuccessor(1)) { bool CondIsTrue = PBI->getSuccessor(0) == BB; - NewPN->addIncoming(ConstantInt::get(Type::getInt1Ty(BB->getContext()), - CondIsTrue), P); + NewPN->addIncoming( + ConstantInt::get(Type::getInt1Ty(BB->getContext()), CondIsTrue), + P); } else { NewPN->addIncoming(BI->getCondition(), P); } @@ -2723,19 +2812,6 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, if (CE->canTrap()) return false; - // If BI is reached from the true path of PBI and PBI's condition implies - // BI's condition, we know the direction of the BI branch. - if (PBI->getSuccessor(0) == BI->getParent() && - isImpliedCondition(PBI->getCondition(), BI->getCondition(), DL) && - PBI->getSuccessor(0) != PBI->getSuccessor(1) && - BB->getSinglePredecessor()) { - // Turn this into a branch on constant. - auto *OldCond = BI->getCondition(); - BI->setCondition(ConstantInt::getTrue(BB->getContext())); - RecursivelyDeleteTriviallyDeadInstructions(OldCond); - return true; // Nuke the branch on constant. - } - // If both branches are conditional and both contain stores to the same // address, remove the stores from the conditionals and create a conditional // merged store at the end. @@ -2753,16 +2829,21 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, return false; int PBIOp, BIOp; - if (PBI->getSuccessor(0) == BI->getSuccessor(0)) - PBIOp = BIOp = 0; - else if (PBI->getSuccessor(0) == BI->getSuccessor(1)) - PBIOp = 0, BIOp = 1; - else if (PBI->getSuccessor(1) == BI->getSuccessor(0)) - PBIOp = 1, BIOp = 0; - else if (PBI->getSuccessor(1) == BI->getSuccessor(1)) - PBIOp = BIOp = 1; - else + if (PBI->getSuccessor(0) == BI->getSuccessor(0)) { + PBIOp = 0; + BIOp = 0; + } else if (PBI->getSuccessor(0) == BI->getSuccessor(1)) { + PBIOp = 0; + BIOp = 1; + } else if (PBI->getSuccessor(1) == BI->getSuccessor(0)) { + PBIOp = 1; + BIOp = 0; + } else if (PBI->getSuccessor(1) == BI->getSuccessor(1)) { + PBIOp = 1; + BIOp = 1; + } else { return false; + } // Check to make sure that the other destination of this branch // isn't BB itself. If so, this is an infinite loop that will @@ -2780,8 +2861,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, BasicBlock *CommonDest = PBI->getSuccessor(PBIOp); unsigned NumPhis = 0; - for (BasicBlock::iterator II = CommonDest->begin(); - isa<PHINode>(II); ++II, ++NumPhis) { + for (BasicBlock::iterator II = CommonDest->begin(); isa<PHINode>(II); + ++II, ++NumPhis) { if (NumPhis > 2) // Disable this xform. return false; @@ -2804,7 +2885,6 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, DEBUG(dbgs() << "FOLDING BRs:" << *PBI->getParent() << "AND: " << *BI->getParent()); - // If OtherDest *is* BB, then BB is a basic block with a single conditional // branch in it, where one edge (OtherDest) goes back to itself but the other // exits. We don't *know* that the program avoids the infinite loop @@ -2815,8 +2895,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, if (OtherDest == BB) { // Insert it at the end of the function, because it's either code, // or it won't matter if it's hot. :) - BasicBlock *InfLoopBlock = BasicBlock::Create(BB->getContext(), - "infloop", BB->getParent()); + BasicBlock *InfLoopBlock = + BasicBlock::Create(BB->getContext(), "infloop", BB->getParent()); BranchInst::Create(InfLoopBlock, InfLoopBlock); OtherDest = InfLoopBlock; } @@ -2828,13 +2908,13 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // Make sure we get to CommonDest on True&True directions. Value *PBICond = PBI->getCondition(); - IRBuilder<true, NoFolder> Builder(PBI); + IRBuilder<NoFolder> Builder(PBI); if (PBIOp) - PBICond = Builder.CreateNot(PBICond, PBICond->getName()+".not"); + PBICond = Builder.CreateNot(PBICond, PBICond->getName() + ".not"); Value *BICond = BI->getCondition(); if (BIOp) - BICond = Builder.CreateNot(BICond, BICond->getName()+".not"); + BICond = Builder.CreateNot(BICond, BICond->getName() + ".not"); // Merge the conditions. Value *Cond = Builder.CreateOr(PBICond, BICond, "brmerge"); @@ -2846,15 +2926,15 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // Update branch weight for PBI. uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight; - bool PredHasWeights = ExtractBranchMetadata(PBI, PredTrueWeight, - PredFalseWeight); - bool SuccHasWeights = ExtractBranchMetadata(BI, SuccTrueWeight, - SuccFalseWeight); - if (PredHasWeights && SuccHasWeights) { - uint64_t PredCommon = PBIOp ? PredFalseWeight : PredTrueWeight; - uint64_t PredOther = PBIOp ?PredTrueWeight : PredFalseWeight; - uint64_t SuccCommon = BIOp ? SuccFalseWeight : SuccTrueWeight; - uint64_t SuccOther = BIOp ? SuccTrueWeight : SuccFalseWeight; + uint64_t PredCommon, PredOther, SuccCommon, SuccOther; + bool HasWeights = + extractPredSuccWeights(PBI, BI, PredTrueWeight, PredFalseWeight, + SuccTrueWeight, SuccFalseWeight); + if (HasWeights) { + PredCommon = PBIOp ? PredFalseWeight : PredTrueWeight; + PredOther = PBIOp ? PredTrueWeight : PredFalseWeight; + SuccCommon = BIOp ? SuccFalseWeight : SuccTrueWeight; + SuccOther = BIOp ? SuccTrueWeight : SuccFalseWeight; // The weight to CommonDest should be PredCommon * SuccTotal + // PredOther * SuccCommon. // The weight to OtherDest should be PredOther * SuccOther. @@ -2885,9 +2965,29 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, Value *PBIV = PN->getIncomingValue(PBBIdx); if (BIV != PBIV) { // Insert a select in PBI to pick the right value. - Value *NV = cast<SelectInst> - (Builder.CreateSelect(PBICond, PBIV, BIV, PBIV->getName()+".mux")); + SelectInst *NV = cast<SelectInst>( + Builder.CreateSelect(PBICond, PBIV, BIV, PBIV->getName() + ".mux")); PN->setIncomingValue(PBBIdx, NV); + // Although the select has the same condition as PBI, the original branch + // weights for PBI do not apply to the new select because the select's + // 'logical' edges are incoming edges of the phi that is eliminated, not + // the outgoing edges of PBI. + if (HasWeights) { + uint64_t PredCommon = PBIOp ? PredFalseWeight : PredTrueWeight; + uint64_t PredOther = PBIOp ? PredTrueWeight : PredFalseWeight; + uint64_t SuccCommon = BIOp ? SuccFalseWeight : SuccTrueWeight; + uint64_t SuccOther = BIOp ? SuccTrueWeight : SuccFalseWeight; + // The weight to PredCommonDest should be PredCommon * SuccTotal. + // The weight to PredOtherDest should be PredOther * SuccCommon. + uint64_t NewWeights[2] = {PredCommon * (SuccCommon + SuccOther), + PredOther * SuccCommon}; + + FitWeights(NewWeights); + + NV->setMetadata(LLVMContext::MD_prof, + MDBuilder(BI->getContext()) + .createBranchWeights(NewWeights[0], NewWeights[1])); + } } } @@ -2907,7 +3007,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, static bool SimplifyTerminatorOnSelect(TerminatorInst *OldTerm, Value *Cond, BasicBlock *TrueBB, BasicBlock *FalseBB, uint32_t TrueWeight, - uint32_t FalseWeight){ + uint32_t FalseWeight) { // Remove any superfluous successor edges from the CFG. // First, figure out which successors to preserve. // If TrueBB and FalseBB are equal, only try to preserve one copy of that @@ -2942,8 +3042,8 @@ static bool SimplifyTerminatorOnSelect(TerminatorInst *OldTerm, Value *Cond, BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB); if (TrueWeight != FalseWeight) NewBI->setMetadata(LLVMContext::MD_prof, - MDBuilder(OldTerm->getContext()). - createBranchWeights(TrueWeight, FalseWeight)); + MDBuilder(OldTerm->getContext()) + .createBranchWeights(TrueWeight, FalseWeight)); } } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) { // Neither of the selected blocks were successors, so this @@ -2988,16 +3088,16 @@ static bool SimplifySwitchOnSelect(SwitchInst *SI, SelectInst *Select) { if (HasWeights) { GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { - TrueWeight = (uint32_t)Weights[SI->findCaseValue(TrueVal). - getSuccessorIndex()]; - FalseWeight = (uint32_t)Weights[SI->findCaseValue(FalseVal). - getSuccessorIndex()]; + TrueWeight = + (uint32_t)Weights[SI->findCaseValue(TrueVal).getSuccessorIndex()]; + FalseWeight = + (uint32_t)Weights[SI->findCaseValue(FalseVal).getSuccessorIndex()]; } } // Perform the actual simplification. - return SimplifyTerminatorOnSelect(SI, Condition, TrueBB, FalseBB, - TrueWeight, FalseWeight); + return SimplifyTerminatorOnSelect(SI, Condition, TrueBB, FalseBB, TrueWeight, + FalseWeight); } // Replaces @@ -3017,8 +3117,8 @@ static bool SimplifyIndirectBrOnSelect(IndirectBrInst *IBI, SelectInst *SI) { BasicBlock *FalseBB = FBA->getBasicBlock(); // Perform the actual simplification. - return SimplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, - 0, 0); + return SimplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB, 0, + 0); } /// This is called when we find an icmp instruction @@ -3046,7 +3146,8 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( // If the block has any PHIs in it or the icmp has multiple uses, it is too // complex. - if (isa<PHINode>(BB->begin()) || !ICI->hasOneUse()) return false; + if (isa<PHINode>(BB->begin()) || !ICI->hasOneUse()) + return false; Value *V = ICI->getOperand(0); ConstantInt *Cst = cast<ConstantInt>(ICI->getOperand(1)); @@ -3055,7 +3156,8 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( // 'V' and this block is the default case for the switch. In this case we can // fold the compared value into the switch to simplify things. BasicBlock *Pred = BB->getSinglePredecessor(); - if (!Pred || !isa<SwitchInst>(Pred->getTerminator())) return false; + if (!Pred || !isa<SwitchInst>(Pred->getTerminator())) + return false; SwitchInst *SI = cast<SwitchInst>(Pred->getTerminator()); if (SI->getCondition() != V) @@ -3104,7 +3206,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( // If the icmp is a SETEQ, then the default dest gets false, the new edge gets // true in the PHI. Constant *DefaultCst = ConstantInt::getTrue(BB->getContext()); - Constant *NewCst = ConstantInt::getFalse(BB->getContext()); + Constant *NewCst = ConstantInt::getFalse(BB->getContext()); if (ICI->getPredicate() == ICmpInst::ICMP_EQ) std::swap(DefaultCst, NewCst); @@ -3116,21 +3218,21 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( // Okay, the switch goes to this block on a default value. Add an edge from // the switch to the merge point on the compared value. - BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "switch.edge", - BB->getParent(), BB); + BasicBlock *NewBB = + BasicBlock::Create(BB->getContext(), "switch.edge", BB->getParent(), BB); SmallVector<uint64_t, 8> Weights; bool HasWeights = HasBranchWeights(SI); if (HasWeights) { GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { // Split weight for default case to case for "Cst". - Weights[0] = (Weights[0]+1) >> 1; + Weights[0] = (Weights[0] + 1) >> 1; Weights.push_back(Weights[0]); SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - SI->setMetadata(LLVMContext::MD_prof, - MDBuilder(SI->getContext()). - createBranchWeights(MDWeights)); + SI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(SI->getContext()).createBranchWeights(MDWeights)); } } SI->addCase(Cst, NewBB); @@ -3149,7 +3251,8 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, const DataLayout &DL) { Instruction *Cond = dyn_cast<Instruction>(BI->getCondition()); - if (!Cond) return false; + if (!Cond) + return false; // Change br (X == 0 | X == 1), T, F into a switch instruction. // If this is a bunch of seteq's or'd together, or if it's a bunch of @@ -3158,13 +3261,14 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, // Try to gather values from a chain of and/or to be turned into a switch ConstantComparesGatherer ConstantCompare(Cond, DL); // Unpack the result - SmallVectorImpl<ConstantInt*> &Values = ConstantCompare.Vals; + SmallVectorImpl<ConstantInt *> &Values = ConstantCompare.Vals; Value *CompVal = ConstantCompare.CompValue; unsigned UsedICmps = ConstantCompare.UsedICmps; Value *ExtraCase = ConstantCompare.Extra; // If we didn't have a multiply compared value, fail. - if (!CompVal) return false; + if (!CompVal) + return false; // Avoid turning single icmps into a switch. if (UsedICmps <= 1) @@ -3179,20 +3283,23 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, // If Extra was used, we require at least two switch values to do the // transformation. A switch with one value is just a conditional branch. - if (ExtraCase && Values.size() < 2) return false; + if (ExtraCase && Values.size() < 2) + return false; // TODO: Preserve branch weight metadata, similarly to how // FoldValueComparisonIntoPredecessors preserves it. // Figure out which block is which destination. BasicBlock *DefaultBB = BI->getSuccessor(1); - BasicBlock *EdgeBB = BI->getSuccessor(0); - if (!TrueWhenEqual) std::swap(DefaultBB, EdgeBB); + BasicBlock *EdgeBB = BI->getSuccessor(0); + if (!TrueWhenEqual) + std::swap(DefaultBB, EdgeBB); BasicBlock *BB = BI->getParent(); DEBUG(dbgs() << "Converting 'icmp' chain with " << Values.size() - << " cases into SWITCH. BB is:\n" << *BB); + << " cases into SWITCH. BB is:\n" + << *BB); // If there are any extra values that couldn't be folded into the switch // then we evaluate them with an explicit branch first. Split the block @@ -3216,7 +3323,7 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, AddPredecessorToBlock(EdgeBB, BB, NewBB); DEBUG(dbgs() << " ** 'icmp' chain unhandled condition: " << *ExtraCase - << "\nEXTRABB = " << *BB); + << "\nEXTRABB = " << *BB); BB = NewBB; } @@ -3237,11 +3344,10 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, // We added edges from PI to the EdgeBB. As such, if there were any // PHI nodes in EdgeBB, they need entries to be added corresponding to // the number of edges added. - for (BasicBlock::iterator BBI = EdgeBB->begin(); - isa<PHINode>(BBI); ++BBI) { + for (BasicBlock::iterator BBI = EdgeBB->begin(); isa<PHINode>(BBI); ++BBI) { PHINode *PN = cast<PHINode>(BBI); Value *InVal = PN->getIncomingValueForBlock(BB); - for (unsigned i = 0, e = Values.size()-1; i != e; ++i) + for (unsigned i = 0, e = Values.size() - 1; i != e; ++i) PN->addIncoming(InVal, BB); } @@ -3270,7 +3376,7 @@ bool SimplifyCFGOpt::SimplifyCommonResume(ResumeInst *RI) { // Check that there are no other instructions except for debug intrinsics // between the phi of landing pads (RI->getValue()) and resume instruction. BasicBlock::iterator I = cast<Instruction>(RI->getValue())->getIterator(), - E = RI->getIterator(); + E = RI->getIterator(); while (++I != E) if (!isa<DbgInfoIntrinsic>(I)) return false; @@ -3279,8 +3385,8 @@ bool SimplifyCFGOpt::SimplifyCommonResume(ResumeInst *RI) { auto *PhiLPInst = cast<PHINode>(RI->getValue()); // Check incoming blocks to see if any of them are trivial. - for (unsigned Idx = 0, End = PhiLPInst->getNumIncomingValues(); - Idx != End; Idx++) { + for (unsigned Idx = 0, End = PhiLPInst->getNumIncomingValues(); Idx != End; + Idx++) { auto *IncomingBB = PhiLPInst->getIncomingBlock(Idx); auto *IncomingValue = PhiLPInst->getIncomingValue(Idx); @@ -3289,8 +3395,7 @@ bool SimplifyCFGOpt::SimplifyCommonResume(ResumeInst *RI) { if (IncomingBB->getUniqueSuccessor() != BB) continue; - auto *LandingPad = - dyn_cast<LandingPadInst>(IncomingBB->getFirstNonPHI()); + auto *LandingPad = dyn_cast<LandingPadInst>(IncomingBB->getFirstNonPHI()); // Not the landing pad that caused the control to branch here. if (IncomingValue != LandingPad) continue; @@ -3310,7 +3415,8 @@ bool SimplifyCFGOpt::SimplifyCommonResume(ResumeInst *RI) { } // If no trivial unwind blocks, don't do any simplifications. - if (TrivialUnwindBlocks.empty()) return false; + if (TrivialUnwindBlocks.empty()) + return false; // Turn all invokes that unwind here into calls. for (auto *TrivialBB : TrivialUnwindBlocks) { @@ -3346,8 +3452,8 @@ bool SimplifyCFGOpt::SimplifyCommonResume(ResumeInst *RI) { bool SimplifyCFGOpt::SimplifySingleResume(ResumeInst *RI) { BasicBlock *BB = RI->getParent(); LandingPadInst *LPInst = dyn_cast<LandingPadInst>(BB->getFirstNonPHI()); - assert (RI->getValue() == LPInst && - "Resume must unwind the exception that caused control to here"); + assert(RI->getValue() == LPInst && + "Resume must unwind the exception that caused control to here"); // Check that there are no other instructions except for debug intrinsics. BasicBlock::iterator I = LPInst->getIterator(), E = RI->getIterator(); @@ -3363,10 +3469,12 @@ bool SimplifyCFGOpt::SimplifySingleResume(ResumeInst *RI) { // The landingpad is now unreachable. Zap it. BB->eraseFromParent(); + if (LoopHeaders) + LoopHeaders->erase(BB); return true; } -bool SimplifyCFGOpt::SimplifyCleanupReturn(CleanupReturnInst *RI) { +static bool removeEmptyCleanup(CleanupReturnInst *RI) { // If this is a trivial cleanup pad that executes no instructions, it can be // eliminated. If the cleanup pad continues to the caller, any predecessor // that is an EH pad will be updated to continue to the caller and any @@ -3381,12 +3489,29 @@ bool SimplifyCFGOpt::SimplifyCleanupReturn(CleanupReturnInst *RI) { // This isn't an empty cleanup. return false; - // Check that there are no other instructions except for debug intrinsics. + // We cannot kill the pad if it has multiple uses. This typically arises + // from unreachable basic blocks. + if (!CPInst->hasOneUse()) + return false; + + // Check that there are no other instructions except for benign intrinsics. BasicBlock::iterator I = CPInst->getIterator(), E = RI->getIterator(); - while (++I != E) - if (!isa<DbgInfoIntrinsic>(I)) + while (++I != E) { + auto *II = dyn_cast<IntrinsicInst>(I); + if (!II) return false; + Intrinsic::ID IntrinsicID = II->getIntrinsicID(); + switch (IntrinsicID) { + case Intrinsic::dbg_declare: + case Intrinsic::dbg_value: + case Intrinsic::lifetime_end: + break; + default: + return false; + } + } + // If the cleanup return we are simplifying unwinds to the caller, this will // set UnwindDest to nullptr. BasicBlock *UnwindDest = RI->getUnwindDest(); @@ -3430,7 +3555,7 @@ bool SimplifyCFGOpt::SimplifyCleanupReturn(CleanupReturnInst *RI) { // removing, we need to merge that PHI node's incoming values into // DestPN. for (unsigned SrcIdx = 0, SrcE = SrcPN->getNumIncomingValues(); - SrcIdx != SrcE; ++SrcIdx) { + SrcIdx != SrcE; ++SrcIdx) { DestPN->addIncoming(SrcPN->getIncomingValue(SrcIdx), SrcPN->getIncomingBlock(SrcIdx)); } @@ -3484,13 +3609,63 @@ bool SimplifyCFGOpt::SimplifyCleanupReturn(CleanupReturnInst *RI) { return true; } +// Try to merge two cleanuppads together. +static bool mergeCleanupPad(CleanupReturnInst *RI) { + // Skip any cleanuprets which unwind to caller, there is nothing to merge + // with. + BasicBlock *UnwindDest = RI->getUnwindDest(); + if (!UnwindDest) + return false; + + // This cleanupret isn't the only predecessor of this cleanuppad, it wouldn't + // be safe to merge without code duplication. + if (UnwindDest->getSinglePredecessor() != RI->getParent()) + return false; + + // Verify that our cleanuppad's unwind destination is another cleanuppad. + auto *SuccessorCleanupPad = dyn_cast<CleanupPadInst>(&UnwindDest->front()); + if (!SuccessorCleanupPad) + return false; + + CleanupPadInst *PredecessorCleanupPad = RI->getCleanupPad(); + // Replace any uses of the successor cleanupad with the predecessor pad + // The only cleanuppad uses should be this cleanupret, it's cleanupret and + // funclet bundle operands. + SuccessorCleanupPad->replaceAllUsesWith(PredecessorCleanupPad); + // Remove the old cleanuppad. + SuccessorCleanupPad->eraseFromParent(); + // Now, we simply replace the cleanupret with a branch to the unwind + // destination. + BranchInst::Create(UnwindDest, RI->getParent()); + RI->eraseFromParent(); + + return true; +} + +bool SimplifyCFGOpt::SimplifyCleanupReturn(CleanupReturnInst *RI) { + // It is possible to transiantly have an undef cleanuppad operand because we + // have deleted some, but not all, dead blocks. + // Eventually, this block will be deleted. + if (isa<UndefValue>(RI->getOperand(0))) + return false; + + if (mergeCleanupPad(RI)) + return true; + + if (removeEmptyCleanup(RI)) + return true; + + return false; +} + bool SimplifyCFGOpt::SimplifyReturn(ReturnInst *RI, IRBuilder<> &Builder) { BasicBlock *BB = RI->getParent(); - if (!BB->getFirstNonPHIOrDbg()->isTerminator()) return false; + if (!BB->getFirstNonPHIOrDbg()->isTerminator()) + return false; // Find predecessors that end with branches. - SmallVector<BasicBlock*, 8> UncondBranchPreds; - SmallVector<BranchInst*, 8> CondBranchPreds; + SmallVector<BasicBlock *, 8> UncondBranchPreds; + SmallVector<BranchInst *, 8> CondBranchPreds; for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { BasicBlock *P = *PI; TerminatorInst *PTI = P->getTerminator(); @@ -3507,14 +3682,17 @@ bool SimplifyCFGOpt::SimplifyReturn(ReturnInst *RI, IRBuilder<> &Builder) { while (!UncondBranchPreds.empty()) { BasicBlock *Pred = UncondBranchPreds.pop_back_val(); DEBUG(dbgs() << "FOLDING: " << *BB - << "INTO UNCOND BRANCH PRED: " << *Pred); + << "INTO UNCOND BRANCH PRED: " << *Pred); (void)FoldReturnIntoUncondBranch(RI, BB, Pred); } // If we eliminated all predecessors of the block, delete the block now. - if (pred_empty(BB)) + if (pred_empty(BB)) { // We know there are no successors, so just nuke the block. BB->eraseFromParent(); + if (LoopHeaders) + LoopHeaders->erase(BB); + } return true; } @@ -3547,7 +3725,8 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { // Do not delete instructions that can have side effects which might cause // the unreachable to not be reachable; specifically, calls and volatile // operations may have this effect. - if (isa<CallInst>(BBI) && !isa<DbgInfoIntrinsic>(BBI)) break; + if (isa<CallInst>(BBI) && !isa<DbgInfoIntrinsic>(BBI)) + break; if (BBI->mayHaveSideEffects()) { if (auto *SI = dyn_cast<StoreInst>(BBI)) { @@ -3589,9 +3768,10 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { // If the unreachable instruction is the first in the block, take a gander // at all of the predecessors of this instruction, and simplify them. - if (&BB->front() != UI) return Changed; + if (&BB->front() != UI) + return Changed; - SmallVector<BasicBlock*, 8> Preds(pred_begin(BB), pred_end(BB)); + SmallVector<BasicBlock *, 8> Preds(pred_begin(BB), pred_end(BB)); for (unsigned i = 0, e = Preds.size(); i != e; ++i) { TerminatorInst *TI = Preds[i]->getTerminator(); IRBuilder<> Builder(TI); @@ -3613,12 +3793,13 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { } } } else if (auto *SI = dyn_cast<SwitchInst>(TI)) { - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; + ++i) if (i.getCaseSuccessor() == BB) { BB->removePredecessor(SI->getParent()); SI->removeCase(i); - --i; --e; + --i; + --e; Changed = true; } } else if (auto *II = dyn_cast<InvokeInst>(TI)) { @@ -3667,10 +3848,11 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { } // If this block is now dead, remove it. - if (pred_empty(BB) && - BB != &BB->getParent()->getEntryBlock()) { + if (pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) { // We know there are no successors, so just nuke the block. BB->eraseFromParent(); + if (LoopHeaders) + LoopHeaders->erase(BB); return true; } @@ -3699,25 +3881,28 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { // Partition the cases into two sets with different destinations. BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr; BasicBlock *DestB = nullptr; - SmallVector <ConstantInt *, 16> CasesA; - SmallVector <ConstantInt *, 16> CasesB; + SmallVector<ConstantInt *, 16> CasesA; + SmallVector<ConstantInt *, 16> CasesB; for (SwitchInst::CaseIt I : SI->cases()) { BasicBlock *Dest = I.getCaseSuccessor(); - if (!DestA) DestA = Dest; + if (!DestA) + DestA = Dest; if (Dest == DestA) { CasesA.push_back(I.getCaseValue()); continue; } - if (!DestB) DestB = Dest; + if (!DestB) + DestB = Dest; if (Dest == DestB) { CasesB.push_back(I.getCaseValue()); continue; } - return false; // More than two destinations. + return false; // More than two destinations. } - assert(DestA && DestB && "Single-destination switch should have been folded."); + assert(DestA && DestB && + "Single-destination switch should have been folded."); assert(DestA != DestB); assert(DestB != SI->getDefaultDest()); assert(!CasesB.empty() && "There must be non-default cases."); @@ -3741,7 +3926,8 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { // Start building the compare and branch. Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back()); - Constant *NumCases = ConstantInt::get(Offset->getType(), ContiguousCases->size()); + Constant *NumCases = + ConstantInt::get(Offset->getType(), ContiguousCases->size()); Value *Sub = SI->getCondition(); if (!Offset->isNullValue()) @@ -3773,21 +3959,24 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { FalseWeight /= 2; } NewBI->setMetadata(LLVMContext::MD_prof, - MDBuilder(SI->getContext()).createBranchWeights( - (uint32_t)TrueWeight, (uint32_t)FalseWeight)); + MDBuilder(SI->getContext()) + .createBranchWeights((uint32_t)TrueWeight, + (uint32_t)FalseWeight)); } } // Prune obsolete incoming values off the successors' PHI nodes. for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) { unsigned PreviousEdges = ContiguousCases->size(); - if (ContiguousDest == SI->getDefaultDest()) ++PreviousEdges; + if (ContiguousDest == SI->getDefaultDest()) + ++PreviousEdges; for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) cast<PHINode>(BBI)->removeIncomingValue(SI->getParent()); } for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) { unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size(); - if (OtherDest == SI->getDefaultDest()) ++PreviousEdges; + if (OtherDest == SI->getDefaultDest()) + ++PreviousEdges; for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I) cast<PHINode>(BBI)->removeIncomingValue(SI->getParent()); } @@ -3807,32 +3996,38 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, APInt KnownZero(Bits, 0), KnownOne(Bits, 0); computeKnownBits(Cond, KnownZero, KnownOne, DL, 0, AC, SI); + // We can also eliminate cases by determining that their values are outside of + // the limited range of the condition based on how many significant (non-sign) + // bits are in the condition value. + unsigned ExtraSignBits = ComputeNumSignBits(Cond, DL, 0, AC, SI) - 1; + unsigned MaxSignificantBitsInCond = Bits - ExtraSignBits; + // Gather dead cases. - SmallVector<ConstantInt*, 8> DeadCases; - for (SwitchInst::CaseIt I = SI->case_begin(), E = SI->case_end(); I != E; ++I) { - if ((I.getCaseValue()->getValue() & KnownZero) != 0 || - (I.getCaseValue()->getValue() & KnownOne) != KnownOne) { - DeadCases.push_back(I.getCaseValue()); - DEBUG(dbgs() << "SimplifyCFG: switch case '" - << I.getCaseValue() << "' is dead.\n"); + SmallVector<ConstantInt *, 8> DeadCases; + for (auto &Case : SI->cases()) { + APInt CaseVal = Case.getCaseValue()->getValue(); + if ((CaseVal & KnownZero) != 0 || (CaseVal & KnownOne) != KnownOne || + (CaseVal.getMinSignedBits() > MaxSignificantBitsInCond)) { + DeadCases.push_back(Case.getCaseValue()); + DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal << " is dead.\n"); } } - // If we can prove that the cases must cover all possible values, the - // default destination becomes dead and we can remove it. If we know some + // If we can prove that the cases must cover all possible values, the + // default destination becomes dead and we can remove it. If we know some // of the bits in the value, we can use that to more precisely compute the // number of possible unique case values. bool HasDefault = - !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()); - const unsigned NumUnknownBits = Bits - - (KnownZero.Or(KnownOne)).countPopulation(); + !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()); + const unsigned NumUnknownBits = + Bits - (KnownZero.Or(KnownOne)).countPopulation(); assert(NumUnknownBits <= Bits); if (HasDefault && DeadCases.empty() && - NumUnknownBits < 64 /* avoid overflow */ && + NumUnknownBits < 64 /* avoid overflow */ && SI->getNumCases() == (1ULL << NumUnknownBits)) { DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n"); - BasicBlock *NewDefault = SplitBlockPredecessors(SI->getDefaultDest(), - SI->getParent(), ""); + BasicBlock *NewDefault = + SplitBlockPredecessors(SI->getDefaultDest(), SI->getParent(), ""); SI->setDefaultDest(&*NewDefault); SplitBlock(&*NewDefault, &NewDefault->front()); auto *OldTI = NewDefault->getTerminator(); @@ -3849,12 +4044,12 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, } // Remove dead cases from the switch. - for (unsigned I = 0, E = DeadCases.size(); I != E; ++I) { - SwitchInst::CaseIt Case = SI->findCaseValue(DeadCases[I]); + for (ConstantInt *DeadCase : DeadCases) { + SwitchInst::CaseIt Case = SI->findCaseValue(DeadCase); assert(Case != SI->case_default() && "Case was not found. Probably mistake in DeadCases forming."); if (HasWeight) { - std::swap(Weights[Case.getCaseIndex()+1], Weights.back()); + std::swap(Weights[Case.getCaseIndex() + 1], Weights.back()); Weights.pop_back(); } @@ -3865,8 +4060,8 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, if (HasWeight && Weights.size() >= 2) { SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); SI->setMetadata(LLVMContext::MD_prof, - MDBuilder(SI->getParent()->getContext()). - createBranchWeights(MDWeights)); + MDBuilder(SI->getParent()->getContext()) + .createBranchWeights(MDWeights)); } return !DeadCases.empty(); @@ -3878,8 +4073,7 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, /// block and see if the incoming value is equal to CaseValue. If so, return /// the phi node, and set PhiIndex to BB's index in the phi node. static PHINode *FindPHIForConditionForwarding(ConstantInt *CaseValue, - BasicBlock *BB, - int *PhiIndex) { + BasicBlock *BB, int *PhiIndex) { if (BB->getFirstNonPHIOrDbg() != BB->getTerminator()) return nullptr; // BB must be empty to be a candidate for simplification. if (!BB->getSinglePredecessor()) @@ -3897,7 +4091,8 @@ static PHINode *FindPHIForConditionForwarding(ConstantInt *CaseValue, assert(Idx >= 0 && "PHI has no entry for predecessor?"); Value *InValue = PHI->getIncomingValue(Idx); - if (InValue != CaseValue) continue; + if (InValue != CaseValue) + continue; *PhiIndex = Idx; return PHI; @@ -3911,17 +4106,19 @@ static PHINode *FindPHIForConditionForwarding(ConstantInt *CaseValue, /// blocks of the switch can be folded away. /// Returns true if a change is made. static bool ForwardSwitchConditionToPHI(SwitchInst *SI) { - typedef DenseMap<PHINode*, SmallVector<int,4> > ForwardingNodesMap; + typedef DenseMap<PHINode *, SmallVector<int, 4>> ForwardingNodesMap; ForwardingNodesMap ForwardingNodes; - for (SwitchInst::CaseIt I = SI->case_begin(), E = SI->case_end(); I != E; ++I) { + for (SwitchInst::CaseIt I = SI->case_begin(), E = SI->case_end(); I != E; + ++I) { ConstantInt *CaseValue = I.getCaseValue(); BasicBlock *CaseDest = I.getCaseSuccessor(); int PhiIndex; - PHINode *PHI = FindPHIForConditionForwarding(CaseValue, CaseDest, - &PhiIndex); - if (!PHI) continue; + PHINode *PHI = + FindPHIForConditionForwarding(CaseValue, CaseDest, &PhiIndex); + if (!PHI) + continue; ForwardingNodes[PHI].push_back(PhiIndex); } @@ -3929,11 +4126,13 @@ static bool ForwardSwitchConditionToPHI(SwitchInst *SI) { bool Changed = false; for (ForwardingNodesMap::iterator I = ForwardingNodes.begin(), - E = ForwardingNodes.end(); I != E; ++I) { + E = ForwardingNodes.end(); + I != E; ++I) { PHINode *Phi = I->first; SmallVectorImpl<int> &Indexes = I->second; - if (Indexes.size() < 2) continue; + if (Indexes.size() < 2) + continue; for (size_t I = 0, E = Indexes.size(); I != E; ++I) Phi->setIncomingValue(Indexes[I], SI->getCondition()); @@ -3954,17 +4153,16 @@ static bool ValidLookupTableConstant(Constant *C) { if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) return CE->isGEPWithNoNotionalOverIndexing(); - return isa<ConstantFP>(C) || - isa<ConstantInt>(C) || - isa<ConstantPointerNull>(C) || - isa<GlobalValue>(C) || - isa<UndefValue>(C); + return isa<ConstantFP>(C) || isa<ConstantInt>(C) || + isa<ConstantPointerNull>(C) || isa<GlobalValue>(C) || + isa<UndefValue>(C); } /// If V is a Constant, return it. Otherwise, try to look up /// its constant value in ConstantPool, returning 0 if it's not there. -static Constant *LookupConstant(Value *V, - const SmallDenseMap<Value*, Constant*>& ConstantPool) { +static Constant * +LookupConstant(Value *V, + const SmallDenseMap<Value *, Constant *> &ConstantPool) { if (Constant *C = dyn_cast<Constant>(V)) return C; return ConstantPool.lookup(V); @@ -4001,7 +4199,7 @@ ConstantFold(Instruction *I, const DataLayout &DL, COps[1], DL); } - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), COps, DL); + return ConstantFoldInstOperands(I, COps, DL); } /// Try to determine the resulting constant values in phi nodes @@ -4018,7 +4216,7 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, // If CaseDest is empty except for some side-effect free instructions through // which we can constant-propagate the CaseVal, continue to its successor. - SmallDenseMap<Value*, Constant*> ConstantPool; + SmallDenseMap<Value *, Constant *> ConstantPool; ConstantPool.insert(std::make_pair(SI->getCondition(), CaseVal)); for (BasicBlock::iterator I = CaseDest->begin(), E = CaseDest->end(); I != E; ++I) { @@ -4068,8 +4266,8 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, if (Idx == -1) continue; - Constant *ConstVal = LookupConstant(PHI->getIncomingValue(Idx), - ConstantPool); + Constant *ConstVal = + LookupConstant(PHI->getIncomingValue(Idx), ConstantPool); if (!ConstVal) return false; @@ -4086,16 +4284,16 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, // Helper function used to add CaseVal to the list of cases that generate // Result. static void MapCaseToResult(ConstantInt *CaseVal, - SwitchCaseResultVectorTy &UniqueResults, - Constant *Result) { + SwitchCaseResultVectorTy &UniqueResults, + Constant *Result) { for (auto &I : UniqueResults) { if (I.first == Result) { I.second.push_back(CaseVal); return; } } - UniqueResults.push_back(std::make_pair(Result, - SmallVector<ConstantInt*, 4>(1, CaseVal))); + UniqueResults.push_back( + std::make_pair(Result, SmallVector<ConstantInt *, 4>(1, CaseVal))); } // Helper function that initializes a map containing @@ -4137,7 +4335,7 @@ static bool InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, DefaultResult = DefaultResults.size() == 1 ? DefaultResults.begin()->second : nullptr; if ((!DefaultResult && - !isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg()))) + !isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg()))) return false; return true; @@ -4154,12 +4352,11 @@ static bool InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, // default: // return 4; // } -static Value * -ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, - Constant *DefaultResult, Value *Condition, - IRBuilder<> &Builder) { +static Value *ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, + Constant *DefaultResult, Value *Condition, + IRBuilder<> &Builder) { assert(ResultVector.size() == 2 && - "We should have exactly two unique results at this point"); + "We should have exactly two unique results at this point"); // If we are selecting between only two cases transform into a simple // select or a two-way select if default is possible. if (ResultVector[0].second.size() == 1 && @@ -4177,8 +4374,8 @@ ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, } Value *const ValueCompare = Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp"); - return Builder.CreateSelect(ValueCompare, ResultVector[0].first, SelectValue, - "switch.select"); + return Builder.CreateSelect(ValueCompare, ResultVector[0].first, + SelectValue, "switch.select"); } return nullptr; @@ -4227,9 +4424,8 @@ static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, assert(PHI != nullptr && "PHI for value select not found"); Builder.SetInsertPoint(SI); - Value *SelectValue = ConvertTwoCaseSwitch( - UniqueResults, - DefaultResult, Cond, Builder); + Value *SelectValue = + ConvertTwoCaseSwitch(UniqueResults, DefaultResult, Cond, Builder); if (SelectValue) { RemoveSwitchAfterSelectConversion(SI, PHI, SelectValue, Builder); return true; @@ -4239,62 +4435,62 @@ static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, } namespace { - /// This class represents a lookup table that can be used to replace a switch. - class SwitchLookupTable { - public: - /// Create a lookup table to use as a switch replacement with the contents - /// of Values, using DefaultValue to fill any holes in the table. - SwitchLookupTable( - Module &M, uint64_t TableSize, ConstantInt *Offset, - const SmallVectorImpl<std::pair<ConstantInt *, Constant *>> &Values, - Constant *DefaultValue, const DataLayout &DL); - - /// Build instructions with Builder to retrieve the value at - /// the position given by Index in the lookup table. - Value *BuildLookup(Value *Index, IRBuilder<> &Builder); - - /// Return true if a table with TableSize elements of - /// type ElementType would fit in a target-legal register. - static bool WouldFitInRegister(const DataLayout &DL, uint64_t TableSize, - Type *ElementType); - - private: - // Depending on the contents of the table, it can be represented in - // different ways. - enum { - // For tables where each element contains the same value, we just have to - // store that single value and return it for each lookup. - SingleValueKind, - - // For tables where there is a linear relationship between table index - // and values. We calculate the result with a simple multiplication - // and addition instead of a table lookup. - LinearMapKind, - - // For small tables with integer elements, we can pack them into a bitmap - // that fits into a target-legal register. Values are retrieved by - // shift and mask operations. - BitMapKind, - - // The table is stored as an array of values. Values are retrieved by load - // instructions from the table. - ArrayKind - } Kind; - - // For SingleValueKind, this is the single value. - Constant *SingleValue; - - // For BitMapKind, this is the bitmap. - ConstantInt *BitMap; - IntegerType *BitMapElementTy; - - // For LinearMapKind, these are the constants used to derive the value. - ConstantInt *LinearOffset; - ConstantInt *LinearMultiplier; - - // For ArrayKind, this is the array. - GlobalVariable *Array; - }; +/// This class represents a lookup table that can be used to replace a switch. +class SwitchLookupTable { +public: + /// Create a lookup table to use as a switch replacement with the contents + /// of Values, using DefaultValue to fill any holes in the table. + SwitchLookupTable( + Module &M, uint64_t TableSize, ConstantInt *Offset, + const SmallVectorImpl<std::pair<ConstantInt *, Constant *>> &Values, + Constant *DefaultValue, const DataLayout &DL); + + /// Build instructions with Builder to retrieve the value at + /// the position given by Index in the lookup table. + Value *BuildLookup(Value *Index, IRBuilder<> &Builder); + + /// Return true if a table with TableSize elements of + /// type ElementType would fit in a target-legal register. + static bool WouldFitInRegister(const DataLayout &DL, uint64_t TableSize, + Type *ElementType); + +private: + // Depending on the contents of the table, it can be represented in + // different ways. + enum { + // For tables where each element contains the same value, we just have to + // store that single value and return it for each lookup. + SingleValueKind, + + // For tables where there is a linear relationship between table index + // and values. We calculate the result with a simple multiplication + // and addition instead of a table lookup. + LinearMapKind, + + // For small tables with integer elements, we can pack them into a bitmap + // that fits into a target-legal register. Values are retrieved by + // shift and mask operations. + BitMapKind, + + // The table is stored as an array of values. Values are retrieved by load + // instructions from the table. + ArrayKind + } Kind; + + // For SingleValueKind, this is the single value. + Constant *SingleValue; + + // For BitMapKind, this is the bitmap. + ConstantInt *BitMap; + IntegerType *BitMapElementTy; + + // For LinearMapKind, these are the constants used to derive the value. + ConstantInt *LinearOffset; + ConstantInt *LinearMultiplier; + + // For ArrayKind, this is the array. + GlobalVariable *Array; +}; } SwitchLookupTable::SwitchLookupTable( @@ -4312,14 +4508,13 @@ SwitchLookupTable::SwitchLookupTable( Type *ValueType = Values.begin()->second->getType(); // Build up the table contents. - SmallVector<Constant*, 64> TableContents(TableSize); + SmallVector<Constant *, 64> TableContents(TableSize); for (size_t I = 0, E = Values.size(); I != E; ++I) { ConstantInt *CaseVal = Values[I].first; Constant *CaseRes = Values[I].second; assert(CaseRes->getType() == ValueType); - uint64_t Idx = (CaseVal->getValue() - Offset->getValue()) - .getLimitedValue(); + uint64_t Idx = (CaseVal->getValue() - Offset->getValue()).getLimitedValue(); TableContents[Idx] = CaseRes; if (CaseRes != SingleValue) @@ -4407,65 +4602,62 @@ SwitchLookupTable::SwitchLookupTable( ArrayType *ArrayTy = ArrayType::get(ValueType, TableSize); Constant *Initializer = ConstantArray::get(ArrayTy, TableContents); - Array = new GlobalVariable(M, ArrayTy, /*constant=*/ true, - GlobalVariable::PrivateLinkage, - Initializer, + Array = new GlobalVariable(M, ArrayTy, /*constant=*/true, + GlobalVariable::PrivateLinkage, Initializer, "switch.table"); - Array->setUnnamedAddr(true); + Array->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); Kind = ArrayKind; } Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { switch (Kind) { - case SingleValueKind: - return SingleValue; - case LinearMapKind: { - // Derive the result value from the input value. - Value *Result = Builder.CreateIntCast(Index, LinearMultiplier->getType(), - false, "switch.idx.cast"); - if (!LinearMultiplier->isOne()) - Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult"); - if (!LinearOffset->isZero()) - Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset"); - return Result; - } - case BitMapKind: { - // Type of the bitmap (e.g. i59). - IntegerType *MapTy = BitMap->getType(); - - // Cast Index to the same type as the bitmap. - // Note: The Index is <= the number of elements in the table, so - // truncating it to the width of the bitmask is safe. - Value *ShiftAmt = Builder.CreateZExtOrTrunc(Index, MapTy, "switch.cast"); - - // Multiply the shift amount by the element width. - ShiftAmt = Builder.CreateMul(ShiftAmt, - ConstantInt::get(MapTy, BitMapElementTy->getBitWidth()), - "switch.shiftamt"); - - // Shift down. - Value *DownShifted = Builder.CreateLShr(BitMap, ShiftAmt, - "switch.downshift"); - // Mask off. - return Builder.CreateTrunc(DownShifted, BitMapElementTy, - "switch.masked"); - } - case ArrayKind: { - // Make sure the table index will not overflow when treated as signed. - IntegerType *IT = cast<IntegerType>(Index->getType()); - uint64_t TableSize = Array->getInitializer()->getType() - ->getArrayNumElements(); - if (TableSize > (1ULL << (IT->getBitWidth() - 1))) - Index = Builder.CreateZExt(Index, - IntegerType::get(IT->getContext(), - IT->getBitWidth() + 1), - "switch.tableidx.zext"); - - Value *GEPIndices[] = { Builder.getInt32(0), Index }; - Value *GEP = Builder.CreateInBoundsGEP(Array->getValueType(), Array, - GEPIndices, "switch.gep"); - return Builder.CreateLoad(GEP, "switch.load"); - } + case SingleValueKind: + return SingleValue; + case LinearMapKind: { + // Derive the result value from the input value. + Value *Result = Builder.CreateIntCast(Index, LinearMultiplier->getType(), + false, "switch.idx.cast"); + if (!LinearMultiplier->isOne()) + Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult"); + if (!LinearOffset->isZero()) + Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset"); + return Result; + } + case BitMapKind: { + // Type of the bitmap (e.g. i59). + IntegerType *MapTy = BitMap->getType(); + + // Cast Index to the same type as the bitmap. + // Note: The Index is <= the number of elements in the table, so + // truncating it to the width of the bitmask is safe. + Value *ShiftAmt = Builder.CreateZExtOrTrunc(Index, MapTy, "switch.cast"); + + // Multiply the shift amount by the element width. + ShiftAmt = Builder.CreateMul( + ShiftAmt, ConstantInt::get(MapTy, BitMapElementTy->getBitWidth()), + "switch.shiftamt"); + + // Shift down. + Value *DownShifted = + Builder.CreateLShr(BitMap, ShiftAmt, "switch.downshift"); + // Mask off. + return Builder.CreateTrunc(DownShifted, BitMapElementTy, "switch.masked"); + } + case ArrayKind: { + // Make sure the table index will not overflow when treated as signed. + IntegerType *IT = cast<IntegerType>(Index->getType()); + uint64_t TableSize = + Array->getInitializer()->getType()->getArrayNumElements(); + if (TableSize > (1ULL << (IT->getBitWidth() - 1))) + Index = Builder.CreateZExt( + Index, IntegerType::get(IT->getContext(), IT->getBitWidth() + 1), + "switch.tableidx.zext"); + + Value *GEPIndices[] = {Builder.getInt32(0), Index}; + Value *GEP = Builder.CreateInBoundsGEP(Array->getValueType(), Array, + GEPIndices, "switch.gep"); + return Builder.CreateLoad(GEP, "switch.load"); + } } llvm_unreachable("Unknown lookup table kind!"); } @@ -4480,7 +4672,7 @@ bool SwitchLookupTable::WouldFitInRegister(const DataLayout &DL, // are <= 15, we could try to narrow the type. // Avoid overflow, fitsInLegalInteger uses unsigned int for the width. - if (TableSize >= UINT_MAX/IT->getBitWidth()) + if (TableSize >= UINT_MAX / IT->getBitWidth()) return false; return DL.fitsInLegalInteger(TableSize * IT->getBitWidth()); } @@ -4503,8 +4695,9 @@ ShouldBuildLookupTable(SwitchInst *SI, uint64_t TableSize, HasIllegalType = HasIllegalType || !TTI.isTypeLegal(Ty); // Saturate this flag to false. - AllTablesFitInRegister = AllTablesFitInRegister && - SwitchLookupTable::WouldFitInRegister(DL, TableSize, Ty); + AllTablesFitInRegister = + AllTablesFitInRegister && + SwitchLookupTable::WouldFitInRegister(DL, TableSize, Ty); // If both flags saturate, we're done. NOTE: This *only* works with // saturating flags, and all flags have to saturate first due to the @@ -4547,9 +4740,10 @@ ShouldBuildLookupTable(SwitchInst *SI, uint64_t TableSize, /// ... /// \endcode /// Jump threading will then eliminate the second if(cond). -static void reuseTableCompare(User *PhiUser, BasicBlock *PhiBlock, - BranchInst *RangeCheckBranch, Constant *DefaultValue, - const SmallVectorImpl<std::pair<ConstantInt*, Constant*> >& Values) { +static void reuseTableCompare( + User *PhiUser, BasicBlock *PhiBlock, BranchInst *RangeCheckBranch, + Constant *DefaultValue, + const SmallVectorImpl<std::pair<ConstantInt *, Constant *>> &Values) { ICmpInst *CmpInst = dyn_cast<ICmpInst>(PhiUser); if (!CmpInst) @@ -4578,13 +4772,13 @@ static void reuseTableCompare(User *PhiUser, BasicBlock *PhiBlock, // compare result. for (auto ValuePair : Values) { Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(), - ValuePair.second, CmpOp1, true); + ValuePair.second, CmpOp1, true); if (!CaseConst || CaseConst == DefaultConst) return; assert((CaseConst == TrueConst || CaseConst == FalseConst) && "Expect true or false as compare result."); } - + // Check if the branch instruction dominates the phi node. It's a simple // dominance check, but sufficient for our needs. // Although this check is invariant in the calling loops, it's better to do it @@ -4602,9 +4796,9 @@ static void reuseTableCompare(User *PhiUser, BasicBlock *PhiBlock, ++NumTableCmpReuses; } else { // The compare yields the same result, just inverted. We can replace it. - Value *InvertedTableCmp = BinaryOperator::CreateXor(RangeCmp, - ConstantInt::get(RangeCmp->getType(), 1), "inverted.cmp", - RangeCheckBranch); + Value *InvertedTableCmp = BinaryOperator::CreateXor( + RangeCmp, ConstantInt::get(RangeCmp->getType(), 1), "inverted.cmp", + RangeCheckBranch); CmpInst->replaceAllUsesWith(InvertedTableCmp); ++NumTableCmpReuses; } @@ -4629,7 +4823,8 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // GEP needs a runtime relocation in PIC code. We should just build one big // string and lookup indices into that. - // Ignore switches with less than three cases. Lookup tables will not make them + // Ignore switches with less than three cases. Lookup tables will not make + // them // faster, so we don't analyze them. if (SI->getNumCases() < 3) return false; @@ -4642,11 +4837,11 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, ConstantInt *MaxCaseVal = CI.getCaseValue(); BasicBlock *CommonDest = nullptr; - typedef SmallVector<std::pair<ConstantInt*, Constant*>, 4> ResultListTy; - SmallDenseMap<PHINode*, ResultListTy> ResultLists; - SmallDenseMap<PHINode*, Constant*> DefaultResults; - SmallDenseMap<PHINode*, Type*> ResultTypes; - SmallVector<PHINode*, 4> PHIs; + typedef SmallVector<std::pair<ConstantInt *, Constant *>, 4> ResultListTy; + SmallDenseMap<PHINode *, ResultListTy> ResultLists; + SmallDenseMap<PHINode *, Constant *> DefaultResults; + SmallDenseMap<PHINode *, Type *> ResultTypes; + SmallVector<PHINode *, 4> PHIs; for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) { ConstantInt *CaseVal = CI.getCaseValue(); @@ -4656,7 +4851,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, MaxCaseVal = CaseVal; // Resulting value at phi nodes for this case value. - typedef SmallVector<std::pair<PHINode*, Constant*>, 4> ResultsTy; + typedef SmallVector<std::pair<PHINode *, Constant *>, 4> ResultsTy; ResultsTy Results; if (!GetCaseResults(SI, CaseVal, CI.getCaseSuccessor(), &CommonDest, Results, DL)) @@ -4684,14 +4879,14 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // If the table has holes, we need a constant result for the default case // or a bitmask that fits in a register. - SmallVector<std::pair<PHINode*, Constant*>, 4> DefaultResultsList; + SmallVector<std::pair<PHINode *, Constant *>, 4> DefaultResultsList; bool HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResultsList, DL); bool NeedMask = (TableHasHoles && !HasDefaultResults); if (NeedMask) { // As an extra penalty for the validity test we require more cases. - if (SI->getNumCases() < 4) // FIXME: Find best threshold value (benchmark). + if (SI->getNumCases() < 4) // FIXME: Find best threshold value (benchmark). return false; if (!DL.fitsInLegalInteger(TableSize)) return false; @@ -4708,15 +4903,13 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // Create the BB that does the lookups. Module &Mod = *CommonDest->getParent()->getParent(); - BasicBlock *LookupBB = BasicBlock::Create(Mod.getContext(), - "switch.lookup", - CommonDest->getParent(), - CommonDest); + BasicBlock *LookupBB = BasicBlock::Create( + Mod.getContext(), "switch.lookup", CommonDest->getParent(), CommonDest); // Compute the table index value. Builder.SetInsertPoint(SI); - Value *TableIndex = Builder.CreateSub(SI->getCondition(), MinCaseVal, - "switch.tableidx"); + Value *TableIndex = + Builder.CreateSub(SI->getCondition(), MinCaseVal, "switch.tableidx"); // Compute the maximum table size representable by the integer type we are // switching upon. @@ -4739,9 +4932,10 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // Note: We call removeProdecessor later since we need to be able to get the // PHI value for the default case in case we're using a bit mask. } else { - Value *Cmp = Builder.CreateICmpULT(TableIndex, ConstantInt::get( - MinCaseVal->getType(), TableSize)); - RangeCheckBranch = Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest()); + Value *Cmp = Builder.CreateICmpULT( + TableIndex, ConstantInt::get(MinCaseVal->getType(), TableSize)); + RangeCheckBranch = + Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest()); } // Populate the BB that does the lookups. @@ -4753,10 +4947,8 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // and we create a new LookupBB. BasicBlock *MaskBB = LookupBB; MaskBB->setName("switch.hole_check"); - LookupBB = BasicBlock::Create(Mod.getContext(), - "switch.lookup", - CommonDest->getParent(), - CommonDest); + LookupBB = BasicBlock::Create(Mod.getContext(), "switch.lookup", + CommonDest->getParent(), CommonDest); // Make the mask's bitwidth at least 8bit and a power-of-2 to avoid // unnecessary illegal types. @@ -4766,8 +4958,8 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // Build bitmask; fill in a 1 bit for every case. const ResultListTy &ResultList = ResultLists[PHIs[0]]; for (size_t I = 0, E = ResultList.size(); I != E; ++I) { - uint64_t Idx = (ResultList[I].first->getValue() - - MinCaseVal->getValue()).getLimitedValue(); + uint64_t Idx = (ResultList[I].first->getValue() - MinCaseVal->getValue()) + .getLimitedValue(); MaskInt |= One << Idx; } ConstantInt *TableMask = ConstantInt::get(Mod.getContext(), MaskInt); @@ -4776,13 +4968,11 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // If this bit is 0 (meaning hole) jump to the default destination, // else continue with table lookup. IntegerType *MapTy = TableMask->getType(); - Value *MaskIndex = Builder.CreateZExtOrTrunc(TableIndex, MapTy, - "switch.maskindex"); - Value *Shifted = Builder.CreateLShr(TableMask, MaskIndex, - "switch.shifted"); - Value *LoBit = Builder.CreateTrunc(Shifted, - Type::getInt1Ty(Mod.getContext()), - "switch.lobit"); + Value *MaskIndex = + Builder.CreateZExtOrTrunc(TableIndex, MapTy, "switch.maskindex"); + Value *Shifted = Builder.CreateLShr(TableMask, MaskIndex, "switch.shifted"); + Value *LoBit = Builder.CreateTrunc( + Shifted, Type::getInt1Ty(Mod.getContext()), "switch.lobit"); Builder.CreateCondBr(LoBit, LookupBB, SI->getDefaultDest()); Builder.SetInsertPoint(LookupBB); @@ -4905,7 +5095,8 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) { Dest->removePredecessor(BB); IBI->removeDestination(i); - --i; --e; + --i; + --e; Changed = true; } } @@ -4968,27 +5159,28 @@ static bool TryToMergeLandingPad(LandingPadInst *LPad, BranchInst *BI, LandingPadInst *LPad2 = dyn_cast<LandingPadInst>(I); if (!LPad2 || !LPad2->isIdenticalTo(LPad)) continue; - for (++I; isa<DbgInfoIntrinsic>(I); ++I) {} + for (++I; isa<DbgInfoIntrinsic>(I); ++I) { + } BranchInst *BI2 = dyn_cast<BranchInst>(I); if (!BI2 || !BI2->isIdenticalTo(BI)) continue; - // We've found an identical block. Update our predeccessors to take that + // We've found an identical block. Update our predecessors to take that // path instead and make ourselves dead. SmallSet<BasicBlock *, 16> Preds; Preds.insert(pred_begin(BB), pred_end(BB)); for (BasicBlock *Pred : Preds) { InvokeInst *II = cast<InvokeInst>(Pred->getTerminator()); - assert(II->getNormalDest() != BB && - II->getUnwindDest() == BB && "unexpected successor"); + assert(II->getNormalDest() != BB && II->getUnwindDest() == BB && + "unexpected successor"); II->setUnwindDest(OtherPred); } // The debug info in OtherPred doesn't cover the merged control flow that // used to go through BB. We need to delete it or update it. - for (auto I = OtherPred->begin(), E = OtherPred->end(); - I != E;) { - Instruction &Inst = *I; I++; + for (auto I = OtherPred->begin(), E = OtherPred->end(); I != E;) { + Instruction &Inst = *I; + I++; if (isa<DbgInfoIntrinsic>(Inst)) Inst.eraseFromParent(); } @@ -5007,15 +5199,22 @@ static bool TryToMergeLandingPad(LandingPadInst *LPad, BranchInst *BI, return false; } -bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ +bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, + IRBuilder<> &Builder) { BasicBlock *BB = BI->getParent(); if (SinkCommon && SinkThenElseCodeToEnd(BI)) return true; // If the Terminator is the only non-phi instruction, simplify the block. + // if LoopHeader is provided, check if the block is a loop header + // (This is for early invocations before loop simplify and vectorization + // to keep canonical loop forms for nested loops. + // These blocks can be eliminated when the pass is invoked later + // in the back-end.) BasicBlock::iterator I = BB->getFirstNonPHIOrDbg()->getIterator(); if (I->isTerminator() && BB != &BB->getParent()->getEntryBlock() && + (!LoopHeaders || !LoopHeaders->count(BB)) && TryToSimplifyUncondBranchFromEmptyBlock(BB)) return true; @@ -5034,9 +5233,9 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ // See if we can merge an empty landing pad block with another which is // equivalent. if (LandingPadInst *LPad = dyn_cast<LandingPadInst>(I)) { - for (++I; isa<DbgInfoIntrinsic>(I); ++I) {} - if (I->isTerminator() && - TryToMergeLandingPad(LPad, BI, BB)) + for (++I; isa<DbgInfoIntrinsic>(I); ++I) { + } + if (I->isTerminator() && TryToMergeLandingPad(LPad, BI, BB)) return true; } @@ -5081,7 +5280,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (&*I == BI) { if (FoldValueComparisonIntoPredecessors(BI, Builder)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; - } else if (&*I == cast<Instruction>(BI->getCondition())){ + } else if (&*I == cast<Instruction>(BI->getCondition())) { ++I; // Ignore dbg intrinsics. while (isa<DbgInfoIntrinsic>(I)) @@ -5095,6 +5294,30 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (SimplifyBranchOnICmpChain(BI, Builder, DL)) return true; + // If this basic block has a single dominating predecessor block and the + // dominating block's condition implies BI's condition, we know the direction + // of the BI branch. + if (BasicBlock *Dom = BB->getSinglePredecessor()) { + auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator()); + if (PBI && PBI->isConditional() && + PBI->getSuccessor(0) != PBI->getSuccessor(1) && + (PBI->getSuccessor(0) == BB || PBI->getSuccessor(1) == BB)) { + bool CondIsFalse = PBI->getSuccessor(1) == BB; + Optional<bool> Implication = isImpliedCondition( + PBI->getCondition(), BI->getCondition(), DL, CondIsFalse); + if (Implication) { + // Turn this into a branch on constant. + auto *OldCond = BI->getCondition(); + ConstantInt *CI = *Implication + ? ConstantInt::getTrue(BB->getContext()) + : ConstantInt::getFalse(BB->getContext()); + BI->setCondition(CI); + RecursivelyDeleteTriviallyDeadInstructions(OldCond); + return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + } + } + } + // If this basic block is ONLY a compare and a branch, and if a predecessor // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. @@ -5149,7 +5372,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (PBI != BI && PBI->isConditional()) if (mergeConditionalStores(PBI, BI)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; - + return false; } @@ -5162,7 +5385,7 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I) { if (I->use_empty()) return false; - if (C->isNullValue()) { + if (C->isNullValue() || isa<UndefValue>(C)) { // Only look at the first use, avoid hurting compile time with long uselists User *Use = *I->user_begin(); @@ -5189,7 +5412,12 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I) { // Store to null is undefined. if (StoreInst *SI = dyn_cast<StoreInst>(Use)) if (!SI->isVolatile()) - return SI->getPointerAddressSpace() == 0 && SI->getPointerOperand() == I; + return SI->getPointerAddressSpace() == 0 && + SI->getPointerOperand() == I; + + // A call to null is undefined. + if (auto CS = CallSite(Use)) + return CS.getCalledValue() == I; } return false; } @@ -5210,8 +5438,8 @@ static bool removeUndefIntroducingPredecessor(BasicBlock *BB) { if (BI->isUnconditional()) Builder.CreateUnreachable(); else - Builder.CreateBr(BI->getSuccessor(0) == BB ? BI->getSuccessor(1) : - BI->getSuccessor(0)); + Builder.CreateBr(BI->getSuccessor(0) == BB ? BI->getSuccessor(1) + : BI->getSuccessor(0)); BI->eraseFromParent(); return true; } @@ -5229,8 +5457,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { // Remove basic blocks that have no predecessors (except the entry block)... // or that just have themself as a predecessor. These are unreachable. - if ((pred_empty(BB) && - BB != &BB->getParent()->getEntryBlock()) || + if ((pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) || BB->getSinglePredecessor() == BB) { DEBUG(dbgs() << "Removing BB: \n" << *BB); DeleteDeadBlock(BB); @@ -5265,25 +5492,33 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { Builder.SetInsertPoint(BB->getTerminator()); if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) { if (BI->isUnconditional()) { - if (SimplifyUncondBranch(BI, Builder)) return true; + if (SimplifyUncondBranch(BI, Builder)) + return true; } else { - if (SimplifyCondBranch(BI, Builder)) return true; + if (SimplifyCondBranch(BI, Builder)) + return true; } } else if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) { - if (SimplifyReturn(RI, Builder)) return true; + if (SimplifyReturn(RI, Builder)) + return true; } else if (ResumeInst *RI = dyn_cast<ResumeInst>(BB->getTerminator())) { - if (SimplifyResume(RI, Builder)) return true; + if (SimplifyResume(RI, Builder)) + return true; } else if (CleanupReturnInst *RI = - dyn_cast<CleanupReturnInst>(BB->getTerminator())) { - if (SimplifyCleanupReturn(RI)) return true; + dyn_cast<CleanupReturnInst>(BB->getTerminator())) { + if (SimplifyCleanupReturn(RI)) + return true; } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { - if (SimplifySwitch(SI, Builder)) return true; + if (SimplifySwitch(SI, Builder)) + return true; } else if (UnreachableInst *UI = - dyn_cast<UnreachableInst>(BB->getTerminator())) { - if (SimplifyUnreachable(UI)) return true; + dyn_cast<UnreachableInst>(BB->getTerminator())) { + if (SimplifyUnreachable(UI)) + return true; } else if (IndirectBrInst *IBI = - dyn_cast<IndirectBrInst>(BB->getTerminator())) { - if (SimplifyIndirectBr(IBI)) return true; + dyn_cast<IndirectBrInst>(BB->getTerminator())) { + if (SimplifyIndirectBr(IBI)) + return true; } return Changed; @@ -5295,7 +5530,9 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { /// of the CFG. It returns true if a modification was made. /// bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, - unsigned BonusInstThreshold, AssumptionCache *AC) { + unsigned BonusInstThreshold, AssumptionCache *AC, + SmallPtrSetImpl<BasicBlock *> *LoopHeaders) { return SimplifyCFGOpt(TTI, BB->getModule()->getDataLayout(), - BonusInstThreshold, AC).run(BB); + BonusInstThreshold, AC, LoopHeaders) + .run(BB); } diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index ddd8775a8431..6b1d3dc41330 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -25,7 +25,6 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -71,14 +70,12 @@ namespace { bool eliminateIdentitySCEV(Instruction *UseInst, Instruction *IVOperand); + bool eliminateOverflowIntrinsic(CallInst *CI); bool eliminateIVUser(Instruction *UseInst, Instruction *IVOperand); void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand); void eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand, bool IsSigned); bool strengthenOverflowingOperation(BinaryOperator *OBO, Value *IVOperand); - - Instruction *splitOverflowIntrinsic(Instruction *IVUser, - const DominatorTree *DT); }; } @@ -183,9 +180,8 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { DeadInsts.emplace_back(ICmp); DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n'); } else if (isa<PHINode>(IVOperand) && - SE->isLoopInvariantPredicate(Pred, S, X, ICmpLoop, - InvariantPredicate, InvariantLHS, - InvariantRHS)) { + SE->isLoopInvariantPredicate(Pred, S, X, L, InvariantPredicate, + InvariantLHS, InvariantRHS)) { // Rewrite the comparison to a loop invariant comparison if it can be done // cheaply, where cheaply means "we don't need to emit any new @@ -201,9 +197,48 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { NewRHS = ICmp->getOperand(S == InvariantRHS ? IVOperIdx : (1 - IVOperIdx)); - for (Value *Incoming : cast<PHINode>(IVOperand)->incoming_values()) { - if (NewLHS && NewRHS) - break; + auto *PN = cast<PHINode>(IVOperand); + for (unsigned i = 0, e = PN->getNumIncomingValues(); + i != e && (!NewLHS || !NewRHS); + ++i) { + + // If this is a value incoming from the backedge, then it cannot be a loop + // invariant value (since we know that IVOperand is an induction variable). + if (L->contains(PN->getIncomingBlock(i))) + continue; + + // NB! This following assert does not fundamentally have to be true, but + // it is true today given how SCEV analyzes induction variables. + // Specifically, today SCEV will *not* recognize %iv as an induction + // variable in the following case: + // + // define void @f(i32 %k) { + // entry: + // br i1 undef, label %r, label %l + // + // l: + // %k.inc.l = add i32 %k, 1 + // br label %loop + // + // r: + // %k.inc.r = add i32 %k, 1 + // br label %loop + // + // loop: + // %iv = phi i32 [ %k.inc.l, %l ], [ %k.inc.r, %r ], [ %iv.inc, %loop ] + // %iv.inc = add i32 %iv, 1 + // br label %loop + // } + // + // but if it starts to, at some point, then the assertion below will have + // to be changed to a runtime check. + + Value *Incoming = PN->getIncomingValue(i); + +#ifndef NDEBUG + if (auto *I = dyn_cast<Instruction>(Incoming)) + assert(DT->dominates(I, ICmp) && "Should be a unique loop dominating value!"); +#endif const SCEV *IncomingS = SE->getSCEV(Incoming); @@ -280,6 +315,108 @@ void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, DeadInsts.emplace_back(Rem); } +bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) { + auto *F = CI->getCalledFunction(); + if (!F) + return false; + + typedef const SCEV *(ScalarEvolution::*OperationFunctionTy)( + const SCEV *, const SCEV *, SCEV::NoWrapFlags); + typedef const SCEV *(ScalarEvolution::*ExtensionFunctionTy)( + const SCEV *, Type *); + + OperationFunctionTy Operation; + ExtensionFunctionTy Extension; + + Instruction::BinaryOps RawOp; + + // We always have exactly one of nsw or nuw. If NoSignedOverflow is false, we + // have nuw. + bool NoSignedOverflow; + + switch (F->getIntrinsicID()) { + default: + return false; + + case Intrinsic::sadd_with_overflow: + Operation = &ScalarEvolution::getAddExpr; + Extension = &ScalarEvolution::getSignExtendExpr; + RawOp = Instruction::Add; + NoSignedOverflow = true; + break; + + case Intrinsic::uadd_with_overflow: + Operation = &ScalarEvolution::getAddExpr; + Extension = &ScalarEvolution::getZeroExtendExpr; + RawOp = Instruction::Add; + NoSignedOverflow = false; + break; + + case Intrinsic::ssub_with_overflow: + Operation = &ScalarEvolution::getMinusSCEV; + Extension = &ScalarEvolution::getSignExtendExpr; + RawOp = Instruction::Sub; + NoSignedOverflow = true; + break; + + case Intrinsic::usub_with_overflow: + Operation = &ScalarEvolution::getMinusSCEV; + Extension = &ScalarEvolution::getZeroExtendExpr; + RawOp = Instruction::Sub; + NoSignedOverflow = false; + break; + } + + const SCEV *LHS = SE->getSCEV(CI->getArgOperand(0)); + const SCEV *RHS = SE->getSCEV(CI->getArgOperand(1)); + + auto *NarrowTy = cast<IntegerType>(LHS->getType()); + auto *WideTy = + IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2); + + const SCEV *A = + (SE->*Extension)((SE->*Operation)(LHS, RHS, SCEV::FlagAnyWrap), WideTy); + const SCEV *B = + (SE->*Operation)((SE->*Extension)(LHS, WideTy), + (SE->*Extension)(RHS, WideTy), SCEV::FlagAnyWrap); + + if (A != B) + return false; + + // Proved no overflow, nuke the overflow check and, if possible, the overflow + // intrinsic as well. + + BinaryOperator *NewResult = BinaryOperator::Create( + RawOp, CI->getArgOperand(0), CI->getArgOperand(1), "", CI); + + if (NoSignedOverflow) + NewResult->setHasNoSignedWrap(true); + else + NewResult->setHasNoUnsignedWrap(true); + + SmallVector<ExtractValueInst *, 4> ToDelete; + + for (auto *U : CI->users()) { + if (auto *EVI = dyn_cast<ExtractValueInst>(U)) { + if (EVI->getIndices()[0] == 1) + EVI->replaceAllUsesWith(ConstantInt::getFalse(CI->getContext())); + else { + assert(EVI->getIndices()[0] == 0 && "Only two possibilities!"); + EVI->replaceAllUsesWith(NewResult); + } + ToDelete.push_back(EVI); + } + } + + for (auto *EVI : ToDelete) + EVI->eraseFromParent(); + + if (CI->use_empty()) + CI->eraseFromParent(); + + return true; +} + /// Eliminate an operation that consumes a simple IV and has no observable /// side-effect given the range of IV values. IVOperand is guaranteed SCEVable, /// but UseInst may not be. @@ -297,6 +434,10 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, } } + if (auto *CI = dyn_cast<CallInst>(UseInst)) + if (eliminateOverflowIntrinsic(CI)) + return true; + if (eliminateIdentitySCEV(UseInst, IVOperand)) return true; @@ -408,69 +549,6 @@ bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, return Changed; } -/// \brief Split sadd.with.overflow into add + sadd.with.overflow to allow -/// analysis and optimization. -/// -/// \return A new value representing the non-overflowing add if possible, -/// otherwise return the original value. -Instruction *SimplifyIndvar::splitOverflowIntrinsic(Instruction *IVUser, - const DominatorTree *DT) { - IntrinsicInst *II = dyn_cast<IntrinsicInst>(IVUser); - if (!II || II->getIntrinsicID() != Intrinsic::sadd_with_overflow) - return IVUser; - - // Find a branch guarded by the overflow check. - BranchInst *Branch = nullptr; - Instruction *AddVal = nullptr; - for (User *U : II->users()) { - if (ExtractValueInst *ExtractInst = dyn_cast<ExtractValueInst>(U)) { - if (ExtractInst->getNumIndices() != 1) - continue; - if (ExtractInst->getIndices()[0] == 0) - AddVal = ExtractInst; - else if (ExtractInst->getIndices()[0] == 1 && ExtractInst->hasOneUse()) - Branch = dyn_cast<BranchInst>(ExtractInst->user_back()); - } - } - if (!AddVal || !Branch) - return IVUser; - - BasicBlock *ContinueBB = Branch->getSuccessor(1); - if (std::next(pred_begin(ContinueBB)) != pred_end(ContinueBB)) - return IVUser; - - // Check if all users of the add are provably NSW. - bool AllNSW = true; - for (Use &U : AddVal->uses()) { - if (Instruction *UseInst = dyn_cast<Instruction>(U.getUser())) { - BasicBlock *UseBB = UseInst->getParent(); - if (PHINode *PHI = dyn_cast<PHINode>(UseInst)) - UseBB = PHI->getIncomingBlock(U); - if (!DT->dominates(ContinueBB, UseBB)) { - AllNSW = false; - break; - } - } - } - if (!AllNSW) - return IVUser; - - // Go for it... - IRBuilder<> Builder(IVUser); - Instruction *AddInst = dyn_cast<Instruction>( - Builder.CreateNSWAdd(II->getOperand(0), II->getOperand(1))); - - // The caller expects the new add to have the same form as the intrinsic. The - // IV operand position must be the same. - assert((AddInst->getOpcode() == Instruction::Add && - AddInst->getOperand(0) == II->getOperand(0)) && - "Bad add instruction created from overflow intrinsic."); - - AddVal->replaceAllUsesWith(AddInst); - DeadInsts.emplace_back(AddVal); - return AddInst; -} - /// Add all uses of Def to the current IV's worklist. static void pushIVUsers( Instruction *Def, @@ -545,12 +623,6 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { // Bypass back edges to avoid extra work. if (UseInst == CurrIV) continue; - if (V && V->shouldSplitOverflowInstrinsics()) { - UseInst = splitOverflowIntrinsic(UseInst, V->getDomTree()); - if (!UseInst) - continue; - } - Instruction *IVOperand = UseOper.second; for (unsigned N = 0; IVOperand; ++N) { assert(N <= Simplified.size() && "runaway iteration"); diff --git a/lib/Transforms/Utils/SimplifyInstructions.cpp b/lib/Transforms/Utils/SimplifyInstructions.cpp index d5377f9a4c1f..df299067094f 100644 --- a/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -14,7 +14,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/SimplifyInstructions.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" @@ -27,12 +27,60 @@ #include "llvm/IR/Type.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Scalar.h" using namespace llvm; #define DEBUG_TYPE "instsimplify" STATISTIC(NumSimplified, "Number of redundant instructions removed"); +static bool runImpl(Function &F, const DominatorTree *DT, const TargetLibraryInfo *TLI, + AssumptionCache *AC) { + const DataLayout &DL = F.getParent()->getDataLayout(); + SmallPtrSet<const Instruction*, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; + bool Changed = false; + + do { + for (BasicBlock *BB : depth_first(&F.getEntryBlock())) + // Here be subtlety: the iterator must be incremented before the loop + // body (not sure why), so a range-for loop won't work here. + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { + Instruction *I = &*BI++; + // The first time through the loop ToSimplify is empty and we try to + // simplify all instructions. On later iterations ToSimplify is not + // empty and we only bother simplifying instructions that are in it. + if (!ToSimplify->empty() && !ToSimplify->count(I)) + continue; + // Don't waste time simplifying unused instructions. + if (!I->use_empty()) + if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { + // Mark all uses for resimplification next time round the loop. + for (User *U : I->users()) + Next->insert(cast<Instruction>(U)); + I->replaceAllUsesWith(V); + ++NumSimplified; + Changed = true; + } + bool res = RecursivelyDeleteTriviallyDeadInstructions(I, TLI); + if (res) { + // RecursivelyDeleteTriviallyDeadInstruction can remove + // more than one instruction, so simply incrementing the + // iterator does not work. When instructions get deleted + // re-iterate instead. + BI = BB->begin(); BE = BB->end(); + Changed |= res; + } + } + + // Place the list of instructions to simplify on the next loop iteration + // into ToSimplify. + std::swap(ToSimplify, Next); + Next->clear(); + } while (!ToSimplify->empty()); + + return Changed; +} + namespace { struct InstSimplifier : public FunctionPass { static char ID; // Pass identification, replacement for typeid @@ -48,56 +96,17 @@ namespace { /// runOnFunction - Remove instructions that simplify. bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + const DominatorTreeWrapperPass *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); const DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; - const DataLayout &DL = F.getParent()->getDataLayout(); const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - SmallPtrSet<const Instruction*, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; - bool Changed = false; - - do { - for (BasicBlock *BB : depth_first(&F.getEntryBlock())) - // Here be subtlety: the iterator must be incremented before the loop - // body (not sure why), so a range-for loop won't work here. - for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { - Instruction *I = &*BI++; - // The first time through the loop ToSimplify is empty and we try to - // simplify all instructions. On later iterations ToSimplify is not - // empty and we only bother simplifying instructions that are in it. - if (!ToSimplify->empty() && !ToSimplify->count(I)) - continue; - // Don't waste time simplifying unused instructions. - if (!I->use_empty()) - if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { - // Mark all uses for resimplification next time round the loop. - for (User *U : I->users()) - Next->insert(cast<Instruction>(U)); - I->replaceAllUsesWith(V); - ++NumSimplified; - Changed = true; - } - bool res = RecursivelyDeleteTriviallyDeadInstructions(I, TLI); - if (res) { - // RecursivelyDeleteTriviallyDeadInstruction can remove - // more than one instruction, so simply incrementing the - // iterator does not work. When instructions get deleted - // re-iterate instead. - BI = BB->begin(); BE = BB->end(); - Changed |= res; - } - } - - // Place the list of instructions to simplify on the next loop iteration - // into ToSimplify. - std::swap(ToSimplify, Next); - Next->clear(); - } while (!ToSimplify->empty()); - - return Changed; + return runImpl(F, DT, TLI, AC); } }; } @@ -115,3 +124,15 @@ char &llvm::InstructionSimplifierID = InstSimplifier::ID; FunctionPass *llvm::createInstructionSimplifierPass() { return new InstSimplifier(); } + +PreservedAnalyses InstSimplifierPass::run(Function &F, + AnalysisManager<Function> &AM) { + auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + bool Changed = runImpl(F, DT, &TLI, &AC); + if (!Changed) + return PreservedAnalyses::all(); + // FIXME: This should also 'preserve the CFG'. + return PreservedAnalyses::none(); +} diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 2f3c31128cf0..c2986951e48f 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -29,7 +29,6 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" @@ -104,101 +103,11 @@ static bool hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, } } -/// \brief Check whether we can use unsafe floating point math for -/// the function passed as input. -static bool canUseUnsafeFPMath(Function *F) { - - // FIXME: For finer-grain optimization, we need intrinsics to have the same - // fast-math flag decorations that are applied to FP instructions. For now, - // we have to rely on the function-level unsafe-fp-math attribute to do this - // optimization because there's no other way to express that the call can be - // relaxed. - if (F->hasFnAttribute("unsafe-fp-math")) { - Attribute Attr = F->getFnAttribute("unsafe-fp-math"); - if (Attr.getValueAsString() == "true") - return true; - } - return false; -} - -/// \brief Returns whether \p F matches the signature expected for the -/// string/memory copying library function \p Func. -/// Acceptable functions are st[rp][n]?cpy, memove, memcpy, and memset. -/// Their fortified (_chk) counterparts are also accepted. -static bool checkStringCopyLibFuncSignature(Function *F, LibFunc::Func Func) { - const DataLayout &DL = F->getParent()->getDataLayout(); - FunctionType *FT = F->getFunctionType(); - LLVMContext &Context = F->getContext(); - Type *PCharTy = Type::getInt8PtrTy(Context); - Type *SizeTTy = DL.getIntPtrType(Context); - unsigned NumParams = FT->getNumParams(); - - // All string libfuncs return the same type as the first parameter. - if (FT->getReturnType() != FT->getParamType(0)) - return false; - - switch (Func) { - default: - llvm_unreachable("Can't check signature for non-string-copy libfunc."); - case LibFunc::stpncpy_chk: - case LibFunc::strncpy_chk: - --NumParams; // fallthrough - case LibFunc::stpncpy: - case LibFunc::strncpy: { - if (NumParams != 3 || FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != PCharTy || !FT->getParamType(2)->isIntegerTy()) - return false; - break; - } - case LibFunc::strcpy_chk: - case LibFunc::stpcpy_chk: - --NumParams; // fallthrough - case LibFunc::stpcpy: - case LibFunc::strcpy: { - if (NumParams != 2 || FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != PCharTy) - return false; - break; - } - case LibFunc::memmove_chk: - case LibFunc::memcpy_chk: - --NumParams; // fallthrough - case LibFunc::memmove: - case LibFunc::memcpy: { - if (NumParams != 3 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || FT->getParamType(2) != SizeTTy) - return false; - break; - } - case LibFunc::memset_chk: - --NumParams; // fallthrough - case LibFunc::memset: { - if (NumParams != 3 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || FT->getParamType(2) != SizeTTy) - return false; - break; - } - } - // If this is a fortified libcall, the last parameter is a size_t. - if (NumParams == FT->getNumParams() - 1) - return FT->getParamType(FT->getNumParams() - 1) == SizeTTy; - return true; -} - //===----------------------------------------------------------------------===// // String and Memory Library Call Optimizations //===----------------------------------------------------------------------===// Value *LibCallSimplifier::optimizeStrCat(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // Verify the "strcat" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2|| - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - FT->getParamType(1) != FT->getReturnType()) - return nullptr; - // Extract some information from the instruction Value *Dst = CI->getArgOperand(0); Value *Src = CI->getArgOperand(1); @@ -220,7 +129,7 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, IRBuilder<> &B) { // We need to find the end of the destination string. That's where the // memory is to be moved to. We just generate a call to strlen. - Value *DstLen = EmitStrLen(Dst, B, DL, TLI); + Value *DstLen = emitStrLen(Dst, B, DL, TLI); if (!DstLen) return nullptr; @@ -238,15 +147,6 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, } Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // Verify the "strncat" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - FT->getParamType(1) != FT->getReturnType() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; - // Extract some information from the instruction. Value *Dst = CI->getArgOperand(0); Value *Src = CI->getArgOperand(1); @@ -281,13 +181,7 @@ Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - // Verify the "strchr" function prototype. FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - !FT->getParamType(1)->isIntegerTy(32)) - return nullptr; - Value *SrcStr = CI->getArgOperand(0); // If the second operand is non-constant, see if we can compute the length @@ -298,7 +192,7 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32. return nullptr; - return EmitMemChr(SrcStr, CI->getArgOperand(1), // include nul. + return emitMemChr(SrcStr, CI->getArgOperand(1), // include nul. ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len), B, DL, TLI); } @@ -308,7 +202,7 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { StringRef Str; if (!getConstantStringInfo(SrcStr, Str)) { if (CharC->isZero()) // strchr(p, 0) -> p + strlen(p) - return B.CreateGEP(B.getInt8Ty(), SrcStr, EmitStrLen(SrcStr, B, DL, TLI), + return B.CreateGEP(B.getInt8Ty(), SrcStr, emitStrLen(SrcStr, B, DL, TLI), "strchr"); return nullptr; } @@ -326,14 +220,6 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // Verify the "strrchr" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - !FT->getParamType(1)->isIntegerTy(32)) - return nullptr; - Value *SrcStr = CI->getArgOperand(0); ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); @@ -345,7 +231,7 @@ Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilder<> &B) { if (!getConstantStringInfo(SrcStr, Str)) { // strrchr(s, 0) -> strchr(s, 0) if (CharC->isZero()) - return EmitStrChr(SrcStr, '\0', B, TLI); + return emitStrChr(SrcStr, '\0', B, TLI); return nullptr; } @@ -361,14 +247,6 @@ Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // Verify the "strcmp" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getReturnType()->isIntegerTy(32) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); if (Str1P == Str2P) // strcmp(x,x) -> 0 return ConstantInt::get(CI->getType(), 0); @@ -392,7 +270,7 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { uint64_t Len1 = GetStringLength(Str1P); uint64_t Len2 = GetStringLength(Str2P); if (Len1 && Len2) { - return EmitMemCmp(Str1P, Str2P, + return emitMemCmp(Str1P, Str2P, ConstantInt::get(DL.getIntPtrType(CI->getContext()), std::min(Len1, Len2)), B, DL, TLI); @@ -402,15 +280,6 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // Verify the "strncmp" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || !FT->getReturnType()->isIntegerTy(32) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); if (Str1P == Str2P) // strncmp(x,x,n) -> 0 return ConstantInt::get(CI->getType(), 0); @@ -426,7 +295,7 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { return ConstantInt::get(CI->getType(), 0); if (Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1) - return EmitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); + return emitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); StringRef Str1, Str2; bool HasStr1 = getConstantStringInfo(Str1P, Str1); @@ -450,11 +319,6 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::strcpy)) - return nullptr; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); if (Dst == Src) // strcpy(x,x) -> x return Src; @@ -473,12 +337,9 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::stpcpy)) - return nullptr; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) - Value *StrLen = EmitStrLen(Src, B, DL, TLI); + Value *StrLen = emitStrLen(Src, B, DL, TLI); return StrLen ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, StrLen) : nullptr; } @@ -500,9 +361,6 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::strncpy)) - return nullptr; - Value *Dst = CI->getArgOperand(0); Value *Src = CI->getArgOperand(1); Value *LenOp = CI->getArgOperand(2); @@ -540,18 +398,63 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 1 || FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - Value *Src = CI->getArgOperand(0); // Constant folding: strlen("xyz") -> 3 if (uint64_t Len = GetStringLength(Src)) return ConstantInt::get(CI->getType(), Len - 1); + // If s is a constant pointer pointing to a string literal, we can fold + // strlen(s + x) to strlen(s) - x, when x is known to be in the range + // [0, strlen(s)] or the string has a single null terminator '\0' at the end. + // We only try to simplify strlen when the pointer s points to an array + // of i8. Otherwise, we would need to scale the offset x before doing the + // subtraction. This will make the optimization more complex, and it's not + // very useful because calling strlen for a pointer of other types is + // very uncommon. + if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) { + if (!isGEPBasedOnPointerToString(GEP)) + return nullptr; + + StringRef Str; + if (getConstantStringInfo(GEP->getOperand(0), Str, 0, false)) { + size_t NullTermIdx = Str.find('\0'); + + // If the string does not have '\0', leave it to strlen to compute + // its length. + if (NullTermIdx == StringRef::npos) + return nullptr; + + Value *Offset = GEP->getOperand(2); + unsigned BitWidth = Offset->getType()->getIntegerBitWidth(); + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + computeKnownBits(Offset, KnownZero, KnownOne, DL, 0, nullptr, CI, + nullptr); + KnownZero.flipAllBits(); + size_t ArrSize = + cast<ArrayType>(GEP->getSourceElementType())->getNumElements(); + + // KnownZero's bits are flipped, so zeros in KnownZero now represent + // bits known to be zeros in Offset, and ones in KnowZero represent + // bits unknown in Offset. Therefore, Offset is known to be in range + // [0, NullTermIdx] when the flipped KnownZero is non-negative and + // unsigned-less-than NullTermIdx. + // + // If Offset is not provably in the range [0, NullTermIdx], we can still + // optimize if we can prove that the program has undefined behavior when + // Offset is outside that range. That is the case when GEP->getOperand(0) + // is a pointer to an object whose memory extent is NullTermIdx+1. + if ((KnownZero.isNonNegative() && KnownZero.ule(NullTermIdx)) || + (GEP->isInBounds() && isa<GlobalVariable>(GEP->getOperand(0)) && + NullTermIdx == ArrSize - 1)) + return B.CreateSub(ConstantInt::get(CI->getType(), NullTermIdx), + Offset); + } + + return nullptr; + } + // strlen(x?"foo":"bars") --> x ? 3 : 4 if (SelectInst *SI = dyn_cast<SelectInst>(Src)) { uint64_t LenTrue = GetStringLength(SI->getTrueValue()); @@ -576,13 +479,6 @@ Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - FT->getReturnType() != FT->getParamType(0)) - return nullptr; - StringRef S1, S2; bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); @@ -604,19 +500,12 @@ Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilder<> &B) { // strpbrk(s, "a") -> strchr(s, 'a') if (HasS2 && S2.size() == 1) - return EmitStrChr(CI->getArgOperand(0), S2[0], B, TLI); + return emitStrChr(CI->getArgOperand(0), S2[0], B, TLI); return nullptr; } Value *LibCallSimplifier::optimizeStrTo(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - if ((FT->getNumParams() != 2 && FT->getNumParams() != 3) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy()) - return nullptr; - Value *EndPtr = CI->getArgOperand(1); if (isa<ConstantPointerNull>(EndPtr)) { // With a null EndPtr, this function won't capture the main argument. @@ -628,13 +517,6 @@ Value *LibCallSimplifier::optimizeStrTo(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeStrSpn(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - StringRef S1, S2; bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); @@ -656,13 +538,6 @@ Value *LibCallSimplifier::optimizeStrSpn(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeStrCSpn(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - StringRef S1, S2; bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); @@ -681,29 +556,22 @@ Value *LibCallSimplifier::optimizeStrCSpn(CallInst *CI, IRBuilder<> &B) { // strcspn(s, "") -> strlen(s) if (HasS2 && S2.empty()) - return EmitStrLen(CI->getArgOperand(0), B, DL, TLI); + return emitStrLen(CI->getArgOperand(0), B, DL, TLI); return nullptr; } Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isPointerTy()) - return nullptr; - // fold strstr(x, x) -> x. if (CI->getArgOperand(0) == CI->getArgOperand(1)) return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0 if (isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) { - Value *StrLen = EmitStrLen(CI->getArgOperand(1), B, DL, TLI); + Value *StrLen = emitStrLen(CI->getArgOperand(1), B, DL, TLI); if (!StrLen) return nullptr; - Value *StrNCmp = EmitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1), + Value *StrNCmp = emitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1), StrLen, B, DL, TLI); if (!StrNCmp) return nullptr; @@ -734,28 +602,20 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilder<> &B) { return Constant::getNullValue(CI->getType()); // strstr("abcd", "bc") -> gep((char*)"abcd", 1) - Value *Result = CastToCStr(CI->getArgOperand(0), B); + Value *Result = castToCStr(CI->getArgOperand(0), B); Result = B.CreateConstInBoundsGEP1_64(Result, Offset, "strstr"); return B.CreateBitCast(Result, CI->getType()); } // fold strstr(x, "y") -> strchr(x, 'y'). if (HasStr2 && ToFindStr.size() == 1) { - Value *StrChr = EmitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI); + Value *StrChr = emitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI); return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; } return nullptr; } Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy(32) || - !FT->getParamType(2)->isIntegerTy() || - !FT->getReturnType()->isPointerTy()) - return nullptr; - Value *SrcStr = CI->getArgOperand(0); ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); @@ -834,13 +694,6 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy(32)) - return nullptr; - Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); if (LHS == RHS) // memcmp(s,s,x) -> 0 @@ -857,9 +710,9 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS if (Len == 1) { - Value *LHSV = B.CreateZExt(B.CreateLoad(CastToCStr(LHS, B), "lhsc"), + Value *LHSV = B.CreateZExt(B.CreateLoad(castToCStr(LHS, B), "lhsc"), CI->getType(), "lhsv"); - Value *RHSV = B.CreateZExt(B.CreateLoad(CastToCStr(RHS, B), "rhsc"), + Value *RHSV = B.CreateZExt(B.CreateLoad(castToCStr(RHS, B), "rhsc"), CI->getType(), "rhsv"); return B.CreateSub(LHSV, RHSV, "chardiff"); } @@ -909,11 +762,6 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memcpy)) - return nullptr; - // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), 1); @@ -921,23 +769,81 @@ Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memmove)) - return nullptr; - // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), 1); return CI->getArgOperand(0); } -Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); +// TODO: Does this belong in BuildLibCalls or should all of those similar +// functions be moved here? +static Value *emitCalloc(Value *Num, Value *Size, const AttributeSet &Attrs, + IRBuilder<> &B, const TargetLibraryInfo &TLI) { + LibFunc::Func Func; + if (!TLI.getLibFunc("calloc", Func) || !TLI.has(Func)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + const DataLayout &DL = M->getDataLayout(); + IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); + Value *Calloc = M->getOrInsertFunction("calloc", Attrs, B.getInt8PtrTy(), + PtrType, PtrType, nullptr); + CallInst *CI = B.CreateCall(Calloc, { Num, Size }, "calloc"); + + if (const auto *F = dyn_cast<Function>(Calloc->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + + return CI; +} - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memset)) +/// Fold memset[_chk](malloc(n), 0, n) --> calloc(1, n). +static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, + const TargetLibraryInfo &TLI) { + // This has to be a memset of zeros (bzero). + auto *FillValue = dyn_cast<ConstantInt>(Memset->getArgOperand(1)); + if (!FillValue || FillValue->getZExtValue() != 0) return nullptr; + // TODO: We should handle the case where the malloc has more than one use. + // This is necessary to optimize common patterns such as when the result of + // the malloc is checked against null or when a memset intrinsic is used in + // place of a memset library call. + auto *Malloc = dyn_cast<CallInst>(Memset->getArgOperand(0)); + if (!Malloc || !Malloc->hasOneUse()) + return nullptr; + + // Is the inner call really malloc()? + Function *InnerCallee = Malloc->getCalledFunction(); + LibFunc::Func Func; + if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) || + Func != LibFunc::malloc) + return nullptr; + + // The memset must cover the same number of bytes that are malloc'd. + if (Memset->getArgOperand(2) != Malloc->getArgOperand(0)) + return nullptr; + + // Replace the malloc with a calloc. We need the data layout to know what the + // actual size of a 'size_t' parameter is. + B.SetInsertPoint(Malloc->getParent(), ++Malloc->getIterator()); + const DataLayout &DL = Malloc->getModule()->getDataLayout(); + IntegerType *SizeType = DL.getIntPtrType(B.GetInsertBlock()->getContext()); + Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1), + Malloc->getArgOperand(0), Malloc->getAttributes(), + B, TLI); + if (!Calloc) + return nullptr; + + Malloc->replaceAllUsesWith(Calloc); + Malloc->eraseFromParent(); + + return Calloc; +} + +Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { + if (auto *Calloc = foldMallocMemset(CI, B, *TLI)) + return Calloc; + // memset(p, v, n) -> llvm.memset(p, v, n, 1) Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); @@ -970,34 +876,12 @@ static Value *valueHasFloatPrecision(Value *Val) { return nullptr; } -/// Any floating-point library function that we're trying to simplify will have -/// a signature of the form: fptype foo(fptype param1, fptype param2, ...). -/// CheckDoubleTy indicates that 'fptype' must be 'double'. -static bool matchesFPLibFunctionSignature(const Function *F, unsigned NumParams, - bool CheckDoubleTy) { - FunctionType *FT = F->getFunctionType(); - if (FT->getNumParams() != NumParams) - return false; - - // The return type must match what we're looking for. - Type *RetTy = FT->getReturnType(); - if (CheckDoubleTy ? !RetTy->isDoubleTy() : !RetTy->isFloatingPointTy()) - return false; - - // Each parameter must match the return type, and therefore, match every other - // parameter too. - for (const Type *ParamTy : FT->params()) - if (ParamTy != RetTy) - return false; - - return true; -} - /// Shrink double -> float for unary functions like 'floor'. static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, bool CheckRetType) { Function *Callee = CI->getCalledFunction(); - if (!matchesFPLibFunctionSignature(Callee, 1, true)) + // We know this libcall has a valid prototype, but we don't know which. + if (!CI->getType()->isDoubleTy()) return nullptr; if (CheckRetType) { @@ -1026,7 +910,7 @@ static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, V = B.CreateCall(F, V); } else { // The call is a library call rather than an intrinsic. - V = EmitUnaryFloatFnCall(V, Callee->getName(), B, Callee->getAttributes()); + V = emitUnaryFloatFnCall(V, Callee->getName(), B, Callee->getAttributes()); } return B.CreateFPExt(V, B.getDoubleTy()); @@ -1035,7 +919,8 @@ static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, /// Shrink double -> float for binary functions like 'fmin/fmax'. static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - if (!matchesFPLibFunctionSignature(Callee, 2, true)) + // We know this libcall has a valid prototype, but we don't know which. + if (!CI->getType()->isDoubleTy()) return nullptr; // If this is something like 'fmin((double)floatval1, (double)floatval2)', @@ -1054,7 +939,7 @@ static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { // fmin((double)floatval1, (double)floatval2) // -> (double)fminf(floatval1, floatval2) // TODO: Handle intrinsics in the same way as in optimizeUnaryDoubleFP(). - Value *V = EmitBinaryFloatFnCall(V1, V2, Callee->getName(), B, + Value *V = emitBinaryFloatFnCall(V1, V2, Callee->getName(), B, Callee->getAttributes()); return B.CreateFPExt(V, B.getDoubleTy()); } @@ -1066,13 +951,6 @@ Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) { if (UnsafeFPShrink && Name == "cos" && hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, true); - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; - // cos(-x) -> cos(x) Value *Op1 = CI->getArgOperand(0); if (BinaryOperator::isFNeg(Op1)) { @@ -1114,14 +992,6 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { if (UnsafeFPShrink && Name == "pow" && hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, true); - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; - Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { // pow(1.0, x) -> 1.0 @@ -1131,19 +1001,16 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { if (Op1C->isExactlyValue(2.0) && hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp2, LibFunc::exp2f, LibFunc::exp2l)) - return EmitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp2), B, + return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp2), B, Callee->getAttributes()); // pow(10.0, x) -> exp10(x) if (Op1C->isExactlyValue(10.0) && hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, LibFunc::exp10l)) - return EmitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, + return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, Callee->getAttributes()); } - // FIXME: Use instruction-level FMF. - bool UnsafeFPMath = canUseUnsafeFPMath(CI->getParent()->getParent()); - // pow(exp(x), y) -> exp(x * y) // pow(exp2(x), y) -> exp2(x * y) // We enable these only with fast-math. Besides rounding differences, the @@ -1159,7 +1026,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); Value *FMul = B.CreateFMul(OpC->getArgOperand(0), Op2, "mul"); - return EmitUnaryFloatFnCall(FMul, OpCCallee->getName(), B, + return emitUnaryFloatFnCall(FMul, OpCCallee->getName(), B, OpCCallee->getAttributes()); } } @@ -1181,7 +1048,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { if (CI->hasUnsafeAlgebra()) { IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); - return EmitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B, + return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B, Callee->getAttributes()); } @@ -1191,9 +1058,9 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { // TODO: In finite-only mode, this could be just fabs(sqrt(x)). Value *Inf = ConstantFP::getInfinity(CI->getType()); Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); - Value *Sqrt = EmitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); + Value *Sqrt = emitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); Value *FAbs = - EmitUnaryFloatFnCall(Sqrt, "fabs", B, Callee->getAttributes()); + emitUnaryFloatFnCall(Sqrt, "fabs", B, Callee->getAttributes()); Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); return Sel; @@ -1207,7 +1074,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip"); // In -ffast-math, generate repeated fmul instead of generating pow(x, n). - if (UnsafeFPMath) { + if (CI->hasUnsafeAlgebra()) { APFloat V = abs(Op2C->getValueAPF()); // We limit to a max of 7 fmul(s). Thus max exponent is 32. // This transformation applies to integer exponents only. @@ -1224,6 +1091,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { // So we first convert V to something which could be converted to double. bool ignored; V.convert(APFloat::IEEEdouble, APFloat::rmTowardZero, &ignored); + + // TODO: Should the new instructions propagate the 'fast' flag of the pow()? Value *FMul = getPow(InnerChain, V.convertToDouble(), B); // For negative exponents simply compute the reciprocal. if (Op2C->isNegative()) @@ -1236,19 +1105,11 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - Function *Caller = CI->getParent()->getParent(); Value *Ret = nullptr; StringRef Name = Callee->getName(); if (UnsafeFPShrink && Name == "exp2" && hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, true); - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; - Value *Op = CI->getArgOperand(0); // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 @@ -1273,11 +1134,11 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { if (!Op->getType()->isFloatTy()) One = ConstantExpr::getFPExtend(One, Op->getType()); - Module *M = Caller->getParent(); - Value *Callee = + Module *M = CI->getModule(); + Value *NewCallee = M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), Op->getType(), B.getInt32Ty(), nullptr); - CallInst *CI = B.CreateCall(Callee, {One, LdExpArg}); + CallInst *CI = B.CreateCall(NewCallee, {One, LdExpArg}); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1294,12 +1155,6 @@ Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { if (Name == "fabs" && hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, false); - FunctionType *FT = Callee->getFunctionType(); - // Make sure this has 1 argument of FP type which matches the result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; - Value *Op = CI->getArgOperand(0); if (Instruction *I = dyn_cast<Instruction>(Op)) { // Fold fabs(x * x) -> x * x; any squared FP value must already be positive. @@ -1311,21 +1166,14 @@ Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); // If we can shrink the call to a float function rather than a double // function, do that first. - Function *Callee = CI->getCalledFunction(); StringRef Name = Callee->getName(); if ((Name == "fmin" || Name == "fmax") && hasFloatVersion(Name)) if (Value *Ret = optimizeBinaryDoubleFP(CI, B)) return Ret; - // Make sure this has 2 arguments of FP type which match the result type. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - !FT->getParamType(0)->isFloatingPointTy()) - return nullptr; - IRBuilder<>::FastMathFlagGuard Guard(B); FastMathFlags FMF; if (CI->hasUnsafeAlgebra()) { @@ -1360,13 +1208,6 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { StringRef Name = Callee->getName(); if (UnsafeFPShrink && hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, true); - FunctionType *FT = Callee->getFunctionType(); - - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; if (!CI->hasUnsafeAlgebra()) return Ret; @@ -1392,7 +1233,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow)) return B.CreateFMul(OpC->getArgOperand(1), - EmitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B, + emitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B, Callee->getAttributes()), "mul"); // log(exp2(y)) -> y*log(2) @@ -1400,7 +1241,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { TLI->has(Func) && Func == LibFunc::exp2) return B.CreateFMul( OpC->getArgOperand(0), - EmitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0), + emitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0), Callee->getName(), B, Callee->getAttributes()), "logmul"); return Ret; @@ -1408,21 +1249,11 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - Value *Ret = nullptr; if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" || Callee->getIntrinsicID() == Intrinsic::sqrt)) Ret = optimizeUnaryDoubleFP(CI, B, true); - // FIXME: Refactor - this check is repeated all over this file and even in the - // preceding call to shrink double -> float. - - // Make sure this has 1 argument of FP type, which matches the result type. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; - if (!CI->hasUnsafeAlgebra()) return Ret; @@ -1489,13 +1320,6 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) { StringRef Name = Callee->getName(); if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, true); - FunctionType *FT = Callee->getFunctionType(); - - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; Value *Op1 = CI->getArgOperand(0); auto *OpC = dyn_cast<CallInst>(Op1); @@ -1519,13 +1343,65 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) { return Ret; } -static bool isTrigLibCall(CallInst *CI); +static bool isTrigLibCall(CallInst *CI) { + // We can only hope to do anything useful if we can ignore things like errno + // and floating-point exceptions. + // We already checked the prototype. + return CI->hasFnAttr(Attribute::NoUnwind) && + CI->hasFnAttr(Attribute::ReadNone); +} + static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, bool UseFloat, Value *&Sin, Value *&Cos, - Value *&SinCos); + Value *&SinCos) { + Type *ArgTy = Arg->getType(); + Type *ResTy; + StringRef Name; -Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilder<> &B) { + Triple T(OrigCallee->getParent()->getTargetTriple()); + if (UseFloat) { + Name = "__sincospif_stret"; + + assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now"); + // x86_64 can't use {float, float} since that would be returned in both + // xmm0 and xmm1, which isn't what a real struct would do. + ResTy = T.getArch() == Triple::x86_64 + ? static_cast<Type *>(VectorType::get(ArgTy, 2)) + : static_cast<Type *>(StructType::get(ArgTy, ArgTy, nullptr)); + } else { + Name = "__sincospi_stret"; + ResTy = StructType::get(ArgTy, ArgTy, nullptr); + } + + Module *M = OrigCallee->getParent(); + Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), + ResTy, ArgTy, nullptr); + + if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { + // If the argument is an instruction, it must dominate all uses so put our + // sincos call there. + B.SetInsertPoint(ArgInst->getParent(), ++ArgInst->getIterator()); + } else { + // Otherwise (e.g. for a constant) the beginning of the function is as + // good a place as any. + BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock(); + B.SetInsertPoint(&EntryBB, EntryBB.begin()); + } + + SinCos = B.CreateCall(Callee, Arg, "sincospi"); + + if (SinCos->getType()->isStructTy()) { + Sin = B.CreateExtractValue(SinCos, 0, "sinpi"); + Cos = B.CreateExtractValue(SinCos, 1, "cospi"); + } else { + Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0), + "sinpi"); + Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), + "cospi"); + } +} +Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilder<> &B) { // Make sure the prototype is as expected, otherwise the rest of the // function is probably invalid and likely to abort. if (!isTrigLibCall(CI)) @@ -1541,9 +1417,9 @@ Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilder<> &B) { // Look for all compatible sinpi, cospi and sincospi calls with the same // argument. If there are enough (in some sense) we can make the // substitution. + Function *F = CI->getFunction(); for (User *U : Arg->users()) - classifyArgUse(U, CI->getParent(), IsFloat, SinCalls, CosCalls, - SinCosCalls); + classifyArgUse(U, F, IsFloat, SinCalls, CosCalls, SinCosCalls); // It's only worthwhile if both sinpi and cospi are actually used. if (SinCosCalls.empty() && (SinCalls.empty() || CosCalls.empty())) @@ -1559,35 +1435,23 @@ Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilder<> &B) { return nullptr; } -static bool isTrigLibCall(CallInst *CI) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - - // We can only hope to do anything useful if we can ignore things like errno - // and floating-point exceptions. - bool AttributesSafe = - CI->hasFnAttr(Attribute::NoUnwind) && CI->hasFnAttr(Attribute::ReadNone); - - // Other than that we need float(float) or double(double) - return AttributesSafe && FT->getNumParams() == 1 && - FT->getReturnType() == FT->getParamType(0) && - (FT->getParamType(0)->isFloatTy() || - FT->getParamType(0)->isDoubleTy()); -} - -void -LibCallSimplifier::classifyArgUse(Value *Val, BasicBlock *BB, bool IsFloat, - SmallVectorImpl<CallInst *> &SinCalls, - SmallVectorImpl<CallInst *> &CosCalls, - SmallVectorImpl<CallInst *> &SinCosCalls) { +void LibCallSimplifier::classifyArgUse( + Value *Val, Function *F, bool IsFloat, + SmallVectorImpl<CallInst *> &SinCalls, + SmallVectorImpl<CallInst *> &CosCalls, + SmallVectorImpl<CallInst *> &SinCosCalls) { CallInst *CI = dyn_cast<CallInst>(Val); if (!CI) return; + // Don't consider calls in other functions. + if (CI->getFunction() != F) + return; + Function *Callee = CI->getCalledFunction(); LibFunc::Func Func; - if (!Callee || !TLI->getLibFunc(Callee->getName(), Func) || !TLI->has(Func) || + if (!Callee || !TLI->getLibFunc(*Callee, Func) || !TLI->has(Func) || !isTrigLibCall(CI)) return; @@ -1614,69 +1478,12 @@ void LibCallSimplifier::replaceTrigInsts(SmallVectorImpl<CallInst *> &Calls, replaceAllUsesWith(C, Res); } -void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, - bool UseFloat, Value *&Sin, Value *&Cos, Value *&SinCos) { - Type *ArgTy = Arg->getType(); - Type *ResTy; - StringRef Name; - - Triple T(OrigCallee->getParent()->getTargetTriple()); - if (UseFloat) { - Name = "__sincospif_stret"; - - assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now"); - // x86_64 can't use {float, float} since that would be returned in both - // xmm0 and xmm1, which isn't what a real struct would do. - ResTy = T.getArch() == Triple::x86_64 - ? static_cast<Type *>(VectorType::get(ArgTy, 2)) - : static_cast<Type *>(StructType::get(ArgTy, ArgTy, nullptr)); - } else { - Name = "__sincospi_stret"; - ResTy = StructType::get(ArgTy, ArgTy, nullptr); - } - - Module *M = OrigCallee->getParent(); - Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), - ResTy, ArgTy, nullptr); - - if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { - // If the argument is an instruction, it must dominate all uses so put our - // sincos call there. - B.SetInsertPoint(ArgInst->getParent(), ++ArgInst->getIterator()); - } else { - // Otherwise (e.g. for a constant) the beginning of the function is as - // good a place as any. - BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock(); - B.SetInsertPoint(&EntryBB, EntryBB.begin()); - } - - SinCos = B.CreateCall(Callee, Arg, "sincospi"); - - if (SinCos->getType()->isStructTy()) { - Sin = B.CreateExtractValue(SinCos, 0, "sinpi"); - Cos = B.CreateExtractValue(SinCos, 1, "cospi"); - } else { - Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0), - "sinpi"); - Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), - "cospi"); - } -} - //===----------------------------------------------------------------------===// // Integer Library Call Optimizations //===----------------------------------------------------------------------===// -static bool checkIntUnaryReturnAndParam(Function *Callee) { - FunctionType *FT = Callee->getFunctionType(); - return FT->getNumParams() == 1 && FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0)->isIntegerTy(); -} - Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - if (!checkIntUnaryReturnAndParam(Callee)) - return nullptr; Value *Op = CI->getArgOperand(0); // Constant fold. @@ -1700,13 +1507,6 @@ Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - // We require integer(integer) where the types agree. - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - FT->getParamType(0) != FT->getReturnType()) - return nullptr; - // abs(x) -> x >s -1 ? x : -x Value *Op = CI->getArgOperand(0); Value *Pos = @@ -1716,9 +1516,6 @@ Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeIsDigit(CallInst *CI, IRBuilder<> &B) { - if (!checkIntUnaryReturnAndParam(CI->getCalledFunction())) - return nullptr; - // isdigit(c) -> (c-'0') <u 10 Value *Op = CI->getArgOperand(0); Op = B.CreateSub(Op, B.getInt32('0'), "isdigittmp"); @@ -1727,9 +1524,6 @@ Value *LibCallSimplifier::optimizeIsDigit(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeIsAscii(CallInst *CI, IRBuilder<> &B) { - if (!checkIntUnaryReturnAndParam(CI->getCalledFunction())) - return nullptr; - // isascii(c) -> c <u 128 Value *Op = CI->getArgOperand(0); Op = B.CreateICmpULT(Op, B.getInt32(128), "isascii"); @@ -1737,9 +1531,6 @@ Value *LibCallSimplifier::optimizeIsAscii(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeToAscii(CallInst *CI, IRBuilder<> &B) { - if (!checkIntUnaryReturnAndParam(CI->getCalledFunction())) - return nullptr; - // toascii(c) -> c & 0x7f return B.CreateAnd(CI->getArgOperand(0), ConstantInt::get(CI->getType(), 0x7F)); @@ -1753,6 +1544,7 @@ static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg); Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilder<> &B, int StreamArg) { + Function *Callee = CI->getCalledFunction(); // Error reporting calls should be cold, mark them as such. // This applies even to non-builtin calls: it is only a hint and applies to // functions that the frontend might not understand as builtins. @@ -1761,8 +1553,6 @@ Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilder<> &B, // Improving Static Branch Prediction in a Compiler // Brian L. Deitrich, Ben-Chung Cheng, Wen-mei W. Hwu // Proceedings of PACT'98, Oct. 1998, IEEE - Function *Callee = CI->getCalledFunction(); - if (!CI->hasFnAttr(Attribute::Cold) && isReportingError(Callee, CI, StreamArg)) { CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); @@ -1808,12 +1598,18 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilder<> &B) { if (!CI->use_empty()) return nullptr; - // printf("x") -> putchar('x'), even for '%'. - if (FormatStr.size() == 1) { - Value *Res = EmitPutChar(B.getInt32(FormatStr[0]), B, TLI); - if (CI->use_empty() || !Res) - return Res; - return B.CreateIntCast(Res, CI->getType(), true); + // printf("x") -> putchar('x'), even for "%" and "%%". + if (FormatStr.size() == 1 || FormatStr == "%%") + return emitPutChar(B.getInt32(FormatStr[0]), B, TLI); + + // printf("%s", "a") --> putchar('a') + if (FormatStr == "%s" && CI->getNumArgOperands() > 1) { + StringRef ChrStr; + if (!getConstantStringInfo(CI->getOperand(1), ChrStr)) + return nullptr; + if (ChrStr.size() != 1) + return nullptr; + return emitPutChar(B.getInt32(ChrStr[0]), B, TLI); } // printf("foo\n") --> puts("foo") @@ -1823,40 +1619,26 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilder<> &B) { // pass to be run after this pass, to merge duplicate strings. FormatStr = FormatStr.drop_back(); Value *GV = B.CreateGlobalString(FormatStr, "str"); - Value *NewCI = EmitPutS(GV, B, TLI); - return (CI->use_empty() || !NewCI) - ? NewCI - : ConstantInt::get(CI->getType(), FormatStr.size() + 1); + return emitPutS(GV, B, TLI); } // Optimize specific format strings. // printf("%c", chr) --> putchar(chr) if (FormatStr == "%c" && CI->getNumArgOperands() > 1 && - CI->getArgOperand(1)->getType()->isIntegerTy()) { - Value *Res = EmitPutChar(CI->getArgOperand(1), B, TLI); - - if (CI->use_empty() || !Res) - return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + CI->getArgOperand(1)->getType()->isIntegerTy()) + return emitPutChar(CI->getArgOperand(1), B, TLI); // printf("%s\n", str) --> puts(str) if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 && - CI->getArgOperand(1)->getType()->isPointerTy()) { - return EmitPutS(CI->getArgOperand(1), B, TLI); - } + CI->getArgOperand(1)->getType()->isPointerTy()) + return emitPutS(CI->getArgOperand(1), B, TLI); return nullptr; } Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - // Require one fixed pointer argument and an integer/void result. FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || - !(FT->getReturnType()->isIntegerTy() || FT->getReturnType()->isVoidTy())) - return nullptr; - if (Value *V = optimizePrintFString(CI, B)) { return V; } @@ -1909,7 +1691,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); - Value *Ptr = CastToCStr(CI->getArgOperand(0), B); + Value *Ptr = castToCStr(CI->getArgOperand(0), B); B.CreateStore(V, Ptr); Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); B.CreateStore(B.getInt8(0), Ptr); @@ -1922,7 +1704,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { if (!CI->getArgOperand(2)->getType()->isPointerTy()) return nullptr; - Value *Len = EmitStrLen(CI->getArgOperand(2), B, DL, TLI); + Value *Len = emitStrLen(CI->getArgOperand(2), B, DL, TLI); if (!Len) return nullptr; Value *IncLen = @@ -1937,13 +1719,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - // Require two fixed pointer arguments and an integer result. FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - if (Value *V = optimizeSPrintFString(CI, B)) { return V; } @@ -1982,7 +1758,7 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) { if (FormatStr[i] == '%') // Could handle %% -> % if we cared. return nullptr; // We found a format specifier. - return EmitFWrite( + return emitFWrite( CI->getArgOperand(1), ConstantInt::get(DL.getIntPtrType(CI->getContext()), FormatStr.size()), CI->getArgOperand(0), B, DL, TLI); @@ -1999,27 +1775,21 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) { // fprintf(F, "%c", chr) --> fputc(chr, F) if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; - return EmitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, TLI); + return emitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, TLI); } if (FormatStr[1] == 's') { // fprintf(F, "%s", str) --> fputs(str, F) if (!CI->getArgOperand(2)->getType()->isPointerTy()) return nullptr; - return EmitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, TLI); + return emitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, TLI); } return nullptr; } Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - // Require two fixed paramters as pointers and integer result. FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - if (Value *V = optimizeFPrintFString(CI, B)) { return V; } @@ -2041,16 +1811,6 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilder<> &B) { optimizeErrorReporting(CI, B, 3); - Function *Callee = CI->getCalledFunction(); - // Require a pointer, an integer, an integer, a pointer, returning integer. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 4 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - !FT->getParamType(2)->isIntegerTy() || - !FT->getParamType(3)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - // Get the element size and count. ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); @@ -2065,8 +1825,8 @@ Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilder<> &B) { // If this is writing one byte, turn it into fputc. // This optimisation is only valid, if the return value is unused. if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) - Value *Char = B.CreateLoad(CastToCStr(CI->getArgOperand(0), B), "char"); - Value *NewCI = EmitFPutC(Char, CI->getArgOperand(3), B, TLI); + Value *Char = B.CreateLoad(castToCStr(CI->getArgOperand(0), B), "char"); + Value *NewCI = emitFPutC(Char, CI->getArgOperand(3), B, TLI); return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; } @@ -2076,12 +1836,13 @@ Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { optimizeErrorReporting(CI, B, 1); - Function *Callee = CI->getCalledFunction(); + // Don't rewrite fputs to fwrite when optimising for size because fwrite + // requires more arguments and thus extra MOVs are required. + if (CI->getParent()->getParent()->optForSize()) + return nullptr; - // Require two pointers. Also, we can't optimize if return value is used. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || !CI->use_empty()) + // We can't optimize if return value is used. + if (!CI->use_empty()) return nullptr; // fputs(s,F) --> fwrite(s,1,strlen(s),F) @@ -2090,20 +1851,13 @@ Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { return nullptr; // Known to have no uses (see above). - return EmitFWrite( + return emitFWrite( CI->getArgOperand(0), ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len - 1), CI->getArgOperand(1), B, DL, TLI); } Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // Require one fixed pointer argument and an integer/void result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || - !(FT->getReturnType()->isIntegerTy() || FT->getReturnType()->isVoidTy())) - return nullptr; - // Check for a constant string. StringRef Str; if (!getConstantStringInfo(CI->getArgOperand(0), Str)) @@ -2111,7 +1865,7 @@ Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { if (Str.empty() && CI->use_empty()) { // puts("") -> putchar('\n') - Value *Res = EmitPutChar(B.getInt32('\n'), B, TLI); + Value *Res = emitPutChar(B.getInt32('\n'), B, TLI); if (CI->use_empty() || !Res) return Res; return B.CreateIntCast(Res, CI->getType(), true); @@ -2133,10 +1887,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, IRBuilder<> &Builder) { LibFunc::Func Func; Function *Callee = CI->getCalledFunction(); - StringRef FuncName = Callee->getName(); - // Check for string/memory library functions. - if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func)) { + if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { // Make sure we never change the calling convention. assert((ignoreCallingConv(Func) || CI->getCallingConv() == llvm::CallingConv::C) && @@ -2208,10 +1960,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { IRBuilder<> Builder(CI, /*FPMathTag=*/nullptr, OpBundles); bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; - // Command-line parameter overrides function attribute. + // Command-line parameter overrides instruction attribute. if (EnableUnsafeFPShrink.getNumOccurrences() > 0) UnsafeFPShrink = EnableUnsafeFPShrink; - else if (canUseUnsafeFPMath(Callee)) + else if (isa<FPMathOperator>(CI) && CI->hasUnsafeAlgebra()) UnsafeFPShrink = true; // First, check for intrinsics. @@ -2229,6 +1981,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeLog(CI, Builder); case Intrinsic::sqrt: return optimizeSqrt(CI, Builder); + // TODO: Use foldMallocMemset() with memset intrinsic. default: return nullptr; } @@ -2253,7 +2006,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { } // Then check for known library functions. - if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func)) { + if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { // We never change the calling convention. if (!ignoreCallingConv(Func) && !isCallingConvC) return nullptr; @@ -2457,11 +2210,6 @@ bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memcpy_chk)) - return nullptr; - if (isFortifiedCallFoldable(CI, 3, 2, false)) { B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), 1); @@ -2472,11 +2220,6 @@ Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memmove_chk)) - return nullptr; - if (isFortifiedCallFoldable(CI, 3, 2, false)) { B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), 1); @@ -2487,10 +2230,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - - if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memset_chk)) - return nullptr; + // TODO: Try foldMallocMemset() here. if (isFortifiedCallFoldable(CI, 3, 2, false)) { Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); @@ -2506,16 +2246,12 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, Function *Callee = CI->getCalledFunction(); StringRef Name = Callee->getName(); const DataLayout &DL = CI->getModule()->getDataLayout(); - - if (!checkStringCopyLibFuncSignature(Callee, Func)) - return nullptr; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1), *ObjSize = CI->getArgOperand(2); // __stpcpy_chk(x,x,...) -> x+strlen(x) if (Func == LibFunc::stpcpy_chk && !OnlyLowerUnknownSize && Dst == Src) { - Value *StrLen = EmitStrLen(Src, B, DL, TLI); + Value *StrLen = emitStrLen(Src, B, DL, TLI); return StrLen ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, StrLen) : nullptr; } @@ -2525,7 +2261,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, // TODO: It might be nice to get a maximum length out of the possible // string lengths for varying. if (isFortifiedCallFoldable(CI, 2, 1, true)) - return EmitStrCpy(Dst, Src, B, TLI, Name.substr(2, 6)); + return emitStrCpy(Dst, Src, B, TLI, Name.substr(2, 6)); if (OnlyLowerUnknownSize) return nullptr; @@ -2537,7 +2273,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, Type *SizeTTy = DL.getIntPtrType(CI->getContext()); Value *LenV = ConstantInt::get(SizeTTy, Len); - Value *Ret = EmitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI); + Value *Ret = emitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI); // If the function was an __stpcpy_chk, and we were able to fold it into // a __memcpy_chk, we still need to return the correct end pointer. if (Ret && Func == LibFunc::stpcpy_chk) @@ -2550,11 +2286,8 @@ Value *FortifiedLibCallSimplifier::optimizeStrpNCpyChk(CallInst *CI, LibFunc::Func Func) { Function *Callee = CI->getCalledFunction(); StringRef Name = Callee->getName(); - - if (!checkStringCopyLibFuncSignature(Callee, Func)) - return nullptr; if (isFortifiedCallFoldable(CI, 3, 2, false)) { - Value *Ret = EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), + Value *Ret = emitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, TLI, Name.substr(2, 7)); return Ret; } @@ -2577,15 +2310,15 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) { LibFunc::Func Func; Function *Callee = CI->getCalledFunction(); - StringRef FuncName = Callee->getName(); SmallVector<OperandBundleDef, 2> OpBundles; CI->getOperandBundlesAsDefs(OpBundles); IRBuilder<> Builder(CI, /*FPMathTag=*/nullptr, OpBundles); bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; - // First, check that this is a known library functions. - if (!TLI->getLibFunc(FuncName, Func)) + // First, check that this is a known library functions and that the prototype + // is correct. + if (!TLI->getLibFunc(*Callee, Func)) return nullptr; // We never change the calling convention. diff --git a/lib/Transforms/Utils/SplitModule.cpp b/lib/Transforms/Utils/SplitModule.cpp index ad6b782caf8b..e9a368f4faa4 100644 --- a/lib/Transforms/Utils/SplitModule.cpp +++ b/lib/Transforms/Utils/SplitModule.cpp @@ -13,19 +13,184 @@ // //===----------------------------------------------------------------------===// +#define DEBUG_TYPE "split-module" + #include "llvm/Transforms/Utils/SplitModule.h" +#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalObject.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/MD5.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Cloning.h" +#include <queue> using namespace llvm; +namespace { +typedef EquivalenceClasses<const GlobalValue *> ClusterMapType; +typedef DenseMap<const Comdat *, const GlobalValue *> ComdatMembersType; +typedef DenseMap<const GlobalValue *, unsigned> ClusterIDMapType; +} + +static void addNonConstUser(ClusterMapType &GVtoClusterMap, + const GlobalValue *GV, const User *U) { + assert((!isa<Constant>(U) || isa<GlobalValue>(U)) && "Bad user"); + + if (const Instruction *I = dyn_cast<Instruction>(U)) { + const GlobalValue *F = I->getParent()->getParent(); + GVtoClusterMap.unionSets(GV, F); + } else if (isa<GlobalIndirectSymbol>(U) || isa<Function>(U) || + isa<GlobalVariable>(U)) { + GVtoClusterMap.unionSets(GV, cast<GlobalValue>(U)); + } else { + llvm_unreachable("Underimplemented use case"); + } +} + +// Adds all GlobalValue users of V to the same cluster as GV. +static void addAllGlobalValueUsers(ClusterMapType &GVtoClusterMap, + const GlobalValue *GV, const Value *V) { + for (auto *U : V->users()) { + SmallVector<const User *, 4> Worklist; + Worklist.push_back(U); + while (!Worklist.empty()) { + const User *UU = Worklist.pop_back_val(); + // For each constant that is not a GV (a pure const) recurse. + if (isa<Constant>(UU) && !isa<GlobalValue>(UU)) { + Worklist.append(UU->user_begin(), UU->user_end()); + continue; + } + addNonConstUser(GVtoClusterMap, GV, UU); + } + } +} + +// Find partitions for module in the way that no locals need to be +// globalized. +// Try to balance pack those partitions into N files since this roughly equals +// thread balancing for the backend codegen step. +static void findPartitions(Module *M, ClusterIDMapType &ClusterIDMap, + unsigned N) { + // At this point module should have the proper mix of globals and locals. + // As we attempt to partition this module, we must not change any + // locals to globals. + DEBUG(dbgs() << "Partition module with (" << M->size() << ")functions\n"); + ClusterMapType GVtoClusterMap; + ComdatMembersType ComdatMembers; + + auto recordGVSet = [&GVtoClusterMap, &ComdatMembers](GlobalValue &GV) { + if (GV.isDeclaration()) + return; + + if (!GV.hasName()) + GV.setName("__llvmsplit_unnamed"); + + // Comdat groups must not be partitioned. For comdat groups that contain + // locals, record all their members here so we can keep them together. + // Comdat groups that only contain external globals are already handled by + // the MD5-based partitioning. + if (const Comdat *C = GV.getComdat()) { + auto &Member = ComdatMembers[C]; + if (Member) + GVtoClusterMap.unionSets(Member, &GV); + else + Member = &GV; + } + + // For aliases we should not separate them from their aliasees regardless + // of linkage. + if (auto *GIS = dyn_cast<GlobalIndirectSymbol>(&GV)) { + if (const GlobalObject *Base = GIS->getBaseObject()) + GVtoClusterMap.unionSets(&GV, Base); + } + + if (const Function *F = dyn_cast<Function>(&GV)) { + for (const BasicBlock &BB : *F) { + BlockAddress *BA = BlockAddress::lookup(&BB); + if (!BA || !BA->isConstantUsed()) + continue; + addAllGlobalValueUsers(GVtoClusterMap, F, BA); + } + } + + if (GV.hasLocalLinkage()) + addAllGlobalValueUsers(GVtoClusterMap, &GV, &GV); + }; + + std::for_each(M->begin(), M->end(), recordGVSet); + std::for_each(M->global_begin(), M->global_end(), recordGVSet); + std::for_each(M->alias_begin(), M->alias_end(), recordGVSet); + + // Assigned all GVs to merged clusters while balancing number of objects in + // each. + auto CompareClusters = [](const std::pair<unsigned, unsigned> &a, + const std::pair<unsigned, unsigned> &b) { + if (a.second || b.second) + return a.second > b.second; + else + return a.first > b.first; + }; + + std::priority_queue<std::pair<unsigned, unsigned>, + std::vector<std::pair<unsigned, unsigned>>, + decltype(CompareClusters)> + BalancinQueue(CompareClusters); + // Pre-populate priority queue with N slot blanks. + for (unsigned i = 0; i < N; ++i) + BalancinQueue.push(std::make_pair(i, 0)); + + typedef std::pair<unsigned, ClusterMapType::iterator> SortType; + SmallVector<SortType, 64> Sets; + SmallPtrSet<const GlobalValue *, 32> Visited; + + // To guarantee determinism, we have to sort SCC according to size. + // When size is the same, use leader's name. + for (ClusterMapType::iterator I = GVtoClusterMap.begin(), + E = GVtoClusterMap.end(); I != E; ++I) + if (I->isLeader()) + Sets.push_back( + std::make_pair(std::distance(GVtoClusterMap.member_begin(I), + GVtoClusterMap.member_end()), I)); + + std::sort(Sets.begin(), Sets.end(), [](const SortType &a, const SortType &b) { + if (a.first == b.first) + return a.second->getData()->getName() > b.second->getData()->getName(); + else + return a.first > b.first; + }); + + for (auto &I : Sets) { + unsigned CurrentClusterID = BalancinQueue.top().first; + unsigned CurrentClusterSize = BalancinQueue.top().second; + BalancinQueue.pop(); + + DEBUG(dbgs() << "Root[" << CurrentClusterID << "] cluster_size(" << I.first + << ") ----> " << I.second->getData()->getName() << "\n"); + + for (ClusterMapType::member_iterator MI = + GVtoClusterMap.findLeader(I.second); + MI != GVtoClusterMap.member_end(); ++MI) { + if (!Visited.insert(*MI).second) + continue; + DEBUG(dbgs() << "----> " << (*MI)->getName() + << ((*MI)->hasLocalLinkage() ? " l " : " e ") << "\n"); + Visited.insert(*MI); + ClusterIDMap[*MI] = CurrentClusterID; + CurrentClusterSize++; + } + // Add this set size to the number of entries in this cluster. + BalancinQueue.push(std::make_pair(CurrentClusterID, CurrentClusterSize)); + } +} + static void externalize(GlobalValue *GV) { if (GV->hasLocalLinkage()) { GV->setLinkage(GlobalValue::ExternalLinkage); @@ -40,8 +205,8 @@ static void externalize(GlobalValue *GV) { // Returns whether GV should be in partition (0-based) I of N. static bool isInPartition(const GlobalValue *GV, unsigned I, unsigned N) { - if (auto GA = dyn_cast<GlobalAlias>(GV)) - if (const GlobalObject *Base = GA->getBaseObject()) + if (auto *GIS = dyn_cast<GlobalIndirectSymbol>(GV)) + if (const GlobalObject *Base = GIS->getBaseObject()) GV = Base; StringRef Name; @@ -62,21 +227,34 @@ static bool isInPartition(const GlobalValue *GV, unsigned I, unsigned N) { void llvm::SplitModule( std::unique_ptr<Module> M, unsigned N, - std::function<void(std::unique_ptr<Module> MPart)> ModuleCallback) { - for (Function &F : *M) - externalize(&F); - for (GlobalVariable &GV : M->globals()) - externalize(&GV); - for (GlobalAlias &GA : M->aliases()) - externalize(&GA); + function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback, + bool PreserveLocals) { + if (!PreserveLocals) { + for (Function &F : *M) + externalize(&F); + for (GlobalVariable &GV : M->globals()) + externalize(&GV); + for (GlobalAlias &GA : M->aliases()) + externalize(&GA); + for (GlobalIFunc &GIF : M->ifuncs()) + externalize(&GIF); + } + + // This performs splitting without a need for externalization, which might not + // always be possible. + ClusterIDMapType ClusterIDMap; + findPartitions(M.get(), ClusterIDMap, N); // FIXME: We should be able to reuse M as the last partition instead of // cloning it. - for (unsigned I = 0; I != N; ++I) { + for (unsigned I = 0; I < N; ++I) { ValueToValueMapTy VMap; std::unique_ptr<Module> MPart( - CloneModule(M.get(), VMap, [=](const GlobalValue *GV) { - return isInPartition(GV, I, N); + CloneModule(M.get(), VMap, [&](const GlobalValue *GV) { + if (ClusterIDMap.count(GV)) + return (ClusterIDMap[GV] == I); + else + return isInPartition(GV, I, N); })); if (I != 0) MPart->setModuleInlineAsm(""); diff --git a/lib/Transforms/Utils/SymbolRewriter.cpp b/lib/Transforms/Utils/SymbolRewriter.cpp index 1d1f602b041d..7523ca527b68 100644 --- a/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/lib/Transforms/Utils/SymbolRewriter.cpp @@ -58,7 +58,6 @@ //===----------------------------------------------------------------------===// #define DEBUG_TYPE "symbol-rewriter" -#include "llvm/CodeGen/Passes.h" #include "llvm/Pass.h" #include "llvm/ADT/SmallString.h" #include "llvm/IR/LegacyPassManager.h" diff --git a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp index 6b1d1dae5f01..9385f825523c 100644 --- a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp +++ b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp @@ -66,9 +66,7 @@ bool UnifyFunctionExitNodes::runOnFunction(Function &F) { "UnifiedUnreachableBlock", &F); new UnreachableInst(F.getContext(), UnreachableBlock); - for (std::vector<BasicBlock*>::iterator I = UnreachableBlocks.begin(), - E = UnreachableBlocks.end(); I != E; ++I) { - BasicBlock *BB = *I; + for (BasicBlock *BB : UnreachableBlocks) { BB->getInstList().pop_back(); // Remove the unreachable inst. BranchInst::Create(UnreachableBlock, BB); } @@ -104,10 +102,7 @@ bool UnifyFunctionExitNodes::runOnFunction(Function &F) { // Loop over all of the blocks, replacing the return instruction with an // unconditional branch. // - for (std::vector<BasicBlock*>::iterator I = ReturningBlocks.begin(), - E = ReturningBlocks.end(); I != E; ++I) { - BasicBlock *BB = *I; - + for (BasicBlock *BB : ReturningBlocks) { // Add an incoming element to the PHI node for every return instruction that // is merging into this new block... if (PN) diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp index ed4f45c6a615..8f85f19efe38 100644 --- a/lib/Transforms/Utils/Utils.cpp +++ b/lib/Transforms/Utils/Utils.cpp @@ -21,17 +21,20 @@ using namespace llvm; /// initializeTransformUtils - Initialize all passes in the TransformUtils /// library. void llvm::initializeTransformUtils(PassRegistry &Registry) { - initializeAddDiscriminatorsPass(Registry); + initializeAddDiscriminatorsLegacyPassPass(Registry); initializeBreakCriticalEdgesPass(Registry); initializeInstNamerPass(Registry); - initializeLCSSAPass(Registry); + initializeLCSSAWrapperPassPass(Registry); initializeLoopSimplifyPass(Registry); initializeLowerInvokePass(Registry); initializeLowerSwitchPass(Registry); - initializePromotePassPass(Registry); + initializeNameAnonFunctionPass(Registry); + initializePromoteLegacyPassPass(Registry); initializeUnifyFunctionExitNodesPass(Registry); initializeInstSimplifierPass(Registry); initializeMetaRenamerPass(Registry); + initializeMemorySSAWrapperPassPass(Registry); + initializeMemorySSAPrinterLegacyPassPass(Registry); } /// LLVMInitializeTransformUtils - C binding for initializeTransformUtilsPasses. diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index f47ddb9f064f..2eade8cbe8ef 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -13,9 +13,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ValueMapper.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" @@ -25,25 +29,326 @@ using namespace llvm; // Out of line method to get vtable etc for class. void ValueMapTypeRemapper::anchor() {} void ValueMaterializer::anchor() {} -void ValueMaterializer::materializeInitFor(GlobalValue *New, GlobalValue *Old) { -} -Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, - ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer) { - ValueToValueMapTy::iterator I = VM.find(V); - +namespace { + +/// A basic block used in a BlockAddress whose function body is not yet +/// materialized. +struct DelayedBasicBlock { + BasicBlock *OldBB; + std::unique_ptr<BasicBlock> TempBB; + + // Explicit move for MSVC. + DelayedBasicBlock(DelayedBasicBlock &&X) + : OldBB(std::move(X.OldBB)), TempBB(std::move(X.TempBB)) {} + DelayedBasicBlock &operator=(DelayedBasicBlock &&X) { + OldBB = std::move(X.OldBB); + TempBB = std::move(X.TempBB); + return *this; + } + + DelayedBasicBlock(const BlockAddress &Old) + : OldBB(Old.getBasicBlock()), + TempBB(BasicBlock::Create(Old.getContext())) {} +}; + +struct WorklistEntry { + enum EntryKind { + MapGlobalInit, + MapAppendingVar, + MapGlobalAliasee, + RemapFunction + }; + struct GVInitTy { + GlobalVariable *GV; + Constant *Init; + }; + struct AppendingGVTy { + GlobalVariable *GV; + Constant *InitPrefix; + }; + struct GlobalAliaseeTy { + GlobalAlias *GA; + Constant *Aliasee; + }; + + unsigned Kind : 2; + unsigned MCID : 29; + unsigned AppendingGVIsOldCtorDtor : 1; + unsigned AppendingGVNumNewMembers; + union { + GVInitTy GVInit; + AppendingGVTy AppendingGV; + GlobalAliaseeTy GlobalAliasee; + Function *RemapF; + } Data; +}; + +struct MappingContext { + ValueToValueMapTy *VM; + ValueMaterializer *Materializer = nullptr; + + /// Construct a MappingContext with a value map and materializer. + explicit MappingContext(ValueToValueMapTy &VM, + ValueMaterializer *Materializer = nullptr) + : VM(&VM), Materializer(Materializer) {} +}; + +class MDNodeMapper; +class Mapper { + friend class MDNodeMapper; + +#ifndef NDEBUG + DenseSet<GlobalValue *> AlreadyScheduled; +#endif + + RemapFlags Flags; + ValueMapTypeRemapper *TypeMapper; + unsigned CurrentMCID = 0; + SmallVector<MappingContext, 2> MCs; + SmallVector<WorklistEntry, 4> Worklist; + SmallVector<DelayedBasicBlock, 1> DelayedBBs; + SmallVector<Constant *, 16> AppendingInits; + +public: + Mapper(ValueToValueMapTy &VM, RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, ValueMaterializer *Materializer) + : Flags(Flags), TypeMapper(TypeMapper), + MCs(1, MappingContext(VM, Materializer)) {} + + /// ValueMapper should explicitly call \a flush() before destruction. + ~Mapper() { assert(!hasWorkToDo() && "Expected to be flushed"); } + + bool hasWorkToDo() const { return !Worklist.empty(); } + + unsigned + registerAlternateMappingContext(ValueToValueMapTy &VM, + ValueMaterializer *Materializer = nullptr) { + MCs.push_back(MappingContext(VM, Materializer)); + return MCs.size() - 1; + } + + void addFlags(RemapFlags Flags); + + Value *mapValue(const Value *V); + void remapInstruction(Instruction *I); + void remapFunction(Function &F); + + Constant *mapConstant(const Constant *C) { + return cast_or_null<Constant>(mapValue(C)); + } + + /// Map metadata. + /// + /// Find the mapping for MD. Guarantees that the return will be resolved + /// (not an MDNode, or MDNode::isResolved() returns true). + Metadata *mapMetadata(const Metadata *MD); + + void scheduleMapGlobalInitializer(GlobalVariable &GV, Constant &Init, + unsigned MCID); + void scheduleMapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, + bool IsOldCtorDtor, + ArrayRef<Constant *> NewMembers, + unsigned MCID); + void scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee, + unsigned MCID); + void scheduleRemapFunction(Function &F, unsigned MCID); + + void flush(); + +private: + void mapGlobalInitializer(GlobalVariable &GV, Constant &Init); + void mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, + bool IsOldCtorDtor, + ArrayRef<Constant *> NewMembers); + void mapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee); + void remapFunction(Function &F, ValueToValueMapTy &VM); + + ValueToValueMapTy &getVM() { return *MCs[CurrentMCID].VM; } + ValueMaterializer *getMaterializer() { return MCs[CurrentMCID].Materializer; } + + Value *mapBlockAddress(const BlockAddress &BA); + + /// Map metadata that doesn't require visiting operands. + Optional<Metadata *> mapSimpleMetadata(const Metadata *MD); + + Metadata *mapToMetadata(const Metadata *Key, Metadata *Val); + Metadata *mapToSelf(const Metadata *MD); +}; + +class MDNodeMapper { + Mapper &M; + + /// Data about a node in \a UniquedGraph. + struct Data { + bool HasChanged = false; + unsigned ID = ~0u; + TempMDNode Placeholder; + + Data() {} + Data(Data &&X) + : HasChanged(std::move(X.HasChanged)), ID(std::move(X.ID)), + Placeholder(std::move(X.Placeholder)) {} + Data &operator=(Data &&X) { + HasChanged = std::move(X.HasChanged); + ID = std::move(X.ID); + Placeholder = std::move(X.Placeholder); + return *this; + } + }; + + /// A graph of uniqued nodes. + struct UniquedGraph { + SmallDenseMap<const Metadata *, Data, 32> Info; // Node properties. + SmallVector<MDNode *, 16> POT; // Post-order traversal. + + /// Propagate changed operands through the post-order traversal. + /// + /// Iteratively update \a Data::HasChanged for each node based on \a + /// Data::HasChanged of its operands, until fixed point. + void propagateChanges(); + + /// Get a forward reference to a node to use as an operand. + Metadata &getFwdReference(MDNode &Op); + }; + + /// Worklist of distinct nodes whose operands need to be remapped. + SmallVector<MDNode *, 16> DistinctWorklist; + + // Storage for a UniquedGraph. + SmallDenseMap<const Metadata *, Data, 32> InfoStorage; + SmallVector<MDNode *, 16> POTStorage; + +public: + MDNodeMapper(Mapper &M) : M(M) {} + + /// Map a metadata node (and its transitive operands). + /// + /// Map all the (unmapped) nodes in the subgraph under \c N. The iterative + /// algorithm handles distinct nodes and uniqued node subgraphs using + /// different strategies. + /// + /// Distinct nodes are immediately mapped and added to \a DistinctWorklist + /// using \a mapDistinctNode(). Their mapping can always be computed + /// immediately without visiting operands, even if their operands change. + /// + /// The mapping for uniqued nodes depends on whether their operands change. + /// \a mapTopLevelUniquedNode() traverses the transitive uniqued subgraph of + /// a node to calculate uniqued node mappings in bulk. Distinct leafs are + /// added to \a DistinctWorklist with \a mapDistinctNode(). + /// + /// After mapping \c N itself, this function remaps the operands of the + /// distinct nodes in \a DistinctWorklist until the entire subgraph under \c + /// N has been mapped. + Metadata *map(const MDNode &N); + +private: + /// Map a top-level uniqued node and the uniqued subgraph underneath it. + /// + /// This builds up a post-order traversal of the (unmapped) uniqued subgraph + /// underneath \c FirstN and calculates the nodes' mapping. Each node uses + /// the identity mapping (\a Mapper::mapToSelf()) as long as all of its + /// operands uses the identity mapping. + /// + /// The algorithm works as follows: + /// + /// 1. \a createPOT(): traverse the uniqued subgraph under \c FirstN and + /// save the post-order traversal in the given \a UniquedGraph, tracking + /// nodes' operands change. + /// + /// 2. \a UniquedGraph::propagateChanges(): propagate changed operands + /// through the \a UniquedGraph until fixed point, following the rule + /// that if a node changes, any node that references must also change. + /// + /// 3. \a mapNodesInPOT(): map the uniqued nodes, creating new uniqued nodes + /// (referencing new operands) where necessary. + Metadata *mapTopLevelUniquedNode(const MDNode &FirstN); + + /// Try to map the operand of an \a MDNode. + /// + /// If \c Op is already mapped, return the mapping. If it's not an \a + /// MDNode, compute and return the mapping. If it's a distinct \a MDNode, + /// return the result of \a mapDistinctNode(). + /// + /// \return None if \c Op is an unmapped uniqued \a MDNode. + /// \post getMappedOp(Op) only returns None if this returns None. + Optional<Metadata *> tryToMapOperand(const Metadata *Op); + + /// Map a distinct node. + /// + /// Return the mapping for the distinct node \c N, saving the result in \a + /// DistinctWorklist for later remapping. + /// + /// \pre \c N is not yet mapped. + /// \pre \c N.isDistinct(). + MDNode *mapDistinctNode(const MDNode &N); + + /// Get a previously mapped node. + Optional<Metadata *> getMappedOp(const Metadata *Op) const; + + /// Create a post-order traversal of an unmapped uniqued node subgraph. + /// + /// This traverses the metadata graph deeply enough to map \c FirstN. It + /// uses \a tryToMapOperand() (via \a Mapper::mapSimplifiedNode()), so any + /// metadata that has already been mapped will not be part of the POT. + /// + /// Each node that has a changed operand from outside the graph (e.g., a + /// distinct node, an already-mapped uniqued node, or \a ConstantAsMetadata) + /// is marked with \a Data::HasChanged. + /// + /// \return \c true if any nodes in \c G have \a Data::HasChanged. + /// \post \c G.POT is a post-order traversal ending with \c FirstN. + /// \post \a Data::hasChanged in \c G.Info indicates whether any node needs + /// to change because of operands outside the graph. + bool createPOT(UniquedGraph &G, const MDNode &FirstN); + + /// Visit the operands of a uniqued node in the POT. + /// + /// Visit the operands in the range from \c I to \c E, returning the first + /// uniqued node we find that isn't yet in \c G. \c I is always advanced to + /// where to continue the loop through the operands. + /// + /// This sets \c HasChanged if any of the visited operands change. + MDNode *visitOperands(UniquedGraph &G, MDNode::op_iterator &I, + MDNode::op_iterator E, bool &HasChanged); + + /// Map all the nodes in the given uniqued graph. + /// + /// This visits all the nodes in \c G in post-order, using the identity + /// mapping or creating a new node depending on \a Data::HasChanged. + /// + /// \pre \a getMappedOp() returns None for nodes in \c G, but not for any of + /// their operands outside of \c G. + /// \pre \a Data::HasChanged is true for a node in \c G iff any of its + /// operands have changed. + /// \post \a getMappedOp() returns the mapped node for every node in \c G. + void mapNodesInPOT(UniquedGraph &G); + + /// Remap a node's operands using the given functor. + /// + /// Iterate through the operands of \c N and update them in place using \c + /// mapOperand. + /// + /// \pre N.isDistinct() or N.isTemporary(). + template <class OperandMapper> + void remapOperands(MDNode &N, OperandMapper mapOperand); +}; + +} // end namespace + +Value *Mapper::mapValue(const Value *V) { + ValueToValueMapTy::iterator I = getVM().find(V); + // If the value already exists in the map, use it. - if (I != VM.end() && I->second) return I->second; - + if (I != getVM().end()) { + assert(I->second && "Unexpected null mapping"); + return I->second; + } + // If we have a materializer and it can materialize a value, use that. - if (Materializer) { - if (Value *NewV = - Materializer->materializeDeclFor(const_cast<Value *>(V))) { - VM[V] = NewV; - if (auto *NewGV = dyn_cast<GlobalValue>(NewV)) - Materializer->materializeInitFor( - NewGV, const_cast<GlobalValue *>(cast<GlobalValue>(V))); + if (auto *Materializer = getMaterializer()) { + if (Value *NewV = Materializer->materialize(const_cast<Value *>(V))) { + getVM()[V] = NewV; return NewV; } } @@ -51,13 +356,9 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, // Global values do not need to be seeded into the VM if they // are using the identity mapping. if (isa<GlobalValue>(V)) { - if (Flags & RF_NullMapMissingGlobalValues) { - assert(!(Flags & RF_IgnoreMissingEntries) && - "Illegal to specify both RF_NullMapMissingGlobalValues and " - "RF_IgnoreMissingEntries"); + if (Flags & RF_NullMapMissingGlobalValues) return nullptr; - } - return VM[V] = const_cast<Value*>(V); + return getVM()[V] = const_cast<Value *>(V); } if (const InlineAsm *IA = dyn_cast<InlineAsm>(V)) { @@ -70,28 +371,39 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, V = InlineAsm::get(NewTy, IA->getAsmString(), IA->getConstraintString(), IA->hasSideEffects(), IA->isAlignStack()); } - - return VM[V] = const_cast<Value*>(V); + + return getVM()[V] = const_cast<Value *>(V); } if (const auto *MDV = dyn_cast<MetadataAsValue>(V)) { const Metadata *MD = MDV->getMetadata(); + + if (auto *LAM = dyn_cast<LocalAsMetadata>(MD)) { + // Look through to grab the local value. + if (Value *LV = mapValue(LAM->getValue())) { + if (V == LAM->getValue()) + return const_cast<Value *>(V); + return MetadataAsValue::get(V->getContext(), ValueAsMetadata::get(LV)); + } + + // FIXME: always return nullptr once Verifier::verifyDominatesUse() + // ensures metadata operands only reference defined SSA values. + return (Flags & RF_IgnoreMissingLocals) + ? nullptr + : MetadataAsValue::get(V->getContext(), + MDTuple::get(V->getContext(), None)); + } + // If this is a module-level metadata and we know that nothing at the module // level is changing, then use an identity mapping. - if (!isa<LocalAsMetadata>(MD) && (Flags & RF_NoModuleLevelChanges)) - return VM[V] = const_cast<Value *>(V); - - auto *MappedMD = MapMetadata(MD, VM, Flags, TypeMapper, Materializer); - if (MD == MappedMD || (!MappedMD && (Flags & RF_IgnoreMissingEntries))) - return VM[V] = const_cast<Value *>(V); - - // FIXME: This assert crashes during bootstrap, but I think it should be - // correct. For now, just match behaviour from before the metadata/value - // split. - // - // assert((MappedMD || (Flags & RF_NullMapMissingGlobalValues)) && - // "Referenced metadata value not in value map"); - return VM[V] = MetadataAsValue::get(V->getContext(), MappedMD); + if (Flags & RF_NoModuleLevelChanges) + return getVM()[V] = const_cast<Value *>(V); + + // Map the metadata and turn it into a value. + auto *MappedMD = mapMetadata(MD); + if (MD == MappedMD) + return getVM()[V] = const_cast<Value *>(V); + return getVM()[V] = MetadataAsValue::get(V->getContext(), MappedMD); } // Okay, this either must be a constant (which may or may not be mappable) or @@ -99,25 +411,31 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, Constant *C = const_cast<Constant*>(dyn_cast<Constant>(V)); if (!C) return nullptr; - - if (BlockAddress *BA = dyn_cast<BlockAddress>(C)) { - Function *F = - cast<Function>(MapValue(BA->getFunction(), VM, Flags, TypeMapper, Materializer)); - BasicBlock *BB = cast_or_null<BasicBlock>(MapValue(BA->getBasicBlock(), VM, - Flags, TypeMapper, Materializer)); - return VM[V] = BlockAddress::get(F, BB ? BB : BA->getBasicBlock()); - } - + + if (BlockAddress *BA = dyn_cast<BlockAddress>(C)) + return mapBlockAddress(*BA); + + auto mapValueOrNull = [this](Value *V) { + auto Mapped = mapValue(V); + assert((Mapped || (Flags & RF_NullMapMissingGlobalValues)) && + "Unexpected null mapping for constant operand without " + "NullMapMissingGlobalValues flag"); + return Mapped; + }; + // Otherwise, we have some other constant to remap. Start by checking to see // if all operands have an identity remapping. unsigned OpNo = 0, NumOperands = C->getNumOperands(); Value *Mapped = nullptr; for (; OpNo != NumOperands; ++OpNo) { Value *Op = C->getOperand(OpNo); - Mapped = MapValue(Op, VM, Flags, TypeMapper, Materializer); - if (Mapped != C) break; + Mapped = mapValueOrNull(Op); + if (!Mapped) + return nullptr; + if (Mapped != Op) + break; } - + // See if the type mapper wants to remap the type as well. Type *NewTy = C->getType(); if (TypeMapper) @@ -126,23 +444,26 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, // If the result type and all operands match up, then just insert an identity // mapping. if (OpNo == NumOperands && NewTy == C->getType()) - return VM[V] = C; - + return getVM()[V] = C; + // Okay, we need to create a new constant. We've already processed some or // all of the operands, set them all up now. SmallVector<Constant*, 8> Ops; Ops.reserve(NumOperands); for (unsigned j = 0; j != OpNo; ++j) Ops.push_back(cast<Constant>(C->getOperand(j))); - + // If one of the operands mismatch, push it and the other mapped operands. if (OpNo != NumOperands) { Ops.push_back(cast<Constant>(Mapped)); - + // Map the rest of the operands that aren't processed yet. - for (++OpNo; OpNo != NumOperands; ++OpNo) - Ops.push_back(MapValue(cast<Constant>(C->getOperand(OpNo)), VM, - Flags, TypeMapper, Materializer)); + for (++OpNo; OpNo != NumOperands; ++OpNo) { + Mapped = mapValueOrNull(C->getOperand(OpNo)); + if (!Mapped) + return nullptr; + Ops.push_back(cast<Constant>(Mapped)); + } } Type *NewSrcTy = nullptr; if (TypeMapper) @@ -150,309 +471,407 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, NewSrcTy = TypeMapper->remapType(GEPO->getSourceElementType()); if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) - return VM[V] = CE->getWithOperands(Ops, NewTy, false, NewSrcTy); + return getVM()[V] = CE->getWithOperands(Ops, NewTy, false, NewSrcTy); if (isa<ConstantArray>(C)) - return VM[V] = ConstantArray::get(cast<ArrayType>(NewTy), Ops); + return getVM()[V] = ConstantArray::get(cast<ArrayType>(NewTy), Ops); if (isa<ConstantStruct>(C)) - return VM[V] = ConstantStruct::get(cast<StructType>(NewTy), Ops); + return getVM()[V] = ConstantStruct::get(cast<StructType>(NewTy), Ops); if (isa<ConstantVector>(C)) - return VM[V] = ConstantVector::get(Ops); + return getVM()[V] = ConstantVector::get(Ops); // If this is a no-operand constant, it must be because the type was remapped. if (isa<UndefValue>(C)) - return VM[V] = UndefValue::get(NewTy); + return getVM()[V] = UndefValue::get(NewTy); if (isa<ConstantAggregateZero>(C)) - return VM[V] = ConstantAggregateZero::get(NewTy); + return getVM()[V] = ConstantAggregateZero::get(NewTy); assert(isa<ConstantPointerNull>(C)); - return VM[V] = ConstantPointerNull::get(cast<PointerType>(NewTy)); -} - -static Metadata *mapToMetadata(ValueToValueMapTy &VM, const Metadata *Key, - Metadata *Val, ValueMaterializer *Materializer, - RemapFlags Flags) { - VM.MD()[Key].reset(Val); - if (Materializer && !(Flags & RF_HaveUnmaterializedMetadata)) { - auto *N = dyn_cast_or_null<MDNode>(Val); - // Need to invoke this once we have non-temporary MD. - if (!N || !N->isTemporary()) - Materializer->replaceTemporaryMetadata(Key, Val); + return getVM()[V] = ConstantPointerNull::get(cast<PointerType>(NewTy)); +} + +Value *Mapper::mapBlockAddress(const BlockAddress &BA) { + Function *F = cast<Function>(mapValue(BA.getFunction())); + + // F may not have materialized its initializer. In that case, create a + // dummy basic block for now, and replace it once we've materialized all + // the initializers. + BasicBlock *BB; + if (F->empty()) { + DelayedBBs.push_back(DelayedBasicBlock(BA)); + BB = DelayedBBs.back().TempBB.get(); + } else { + BB = cast_or_null<BasicBlock>(mapValue(BA.getBasicBlock())); } - return Val; + + return getVM()[&BA] = BlockAddress::get(F, BB ? BB : BA.getBasicBlock()); } -static Metadata *mapToSelf(ValueToValueMapTy &VM, const Metadata *MD, - ValueMaterializer *Materializer, RemapFlags Flags) { - return mapToMetadata(VM, MD, const_cast<Metadata *>(MD), Materializer, Flags); +Metadata *Mapper::mapToMetadata(const Metadata *Key, Metadata *Val) { + getVM().MD()[Key].reset(Val); + return Val; } -static Metadata *MapMetadataImpl(const Metadata *MD, - SmallVectorImpl<MDNode *> &DistinctWorklist, - ValueToValueMapTy &VM, RemapFlags Flags, - ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer); +Metadata *Mapper::mapToSelf(const Metadata *MD) { + return mapToMetadata(MD, const_cast<Metadata *>(MD)); +} -static Metadata *mapMetadataOp(Metadata *Op, - SmallVectorImpl<MDNode *> &DistinctWorklist, - ValueToValueMapTy &VM, RemapFlags Flags, - ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer) { +Optional<Metadata *> MDNodeMapper::tryToMapOperand(const Metadata *Op) { if (!Op) return nullptr; - if (Materializer && !Materializer->isMetadataNeeded(Op)) + if (Optional<Metadata *> MappedOp = M.mapSimpleMetadata(Op)) { +#ifndef NDEBUG + if (auto *CMD = dyn_cast<ConstantAsMetadata>(Op)) + assert((!*MappedOp || M.getVM().count(CMD->getValue()) || + M.getVM().getMappedMD(Op)) && + "Expected Value to be memoized"); + else + assert((isa<MDString>(Op) || M.getVM().getMappedMD(Op)) && + "Expected result to be memoized"); +#endif + return *MappedOp; + } + + const MDNode &N = *cast<MDNode>(Op); + if (N.isDistinct()) + return mapDistinctNode(N); + return None; +} + +MDNode *MDNodeMapper::mapDistinctNode(const MDNode &N) { + assert(N.isDistinct() && "Expected a distinct node"); + assert(!M.getVM().getMappedMD(&N) && "Expected an unmapped node"); + DistinctWorklist.push_back(cast<MDNode>( + (M.Flags & RF_MoveDistinctMDs) + ? M.mapToSelf(&N) + : M.mapToMetadata(&N, MDNode::replaceWithDistinct(N.clone())))); + return DistinctWorklist.back(); +} + +static ConstantAsMetadata *wrapConstantAsMetadata(const ConstantAsMetadata &CMD, + Value *MappedV) { + if (CMD.getValue() == MappedV) + return const_cast<ConstantAsMetadata *>(&CMD); + return MappedV ? ConstantAsMetadata::getConstant(MappedV) : nullptr; +} + +Optional<Metadata *> MDNodeMapper::getMappedOp(const Metadata *Op) const { + if (!Op) return nullptr; - if (Metadata *MappedOp = MapMetadataImpl(Op, DistinctWorklist, VM, Flags, - TypeMapper, Materializer)) - return MappedOp; - // Use identity map if MappedOp is null and we can ignore missing entries. - if (Flags & RF_IgnoreMissingEntries) + if (Optional<Metadata *> MappedOp = M.getVM().getMappedMD(Op)) + return *MappedOp; + + if (isa<MDString>(Op)) + return const_cast<Metadata *>(Op); + + if (auto *CMD = dyn_cast<ConstantAsMetadata>(Op)) + return wrapConstantAsMetadata(*CMD, M.getVM().lookup(CMD->getValue())); + + return None; +} + +Metadata &MDNodeMapper::UniquedGraph::getFwdReference(MDNode &Op) { + auto Where = Info.find(&Op); + assert(Where != Info.end() && "Expected a valid reference"); + + auto &OpD = Where->second; + if (!OpD.HasChanged) return Op; - // FIXME: This assert crashes during bootstrap, but I think it should be - // correct. For now, just match behaviour from before the metadata/value - // split. - // - // assert((Flags & RF_NullMapMissingGlobalValues) && - // "Referenced metadata not in value map!"); - return nullptr; + // Lazily construct a temporary node. + if (!OpD.Placeholder) + OpD.Placeholder = Op.clone(); + + return *OpD.Placeholder; } -/// Resolve uniquing cycles involving the given metadata. -static void resolveCycles(Metadata *MD, bool AllowTemps) { - if (auto *N = dyn_cast_or_null<MDNode>(MD)) { - if (AllowTemps && N->isTemporary()) - return; - if (!N->isResolved()) { - if (AllowTemps) - // Note that this will drop RAUW support on any temporaries, which - // blocks uniquing. If this ends up being an issue, in the future - // we can experiment with delaying resolving these nodes until - // after metadata is fully materialized (i.e. when linking metadata - // as a postpass after function importing). - N->resolveNonTemporaries(); - else - N->resolveCycles(); - } +template <class OperandMapper> +void MDNodeMapper::remapOperands(MDNode &N, OperandMapper mapOperand) { + assert(!N.isUniqued() && "Expected distinct or temporary nodes"); + for (unsigned I = 0, E = N.getNumOperands(); I != E; ++I) { + Metadata *Old = N.getOperand(I); + Metadata *New = mapOperand(Old); + + if (Old != New) + N.replaceOperandWith(I, New); } } -/// Remap the operands of an MDNode. -/// -/// If \c Node is temporary, uniquing cycles are ignored. If \c Node is -/// distinct, uniquing cycles are resolved as they're found. -/// -/// \pre \c Node.isDistinct() or \c Node.isTemporary(). -static bool remapOperands(MDNode &Node, - SmallVectorImpl<MDNode *> &DistinctWorklist, - ValueToValueMapTy &VM, RemapFlags Flags, - ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer) { - assert(!Node.isUniqued() && "Expected temporary or distinct node"); - const bool IsDistinct = Node.isDistinct(); - - bool AnyChanged = false; - for (unsigned I = 0, E = Node.getNumOperands(); I != E; ++I) { - Metadata *Old = Node.getOperand(I); - Metadata *New = mapMetadataOp(Old, DistinctWorklist, VM, Flags, TypeMapper, - Materializer); - if (Old != New) { - AnyChanged = true; - Node.replaceOperandWith(I, New); - - // Resolve uniquing cycles underneath distinct nodes on the fly so they - // don't infect later operands. - if (IsDistinct) - resolveCycles(New, Flags & RF_HaveUnmaterializedMetadata); +namespace { +/// An entry in the worklist for the post-order traversal. +struct POTWorklistEntry { + MDNode *N; ///< Current node. + MDNode::op_iterator Op; ///< Current operand of \c N. + + /// Keep a flag of whether operands have changed in the worklist to avoid + /// hitting the map in \a UniquedGraph. + bool HasChanged = false; + + POTWorklistEntry(MDNode &N) : N(&N), Op(N.op_begin()) {} +}; +} // end namespace + +bool MDNodeMapper::createPOT(UniquedGraph &G, const MDNode &FirstN) { + assert(G.Info.empty() && "Expected a fresh traversal"); + assert(FirstN.isUniqued() && "Expected uniqued node in POT"); + + // Construct a post-order traversal of the uniqued subgraph under FirstN. + bool AnyChanges = false; + SmallVector<POTWorklistEntry, 16> Worklist; + Worklist.push_back(POTWorklistEntry(const_cast<MDNode &>(FirstN))); + (void)G.Info[&FirstN]; + while (!Worklist.empty()) { + // Start or continue the traversal through the this node's operands. + auto &WE = Worklist.back(); + if (MDNode *N = visitOperands(G, WE.Op, WE.N->op_end(), WE.HasChanged)) { + // Push a new node to traverse first. + Worklist.push_back(POTWorklistEntry(*N)); + continue; } + + // Push the node onto the POT. + assert(WE.N->isUniqued() && "Expected only uniqued nodes"); + assert(WE.Op == WE.N->op_end() && "Expected to visit all operands"); + auto &D = G.Info[WE.N]; + AnyChanges |= D.HasChanged = WE.HasChanged; + D.ID = G.POT.size(); + G.POT.push_back(WE.N); + + // Pop the node off the worklist. + Worklist.pop_back(); } + return AnyChanges; +} - return AnyChanged; -} - -/// Map a distinct MDNode. -/// -/// Whether distinct nodes change is independent of their operands. If \a -/// RF_MoveDistinctMDs, then they are reused, and their operands remapped in -/// place; effectively, they're moved from one graph to another. Otherwise, -/// they're cloned/duplicated, and the new copy's operands are remapped. -static Metadata *mapDistinctNode(const MDNode *Node, - SmallVectorImpl<MDNode *> &DistinctWorklist, - ValueToValueMapTy &VM, RemapFlags Flags, - ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer) { - assert(Node->isDistinct() && "Expected distinct node"); - - MDNode *NewMD; - if (Flags & RF_MoveDistinctMDs) - NewMD = const_cast<MDNode *>(Node); - else - NewMD = MDNode::replaceWithDistinct(Node->clone()); - - // Remap operands later. - DistinctWorklist.push_back(NewMD); - return mapToMetadata(VM, Node, NewMD, Materializer, Flags); -} - -/// \brief Map a uniqued MDNode. -/// -/// Uniqued nodes may not need to be recreated (they may map to themselves). -static Metadata *mapUniquedNode(const MDNode *Node, - SmallVectorImpl<MDNode *> &DistinctWorklist, - ValueToValueMapTy &VM, RemapFlags Flags, - ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer) { - assert(((Flags & RF_HaveUnmaterializedMetadata) || Node->isUniqued()) && - "Expected uniqued node"); - - // Create a temporary node and map it upfront in case we have a uniquing - // cycle. If necessary, this mapping will get updated by RAUW logic before - // returning. - auto ClonedMD = Node->clone(); - mapToMetadata(VM, Node, ClonedMD.get(), Materializer, Flags); - if (!remapOperands(*ClonedMD, DistinctWorklist, VM, Flags, TypeMapper, - Materializer)) { - // No operands changed, so use the original. - ClonedMD->replaceAllUsesWith(const_cast<MDNode *>(Node)); - // Even though replaceAllUsesWith would have replaced the value map - // entry, we need to explictly map with the final non-temporary node - // to replace any temporary metadata via the callback. - return mapToSelf(VM, Node, Materializer, Flags); +MDNode *MDNodeMapper::visitOperands(UniquedGraph &G, MDNode::op_iterator &I, + MDNode::op_iterator E, bool &HasChanged) { + while (I != E) { + Metadata *Op = *I++; // Increment even on early return. + if (Optional<Metadata *> MappedOp = tryToMapOperand(Op)) { + // Check if the operand changes. + HasChanged |= Op != *MappedOp; + continue; + } + + // A uniqued metadata node. + MDNode &OpN = *cast<MDNode>(Op); + assert(OpN.isUniqued() && + "Only uniqued operands cannot be mapped immediately"); + if (G.Info.insert(std::make_pair(&OpN, Data())).second) + return &OpN; // This is a new one. Return it. } + return nullptr; +} - // Uniquify the cloned node. Explicitly map it with the final non-temporary - // node so that replacement of temporary metadata via the callback occurs. - return mapToMetadata(VM, Node, - MDNode::replaceWithUniqued(std::move(ClonedMD)), - Materializer, Flags); +void MDNodeMapper::UniquedGraph::propagateChanges() { + bool AnyChanges; + do { + AnyChanges = false; + for (MDNode *N : POT) { + auto &D = Info[N]; + if (D.HasChanged) + continue; + + if (!llvm::any_of(N->operands(), [&](const Metadata *Op) { + auto Where = Info.find(Op); + return Where != Info.end() && Where->second.HasChanged; + })) + continue; + + AnyChanges = D.HasChanged = true; + } + } while (AnyChanges); } -static Metadata *MapMetadataImpl(const Metadata *MD, - SmallVectorImpl<MDNode *> &DistinctWorklist, - ValueToValueMapTy &VM, RemapFlags Flags, - ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer) { - // If the value already exists in the map, use it. - if (Metadata *NewMD = VM.MD().lookup(MD).get()) - return NewMD; +void MDNodeMapper::mapNodesInPOT(UniquedGraph &G) { + // Construct uniqued nodes, building forward references as necessary. + SmallVector<MDNode *, 16> CyclicNodes; + for (auto *N : G.POT) { + auto &D = G.Info[N]; + if (!D.HasChanged) { + // The node hasn't changed. + M.mapToSelf(N); + continue; + } - if (isa<MDString>(MD)) - return mapToSelf(VM, MD, Materializer, Flags); - - if (isa<ConstantAsMetadata>(MD)) - if ((Flags & RF_NoModuleLevelChanges)) - return mapToSelf(VM, MD, Materializer, Flags); - - if (const auto *VMD = dyn_cast<ValueAsMetadata>(MD)) { - Value *MappedV = - MapValue(VMD->getValue(), VM, Flags, TypeMapper, Materializer); - if (VMD->getValue() == MappedV || - (!MappedV && (Flags & RF_IgnoreMissingEntries))) - return mapToSelf(VM, MD, Materializer, Flags); - - // FIXME: This assert crashes during bootstrap, but I think it should be - // correct. For now, just match behaviour from before the metadata/value - // split. - // - // assert((MappedV || (Flags & RF_NullMapMissingGlobalValues)) && - // "Referenced metadata not in value map!"); - if (MappedV) - return mapToMetadata(VM, MD, ValueAsMetadata::get(MappedV), Materializer, - Flags); - return nullptr; + // Remember whether this node had a placeholder. + bool HadPlaceholder(D.Placeholder); + + // Clone the uniqued node and remap the operands. + TempMDNode ClonedN = D.Placeholder ? std::move(D.Placeholder) : N->clone(); + remapOperands(*ClonedN, [this, &D, &G](Metadata *Old) { + if (Optional<Metadata *> MappedOp = getMappedOp(Old)) + return *MappedOp; + assert(G.Info[Old].ID > D.ID && "Expected a forward reference"); + return &G.getFwdReference(*cast<MDNode>(Old)); + }); + + auto *NewN = MDNode::replaceWithUniqued(std::move(ClonedN)); + M.mapToMetadata(N, NewN); + + // Nodes that were referenced out of order in the POT are involved in a + // uniquing cycle. + if (HadPlaceholder) + CyclicNodes.push_back(NewN); } - // Note: this cast precedes the Flags check so we always get its associated - // assertion. - const MDNode *Node = cast<MDNode>(MD); + // Resolve cycles. + for (auto *N : CyclicNodes) + if (!N->isResolved()) + N->resolveCycles(); +} - // If this is a module-level metadata and we know that nothing at the - // module level is changing, then use an identity mapping. - if (Flags & RF_NoModuleLevelChanges) - return mapToSelf(VM, MD, Materializer, Flags); +Metadata *MDNodeMapper::map(const MDNode &N) { + assert(DistinctWorklist.empty() && "MDNodeMapper::map is not recursive"); + assert(!(M.Flags & RF_NoModuleLevelChanges) && + "MDNodeMapper::map assumes module-level changes"); // Require resolved nodes whenever metadata might be remapped. - assert(((Flags & RF_HaveUnmaterializedMetadata) || Node->isResolved()) && - "Unexpected unresolved node"); - - if (Materializer && Node->isTemporary()) { - assert(Flags & RF_HaveUnmaterializedMetadata); - Metadata *TempMD = - Materializer->mapTemporaryMetadata(const_cast<Metadata *>(MD)); - // If the above callback returned an existing temporary node, use it - // instead of the current temporary node. This happens when earlier - // function importing passes already created and saved a temporary - // metadata node for the same value id. - if (TempMD) { - mapToMetadata(VM, MD, TempMD, Materializer, Flags); - return TempMD; - } + assert(N.isResolved() && "Unexpected unresolved node"); + + Metadata *MappedN = + N.isUniqued() ? mapTopLevelUniquedNode(N) : mapDistinctNode(N); + while (!DistinctWorklist.empty()) + remapOperands(*DistinctWorklist.pop_back_val(), [this](Metadata *Old) { + if (Optional<Metadata *> MappedOp = tryToMapOperand(Old)) + return *MappedOp; + return mapTopLevelUniquedNode(*cast<MDNode>(Old)); + }); + return MappedN; +} + +Metadata *MDNodeMapper::mapTopLevelUniquedNode(const MDNode &FirstN) { + assert(FirstN.isUniqued() && "Expected uniqued node"); + + // Create a post-order traversal of uniqued nodes under FirstN. + UniquedGraph G; + if (!createPOT(G, FirstN)) { + // Return early if no nodes have changed. + for (const MDNode *N : G.POT) + M.mapToSelf(N); + return &const_cast<MDNode &>(FirstN); } - if (Node->isDistinct()) - return mapDistinctNode(Node, DistinctWorklist, VM, Flags, TypeMapper, - Materializer); + // Update graph with all nodes that have changed. + G.propagateChanges(); - return mapUniquedNode(Node, DistinctWorklist, VM, Flags, TypeMapper, - Materializer); + // Map all the nodes in the graph. + mapNodesInPOT(G); + + // Return the original node, remapped. + return *getMappedOp(&FirstN); } -Metadata *llvm::MapMetadata(const Metadata *MD, ValueToValueMapTy &VM, - RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer) { - SmallVector<MDNode *, 8> DistinctWorklist; - Metadata *NewMD = MapMetadataImpl(MD, DistinctWorklist, VM, Flags, TypeMapper, - Materializer); +namespace { - // When there are no module-level changes, it's possible that the metadata - // graph has temporaries. Skip the logic to resolve cycles, since it's - // unnecessary (and invalid) in that case. - if (Flags & RF_NoModuleLevelChanges) - return NewMD; +struct MapMetadataDisabler { + ValueToValueMapTy &VM; - // Resolve cycles involving the entry metadata. - resolveCycles(NewMD, Flags & RF_HaveUnmaterializedMetadata); + MapMetadataDisabler(ValueToValueMapTy &VM) : VM(VM) { + VM.disableMapMetadata(); + } + ~MapMetadataDisabler() { VM.enableMapMetadata(); } +}; - // Remap the operands of distinct MDNodes. - while (!DistinctWorklist.empty()) - remapOperands(*DistinctWorklist.pop_back_val(), DistinctWorklist, VM, Flags, - TypeMapper, Materializer); +} // end namespace - return NewMD; +Optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) { + // If the value already exists in the map, use it. + if (Optional<Metadata *> NewMD = getVM().getMappedMD(MD)) + return *NewMD; + + if (isa<MDString>(MD)) + return const_cast<Metadata *>(MD); + + // This is a module-level metadata. If nothing at the module level is + // changing, use an identity mapping. + if ((Flags & RF_NoModuleLevelChanges)) + return const_cast<Metadata *>(MD); + + if (auto *CMD = dyn_cast<ConstantAsMetadata>(MD)) { + // Disallow recursion into metadata mapping through mapValue. + MapMetadataDisabler MMD(getVM()); + + // Don't memoize ConstantAsMetadata. Instead of lasting until the + // LLVMContext is destroyed, they can be deleted when the GlobalValue they + // reference is destructed. These aren't super common, so the extra + // indirection isn't that expensive. + return wrapConstantAsMetadata(*CMD, mapValue(CMD->getValue())); + } + + assert(isa<MDNode>(MD) && "Expected a metadata node"); + + return None; } -MDNode *llvm::MapMetadata(const MDNode *MD, ValueToValueMapTy &VM, - RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer) { - return cast<MDNode>(MapMetadata(static_cast<const Metadata *>(MD), VM, Flags, - TypeMapper, Materializer)); +Metadata *Mapper::mapMetadata(const Metadata *MD) { + assert(MD && "Expected valid metadata"); + assert(!isa<LocalAsMetadata>(MD) && "Unexpected local metadata"); + + if (Optional<Metadata *> NewMD = mapSimpleMetadata(MD)) + return *NewMD; + + return MDNodeMapper(*this).map(*cast<MDNode>(MD)); +} + +void Mapper::flush() { + // Flush out the worklist of global values. + while (!Worklist.empty()) { + WorklistEntry E = Worklist.pop_back_val(); + CurrentMCID = E.MCID; + switch (E.Kind) { + case WorklistEntry::MapGlobalInit: + E.Data.GVInit.GV->setInitializer(mapConstant(E.Data.GVInit.Init)); + break; + case WorklistEntry::MapAppendingVar: { + unsigned PrefixSize = AppendingInits.size() - E.AppendingGVNumNewMembers; + mapAppendingVariable(*E.Data.AppendingGV.GV, + E.Data.AppendingGV.InitPrefix, + E.AppendingGVIsOldCtorDtor, + makeArrayRef(AppendingInits).slice(PrefixSize)); + AppendingInits.resize(PrefixSize); + break; + } + case WorklistEntry::MapGlobalAliasee: + E.Data.GlobalAliasee.GA->setAliasee( + mapConstant(E.Data.GlobalAliasee.Aliasee)); + break; + case WorklistEntry::RemapFunction: + remapFunction(*E.Data.RemapF); + break; + } + } + CurrentMCID = 0; + + // Finish logic for block addresses now that all global values have been + // handled. + while (!DelayedBBs.empty()) { + DelayedBasicBlock DBB = DelayedBBs.pop_back_val(); + BasicBlock *BB = cast_or_null<BasicBlock>(mapValue(DBB.OldBB)); + DBB.TempBB->replaceAllUsesWith(BB ? BB : DBB.OldBB); + } } -/// RemapInstruction - Convert the instruction operands from referencing the -/// current values into those specified by VMap. -/// -void llvm::RemapInstruction(Instruction *I, ValueToValueMapTy &VMap, - RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, - ValueMaterializer *Materializer){ +void Mapper::remapInstruction(Instruction *I) { // Remap operands. - for (User::op_iterator op = I->op_begin(), E = I->op_end(); op != E; ++op) { - Value *V = MapValue(*op, VMap, Flags, TypeMapper, Materializer); + for (Use &Op : I->operands()) { + Value *V = mapValue(Op); // If we aren't ignoring missing entries, assert that something happened. if (V) - *op = V; + Op = V; else - assert((Flags & RF_IgnoreMissingEntries) && + assert((Flags & RF_IgnoreMissingLocals) && "Referenced value not in value map!"); } // Remap phi nodes' incoming blocks. if (PHINode *PN = dyn_cast<PHINode>(I)) { for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { - Value *V = MapValue(PN->getIncomingBlock(i), VMap, Flags); + Value *V = mapValue(PN->getIncomingBlock(i)); // If we aren't ignoring missing entries, assert that something happened. if (V) PN->setIncomingBlock(i, cast<BasicBlock>(V)); else - assert((Flags & RF_IgnoreMissingEntries) && + assert((Flags & RF_IgnoreMissingLocals) && "Referenced block not in value map!"); } } @@ -462,11 +881,11 @@ void llvm::RemapInstruction(Instruction *I, ValueToValueMapTy &VMap, I->getAllMetadata(MDs); for (const auto &MI : MDs) { MDNode *Old = MI.second; - MDNode *New = MapMetadata(Old, VMap, Flags, TypeMapper, Materializer); + MDNode *New = cast_or_null<MDNode>(mapMetadata(Old)); if (New != Old) I->setMetadata(MI.first, New); } - + if (!TypeMapper) return; @@ -491,3 +910,213 @@ void llvm::RemapInstruction(Instruction *I, ValueToValueMapTy &VMap, } I->mutateType(TypeMapper->remapType(I->getType())); } + +void Mapper::remapFunction(Function &F) { + // Remap the operands. + for (Use &Op : F.operands()) + if (Op) + Op = mapValue(Op); + + // Remap the metadata attachments. + SmallVector<std::pair<unsigned, MDNode *>, 8> MDs; + F.getAllMetadata(MDs); + F.clearMetadata(); + for (const auto &I : MDs) + F.addMetadata(I.first, *cast<MDNode>(mapMetadata(I.second))); + + // Remap the argument types. + if (TypeMapper) + for (Argument &A : F.args()) + A.mutateType(TypeMapper->remapType(A.getType())); + + // Remap the instructions. + for (BasicBlock &BB : F) + for (Instruction &I : BB) + remapInstruction(&I); +} + +void Mapper::mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, + bool IsOldCtorDtor, + ArrayRef<Constant *> NewMembers) { + SmallVector<Constant *, 16> Elements; + if (InitPrefix) { + unsigned NumElements = + cast<ArrayType>(InitPrefix->getType())->getNumElements(); + for (unsigned I = 0; I != NumElements; ++I) + Elements.push_back(InitPrefix->getAggregateElement(I)); + } + + PointerType *VoidPtrTy; + Type *EltTy; + if (IsOldCtorDtor) { + // FIXME: This upgrade is done during linking to support the C API. See + // also IRLinker::linkAppendingVarProto() in IRMover.cpp. + VoidPtrTy = Type::getInt8Ty(GV.getContext())->getPointerTo(); + auto &ST = *cast<StructType>(NewMembers.front()->getType()); + Type *Tys[3] = {ST.getElementType(0), ST.getElementType(1), VoidPtrTy}; + EltTy = StructType::get(GV.getContext(), Tys, false); + } + + for (auto *V : NewMembers) { + Constant *NewV; + if (IsOldCtorDtor) { + auto *S = cast<ConstantStruct>(V); + auto *E1 = mapValue(S->getOperand(0)); + auto *E2 = mapValue(S->getOperand(1)); + Value *Null = Constant::getNullValue(VoidPtrTy); + NewV = + ConstantStruct::get(cast<StructType>(EltTy), E1, E2, Null, nullptr); + } else { + NewV = cast_or_null<Constant>(mapValue(V)); + } + Elements.push_back(NewV); + } + + GV.setInitializer(ConstantArray::get( + cast<ArrayType>(GV.getType()->getElementType()), Elements)); +} + +void Mapper::scheduleMapGlobalInitializer(GlobalVariable &GV, Constant &Init, + unsigned MCID) { + assert(AlreadyScheduled.insert(&GV).second && "Should not reschedule"); + assert(MCID < MCs.size() && "Invalid mapping context"); + + WorklistEntry WE; + WE.Kind = WorklistEntry::MapGlobalInit; + WE.MCID = MCID; + WE.Data.GVInit.GV = &GV; + WE.Data.GVInit.Init = &Init; + Worklist.push_back(WE); +} + +void Mapper::scheduleMapAppendingVariable(GlobalVariable &GV, + Constant *InitPrefix, + bool IsOldCtorDtor, + ArrayRef<Constant *> NewMembers, + unsigned MCID) { + assert(AlreadyScheduled.insert(&GV).second && "Should not reschedule"); + assert(MCID < MCs.size() && "Invalid mapping context"); + + WorklistEntry WE; + WE.Kind = WorklistEntry::MapAppendingVar; + WE.MCID = MCID; + WE.Data.AppendingGV.GV = &GV; + WE.Data.AppendingGV.InitPrefix = InitPrefix; + WE.AppendingGVIsOldCtorDtor = IsOldCtorDtor; + WE.AppendingGVNumNewMembers = NewMembers.size(); + Worklist.push_back(WE); + AppendingInits.append(NewMembers.begin(), NewMembers.end()); +} + +void Mapper::scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee, + unsigned MCID) { + assert(AlreadyScheduled.insert(&GA).second && "Should not reschedule"); + assert(MCID < MCs.size() && "Invalid mapping context"); + + WorklistEntry WE; + WE.Kind = WorklistEntry::MapGlobalAliasee; + WE.MCID = MCID; + WE.Data.GlobalAliasee.GA = &GA; + WE.Data.GlobalAliasee.Aliasee = &Aliasee; + Worklist.push_back(WE); +} + +void Mapper::scheduleRemapFunction(Function &F, unsigned MCID) { + assert(AlreadyScheduled.insert(&F).second && "Should not reschedule"); + assert(MCID < MCs.size() && "Invalid mapping context"); + + WorklistEntry WE; + WE.Kind = WorklistEntry::RemapFunction; + WE.MCID = MCID; + WE.Data.RemapF = &F; + Worklist.push_back(WE); +} + +void Mapper::addFlags(RemapFlags Flags) { + assert(!hasWorkToDo() && "Expected to have flushed the worklist"); + this->Flags = this->Flags | Flags; +} + +static Mapper *getAsMapper(void *pImpl) { + return reinterpret_cast<Mapper *>(pImpl); +} + +namespace { + +class FlushingMapper { + Mapper &M; + +public: + explicit FlushingMapper(void *pImpl) : M(*getAsMapper(pImpl)) { + assert(!M.hasWorkToDo() && "Expected to be flushed"); + } + ~FlushingMapper() { M.flush(); } + Mapper *operator->() const { return &M; } +}; + +} // end namespace + +ValueMapper::ValueMapper(ValueToValueMapTy &VM, RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) + : pImpl(new Mapper(VM, Flags, TypeMapper, Materializer)) {} + +ValueMapper::~ValueMapper() { delete getAsMapper(pImpl); } + +unsigned +ValueMapper::registerAlternateMappingContext(ValueToValueMapTy &VM, + ValueMaterializer *Materializer) { + return getAsMapper(pImpl)->registerAlternateMappingContext(VM, Materializer); +} + +void ValueMapper::addFlags(RemapFlags Flags) { + FlushingMapper(pImpl)->addFlags(Flags); +} + +Value *ValueMapper::mapValue(const Value &V) { + return FlushingMapper(pImpl)->mapValue(&V); +} + +Constant *ValueMapper::mapConstant(const Constant &C) { + return cast_or_null<Constant>(mapValue(C)); +} + +Metadata *ValueMapper::mapMetadata(const Metadata &MD) { + return FlushingMapper(pImpl)->mapMetadata(&MD); +} + +MDNode *ValueMapper::mapMDNode(const MDNode &N) { + return cast_or_null<MDNode>(mapMetadata(N)); +} + +void ValueMapper::remapInstruction(Instruction &I) { + FlushingMapper(pImpl)->remapInstruction(&I); +} + +void ValueMapper::remapFunction(Function &F) { + FlushingMapper(pImpl)->remapFunction(F); +} + +void ValueMapper::scheduleMapGlobalInitializer(GlobalVariable &GV, + Constant &Init, + unsigned MCID) { + getAsMapper(pImpl)->scheduleMapGlobalInitializer(GV, Init, MCID); +} + +void ValueMapper::scheduleMapAppendingVariable(GlobalVariable &GV, + Constant *InitPrefix, + bool IsOldCtorDtor, + ArrayRef<Constant *> NewMembers, + unsigned MCID) { + getAsMapper(pImpl)->scheduleMapAppendingVariable( + GV, InitPrefix, IsOldCtorDtor, NewMembers, MCID); +} + +void ValueMapper::scheduleMapGlobalAliasee(GlobalAlias &GA, Constant &Aliasee, + unsigned MCID) { + getAsMapper(pImpl)->scheduleMapGlobalAliasee(GA, Aliasee, MCID); +} + +void ValueMapper::scheduleRemapFunction(Function &F, unsigned MCID) { + getAsMapper(pImpl)->scheduleRemapFunction(F, MCID); +} diff --git a/lib/Transforms/Vectorize/BBVectorize.cpp b/lib/Transforms/Vectorize/BBVectorize.cpp index 8844d574a79d..af594cb751aa 100644 --- a/lib/Transforms/Vectorize/BBVectorize.cpp +++ b/lib/Transforms/Vectorize/BBVectorize.cpp @@ -397,7 +397,7 @@ namespace { Instruction *I, Instruction *J); bool vectorizeBB(BasicBlock &BB) { - if (skipOptnoneFunction(BB)) + if (skipBasicBlock(BB)) return false; if (!DT->isReachableFromEntry(&BB)) { DEBUG(dbgs() << "BBV: skipping unreachable " << BB.getName() << @@ -886,9 +886,16 @@ namespace { Type *DestTy = C->getDestTy(); if (!DestTy->isSingleValueType()) return false; - } else if (isa<SelectInst>(I)) { + } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { if (!Config.VectorizeSelect) return false; + // We can vectorize a select if either all operands are scalars, + // or all operands are vectors. Trying to "widen" a select between + // vectors that has a scalar condition results in a malformed select. + // FIXME: We could probably be smarter about this by rewriting the select + // with different types instead. + return (SI->getCondition()->getType()->isVectorTy() == + SI->getTrueValue()->getType()->isVectorTy()); } else if (isa<CmpInst>(I)) { if (!Config.VectorizeCmp) return false; @@ -1117,16 +1124,25 @@ namespace { } if (IID && TTI) { + FastMathFlags FMFCI; + if (auto *FPMOCI = dyn_cast<FPMathOperator>(CI)) + FMFCI = FPMOCI->getFastMathFlags(); + SmallVector<Type*, 4> Tys; for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) Tys.push_back(CI->getArgOperand(i)->getType()); - unsigned ICost = TTI->getIntrinsicInstrCost(IID, IT1, Tys); + unsigned ICost = TTI->getIntrinsicInstrCost(IID, IT1, Tys, FMFCI); Tys.clear(); CallInst *CJ = cast<CallInst>(J); + + FastMathFlags FMFCJ; + if (auto *FPMOCJ = dyn_cast<FPMathOperator>(CJ)) + FMFCJ = FPMOCJ->getFastMathFlags(); + for (unsigned i = 0, ie = CJ->getNumArgOperands(); i != ie; ++i) Tys.push_back(CJ->getArgOperand(i)->getType()); - unsigned JCost = TTI->getIntrinsicInstrCost(IID, JT1, Tys); + unsigned JCost = TTI->getIntrinsicInstrCost(IID, JT1, Tys, FMFCJ); Tys.clear(); assert(CI->getNumArgOperands() == CJ->getNumArgOperands() && @@ -1140,8 +1156,10 @@ namespace { CJ->getArgOperand(i)->getType())); } + FastMathFlags FMFV = FMFCI; + FMFV &= FMFCJ; Type *RetTy = getVecTypeForPair(IT1, JT1); - unsigned VCost = TTI->getIntrinsicInstrCost(IID, RetTy, Tys); + unsigned VCost = TTI->getIntrinsicInstrCost(IID, RetTy, Tys, FMFV); if (VCost > ICost + JCost) return false; @@ -1259,7 +1277,7 @@ namespace { bool JAfterStart = IAfterStart; BasicBlock::iterator J = std::next(I); for (unsigned ss = 0; J != E && ss <= Config.SearchLimit; ++J, ++ss) { - if (&*J == Start) + if (J == Start) JAfterStart = true; // Determine if J uses I, if so, exit the loop. diff --git a/lib/Transforms/Vectorize/CMakeLists.txt b/lib/Transforms/Vectorize/CMakeLists.txt index 905c069cf851..23c2ab025f37 100644 --- a/lib/Transforms/Vectorize/CMakeLists.txt +++ b/lib/Transforms/Vectorize/CMakeLists.txt @@ -1,8 +1,9 @@ add_llvm_library(LLVMVectorize BBVectorize.cpp - Vectorize.cpp + LoadStoreVectorizer.cpp LoopVectorize.cpp SLPVectorizer.cpp + Vectorize.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp new file mode 100644 index 000000000000..c8906bde15e0 --- /dev/null +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -0,0 +1,999 @@ +//===----- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Triple.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Vectorize.h" + +using namespace llvm; + +#define DEBUG_TYPE "load-store-vectorizer" +STATISTIC(NumVectorInstructions, "Number of vector accesses generated"); +STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized"); + +namespace { + +// TODO: Remove this +static const unsigned TargetBaseAlign = 4; + +class Vectorizer { + typedef SmallVector<Value *, 8> ValueList; + typedef MapVector<Value *, ValueList> ValueListMap; + + Function &F; + AliasAnalysis &AA; + DominatorTree &DT; + ScalarEvolution &SE; + TargetTransformInfo &TTI; + const DataLayout &DL; + IRBuilder<> Builder; + ValueListMap StoreRefs; + ValueListMap LoadRefs; + +public: + Vectorizer(Function &F, AliasAnalysis &AA, DominatorTree &DT, + ScalarEvolution &SE, TargetTransformInfo &TTI) + : F(F), AA(AA), DT(DT), SE(SE), TTI(TTI), + DL(F.getParent()->getDataLayout()), Builder(SE.getContext()) {} + + bool run(); + +private: + Value *getPointerOperand(Value *I); + + unsigned getPointerAddressSpace(Value *I); + + unsigned getAlignment(LoadInst *LI) const { + unsigned Align = LI->getAlignment(); + if (Align != 0) + return Align; + + return DL.getABITypeAlignment(LI->getType()); + } + + unsigned getAlignment(StoreInst *SI) const { + unsigned Align = SI->getAlignment(); + if (Align != 0) + return Align; + + return DL.getABITypeAlignment(SI->getValueOperand()->getType()); + } + + bool isConsecutiveAccess(Value *A, Value *B); + + /// After vectorization, reorder the instructions that I depends on + /// (the instructions defining its operands), to ensure they dominate I. + void reorder(Instruction *I); + + /// Returns the first and the last instructions in Chain. + std::pair<BasicBlock::iterator, BasicBlock::iterator> + getBoundaryInstrs(ArrayRef<Value *> Chain); + + /// Erases the original instructions after vectorizing. + void eraseInstructions(ArrayRef<Value *> Chain); + + /// "Legalize" the vector type that would be produced by combining \p + /// ElementSizeBits elements in \p Chain. Break into two pieces such that the + /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is + /// expected to have more than 4 elements. + std::pair<ArrayRef<Value *>, ArrayRef<Value *>> + splitOddVectorElts(ArrayRef<Value *> Chain, unsigned ElementSizeBits); + + /// Checks for instructions which may affect the memory accessed + /// in the chain between \p From and \p To. Returns Index, where + /// \p Chain[0, Index) is the largest vectorizable chain prefix. + /// The elements of \p Chain should be all loads or all stores. + unsigned getVectorizablePrefixEndIdx(ArrayRef<Value *> Chain, + BasicBlock::iterator From, + BasicBlock::iterator To); + + /// Collects load and store instructions to vectorize. + void collectInstructions(BasicBlock *BB); + + /// Processes the collected instructions, the \p Map. The elements of \p Map + /// should be all loads or all stores. + bool vectorizeChains(ValueListMap &Map); + + /// Finds the load/stores to consecutive memory addresses and vectorizes them. + bool vectorizeInstructions(ArrayRef<Value *> Instrs); + + /// Vectorizes the load instructions in Chain. + bool vectorizeLoadChain(ArrayRef<Value *> Chain, + SmallPtrSet<Value *, 16> *InstructionsProcessed); + + /// Vectorizes the store instructions in Chain. + bool vectorizeStoreChain(ArrayRef<Value *> Chain, + SmallPtrSet<Value *, 16> *InstructionsProcessed); + + /// Check if this load/store access is misaligned accesses + bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, + unsigned Alignment); +}; + +class LoadStoreVectorizer : public FunctionPass { +public: + static char ID; + + LoadStoreVectorizer() : FunctionPass(ID) { + initializeLoadStoreVectorizerPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + const char *getPassName() const override { + return "GPU Load and Store Vectorizer"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.setPreservesCFG(); + } +}; +} + +INITIALIZE_PASS_BEGIN(LoadStoreVectorizer, DEBUG_TYPE, + "Vectorize load and Store instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(LoadStoreVectorizer, DEBUG_TYPE, + "Vectorize load and store instructions", false, false) + +char LoadStoreVectorizer::ID = 0; + +Pass *llvm::createLoadStoreVectorizerPass() { + return new LoadStoreVectorizer(); +} + +bool LoadStoreVectorizer::runOnFunction(Function &F) { + // Don't vectorize when the attribute NoImplicitFloat is used. + if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat)) + return false; + + AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + Vectorizer V(F, AA, DT, SE, TTI); + return V.run(); +} + +// Vectorizer Implementation +bool Vectorizer::run() { + bool Changed = false; + + // Scan the blocks in the function in post order. + for (BasicBlock *BB : post_order(&F)) { + collectInstructions(BB); + Changed |= vectorizeChains(LoadRefs); + Changed |= vectorizeChains(StoreRefs); + } + + return Changed; +} + +Value *Vectorizer::getPointerOperand(Value *I) { + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + return LI->getPointerOperand(); + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return SI->getPointerOperand(); + return nullptr; +} + +unsigned Vectorizer::getPointerAddressSpace(Value *I) { + if (LoadInst *L = dyn_cast<LoadInst>(I)) + return L->getPointerAddressSpace(); + if (StoreInst *S = dyn_cast<StoreInst>(I)) + return S->getPointerAddressSpace(); + return -1; +} + +// FIXME: Merge with llvm::isConsecutiveAccess +bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { + Value *PtrA = getPointerOperand(A); + Value *PtrB = getPointerOperand(B); + unsigned ASA = getPointerAddressSpace(A); + unsigned ASB = getPointerAddressSpace(B); + + // Check that the address spaces match and that the pointers are valid. + if (!PtrA || !PtrB || (ASA != ASB)) + return false; + + // Make sure that A and B are different pointers of the same size type. + unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA); + Type *PtrATy = PtrA->getType()->getPointerElementType(); + Type *PtrBTy = PtrB->getType()->getPointerElementType(); + if (PtrA == PtrB || + DL.getTypeStoreSize(PtrATy) != DL.getTypeStoreSize(PtrBTy) || + DL.getTypeStoreSize(PtrATy->getScalarType()) != + DL.getTypeStoreSize(PtrBTy->getScalarType())) + return false; + + APInt Size(PtrBitWidth, DL.getTypeStoreSize(PtrATy)); + + APInt OffsetA(PtrBitWidth, 0), OffsetB(PtrBitWidth, 0); + PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); + PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); + + APInt OffsetDelta = OffsetB - OffsetA; + + // Check if they are based on the same pointer. That makes the offsets + // sufficient. + if (PtrA == PtrB) + return OffsetDelta == Size; + + // Compute the necessary base pointer delta to have the necessary final delta + // equal to the size. + APInt BaseDelta = Size - OffsetDelta; + + // Compute the distance with SCEV between the base pointers. + const SCEV *PtrSCEVA = SE.getSCEV(PtrA); + const SCEV *PtrSCEVB = SE.getSCEV(PtrB); + const SCEV *C = SE.getConstant(BaseDelta); + const SCEV *X = SE.getAddExpr(PtrSCEVA, C); + if (X == PtrSCEVB) + return true; + + // Sometimes even this doesn't work, because SCEV can't always see through + // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking + // things the hard way. + + // Look through GEPs after checking they're the same except for the last + // index. + GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(getPointerOperand(A)); + GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(getPointerOperand(B)); + if (!GEPA || !GEPB || GEPA->getNumOperands() != GEPB->getNumOperands()) + return false; + unsigned FinalIndex = GEPA->getNumOperands() - 1; + for (unsigned i = 0; i < FinalIndex; i++) + if (GEPA->getOperand(i) != GEPB->getOperand(i)) + return false; + + Instruction *OpA = dyn_cast<Instruction>(GEPA->getOperand(FinalIndex)); + Instruction *OpB = dyn_cast<Instruction>(GEPB->getOperand(FinalIndex)); + if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() || + OpA->getType() != OpB->getType()) + return false; + + // Only look through a ZExt/SExt. + if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA)) + return false; + + bool Signed = isa<SExtInst>(OpA); + + OpA = dyn_cast<Instruction>(OpA->getOperand(0)); + OpB = dyn_cast<Instruction>(OpB->getOperand(0)); + if (!OpA || !OpB || OpA->getType() != OpB->getType()) + return false; + + // Now we need to prove that adding 1 to OpA won't overflow. + bool Safe = false; + // First attempt: if OpB is an add with NSW/NUW, and OpB is 1 added to OpA, + // we're okay. + if (OpB->getOpcode() == Instruction::Add && + isa<ConstantInt>(OpB->getOperand(1)) && + cast<ConstantInt>(OpB->getOperand(1))->getSExtValue() > 0) { + if (Signed) + Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap(); + else + Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap(); + } + + unsigned BitWidth = OpA->getType()->getScalarSizeInBits(); + + // Second attempt: + // If any bits are known to be zero other than the sign bit in OpA, we can + // add 1 to it while guaranteeing no overflow of any sort. + if (!Safe) { + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + computeKnownBits(OpA, KnownZero, KnownOne, DL, 0, nullptr, OpA, &DT); + KnownZero &= ~APInt::getHighBitsSet(BitWidth, 1); + if (KnownZero != 0) + Safe = true; + } + + if (!Safe) + return false; + + const SCEV *OffsetSCEVA = SE.getSCEV(OpA); + const SCEV *OffsetSCEVB = SE.getSCEV(OpB); + const SCEV *One = SE.getConstant(APInt(BitWidth, 1)); + const SCEV *X2 = SE.getAddExpr(OffsetSCEVA, One); + return X2 == OffsetSCEVB; +} + +void Vectorizer::reorder(Instruction *I) { + SmallPtrSet<Instruction *, 16> InstructionsToMove; + SmallVector<Instruction *, 16> Worklist; + + Worklist.push_back(I); + while (!Worklist.empty()) { + Instruction *IW = Worklist.pop_back_val(); + int NumOperands = IW->getNumOperands(); + for (int i = 0; i < NumOperands; i++) { + Instruction *IM = dyn_cast<Instruction>(IW->getOperand(i)); + if (!IM || IM->getOpcode() == Instruction::PHI) + continue; + + if (!DT.dominates(IM, I)) { + InstructionsToMove.insert(IM); + Worklist.push_back(IM); + assert(IM->getParent() == IW->getParent() && + "Instructions to move should be in the same basic block"); + } + } + } + + // All instructions to move should follow I. Start from I, not from begin(). + for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E; + ++BBI) { + if (!is_contained(InstructionsToMove, &*BBI)) + continue; + Instruction *IM = &*BBI; + --BBI; + IM->removeFromParent(); + IM->insertBefore(I); + } +} + +std::pair<BasicBlock::iterator, BasicBlock::iterator> +Vectorizer::getBoundaryInstrs(ArrayRef<Value *> Chain) { + Instruction *C0 = cast<Instruction>(Chain[0]); + BasicBlock::iterator FirstInstr = C0->getIterator(); + BasicBlock::iterator LastInstr = C0->getIterator(); + + BasicBlock *BB = C0->getParent(); + unsigned NumFound = 0; + for (Instruction &I : *BB) { + if (!is_contained(Chain, &I)) + continue; + + ++NumFound; + if (NumFound == 1) { + FirstInstr = I.getIterator(); + } + if (NumFound == Chain.size()) { + LastInstr = I.getIterator(); + break; + } + } + + // Range is [first, last). + return std::make_pair(FirstInstr, ++LastInstr); +} + +void Vectorizer::eraseInstructions(ArrayRef<Value *> Chain) { + SmallVector<Instruction *, 16> Instrs; + for (Value *V : Chain) { + Value *PtrOperand = getPointerOperand(V); + assert(PtrOperand && "Instruction must have a pointer operand."); + Instrs.push_back(cast<Instruction>(V)); + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand)) + Instrs.push_back(GEP); + } + + // Erase instructions. + for (Value *V : Instrs) { + Instruction *Instr = cast<Instruction>(V); + if (Instr->use_empty()) + Instr->eraseFromParent(); + } +} + +std::pair<ArrayRef<Value *>, ArrayRef<Value *>> +Vectorizer::splitOddVectorElts(ArrayRef<Value *> Chain, + unsigned ElementSizeBits) { + unsigned ElemSizeInBytes = ElementSizeBits / 8; + unsigned SizeInBytes = ElemSizeInBytes * Chain.size(); + unsigned NumRight = (SizeInBytes % 4) / ElemSizeInBytes; + unsigned NumLeft = Chain.size() - NumRight; + return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft)); +} + +unsigned Vectorizer::getVectorizablePrefixEndIdx(ArrayRef<Value *> Chain, + BasicBlock::iterator From, + BasicBlock::iterator To) { + SmallVector<std::pair<Value *, unsigned>, 16> MemoryInstrs; + SmallVector<std::pair<Value *, unsigned>, 16> ChainInstrs; + + unsigned InstrIdx = 0; + for (auto I = From; I != To; ++I, ++InstrIdx) { + if (isa<LoadInst>(I) || isa<StoreInst>(I)) { + if (!is_contained(Chain, &*I)) + MemoryInstrs.push_back({&*I, InstrIdx}); + else + ChainInstrs.push_back({&*I, InstrIdx}); + } else if (I->mayHaveSideEffects()) { + DEBUG(dbgs() << "LSV: Found side-effecting operation: " << *I << '\n'); + return 0; + } + } + + assert(Chain.size() == ChainInstrs.size() && + "All instructions in the Chain must exist in [From, To)."); + + unsigned ChainIdx = 0; + for (auto EntryChain : ChainInstrs) { + Value *ChainInstrValue = EntryChain.first; + unsigned ChainInstrIdx = EntryChain.second; + for (auto EntryMem : MemoryInstrs) { + Value *MemInstrValue = EntryMem.first; + unsigned MemInstrIdx = EntryMem.second; + if (isa<LoadInst>(MemInstrValue) && isa<LoadInst>(ChainInstrValue)) + continue; + + // We can ignore the alias as long as the load comes before the store, + // because that means we won't be moving the load past the store to + // vectorize it (the vectorized load is inserted at the location of the + // first load in the chain). + if (isa<StoreInst>(MemInstrValue) && isa<LoadInst>(ChainInstrValue) && + ChainInstrIdx < MemInstrIdx) + continue; + + // Same case, but in reverse. + if (isa<LoadInst>(MemInstrValue) && isa<StoreInst>(ChainInstrValue) && + ChainInstrIdx > MemInstrIdx) + continue; + + Instruction *M0 = cast<Instruction>(MemInstrValue); + Instruction *M1 = cast<Instruction>(ChainInstrValue); + + if (!AA.isNoAlias(MemoryLocation::get(M0), MemoryLocation::get(M1))) { + DEBUG({ + Value *Ptr0 = getPointerOperand(M0); + Value *Ptr1 = getPointerOperand(M1); + + dbgs() << "LSV: Found alias.\n" + " Aliasing instruction and pointer:\n" + << *MemInstrValue << " aliases " << *Ptr0 << '\n' + << " Aliased instruction and pointer:\n" + << *ChainInstrValue << " aliases " << *Ptr1 << '\n'; + }); + + return ChainIdx; + } + } + ChainIdx++; + } + return Chain.size(); +} + +void Vectorizer::collectInstructions(BasicBlock *BB) { + LoadRefs.clear(); + StoreRefs.clear(); + + for (Instruction &I : *BB) { + if (!I.mayReadOrWriteMemory()) + continue; + + if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { + if (!LI->isSimple()) + continue; + + Type *Ty = LI->getType(); + if (!VectorType::isValidElementType(Ty->getScalarType())) + continue; + + // Skip weird non-byte sizes. They probably aren't worth the effort of + // handling correctly. + unsigned TySize = DL.getTypeSizeInBits(Ty); + if (TySize < 8) + continue; + + Value *Ptr = LI->getPointerOperand(); + unsigned AS = Ptr->getType()->getPointerAddressSpace(); + unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); + + // No point in looking at these if they're too big to vectorize. + if (TySize > VecRegSize / 2) + continue; + + // Make sure all the users of a vector are constant-index extracts. + if (isa<VectorType>(Ty) && !all_of(LI->users(), [LI](const User *U) { + const Instruction *UI = cast<Instruction>(U); + return isa<ExtractElementInst>(UI) && + isa<ConstantInt>(UI->getOperand(1)); + })) + continue; + + // TODO: Target hook to filter types. + + // Save the load locations. + Value *ObjPtr = GetUnderlyingObject(Ptr, DL); + LoadRefs[ObjPtr].push_back(LI); + + } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { + if (!SI->isSimple()) + continue; + + Type *Ty = SI->getValueOperand()->getType(); + if (!VectorType::isValidElementType(Ty->getScalarType())) + continue; + + // Skip weird non-byte sizes. They probably aren't worth the effort of + // handling correctly. + unsigned TySize = DL.getTypeSizeInBits(Ty); + if (TySize < 8) + continue; + + Value *Ptr = SI->getPointerOperand(); + unsigned AS = Ptr->getType()->getPointerAddressSpace(); + unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); + if (TySize > VecRegSize / 2) + continue; + + if (isa<VectorType>(Ty) && !all_of(SI->users(), [SI](const User *U) { + const Instruction *UI = cast<Instruction>(U); + return isa<ExtractElementInst>(UI) && + isa<ConstantInt>(UI->getOperand(1)); + })) + continue; + + // Save store location. + Value *ObjPtr = GetUnderlyingObject(Ptr, DL); + StoreRefs[ObjPtr].push_back(SI); + } + } +} + +bool Vectorizer::vectorizeChains(ValueListMap &Map) { + bool Changed = false; + + for (const std::pair<Value *, ValueList> &Chain : Map) { + unsigned Size = Chain.second.size(); + if (Size < 2) + continue; + + DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n"); + + // Process the stores in chunks of 64. + for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) { + unsigned Len = std::min<unsigned>(CE - CI, 64); + ArrayRef<Value *> Chunk(&Chain.second[CI], Len); + Changed |= vectorizeInstructions(Chunk); + } + } + + return Changed; +} + +bool Vectorizer::vectorizeInstructions(ArrayRef<Value *> Instrs) { + DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size() << " instructions.\n"); + SmallSetVector<int, 16> Heads, Tails; + int ConsecutiveChain[64]; + + // Do a quadratic search on all of the given stores and find all of the pairs + // of stores that follow each other. + for (int i = 0, e = Instrs.size(); i < e; ++i) { + ConsecutiveChain[i] = -1; + for (int j = e - 1; j >= 0; --j) { + if (i == j) + continue; + + if (isConsecutiveAccess(Instrs[i], Instrs[j])) { + if (ConsecutiveChain[i] != -1) { + int CurDistance = std::abs(ConsecutiveChain[i] - i); + int NewDistance = std::abs(ConsecutiveChain[i] - j); + if (j < i || NewDistance > CurDistance) + continue; // Should not insert. + } + + Tails.insert(j); + Heads.insert(i); + ConsecutiveChain[i] = j; + } + } + } + + bool Changed = false; + SmallPtrSet<Value *, 16> InstructionsProcessed; + + for (int Head : Heads) { + if (InstructionsProcessed.count(Instrs[Head])) + continue; + bool longerChainExists = false; + for (unsigned TIt = 0; TIt < Tails.size(); TIt++) + if (Head == Tails[TIt] && + !InstructionsProcessed.count(Instrs[Heads[TIt]])) { + longerChainExists = true; + break; + } + if (longerChainExists) + continue; + + // We found an instr that starts a chain. Now follow the chain and try to + // vectorize it. + SmallVector<Value *, 16> Operands; + int I = Head; + while (I != -1 && (Tails.count(I) || Heads.count(I))) { + if (InstructionsProcessed.count(Instrs[I])) + break; + + Operands.push_back(Instrs[I]); + I = ConsecutiveChain[I]; + } + + bool Vectorized = false; + if (isa<LoadInst>(*Operands.begin())) + Vectorized = vectorizeLoadChain(Operands, &InstructionsProcessed); + else + Vectorized = vectorizeStoreChain(Operands, &InstructionsProcessed); + + Changed |= Vectorized; + } + + return Changed; +} + +bool Vectorizer::vectorizeStoreChain( + ArrayRef<Value *> Chain, SmallPtrSet<Value *, 16> *InstructionsProcessed) { + StoreInst *S0 = cast<StoreInst>(Chain[0]); + + // If the vector has an int element, default to int for the whole load. + Type *StoreTy; + for (const auto &V : Chain) { + StoreTy = cast<StoreInst>(V)->getValueOperand()->getType(); + if (StoreTy->isIntOrIntVectorTy()) + break; + + if (StoreTy->isPtrOrPtrVectorTy()) { + StoreTy = Type::getIntNTy(F.getParent()->getContext(), + DL.getTypeSizeInBits(StoreTy)); + break; + } + } + + unsigned Sz = DL.getTypeSizeInBits(StoreTy); + unsigned AS = S0->getPointerAddressSpace(); + unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); + unsigned VF = VecRegSize / Sz; + unsigned ChainSize = Chain.size(); + + if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { + InstructionsProcessed->insert(Chain.begin(), Chain.end()); + return false; + } + + BasicBlock::iterator First, Last; + std::tie(First, Last) = getBoundaryInstrs(Chain); + unsigned StopChain = getVectorizablePrefixEndIdx(Chain, First, Last); + if (StopChain == 0) { + // There exists a side effect instruction, no vectorization possible. + InstructionsProcessed->insert(Chain.begin(), Chain.end()); + return false; + } + if (StopChain == 1) { + // Failed after the first instruction. Discard it and try the smaller chain. + InstructionsProcessed->insert(Chain.front()); + return false; + } + + // Update Chain to the valid vectorizable subchain. + Chain = Chain.slice(0, StopChain); + ChainSize = Chain.size(); + + // Store size should be 1B, 2B or multiple of 4B. + // TODO: Target hook for size constraint? + unsigned SzInBytes = (Sz / 8) * ChainSize; + if (SzInBytes > 2 && SzInBytes % 4 != 0) { + DEBUG(dbgs() << "LSV: Size should be 1B, 2B " + "or multiple of 4B. Splitting.\n"); + if (SzInBytes == 3) + return vectorizeStoreChain(Chain.slice(0, ChainSize - 1), + InstructionsProcessed); + + auto Chains = splitOddVectorElts(Chain, Sz); + return vectorizeStoreChain(Chains.first, InstructionsProcessed) | + vectorizeStoreChain(Chains.second, InstructionsProcessed); + } + + VectorType *VecTy; + VectorType *VecStoreTy = dyn_cast<VectorType>(StoreTy); + if (VecStoreTy) + VecTy = VectorType::get(StoreTy->getScalarType(), + Chain.size() * VecStoreTy->getNumElements()); + else + VecTy = VectorType::get(StoreTy, Chain.size()); + + // If it's more than the max vector size, break it into two pieces. + // TODO: Target hook to control types to split to. + if (ChainSize > VF) { + DEBUG(dbgs() << "LSV: Vector factor is too big." + " Creating two separate arrays.\n"); + return vectorizeStoreChain(Chain.slice(0, VF), InstructionsProcessed) | + vectorizeStoreChain(Chain.slice(VF), InstructionsProcessed); + } + + DEBUG({ + dbgs() << "LSV: Stores to vectorize:\n"; + for (Value *V : Chain) + V->dump(); + }); + + // We won't try again to vectorize the elements of the chain, regardless of + // whether we succeed below. + InstructionsProcessed->insert(Chain.begin(), Chain.end()); + + // Check alignment restrictions. + unsigned Alignment = getAlignment(S0); + + // If the store is going to be misaligned, don't vectorize it. + if (accessIsMisaligned(SzInBytes, AS, Alignment)) { + if (S0->getPointerAddressSpace() != 0) + return false; + + // If we're storing to an object on the stack, we control its alignment, + // so we can cheat and change it! + Value *V = GetUnderlyingObject(S0->getPointerOperand(), DL); + if (AllocaInst *AI = dyn_cast_or_null<AllocaInst>(V)) { + AI->setAlignment(TargetBaseAlign); + Alignment = TargetBaseAlign; + } else { + return false; + } + } + + // Set insert point. + Builder.SetInsertPoint(&*Last); + + Value *Vec = UndefValue::get(VecTy); + + if (VecStoreTy) { + unsigned VecWidth = VecStoreTy->getNumElements(); + for (unsigned I = 0, E = Chain.size(); I != E; ++I) { + StoreInst *Store = cast<StoreInst>(Chain[I]); + for (unsigned J = 0, NE = VecStoreTy->getNumElements(); J != NE; ++J) { + unsigned NewIdx = J + I * VecWidth; + Value *Extract = Builder.CreateExtractElement(Store->getValueOperand(), + Builder.getInt32(J)); + if (Extract->getType() != StoreTy->getScalarType()) + Extract = Builder.CreateBitCast(Extract, StoreTy->getScalarType()); + + Value *Insert = + Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(NewIdx)); + Vec = Insert; + } + } + } else { + for (unsigned I = 0, E = Chain.size(); I != E; ++I) { + StoreInst *Store = cast<StoreInst>(Chain[I]); + Value *Extract = Store->getValueOperand(); + if (Extract->getType() != StoreTy->getScalarType()) + Extract = + Builder.CreateBitOrPointerCast(Extract, StoreTy->getScalarType()); + + Value *Insert = + Builder.CreateInsertElement(Vec, Extract, Builder.getInt32(I)); + Vec = Insert; + } + } + + Value *Bitcast = + Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS)); + StoreInst *SI = cast<StoreInst>(Builder.CreateStore(Vec, Bitcast)); + propagateMetadata(SI, Chain); + SI->setAlignment(Alignment); + + eraseInstructions(Chain); + ++NumVectorInstructions; + NumScalarsVectorized += Chain.size(); + return true; +} + +bool Vectorizer::vectorizeLoadChain( + ArrayRef<Value *> Chain, SmallPtrSet<Value *, 16> *InstructionsProcessed) { + LoadInst *L0 = cast<LoadInst>(Chain[0]); + + // If the vector has an int element, default to int for the whole load. + Type *LoadTy; + for (const auto &V : Chain) { + LoadTy = cast<LoadInst>(V)->getType(); + if (LoadTy->isIntOrIntVectorTy()) + break; + + if (LoadTy->isPtrOrPtrVectorTy()) { + LoadTy = Type::getIntNTy(F.getParent()->getContext(), + DL.getTypeSizeInBits(LoadTy)); + break; + } + } + + unsigned Sz = DL.getTypeSizeInBits(LoadTy); + unsigned AS = L0->getPointerAddressSpace(); + unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); + unsigned VF = VecRegSize / Sz; + unsigned ChainSize = Chain.size(); + + if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { + InstructionsProcessed->insert(Chain.begin(), Chain.end()); + return false; + } + + BasicBlock::iterator First, Last; + std::tie(First, Last) = getBoundaryInstrs(Chain); + unsigned StopChain = getVectorizablePrefixEndIdx(Chain, First, Last); + if (StopChain == 0) { + // There exists a side effect instruction, no vectorization possible. + InstructionsProcessed->insert(Chain.begin(), Chain.end()); + return false; + } + if (StopChain == 1) { + // Failed after the first instruction. Discard it and try the smaller chain. + InstructionsProcessed->insert(Chain.front()); + return false; + } + + // Update Chain to the valid vectorizable subchain. + Chain = Chain.slice(0, StopChain); + ChainSize = Chain.size(); + + // Load size should be 1B, 2B or multiple of 4B. + // TODO: Should size constraint be a target hook? + unsigned SzInBytes = (Sz / 8) * ChainSize; + if (SzInBytes > 2 && SzInBytes % 4 != 0) { + DEBUG(dbgs() << "LSV: Size should be 1B, 2B " + "or multiple of 4B. Splitting.\n"); + if (SzInBytes == 3) + return vectorizeLoadChain(Chain.slice(0, ChainSize - 1), + InstructionsProcessed); + auto Chains = splitOddVectorElts(Chain, Sz); + return vectorizeLoadChain(Chains.first, InstructionsProcessed) | + vectorizeLoadChain(Chains.second, InstructionsProcessed); + } + + VectorType *VecTy; + VectorType *VecLoadTy = dyn_cast<VectorType>(LoadTy); + if (VecLoadTy) + VecTy = VectorType::get(LoadTy->getScalarType(), + Chain.size() * VecLoadTy->getNumElements()); + else + VecTy = VectorType::get(LoadTy, Chain.size()); + + // If it's more than the max vector size, break it into two pieces. + // TODO: Target hook to control types to split to. + if (ChainSize > VF) { + DEBUG(dbgs() << "LSV: Vector factor is too big. " + "Creating two separate arrays.\n"); + return vectorizeLoadChain(Chain.slice(0, VF), InstructionsProcessed) | + vectorizeLoadChain(Chain.slice(VF), InstructionsProcessed); + } + + // We won't try again to vectorize the elements of the chain, regardless of + // whether we succeed below. + InstructionsProcessed->insert(Chain.begin(), Chain.end()); + + // Check alignment restrictions. + unsigned Alignment = getAlignment(L0); + + // If the load is going to be misaligned, don't vectorize it. + if (accessIsMisaligned(SzInBytes, AS, Alignment)) { + if (L0->getPointerAddressSpace() != 0) + return false; + + // If we're loading from an object on the stack, we control its alignment, + // so we can cheat and change it! + Value *V = GetUnderlyingObject(L0->getPointerOperand(), DL); + if (AllocaInst *AI = dyn_cast_or_null<AllocaInst>(V)) { + AI->setAlignment(TargetBaseAlign); + Alignment = TargetBaseAlign; + } else { + return false; + } + } + + DEBUG({ + dbgs() << "LSV: Loads to vectorize:\n"; + for (Value *V : Chain) + V->dump(); + }); + + // Set insert point. + Builder.SetInsertPoint(&*First); + + Value *Bitcast = + Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS)); + + LoadInst *LI = cast<LoadInst>(Builder.CreateLoad(Bitcast)); + propagateMetadata(LI, Chain); + LI->setAlignment(Alignment); + + if (VecLoadTy) { + SmallVector<Instruction *, 16> InstrsToErase; + SmallVector<Instruction *, 16> InstrsToReorder; + InstrsToReorder.push_back(cast<Instruction>(Bitcast)); + + unsigned VecWidth = VecLoadTy->getNumElements(); + for (unsigned I = 0, E = Chain.size(); I != E; ++I) { + for (auto Use : Chain[I]->users()) { + Instruction *UI = cast<Instruction>(Use); + unsigned Idx = cast<ConstantInt>(UI->getOperand(1))->getZExtValue(); + unsigned NewIdx = Idx + I * VecWidth; + Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(NewIdx)); + Instruction *Extracted = cast<Instruction>(V); + if (Extracted->getType() != UI->getType()) + Extracted = cast<Instruction>( + Builder.CreateBitCast(Extracted, UI->getType())); + + // Replace the old instruction. + UI->replaceAllUsesWith(Extracted); + InstrsToErase.push_back(UI); + } + } + + for (Instruction *ModUser : InstrsToReorder) + reorder(ModUser); + + for (auto I : InstrsToErase) + I->eraseFromParent(); + } else { + SmallVector<Instruction *, 16> InstrsToReorder; + InstrsToReorder.push_back(cast<Instruction>(Bitcast)); + + for (unsigned I = 0, E = Chain.size(); I != E; ++I) { + Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(I)); + Instruction *Extracted = cast<Instruction>(V); + Instruction *UI = cast<Instruction>(Chain[I]); + if (Extracted->getType() != UI->getType()) { + Extracted = cast<Instruction>( + Builder.CreateBitOrPointerCast(Extracted, UI->getType())); + } + + // Replace the old instruction. + UI->replaceAllUsesWith(Extracted); + } + + for (Instruction *ModUser : InstrsToReorder) + reorder(ModUser); + } + + eraseInstructions(Chain); + + ++NumVectorInstructions; + NumScalarsVectorized += Chain.size(); + return true; +} + +bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, + unsigned Alignment) { + bool Fast = false; + bool Allows = TTI.allowsMisalignedMemoryAccesses(SzInBytes * 8, AddressSpace, + Alignment, &Fast); + // TODO: Remove TargetBaseAlign + return !(Allows && Fast) && (Alignment % SzInBytes) != 0 && + (Alignment % TargetBaseAlign) != 0; +} diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 17c25dfffc10..8b85e320d3b2 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -46,7 +46,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Vectorize.h" +#include "llvm/Transforms/Vectorize/LoopVectorize.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" @@ -56,23 +56,15 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/AliasSetTracker.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" @@ -98,10 +90,10 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Analysis/VectorUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Vectorize.h" #include <algorithm> -#include <functional> #include <map> #include <tuple> @@ -115,37 +107,21 @@ STATISTIC(LoopsVectorized, "Number of loops vectorized"); STATISTIC(LoopsAnalyzed, "Number of loops analyzed for vectorization"); static cl::opt<bool> -EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, - cl::desc("Enable if-conversion during vectorization.")); + EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, + cl::desc("Enable if-conversion during vectorization.")); /// We don't vectorize loops with a known constant trip count below this number. -static cl::opt<unsigned> -TinyTripCountVectorThreshold("vectorizer-min-trip-count", cl::init(16), - cl::Hidden, - cl::desc("Don't vectorize loops with a constant " - "trip count that is smaller than this " - "value.")); +static cl::opt<unsigned> TinyTripCountVectorThreshold( + "vectorizer-min-trip-count", cl::init(16), cl::Hidden, + cl::desc("Don't vectorize loops with a constant " + "trip count that is smaller than this " + "value.")); static cl::opt<bool> MaximizeBandwidth( "vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden, cl::desc("Maximize bandwidth when selecting vectorization factor which " "will be determined by the smallest type in loop.")); -/// This enables versioning on the strides of symbolically striding memory -/// accesses in code like the following. -/// for (i = 0; i < N; ++i) -/// A[i * Stride1] += B[i * Stride2] ... -/// -/// Will be roughly translated to -/// if (Stride1 == 1 && Stride2 == 1) { -/// for (i = 0; i < N; i+=4) -/// A[i:i+3] += ... -/// } else -/// ... -static cl::opt<bool> EnableMemAccessVersioning( - "enable-mem-access-versioning", cl::init(true), cl::Hidden, - cl::desc("Enable symbolic stride memory access versioning")); - static cl::opt<bool> EnableInterleavedMemAccesses( "enable-interleaved-mem-accesses", cl::init(false), cl::Hidden, cl::desc("Enable vectorization on interleaved memory accesses in a loop")); @@ -262,7 +238,7 @@ public: /// A helper function for converting Scalar types to vector types. /// If the incoming type is void, we return void. If the VF is 1, we return /// the scalar type. -static Type* ToVectorTy(Type *Scalar, unsigned VF) { +static Type *ToVectorTy(Type *Scalar, unsigned VF) { if (Scalar->isVoidTy() || VF == 1) return Scalar; return VectorType::get(Scalar, VF); @@ -313,21 +289,25 @@ public: InnerLoopVectorizer(Loop *OrigLoop, PredicatedScalarEvolution &PSE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, unsigned VecWidth, - unsigned UnrollFactor) + const TargetTransformInfo *TTI, AssumptionCache *AC, + unsigned VecWidth, unsigned UnrollFactor) : OrigLoop(OrigLoop), PSE(PSE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), - VF(VecWidth), UF(UnrollFactor), Builder(PSE.getSE()->getContext()), - Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor), - TripCount(nullptr), VectorTripCount(nullptr), Legal(nullptr), - AddedSafetyChecks(false) {} + AC(AC), VF(VecWidth), UF(UnrollFactor), + Builder(PSE.getSE()->getContext()), Induction(nullptr), + OldInduction(nullptr), WidenMap(UnrollFactor), TripCount(nullptr), + VectorTripCount(nullptr), Legal(nullptr), AddedSafetyChecks(false) {} // Perform the actual loop widening (vectorization). // MinimumBitWidths maps scalar integer values to the smallest bitwidth they // can be validly truncated to. The cost model has assumed this truncation - // will happen when vectorizing. + // will happen when vectorizing. VecValuesToIgnore contains scalar values + // that the cost model has chosen to ignore because they will not be + // vectorized. void vectorize(LoopVectorizationLegality *L, - MapVector<Instruction*,uint64_t> MinimumBitWidths) { - MinBWs = MinimumBitWidths; + const MapVector<Instruction *, uint64_t> &MinimumBitWidths, + SmallPtrSetImpl<const Value *> &VecValuesToIgnore) { + MinBWs = &MinimumBitWidths; + ValuesNotWidened = &VecValuesToIgnore; Legal = L; // Create a new empty loop. Unlink the old loop and connect the new one. createEmptyLoop(); @@ -337,33 +317,41 @@ public: } // Return true if any runtime check is added. - bool IsSafetyChecksAdded() { - return AddedSafetyChecks; - } + bool areSafetyChecksAdded() { return AddedSafetyChecks; } virtual ~InnerLoopVectorizer() {} protected: /// A small list of PHINodes. - typedef SmallVector<PHINode*, 4> PhiVector; + typedef SmallVector<PHINode *, 4> PhiVector; /// When we unroll loops we have multiple vector values for each scalar. /// This data structure holds the unrolled and vectorized values that /// originated from one scalar instruction. - typedef SmallVector<Value*, 2> VectorParts; + typedef SmallVector<Value *, 2> VectorParts; // When we if-convert we need to create edge masks. We have to cache values // so that we don't end up with exponential recursion/IR. - typedef DenseMap<std::pair<BasicBlock*, BasicBlock*>, - VectorParts> EdgeMaskCache; + typedef DenseMap<std::pair<BasicBlock *, BasicBlock *>, VectorParts> + EdgeMaskCache; /// Create an empty loop, based on the loop ranges of the old loop. void createEmptyLoop(); + + /// Set up the values of the IVs correctly when exiting the vector loop. + void fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II, + Value *CountRoundDown, Value *EndValue, + BasicBlock *MiddleBlock); + /// Create a new induction variable inside L. PHINode *createInductionVariable(Loop *L, Value *Start, Value *End, Value *Step, Instruction *DL); /// Copy and widen the instructions from the old loop. virtual void vectorizeLoop(); + /// Fix a first-order recurrence. This is the second phase of vectorizing + /// this phi node. + void fixFirstOrderRecurrence(PHINode *Phi); + /// \brief The Loop exit block may have single value PHI nodes where the /// incoming value is 'Undef'. While vectorizing we only handled real values /// that were defined inside the loop. Here we fix the 'undef case'. @@ -372,7 +360,7 @@ protected: /// Shrinks vector element sizes based on information in "MinBWs". void truncateToMinimalBitwidths(); - + /// A helper function that computes the predicate of the block BB, assuming /// that the header block of the loop is set to True. It returns the *entry* /// mask for the block BB. @@ -383,12 +371,12 @@ protected: /// A helper function to vectorize a single BB within the innermost loop. void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV); - + /// Vectorize a single PHINode in a block. This method handles the induction /// variable canonicalization. It supports both VF = 1 for unrolled loops and /// arbitrary length vectors. - void widenPHIInstruction(Instruction *PN, VectorParts &Entry, - unsigned UF, unsigned VF, PhiVector *PV); + void widenPHIInstruction(Instruction *PN, VectorParts &Entry, unsigned UF, + unsigned VF, PhiVector *PV); /// Insert the new loop to the loop hierarchy and pass manager /// and update the analysis passes. @@ -399,7 +387,7 @@ protected: /// scalarized instruction behind an if block predicated on the control /// dependence of the instruction. virtual void scalarizeInstruction(Instruction *Instr, - bool IfPredicateStore=false); + bool IfPredicateStore = false); /// Vectorize Load and Store instructions, virtual void vectorizeMemoryInstruction(Instruction *Instr); @@ -415,6 +403,26 @@ protected: /// to each vector element of Val. The sequence starts at StartIndex. virtual Value *getStepVector(Value *Val, int StartIdx, Value *Step); + /// Compute scalar induction steps. \p ScalarIV is the scalar induction + /// variable on which to base the steps, \p Step is the size of the step, and + /// \p EntryVal is the value from the original loop that maps to the steps. + /// Note that \p EntryVal doesn't have to be an induction variable (e.g., it + /// can be a truncate instruction). + void buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal); + + /// Create a vector induction phi node based on an existing scalar one. This + /// currently only works for integer induction variables with a constant + /// step. If \p TruncType is non-null, instead of widening the original IV, + /// we widen a version of the IV truncated to \p TruncType. + void createVectorIntInductionPHI(const InductionDescriptor &II, + VectorParts &Entry, IntegerType *TruncType); + + /// Widen an integer induction variable \p IV. If \p Trunc is provided, the + /// induction variable will first be truncated to the corresponding type. The + /// widened values are placed in \p Entry. + void widenIntInduction(PHINode *IV, VectorParts &Entry, + TruncInst *Trunc = nullptr); + /// When we go over instructions in the basic block we rely on previous /// values within the current basic block or on loop invariant values. /// When we widen (vectorize) values we place them in the map. If the values @@ -445,6 +453,24 @@ protected: /// Emit bypass checks to check any memory assumptions we may have made. void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); + /// Add additional metadata to \p To that was not present on \p Orig. + /// + /// Currently this is used to add the noalias annotations based on the + /// inserted memchecks. Use this for instructions that are *cloned* into the + /// vector loop. + void addNewMetadata(Instruction *To, const Instruction *Orig); + + /// Add metadata from one instruction to another. + /// + /// This includes both the original MDs from \p From and additional ones (\see + /// addNewMetadata). Use this for *newly created* instructions in the vector + /// loop. + void addMetadata(Instruction *To, Instruction *From); + + /// \brief Similar to the previous function but it adds the metadata to a + /// vector of instructions. + void addMetadata(ArrayRef<Value *> To, Instruction *From); + /// This is a helper class that holds the vectorizer state. It maps scalar /// instructions to vector instructions. When the code is 'unrolled' then /// then a single scalar value is mapped to multiple vector parts. The parts @@ -501,6 +527,15 @@ protected: const TargetLibraryInfo *TLI; /// Target Transform Info. const TargetTransformInfo *TTI; + /// Assumption Cache. + AssumptionCache *AC; + + /// \brief LoopVersioning. It's only set up (non-null) if memchecks were + /// used. + /// + /// This is currently only used to add no-alias metadata based on the + /// memchecks. The actually versioning is performed manually. + std::unique_ptr<LoopVersioning> LVer; /// The vectorization SIMD factor to use. Each vector will have this many /// vector elements. @@ -522,11 +557,11 @@ protected: BasicBlock *LoopScalarPreHeader; /// Middle Block between the vector and the scalar. BasicBlock *LoopMiddleBlock; - ///The ExitBlock of the scalar loop. + /// The ExitBlock of the scalar loop. BasicBlock *LoopExitBlock; - ///The vector loop body. - SmallVector<BasicBlock *, 4> LoopVectorBody; - ///The scalar loop body. + /// The vector loop body. + BasicBlock *LoopVectorBody; + /// The scalar loop body. BasicBlock *LoopScalarBody; /// A list of all bypass blocks. The first block is the entry of the loop. SmallVector<BasicBlock *, 4> LoopBypassBlocks; @@ -537,9 +572,20 @@ protected: PHINode *OldInduction; /// Maps scalars to widened vectors. ValueMap WidenMap; + + /// A map of induction variables from the original loop to their + /// corresponding VF * UF scalarized values in the vectorized loop. The + /// purpose of ScalarIVMap is similar to that of WidenMap. Whereas WidenMap + /// maps original loop values to their vector versions in the new loop, + /// ScalarIVMap maps induction variables from the original loop that are not + /// vectorized to their scalar equivalents in the vector loop. Maintaining a + /// separate map for scalarized induction variables allows us to avoid + /// unnecessary scalar-to-vector-to-scalar conversions. + DenseMap<Value *, SmallVector<Value *, 8>> ScalarIVMap; + /// Store instructions that should be predicated, as a pair /// <StoreInst, Predicate> - SmallVector<std::pair<StoreInst*,Value*>, 4> PredicatedStores; + SmallVector<std::pair<StoreInst *, Value *>, 4> PredicatedStores; EdgeMaskCache MaskCache; /// Trip count of the original loop. Value *TripCount; @@ -549,10 +595,15 @@ protected: /// Map of scalar integer values to the smallest bitwidth they can be legally /// represented as. The vector equivalents of these values should be truncated /// to this type. - MapVector<Instruction*,uint64_t> MinBWs; + const MapVector<Instruction *, uint64_t> *MinBWs; + + /// A set of values that should not be widened. This is taken from + /// VecValuesToIgnore in the cost model. + SmallPtrSetImpl<const Value *> *ValuesNotWidened; + LoopVectorizationLegality *Legal; - // Record whether runtime check is added. + // Record whether runtime checks are added. bool AddedSafetyChecks; }; @@ -561,8 +612,10 @@ public: InnerLoopUnroller(Loop *OrigLoop, PredicatedScalarEvolution &PSE, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, unsigned UnrollFactor) - : InnerLoopVectorizer(OrigLoop, PSE, LI, DT, TLI, TTI, 1, UnrollFactor) {} + const TargetTransformInfo *TTI, AssumptionCache *AC, + unsigned UnrollFactor) + : InnerLoopVectorizer(OrigLoop, PSE, LI, DT, TLI, TTI, AC, 1, + UnrollFactor) {} private: void scalarizeInstruction(Instruction *Instr, @@ -618,36 +671,26 @@ static std::string getDebugLocString(const Loop *L) { } #endif -/// \brief Propagate known metadata from one instruction to another. -static void propagateMetadata(Instruction *To, const Instruction *From) { - SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; - From->getAllMetadataOtherThanDebugLoc(Metadata); - - for (auto M : Metadata) { - unsigned Kind = M.first; - - // These are safe to transfer (this is safe for TBAA, even when we - // if-convert, because should that metadata have had a control dependency - // on the condition, and thus actually aliased with some other - // non-speculated memory access when the condition was false, this would be - // caught by the runtime overlap checks). - if (Kind != LLVMContext::MD_tbaa && - Kind != LLVMContext::MD_alias_scope && - Kind != LLVMContext::MD_noalias && - Kind != LLVMContext::MD_fpmath && - Kind != LLVMContext::MD_nontemporal) - continue; +void InnerLoopVectorizer::addNewMetadata(Instruction *To, + const Instruction *Orig) { + // If the loop was versioned with memchecks, add the corresponding no-alias + // metadata. + if (LVer && (isa<LoadInst>(Orig) || isa<StoreInst>(Orig))) + LVer->annotateInstWithNoAlias(To, Orig); +} - To->setMetadata(Kind, M.second); - } +void InnerLoopVectorizer::addMetadata(Instruction *To, + Instruction *From) { + propagateMetadata(To, From); + addNewMetadata(To, From); } -/// \brief Propagate known metadata from one instruction to a vector of others. -static void propagateMetadata(SmallVectorImpl<Value *> &To, - const Instruction *From) { - for (Value *V : To) +void InnerLoopVectorizer::addMetadata(ArrayRef<Value *> To, + Instruction *From) { + for (Value *V : To) { if (Instruction *I = dyn_cast<Instruction>(V)) - propagateMetadata(I, From); + addMetadata(I, From); + } } /// \brief The group of interleaved loads/stores sharing the same stride and @@ -785,8 +828,9 @@ private: class InterleavedAccessInfo { public: InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, - DominatorTree *DT) - : PSE(PSE), TheLoop(L), DT(DT) {} + DominatorTree *DT, LoopInfo *LI) + : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(nullptr), + RequiresScalarEpilogue(false) {} ~InterleavedAccessInfo() { SmallSet<InterleaveGroup *, 4> DelSet; @@ -806,6 +850,14 @@ public: return InterleaveGroupMap.count(Instr); } + /// \brief Return the maximum interleave factor of all interleaved groups. + unsigned getMaxInterleaveFactor() const { + unsigned MaxFactor = 1; + for (auto &Entry : InterleaveGroupMap) + MaxFactor = std::max(MaxFactor, Entry.second->getFactor()); + return MaxFactor; + } + /// \brief Get the interleave group that \p Instr belongs to. /// /// \returns nullptr if doesn't have such group. @@ -815,6 +867,13 @@ public: return nullptr; } + /// \brief Returns true if an interleaved group that may access memory + /// out-of-bounds requires a scalar epilogue iteration for correctness. + bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } + + /// \brief Initialize the LoopAccessInfo used for dependence checking. + void setLAI(const LoopAccessInfo *Info) { LAI = Info; } + private: /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. /// Simplifies SCEV expressions in the context of existing SCEV assumptions. @@ -823,24 +882,39 @@ private: PredicatedScalarEvolution &PSE; Loop *TheLoop; DominatorTree *DT; + LoopInfo *LI; + const LoopAccessInfo *LAI; + + /// True if the loop may contain non-reversed interleaved groups with + /// out-of-bounds accesses. We ensure we don't speculatively access memory + /// out-of-bounds by executing at least one scalar epilogue iteration. + bool RequiresScalarEpilogue; /// Holds the relationships between the members and the interleave group. DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap; + /// Holds dependences among the memory accesses in the loop. It maps a source + /// access to a set of dependent sink accesses. + DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; + /// \brief The descriptor for a strided memory access. struct StrideDescriptor { - StrideDescriptor(int Stride, const SCEV *Scev, unsigned Size, + StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, unsigned Align) : Stride(Stride), Scev(Scev), Size(Size), Align(Align) {} - StrideDescriptor() : Stride(0), Scev(nullptr), Size(0), Align(0) {} + StrideDescriptor() = default; - int Stride; // The access's stride. It is negative for a reverse access. - const SCEV *Scev; // The scalar expression of this access - unsigned Size; // The size of the memory object. - unsigned Align; // The alignment of this access. + // The access's stride. It is negative for a reverse access. + int64_t Stride = 0; + const SCEV *Scev = nullptr; // The scalar expression of this access + uint64_t Size = 0; // The size of the memory object. + unsigned Align = 0; // The alignment of this access. }; + /// \brief A type for holding instructions and their stride descriptors. + typedef std::pair<Instruction *, StrideDescriptor> StrideEntry; + /// \brief Create a new interleave group with the given instruction \p Instr, /// stride \p Stride and alignment \p Align. /// @@ -863,9 +937,86 @@ private: } /// \brief Collect all the accesses with a constant stride in program order. - void collectConstStridedAccesses( - MapVector<Instruction *, StrideDescriptor> &StrideAccesses, + void collectConstStrideAccesses( + MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, const ValueToValueMap &Strides); + + /// \brief Returns true if \p Stride is allowed in an interleaved group. + static bool isStrided(int Stride) { + unsigned Factor = std::abs(Stride); + return Factor >= 2 && Factor <= MaxInterleaveGroupFactor; + } + + /// \brief Returns true if \p BB is a predicated block. + bool isPredicated(BasicBlock *BB) const { + return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); + } + + /// \brief Returns true if LoopAccessInfo can be used for dependence queries. + bool areDependencesValid() const { + return LAI && LAI->getDepChecker().getDependences(); + } + + /// \brief Returns true if memory accesses \p A and \p B can be reordered, if + /// necessary, when constructing interleaved groups. + /// + /// \p A must precede \p B in program order. We return false if reordering is + /// not necessary or is prevented because \p A and \p B may be dependent. + bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, + StrideEntry *B) const { + + // Code motion for interleaved accesses can potentially hoist strided loads + // and sink strided stores. The code below checks the legality of the + // following two conditions: + // + // 1. Potentially moving a strided load (B) before any store (A) that + // precedes B, or + // + // 2. Potentially moving a strided store (A) after any load or store (B) + // that A precedes. + // + // It's legal to reorder A and B if we know there isn't a dependence from A + // to B. Note that this determination is conservative since some + // dependences could potentially be reordered safely. + + // A is potentially the source of a dependence. + auto *Src = A->first; + auto SrcDes = A->second; + + // B is potentially the sink of a dependence. + auto *Sink = B->first; + auto SinkDes = B->second; + + // Code motion for interleaved accesses can't violate WAR dependences. + // Thus, reordering is legal if the source isn't a write. + if (!Src->mayWriteToMemory()) + return true; + + // At least one of the accesses must be strided. + if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) + return true; + + // If dependence information is not available from LoopAccessInfo, + // conservatively assume the instructions can't be reordered. + if (!areDependencesValid()) + return false; + + // If we know there is a dependence from source to sink, assume the + // instructions can't be reordered. Otherwise, reordering is legal. + return !Dependences.count(Src) || !Dependences.lookup(Src).count(Sink); + } + + /// \brief Collect the dependences from LoopAccessInfo. + /// + /// We process the dependences once during the interleaved access analysis to + /// enable constant-time dependence queries. + void collectDependences() { + if (!areDependencesValid()) + return; + auto *Deps = LAI->getDepChecker().getDependences(); + for (auto Dep : *Deps) + Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); + } }; /// Utility class for getting and setting loop vectorizer hints in the form @@ -878,20 +1029,16 @@ private: /// for example 'force', means a decision has been made. So, we need to be /// careful NOT to add them if the user hasn't specifically asked so. class LoopVectorizeHints { - enum HintKind { - HK_WIDTH, - HK_UNROLL, - HK_FORCE - }; + enum HintKind { HK_WIDTH, HK_UNROLL, HK_FORCE }; /// Hint - associates name and validation with the hint value. struct Hint { - const char * Name; + const char *Name; unsigned Value; // This may have to change for non-numeric values. HintKind Kind; - Hint(const char * Name, unsigned Value, HintKind Kind) - : Name(Name), Value(Value), Kind(Kind) { } + Hint(const char *Name, unsigned Value, HintKind Kind) + : Name(Name), Value(Value), Kind(Kind) {} bool validate(unsigned Val) { switch (Kind) { @@ -916,6 +1063,9 @@ class LoopVectorizeHints { /// Return the loop metadata prefix. static StringRef Prefix() { return "llvm.loop."; } + /// True if there is any unsafe math in the loop. + bool PotentiallyUnsafe; + public: enum ForceKind { FK_Undefined = -1, ///< Not selected. @@ -928,7 +1078,7 @@ public: HK_WIDTH), Interleave("interleave.count", DisableInterleaving, HK_UNROLL), Force("vectorize.enable", FK_Undefined, HK_FORCE), - TheLoop(L) { + PotentiallyUnsafe(false), TheLoop(L) { // Populate values with existing loop metadata. getHintsFromMetadata(); @@ -1005,16 +1155,17 @@ public: unsigned getWidth() const { return Width.Value; } unsigned getInterleave() const { return Interleave.Value; } enum ForceKind getForce() const { return (ForceKind)Force.Value; } + + /// \brief If hints are provided that force vectorization, use the AlwaysPrint + /// pass name to force the frontend to print the diagnostic. const char *vectorizeAnalysisPassName() const { - // If hints are provided that don't disable vectorization use the - // AlwaysPrint pass name to force the frontend to print the diagnostic. if (getWidth() == 1) return LV_NAME; if (getForce() == LoopVectorizeHints::FK_Disabled) return LV_NAME; if (getForce() == LoopVectorizeHints::FK_Undefined && getWidth() == 0) return LV_NAME; - return DiagnosticInfo::AlwaysPrint; + return DiagnosticInfoOptimizationRemarkAnalysis::AlwaysPrint; } bool allowReordering() const { @@ -1026,6 +1177,17 @@ public: return getForce() == LoopVectorizeHints::FK_Enabled || getWidth() > 1; } + bool isPotentiallyUnsafe() const { + // Avoid FP vectorization if the target is unsure about proper support. + // This may be related to the SIMD unit in the target not handling + // IEEE 754 FP ops properly, or bad single-to-double promotions. + // Otherwise, a sequence of vectorized loops, even without reduction, + // could lead to different end results on the destination vectors. + return getForce() != LoopVectorizeHints::FK_Enabled && PotentiallyUnsafe; + } + + void setPotentiallyUnsafe() { PotentiallyUnsafe = true; } + private: /// Find hints specified in the loop metadata and update local values. void getHintsFromMetadata() { @@ -1071,7 +1233,8 @@ private: Name = Name.substr(Prefix().size(), StringRef::npos); const ConstantInt *C = mdconst::dyn_extract<ConstantInt>(Arg); - if (!C) return; + if (!C) + return; unsigned Val = C->getZExtValue(); Hint *Hints[] = {&Width, &Interleave, &Force}; @@ -1097,7 +1260,7 @@ private: /// Matches metadata with hint name. bool matchesHintMetadataName(MDNode *Node, ArrayRef<Hint> HintTypes) { - MDString* Name = dyn_cast<MDString>(Node->getOperand(0)); + MDString *Name = dyn_cast<MDString>(Node->getOperand(0)); if (!Name) return false; @@ -1181,17 +1344,17 @@ static void emitMissedWarning(Function *F, Loop *L, /// induction variable and the different reduction variables. class LoopVectorizationLegality { public: - LoopVectorizationLegality(Loop *L, PredicatedScalarEvolution &PSE, - DominatorTree *DT, TargetLibraryInfo *TLI, - AliasAnalysis *AA, Function *F, - const TargetTransformInfo *TTI, - LoopAccessAnalysis *LAA, - LoopVectorizationRequirements *R, - const LoopVectorizeHints *H) + LoopVectorizationLegality( + Loop *L, PredicatedScalarEvolution &PSE, DominatorTree *DT, + TargetLibraryInfo *TLI, AliasAnalysis *AA, Function *F, + const TargetTransformInfo *TTI, + std::function<const LoopAccessInfo &(Loop &)> *GetLAA, LoopInfo *LI, + LoopVectorizationRequirements *R, LoopVectorizeHints *H) : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TheFunction(F), - TTI(TTI), DT(DT), LAA(LAA), LAI(nullptr), InterleaveInfo(PSE, L, DT), - Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), - Requirements(R), Hints(H) {} + TTI(TTI), DT(DT), GetLAA(GetLAA), LAI(nullptr), + InterleaveInfo(PSE, L, DT, LI), Induction(nullptr), + WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), + Hints(H) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. @@ -1199,7 +1362,11 @@ public: /// InductionList saves induction variables and maps them to the /// induction descriptor. - typedef MapVector<PHINode*, InductionDescriptor> InductionList; + typedef MapVector<PHINode *, InductionDescriptor> InductionList; + + /// RecurrenceSet contains the phi nodes that are recurrences other than + /// inductions and reductions. + typedef SmallPtrSet<const PHINode *, 8> RecurrenceSet; /// Returns true if it is legal to vectorize this loop. /// This does not mean that it is profitable to vectorize this @@ -1215,6 +1382,9 @@ public: /// Returns the induction variables found in the loop. InductionList *getInductionVars() { return &Inductions; } + /// Return the first-order recurrences found in the loop. + RecurrenceSet *getFirstOrderRecurrences() { return &FirstOrderRecurrences; } + /// Returns the widest induction type. Type *getWidestInductionType() { return WidestIndTy; } @@ -1224,11 +1394,14 @@ public: /// Returns True if PN is a reduction variable in this loop. bool isReductionVariable(PHINode *PN) { return Reductions.count(PN); } + /// Returns True if Phi is a first-order recurrence in this loop. + bool isFirstOrderRecurrence(const PHINode *Phi); + /// Return true if the block BB needs to be predicated in order for the loop /// to be vectorized. bool blockNeedsPredication(BasicBlock *BB); - /// Check if this pointer is consecutive when vectorizing. This happens + /// Check if this pointer is consecutive when vectorizing. This happens /// when the last index of the GEP is the induction variable, or that the /// pointer itself is an induction variable. /// This check allows us to vectorize A[idx] into a wide load/store. @@ -1242,35 +1415,39 @@ public: bool isUniform(Value *V); /// Returns true if this instruction will remain scalar after vectorization. - bool isUniformAfterVectorization(Instruction* I) { return Uniforms.count(I); } + bool isUniformAfterVectorization(Instruction *I) { return Uniforms.count(I); } /// Returns the information that we collected about runtime memory check. const RuntimePointerChecking *getRuntimePointerChecking() const { return LAI->getRuntimePointerChecking(); } - const LoopAccessInfo *getLAI() const { - return LAI; - } + const LoopAccessInfo *getLAI() const { return LAI; } /// \brief Check if \p Instr belongs to any interleaved access group. bool isAccessInterleaved(Instruction *Instr) { return InterleaveInfo.isInterleaved(Instr); } + /// \brief Return the maximum interleave factor of all interleaved groups. + unsigned getMaxInterleaveFactor() const { + return InterleaveInfo.getMaxInterleaveFactor(); + } + /// \brief Get the interleaved access group that \p Instr belongs to. const InterleaveGroup *getInterleavedAccessGroup(Instruction *Instr) { return InterleaveInfo.getInterleaveGroup(Instr); } + /// \brief Returns true if an interleaved group requires a scalar iteration + /// to handle accesses with gaps. + bool requiresScalarEpilogue() const { + return InterleaveInfo.requiresScalarEpilogue(); + } + unsigned getMaxSafeDepDistBytes() { return LAI->getMaxSafeDepDistBytes(); } - bool hasStride(Value *V) { return StrideSet.count(V); } - bool mustCheckStrides() { return !StrideSet.empty(); } - SmallPtrSet<Value *, 8>::iterator strides_begin() { - return StrideSet.begin(); - } - SmallPtrSet<Value *, 8>::iterator strides_end() { return StrideSet.end(); } + bool hasStride(Value *V) { return LAI->hasStride(V); } /// Returns true if the target machine supports masked store operation /// for the given \p DataType and kind of access to \p Ptr. @@ -1282,20 +1459,24 @@ public: bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { return isConsecutivePtr(Ptr) && TTI->isLegalMaskedLoad(DataType); } - /// Returns true if vector representation of the instruction \p I - /// requires mask. - bool isMaskRequired(const Instruction* I) { - return (MaskedOp.count(I) != 0); - } - unsigned getNumStores() const { - return LAI->getNumStores(); + /// Returns true if the target machine supports masked scatter operation + /// for the given \p DataType. + bool isLegalMaskedScatter(Type *DataType) { + return TTI->isLegalMaskedScatter(DataType); } - unsigned getNumLoads() const { - return LAI->getNumLoads(); - } - unsigned getNumPredStores() const { - return NumPredStores; + /// Returns true if the target machine supports masked gather operation + /// for the given \p DataType. + bool isLegalMaskedGather(Type *DataType) { + return TTI->isLegalMaskedGather(DataType); } + + /// Returns true if vector representation of the instruction \p I + /// requires mask. + bool isMaskRequired(const Instruction *I) { return (MaskedOp.count(I) != 0); } + unsigned getNumStores() const { return LAI->getNumStores(); } + unsigned getNumLoads() const { return LAI->getNumLoads(); } + unsigned getNumPredStores() const { return NumPredStores; } + private: /// Check if a single basic block loop is vectorizable. /// At this point we know that this is a loop with a constant trip count @@ -1320,11 +1501,11 @@ private: /// and we know that we can read from them without segfault. bool blockCanBePredicated(BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs); - /// \brief Collect memory access with loop invariant strides. - /// - /// Looks for accesses like "a[i * StrideA]" where "StrideA" is loop - /// invariant. - void collectStridedAccess(Value *LoadOrStoreInst); + /// Updates the vectorization state by adding \p Phi to the inductions list. + /// This can set \p Phi as the main induction of the loop if \p Phi is a + /// better choice for the main induction than the existing one. + void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID, + SmallPtrSetImpl<Value *> &AllowedExit); /// Report an analysis message to assist the user in diagnosing loops that are /// not vectorized. These are handled as LoopAccessReport rather than @@ -1334,6 +1515,16 @@ private: emitAnalysisDiag(TheFunction, TheLoop, *Hints, Message); } + /// \brief If an access has a symbolic strides, this maps the pointer value to + /// the stride symbol. + const ValueToValueMap *getSymbolicStrides() { + // FIXME: Currently, the set of symbolic strides is sometimes queried before + // it's collected. This happens from canVectorizeWithIfConvert, when the + // pointer is checked to reference consecutive elements suitable for a + // masked access. + return LAI ? &LAI->getSymbolicStrides() : nullptr; + } + unsigned NumPredStores; /// The loop that we evaluate. @@ -1353,7 +1544,7 @@ private: /// Dominator Tree. DominatorTree *DT; // LoopAccess analysis. - LoopAccessAnalysis *LAA; + std::function<const LoopAccessInfo &(Loop &)> *GetLAA; // And the loop-accesses info corresponding to this loop. This pointer is // null until canVectorizeMemory sets it up. const LoopAccessInfo *LAI; @@ -1373,15 +1564,17 @@ private: /// Notice that inductions don't need to start at zero and that induction /// variables can be pointers. InductionList Inductions; + /// Holds the phi nodes that are first-order recurrences. + RecurrenceSet FirstOrderRecurrences; /// Holds the widest induction type encountered. Type *WidestIndTy; - /// Allowed outside users. This holds the reduction + /// Allowed outside users. This holds the induction and reduction /// vars which can be accessed from outside the loop. - SmallPtrSet<Value*, 4> AllowedExit; + SmallPtrSet<Value *, 4> AllowedExit; /// This set holds the variables which are known to be uniform after /// vectorization. - SmallPtrSet<Instruction*, 4> Uniforms; + SmallPtrSet<Instruction *, 4> Uniforms; /// Can we assume the absence of NaNs. bool HasFunNoNaNAttr; @@ -1390,10 +1583,7 @@ private: LoopVectorizationRequirements *Requirements; /// Used to emit an analysis of any legality issues. - const LoopVectorizeHints *Hints; - - ValueToValueMap Strides; - SmallPtrSet<Value *, 8> StrideSet; + LoopVectorizeHints *Hints; /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic @@ -1409,20 +1599,19 @@ private: /// different operations. class LoopVectorizationCostModel { public: - LoopVectorizationCostModel(Loop *L, ScalarEvolution *SE, LoopInfo *LI, - LoopVectorizationLegality *Legal, + LoopVectorizationCostModel(Loop *L, PredicatedScalarEvolution &PSE, + LoopInfo *LI, LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, DemandedBits *DB, AssumptionCache *AC, const Function *F, - const LoopVectorizeHints *Hints, - SmallPtrSetImpl<const Value *> &ValuesToIgnore) - : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), - TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} + const LoopVectorizeHints *Hints) + : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), + AC(AC), TheFunction(F), Hints(Hints) {} /// Information about vectorization costs struct VectorizationFactor { unsigned Width; // Vector width with best cost - unsigned Cost; // Cost of the loop with that width + unsigned Cost; // Cost of the loop with that width }; /// \return The most profitable vectorization factor and the cost of that VF. /// This method checks every power of two up to VF. If UserVF is not ZERO @@ -1462,19 +1651,34 @@ public: /// \return Returns information about the register usages of the loop for the /// given vectorization factors. - SmallVector<RegisterUsage, 8> - calculateRegisterUsage(const SmallVector<unsigned, 8> &VFs); + SmallVector<RegisterUsage, 8> calculateRegisterUsage(ArrayRef<unsigned> VFs); + + /// Collect values we want to ignore in the cost model. + void collectValuesToIgnore(); private: + /// The vectorization cost is a combination of the cost itself and a boolean + /// indicating whether any of the contributing operations will actually + /// operate on + /// vector values after type legalization in the backend. If this latter value + /// is + /// false, then all operations will be scalarized (i.e. no vectorization has + /// actually taken place). + typedef std::pair<unsigned, bool> VectorizationCostTy; + /// Returns the expected execution cost. The unit of the cost does /// not matter because we use the 'cost' units to compare different /// vector widths. The cost that is returned is *not* normalized by /// the factor width. - unsigned expectedCost(unsigned VF); + VectorizationCostTy expectedCost(unsigned VF); /// Returns the execution time cost of an instruction for a given vector /// width. Vector width of one means scalar. - unsigned getInstructionCost(Instruction *I, unsigned VF); + VectorizationCostTy getInstructionCost(Instruction *I, unsigned VF); + + /// The cost-computation logic from getInstructionCost which provides + /// the vector type as an output parameter. + unsigned getInstructionCost(Instruction *I, unsigned VF, Type *&VectorTy); /// Returns whether the instruction is a load or store and will be a emitted /// as a vector operation. @@ -1492,12 +1696,12 @@ public: /// Map of scalar integer values to the smallest bitwidth they can be legally /// represented as. The vector equivalents of these values should be truncated /// to this type. - MapVector<Instruction*,uint64_t> MinBWs; + MapVector<Instruction *, uint64_t> MinBWs; /// The loop that we evaluate. Loop *TheLoop; - /// Scev analysis. - ScalarEvolution *SE; + /// Predicated scalar evolution analysis. + PredicatedScalarEvolution &PSE; /// Loop Info analysis. LoopInfo *LI; /// Vectorization legality. @@ -1506,13 +1710,17 @@ public: const TargetTransformInfo &TTI; /// Target Library Info. const TargetLibraryInfo *TLI; - /// Demanded bits analysis + /// Demanded bits analysis. DemandedBits *DB; + /// Assumption cache. + AssumptionCache *AC; const Function *TheFunction; - // Loop Vectorize Hint. + /// Loop Vectorize Hint. const LoopVectorizeHints *Hints; - // Values to ignore in the cost model. - const SmallPtrSetImpl<const Value *> &ValuesToIgnore; + /// Values to ignore in the cost model. + SmallPtrSet<const Value *, 16> ValuesToIgnore; + /// Values to ignore in the cost model when VF > 1. + SmallPtrSet<const Value *, 16> VecValuesToIgnore; }; /// \brief This holds vectorization requirements that must be verified late in @@ -1588,328 +1796,35 @@ struct LoopVectorize : public FunctionPass { static char ID; explicit LoopVectorize(bool NoUnrolling = false, bool AlwaysVectorize = true) - : FunctionPass(ID), - DisableUnrolling(NoUnrolling), - AlwaysVectorize(AlwaysVectorize) { + : FunctionPass(ID) { + Impl.DisableUnrolling = NoUnrolling; + Impl.AlwaysVectorize = AlwaysVectorize; initializeLoopVectorizePass(*PassRegistry::getPassRegistry()); } - ScalarEvolution *SE; - LoopInfo *LI; - TargetTransformInfo *TTI; - DominatorTree *DT; - BlockFrequencyInfo *BFI; - TargetLibraryInfo *TLI; - DemandedBits *DB; - AliasAnalysis *AA; - AssumptionCache *AC; - LoopAccessAnalysis *LAA; - bool DisableUnrolling; - bool AlwaysVectorize; - - BlockFrequency ColdEntryFreq; + LoopVectorizePass Impl; bool runOnFunction(Function &F) override { - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - BFI = &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - TLI = TLIP ? &TLIP->getTLI() : nullptr; - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - LAA = &getAnalysis<LoopAccessAnalysis>(); - DB = &getAnalysis<DemandedBits>(); - - // Compute some weights outside of the loop over the loops. Compute this - // using a BranchProbability to re-use its scaling math. - const BranchProbability ColdProb(1, 5); // 20% - ColdEntryFreq = BlockFrequency(BFI->getEntryFreq()) * ColdProb; - - // Don't attempt if - // 1. the target claims to have no vector registers, and - // 2. interleaving won't help ILP. - // - // The second condition is necessary because, even if the target has no - // vector registers, loop vectorization may still enable scalar - // interleaving. - if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2) + if (skipFunction(F)) return false; - // Build up a worklist of inner-loops to vectorize. This is necessary as - // the act of vectorizing or partially unrolling a loop creates new loops - // and can invalidate iterators across the loops. - SmallVector<Loop *, 8> Worklist; - - for (Loop *L : *LI) - addInnerLoop(*L, Worklist); - - LoopsAnalyzed += Worklist.size(); - - // Now walk the identified inner loops. - bool Changed = false; - while (!Worklist.empty()) - Changed |= processLoop(Worklist.pop_back_val()); - - // Process each loop nest in the function. - return Changed; - } - - static void AddRuntimeUnrollDisableMetaData(Loop *L) { - SmallVector<Metadata *, 4> MDs; - // Reserve first location for self reference to the LoopID metadata node. - MDs.push_back(nullptr); - bool IsUnrollMetadata = false; - MDNode *LoopID = L->getLoopID(); - if (LoopID) { - // First find existing loop unrolling disable metadata. - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { - MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); - if (MD) { - const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); - IsUnrollMetadata = - S && S->getString().startswith("llvm.loop.unroll.disable"); - } - MDs.push_back(LoopID->getOperand(i)); - } - } - - if (!IsUnrollMetadata) { - // Add runtime unroll disable metadata. - LLVMContext &Context = L->getHeader()->getContext(); - SmallVector<Metadata *, 1> DisableOperands; - DisableOperands.push_back( - MDString::get(Context, "llvm.loop.unroll.runtime.disable")); - MDNode *DisableNode = MDNode::get(Context, DisableOperands); - MDs.push_back(DisableNode); - MDNode *NewLoopID = MDNode::get(Context, MDs); - // Set operand 0 to refer to the loop id itself. - NewLoopID->replaceOperandWith(0, NewLoopID); - L->setLoopID(NewLoopID); - } - } - - bool processLoop(Loop *L) { - assert(L->empty() && "Only process inner loops."); - -#ifndef NDEBUG - const std::string DebugLocStr = getDebugLocString(L); -#endif /* NDEBUG */ - - DEBUG(dbgs() << "\nLV: Checking a loop in \"" - << L->getHeader()->getParent()->getName() << "\" from " - << DebugLocStr << "\n"); - - LoopVectorizeHints Hints(L, DisableUnrolling); - - DEBUG(dbgs() << "LV: Loop hints:" - << " force=" - << (Hints.getForce() == LoopVectorizeHints::FK_Disabled - ? "disabled" - : (Hints.getForce() == LoopVectorizeHints::FK_Enabled - ? "enabled" - : "?")) << " width=" << Hints.getWidth() - << " unroll=" << Hints.getInterleave() << "\n"); - - // Function containing loop - Function *F = L->getHeader()->getParent(); - - // Looking at the diagnostic output is the only way to determine if a loop - // was vectorized (other than looking at the IR or machine code), so it - // is important to generate an optimization remark for each loop. Most of - // these messages are generated by emitOptimizationRemarkAnalysis. Remarks - // generated by emitOptimizationRemark and emitOptimizationRemarkMissed are - // less verbose reporting vectorized loops and unvectorized loops that may - // benefit from vectorization, respectively. - - if (!Hints.allowVectorization(F, L, AlwaysVectorize)) { - DEBUG(dbgs() << "LV: Loop hints prevent vectorization.\n"); - return false; - } - - // Check the loop for a trip count threshold: - // do not vectorize loops with a tiny trip count. - const unsigned TC = SE->getSmallConstantTripCount(L); - if (TC > 0u && TC < TinyTripCountVectorThreshold) { - DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " - << "This loop is not worth vectorizing."); - if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) - DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); - else { - DEBUG(dbgs() << "\n"); - emitAnalysisDiag(F, L, Hints, VectorizationReport() - << "vectorization is not beneficial " - "and is not explicitly forced"); - return false; - } - } - - PredicatedScalarEvolution PSE(*SE); - - // Check if it is legal to vectorize the loop. - LoopVectorizationRequirements Requirements; - LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, LAA, - &Requirements, &Hints); - if (!LVL.canVectorize()) { - DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); - emitMissedWarning(F, L, Hints); - return false; - } - - // Collect values we want to ignore in the cost model. This includes - // type-promoting instructions we identified during reduction detection. - SmallPtrSet<const Value *, 32> ValuesToIgnore; - CodeMetrics::collectEphemeralValues(L, AC, ValuesToIgnore); - for (auto &Reduction : *LVL.getReductionVars()) { - RecurrenceDescriptor &RedDes = Reduction.second; - SmallPtrSetImpl<Instruction *> &Casts = RedDes.getCastInsts(); - ValuesToIgnore.insert(Casts.begin(), Casts.end()); - } - - // Use the cost model. - LoopVectorizationCostModel CM(L, PSE.getSE(), LI, &LVL, *TTI, TLI, DB, AC, - F, &Hints, ValuesToIgnore); - - // Check the function attributes to find out if this function should be - // optimized for size. - bool OptForSize = Hints.getForce() != LoopVectorizeHints::FK_Enabled && - F->optForSize(); - - // Compute the weighted frequency of this loop being executed and see if it - // is less than 20% of the function entry baseline frequency. Note that we - // always have a canonical loop here because we think we *can* vectorize. - // FIXME: This is hidden behind a flag due to pervasive problems with - // exactly what block frequency models. - if (LoopVectorizeWithBlockFrequency) { - BlockFrequency LoopEntryFreq = BFI->getBlockFreq(L->getLoopPreheader()); - if (Hints.getForce() != LoopVectorizeHints::FK_Enabled && - LoopEntryFreq < ColdEntryFreq) - OptForSize = true; - } - - // Check the function attributes to see if implicit floats are allowed. - // FIXME: This check doesn't seem possibly correct -- what if the loop is - // an integer loop and the vector instructions selected are purely integer - // vector instructions? - if (F->hasFnAttribute(Attribute::NoImplicitFloat)) { - DEBUG(dbgs() << "LV: Can't vectorize when the NoImplicitFloat" - "attribute is used.\n"); - emitAnalysisDiag( - F, L, Hints, - VectorizationReport() - << "loop not vectorized due to NoImplicitFloat attribute"); - emitMissedWarning(F, L, Hints); - return false; - } - - // Select the optimal vectorization factor. - const LoopVectorizationCostModel::VectorizationFactor VF = - CM.selectVectorizationFactor(OptForSize); - - // Select the interleave count. - unsigned IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost); - - // Get user interleave count. - unsigned UserIC = Hints.getInterleave(); - - // Identify the diagnostic messages that should be produced. - std::string VecDiagMsg, IntDiagMsg; - bool VectorizeLoop = true, InterleaveLoop = true; - - if (Requirements.doesNotMeet(F, L, Hints)) { - DEBUG(dbgs() << "LV: Not vectorizing: loop did not meet vectorization " - "requirements.\n"); - emitMissedWarning(F, L, Hints); - return false; - } - - if (VF.Width == 1) { - DEBUG(dbgs() << "LV: Vectorization is possible but not beneficial.\n"); - VecDiagMsg = - "the cost-model indicates that vectorization is not beneficial"; - VectorizeLoop = false; - } - - if (IC == 1 && UserIC <= 1) { - // Tell the user interleaving is not beneficial. - DEBUG(dbgs() << "LV: Interleaving is not beneficial.\n"); - IntDiagMsg = - "the cost-model indicates that interleaving is not beneficial"; - InterleaveLoop = false; - if (UserIC == 1) - IntDiagMsg += - " and is explicitly disabled or interleave count is set to 1"; - } else if (IC > 1 && UserIC == 1) { - // Tell the user interleaving is beneficial, but it explicitly disabled. - DEBUG(dbgs() - << "LV: Interleaving is beneficial but is explicitly disabled."); - IntDiagMsg = "the cost-model indicates that interleaving is beneficial " - "but is explicitly disabled or interleave count is set to 1"; - InterleaveLoop = false; - } - - // Override IC if user provided an interleave count. - IC = UserIC > 0 ? UserIC : IC; - - // Emit diagnostic messages, if any. - const char *VAPassName = Hints.vectorizeAnalysisPassName(); - if (!VectorizeLoop && !InterleaveLoop) { - // Do not vectorize or interleaving the loop. - emitOptimizationRemarkAnalysis(F->getContext(), VAPassName, *F, - L->getStartLoc(), VecDiagMsg); - emitOptimizationRemarkAnalysis(F->getContext(), LV_NAME, *F, - L->getStartLoc(), IntDiagMsg); - return false; - } else if (!VectorizeLoop && InterleaveLoop) { - DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); - emitOptimizationRemarkAnalysis(F->getContext(), VAPassName, *F, - L->getStartLoc(), VecDiagMsg); - } else if (VectorizeLoop && !InterleaveLoop) { - DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " - << DebugLocStr << '\n'); - emitOptimizationRemarkAnalysis(F->getContext(), LV_NAME, *F, - L->getStartLoc(), IntDiagMsg); - } else if (VectorizeLoop && InterleaveLoop) { - DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " - << DebugLocStr << '\n'); - DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); - } - - if (!VectorizeLoop) { - assert(IC > 1 && "interleave count should not be 1 or 0"); - // If we decided that it is not legal to vectorize the loop then - // interleave it. - InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, IC); - Unroller.vectorize(&LVL, CM.MinBWs); - - emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), - Twine("interleaved loop (interleaved count: ") + - Twine(IC) + ")"); - } else { - // If we decided that it is *legal* to vectorize the loop then do it. - InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, VF.Width, IC); - LB.vectorize(&LVL, CM.MinBWs); - ++LoopsVectorized; - - // Add metadata to disable runtime unrolling scalar loop when there's no - // runtime check about strides and memory. Because at this situation, - // scalar loop is rarely used not worthy to be unrolled. - if (!LB.IsSafetyChecksAdded()) - AddRuntimeUnrollDisableMetaData(L); - - // Report the vectorization decision. - emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), - Twine("vectorized loop (vectorization width: ") + - Twine(VF.Width) + ", interleaved count: " + - Twine(IC) + ")"); - } + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *BFI = &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + auto *TLI = TLIP ? &TLIP->getTLI() : nullptr; + auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); + auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); - // Mark the loop as already vectorized to avoid vectorizing again. - Hints.setAlreadyVectorized(); + std::function<const LoopAccessInfo &(Loop &)> GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; - DEBUG(verifyFunction(*L->getHeader()->getParent())); - return true; + return Impl.runImpl(F, *SE, *LI, *TTI, *DT, *BFI, TLI, *DB, *AA, *AC, + GetLAA); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -1922,15 +1837,13 @@ struct LoopVectorize : public FunctionPass { AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<LoopAccessAnalysis>(); - AU.addRequired<DemandedBits>(); + AU.addRequired<LoopAccessLegacyAnalysis>(); + AU.addRequired<DemandedBitsWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<BasicAAWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } - }; } // end anonymous namespace @@ -1943,9 +1856,7 @@ struct LoopVectorize : public FunctionPass { Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { // We need to place the broadcast of invariant variables outside the loop. Instruction *Instr = dyn_cast<Instruction>(V); - bool NewInstr = - (Instr && std::find(LoopVectorBody.begin(), LoopVectorBody.end(), - Instr->getParent()) != LoopVectorBody.end()); + bool NewInstr = (Instr && Instr->getParent() == LoopVectorBody); bool Invariant = OrigLoop->isLoopInvariant(V) && !NewInstr; // Place the code for broadcasting invariant variables in the new preheader. @@ -1959,6 +1870,111 @@ Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { return Shuf; } +void InnerLoopVectorizer::createVectorIntInductionPHI( + const InductionDescriptor &II, VectorParts &Entry, IntegerType *TruncType) { + Value *Start = II.getStartValue(); + ConstantInt *Step = II.getConstIntStepValue(); + assert(Step && "Can not widen an IV with a non-constant step"); + + // Construct the initial value of the vector IV in the vector loop preheader + auto CurrIP = Builder.saveIP(); + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + if (TruncType) { + Step = ConstantInt::getSigned(TruncType, Step->getSExtValue()); + Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); + } + Value *SplatStart = Builder.CreateVectorSplat(VF, Start); + Value *SteppedStart = getStepVector(SplatStart, 0, Step); + Builder.restoreIP(CurrIP); + + Value *SplatVF = + ConstantVector::getSplat(VF, ConstantInt::getSigned(Start->getType(), + VF * Step->getSExtValue())); + // We may need to add the step a number of times, depending on the unroll + // factor. The last of those goes into the PHI. + PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind", + &*LoopVectorBody->getFirstInsertionPt()); + Value *LastInduction = VecInd; + for (unsigned Part = 0; Part < UF; ++Part) { + Entry[Part] = LastInduction; + LastInduction = Builder.CreateAdd(LastInduction, SplatVF, "step.add"); + } + + VecInd->addIncoming(SteppedStart, LoopVectorPreHeader); + VecInd->addIncoming(LastInduction, LoopVectorBody); +} + +void InnerLoopVectorizer::widenIntInduction(PHINode *IV, VectorParts &Entry, + TruncInst *Trunc) { + + auto II = Legal->getInductionVars()->find(IV); + assert(II != Legal->getInductionVars()->end() && "IV is not an induction"); + + auto ID = II->second; + assert(IV->getType() == ID.getStartValue()->getType() && "Types must match"); + + // If a truncate instruction was provided, get the smaller type. + auto *TruncType = Trunc ? cast<IntegerType>(Trunc->getType()) : nullptr; + + // The step of the induction. + Value *Step = nullptr; + + // If the induction variable has a constant integer step value, go ahead and + // get it now. + if (ID.getConstIntStepValue()) + Step = ID.getConstIntStepValue(); + + // Try to create a new independent vector induction variable. If we can't + // create the phi node, we will splat the scalar induction variable in each + // loop iteration. + if (VF > 1 && IV->getType() == Induction->getType() && Step && + !ValuesNotWidened->count(IV)) + return createVectorIntInductionPHI(ID, Entry, TruncType); + + // The scalar value to broadcast. This will be derived from the canonical + // induction variable. + Value *ScalarIV = nullptr; + + // Define the scalar induction variable and step values. If we were given a + // truncation type, truncate the canonical induction variable and constant + // step. Otherwise, derive these values from the induction descriptor. + if (TruncType) { + assert(Step && "Truncation requires constant integer step"); + auto StepInt = cast<ConstantInt>(Step)->getSExtValue(); + ScalarIV = Builder.CreateCast(Instruction::Trunc, Induction, TruncType); + Step = ConstantInt::getSigned(TruncType, StepInt); + } else { + ScalarIV = Induction; + auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + if (IV != OldInduction) { + ScalarIV = Builder.CreateSExtOrTrunc(ScalarIV, IV->getType()); + ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL); + ScalarIV->setName("offset.idx"); + } + if (!Step) { + SCEVExpander Exp(*PSE.getSE(), DL, "induction"); + Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(), + &*Builder.GetInsertPoint()); + } + } + + // Splat the scalar induction variable, and build the necessary step vectors. + Value *Broadcasted = getBroadcastInstrs(ScalarIV); + for (unsigned Part = 0; Part < UF; ++Part) + Entry[Part] = getStepVector(Broadcasted, VF * Part, Step); + + // If an induction variable is only used for counting loop iterations or + // calculating addresses, it doesn't need to be widened. Create scalar steps + // that can be used by instructions we will later scalarize. Note that the + // addition of the scalar steps will not increase the number of instructions + // in the loop in the common case prior to InstCombine. We will be trading + // one vector extract for each scalar step. + if (VF > 1 && ValuesNotWidened->count(IV)) { + auto *EntryVal = Trunc ? cast<Value>(Trunc) : IV; + buildScalarSteps(ScalarIV, Step, EntryVal); + } +} + Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step) { assert(Val->getType()->isVectorTy() && "Must be a vector"); @@ -1970,7 +1986,7 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Type *ITy = Val->getType()->getScalarType(); VectorType *Ty = cast<VectorType>(Val->getType()); int VLen = Ty->getNumElements(); - SmallVector<Constant*, 8> Indices; + SmallVector<Constant *, 8> Indices; // Create a vector of consecutive numbers from zero to VF. for (int i = 0; i < VLen; ++i) @@ -1987,6 +2003,27 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, return Builder.CreateAdd(Val, Step, "induction"); } +void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, + Value *EntryVal) { + + // We shouldn't have to build scalar steps if we aren't vectorizing. + assert(VF > 1 && "VF should be greater than one"); + + // Get the value type and ensure it and the step have the same integer type. + Type *ScalarIVTy = ScalarIV->getType()->getScalarType(); + assert(ScalarIVTy->isIntegerTy() && ScalarIVTy == Step->getType() && + "Val and Step should have the same integer type"); + + // Compute the scalar steps and save the results in ScalarIVMap. + for (unsigned Part = 0; Part < UF; ++Part) + for (unsigned I = 0; I < VF; ++I) { + auto *StartIdx = ConstantInt::get(ScalarIVTy, VF * Part + I); + auto *Mul = Builder.CreateMul(StartIdx, Step); + auto *Add = Builder.CreateAdd(ScalarIV, Mul); + ScalarIVMap[EntryVal].push_back(Add); + } +} + int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { assert(Ptr->getType()->isPointerTy() && "Unexpected non-ptr"); auto *SE = PSE.getSE(); @@ -1994,7 +2031,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { if (Ptr->getType()->getPointerElementType()->isAggregateType()) return 0; - // If this value is a pointer induction variable we know it is consecutive. + // If this value is a pointer induction variable, we know it is consecutive. PHINode *Phi = dyn_cast_or_null<PHINode>(Ptr); if (Phi && Inductions.count(Phi)) { InductionDescriptor II = Inductions[Phi]; @@ -2008,7 +2045,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { unsigned NumOperands = Gep->getNumOperands(); Value *GpPtr = Gep->getPointerOperand(); // If this GEP value is a consecutive pointer induction variable and all of - // the indices are constant then we know it is consecutive. We can + // the indices are constant, then we know it is consecutive. Phi = dyn_cast<PHINode>(GpPtr); if (Phi && Inductions.count(Phi)) { @@ -2038,7 +2075,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { // We can emit wide load/stores only if the last non-zero index is the // induction variable. const SCEV *Last = nullptr; - if (!Strides.count(Gep)) + if (!getSymbolicStrides() || !getSymbolicStrides()->count(Gep)) Last = PSE.getSCEV(Gep->getOperand(InductionOperand)); else { // Because of the multiplication by a stride we can have a s/zext cast. @@ -2050,7 +2087,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { // %idxprom = zext i32 %mul to i64 << Safe cast. // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom // - Last = replaceSymbolicStrideSCEV(PSE, Strides, + Last = replaceSymbolicStrideSCEV(PSE, *getSymbolicStrides(), Gep->getOperand(InductionOperand), Gep); if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last)) Last = @@ -2076,7 +2113,7 @@ bool LoopVectorizationLegality::isUniform(Value *V) { return LAI->isUniform(V); } -InnerLoopVectorizer::VectorParts& +InnerLoopVectorizer::VectorParts & InnerLoopVectorizer::getVectorValue(Value *V) { assert(V != Induction && "The new induction variable should not be used."); assert(!V->getType()->isVectorTy() && "Can't widen a vector"); @@ -2097,7 +2134,7 @@ InnerLoopVectorizer::getVectorValue(Value *V) { Value *InnerLoopVectorizer::reverseVector(Value *Vec) { assert(Vec->getType()->isVectorTy() && "Invalid type"); - SmallVector<Constant*, 8> ShuffleMask; + SmallVector<Constant *, 8> ShuffleMask; for (unsigned i = 0; i < VF; ++i) ShuffleMask.push_back(Builder.getInt32(VF - i - 1)); @@ -2308,7 +2345,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Group->isReverse() ? reverseVector(StridedVec) : StridedVec; } - propagateMetadata(NewLoadInstr, Instr); + addMetadata(NewLoadInstr, Instr); } return; } @@ -2326,7 +2363,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { assert(Member && "Fail to get a member from an interleaved store group"); Value *StoredVec = - getVectorValue(dyn_cast<StoreInst>(Member)->getValueOperand())[Part]; + getVectorValue(cast<StoreInst>(Member)->getValueOperand())[Part]; if (Group->isReverse()) StoredVec = reverseVector(StoredVec); @@ -2347,7 +2384,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Instruction *NewStoreInstr = Builder.CreateAlignedStore(IVec, NewPtrs[Part], Group->getAlignment()); - propagateMetadata(NewStoreInstr, Instr); + addMetadata(NewStoreInstr, Instr); } } @@ -2372,8 +2409,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { if (!Alignment) Alignment = DL.getABITypeAlignment(ScalarDataTy); unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace(); - unsigned ScalarAllocatedSize = DL.getTypeAllocSize(ScalarDataTy); - unsigned VectorElementSize = DL.getTypeStoreSize(DataTy) / VF; + uint64_t ScalarAllocatedSize = DL.getTypeAllocSize(ScalarDataTy); + uint64_t VectorElementSize = DL.getTypeStoreSize(DataTy) / VF; if (SI && Legal->blockNeedsPredication(SI->getParent()) && !Legal->isMaskRequired(SI)) @@ -2382,69 +2419,115 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { if (ScalarAllocatedSize != VectorElementSize) return scalarizeInstruction(Instr); - // If the pointer is loop invariant or if it is non-consecutive, - // scalarize the load. + // If the pointer is loop invariant scalarize the load. + if (LI && Legal->isUniform(Ptr)) + return scalarizeInstruction(Instr); + + // If the pointer is non-consecutive and gather/scatter is not supported + // scalarize the instruction. int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); bool Reverse = ConsecutiveStride < 0; - bool UniformLoad = LI && Legal->isUniform(Ptr); - if (!ConsecutiveStride || UniformLoad) + bool CreateGatherScatter = + !ConsecutiveStride && ((LI && Legal->isLegalMaskedGather(ScalarDataTy)) || + (SI && Legal->isLegalMaskedScatter(ScalarDataTy))); + + if (!ConsecutiveStride && !CreateGatherScatter) return scalarizeInstruction(Instr); Constant *Zero = Builder.getInt32(0); VectorParts &Entry = WidenMap.get(Instr); + VectorParts VectorGep; // Handle consecutive loads/stores. GetElementPtrInst *Gep = getGEPInstruction(Ptr); - if (Gep && Legal->isInductionVariable(Gep->getPointerOperand())) { - setDebugLocFromInst(Builder, Gep); - Value *PtrOperand = Gep->getPointerOperand(); - Value *FirstBasePtr = getVectorValue(PtrOperand)[0]; - FirstBasePtr = Builder.CreateExtractElement(FirstBasePtr, Zero); - - // Create the new GEP with the new induction variable. - GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone()); - Gep2->setOperand(0, FirstBasePtr); - Gep2->setName("gep.indvar.base"); - Ptr = Builder.Insert(Gep2); - } else if (Gep) { - setDebugLocFromInst(Builder, Gep); - assert(PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getPointerOperand()), - OrigLoop) && - "Base ptr must be invariant"); - - // The last index does not have to be the induction. It can be - // consecutive and be a function of the index. For example A[I+1]; - unsigned NumOperands = Gep->getNumOperands(); - unsigned InductionOperand = getGEPInductionOperand(Gep); - // Create the new GEP with the new induction variable. - GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone()); - - for (unsigned i = 0; i < NumOperands; ++i) { - Value *GepOperand = Gep->getOperand(i); - Instruction *GepOperandInst = dyn_cast<Instruction>(GepOperand); - - // Update last index or loop invariant instruction anchored in loop. - if (i == InductionOperand || - (GepOperandInst && OrigLoop->contains(GepOperandInst))) { - assert((i == InductionOperand || - PSE.getSE()->isLoopInvariant(PSE.getSCEV(GepOperandInst), - OrigLoop)) && - "Must be last index or loop invariant"); - - VectorParts &GEPParts = getVectorValue(GepOperand); - Value *Index = GEPParts[0]; - Index = Builder.CreateExtractElement(Index, Zero); - Gep2->setOperand(i, Index); - Gep2->setName("gep.indvar.idx"); + if (ConsecutiveStride) { + if (Gep && Legal->isInductionVariable(Gep->getPointerOperand())) { + setDebugLocFromInst(Builder, Gep); + Value *PtrOperand = Gep->getPointerOperand(); + Value *FirstBasePtr = getVectorValue(PtrOperand)[0]; + FirstBasePtr = Builder.CreateExtractElement(FirstBasePtr, Zero); + + // Create the new GEP with the new induction variable. + GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone()); + Gep2->setOperand(0, FirstBasePtr); + Gep2->setName("gep.indvar.base"); + Ptr = Builder.Insert(Gep2); + } else if (Gep) { + setDebugLocFromInst(Builder, Gep); + assert(PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getPointerOperand()), + OrigLoop) && + "Base ptr must be invariant"); + // The last index does not have to be the induction. It can be + // consecutive and be a function of the index. For example A[I+1]; + unsigned NumOperands = Gep->getNumOperands(); + unsigned InductionOperand = getGEPInductionOperand(Gep); + // Create the new GEP with the new induction variable. + GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone()); + + for (unsigned i = 0; i < NumOperands; ++i) { + Value *GepOperand = Gep->getOperand(i); + Instruction *GepOperandInst = dyn_cast<Instruction>(GepOperand); + + // Update last index or loop invariant instruction anchored in loop. + if (i == InductionOperand || + (GepOperandInst && OrigLoop->contains(GepOperandInst))) { + assert((i == InductionOperand || + PSE.getSE()->isLoopInvariant(PSE.getSCEV(GepOperandInst), + OrigLoop)) && + "Must be last index or loop invariant"); + + VectorParts &GEPParts = getVectorValue(GepOperand); + + // If GepOperand is an induction variable, and there's a scalarized + // version of it available, use it. Otherwise, we will need to create + // an extractelement instruction. + Value *Index = ScalarIVMap.count(GepOperand) + ? ScalarIVMap[GepOperand][0] + : Builder.CreateExtractElement(GEPParts[0], Zero); + + Gep2->setOperand(i, Index); + Gep2->setName("gep.indvar.idx"); + } } + Ptr = Builder.Insert(Gep2); + } else { // No GEP + // Use the induction element ptr. + assert(isa<PHINode>(Ptr) && "Invalid induction ptr"); + setDebugLocFromInst(Builder, Ptr); + VectorParts &PtrVal = getVectorValue(Ptr); + Ptr = Builder.CreateExtractElement(PtrVal[0], Zero); } - Ptr = Builder.Insert(Gep2); } else { - // Use the induction element ptr. - assert(isa<PHINode>(Ptr) && "Invalid induction ptr"); - setDebugLocFromInst(Builder, Ptr); - VectorParts &PtrVal = getVectorValue(Ptr); - Ptr = Builder.CreateExtractElement(PtrVal[0], Zero); + // At this point we should vector version of GEP for Gather or Scatter + assert(CreateGatherScatter && "The instruction should be scalarized"); + if (Gep) { + // Vectorizing GEP, across UF parts. We want to get a vector value for base + // and each index that's defined inside the loop, even if it is + // loop-invariant but wasn't hoisted out. Otherwise we want to keep them + // scalar. + SmallVector<VectorParts, 4> OpsV; + for (Value *Op : Gep->operands()) { + Instruction *SrcInst = dyn_cast<Instruction>(Op); + if (SrcInst && OrigLoop->contains(SrcInst)) + OpsV.push_back(getVectorValue(Op)); + else + OpsV.push_back(VectorParts(UF, Op)); + } + for (unsigned Part = 0; Part < UF; ++Part) { + SmallVector<Value *, 4> Ops; + Value *GEPBasePtr = OpsV[0][Part]; + for (unsigned i = 1; i < Gep->getNumOperands(); i++) + Ops.push_back(OpsV[i][Part]); + Value *NewGep = Builder.CreateGEP(GEPBasePtr, Ops, "VectorGep"); + cast<GetElementPtrInst>(NewGep)->setIsInBounds(Gep->isInBounds()); + assert(NewGep->getType()->isVectorTy() && "Expected vector GEP"); + + NewGep = + Builder.CreateBitCast(NewGep, VectorType::get(Ptr->getType(), VF)); + VectorGep.push_back(NewGep); + } + } else + VectorGep = getVectorValue(Ptr); } VectorParts Mask = createBlockInMask(Instr->getParent()); @@ -2458,62 +2541,78 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { VectorParts StoredVal = getVectorValue(SI->getValueOperand()); for (unsigned Part = 0; Part < UF; ++Part) { + Instruction *NewSI = nullptr; + if (CreateGatherScatter) { + Value *MaskPart = Legal->isMaskRequired(SI) ? Mask[Part] : nullptr; + NewSI = Builder.CreateMaskedScatter(StoredVal[Part], VectorGep[Part], + Alignment, MaskPart); + } else { + // Calculate the pointer for the specific unroll-part. + Value *PartPtr = + Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(Part * VF)); + + if (Reverse) { + // If we store to reverse consecutive memory locations, then we need + // to reverse the order of elements in the stored value. + StoredVal[Part] = reverseVector(StoredVal[Part]); + // If the address is consecutive but reversed, then the + // wide store needs to start at the last vector element. + PartPtr = + Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); + PartPtr = + Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); + Mask[Part] = reverseVector(Mask[Part]); + } + + Value *VecPtr = + Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); + + if (Legal->isMaskRequired(SI)) + NewSI = Builder.CreateMaskedStore(StoredVal[Part], VecPtr, Alignment, + Mask[Part]); + else + NewSI = + Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); + } + addMetadata(NewSI, SI); + } + return; + } + + // Handle loads. + assert(LI && "Must have a load instruction"); + setDebugLocFromInst(Builder, LI); + for (unsigned Part = 0; Part < UF; ++Part) { + Instruction *NewLI; + if (CreateGatherScatter) { + Value *MaskPart = Legal->isMaskRequired(LI) ? Mask[Part] : nullptr; + NewLI = Builder.CreateMaskedGather(VectorGep[Part], Alignment, MaskPart, + 0, "wide.masked.gather"); + Entry[Part] = NewLI; + } else { // Calculate the pointer for the specific unroll-part. Value *PartPtr = Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(Part * VF)); if (Reverse) { - // If we store to reverse consecutive memory locations, then we need - // to reverse the order of elements in the stored value. - StoredVal[Part] = reverseVector(StoredVal[Part]); // If the address is consecutive but reversed, then the - // wide store needs to start at the last vector element. + // wide load needs to start at the last vector element. PartPtr = Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); PartPtr = Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); Mask[Part] = reverseVector(Mask[Part]); } - Value *VecPtr = Builder.CreateBitCast(PartPtr, - DataTy->getPointerTo(AddressSpace)); - - Instruction *NewSI; - if (Legal->isMaskRequired(SI)) - NewSI = Builder.CreateMaskedStore(StoredVal[Part], VecPtr, Alignment, - Mask[Part]); - else - NewSI = Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); - propagateMetadata(NewSI, SI); - } - return; - } - - // Handle loads. - assert(LI && "Must have a load instruction"); - setDebugLocFromInst(Builder, LI); - for (unsigned Part = 0; Part < UF; ++Part) { - // Calculate the pointer for the specific unroll-part. - Value *PartPtr = - Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(Part * VF)); - - if (Reverse) { - // If the address is consecutive but reversed, then the - // wide load needs to start at the last vector element. - PartPtr = Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); - PartPtr = Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); - Mask[Part] = reverseVector(Mask[Part]); + Value *VecPtr = + Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); + if (Legal->isMaskRequired(LI)) + NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask[Part], + UndefValue::get(DataTy), + "wide.masked.load"); + else + NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); + Entry[Part] = Reverse ? reverseVector(NewLI) : NewLI; } - - Instruction* NewLI; - Value *VecPtr = Builder.CreateBitCast(PartPtr, - DataTy->getPointerTo(AddressSpace)); - if (Legal->isMaskRequired(LI)) - NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask[Part], - UndefValue::get(DataTy), - "wide.masked.load"); - else - NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); - propagateMetadata(NewLI, LI); - Entry[Part] = Reverse ? reverseVector(NewLI) : NewLI; + addMetadata(NewLI, LI); } } @@ -2526,9 +2625,7 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, setDebugLocFromInst(Builder, Instr); // Find all of the vectorized parameters. - for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - Value *SrcOp = Instr->getOperand(op); - + for (Value *SrcOp : Instr->operands()) { // If we are accessing the old induction variable, use the new one. if (SrcOp == OldInduction) { Params.push_back(getVectorValue(SrcOp)); @@ -2536,7 +2633,7 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, } // Try using previously calculated values. - Instruction *SrcInst = dyn_cast<Instruction>(SrcOp); + auto *SrcInst = dyn_cast<Instruction>(SrcOp); // If the src is an instruction that appeared earlier in the basic block, // then it should already be vectorized. @@ -2558,8 +2655,9 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, // Does this instruction return a value ? bool IsVoidRetTy = Instr->getType()->isVoidTy(); - Value *UndefVec = IsVoidRetTy ? nullptr : - UndefValue::get(VectorType::get(Instr->getType(), VF)); + Value *UndefVec = + IsVoidRetTy ? nullptr + : UndefValue::get(VectorType::get(Instr->getType(), VF)); // Create a new entry in the WidenMap and initialize it to Undef or Null. VectorParts &VecResults = WidenMap.splat(Instr, UndefVec); @@ -2589,16 +2687,28 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, Cloned->setName(Instr->getName() + ".cloned"); // Replace the operands of the cloned instructions with extracted scalars. for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - Value *Op = Params[op][Part]; - // Param is a vector. Need to extract the right lane. - if (Op->getType()->isVectorTy()) - Op = Builder.CreateExtractElement(Op, Builder.getInt32(Width)); - Cloned->setOperand(op, Op); + + // If the operand is an induction variable, and there's a scalarized + // version of it available, use it. Otherwise, we will need to create + // an extractelement instruction if vectorizing. + auto *NewOp = Params[op][Part]; + auto *ScalarOp = Instr->getOperand(op); + if (ScalarIVMap.count(ScalarOp)) + NewOp = ScalarIVMap[ScalarOp][VF * Part + Width]; + else if (NewOp->getType()->isVectorTy()) + NewOp = Builder.CreateExtractElement(NewOp, Builder.getInt32(Width)); + Cloned->setOperand(op, NewOp); } + addNewMetadata(Cloned, Instr); // Place the cloned scalar in the new loop. Builder.Insert(Cloned); + // If we just cloned a new assumption, add it the assumption cache. + if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + // If the original scalar returns a value we need to place it in a vector // so that future users will be able to use it. if (!IsVoidRetTy) @@ -2606,8 +2716,8 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, Builder.getInt32(Width)); // End if-block. if (IfPredicateStore) - PredicatedStores.push_back(std::make_pair(cast<StoreInst>(Cloned), - Cmp)); + PredicatedStores.push_back( + std::make_pair(cast<StoreInst>(Cloned), Cmp)); } } } @@ -2627,7 +2737,7 @@ PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, auto *Induction = Builder.CreatePHI(Start->getType(), 2, "index"); Builder.SetInsertPoint(Latch->getTerminator()); - + // Create i+1 and fill the PHINode. Value *Next = Builder.CreateAdd(Induction, Step, "index.next"); Induction->addIncoming(Start, L->getLoopPreheader()); @@ -2635,7 +2745,7 @@ PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, // Create the compare. Value *ICmp = Builder.CreateICmpEQ(Next, End); Builder.CreateCondBr(ICmp, L->getExitBlock(), Header); - + // Now we have two terminators. Remove the old one from the block. Latch->getTerminator()->eraseFromParent(); @@ -2649,12 +2759,12 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); // Find the loop boundaries. ScalarEvolution *SE = PSE.getSE(); - const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(OrigLoop); + const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount(); assert(BackedgeTakenCount != SE->getCouldNotCompute() && "Invalid loop count"); Type *IdxTy = Legal->getWidestInductionType(); - + // The exit count might have the type of i64 while the phi is i32. This can // happen if we have an induction variable that is sign extended before the // compare. The only way that we get a backedge taken count is that the @@ -2664,7 +2774,7 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { IdxTy->getPrimitiveSizeInBits()) BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount, IdxTy); BackedgeTakenCount = SE->getNoopOrZeroExtend(BackedgeTakenCount, IdxTy); - + // Get the total trip count from the count by adding 1. const SCEV *ExitCount = SE->getAddExpr( BackedgeTakenCount, SE->getOne(BackedgeTakenCount->getType())); @@ -2681,9 +2791,8 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { if (TripCount->getType()->isPointerTy()) TripCount = - CastInst::CreatePointerCast(TripCount, IdxTy, - "exitcount.ptrcnt.to.int", - L->getLoopPreheader()->getTerminator()); + CastInst::CreatePointerCast(TripCount, IdxTy, "exitcount.ptrcnt.to.int", + L->getLoopPreheader()->getTerminator()); return TripCount; } @@ -2691,16 +2800,30 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { Value *InnerLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { if (VectorTripCount) return VectorTripCount; - + Value *TC = getOrCreateTripCount(L); IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); - - // Now we need to generate the expression for N - (N % VF), which is - // the part that the vectorized body will execute. - // The loop step is equal to the vectorization factor (num of SIMD elements) - // times the unroll factor (num of SIMD instructions). + + // Now we need to generate the expression for the part of the loop that the + // vectorized body will execute. This is equal to N - (N % Step) if scalar + // iterations are not required for correctness, or N - Step, otherwise. Step + // is equal to the vectorization factor (number of SIMD elements) times the + // unroll factor (number of SIMD instructions). Constant *Step = ConstantInt::get(TC->getType(), VF * UF); Value *R = Builder.CreateURem(TC, Step, "n.mod.vf"); + + // If there is a non-reversed interleaved group that may speculatively access + // memory out-of-bounds, we need to ensure that there will be at least one + // iteration of the scalar epilogue loop. Thus, if the step evenly divides + // the trip count, we set the remainder to be equal to the step. If the step + // does not evenly divide the trip count, no adjustment is necessary since + // there will already be scalar iterations. Note that the minimum iterations + // check ensures that N >= Step. + if (VF > 1 && Legal->requiresScalarEpilogue()) { + auto *IsZero = Builder.CreateICmpEQ(R, ConstantInt::get(R->getType(), 0)); + R = Builder.CreateSelect(IsZero, Step, R); + } + VectorTripCount = Builder.CreateSub(TC, R, "n.vec"); return VectorTripCount; @@ -2714,13 +2837,15 @@ void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, // Generate code to check that the loop's trip count that we computed by // adding one to the backedge-taken count will not overflow. - Value *CheckMinIters = - Builder.CreateICmpULT(Count, - ConstantInt::get(Count->getType(), VF * UF), - "min.iters.check"); - - BasicBlock *NewBB = BB->splitBasicBlock(BB->getTerminator(), - "min.iters.checked"); + Value *CheckMinIters = Builder.CreateICmpULT( + Count, ConstantInt::get(Count->getType(), VF * UF), "min.iters.check"); + + BasicBlock *NewBB = + BB->splitBasicBlock(BB->getTerminator(), "min.iters.checked"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), @@ -2733,7 +2858,7 @@ void InnerLoopVectorizer::emitVectorLoopEnteredCheck(Loop *L, Value *TC = getOrCreateVectorTripCount(L); BasicBlock *BB = L->getLoopPreheader(); IRBuilder<> Builder(BB->getTerminator()); - + // Now, compare the new count to zero. If it is zero skip the vector loop and // jump to the scalar loop. Value *Cmp = Builder.CreateICmpEQ(TC, Constant::getNullValue(TC->getType()), @@ -2741,8 +2866,11 @@ void InnerLoopVectorizer::emitVectorLoopEnteredCheck(Loop *L, // Generate code to check that the loop's trip count that we computed by // adding one to the backedge-taken count will not overflow. - BasicBlock *NewBB = BB->splitBasicBlock(BB->getTerminator(), - "vector.ph"); + BasicBlock *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), @@ -2768,6 +2896,10 @@ void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { // Create a new block containing the stride check. BB->setName("vector.scevcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), @@ -2776,8 +2908,7 @@ void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { AddedSafetyChecks = true; } -void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, - BasicBlock *Bypass) { +void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { BasicBlock *BB = L->getLoopPreheader(); // Generate the code that checks in runtime if arrays overlap. We put the @@ -2793,14 +2924,23 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, // Create a new block containing the memory check. BB->setName("vector.memcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); if (L->getParentLoop()) L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); ReplaceInstWithInst(BB->getTerminator(), BranchInst::Create(Bypass, NewBB, MemRuntimeCheck)); LoopBypassBlocks.push_back(BB); AddedSafetyChecks = true; -} + // We currently don't use LoopVersioning for the actual loop cloning but we + // still use it to add the noalias metadata. + LVer = llvm::make_unique<LoopVersioning>(*Legal->getLAI(), OrigLoop, LI, DT, + PSE.getSE()); + LVer->prepareNoAliasMetadata(); +} void InnerLoopVectorizer::createEmptyLoop() { /* @@ -2859,12 +2999,12 @@ void InnerLoopVectorizer::createEmptyLoop() { BasicBlock *VecBody = VectorPH->splitBasicBlock(VectorPH->getTerminator(), "vector.body"); BasicBlock *MiddleBlock = - VecBody->splitBasicBlock(VecBody->getTerminator(), "middle.block"); + VecBody->splitBasicBlock(VecBody->getTerminator(), "middle.block"); BasicBlock *ScalarPH = - MiddleBlock->splitBasicBlock(MiddleBlock->getTerminator(), "scalar.ph"); + MiddleBlock->splitBasicBlock(MiddleBlock->getTerminator(), "scalar.ph"); // Create and register the new vector loop. - Loop* Lp = new Loop(); + Loop *Lp = new Loop(); Loop *ParentLoop = OrigLoop->getParentLoop(); // Insert the new loop into the loop nest and register the new basic blocks @@ -2899,15 +3039,15 @@ void InnerLoopVectorizer::createEmptyLoop() { // checks into a separate block to make the more common case of few elements // faster. emitMemRuntimeChecks(Lp, ScalarPH); - + // Generate the induction variable. // The loop step is equal to the vectorization factor (num of SIMD elements) // times the unroll factor (num of SIMD instructions). Value *CountRoundDown = getOrCreateVectorTripCount(Lp); Constant *Step = ConstantInt::get(IdxTy, VF * UF); Induction = - createInductionVariable(Lp, StartIdx, CountRoundDown, Step, - getDebugLocFromInstOrOperands(OldInduction)); + createInductionVariable(Lp, StartIdx, CountRoundDown, Step, + getDebugLocFromInstOrOperands(OldInduction)); // We are going to resume the execution of the scalar loop. // Go over all of the induction variables that we found and fix the @@ -2920,16 +3060,14 @@ void InnerLoopVectorizer::createEmptyLoop() { // This variable saves the new starting index for the scalar loop. It is used // to test if there are any tail iterations left once the vector loop has // completed. - LoopVectorizationLegality::InductionList::iterator I, E; LoopVectorizationLegality::InductionList *List = Legal->getInductionVars(); - for (I = List->begin(), E = List->end(); I != E; ++I) { - PHINode *OrigPhi = I->first; - InductionDescriptor II = I->second; + for (auto &InductionEntry : *List) { + PHINode *OrigPhi = InductionEntry.first; + InductionDescriptor II = InductionEntry.second; // Create phi nodes to merge from the backedge-taken check block. - PHINode *BCResumeVal = PHINode::Create(OrigPhi->getType(), 3, - "bc.resume.val", - ScalarPH->getTerminator()); + PHINode *BCResumeVal = PHINode::Create( + OrigPhi->getType(), 3, "bc.resume.val", ScalarPH->getTerminator()); Value *EndValue; if (OrigPhi == OldInduction) { // We know what the end value is. @@ -2937,9 +3075,9 @@ void InnerLoopVectorizer::createEmptyLoop() { } else { IRBuilder<> B(LoopBypassBlocks.back()->getTerminator()); Value *CRD = B.CreateSExtOrTrunc(CountRoundDown, - II.getStepValue()->getType(), - "cast.crd"); - EndValue = II.transform(B, CRD); + II.getStep()->getType(), "cast.crd"); + const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + EndValue = II.transform(B, CRD, PSE.getSE(), DL); EndValue->setName("ind.end"); } @@ -2947,22 +3085,25 @@ void InnerLoopVectorizer::createEmptyLoop() { // or the value at the end of the vectorized loop. BCResumeVal->addIncoming(EndValue, MiddleBlock); + // Fix up external users of the induction variable. + fixupIVUsers(OrigPhi, II, CountRoundDown, EndValue, MiddleBlock); + // Fix the scalar body counter (PHI node). unsigned BlockIdx = OrigPhi->getBasicBlockIndex(ScalarPH); // The old induction's phi node in the scalar body needs the truncated // value. - for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I) - BCResumeVal->addIncoming(II.getStartValue(), LoopBypassBlocks[I]); + for (BasicBlock *BB : LoopBypassBlocks) + BCResumeVal->addIncoming(II.getStartValue(), BB); OrigPhi->setIncomingValue(BlockIdx, BCResumeVal); } // Add a check in the middle block to see if we have completed // all of the iterations in the first vector loop. // If (N - N%VF) == N, then we *don't* need to run the remainder. - Value *CmpN = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, Count, - CountRoundDown, "cmp.n", - MiddleBlock->getTerminator()); + Value *CmpN = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, Count, + CountRoundDown, "cmp.n", MiddleBlock->getTerminator()); ReplaceInstWithInst(MiddleBlock->getTerminator(), BranchInst::Create(ExitBlock, ScalarPH, CmpN)); @@ -2974,13 +3115,79 @@ void InnerLoopVectorizer::createEmptyLoop() { LoopScalarPreHeader = ScalarPH; LoopMiddleBlock = MiddleBlock; LoopExitBlock = ExitBlock; - LoopVectorBody.push_back(VecBody); + LoopVectorBody = VecBody; LoopScalarBody = OldBasicBlock; + // Keep all loop hints from the original loop on the vector loop (we'll + // replace the vectorizer-specific hints below). + if (MDNode *LID = OrigLoop->getLoopID()) + Lp->setLoopID(LID); + LoopVectorizeHints Hints(Lp, true); Hints.setAlreadyVectorized(); } +// Fix up external users of the induction variable. At this point, we are +// in LCSSA form, with all external PHIs that use the IV having one input value, +// coming from the remainder loop. We need those PHIs to also have a correct +// value for the IV when arriving directly from the middle block. +void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, + const InductionDescriptor &II, + Value *CountRoundDown, Value *EndValue, + BasicBlock *MiddleBlock) { + // There are two kinds of external IV usages - those that use the value + // computed in the last iteration (the PHI) and those that use the penultimate + // value (the value that feeds into the phi from the loop latch). + // We allow both, but they, obviously, have different values. + + assert(OrigLoop->getExitBlock() && "Expected a single exit block"); + + DenseMap<Value *, Value *> MissingVals; + + // An external user of the last iteration's value should see the value that + // the remainder loop uses to initialize its own IV. + Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoop->getLoopLatch()); + for (User *U : PostInc->users()) { + Instruction *UI = cast<Instruction>(U); + if (!OrigLoop->contains(UI)) { + assert(isa<PHINode>(UI) && "Expected LCSSA form"); + MissingVals[UI] = EndValue; + } + } + + // An external user of the penultimate value need to see EndValue - Step. + // The simplest way to get this is to recompute it from the constituent SCEVs, + // that is Start + (Step * (CRD - 1)). + for (User *U : OrigPhi->users()) { + auto *UI = cast<Instruction>(U); + if (!OrigLoop->contains(UI)) { + const DataLayout &DL = + OrigLoop->getHeader()->getModule()->getDataLayout(); + assert(isa<PHINode>(UI) && "Expected LCSSA form"); + + IRBuilder<> B(MiddleBlock->getTerminator()); + Value *CountMinusOne = B.CreateSub( + CountRoundDown, ConstantInt::get(CountRoundDown->getType(), 1)); + Value *CMO = B.CreateSExtOrTrunc(CountMinusOne, II.getStep()->getType(), + "cast.cmo"); + Value *Escape = II.transform(B, CMO, PSE.getSE(), DL); + Escape->setName("ind.escape"); + MissingVals[UI] = Escape; + } + } + + for (auto &I : MissingVals) { + PHINode *PHI = cast<PHINode>(I.first); + // One corner case we have to handle is two IVs "chasing" each-other, + // that is %IV2 = phi [...], [ %IV1, %latch ] + // In this case, if IV1 has an external use, we need to avoid adding both + // "last value of IV1" and "penultimate value of IV2". So, verify that we + // don't already have an incoming value for the middle block. + if (PHI->getBasicBlockIndex(MiddleBlock) == -1) + PHI->addIncoming(I.second, MiddleBlock); + } +} + namespace { struct CSEDenseMapInfo { static bool canHandle(Instruction *I) { @@ -3007,48 +3214,31 @@ struct CSEDenseMapInfo { }; } -/// \brief Check whether this block is a predicated block. -/// Due to if predication of stores we might create a sequence of "if(pred) a[i] -/// = ...; " blocks. We start with one vectorized basic block. For every -/// conditional block we split this vectorized block. Therefore, every second -/// block will be a predicated one. -static bool isPredicatedBlock(unsigned BlockNum) { - return BlockNum % 2; -} - ///\brief Perform cse of induction variable instructions. -static void cse(SmallVector<BasicBlock *, 4> &BBs) { +static void cse(BasicBlock *BB) { // Perform simple cse. SmallDenseMap<Instruction *, Instruction *, 4, CSEDenseMapInfo> CSEMap; - for (unsigned i = 0, e = BBs.size(); i != e; ++i) { - BasicBlock *BB = BBs[i]; - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { - Instruction *In = &*I++; - - if (!CSEDenseMapInfo::canHandle(In)) - continue; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + Instruction *In = &*I++; - // Check if we can replace this instruction with any of the - // visited instructions. - if (Instruction *V = CSEMap.lookup(In)) { - In->replaceAllUsesWith(V); - In->eraseFromParent(); - continue; - } - // Ignore instructions in conditional blocks. We create "if (pred) a[i] = - // ...;" blocks for predicated stores. Every second block is a predicated - // block. - if (isPredicatedBlock(i)) - continue; + if (!CSEDenseMapInfo::canHandle(In)) + continue; - CSEMap[In] = In; + // Check if we can replace this instruction with any of the + // visited instructions. + if (Instruction *V = CSEMap.lookup(In)) { + In->replaceAllUsesWith(V); + In->eraseFromParent(); + continue; } + + CSEMap[In] = In; } } /// \brief Adds a 'fast' flag to floating point operations. static Value *addFastMathFlag(Value *V) { - if (isa<FPMathOperator>(V)){ + if (isa<FPMathOperator>(V)) { FastMathFlags Flags; Flags.setUnsafeAlgebra(); cast<Instruction>(V)->setFastMathFlags(Flags); @@ -3066,11 +3256,11 @@ static unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract, assert(Ty->isVectorTy() && "Can only scalarize vectors"); unsigned Cost = 0; - for (int i = 0, e = Ty->getVectorNumElements(); i < e; ++i) { + for (unsigned I = 0, E = Ty->getVectorNumElements(); I < E; ++I) { if (Insert) - Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, i); + Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, I); if (Extract) - Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, Ty, i); + Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, Ty, I); } return Cost; @@ -3101,15 +3291,15 @@ static unsigned getVectorCallCost(CallInst *CI, unsigned VF, // Compute corresponding vector type for return value and arguments. Type *RetTy = ToVectorTy(ScalarRetTy, VF); - for (unsigned i = 0, ie = ScalarTys.size(); i != ie; ++i) - Tys.push_back(ToVectorTy(ScalarTys[i], VF)); + for (Type *ScalarTy : ScalarTys) + Tys.push_back(ToVectorTy(ScalarTy, VF)); // Compute costs of unpacking argument values for the scalar calls and // packing the return values to a vector. unsigned ScalarizationCost = getScalarizationOverhead(RetTy, true, false, TTI); - for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) - ScalarizationCost += getScalarizationOverhead(Tys[i], false, true, TTI); + for (Type *Ty : Tys) + ScalarizationCost += getScalarizationOverhead(Ty, false, true, TTI); unsigned Cost = ScalarCallCost * VF + ScalarizationCost; @@ -3134,25 +3324,29 @@ static unsigned getVectorCallCost(CallInst *CI, unsigned VF, static unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI) { - Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); assert(ID && "Expected intrinsic call!"); Type *RetTy = ToVectorTy(CI->getType(), VF); SmallVector<Type *, 4> Tys; - for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) - Tys.push_back(ToVectorTy(CI->getArgOperand(i)->getType(), VF)); + for (Value *ArgOperand : CI->arg_operands()) + Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); - return TTI.getIntrinsicInstrCost(ID, RetTy, Tys); + FastMathFlags FMF; + if (auto *FPMO = dyn_cast<FPMathOperator>(CI)) + FMF = FPMO->getFastMathFlags(); + + return TTI.getIntrinsicInstrCost(ID, RetTy, Tys, FMF); } static Type *smallestIntegerVectorType(Type *T1, Type *T2) { - IntegerType *I1 = cast<IntegerType>(T1->getVectorElementType()); - IntegerType *I2 = cast<IntegerType>(T2->getVectorElementType()); + auto *I1 = cast<IntegerType>(T1->getVectorElementType()); + auto *I2 = cast<IntegerType>(T2->getVectorElementType()); return I1->getBitWidth() < I2->getBitWidth() ? T1 : T2; } static Type *largestIntegerVectorType(Type *T1, Type *T2) { - IntegerType *I1 = cast<IntegerType>(T1->getVectorElementType()); - IntegerType *I2 = cast<IntegerType>(T2->getVectorElementType()); + auto *I1 = cast<IntegerType>(T1->getVectorElementType()); + auto *I2 = cast<IntegerType>(T2->getVectorElementType()); return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2; } @@ -3161,21 +3355,22 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { // truncated version of `I` and reextend its result. InstCombine runs // later and will remove any ext/trunc pairs. // - for (auto &KV : MinBWs) { + SmallPtrSet<Value *, 4> Erased; + for (const auto &KV : *MinBWs) { VectorParts &Parts = WidenMap.get(KV.first); for (Value *&I : Parts) { - if (I->use_empty()) + if (Erased.count(I) || I->use_empty() || !isa<Instruction>(I)) continue; Type *OriginalTy = I->getType(); - Type *ScalarTruncatedTy = IntegerType::get(OriginalTy->getContext(), - KV.second); + Type *ScalarTruncatedTy = + IntegerType::get(OriginalTy->getContext(), KV.second); Type *TruncatedTy = VectorType::get(ScalarTruncatedTy, OriginalTy->getVectorNumElements()); if (TruncatedTy == OriginalTy) continue; IRBuilder<> B(cast<Instruction>(I)); - auto ShrinkOperand = [&](Value *V) -> Value* { + auto ShrinkOperand = [&](Value *V) -> Value * { if (auto *ZI = dyn_cast<ZExtInst>(V)) if (ZI->getSrcTy() == TruncatedTy) return ZI->getOperand(0); @@ -3185,50 +3380,59 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { // The actual instruction modification depends on the instruction type, // unfortunately. Value *NewI = nullptr; - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { - NewI = B.CreateBinOp(BO->getOpcode(), - ShrinkOperand(BO->getOperand(0)), + if (auto *BO = dyn_cast<BinaryOperator>(I)) { + NewI = B.CreateBinOp(BO->getOpcode(), ShrinkOperand(BO->getOperand(0)), ShrinkOperand(BO->getOperand(1))); cast<BinaryOperator>(NewI)->copyIRFlags(I); - } else if (ICmpInst *CI = dyn_cast<ICmpInst>(I)) { - NewI = B.CreateICmp(CI->getPredicate(), - ShrinkOperand(CI->getOperand(0)), - ShrinkOperand(CI->getOperand(1))); - } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { + } else if (auto *CI = dyn_cast<ICmpInst>(I)) { + NewI = + B.CreateICmp(CI->getPredicate(), ShrinkOperand(CI->getOperand(0)), + ShrinkOperand(CI->getOperand(1))); + } else if (auto *SI = dyn_cast<SelectInst>(I)) { NewI = B.CreateSelect(SI->getCondition(), ShrinkOperand(SI->getTrueValue()), ShrinkOperand(SI->getFalseValue())); - } else if (CastInst *CI = dyn_cast<CastInst>(I)) { + } else if (auto *CI = dyn_cast<CastInst>(I)) { switch (CI->getOpcode()) { - default: llvm_unreachable("Unhandled cast!"); + default: + llvm_unreachable("Unhandled cast!"); case Instruction::Trunc: NewI = ShrinkOperand(CI->getOperand(0)); break; case Instruction::SExt: - NewI = B.CreateSExtOrTrunc(CI->getOperand(0), - smallestIntegerVectorType(OriginalTy, - TruncatedTy)); + NewI = B.CreateSExtOrTrunc( + CI->getOperand(0), + smallestIntegerVectorType(OriginalTy, TruncatedTy)); break; case Instruction::ZExt: - NewI = B.CreateZExtOrTrunc(CI->getOperand(0), - smallestIntegerVectorType(OriginalTy, - TruncatedTy)); + NewI = B.CreateZExtOrTrunc( + CI->getOperand(0), + smallestIntegerVectorType(OriginalTy, TruncatedTy)); break; } - } else if (ShuffleVectorInst *SI = dyn_cast<ShuffleVectorInst>(I)) { + } else if (auto *SI = dyn_cast<ShuffleVectorInst>(I)) { auto Elements0 = SI->getOperand(0)->getType()->getVectorNumElements(); - auto *O0 = - B.CreateZExtOrTrunc(SI->getOperand(0), - VectorType::get(ScalarTruncatedTy, Elements0)); + auto *O0 = B.CreateZExtOrTrunc( + SI->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements0)); auto Elements1 = SI->getOperand(1)->getType()->getVectorNumElements(); - auto *O1 = - B.CreateZExtOrTrunc(SI->getOperand(1), - VectorType::get(ScalarTruncatedTy, Elements1)); + auto *O1 = B.CreateZExtOrTrunc( + SI->getOperand(1), VectorType::get(ScalarTruncatedTy, Elements1)); NewI = B.CreateShuffleVector(O0, O1, SI->getMask()); } else if (isa<LoadInst>(I)) { // Don't do anything with the operands, just extend the result. continue; + } else if (auto *IE = dyn_cast<InsertElementInst>(I)) { + auto Elements = IE->getOperand(0)->getType()->getVectorNumElements(); + auto *O0 = B.CreateZExtOrTrunc( + IE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); + auto *O1 = B.CreateZExtOrTrunc(IE->getOperand(1), ScalarTruncatedTy); + NewI = B.CreateInsertElement(O0, O1, IE->getOperand(2)); + } else if (auto *EE = dyn_cast<ExtractElementInst>(I)) { + auto Elements = EE->getOperand(0)->getType()->getVectorNumElements(); + auto *O0 = B.CreateZExtOrTrunc( + EE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); + NewI = B.CreateExtractElement(O0, EE->getOperand(2)); } else { llvm_unreachable("Unhandled instruction type!"); } @@ -3238,12 +3442,13 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy); I->replaceAllUsesWith(Res); cast<Instruction>(I)->eraseFromParent(); + Erased.insert(I); I = Res; } } // We'll have created a bunch of ZExts that are now parentless. Clean up. - for (auto &KV : MinBWs) { + for (const auto &KV : *MinBWs) { VectorParts &Parts = WidenMap.get(KV.first); for (Value *&I : Parts) { ZExtInst *Inst = dyn_cast<ZExtInst>(I); @@ -3266,15 +3471,14 @@ void InnerLoopVectorizer::vectorizeLoop() { //===------------------------------------------------===// Constant *Zero = Builder.getInt32(0); - // In order to support reduction variables we need to be able to vectorize - // Phi nodes. Phi nodes have cycles, so we need to vectorize them in two - // stages. First, we create a new vector PHI node with no incoming edges. - // We use this value when we vectorize all of the instructions that use the - // PHI. Next, after all of the instructions in the block are complete we - // add the new incoming edges to the PHI. At this point all of the - // instructions in the basic block are vectorized, so we can use them to - // construct the PHI. - PhiVector RdxPHIsToFix; + // In order to support recurrences we need to be able to vectorize Phi nodes. + // Phi nodes have cycles, so we need to vectorize them in two stages. First, + // we create a new vector PHI node with no incoming edges. We use this value + // when we vectorize all of the instructions that use the PHI. Next, after + // all of the instructions in the block are complete we add the new incoming + // edges to the PHI. At this point all of the instructions in the basic block + // are vectorized, so we can use them to construct the PHI. + PhiVector PHIsToFix; // Scan the loop in a topological order to ensure that defs are vectorized // before users. @@ -3282,33 +3486,32 @@ void InnerLoopVectorizer::vectorizeLoop() { DFS.perform(LI); // Vectorize all of the blocks in the original loop. - for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(), - be = DFS.endRPO(); bb != be; ++bb) - vectorizeBlockInLoop(*bb, &RdxPHIsToFix); + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) + vectorizeBlockInLoop(BB, &PHIsToFix); // Insert truncates and extends for any truncated instructions as hints to // InstCombine. if (VF > 1) truncateToMinimalBitwidths(); - - // At this point every instruction in the original loop is widened to - // a vector form. We are almost done. Now, we need to fix the PHI nodes - // that we vectorized. The PHI nodes are currently empty because we did - // not want to introduce cycles. Notice that the remaining PHI nodes - // that we need to fix are reduction variables. - - // Create the 'reduced' values for each of the induction vars. - // The reduced values are the vector values that we scalarize and combine - // after the loop is finished. - for (PhiVector::iterator it = RdxPHIsToFix.begin(), e = RdxPHIsToFix.end(); - it != e; ++it) { - PHINode *RdxPhi = *it; - assert(RdxPhi && "Unable to recover vectorized PHI"); - - // Find the reduction variable descriptor. - assert(Legal->isReductionVariable(RdxPhi) && + + // At this point every instruction in the original loop is widened to a + // vector form. Now we need to fix the recurrences in PHIsToFix. These PHI + // nodes are currently empty because we did not want to introduce cycles. + // This is the second stage of vectorizing recurrences. + for (PHINode *Phi : PHIsToFix) { + assert(Phi && "Unable to recover vectorized PHI"); + + // Handle first-order recurrences that need to be fixed. + if (Legal->isFirstOrderRecurrence(Phi)) { + fixFirstOrderRecurrence(Phi); + continue; + } + + // If the phi node is not a first-order recurrence, it must be a reduction. + // Get it's reduction variable descriptor. + assert(Legal->isReductionVariable(Phi) && "Unable to find the reduction variable"); - RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[RdxPhi]; + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[Phi]; RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind(); TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue(); @@ -3363,18 +3566,18 @@ void InnerLoopVectorizer::vectorizeLoop() { // Reductions do not have to start at zero. They can start with // any loop invariant values. - VectorParts &VecRdxPhi = WidenMap.get(RdxPhi); + VectorParts &VecRdxPhi = WidenMap.get(Phi); BasicBlock *Latch = OrigLoop->getLoopLatch(); - Value *LoopVal = RdxPhi->getIncomingValueForBlock(Latch); + Value *LoopVal = Phi->getIncomingValueForBlock(Latch); VectorParts &Val = getVectorValue(LoopVal); for (unsigned part = 0; part < UF; ++part) { // Make sure to add the reduction stat value only to the // first unroll part. Value *StartVal = (part == 0) ? VectorStart : Identity; - cast<PHINode>(VecRdxPhi[part])->addIncoming(StartVal, - LoopVectorPreHeader); - cast<PHINode>(VecRdxPhi[part])->addIncoming(Val[part], - LoopVectorBody.back()); + cast<PHINode>(VecRdxPhi[part]) + ->addIncoming(StartVal, LoopVectorPreHeader); + cast<PHINode>(VecRdxPhi[part]) + ->addIncoming(Val[part], LoopVectorBody); } // Before each round, move the insertion point right between @@ -3389,9 +3592,9 @@ void InnerLoopVectorizer::vectorizeLoop() { // If the vector reduction can be performed in a smaller type, we truncate // then extend the loop exit value to enable InstCombine to evaluate the // entire expression in the smaller type. - if (VF > 1 && RdxPhi->getType() != RdxDesc.getRecurrenceType()) { + if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) { Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); - Builder.SetInsertPoint(LoopVectorBody.back()->getTerminator()); + Builder.SetInsertPoint(LoopVectorBody->getTerminator()); for (unsigned part = 0; part < UF; ++part) { Value *Trunc = Builder.CreateTrunc(RdxParts[part], RdxVecTy); Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) @@ -3432,21 +3635,19 @@ void InnerLoopVectorizer::vectorizeLoop() { assert(isPowerOf2_32(VF) && "Reduction emission only supported for pow2 vectors!"); Value *TmpVec = ReducedPartRdx; - SmallVector<Constant*, 32> ShuffleMask(VF, nullptr); + SmallVector<Constant *, 32> ShuffleMask(VF, nullptr); for (unsigned i = VF; i != 1; i >>= 1) { // Move the upper half of the vector to the lower half. - for (unsigned j = 0; j != i/2; ++j) - ShuffleMask[j] = Builder.getInt32(i/2 + j); + for (unsigned j = 0; j != i / 2; ++j) + ShuffleMask[j] = Builder.getInt32(i / 2 + j); // Fill the rest of the mask with undef. - std::fill(&ShuffleMask[i/2], ShuffleMask.end(), + std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), UndefValue::get(Builder.getInt32Ty())); - Value *Shuf = - Builder.CreateShuffleVector(TmpVec, - UndefValue::get(TmpVec->getType()), - ConstantVector::get(ShuffleMask), - "rdx.shuf"); + Value *Shuf = Builder.CreateShuffleVector( + TmpVec, UndefValue::get(TmpVec->getType()), + ConstantVector::get(ShuffleMask), "rdx.shuf"); if (Op != Instruction::ICmp && Op != Instruction::FCmp) // Floating point operations had to be 'fast' to enable the reduction. @@ -3458,21 +3659,21 @@ void InnerLoopVectorizer::vectorizeLoop() { } // The result is in the first element of the vector. - ReducedPartRdx = Builder.CreateExtractElement(TmpVec, - Builder.getInt32(0)); + ReducedPartRdx = + Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); // If the reduction can be performed in a smaller type, we need to extend // the reduction to the wider type before we branch to the original loop. - if (RdxPhi->getType() != RdxDesc.getRecurrenceType()) + if (Phi->getType() != RdxDesc.getRecurrenceType()) ReducedPartRdx = RdxDesc.isSigned() - ? Builder.CreateSExt(ReducedPartRdx, RdxPhi->getType()) - : Builder.CreateZExt(ReducedPartRdx, RdxPhi->getType()); + ? Builder.CreateSExt(ReducedPartRdx, Phi->getType()) + : Builder.CreateZExt(ReducedPartRdx, Phi->getType()); } // Create a phi node that merges control-flow from the backedge-taken check // block and the middle block. - PHINode *BCBlockPhi = PHINode::Create(RdxPhi->getType(), 2, "bc.merge.rdx", + PHINode *BCBlockPhi = PHINode::Create(Phi->getType(), 2, "bc.merge.rdx", LoopScalarPreHeader->getTerminator()); for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I) BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]); @@ -3483,9 +3684,11 @@ void InnerLoopVectorizer::vectorizeLoop() { // We know that the loop is in LCSSA form. We need to update the // PHI nodes in the exit blocks. for (BasicBlock::iterator LEI = LoopExitBlock->begin(), - LEE = LoopExitBlock->end(); LEI != LEE; ++LEI) { + LEE = LoopExitBlock->end(); + LEI != LEE; ++LEI) { PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI); - if (!LCSSAPhi) break; + if (!LCSSAPhi) + break; // All PHINodes need to have a single entry edge, or two if // we already fixed them. @@ -3498,30 +3701,30 @@ void InnerLoopVectorizer::vectorizeLoop() { LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); break; } - }// end of the LCSSA phi scan. + } // end of the LCSSA phi scan. // Fix the scalar loop reduction variable with the incoming reduction sum // from the vector body and from the backedge value. int IncomingEdgeBlockIdx = - (RdxPhi)->getBasicBlockIndex(OrigLoop->getLoopLatch()); + Phi->getBasicBlockIndex(OrigLoop->getLoopLatch()); assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index"); // Pick the other block. int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1); - (RdxPhi)->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi); - (RdxPhi)->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); - }// end of for each redux variable. + Phi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi); + Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); + } // end of for each Phi in PHIsToFix. fixLCSSAPHIs(); // Make sure DomTree is updated. updateAnalysis(); - + // Predicate any stores. for (auto KV : PredicatedStores) { BasicBlock::iterator I(KV.first); auto *BB = SplitBlock(I->getParent(), &*std::next(I), DT, LI); auto *T = SplitBlockAndInsertIfThen(KV.second, &*I, /*Unreachable=*/false, - /*BranchWeights=*/nullptr, DT); + /*BranchWeights=*/nullptr, DT, LI); I->moveBefore(T); I->getParent()->setName("pred.store.if"); BB->setName("pred.store.continue"); @@ -3531,11 +3734,162 @@ void InnerLoopVectorizer::vectorizeLoop() { cse(LoopVectorBody); } +void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { + + // This is the second phase of vectorizing first-order recurrences. An + // overview of the transformation is described below. Suppose we have the + // following loop. + // + // for (int i = 0; i < n; ++i) + // b[i] = a[i] - a[i - 1]; + // + // There is a first-order recurrence on "a". For this loop, the shorthand + // scalar IR looks like: + // + // scalar.ph: + // s_init = a[-1] + // br scalar.body + // + // scalar.body: + // i = phi [0, scalar.ph], [i+1, scalar.body] + // s1 = phi [s_init, scalar.ph], [s2, scalar.body] + // s2 = a[i] + // b[i] = s2 - s1 + // br cond, scalar.body, ... + // + // In this example, s1 is a recurrence because it's value depends on the + // previous iteration. In the first phase of vectorization, we created a + // temporary value for s1. We now complete the vectorization and produce the + // shorthand vector IR shown below (for VF = 4, UF = 1). + // + // vector.ph: + // v_init = vector(..., ..., ..., a[-1]) + // br vector.body + // + // vector.body + // i = phi [0, vector.ph], [i+4, vector.body] + // v1 = phi [v_init, vector.ph], [v2, vector.body] + // v2 = a[i, i+1, i+2, i+3]; + // v3 = vector(v1(3), v2(0, 1, 2)) + // b[i, i+1, i+2, i+3] = v2 - v3 + // br cond, vector.body, middle.block + // + // middle.block: + // x = v2(3) + // br scalar.ph + // + // scalar.ph: + // s_init = phi [x, middle.block], [a[-1], otherwise] + // br scalar.body + // + // After execution completes the vector loop, we extract the next value of + // the recurrence (x) to use as the initial value in the scalar loop. + + // Get the original loop preheader and single loop latch. + auto *Preheader = OrigLoop->getLoopPreheader(); + auto *Latch = OrigLoop->getLoopLatch(); + + // Get the initial and previous values of the scalar recurrence. + auto *ScalarInit = Phi->getIncomingValueForBlock(Preheader); + auto *Previous = Phi->getIncomingValueForBlock(Latch); + + // Create a vector from the initial value. + auto *VectorInit = ScalarInit; + if (VF > 1) { + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + VectorInit = Builder.CreateInsertElement( + UndefValue::get(VectorType::get(VectorInit->getType(), VF)), VectorInit, + Builder.getInt32(VF - 1), "vector.recur.init"); + } + + // We constructed a temporary phi node in the first phase of vectorization. + // This phi node will eventually be deleted. + auto &PhiParts = getVectorValue(Phi); + Builder.SetInsertPoint(cast<Instruction>(PhiParts[0])); + + // Create a phi node for the new recurrence. The current value will either be + // the initial value inserted into a vector or loop-varying vector value. + auto *VecPhi = Builder.CreatePHI(VectorInit->getType(), 2, "vector.recur"); + VecPhi->addIncoming(VectorInit, LoopVectorPreHeader); + + // Get the vectorized previous value. We ensured the previous values was an + // instruction when detecting the recurrence. + auto &PreviousParts = getVectorValue(Previous); + + // Set the insertion point to be after this instruction. We ensured the + // previous value dominated all uses of the phi when detecting the + // recurrence. + Builder.SetInsertPoint( + &*++BasicBlock::iterator(cast<Instruction>(PreviousParts[UF - 1]))); + + // We will construct a vector for the recurrence by combining the values for + // the current and previous iterations. This is the required shuffle mask. + SmallVector<Constant *, 8> ShuffleMask(VF); + ShuffleMask[0] = Builder.getInt32(VF - 1); + for (unsigned I = 1; I < VF; ++I) + ShuffleMask[I] = Builder.getInt32(I + VF - 1); + + // The vector from which to take the initial value for the current iteration + // (actual or unrolled). Initially, this is the vector phi node. + Value *Incoming = VecPhi; + + // Shuffle the current and previous vector and update the vector parts. + for (unsigned Part = 0; Part < UF; ++Part) { + auto *Shuffle = + VF > 1 + ? Builder.CreateShuffleVector(Incoming, PreviousParts[Part], + ConstantVector::get(ShuffleMask)) + : Incoming; + PhiParts[Part]->replaceAllUsesWith(Shuffle); + cast<Instruction>(PhiParts[Part])->eraseFromParent(); + PhiParts[Part] = Shuffle; + Incoming = PreviousParts[Part]; + } + + // Fix the latch value of the new recurrence in the vector loop. + VecPhi->addIncoming(Incoming, LI->getLoopFor(LoopVectorBody)->getLoopLatch()); + + // Extract the last vector element in the middle block. This will be the + // initial value for the recurrence when jumping to the scalar loop. + auto *Extract = Incoming; + if (VF > 1) { + Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); + Extract = Builder.CreateExtractElement(Extract, Builder.getInt32(VF - 1), + "vector.recur.extract"); + } + + // Fix the initial value of the original recurrence in the scalar loop. + Builder.SetInsertPoint(&*LoopScalarPreHeader->begin()); + auto *Start = Builder.CreatePHI(Phi->getType(), 2, "scalar.recur.init"); + for (auto *BB : predecessors(LoopScalarPreHeader)) { + auto *Incoming = BB == LoopMiddleBlock ? Extract : ScalarInit; + Start->addIncoming(Incoming, BB); + } + + Phi->setIncomingValue(Phi->getBasicBlockIndex(LoopScalarPreHeader), Start); + Phi->setName("scalar.recur"); + + // Finally, fix users of the recurrence outside the loop. The users will need + // either the last value of the scalar recurrence or the last value of the + // vector recurrence we extracted in the middle block. Since the loop is in + // LCSSA form, we just need to find the phi node for the original scalar + // recurrence in the exit block, and then add an edge for the middle block. + for (auto &I : *LoopExitBlock) { + auto *LCSSAPhi = dyn_cast<PHINode>(&I); + if (!LCSSAPhi) + break; + if (LCSSAPhi->getIncomingValue(0) == Phi) { + LCSSAPhi->addIncoming(Extract, LoopMiddleBlock); + break; + } + } +} + void InnerLoopVectorizer::fixLCSSAPHIs() { - for (BasicBlock::iterator LEI = LoopExitBlock->begin(), - LEE = LoopExitBlock->end(); LEI != LEE; ++LEI) { - PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI); - if (!LCSSAPhi) break; + for (Instruction &LEI : *LoopExitBlock) { + auto *LCSSAPhi = dyn_cast<PHINode>(&LEI); + if (!LCSSAPhi) + break; if (LCSSAPhi->getNumIncomingValues() == 1) LCSSAPhi->addIncoming(UndefValue::get(LCSSAPhi->getType()), LoopMiddleBlock); @@ -3548,7 +3902,7 @@ InnerLoopVectorizer::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { "Invalid edge"); // Look for cached value. - std::pair<BasicBlock*, BasicBlock*> Edge(Src, Dst); + std::pair<BasicBlock *, BasicBlock *> Edge(Src, Dst); EdgeMaskCache::iterator ECEntryIt = MaskCache.find(Edge); if (ECEntryIt != MaskCache.end()) return ECEntryIt->second; @@ -3604,15 +3958,15 @@ InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) { void InnerLoopVectorizer::widenPHIInstruction( Instruction *PN, InnerLoopVectorizer::VectorParts &Entry, unsigned UF, unsigned VF, PhiVector *PV) { - PHINode* P = cast<PHINode>(PN); - // Handle reduction variables: - if (Legal->isReductionVariable(P)) { + PHINode *P = cast<PHINode>(PN); + // Handle recurrences. + if (Legal->isReductionVariable(P) || Legal->isFirstOrderRecurrence(P)) { for (unsigned part = 0; part < UF; ++part) { // This is phase one of vectorizing PHIs. - Type *VecTy = (VF == 1) ? PN->getType() : - VectorType::get(PN->getType(), VF); + Type *VecTy = + (VF == 1) ? PN->getType() : VectorType::get(PN->getType(), VF); Entry[part] = PHINode::Create( - VecTy, 2, "vec.phi", &*LoopVectorBody.back()->getFirstInsertionPt()); + VecTy, 2, "vec.phi", &*LoopVectorBody->getFirstInsertionPt()); } PV->push_back(P); return; @@ -3635,21 +3989,20 @@ void InnerLoopVectorizer::widenPHIInstruction( // SELECT(Mask2, In2, // ( ...))) for (unsigned In = 0; In < NumIncoming; In++) { - VectorParts Cond = createEdgeMask(P->getIncomingBlock(In), - P->getParent()); + VectorParts Cond = + createEdgeMask(P->getIncomingBlock(In), P->getParent()); VectorParts &In0 = getVectorValue(P->getIncomingValue(In)); for (unsigned part = 0; part < UF; ++part) { // We might have single edge PHIs (blocks) - use an identity // 'select' for the first PHI operand. if (In == 0) - Entry[part] = Builder.CreateSelect(Cond[part], In0[part], - In0[part]); + Entry[part] = Builder.CreateSelect(Cond[part], In0[part], In0[part]); else // Select between the current value and the previous incoming edge // based on the incoming mask. - Entry[part] = Builder.CreateSelect(Cond[part], In0[part], - Entry[part], "predphi"); + Entry[part] = Builder.CreateSelect(Cond[part], In0[part], Entry[part], + "predphi"); } } return; @@ -3657,85 +4010,68 @@ void InnerLoopVectorizer::widenPHIInstruction( // This PHINode must be an induction variable. // Make sure that we know about it. - assert(Legal->getInductionVars()->count(P) && - "Not an induction variable"); + assert(Legal->getInductionVars()->count(P) && "Not an induction variable"); InductionDescriptor II = Legal->getInductionVars()->lookup(P); + const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); // FIXME: The newly created binary instructions should contain nsw/nuw flags, // which can be found from the original scalar operations. switch (II.getKind()) { - case InductionDescriptor::IK_NoInduction: - llvm_unreachable("Unknown induction"); - case InductionDescriptor::IK_IntInduction: { - assert(P->getType() == II.getStartValue()->getType() && - "Types must match"); - // Handle other induction variables that are now based on the - // canonical one. - Value *V = Induction; - if (P != OldInduction) { - V = Builder.CreateSExtOrTrunc(Induction, P->getType()); - V = II.transform(Builder, V); - V->setName("offset.idx"); + case InductionDescriptor::IK_NoInduction: + llvm_unreachable("Unknown induction"); + case InductionDescriptor::IK_IntInduction: + return widenIntInduction(P, Entry); + case InductionDescriptor::IK_PtrInduction: + // Handle the pointer induction variable case. + assert(P->getType()->isPointerTy() && "Unexpected type."); + // This is the normalized GEP that starts counting at zero. + Value *PtrInd = Induction; + PtrInd = Builder.CreateSExtOrTrunc(PtrInd, II.getStep()->getType()); + // This is the vector of results. Notice that we don't generate + // vector geps because scalar geps result in better code. + for (unsigned part = 0; part < UF; ++part) { + if (VF == 1) { + int EltIndex = part; + Constant *Idx = ConstantInt::get(PtrInd->getType(), EltIndex); + Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); + Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); + SclrGep->setName("next.gep"); + Entry[part] = SclrGep; + continue; } - Value *Broadcasted = getBroadcastInstrs(V); - // After broadcasting the induction variable we need to make the vector - // consecutive by adding 0, 1, 2, etc. - for (unsigned part = 0; part < UF; ++part) - Entry[part] = getStepVector(Broadcasted, VF * part, II.getStepValue()); - return; - } - case InductionDescriptor::IK_PtrInduction: - // Handle the pointer induction variable case. - assert(P->getType()->isPointerTy() && "Unexpected type."); - // This is the normalized GEP that starts counting at zero. - Value *PtrInd = Induction; - PtrInd = Builder.CreateSExtOrTrunc(PtrInd, II.getStepValue()->getType()); - // This is the vector of results. Notice that we don't generate - // vector geps because scalar geps result in better code. - for (unsigned part = 0; part < UF; ++part) { - if (VF == 1) { - int EltIndex = part; - Constant *Idx = ConstantInt::get(PtrInd->getType(), EltIndex); - Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); - Value *SclrGep = II.transform(Builder, GlobalIdx); - SclrGep->setName("next.gep"); - Entry[part] = SclrGep; - continue; - } - Value *VecVal = UndefValue::get(VectorType::get(P->getType(), VF)); - for (unsigned int i = 0; i < VF; ++i) { - int EltIndex = i + part * VF; - Constant *Idx = ConstantInt::get(PtrInd->getType(), EltIndex); - Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); - Value *SclrGep = II.transform(Builder, GlobalIdx); - SclrGep->setName("next.gep"); - VecVal = Builder.CreateInsertElement(VecVal, SclrGep, - Builder.getInt32(i), - "insert.gep"); - } - Entry[part] = VecVal; + Value *VecVal = UndefValue::get(VectorType::get(P->getType(), VF)); + for (unsigned int i = 0; i < VF; ++i) { + int EltIndex = i + part * VF; + Constant *Idx = ConstantInt::get(PtrInd->getType(), EltIndex); + Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); + Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); + SclrGep->setName("next.gep"); + VecVal = Builder.CreateInsertElement(VecVal, SclrGep, + Builder.getInt32(i), "insert.gep"); } - return; + Entry[part] = VecVal; + } + return; } } void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // For each instruction in the old loop. - for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { - VectorParts &Entry = WidenMap.get(&*it); + for (Instruction &I : *BB) { + VectorParts &Entry = WidenMap.get(&I); - switch (it->getOpcode()) { + switch (I.getOpcode()) { case Instruction::Br: // Nothing to do for PHIs and BR, since we already took care of the // loop control flow instructions. continue; case Instruction::PHI: { // Vectorize PHINodes. - widenPHIInstruction(&*it, Entry, UF, VF, PV); + widenPHIInstruction(&I, Entry, UF, VF, PV); continue; - }// End of PHI. + } // End of PHI. case Instruction::Add: case Instruction::FAdd: @@ -3756,10 +4092,10 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { case Instruction::Or: case Instruction::Xor: { // Just widen binops. - BinaryOperator *BinOp = dyn_cast<BinaryOperator>(it); + auto *BinOp = cast<BinaryOperator>(&I); setDebugLocFromInst(Builder, BinOp); - VectorParts &A = getVectorValue(it->getOperand(0)); - VectorParts &B = getVectorValue(it->getOperand(1)); + VectorParts &A = getVectorValue(BinOp->getOperand(0)); + VectorParts &B = getVectorValue(BinOp->getOperand(1)); // Use this vector value for all users of the original instruction. for (unsigned Part = 0; Part < UF; ++Part) { @@ -3771,7 +4107,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { Entry[Part] = V; } - propagateMetadata(Entry, &*it); + addMetadata(Entry, BinOp); break; } case Instruction::Select: { @@ -3780,58 +4116,58 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // instruction with a scalar condition. Otherwise, use vector-select. auto *SE = PSE.getSE(); bool InvariantCond = - SE->isLoopInvariant(PSE.getSCEV(it->getOperand(0)), OrigLoop); - setDebugLocFromInst(Builder, &*it); + SE->isLoopInvariant(PSE.getSCEV(I.getOperand(0)), OrigLoop); + setDebugLocFromInst(Builder, &I); // The condition can be loop invariant but still defined inside the // loop. This means that we can't just use the original 'cond' value. // We have to take the 'vectorized' value and pick the first lane. // Instcombine will make this a no-op. - VectorParts &Cond = getVectorValue(it->getOperand(0)); - VectorParts &Op0 = getVectorValue(it->getOperand(1)); - VectorParts &Op1 = getVectorValue(it->getOperand(2)); - - Value *ScalarCond = (VF == 1) ? Cond[0] : - Builder.CreateExtractElement(Cond[0], Builder.getInt32(0)); + VectorParts &Cond = getVectorValue(I.getOperand(0)); + VectorParts &Op0 = getVectorValue(I.getOperand(1)); + VectorParts &Op1 = getVectorValue(I.getOperand(2)); + + Value *ScalarCond = + (VF == 1) + ? Cond[0] + : Builder.CreateExtractElement(Cond[0], Builder.getInt32(0)); for (unsigned Part = 0; Part < UF; ++Part) { Entry[Part] = Builder.CreateSelect( - InvariantCond ? ScalarCond : Cond[Part], - Op0[Part], - Op1[Part]); + InvariantCond ? ScalarCond : Cond[Part], Op0[Part], Op1[Part]); } - propagateMetadata(Entry, &*it); + addMetadata(Entry, &I); break; } case Instruction::ICmp: case Instruction::FCmp: { // Widen compares. Generate vector compares. - bool FCmp = (it->getOpcode() == Instruction::FCmp); - CmpInst *Cmp = dyn_cast<CmpInst>(it); - setDebugLocFromInst(Builder, &*it); - VectorParts &A = getVectorValue(it->getOperand(0)); - VectorParts &B = getVectorValue(it->getOperand(1)); + bool FCmp = (I.getOpcode() == Instruction::FCmp); + auto *Cmp = dyn_cast<CmpInst>(&I); + setDebugLocFromInst(Builder, Cmp); + VectorParts &A = getVectorValue(Cmp->getOperand(0)); + VectorParts &B = getVectorValue(Cmp->getOperand(1)); for (unsigned Part = 0; Part < UF; ++Part) { Value *C = nullptr; if (FCmp) { C = Builder.CreateFCmp(Cmp->getPredicate(), A[Part], B[Part]); - cast<FCmpInst>(C)->copyFastMathFlags(&*it); + cast<FCmpInst>(C)->copyFastMathFlags(Cmp); } else { C = Builder.CreateICmp(Cmp->getPredicate(), A[Part], B[Part]); } Entry[Part] = C; } - propagateMetadata(Entry, &*it); + addMetadata(Entry, &I); break; } case Instruction::Store: case Instruction::Load: - vectorizeMemoryInstruction(&*it); - break; + vectorizeMemoryInstruction(&I); + break; case Instruction::ZExt: case Instruction::SExt: case Instruction::FPToUI: @@ -3844,58 +4180,52 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - CastInst *CI = dyn_cast<CastInst>(it); - setDebugLocFromInst(Builder, &*it); - /// Optimize the special case where the source is the induction - /// variable. Notice that we can only optimize the 'trunc' case - /// because: a. FP conversions lose precision, b. sext/zext may wrap, - /// c. other casts depend on pointer size. - if (CI->getOperand(0) == OldInduction && - it->getOpcode() == Instruction::Trunc) { - Value *ScalarCast = Builder.CreateCast(CI->getOpcode(), Induction, - CI->getType()); - Value *Broadcasted = getBroadcastInstrs(ScalarCast); - InductionDescriptor II = - Legal->getInductionVars()->lookup(OldInduction); - Constant *Step = ConstantInt::getSigned( - CI->getType(), II.getStepValue()->getSExtValue()); - for (unsigned Part = 0; Part < UF; ++Part) - Entry[Part] = getStepVector(Broadcasted, VF * Part, Step); - propagateMetadata(Entry, &*it); + auto *CI = dyn_cast<CastInst>(&I); + setDebugLocFromInst(Builder, CI); + + // Optimize the special case where the source is a constant integer + // induction variable. Notice that we can only optimize the 'trunc' case + // because (a) FP conversions lose precision, (b) sext/zext may wrap, and + // (c) other casts depend on pointer size. + auto ID = Legal->getInductionVars()->lookup(OldInduction); + if (isa<TruncInst>(CI) && CI->getOperand(0) == OldInduction && + ID.getConstIntStepValue()) { + widenIntInduction(OldInduction, Entry, cast<TruncInst>(CI)); + addMetadata(Entry, &I); break; } + /// Vectorize casts. - Type *DestTy = (VF == 1) ? CI->getType() : - VectorType::get(CI->getType(), VF); + Type *DestTy = + (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF); - VectorParts &A = getVectorValue(it->getOperand(0)); + VectorParts &A = getVectorValue(CI->getOperand(0)); for (unsigned Part = 0; Part < UF; ++Part) Entry[Part] = Builder.CreateCast(CI->getOpcode(), A[Part], DestTy); - propagateMetadata(Entry, &*it); + addMetadata(Entry, &I); break; } case Instruction::Call: { // Ignore dbg intrinsics. - if (isa<DbgInfoIntrinsic>(it)) + if (isa<DbgInfoIntrinsic>(I)) break; - setDebugLocFromInst(Builder, &*it); + setDebugLocFromInst(Builder, &I); Module *M = BB->getParent()->getParent(); - CallInst *CI = cast<CallInst>(it); + auto *CI = cast<CallInst>(&I); StringRef FnName = CI->getCalledFunction()->getName(); Function *F = CI->getCalledFunction(); Type *RetTy = ToVectorTy(CI->getType(), VF); SmallVector<Type *, 4> Tys; - for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) - Tys.push_back(ToVectorTy(CI->getArgOperand(i)->getType(), VF)); - - Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); - if (ID && - (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || - ID == Intrinsic::lifetime_start)) { - scalarizeInstruction(&*it); + for (Value *ArgOperand : CI->arg_operands()) + Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); + + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || + ID == Intrinsic::lifetime_start)) { + scalarizeInstruction(&I); break; } // The flag shows whether we use Intrinsic or a usual Call for vectorized @@ -3906,7 +4236,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { bool UseVectorIntrinsic = ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; if (!UseVectorIntrinsic && NeedToScalarize) { - scalarizeInstruction(&*it); + scalarizeInstruction(&I); break; } @@ -3944,19 +4274,27 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { } } assert(VectorF && "Can't create vector function."); - Entry[Part] = Builder.CreateCall(VectorF, Args); + + SmallVector<OperandBundleDef, 1> OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); + + if (isa<FPMathOperator>(V)) + V->copyFastMathFlags(CI); + + Entry[Part] = V; } - propagateMetadata(Entry, &*it); + addMetadata(Entry, &I); break; } default: // All other instructions are unsupported. Scalarize them. - scalarizeInstruction(&*it); + scalarizeInstruction(&I); break; - }// end of switch. - }// end of for_each instr. + } // end of switch. + } // end of for_each instr. } void InnerLoopVectorizer::updateAnalysis() { @@ -3967,16 +4305,11 @@ void InnerLoopVectorizer::updateAnalysis() { assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) && "Entry does not dominate exit."); - for (unsigned I = 1, E = LoopBypassBlocks.size(); I != E; ++I) - DT->addNewBlock(LoopBypassBlocks[I], LoopBypassBlocks[I-1]); - DT->addNewBlock(LoopVectorPreHeader, LoopBypassBlocks.back()); - // We don't predicate stores by this point, so the vector body should be a // single loop. - assert(LoopVectorBody.size() == 1 && "Expected single block loop!"); - DT->addNewBlock(LoopVectorBody[0], LoopVectorPreHeader); + DT->addNewBlock(LoopVectorBody, LoopVectorPreHeader); - DT->addNewBlock(LoopMiddleBlock, LoopVectorBody.back()); + DT->addNewBlock(LoopMiddleBlock, LoopVectorBody); DT->addNewBlock(LoopScalarPreHeader, LoopBypassBlocks[0]); DT->changeImmediateDominator(LoopScalarBody, LoopScalarPreHeader); DT->changeImmediateDominator(LoopExitBlock, LoopBypassBlocks[0]); @@ -3989,12 +4322,12 @@ void InnerLoopVectorizer::updateAnalysis() { /// Phi nodes with constant expressions that can trap are not safe to if /// convert. static bool canIfConvertPHINodes(BasicBlock *BB) { - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { - PHINode *Phi = dyn_cast<PHINode>(I); + for (Instruction &I : *BB) { + auto *Phi = dyn_cast<PHINode>(&I); if (!Phi) return true; - for (unsigned p = 0, e = Phi->getNumIncomingValues(); p != e; ++p) - if (Constant *C = dyn_cast<Constant>(Phi->getIncomingValue(p))) + for (Value *V : Phi->incoming_values()) + if (auto *C = dyn_cast<Constant>(V)) if (C->canTrap()) return false; } @@ -4013,27 +4346,21 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { SmallPtrSet<Value *, 8> SafePointes; // Collect safe addresses. - for (Loop::block_iterator BI = TheLoop->block_begin(), - BE = TheLoop->block_end(); BI != BE; ++BI) { - BasicBlock *BB = *BI; - + for (BasicBlock *BB : TheLoop->blocks()) { if (blockNeedsPredication(BB)) continue; - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { - if (LoadInst *LI = dyn_cast<LoadInst>(I)) + for (Instruction &I : *BB) { + if (auto *LI = dyn_cast<LoadInst>(&I)) SafePointes.insert(LI->getPointerOperand()); - else if (StoreInst *SI = dyn_cast<StoreInst>(I)) + else if (auto *SI = dyn_cast<StoreInst>(&I)) SafePointes.insert(SI->getPointerOperand()); } } // Collect the blocks that need predication. BasicBlock *Header = TheLoop->getHeader(); - for (Loop::block_iterator BI = TheLoop->block_begin(), - BE = TheLoop->block_end(); BI != BE; ++BI) { - BasicBlock *BB = *BI; - + for (BasicBlock *BB : TheLoop->blocks()) { // We don't support switch statements inside loops. if (!isa<BranchInst>(BB->getTerminator())) { emitAnalysis(VectorizationReport(BB->getTerminator()) @@ -4063,9 +4390,8 @@ bool LoopVectorizationLegality::canVectorize() { // We must have a loop in canonical form. Loops with indirectbr in them cannot // be canonicalized. if (!TheLoop->getLoopPreheader()) { - emitAnalysis( - VectorizationReport() << - "loop control flow is not understood by vectorizer"); + emitAnalysis(VectorizationReport() + << "loop control flow is not understood by vectorizer"); return false; } @@ -4077,17 +4403,15 @@ bool LoopVectorizationLegality::canVectorize() { // We must have a single backedge. if (TheLoop->getNumBackEdges() != 1) { - emitAnalysis( - VectorizationReport() << - "loop control flow is not understood by vectorizer"); + emitAnalysis(VectorizationReport() + << "loop control flow is not understood by vectorizer"); return false; } // We must have a single exiting block. if (!TheLoop->getExitingBlock()) { - emitAnalysis( - VectorizationReport() << - "loop control flow is not understood by vectorizer"); + emitAnalysis(VectorizationReport() + << "loop control flow is not understood by vectorizer"); return false; } @@ -4095,15 +4419,14 @@ bool LoopVectorizationLegality::canVectorize() { // checked at the end of each iteration. With that we can assume that all // instructions in the loop are executed the same number of times. if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { - emitAnalysis( - VectorizationReport() << - "loop control flow is not understood by vectorizer"); + emitAnalysis(VectorizationReport() + << "loop control flow is not understood by vectorizer"); return false; } // We need to have a loop header. - DEBUG(dbgs() << "LV: Found a loop: " << - TheLoop->getHeader()->getName() << '\n'); + DEBUG(dbgs() << "LV: Found a loop: " << TheLoop->getHeader()->getName() + << '\n'); // Check if we can if-convert non-single-bb loops. unsigned NumBlocks = TheLoop->getNumBlocks(); @@ -4113,7 +4436,7 @@ bool LoopVectorizationLegality::canVectorize() { } // ScalarEvolution needs to be able to find the exit count. - const SCEV *ExitCount = PSE.getSE()->getBackedgeTakenCount(TheLoop); + const SCEV *ExitCount = PSE.getBackedgeTakenCount(); if (ExitCount == PSE.getSE()->getCouldNotCompute()) { emitAnalysis(VectorizationReport() << "could not determine number of loop iterations"); @@ -4150,7 +4473,7 @@ bool LoopVectorizationLegality::canVectorize() { // Analyze interleaved memory accesses. if (UseInterleaved) - InterleaveInfo.analyzeInterleaving(Strides); + InterleaveInfo.analyzeInterleaving(*getSymbolicStrides()); unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) @@ -4182,7 +4505,7 @@ static Type *convertPointerToIntegerType(const DataLayout &DL, Type *Ty) { return Ty; } -static Type* getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { +static Type *getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { Ty0 = convertPointerToIntegerType(DL, Ty0); Ty1 = convertPointerToIntegerType(DL, Ty1); if (Ty0->getScalarSizeInBits() > Ty1->getScalarSizeInBits()) @@ -4193,11 +4516,11 @@ static Type* getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { /// \brief Check that the instruction has outside loop users and is not an /// identified reduction variable. static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, - SmallPtrSetImpl<Value *> &Reductions) { - // Reduction instructions are allowed to have exit users. All other - // instructions must not have external users. - if (!Reductions.count(Inst)) - //Check that all of the users of the loop are inside the BB. + SmallPtrSetImpl<Value *> &AllowedExit) { + // Reduction and Induction instructions are allowed to have exit users. All + // other instructions must not have external users. + if (!AllowedExit.count(Inst)) + // Check that all of the users of the loop are inside the BB. for (User *U : Inst->users()) { Instruction *UI = cast<Instruction>(U); // This user may be a reduction exit value. @@ -4209,31 +4532,61 @@ static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, return false; } +void LoopVectorizationLegality::addInductionPhi( + PHINode *Phi, const InductionDescriptor &ID, + SmallPtrSetImpl<Value *> &AllowedExit) { + Inductions[Phi] = ID; + Type *PhiTy = Phi->getType(); + const DataLayout &DL = Phi->getModule()->getDataLayout(); + + // Get the widest type. + if (!WidestIndTy) + WidestIndTy = convertPointerToIntegerType(DL, PhiTy); + else + WidestIndTy = getWiderType(DL, PhiTy, WidestIndTy); + + // Int inductions are special because we only allow one IV. + if (ID.getKind() == InductionDescriptor::IK_IntInduction && + ID.getConstIntStepValue() && + ID.getConstIntStepValue()->isOne() && + isa<Constant>(ID.getStartValue()) && + cast<Constant>(ID.getStartValue())->isNullValue()) { + + // Use the phi node with the widest type as induction. Use the last + // one if there are multiple (no good reason for doing this other + // than it is expedient). We've checked that it begins at zero and + // steps by one, so this is a canonical induction variable. + if (!Induction || PhiTy == WidestIndTy) + Induction = Phi; + } + + // Both the PHI node itself, and the "post-increment" value feeding + // back into the PHI node may have external users. + AllowedExit.insert(Phi); + AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); + + DEBUG(dbgs() << "LV: Found an induction variable.\n"); + return; +} + bool LoopVectorizationLegality::canVectorizeInstrs() { BasicBlock *Header = TheLoop->getHeader(); // Look for the attribute signaling the absence of NaNs. Function &F = *Header->getParent(); - const DataLayout &DL = F.getParent()->getDataLayout(); - if (F.hasFnAttribute("no-nans-fp-math")) - HasFunNoNaNAttr = - F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; + HasFunNoNaNAttr = + F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; // For each block in the loop. - for (Loop::block_iterator bb = TheLoop->block_begin(), - be = TheLoop->block_end(); bb != be; ++bb) { - + for (BasicBlock *BB : TheLoop->blocks()) { // Scan the instructions in the block and look for hazards. - for (BasicBlock::iterator it = (*bb)->begin(), e = (*bb)->end(); it != e; - ++it) { - - if (PHINode *Phi = dyn_cast<PHINode>(it)) { + for (Instruction &I : *BB) { + if (auto *Phi = dyn_cast<PHINode>(&I)) { Type *PhiTy = Phi->getType(); // Check that this PHI type is allowed. - if (!PhiTy->isIntegerTy() && - !PhiTy->isFloatingPointTy() && + if (!PhiTy->isIntegerTy() && !PhiTy->isFloatingPointTy() && !PhiTy->isPointerTy()) { - emitAnalysis(VectorizationReport(&*it) + emitAnalysis(VectorizationReport(Phi) << "loop control flow is not understood by vectorizer"); DEBUG(dbgs() << "LV: Found an non-int non-pointer PHI.\n"); return false; @@ -4242,61 +4595,25 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // If this PHINode is not in the header block, then we know that we // can convert it to select during if-conversion. No need to check if // the PHIs in this block are induction or reduction variables. - if (*bb != Header) { + if (BB != Header) { // Check that this instruction has no outside users or is an // identified reduction value with an outside user. - if (!hasOutsideLoopUser(TheLoop, &*it, AllowedExit)) + if (!hasOutsideLoopUser(TheLoop, Phi, AllowedExit)) continue; - emitAnalysis(VectorizationReport(&*it) << - "value could not be identified as " - "an induction or reduction variable"); + emitAnalysis(VectorizationReport(Phi) + << "value could not be identified as " + "an induction or reduction variable"); return false; } // We only allow if-converted PHIs with exactly two incoming values. if (Phi->getNumIncomingValues() != 2) { - emitAnalysis(VectorizationReport(&*it) + emitAnalysis(VectorizationReport(Phi) << "control flow not understood by vectorizer"); DEBUG(dbgs() << "LV: Found an invalid PHI.\n"); return false; } - InductionDescriptor ID; - if (InductionDescriptor::isInductionPHI(Phi, PSE.getSE(), ID)) { - Inductions[Phi] = ID; - // Get the widest type. - if (!WidestIndTy) - WidestIndTy = convertPointerToIntegerType(DL, PhiTy); - else - WidestIndTy = getWiderType(DL, PhiTy, WidestIndTy); - - // Int inductions are special because we only allow one IV. - if (ID.getKind() == InductionDescriptor::IK_IntInduction && - ID.getStepValue()->isOne() && - isa<Constant>(ID.getStartValue()) && - cast<Constant>(ID.getStartValue())->isNullValue()) { - // Use the phi node with the widest type as induction. Use the last - // one if there are multiple (no good reason for doing this other - // than it is expedient). We've checked that it begins at zero and - // steps by one, so this is a canonical induction variable. - if (!Induction || PhiTy == WidestIndTy) - Induction = Phi; - } - - DEBUG(dbgs() << "LV: Found an induction variable.\n"); - - // Until we explicitly handle the case of an induction variable with - // an outside loop user we have to give up vectorizing this loop. - if (hasOutsideLoopUser(TheLoop, &*it, AllowedExit)) { - emitAnalysis(VectorizationReport(&*it) << - "use of induction value outside of the " - "loop is not handled by vectorizer"); - return false; - } - - continue; - } - RecurrenceDescriptor RedDes; if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, RedDes)) { if (RedDes.hasUnsafeAlgebra()) @@ -4306,22 +4623,41 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { continue; } - emitAnalysis(VectorizationReport(&*it) << - "value that could not be identified as " - "reduction is used outside the loop"); - DEBUG(dbgs() << "LV: Found an unidentified PHI."<< *Phi <<"\n"); + InductionDescriptor ID; + if (InductionDescriptor::isInductionPHI(Phi, PSE, ID)) { + addInductionPhi(Phi, ID, AllowedExit); + continue; + } + + if (RecurrenceDescriptor::isFirstOrderRecurrence(Phi, TheLoop, DT)) { + FirstOrderRecurrences.insert(Phi); + continue; + } + + // As a last resort, coerce the PHI to a AddRec expression + // and re-try classifying it a an induction PHI. + if (InductionDescriptor::isInductionPHI(Phi, PSE, ID, true)) { + addInductionPhi(Phi, ID, AllowedExit); + continue; + } + + emitAnalysis(VectorizationReport(Phi) + << "value that could not be identified as " + "reduction is used outside the loop"); + DEBUG(dbgs() << "LV: Found an unidentified PHI." << *Phi << "\n"); return false; - }// end of PHI handling + } // end of PHI handling // We handle calls that: // * Are debug info intrinsics. // * Have a mapping to an IR intrinsic. // * Have a vector version available. - CallInst *CI = dyn_cast<CallInst>(it); - if (CI && !getIntrinsicIDForCall(CI, TLI) && !isa<DbgInfoIntrinsic>(CI) && + auto *CI = dyn_cast<CallInst>(&I); + if (CI && !getVectorIntrinsicIDForCall(CI, TLI) && + !isa<DbgInfoIntrinsic>(CI) && !(CI->getCalledFunction() && TLI && TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()))) { - emitAnalysis(VectorizationReport(&*it) + emitAnalysis(VectorizationReport(CI) << "call instruction cannot be vectorized"); DEBUG(dbgs() << "LV: Found a non-intrinsic, non-libfunc callsite.\n"); return false; @@ -4329,11 +4665,11 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Intrinsics such as powi,cttz and ctlz are legal to vectorize if the // second argument is the same (i.e. loop invariant) - if (CI && - hasVectorInstrinsicScalarOpd(getIntrinsicIDForCall(CI, TLI), 1)) { + if (CI && hasVectorInstrinsicScalarOpd( + getVectorIntrinsicIDForCall(CI, TLI), 1)) { auto *SE = PSE.getSE(); if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) { - emitAnalysis(VectorizationReport(&*it) + emitAnalysis(VectorizationReport(CI) << "intrinsic instruction cannot be vectorized"); DEBUG(dbgs() << "LV: Found unvectorizable intrinsic " << *CI << "\n"); return false; @@ -4342,40 +4678,44 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Check that the instruction return type is vectorizable. // Also, we can't vectorize extractelement instructions. - if ((!VectorType::isValidElementType(it->getType()) && - !it->getType()->isVoidTy()) || isa<ExtractElementInst>(it)) { - emitAnalysis(VectorizationReport(&*it) + if ((!VectorType::isValidElementType(I.getType()) && + !I.getType()->isVoidTy()) || + isa<ExtractElementInst>(I)) { + emitAnalysis(VectorizationReport(&I) << "instruction return type cannot be vectorized"); DEBUG(dbgs() << "LV: Found unvectorizable type.\n"); return false; } // Check that the stored type is vectorizable. - if (StoreInst *ST = dyn_cast<StoreInst>(it)) { + if (auto *ST = dyn_cast<StoreInst>(&I)) { Type *T = ST->getValueOperand()->getType(); if (!VectorType::isValidElementType(T)) { - emitAnalysis(VectorizationReport(ST) << - "store instruction cannot be vectorized"); + emitAnalysis(VectorizationReport(ST) + << "store instruction cannot be vectorized"); return false; } - if (EnableMemAccessVersioning) - collectStridedAccess(ST); - } - if (EnableMemAccessVersioning) - if (LoadInst *LI = dyn_cast<LoadInst>(it)) - collectStridedAccess(LI); + // FP instructions can allow unsafe algebra, thus vectorizable by + // non-IEEE-754 compliant SIMD units. + // This applies to floating-point math operations and calls, not memory + // operations, shuffles, or casts, as they don't change precision or + // semantics. + } else if (I.getType()->isFloatingPointTy() && (CI || I.isBinaryOp()) && + !I.hasUnsafeAlgebra()) { + DEBUG(dbgs() << "LV: Found FP op with unsafe algebra.\n"); + Hints->setPotentiallyUnsafe(); + } // Reduction instructions are allowed to have exit users. // All other instructions must not have external users. - if (hasOutsideLoopUser(TheLoop, &*it, AllowedExit)) { - emitAnalysis(VectorizationReport(&*it) << - "value cannot be used outside the loop"); + if (hasOutsideLoopUser(TheLoop, &I, AllowedExit)) { + emitAnalysis(VectorizationReport(&I) + << "value cannot be used outside the loop"); return false; } } // next instr. - } if (!Induction) { @@ -4396,64 +4736,90 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { return true; } -void LoopVectorizationLegality::collectStridedAccess(Value *MemAccess) { - Value *Ptr = nullptr; - if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess)) - Ptr = LI->getPointerOperand(); - else if (StoreInst *SI = dyn_cast<StoreInst>(MemAccess)) - Ptr = SI->getPointerOperand(); - else - return; - - Value *Stride = getStrideFromPointer(Ptr, PSE.getSE(), TheLoop); - if (!Stride) - return; - - DEBUG(dbgs() << "LV: Found a strided access that we can version"); - DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); - Strides[Ptr] = Stride; - StrideSet.insert(Stride); -} - void LoopVectorizationLegality::collectLoopUniforms() { // We now know that the loop is vectorizable! // Collect variables that will remain uniform after vectorization. - std::vector<Value*> Worklist; - BasicBlock *Latch = TheLoop->getLoopLatch(); - // Start with the conditional branch and walk up the block. - Worklist.push_back(Latch->getTerminator()->getOperand(0)); + // If V is not an instruction inside the current loop, it is a Value + // outside of the scope which we are interesting in. + auto isOutOfScope = [&](Value *V) -> bool { + Instruction *I = dyn_cast<Instruction>(V); + return (!I || !TheLoop->contains(I)); + }; + + SetVector<Instruction *> Worklist; + BasicBlock *Latch = TheLoop->getLoopLatch(); + // Start with the conditional branch. + if (!isOutOfScope(Latch->getTerminator()->getOperand(0))) { + Instruction *Cmp = cast<Instruction>(Latch->getTerminator()->getOperand(0)); + Worklist.insert(Cmp); + DEBUG(dbgs() << "LV: Found uniform instruction: " << *Cmp << "\n"); + } // Also add all consecutive pointer values; these values will be uniform - // after vectorization (and subsequent cleanup) and, until revectorization is - // supported, all dependencies must also be uniform. - for (Loop::block_iterator B = TheLoop->block_begin(), - BE = TheLoop->block_end(); B != BE; ++B) - for (BasicBlock::iterator I = (*B)->begin(), IE = (*B)->end(); - I != IE; ++I) - if (I->getType()->isPointerTy() && isConsecutivePtr(&*I)) - Worklist.insert(Worklist.end(), I->op_begin(), I->op_end()); - - while (!Worklist.empty()) { - Instruction *I = dyn_cast<Instruction>(Worklist.back()); - Worklist.pop_back(); - - // Look at instructions inside this loop. - // Stop when reaching PHI nodes. - // TODO: we need to follow values all over the loop, not only in this block. - if (!I || !TheLoop->contains(I) || isa<PHINode>(I)) - continue; + // after vectorization (and subsequent cleanup). + for (auto *BB : TheLoop->blocks()) { + for (auto &I : *BB) { + if (I.getType()->isPointerTy() && isConsecutivePtr(&I)) { + Worklist.insert(&I); + DEBUG(dbgs() << "LV: Found uniform instruction: " << I << "\n"); + } + } + } - // This is a known uniform. - Uniforms.insert(I); + // Expand Worklist in topological order: whenever a new instruction + // is added , its users should be either already inside Worklist, or + // out of scope. It ensures a uniform instruction will only be used + // by uniform instructions or out of scope instructions. + unsigned idx = 0; + do { + Instruction *I = Worklist[idx++]; - // Insert all operands. - Worklist.insert(Worklist.end(), I->op_begin(), I->op_end()); + for (auto OV : I->operand_values()) { + if (isOutOfScope(OV)) + continue; + auto *OI = cast<Instruction>(OV); + if (all_of(OI->users(), [&](User *U) -> bool { + return isOutOfScope(U) || Worklist.count(cast<Instruction>(U)); + })) { + Worklist.insert(OI); + DEBUG(dbgs() << "LV: Found uniform instruction: " << *OI << "\n"); + } + } + } while (idx != Worklist.size()); + + // For an instruction to be added into Worklist above, all its users inside + // the current loop should be already added into Worklist. This condition + // cannot be true for phi instructions which is always in a dependence loop. + // Because any instruction in the dependence cycle always depends on others + // in the cycle to be added into Worklist first, the result is no ones in + // the cycle will be added into Worklist in the end. + // That is why we process PHI separately. + for (auto &Induction : *getInductionVars()) { + auto *PN = Induction.first; + auto *UpdateV = PN->getIncomingValueForBlock(TheLoop->getLoopLatch()); + if (all_of(PN->users(), + [&](User *U) -> bool { + return U == UpdateV || isOutOfScope(U) || + Worklist.count(cast<Instruction>(U)); + }) && + all_of(UpdateV->users(), [&](User *U) -> bool { + return U == PN || isOutOfScope(U) || + Worklist.count(cast<Instruction>(U)); + })) { + Worklist.insert(cast<Instruction>(PN)); + Worklist.insert(cast<Instruction>(UpdateV)); + DEBUG(dbgs() << "LV: Found uniform instruction: " << *PN << "\n"); + DEBUG(dbgs() << "LV: Found uniform instruction: " << *UpdateV << "\n"); + } } + + Uniforms.insert(Worklist.begin(), Worklist.end()); } bool LoopVectorizationLegality::canVectorizeMemory() { - LAI = &LAA->getInfo(TheLoop, Strides); + LAI = &(*GetLAA)(*TheLoop); + InterleaveInfo.setLAI(LAI); auto &OptionalReport = LAI->getReport(); if (OptionalReport) emitAnalysis(VectorizationReport(*OptionalReport)); @@ -4469,13 +4835,13 @@ bool LoopVectorizationLegality::canVectorizeMemory() { } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); - PSE.addPredicate(LAI->PSE.getUnionPredicate()); + PSE.addPredicate(LAI->getPSE().getUnionPredicate()); return true; } bool LoopVectorizationLegality::isInductionVariable(const Value *V) { - Value *In0 = const_cast<Value*>(V); + Value *In0 = const_cast<Value *>(V); PHINode *PN = dyn_cast_or_null<PHINode>(In0); if (!PN) return false; @@ -4483,67 +4849,73 @@ bool LoopVectorizationLegality::isInductionVariable(const Value *V) { return Inductions.count(PN); } -bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) { +bool LoopVectorizationLegality::isFirstOrderRecurrence(const PHINode *Phi) { + return FirstOrderRecurrences.count(Phi); +} + +bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) { return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); } -bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB, - SmallPtrSetImpl<Value *> &SafePtrs) { - - for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { +bool LoopVectorizationLegality::blockCanBePredicated( + BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs) { + const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); + + for (Instruction &I : *BB) { // Check that we don't have a constant expression that can trap as operand. - for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); - OI != OE; ++OI) { - if (Constant *C = dyn_cast<Constant>(*OI)) + for (Value *Operand : I.operands()) { + if (auto *C = dyn_cast<Constant>(Operand)) if (C->canTrap()) return false; } // We might be able to hoist the load. - if (it->mayReadFromMemory()) { - LoadInst *LI = dyn_cast<LoadInst>(it); + if (I.mayReadFromMemory()) { + auto *LI = dyn_cast<LoadInst>(&I); if (!LI) return false; if (!SafePtrs.count(LI->getPointerOperand())) { - if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand())) { + if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand()) || + isLegalMaskedGather(LI->getType())) { MaskedOp.insert(LI); continue; } + // !llvm.mem.parallel_loop_access implies if-conversion safety. + if (IsAnnotatedParallel) + continue; return false; } } // We don't predicate stores at the moment. - if (it->mayWriteToMemory()) { - StoreInst *SI = dyn_cast<StoreInst>(it); + if (I.mayWriteToMemory()) { + auto *SI = dyn_cast<StoreInst>(&I); // We only support predication of stores in basic blocks with one // predecessor. if (!SI) return false; + // Build a masked store if it is legal for the target. + if (isLegalMaskedStore(SI->getValueOperand()->getType(), + SI->getPointerOperand()) || + isLegalMaskedScatter(SI->getValueOperand()->getType())) { + MaskedOp.insert(SI); + continue; + } + bool isSafePtr = (SafePtrs.count(SI->getPointerOperand()) != 0); bool isSinglePredecessor = SI->getParent()->getSinglePredecessor(); - + if (++NumPredStores > NumberOfStoresToPredicate || !isSafePtr || - !isSinglePredecessor) { - // Build a masked store if it is legal for the target, otherwise - // scalarize the block. - bool isLegalMaskedOp = - isLegalMaskedStore(SI->getValueOperand()->getType(), - SI->getPointerOperand()); - if (isLegalMaskedOp) { - --NumPredStores; - MaskedOp.insert(SI); - continue; - } + !isSinglePredecessor) return false; - } } - if (it->mayThrow()) + if (I.mayThrow()) return false; // The instructions below can trap. - switch (it->getOpcode()) { - default: continue; + switch (I.getOpcode()) { + default: + continue; case Instruction::UDiv: case Instruction::SDiv: case Instruction::URem: @@ -4555,199 +4927,273 @@ bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB, return true; } -void InterleavedAccessInfo::collectConstStridedAccesses( - MapVector<Instruction *, StrideDescriptor> &StrideAccesses, +void InterleavedAccessInfo::collectConstStrideAccesses( + MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, const ValueToValueMap &Strides) { - // Holds load/store instructions in program order. - SmallVector<Instruction *, 16> AccessList; - for (auto *BB : TheLoop->getBlocks()) { - bool IsPred = LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); + auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + // Since it's desired that the load/store instructions be maintained in + // "program order" for the interleaved access analysis, we have to visit the + // blocks in the loop in reverse postorder (i.e., in a topological order). + // Such an ordering will ensure that any load/store that may be executed + // before a second load/store will precede the second load/store in + // AccessStrideInfo. + LoopBlocksDFS DFS(TheLoop); + DFS.perform(LI); + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) for (auto &I : *BB) { - if (!isa<LoadInst>(&I) && !isa<StoreInst>(&I)) + auto *LI = dyn_cast<LoadInst>(&I); + auto *SI = dyn_cast<StoreInst>(&I); + if (!LI && !SI) continue; - // FIXME: Currently we can't handle mixed accesses and predicated accesses - if (IsPred) - return; - - AccessList.push_back(&I); - } - } - - if (AccessList.empty()) - return; - - auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); - for (auto I : AccessList) { - LoadInst *LI = dyn_cast<LoadInst>(I); - StoreInst *SI = dyn_cast<StoreInst>(I); - - Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); - int Stride = isStridedPtr(PSE, Ptr, TheLoop, Strides); - - // The factor of the corresponding interleave group. - unsigned Factor = std::abs(Stride); - // Ignore the access if the factor is too small or too large. - if (Factor < 2 || Factor > MaxInterleaveGroupFactor) - continue; + Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); + int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides); - const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); - PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType()); - unsigned Size = DL.getTypeAllocSize(PtrTy->getElementType()); + const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType()); + uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); - // An alignment of 0 means target ABI alignment. - unsigned Align = LI ? LI->getAlignment() : SI->getAlignment(); - if (!Align) - Align = DL.getABITypeAlignment(PtrTy->getElementType()); + // An alignment of 0 means target ABI alignment. + unsigned Align = LI ? LI->getAlignment() : SI->getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(PtrTy->getElementType()); - StrideAccesses[I] = StrideDescriptor(Stride, Scev, Size, Align); - } + AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size, Align); + } } -// Analyze interleaved accesses and collect them into interleave groups. +// Analyze interleaved accesses and collect them into interleaved load and +// store groups. +// +// When generating code for an interleaved load group, we effectively hoist all +// loads in the group to the location of the first load in program order. When +// generating code for an interleaved store group, we sink all stores to the +// location of the last store. This code motion can change the order of load +// and store instructions and may break dependences. +// +// The code generation strategy mentioned above ensures that we won't violate +// any write-after-read (WAR) dependences. // -// Notice that the vectorization on interleaved groups will change instruction -// orders and may break dependences. But the memory dependence check guarantees -// that there is no overlap between two pointers of different strides, element -// sizes or underlying bases. +// E.g., for the WAR dependence: a = A[i]; // (1) +// A[i] = b; // (2) // -// For pointers sharing the same stride, element size and underlying base, no -// need to worry about Read-After-Write dependences and Write-After-Read +// The store group of (2) is always inserted at or below (2), and the load +// group of (1) is always inserted at or above (1). Thus, the instructions will +// never be reordered. All other dependences are checked to ensure the +// correctness of the instruction reordering. +// +// The algorithm visits all memory accesses in the loop in bottom-up program +// order. Program order is established by traversing the blocks in the loop in +// reverse postorder when collecting the accesses. +// +// We visit the memory accesses in bottom-up order because it can simplify the +// construction of store groups in the presence of write-after-write (WAW) // dependences. // -// E.g. The RAW dependence: A[i] = a; -// b = A[i]; -// This won't exist as it is a store-load forwarding conflict, which has -// already been checked and forbidden in the dependence check. +// E.g., for the WAW dependence: A[i] = a; // (1) +// A[i] = b; // (2) +// A[i + 1] = c; // (3) // -// E.g. The WAR dependence: a = A[i]; // (1) -// A[i] = b; // (2) -// The store group of (2) is always inserted at or below (2), and the load group -// of (1) is always inserted at or above (1). The dependence is safe. +// We will first create a store group with (3) and (2). (1) can't be added to +// this group because it and (2) are dependent. However, (1) can be grouped +// with other accesses that may precede it in program order. Note that a +// bottom-up order does not imply that WAW dependences should not be checked. void InterleavedAccessInfo::analyzeInterleaving( const ValueToValueMap &Strides) { DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); - // Holds all the stride accesses. - MapVector<Instruction *, StrideDescriptor> StrideAccesses; - collectConstStridedAccesses(StrideAccesses, Strides); + // Holds all accesses with a constant stride. + MapVector<Instruction *, StrideDescriptor> AccessStrideInfo; + collectConstStrideAccesses(AccessStrideInfo, Strides); - if (StrideAccesses.empty()) + if (AccessStrideInfo.empty()) return; + // Collect the dependences in the loop. + collectDependences(); + // Holds all interleaved store groups temporarily. SmallSetVector<InterleaveGroup *, 4> StoreGroups; // Holds all interleaved load groups temporarily. SmallSetVector<InterleaveGroup *, 4> LoadGroups; - // Search the load-load/write-write pair B-A in bottom-up order and try to - // insert B into the interleave group of A according to 3 rules: - // 1. A and B have the same stride. - // 2. A and B have the same memory object size. - // 3. B belongs to the group according to the distance. + // Search in bottom-up program order for pairs of accesses (A and B) that can + // form interleaved load or store groups. In the algorithm below, access A + // precedes access B in program order. We initialize a group for B in the + // outer loop of the algorithm, and then in the inner loop, we attempt to + // insert each A into B's group if: + // + // 1. A and B have the same stride, + // 2. A and B have the same memory object size, and + // 3. A belongs in B's group according to its distance from B. // - // The bottom-up order can avoid breaking the Write-After-Write dependences - // between two pointers of the same base. - // E.g. A[i] = a; (1) - // A[i] = b; (2) - // A[i+1] = c (3) - // We form the group (2)+(3) in front, so (1) has to form groups with accesses - // above (1), which guarantees that (1) is always above (2). - for (auto I = StrideAccesses.rbegin(), E = StrideAccesses.rend(); I != E; - ++I) { - Instruction *A = I->first; - StrideDescriptor DesA = I->second; - - InterleaveGroup *Group = getInterleaveGroup(A); - if (!Group) { - DEBUG(dbgs() << "LV: Creating an interleave group with:" << *A << '\n'); - Group = createInterleaveGroup(A, DesA.Stride, DesA.Align); + // Special care is taken to ensure group formation will not break any + // dependences. + for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend(); + BI != E; ++BI) { + Instruction *B = BI->first; + StrideDescriptor DesB = BI->second; + + // Initialize a group for B if it has an allowable stride. Even if we don't + // create a group for B, we continue with the bottom-up algorithm to ensure + // we don't break any of B's dependences. + InterleaveGroup *Group = nullptr; + if (isStrided(DesB.Stride)) { + Group = getInterleaveGroup(B); + if (!Group) { + DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B << '\n'); + Group = createInterleaveGroup(B, DesB.Stride, DesB.Align); + } + if (B->mayWriteToMemory()) + StoreGroups.insert(Group); + else + LoadGroups.insert(Group); } - if (A->mayWriteToMemory()) - StoreGroups.insert(Group); - else - LoadGroups.insert(Group); + for (auto AI = std::next(BI); AI != E; ++AI) { + Instruction *A = AI->first; + StrideDescriptor DesA = AI->second; + + // Our code motion strategy implies that we can't have dependences + // between accesses in an interleaved group and other accesses located + // between the first and last member of the group. Note that this also + // means that a group can't have more than one member at a given offset. + // The accesses in a group can have dependences with other accesses, but + // we must ensure we don't extend the boundaries of the group such that + // we encompass those dependent accesses. + // + // For example, assume we have the sequence of accesses shown below in a + // stride-2 loop: + // + // (1, 2) is a group | A[i] = a; // (1) + // | A[i-1] = b; // (2) | + // A[i-3] = c; // (3) + // A[i] = d; // (4) | (2, 4) is not a group + // + // Because accesses (2) and (3) are dependent, we can group (2) with (1) + // but not with (4). If we did, the dependent access (3) would be within + // the boundaries of the (2, 4) group. + if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) { + + // If a dependence exists and A is already in a group, we know that A + // must be a store since A precedes B and WAR dependences are allowed. + // Thus, A would be sunk below B. We release A's group to prevent this + // illegal code motion. A will then be free to form another group with + // instructions that precede it. + if (isInterleaved(A)) { + InterleaveGroup *StoreGroup = getInterleaveGroup(A); + StoreGroups.remove(StoreGroup); + releaseGroup(StoreGroup); + } - for (auto II = std::next(I); II != E; ++II) { - Instruction *B = II->first; - StrideDescriptor DesB = II->second; + // If a dependence exists and A is not already in a group (or it was + // and we just released it), B might be hoisted above A (if B is a + // load) or another store might be sunk below A (if B is a store). In + // either case, we can't add additional instructions to B's group. B + // will only form a group with instructions that it precedes. + break; + } - // Ignore if B is already in a group or B is a different memory operation. - if (isInterleaved(B) || A->mayReadFromMemory() != B->mayReadFromMemory()) + // At this point, we've checked for illegal code motion. If either A or B + // isn't strided, there's nothing left to do. + if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride)) continue; - // Check the rule 1 and 2. - if (DesB.Stride != DesA.Stride || DesB.Size != DesA.Size) + // Ignore A if it's already in a group or isn't the same kind of memory + // operation as B. + if (isInterleaved(A) || A->mayReadFromMemory() != B->mayReadFromMemory()) continue; - // Calculate the distance and prepare for the rule 3. - const SCEVConstant *DistToA = dyn_cast<SCEVConstant>( - PSE.getSE()->getMinusSCEV(DesB.Scev, DesA.Scev)); - if (!DistToA) + // Check rules 1 and 2. Ignore A if its stride or size is different from + // that of B. + if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size) continue; - int DistanceToA = DistToA->getAPInt().getSExtValue(); + // Calculate the distance from A to B. + const SCEVConstant *DistToB = dyn_cast<SCEVConstant>( + PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); + if (!DistToB) + continue; + int64_t DistanceToB = DistToB->getAPInt().getSExtValue(); + + // Check rule 3. Ignore A if its distance to B is not a multiple of the + // size. + if (DistanceToB % static_cast<int64_t>(DesB.Size)) + continue; - // Skip if the distance is not multiple of size as they are not in the - // same group. - if (DistanceToA % static_cast<int>(DesA.Size)) + // Ignore A if either A or B is in a predicated block. Although we + // currently prevent group formation for predicated accesses, we may be + // able to relax this limitation in the future once we handle more + // complicated blocks. + if (isPredicated(A->getParent()) || isPredicated(B->getParent())) continue; - // The index of B is the index of A plus the related index to A. - int IndexB = - Group->getIndex(A) + DistanceToA / static_cast<int>(DesA.Size); + // The index of A is the index of B plus A's distance to B in multiples + // of the size. + int IndexA = + Group->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size); - // Try to insert B into the group. - if (Group->insertMember(B, IndexB, DesB.Align)) { - DEBUG(dbgs() << "LV: Inserted:" << *B << '\n' - << " into the interleave group with" << *A << '\n'); - InterleaveGroupMap[B] = Group; + // Try to insert A into B's group. + if (Group->insertMember(A, IndexA, DesA.Align)) { + DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' + << " into the interleave group with" << *B << '\n'); + InterleaveGroupMap[A] = Group; // Set the first load in program order as the insert position. - if (B->mayReadFromMemory()) - Group->setInsertPos(B); + if (A->mayReadFromMemory()) + Group->setInsertPos(A); } - } // Iteration on instruction B - } // Iteration on instruction A + } // Iteration over A accesses. + } // Iteration over B accesses. // Remove interleaved store groups with gaps. for (InterleaveGroup *Group : StoreGroups) if (Group->getNumMembers() != Group->getFactor()) releaseGroup(Group); - // Remove interleaved load groups that don't have the first and last member. - // This guarantees that we won't do speculative out of bounds loads. + // If there is a non-reversed interleaved load group with gaps, we will need + // to execute at least one scalar epilogue iteration. This will ensure that + // we don't speculatively access memory out-of-bounds. Note that we only need + // to look for a member at index factor - 1, since every group must have a + // member at index zero. for (InterleaveGroup *Group : LoadGroups) - if (!Group->getMember(0) || !Group->getMember(Group->getFactor() - 1)) - releaseGroup(Group); + if (!Group->getMember(Group->getFactor() - 1)) { + if (Group->isReverse()) { + releaseGroup(Group); + } else { + DEBUG(dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); + RequiresScalarEpilogue = true; + } + } } LoopVectorizationCostModel::VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { // Width 1 means no vectorize - VectorizationFactor Factor = { 1U, 0U }; + VectorizationFactor Factor = {1U, 0U}; if (OptForSize && Legal->getRuntimePointerChecking()->Need) { - emitAnalysis(VectorizationReport() << - "runtime pointer checks needed. Enable vectorization of this " - "loop with '#pragma clang loop vectorize(enable)' when " - "compiling with -Os/-Oz"); - DEBUG(dbgs() << - "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); + emitAnalysis( + VectorizationReport() + << "runtime pointer checks needed. Enable vectorization of this " + "loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz"); + DEBUG(dbgs() + << "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); return Factor; } if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { - emitAnalysis(VectorizationReport() << - "store that is conditionally executed prevents vectorization"); + emitAnalysis( + VectorizationReport() + << "store that is conditionally executed prevents vectorization"); DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); return Factor; } // Find the trip count. - unsigned TC = SE->getSmallConstantTripCount(TheLoop); + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); @@ -4755,16 +5201,25 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); unsigned WidestRegister = TTI.getRegisterBitWidth(true); unsigned MaxSafeDepDist = -1U; + + // Get the maximum safe dependence distance in bits computed by LAA. If the + // loop contains any interleaved accesses, we divide the dependence distance + // by the maximum interleave factor of all interleaved groups. Note that + // although the division ensures correctness, this is a fairly conservative + // computation because the maximum distance computed by LAA may not involve + // any of the interleaved accesses. if (Legal->getMaxSafeDepDistBytes() != -1U) - MaxSafeDepDist = Legal->getMaxSafeDepDistBytes() * 8; - WidestRegister = ((WidestRegister < MaxSafeDepDist) ? - WidestRegister : MaxSafeDepDist); + MaxSafeDepDist = + Legal->getMaxSafeDepDistBytes() * 8 / Legal->getMaxInterleaveFactor(); + + WidestRegister = + ((WidestRegister < MaxSafeDepDist) ? WidestRegister : MaxSafeDepDist); unsigned MaxVectorSize = WidestRegister / WidestType; DEBUG(dbgs() << "LV: The Smallest and Widest types: " << SmallestType << " / " << WidestType << " bits.\n"); - DEBUG(dbgs() << "LV: The Widest register is: " - << WidestRegister << " bits.\n"); + DEBUG(dbgs() << "LV: The Widest register is: " << WidestRegister + << " bits.\n"); if (MaxVectorSize == 0) { DEBUG(dbgs() << "LV: The target has no vector registers.\n"); @@ -4772,7 +5227,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { } assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" - " into one vector!"); + " into one vector!"); unsigned VF = MaxVectorSize; if (MaximizeBandwidth && !OptForSize) { @@ -4800,9 +5255,9 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { if (OptForSize) { // If we are unable to calculate the trip count then don't try to vectorize. if (TC < 2) { - emitAnalysis - (VectorizationReport() << - "unable to calculate the loop count due to complex control flow"); + emitAnalysis( + VectorizationReport() + << "unable to calculate the loop count due to complex control flow"); DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); return Factor; } @@ -4815,11 +5270,11 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { else { // If the trip count that we found modulo the vectorization factor is not // zero then we require a tail. - emitAnalysis(VectorizationReport() << - "cannot optimize for size and vectorize at the " - "same time. Enable vectorization of this loop " - "with '#pragma clang loop vectorize(enable)' " - "when compiling with -Os/-Oz"); + emitAnalysis(VectorizationReport() + << "cannot optimize for size and vectorize at the " + "same time. Enable vectorization of this loop " + "with '#pragma clang loop vectorize(enable)' " + "when compiling with -Os/-Oz"); DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); return Factor; } @@ -4834,7 +5289,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { return Factor; } - float Cost = expectedCost(1); + float Cost = expectedCost(1).first; #ifndef NDEBUG const float ScalarCost = Cost; #endif /* NDEBUG */ @@ -4845,16 +5300,23 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { // Ignore scalar width, because the user explicitly wants vectorization. if (ForceVectorization && VF > 1) { Width = 2; - Cost = expectedCost(Width) / (float)Width; + Cost = expectedCost(Width).first / (float)Width; } - for (unsigned i=2; i <= VF; i*=2) { + for (unsigned i = 2; i <= VF; i *= 2) { // Notice that the vector loop needs to be executed less times, so // we need to divide the cost of the vector loops by the width of // the vector elements. - float VectorCost = expectedCost(i) / (float)i; - DEBUG(dbgs() << "LV: Vector loop of width " << i << " costs: " << - (int)VectorCost << ".\n"); + VectorizationCostTy C = expectedCost(i); + float VectorCost = C.first / (float)i; + DEBUG(dbgs() << "LV: Vector loop of width " << i + << " costs: " << (int)VectorCost << ".\n"); + if (!C.second && !ForceVectorization) { + DEBUG( + dbgs() << "LV: Not considering vector loop of width " << i + << " because it will not generate any vector instructions.\n"); + continue; + } if (VectorCost < Cost) { Cost = VectorCost; Width = i; @@ -4864,7 +5326,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { DEBUG(if (ForceVectorization && Width > 1 && Cost >= ScalarCost) dbgs() << "LV: Vectorization seems to be not beneficial, " << "but was forced by a user.\n"); - DEBUG(dbgs() << "LV: Selecting VF: "<< Width << ".\n"); + DEBUG(dbgs() << "LV: Selecting VF: " << Width << ".\n"); Factor.Width = Width; Factor.Cost = Width * Cost; return Factor; @@ -4877,25 +5339,22 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { const DataLayout &DL = TheFunction->getParent()->getDataLayout(); // For each block. - for (Loop::block_iterator bb = TheLoop->block_begin(), - be = TheLoop->block_end(); bb != be; ++bb) { - BasicBlock *BB = *bb; - + for (BasicBlock *BB : TheLoop->blocks()) { // For each instruction in the loop. - for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { - Type *T = it->getType(); + for (Instruction &I : *BB) { + Type *T = I.getType(); // Skip ignored values. - if (ValuesToIgnore.count(&*it)) + if (ValuesToIgnore.count(&I)) continue; // Only examine Loads, Stores and PHINodes. - if (!isa<LoadInst>(it) && !isa<StoreInst>(it) && !isa<PHINode>(it)) + if (!isa<LoadInst>(I) && !isa<StoreInst>(I) && !isa<PHINode>(I)) continue; // Examine PHI nodes that are reduction variables. Update the type to // account for the recurrence type. - if (PHINode *PN = dyn_cast<PHINode>(it)) { + if (auto *PN = dyn_cast<PHINode>(&I)) { if (!Legal->isReductionVariable(PN)) continue; RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[PN]; @@ -4903,13 +5362,13 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { } // Examine the stored values. - if (StoreInst *ST = dyn_cast<StoreInst>(it)) + if (auto *ST = dyn_cast<StoreInst>(&I)) T = ST->getValueOperand()->getType(); // Ignore loaded pointer types and stored pointer types that are not // consecutive. However, we do want to take consecutive stores/loads of // pointer vectors into account. - if (T->isPointerTy() && !isConsecutiveLoadOrStore(&*it)) + if (T->isPointerTy() && !isConsecutiveLoadOrStore(&I)) continue; MinWidth = std::min(MinWidth, @@ -4949,13 +5408,13 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, return 1; // Do not interleave loops with a relatively small trip count. - unsigned TC = SE->getSmallConstantTripCount(TheLoop); + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); if (TC > 1 && TC < TinyTripCountInterleaveThreshold) return 1; unsigned TargetNumRegisters = TTI.getNumberOfRegisters(VF > 1); - DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters << - " registers\n"); + DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters + << " registers\n"); if (VF == 1) { if (ForceTargetNumScalarRegs.getNumOccurrences() > 0) @@ -5002,7 +5461,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // If we did not calculate the cost for VF (because the user selected the VF) // then we calculate the cost of VF here. if (LoopCost == 0) - LoopCost = expectedCost(VF); + LoopCost = expectedCost(VF).first; // Clamp the calculated IC to be between the 1 and the max interleave count // that the target allows. @@ -5044,8 +5503,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // by this point), we can increase the critical path length if the loop // we're interleaving is inside another loop. Limit, by default to 2, so the // critical path only gets increased by one reduction operation. - if (Legal->getReductionVars()->size() && - TheLoop->getLoopDepth() > 1) { + if (Legal->getReductionVars()->size() && TheLoop->getLoopDepth() > 1) { unsigned F = static_cast<unsigned>(MaxNestedScalarReductionIC); SmallIC = std::min(SmallIC, F); StoresIC = std::min(StoresIC, F); @@ -5075,8 +5533,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, } SmallVector<LoopVectorizationCostModel::RegisterUsage, 8> -LoopVectorizationCostModel::calculateRegisterUsage( - const SmallVector<unsigned, 8> &VFs) { +LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { // This function calculates the register usage by measuring the highest number // of values that are alive at a single location. Obviously, this is a very // rough estimation. We scan the loop in a topological order in order and @@ -5103,31 +5560,30 @@ LoopVectorizationCostModel::calculateRegisterUsage( // Each 'key' in the map opens a new interval. The values // of the map are the index of the 'last seen' usage of the // instruction that is the key. - typedef DenseMap<Instruction*, unsigned> IntervalMap; + typedef DenseMap<Instruction *, unsigned> IntervalMap; // Maps instruction to its index. - DenseMap<unsigned, Instruction*> IdxToInstr; + DenseMap<unsigned, Instruction *> IdxToInstr; // Marks the end of each interval. IntervalMap EndPoint; // Saves the list of instruction indices that are used in the loop. - SmallSet<Instruction*, 8> Ends; + SmallSet<Instruction *, 8> Ends; // Saves the list of values that are used in the loop but are // defined outside the loop, such as arguments and constants. - SmallPtrSet<Value*, 8> LoopInvariants; + SmallPtrSet<Value *, 8> LoopInvariants; unsigned Index = 0; - for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(), - be = DFS.endRPO(); bb != be; ++bb) { - RU.NumInstructions += (*bb)->size(); - for (Instruction &I : **bb) { + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { + RU.NumInstructions += BB->size(); + for (Instruction &I : *BB) { IdxToInstr[Index++] = &I; // Save the end location of each USE. - for (unsigned i = 0; i < I.getNumOperands(); ++i) { - Value *U = I.getOperand(i); - Instruction *Instr = dyn_cast<Instruction>(U); + for (Value *U : I.operands()) { + auto *Instr = dyn_cast<Instruction>(U); // Ignore non-instruction values such as arguments, constants, etc. - if (!Instr) continue; + if (!Instr) + continue; // If this instruction is outside the loop then record it and continue. if (!TheLoop->contains(Instr)) { @@ -5143,15 +5599,14 @@ LoopVectorizationCostModel::calculateRegisterUsage( } // Saves the list of intervals that end with the index in 'key'. - typedef SmallVector<Instruction*, 2> InstrList; + typedef SmallVector<Instruction *, 2> InstrList; DenseMap<unsigned, InstrList> TransposeEnds; // Transpose the EndPoints to a list of values that end at each index. - for (IntervalMap::iterator it = EndPoint.begin(), e = EndPoint.end(); - it != e; ++it) - TransposeEnds[it->second].push_back(it->first); + for (auto &Interval : EndPoint) + TransposeEnds[Interval.second].push_back(Interval.first); - SmallSet<Instruction*, 8> OpenIntervals; + SmallSet<Instruction *, 8> OpenIntervals; // Get the size of the widest register. unsigned MaxSafeDepDist = -1U; @@ -5168,6 +5623,8 @@ LoopVectorizationCostModel::calculateRegisterUsage( // A lambda that gets the register usage for the given type and VF. auto GetRegUsage = [&DL, WidestRegister](Type *Ty, unsigned VF) { + if (Ty->isTokenTy()) + return 0U; unsigned TypeSize = DL.getTypeSizeInBits(Ty->getScalarType()); return std::max<unsigned>(1, VF * TypeSize / WidestRegister); }; @@ -5175,16 +5632,17 @@ LoopVectorizationCostModel::calculateRegisterUsage( for (unsigned int i = 0; i < Index; ++i) { Instruction *I = IdxToInstr[i]; // Ignore instructions that are never used within the loop. - if (!Ends.count(I)) continue; - - // Skip ignored values. - if (ValuesToIgnore.count(I)) + if (!Ends.count(I)) continue; // Remove all of the instructions that end at this location. InstrList &List = TransposeEnds[i]; - for (unsigned int j = 0, e = List.size(); j < e; ++j) - OpenIntervals.erase(List[j]); + for (Instruction *ToRemove : List) + OpenIntervals.erase(ToRemove); + + // Skip ignored values. + if (ValuesToIgnore.count(I)) + continue; // For each VF find the maximum usage of registers. for (unsigned j = 0, e = VFs.size(); j < e; ++j) { @@ -5195,8 +5653,12 @@ LoopVectorizationCostModel::calculateRegisterUsage( // Count the number of live intervals. unsigned RegUsage = 0; - for (auto Inst : OpenIntervals) + for (auto Inst : OpenIntervals) { + // Skip ignored values for VF > 1. + if (VecValuesToIgnore.count(Inst)) + continue; RegUsage += GetRegUsage(Inst->getType(), VFs[j]); + } MaxUsages[j] = std::max(MaxUsages[j], RegUsage); } @@ -5216,7 +5678,7 @@ LoopVectorizationCostModel::calculateRegisterUsage( Invariant += GetRegUsage(Inst->getType(), VFs[i]); } - DEBUG(dbgs() << "LV(REG): VF = " << VFs[i] << '\n'); + DEBUG(dbgs() << "LV(REG): VF = " << VFs[i] << '\n'); DEBUG(dbgs() << "LV(REG): Found max usage: " << MaxUsages[i] << '\n'); DEBUG(dbgs() << "LV(REG): Found invariant usage: " << Invariant << '\n'); DEBUG(dbgs() << "LV(REG): LoopSize: " << RU.NumInstructions << '\n'); @@ -5229,48 +5691,62 @@ LoopVectorizationCostModel::calculateRegisterUsage( return RUs; } -unsigned LoopVectorizationCostModel::expectedCost(unsigned VF) { - unsigned Cost = 0; +LoopVectorizationCostModel::VectorizationCostTy +LoopVectorizationCostModel::expectedCost(unsigned VF) { + VectorizationCostTy Cost; // For each block. - for (Loop::block_iterator bb = TheLoop->block_begin(), - be = TheLoop->block_end(); bb != be; ++bb) { - unsigned BlockCost = 0; - BasicBlock *BB = *bb; + for (BasicBlock *BB : TheLoop->blocks()) { + VectorizationCostTy BlockCost; // For each instruction in the old loop. - for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + for (Instruction &I : *BB) { // Skip dbg intrinsics. - if (isa<DbgInfoIntrinsic>(it)) + if (isa<DbgInfoIntrinsic>(I)) continue; // Skip ignored values. - if (ValuesToIgnore.count(&*it)) + if (ValuesToIgnore.count(&I)) continue; - unsigned C = getInstructionCost(&*it, VF); + VectorizationCostTy C = getInstructionCost(&I, VF); // Check if we should override the cost. if (ForceTargetInstructionCost.getNumOccurrences() > 0) - C = ForceTargetInstructionCost; + C.first = ForceTargetInstructionCost; - BlockCost += C; - DEBUG(dbgs() << "LV: Found an estimated cost of " << C << " for VF " << - VF << " For instruction: " << *it << '\n'); + BlockCost.first += C.first; + BlockCost.second |= C.second; + DEBUG(dbgs() << "LV: Found an estimated cost of " << C.first << " for VF " + << VF << " For instruction: " << I << '\n'); } // We assume that if-converted blocks have a 50% chance of being executed. // When the code is scalar then some of the blocks are avoided due to CF. // When the code is vectorized we execute all code paths. - if (VF == 1 && Legal->blockNeedsPredication(*bb)) - BlockCost /= 2; + if (VF == 1 && Legal->blockNeedsPredication(BB)) + BlockCost.first /= 2; - Cost += BlockCost; + Cost.first += BlockCost.first; + Cost.second |= BlockCost.second; } return Cost; } +/// \brief Check if the load/store instruction \p I may be translated into +/// gather/scatter during vectorization. +/// +/// Pointer \p Ptr specifies address in memory for the given scalar memory +/// instruction. We need it to retrieve data type. +/// Using gather/scatter is possible when it is supported by target. +static bool isGatherOrScatterLegal(Instruction *I, Value *Ptr, + LoopVectorizationLegality *Legal) { + auto *DataTy = cast<PointerType>(Ptr->getType())->getElementType(); + return (isa<LoadInst>(I) && Legal->isLegalMaskedGather(DataTy)) || + (isa<StoreInst>(I) && Legal->isLegalMaskedScatter(DataTy)); +} + /// \brief Check whether the address computation for a non-consecutive memory /// access looks like an unlikely candidate for being merged into the indexing /// mode. @@ -5284,7 +5760,7 @@ static bool isLikelyComplexAddressComputation(Value *Ptr, LoopVectorizationLegality *Legal, ScalarEvolution *SE, const Loop *TheLoop) { - GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr); + auto *Gep = dyn_cast<GetElementPtrInst>(Ptr); if (!Gep) return true; @@ -5309,7 +5785,7 @@ static bool isLikelyComplexAddressComputation(Value *Ptr, // Check the step is constant. const SCEV *Step = AddRec->getStepRecurrence(*SE); // Calculate the pointer stride and check if it is consecutive. - const SCEVConstant *C = dyn_cast<SCEVConstant>(Step); + const auto *C = dyn_cast<SCEVConstant>(Step); if (!C) return true; @@ -5329,17 +5805,29 @@ static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { Legal->hasStride(I->getOperand(1)); } -unsigned +LoopVectorizationCostModel::VectorizationCostTy LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { // If we know that this instruction will remain uniform, check the cost of // the scalar version. if (Legal->isUniformAfterVectorization(I)) VF = 1; + Type *VectorTy; + unsigned C = getInstructionCost(I, VF, VectorTy); + + bool TypeNotScalarized = + VF > 1 && !VectorTy->isVoidTy() && TTI.getNumberOfParts(VectorTy) < VF; + return VectorizationCostTy(C, TypeNotScalarized); +} + +unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, + unsigned VF, + Type *&VectorTy) { Type *RetTy = I->getType(); if (VF > 1 && MinBWs.count(I)) RetTy = IntegerType::get(RetTy->getContext(), MinBWs[I]); - Type *VectorTy = ToVectorTy(RetTy, VF); + VectorTy = ToVectorTy(RetTy, VF); + auto SE = PSE.getSE(); // TODO: We need to estimate the cost of intrinsic calls. switch (I->getOpcode()) { @@ -5352,9 +5840,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { case Instruction::Br: { return TTI.getCFInstrCost(I->getOpcode()); } - case Instruction::PHI: - //TODO: IF-converted IFs become selects. + case Instruction::PHI: { + auto *Phi = cast<PHINode>(I); + + // First-order recurrences are replaced by vector shuffles inside the loop. + if (VF > 1 && Legal->isFirstOrderRecurrence(Phi)) + return TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, + VectorTy, VF - 1, VectorTy); + + // TODO: IF-converted IFs become selects. return 0; + } case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -5379,9 +5875,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { // Certain instructions can be cheaper to vectorize if they have a constant // second vector operand. One example of this are shifts on x86. TargetTransformInfo::OperandValueKind Op1VK = - TargetTransformInfo::OK_AnyValue; + TargetTransformInfo::OK_AnyValue; TargetTransformInfo::OperandValueKind Op2VK = - TargetTransformInfo::OK_AnyValue; + TargetTransformInfo::OK_AnyValue; TargetTransformInfo::OperandValueProperties Op1VP = TargetTransformInfo::OP_None; TargetTransformInfo::OperandValueProperties Op2VP = @@ -5432,20 +5928,28 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { case Instruction::Load: { StoreInst *SI = dyn_cast<StoreInst>(I); LoadInst *LI = dyn_cast<LoadInst>(I); - Type *ValTy = (SI ? SI->getValueOperand()->getType() : - LI->getType()); + Type *ValTy = (SI ? SI->getValueOperand()->getType() : LI->getType()); VectorTy = ToVectorTy(ValTy, VF); unsigned Alignment = SI ? SI->getAlignment() : LI->getAlignment(); - unsigned AS = SI ? SI->getPointerAddressSpace() : - LI->getPointerAddressSpace(); + unsigned AS = + SI ? SI->getPointerAddressSpace() : LI->getPointerAddressSpace(); Value *Ptr = SI ? SI->getPointerOperand() : LI->getPointerOperand(); // We add the cost of address computation here instead of with the gep // instruction because only here we know whether the operation is // scalarized. if (VF == 1) return TTI.getAddressComputationCost(VectorTy) + - TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); + TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); + + if (LI && Legal->isUniform(Ptr)) { + // Scalar load + broadcast + unsigned Cost = TTI.getAddressComputationCost(ValTy->getScalarType()); + Cost += TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), + Alignment, AS); + return Cost + + TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, ValTy); + } // For an interleaved access, calculate the total cost of the whole // interleave group. @@ -5463,7 +5967,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { VectorTy->getVectorNumElements() * InterleaveFactor); // Holds the indices of existing members in an interleaved load group. - // An interleaved store group doesn't need this as it dones't allow gaps. + // An interleaved store group doesn't need this as it doesn't allow gaps. SmallVector<unsigned, 4> Indices; if (LI) { for (unsigned i = 0; i < InterleaveFactor; i++) @@ -5489,13 +5993,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { // Scalarized loads/stores. int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); + bool UseGatherOrScatter = + (ConsecutiveStride == 0) && isGatherOrScatterLegal(I, Ptr, Legal); + bool Reverse = ConsecutiveStride < 0; const DataLayout &DL = I->getModule()->getDataLayout(); - unsigned ScalarAllocatedSize = DL.getTypeAllocSize(ValTy); - unsigned VectorElementSize = DL.getTypeStoreSize(VectorTy) / VF; - if (!ConsecutiveStride || ScalarAllocatedSize != VectorElementSize) { + uint64_t ScalarAllocatedSize = DL.getTypeAllocSize(ValTy); + uint64_t VectorElementSize = DL.getTypeStoreSize(VectorTy) / VF; + if ((!ConsecutiveStride && !UseGatherOrScatter) || + ScalarAllocatedSize != VectorElementSize) { bool IsComplexComputation = - isLikelyComplexAddressComputation(Ptr, Legal, SE, TheLoop); + isLikelyComplexAddressComputation(Ptr, Legal, SE, TheLoop); unsigned Cost = 0; // The cost of extracting from the value vector and pointer vector. Type *PtrTy = ToVectorTy(Ptr->getType(), VF); @@ -5505,29 +6013,36 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { // In case of STORE, the cost of ExtractElement from the vector. // In case of LOAD, the cost of InsertElement into the returned // vector. - Cost += TTI.getVectorInstrCost(SI ? Instruction::ExtractElement : - Instruction::InsertElement, - VectorTy, i); + Cost += TTI.getVectorInstrCost(SI ? Instruction::ExtractElement + : Instruction::InsertElement, + VectorTy, i); } // The cost of the scalar loads/stores. Cost += VF * TTI.getAddressComputationCost(PtrTy, IsComplexComputation); - Cost += VF * TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), - Alignment, AS); + Cost += VF * + TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), + Alignment, AS); return Cost; } - // Wide load/stores. unsigned Cost = TTI.getAddressComputationCost(VectorTy); + if (UseGatherOrScatter) { + assert(ConsecutiveStride == 0 && + "Gather/Scatter are not used for consecutive stride"); + return Cost + + TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr, + Legal->isMaskRequired(I), Alignment); + } + // Wide load/stores. if (Legal->isMaskRequired(I)) - Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, - AS); + Cost += + TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); else Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); if (Reverse) - Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, - VectorTy, 0); + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); return Cost; } case Instruction::ZExt: @@ -5548,7 +6063,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { Legal->isInductionVariable(I->getOperand(0))) return TTI.getCastInstrCost(I->getOpcode(), I->getType(), I->getOperand(0)->getType()); - + Type *SrcScalarTy = I->getOperand(0)->getType(); Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF); if (VF > 1 && MinBWs.count(I)) { @@ -5560,23 +6075,23 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { Type *MinVecTy = VectorTy; if (I->getOpcode() == Instruction::Trunc) { SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy); - VectorTy = largestIntegerVectorType(ToVectorTy(I->getType(), VF), - MinVecTy); + VectorTy = + largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); } else if (I->getOpcode() == Instruction::ZExt || I->getOpcode() == Instruction::SExt) { SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy); - VectorTy = smallestIntegerVectorType(ToVectorTy(I->getType(), VF), - MinVecTy); + VectorTy = + smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); } } - + return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy); } case Instruction::Call: { bool NeedToScalarize; CallInst *CI = cast<CallInst>(I); unsigned CallCost = getVectorCallCost(CI, VF, TTI, TLI, NeedToScalarize); - if (getIntrinsicIDForCall(CI, TLI)) + if (getVectorIntrinsicIDForCall(CI, TLI)) return std::min(CallCost, getVectorIntrinsicCost(CI, VF, TTI, TLI)); return CallCost; } @@ -5587,10 +6102,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { unsigned Cost = 0; if (!RetTy->isVoidTy() && VF != 1) { - unsigned InsCost = TTI.getVectorInstrCost(Instruction::InsertElement, - VectorTy); - unsigned ExtCost = TTI.getVectorInstrCost(Instruction::ExtractElement, - VectorTy); + unsigned InsCost = + TTI.getVectorInstrCost(Instruction::InsertElement, VectorTy); + unsigned ExtCost = + TTI.getVectorInstrCost(Instruction::ExtractElement, VectorTy); // The cost of inserting the results plus extracting each one of the // operands. @@ -5602,7 +6117,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { Cost += VF * TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy); return Cost; } - }// end of switch. + } // end of switch. } char LoopVectorize::ID = 0; @@ -5616,31 +6131,101 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LCSSA) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis) -INITIALIZE_PASS_DEPENDENCY(DemandedBits) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false) namespace llvm { - Pass *createLoopVectorizePass(bool NoUnrolling, bool AlwaysVectorize) { - return new LoopVectorize(NoUnrolling, AlwaysVectorize); - } +Pass *createLoopVectorizePass(bool NoUnrolling, bool AlwaysVectorize) { + return new LoopVectorize(NoUnrolling, AlwaysVectorize); +} } bool LoopVectorizationCostModel::isConsecutiveLoadOrStore(Instruction *Inst) { // Check for a store. - if (StoreInst *ST = dyn_cast<StoreInst>(Inst)) + if (auto *ST = dyn_cast<StoreInst>(Inst)) return Legal->isConsecutivePtr(ST->getPointerOperand()) != 0; // Check for a load. - if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) + if (auto *LI = dyn_cast<LoadInst>(Inst)) return Legal->isConsecutivePtr(LI->getPointerOperand()) != 0; return false; } +void LoopVectorizationCostModel::collectValuesToIgnore() { + // Ignore ephemeral values. + CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore); + + // Ignore type-promoting instructions we identified during reduction + // detection. + for (auto &Reduction : *Legal->getReductionVars()) { + RecurrenceDescriptor &RedDes = Reduction.second; + SmallPtrSetImpl<Instruction *> &Casts = RedDes.getCastInsts(); + VecValuesToIgnore.insert(Casts.begin(), Casts.end()); + } + + // Ignore induction phis that are only used in either GetElementPtr or ICmp + // instruction to exit loop. Induction variables usually have large types and + // can have big impact when estimating register usage. + // This is for when VF > 1. + for (auto &Induction : *Legal->getInductionVars()) { + auto *PN = Induction.first; + auto *UpdateV = PN->getIncomingValueForBlock(TheLoop->getLoopLatch()); + + // Check that the PHI is only used by the induction increment (UpdateV) or + // by GEPs. Then check that UpdateV is only used by a compare instruction, + // the loop header PHI, or by GEPs. + // FIXME: Need precise def-use analysis to determine if this instruction + // variable will be vectorized. + if (all_of(PN->users(), + [&](const User *U) -> bool { + return U == UpdateV || isa<GetElementPtrInst>(U); + }) && + all_of(UpdateV->users(), [&](const User *U) -> bool { + return U == PN || isa<ICmpInst>(U) || isa<GetElementPtrInst>(U); + })) { + VecValuesToIgnore.insert(PN); + VecValuesToIgnore.insert(UpdateV); + } + } + + // Ignore instructions that will not be vectorized. + // This is for when VF > 1. + for (BasicBlock *BB : TheLoop->blocks()) { + for (auto &Inst : *BB) { + switch (Inst.getOpcode()) + case Instruction::GetElementPtr: { + // Ignore GEP if its last operand is an induction variable so that it is + // a consecutive load/store and won't be vectorized as scatter/gather + // pattern. + + GetElementPtrInst *Gep = cast<GetElementPtrInst>(&Inst); + unsigned NumOperands = Gep->getNumOperands(); + unsigned InductionOperand = getGEPInductionOperand(Gep); + bool GepToIgnore = true; + + // Check that all of the gep indices are uniform except for the + // induction operand. + for (unsigned i = 0; i != NumOperands; ++i) { + if (i != InductionOperand && + !PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), + TheLoop)) { + GepToIgnore = false; + break; + } + } + + if (GepToIgnore) + VecValuesToIgnore.insert(&Inst); + break; + } + } + } +} void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, bool IfPredicateStore) { @@ -5651,9 +6236,7 @@ void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, setDebugLocFromInst(Builder, Instr); // Find all of the vectorized parameters. - for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - Value *SrcOp = Instr->getOperand(op); - + for (Value *SrcOp : Instr->operands()) { // If we are accessing the old induction variable, use the new one. if (SrcOp == OldInduction) { Params.push_back(getVectorValue(SrcOp)); @@ -5683,8 +6266,7 @@ void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, // Does this instruction return a value ? bool IsVoidRetTy = Instr->getType()->isVoidTy(); - Value *UndefVec = IsVoidRetTy ? nullptr : - UndefValue::get(Instr->getType()); + Value *UndefVec = IsVoidRetTy ? nullptr : UndefValue::get(Instr->getType()); // Create a new entry in the WidenMap and initialize it to Undef or Null. VectorParts &VecResults = WidenMap.splat(Instr, UndefVec); @@ -5711,43 +6293,43 @@ void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, } Instruction *Cloned = Instr->clone(); - if (!IsVoidRetTy) - Cloned->setName(Instr->getName() + ".cloned"); - // Replace the operands of the cloned instructions with extracted scalars. - for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - Value *Op = Params[op][Part]; - Cloned->setOperand(op, Op); - } + if (!IsVoidRetTy) + Cloned->setName(Instr->getName() + ".cloned"); + // Replace the operands of the cloned instructions with extracted scalars. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + Value *Op = Params[op][Part]; + Cloned->setOperand(op, Op); + } - // Place the cloned scalar in the new loop. - Builder.Insert(Cloned); + // Place the cloned scalar in the new loop. + Builder.Insert(Cloned); - // If the original scalar returns a value we need to place it in a vector - // so that future users will be able to use it. - if (!IsVoidRetTy) - VecResults[Part] = Cloned; + // If we just cloned a new assumption, add it the assumption cache. + if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); - // End if-block. - if (IfPredicateStore) - PredicatedStores.push_back(std::make_pair(cast<StoreInst>(Cloned), - Cmp)); + // If the original scalar returns a value we need to place it in a vector + // so that future users will be able to use it. + if (!IsVoidRetTy) + VecResults[Part] = Cloned; + + // End if-block. + if (IfPredicateStore) + PredicatedStores.push_back(std::make_pair(cast<StoreInst>(Cloned), Cmp)); } } void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) { - StoreInst *SI = dyn_cast<StoreInst>(Instr); + auto *SI = dyn_cast<StoreInst>(Instr); bool IfPredicateStore = (SI && Legal->blockNeedsPredication(SI->getParent())); return scalarizeInstruction(Instr, IfPredicateStore); } -Value *InnerLoopUnroller::reverseVector(Value *Vec) { - return Vec; -} +Value *InnerLoopUnroller::reverseVector(Value *Vec) { return Vec; } -Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { - return V; -} +Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } Value *InnerLoopUnroller::getStepVector(Value *Val, int StartIdx, Value *Step) { // When unrolling and the VF is 1, we only need to add a simple scalar. @@ -5756,3 +6338,346 @@ Value *InnerLoopUnroller::getStepVector(Value *Val, int StartIdx, Value *Step) { Constant *C = ConstantInt::get(ITy, StartIdx); return Builder.CreateAdd(Val, Builder.CreateMul(C, Step), "induction"); } + +static void AddRuntimeUnrollDisableMetaData(Loop *L) { + SmallVector<Metadata *, 4> MDs; + // Reserve first location for self reference to the LoopID metadata node. + MDs.push_back(nullptr); + bool IsUnrollMetadata = false; + MDNode *LoopID = L->getLoopID(); + if (LoopID) { + // First find existing loop unrolling disable metadata. + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + auto *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (MD) { + const auto *S = dyn_cast<MDString>(MD->getOperand(0)); + IsUnrollMetadata = + S && S->getString().startswith("llvm.loop.unroll.disable"); + } + MDs.push_back(LoopID->getOperand(i)); + } + } + + if (!IsUnrollMetadata) { + // Add runtime unroll disable metadata. + LLVMContext &Context = L->getHeader()->getContext(); + SmallVector<Metadata *, 1> DisableOperands; + DisableOperands.push_back( + MDString::get(Context, "llvm.loop.unroll.runtime.disable")); + MDNode *DisableNode = MDNode::get(Context, DisableOperands); + MDs.push_back(DisableNode); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + L->setLoopID(NewLoopID); + } +} + +bool LoopVectorizePass::processLoop(Loop *L) { + assert(L->empty() && "Only process inner loops."); + +#ifndef NDEBUG + const std::string DebugLocStr = getDebugLocString(L); +#endif /* NDEBUG */ + + DEBUG(dbgs() << "\nLV: Checking a loop in \"" + << L->getHeader()->getParent()->getName() << "\" from " + << DebugLocStr << "\n"); + + LoopVectorizeHints Hints(L, DisableUnrolling); + + DEBUG(dbgs() << "LV: Loop hints:" + << " force=" + << (Hints.getForce() == LoopVectorizeHints::FK_Disabled + ? "disabled" + : (Hints.getForce() == LoopVectorizeHints::FK_Enabled + ? "enabled" + : "?")) + << " width=" << Hints.getWidth() + << " unroll=" << Hints.getInterleave() << "\n"); + + // Function containing loop + Function *F = L->getHeader()->getParent(); + + // Looking at the diagnostic output is the only way to determine if a loop + // was vectorized (other than looking at the IR or machine code), so it + // is important to generate an optimization remark for each loop. Most of + // these messages are generated by emitOptimizationRemarkAnalysis. Remarks + // generated by emitOptimizationRemark and emitOptimizationRemarkMissed are + // less verbose reporting vectorized loops and unvectorized loops that may + // benefit from vectorization, respectively. + + if (!Hints.allowVectorization(F, L, AlwaysVectorize)) { + DEBUG(dbgs() << "LV: Loop hints prevent vectorization.\n"); + return false; + } + + // Check the loop for a trip count threshold: + // do not vectorize loops with a tiny trip count. + const unsigned TC = SE->getSmallConstantTripCount(L); + if (TC > 0u && TC < TinyTripCountVectorThreshold) { + DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " + << "This loop is not worth vectorizing."); + if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) + DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); + else { + DEBUG(dbgs() << "\n"); + emitAnalysisDiag(F, L, Hints, VectorizationReport() + << "vectorization is not beneficial " + "and is not explicitly forced"); + return false; + } + } + + PredicatedScalarEvolution PSE(*SE, *L); + + // Check if it is legal to vectorize the loop. + LoopVectorizationRequirements Requirements; + LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, GetLAA, LI, + &Requirements, &Hints); + if (!LVL.canVectorize()) { + DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); + emitMissedWarning(F, L, Hints); + return false; + } + + // Use the cost model. + LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, F, + &Hints); + CM.collectValuesToIgnore(); + + // Check the function attributes to find out if this function should be + // optimized for size. + bool OptForSize = + Hints.getForce() != LoopVectorizeHints::FK_Enabled && F->optForSize(); + + // Compute the weighted frequency of this loop being executed and see if it + // is less than 20% of the function entry baseline frequency. Note that we + // always have a canonical loop here because we think we *can* vectorize. + // FIXME: This is hidden behind a flag due to pervasive problems with + // exactly what block frequency models. + if (LoopVectorizeWithBlockFrequency) { + BlockFrequency LoopEntryFreq = BFI->getBlockFreq(L->getLoopPreheader()); + if (Hints.getForce() != LoopVectorizeHints::FK_Enabled && + LoopEntryFreq < ColdEntryFreq) + OptForSize = true; + } + + // Check the function attributes to see if implicit floats are allowed. + // FIXME: This check doesn't seem possibly correct -- what if the loop is + // an integer loop and the vector instructions selected are purely integer + // vector instructions? + if (F->hasFnAttribute(Attribute::NoImplicitFloat)) { + DEBUG(dbgs() << "LV: Can't vectorize when the NoImplicitFloat" + "attribute is used.\n"); + emitAnalysisDiag( + F, L, Hints, + VectorizationReport() + << "loop not vectorized due to NoImplicitFloat attribute"); + emitMissedWarning(F, L, Hints); + return false; + } + + // Check if the target supports potentially unsafe FP vectorization. + // FIXME: Add a check for the type of safety issue (denormal, signaling) + // for the target we're vectorizing for, to make sure none of the + // additional fp-math flags can help. + if (Hints.isPotentiallyUnsafe() && + TTI->isFPVectorizationPotentiallyUnsafe()) { + DEBUG(dbgs() << "LV: Potentially unsafe FP op prevents vectorization.\n"); + emitAnalysisDiag(F, L, Hints, + VectorizationReport() + << "loop not vectorized due to unsafe FP support."); + emitMissedWarning(F, L, Hints); + return false; + } + + // Select the optimal vectorization factor. + const LoopVectorizationCostModel::VectorizationFactor VF = + CM.selectVectorizationFactor(OptForSize); + + // Select the interleave count. + unsigned IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost); + + // Get user interleave count. + unsigned UserIC = Hints.getInterleave(); + + // Identify the diagnostic messages that should be produced. + std::string VecDiagMsg, IntDiagMsg; + bool VectorizeLoop = true, InterleaveLoop = true; + + if (Requirements.doesNotMeet(F, L, Hints)) { + DEBUG(dbgs() << "LV: Not vectorizing: loop did not meet vectorization " + "requirements.\n"); + emitMissedWarning(F, L, Hints); + return false; + } + + if (VF.Width == 1) { + DEBUG(dbgs() << "LV: Vectorization is possible but not beneficial.\n"); + VecDiagMsg = + "the cost-model indicates that vectorization is not beneficial"; + VectorizeLoop = false; + } + + if (IC == 1 && UserIC <= 1) { + // Tell the user interleaving is not beneficial. + DEBUG(dbgs() << "LV: Interleaving is not beneficial.\n"); + IntDiagMsg = + "the cost-model indicates that interleaving is not beneficial"; + InterleaveLoop = false; + if (UserIC == 1) + IntDiagMsg += + " and is explicitly disabled or interleave count is set to 1"; + } else if (IC > 1 && UserIC == 1) { + // Tell the user interleaving is beneficial, but it explicitly disabled. + DEBUG(dbgs() + << "LV: Interleaving is beneficial but is explicitly disabled."); + IntDiagMsg = "the cost-model indicates that interleaving is beneficial " + "but is explicitly disabled or interleave count is set to 1"; + InterleaveLoop = false; + } + + // Override IC if user provided an interleave count. + IC = UserIC > 0 ? UserIC : IC; + + // Emit diagnostic messages, if any. + const char *VAPassName = Hints.vectorizeAnalysisPassName(); + if (!VectorizeLoop && !InterleaveLoop) { + // Do not vectorize or interleaving the loop. + emitOptimizationRemarkAnalysis(F->getContext(), VAPassName, *F, + L->getStartLoc(), VecDiagMsg); + emitOptimizationRemarkAnalysis(F->getContext(), LV_NAME, *F, + L->getStartLoc(), IntDiagMsg); + return false; + } else if (!VectorizeLoop && InterleaveLoop) { + DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); + emitOptimizationRemarkAnalysis(F->getContext(), VAPassName, *F, + L->getStartLoc(), VecDiagMsg); + } else if (VectorizeLoop && !InterleaveLoop) { + DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " + << DebugLocStr << '\n'); + emitOptimizationRemarkAnalysis(F->getContext(), LV_NAME, *F, + L->getStartLoc(), IntDiagMsg); + } else if (VectorizeLoop && InterleaveLoop) { + DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " + << DebugLocStr << '\n'); + DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); + } + + if (!VectorizeLoop) { + assert(IC > 1 && "interleave count should not be 1 or 0"); + // If we decided that it is not legal to vectorize the loop, then + // interleave it. + InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, AC, IC); + Unroller.vectorize(&LVL, CM.MinBWs, CM.VecValuesToIgnore); + + emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), + Twine("interleaved loop (interleaved count: ") + + Twine(IC) + ")"); + } else { + // If we decided that it is *legal* to vectorize the loop, then do it. + InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, VF.Width, IC); + LB.vectorize(&LVL, CM.MinBWs, CM.VecValuesToIgnore); + ++LoopsVectorized; + + // Add metadata to disable runtime unrolling a scalar loop when there are + // no runtime checks about strides and memory. A scalar loop that is + // rarely used is not worth unrolling. + if (!LB.areSafetyChecksAdded()) + AddRuntimeUnrollDisableMetaData(L); + + // Report the vectorization decision. + emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), + Twine("vectorized loop (vectorization width: ") + + Twine(VF.Width) + ", interleaved count: " + + Twine(IC) + ")"); + } + + // Mark the loop as already vectorized to avoid vectorizing again. + Hints.setAlreadyVectorized(); + + DEBUG(verifyFunction(*L->getHeader()->getParent())); + return true; +} + +bool LoopVectorizePass::runImpl( + Function &F, ScalarEvolution &SE_, LoopInfo &LI_, TargetTransformInfo &TTI_, + DominatorTree &DT_, BlockFrequencyInfo &BFI_, TargetLibraryInfo *TLI_, + DemandedBits &DB_, AliasAnalysis &AA_, AssumptionCache &AC_, + std::function<const LoopAccessInfo &(Loop &)> &GetLAA_) { + + SE = &SE_; + LI = &LI_; + TTI = &TTI_; + DT = &DT_; + BFI = &BFI_; + TLI = TLI_; + AA = &AA_; + AC = &AC_; + GetLAA = &GetLAA_; + DB = &DB_; + + // Compute some weights outside of the loop over the loops. Compute this + // using a BranchProbability to re-use its scaling math. + const BranchProbability ColdProb(1, 5); // 20% + ColdEntryFreq = BlockFrequency(BFI->getEntryFreq()) * ColdProb; + + // Don't attempt if + // 1. the target claims to have no vector registers, and + // 2. interleaving won't help ILP. + // + // The second condition is necessary because, even if the target has no + // vector registers, loop vectorization may still enable scalar + // interleaving. + if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2) + return false; + + // Build up a worklist of inner-loops to vectorize. This is necessary as + // the act of vectorizing or partially unrolling a loop creates new loops + // and can invalidate iterators across the loops. + SmallVector<Loop *, 8> Worklist; + + for (Loop *L : *LI) + addInnerLoop(*L, Worklist); + + LoopsAnalyzed += Worklist.size(); + + // Now walk the identified inner loops. + bool Changed = false; + while (!Worklist.empty()) + Changed |= processLoop(Worklist.pop_back_val()); + + // Process each loop nest in the function. + return Changed; + +} + + +PreservedAnalyses LoopVectorizePass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &BFI = AM.getResult<BlockFrequencyAnalysis>(F); + auto *TLI = AM.getCachedResult<TargetLibraryAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DB = AM.getResult<DemandedBitsAnalysis>(F); + + auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); + std::function<const LoopAccessInfo &(Loop &)> GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { + return LAM.getResult<LoopAccessAnalysis>(L); + }; + bool Changed = runImpl(F, SE, LI, TTI, DT, BFI, TLI, DB, AA, AC, GetLAA); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<LoopAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<BasicAA>(); + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/lib/Transforms/Vectorize/Makefile b/lib/Transforms/Vectorize/Makefile deleted file mode 100644 index 86c36585f23f..000000000000 --- a/lib/Transforms/Vectorize/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Transforms/Vectorize/Makefile -----------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../../.. -LIBRARYNAME = LLVMVectorize -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index f69a4e52c7e1..8a3c4d14fecb 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -15,21 +15,17 @@ // "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks. // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Vectorize.h" -#include "llvm/ADT/MapVector.h" +#include "llvm/Transforms/Vectorize/SLPVectorizer.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" @@ -44,12 +40,12 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Analysis/VectorUtils.h" +#include "llvm/Transforms/Vectorize.h" #include <algorithm> -#include <map> #include <memory> using namespace llvm; +using namespace slpvectorizer; #define SV_NAME "slp-vectorizer" #define DEBUG_TYPE "SLP" @@ -82,11 +78,11 @@ static cl::opt<int> ScheduleRegionSizeBudget("slp-schedule-budget", cl::init(100000), cl::Hidden, cl::desc("Limit the size of the SLP scheduling region per block")); -namespace { +static cl::opt<int> MinVectorRegSizeOption( + "slp-min-reg-size", cl::init(128), cl::Hidden, + cl::desc("Attempt to vectorize for this register size in bits")); // FIXME: Set this via cl::opt to allow overriding. -static const unsigned MinVecRegSize = 128; - static const unsigned RecursionMaxDepth = 12; // Limit the number of alias checks. The limit is chosen so that @@ -134,8 +130,8 @@ static BasicBlock *getSameBlock(ArrayRef<Value *> VL) { /// \returns True if all of the values in \p VL are constants. static bool allConstant(ArrayRef<Value *> VL) { - for (unsigned i = 0, e = VL.size(); i < e; ++i) - if (!isa<Constant>(VL[i])) + for (Value *i : VL) + if (!isa<Constant>(i)) return false; return true; } @@ -223,46 +219,6 @@ static void propagateIRFlags(Value *I, ArrayRef<Value *> VL) { } } -/// \returns \p I after propagating metadata from \p VL. -static Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL) { - Instruction *I0 = cast<Instruction>(VL[0]); - SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; - I0->getAllMetadataOtherThanDebugLoc(Metadata); - - for (unsigned i = 0, n = Metadata.size(); i != n; ++i) { - unsigned Kind = Metadata[i].first; - MDNode *MD = Metadata[i].second; - - for (int i = 1, e = VL.size(); MD && i != e; i++) { - Instruction *I = cast<Instruction>(VL[i]); - MDNode *IMD = I->getMetadata(Kind); - - switch (Kind) { - default: - MD = nullptr; // Remove unknown metadata - break; - case LLVMContext::MD_tbaa: - MD = MDNode::getMostGenericTBAA(MD, IMD); - break; - case LLVMContext::MD_alias_scope: - MD = MDNode::getMostGenericAliasScope(MD, IMD); - break; - case LLVMContext::MD_noalias: - MD = MDNode::intersect(MD, IMD); - break; - case LLVMContext::MD_fpmath: - MD = MDNode::getMostGenericFPMath(MD, IMD); - break; - case LLVMContext::MD_nontemporal: - MD = MDNode::intersect(MD, IMD); - break; - } - } - I->setMetadata(Kind, MD); - } - return I; -} - /// \returns The type that all of the values in \p VL have or null if there /// are different types. static Type* getSameType(ArrayRef<Value *> VL) { @@ -274,36 +230,17 @@ static Type* getSameType(ArrayRef<Value *> VL) { return Ty; } -/// \returns True if the ExtractElement instructions in VL can be vectorized -/// to use the original vector. -static bool CanReuseExtract(ArrayRef<Value *> VL) { - assert(Instruction::ExtractElement == getSameOpcode(VL) && "Invalid opcode"); - // Check if all of the extracts come from the same vector and from the - // correct offset. - Value *VL0 = VL[0]; - ExtractElementInst *E0 = cast<ExtractElementInst>(VL0); - Value *Vec = E0->getOperand(0); - - // We have to extract from the same vector type. - unsigned NElts = Vec->getType()->getVectorNumElements(); - - if (NElts != VL.size()) - return false; - - // Check that all of the indices extract from the correct offset. - ConstantInt *CI = dyn_cast<ConstantInt>(E0->getOperand(1)); - if (!CI || CI->getZExtValue()) - return false; - - for (unsigned i = 1, e = VL.size(); i < e; ++i) { - ExtractElementInst *E = cast<ExtractElementInst>(VL[i]); +/// \returns True if Extract{Value,Element} instruction extracts element Idx. +static bool matchExtractIndex(Instruction *E, unsigned Idx, unsigned Opcode) { + assert(Opcode == Instruction::ExtractElement || + Opcode == Instruction::ExtractValue); + if (Opcode == Instruction::ExtractElement) { ConstantInt *CI = dyn_cast<ConstantInt>(E->getOperand(1)); - - if (!CI || CI->getZExtValue() != i || E->getOperand(0) != Vec) - return false; + return CI && CI->getZExtValue() == Idx; + } else { + ExtractValueInst *EI = cast<ExtractValueInst>(E); + return EI->getNumIndices() == 1 && *EI->idx_begin() == Idx; } - - return true; } /// \returns True if in-tree use also needs extract. This refers to @@ -323,7 +260,7 @@ static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, } case Instruction::Call: { CallInst *CI = cast<CallInst>(UserInst); - Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); if (hasVectorInstrinsicScalarOpd(ID, 1)) { return (CI->getArgOperand(1) == Scalar); } @@ -353,6 +290,8 @@ static bool isSimple(Instruction *I) { return true; } +namespace llvm { +namespace slpvectorizer { /// Bottom Up SLP Vectorizer. class BoUpSLP { public: @@ -363,11 +302,24 @@ public: BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, - DominatorTree *Dt, AssumptionCache *AC) + DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB, + const DataLayout *DL) : NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func), - SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), - Builder(Se->getContext()) { + SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), DB(DB), + DL(DL), Builder(Se->getContext()) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); + // Use the vector register size specified by the target unless overridden + // by a command-line option. + // TODO: It would be better to limit the vectorization factor based on + // data type rather than just register size. For example, x86 AVX has + // 256-bit registers, but it does not support integer operations + // at that width (that requires AVX2). + if (MaxVectorRegSizeOption.getNumOccurrences()) + MaxVecRegSize = MaxVectorRegSizeOption; + else + MaxVecRegSize = TTI->getRegisterBitWidth(true); + + MinVecRegSize = MinVectorRegSizeOption; } /// \brief Vectorize the tree that starts with the elements in \p VL. @@ -399,11 +351,9 @@ public: BlockScheduling *BS = Iter.second.get(); BS->clear(); } + MinBWs.clear(); } - /// \returns true if the memory operations A and B are consecutive. - bool isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL); - /// \brief Perform LICM and CSE on the newly generated gather sequences. void optimizeGatherSequence(); @@ -412,6 +362,32 @@ public: return NumLoadsWantToChangeOrder > NumLoadsWantToKeepOrder; } + /// \return The vector element size in bits to use when vectorizing the + /// expression tree ending at \p V. If V is a store, the size is the width of + /// the stored value. Otherwise, the size is the width of the largest loaded + /// value reaching V. This method is used by the vectorizer to calculate + /// vectorization factors. + unsigned getVectorElementSize(Value *V); + + /// Compute the minimum type sizes required to represent the entries in a + /// vectorizable tree. + void computeMinimumValueSizes(); + + // \returns maximum vector register size as set by TTI or overridden by cl::opt. + unsigned getMaxVecRegSize() const { + return MaxVecRegSize; + } + + // \returns minimum vector register size as set by cl::opt. + unsigned getMinVecRegSize() const { + return MinVecRegSize; + } + + /// \brief Check if ArrayType or StructType is isomorphic to some VectorType. + /// + /// \returns number of elements in vector if isomorphism exists, 0 otherwise. + unsigned canMapToVector(Type *T, const DataLayout &DL) const; + private: struct TreeEntry; @@ -421,6 +397,10 @@ private: /// This is the recursive part of buildTree. void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth); + /// \returns True if the ExtractElement/ExtractValue instructions in VL can + /// be vectorized to use the original vector (or aggregate "bitcast" to a vector). + bool canReuseExtract(ArrayRef<Value *> VL, unsigned Opcode) const; + /// Vectorize a single entry in the tree. Value *vectorizeTree(TreeEntry *E); @@ -431,14 +411,6 @@ private: /// vectorized, or NULL. They may happen in cycles. Value *alreadyVectorized(ArrayRef<Value *> VL) const; - /// \brief Take the pointer operand from the Load/Store instruction. - /// \returns NULL if this is not a valid Load/Store instruction. - static Value *getPointerOperand(Value *I); - - /// \brief Take the address space operand from the Load/Store instruction. - /// \returns -1 if this is not a valid Load/Store instruction. - static unsigned getAddressSpaceOperand(Value *I); - /// \returns the scalarization cost for this type. Scalarization in this /// context means the creation of vectors from a group of scalars. int getGatherCost(Type *Ty); @@ -719,8 +691,11 @@ private: }; #ifndef NDEBUG - friend raw_ostream &operator<<(raw_ostream &os, - const BoUpSLP::ScheduleData &SD); + friend inline raw_ostream &operator<<(raw_ostream &os, + const BoUpSLP::ScheduleData &SD) { + SD.dump(os); + return os; + } #endif /// Contains all scheduling data for a basic block. @@ -917,16 +892,21 @@ private: AliasAnalysis *AA; LoopInfo *LI; DominatorTree *DT; + AssumptionCache *AC; + DemandedBits *DB; + const DataLayout *DL; + unsigned MaxVecRegSize; // This is set by TTI or overridden by cl::opt. + unsigned MinVecRegSize; // Set by cl::opt (default: 128). /// Instruction builder to construct the vectorized tree. IRBuilder<> Builder; + + /// A map of scalar integer values to the smallest bit width with which they + /// can legally be represented. + MapVector<Value *, uint64_t> MinBWs; }; -#ifndef NDEBUG -raw_ostream &operator<<(raw_ostream &os, const BoUpSLP::ScheduleData &SD) { - SD.dump(os); - return os; -} -#endif +} // end namespace llvm +} // end namespace slpvectorizer void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ArrayRef<Value *> UserIgnoreLst) { @@ -937,8 +917,8 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, buildTree_rec(Roots, 0); // Collect the values that we need to extract from the tree. - for (int EIdx = 0, EE = VectorizableTree.size(); EIdx < EE; ++EIdx) { - TreeEntry *Entry = &VectorizableTree[EIdx]; + for (TreeEntry &EIdx : VectorizableTree) { + TreeEntry *Entry = &EIdx; // For each lane: for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) { @@ -987,7 +967,7 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { - bool SameTy = getSameType(VL); (void)SameTy; + bool SameTy = allConstant(VL) || getSameType(VL); (void)SameTy; bool isAltShuffle = false; assert(SameTy && "Invalid types!"); @@ -1138,16 +1118,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (unsigned j = 0; j < VL.size(); ++j) - Operands.push_back(cast<PHINode>(VL[j])->getIncomingValueForBlock( + for (Value *j : VL) + Operands.push_back(cast<PHINode>(j)->getIncomingValueForBlock( PH->getIncomingBlock(i))); buildTree_rec(Operands, Depth + 1); } return; } + case Instruction::ExtractValue: case Instruction::ExtractElement: { - bool Reuse = CanReuseExtract(VL); + bool Reuse = canReuseExtract(VL, Opcode); if (Reuse) { DEBUG(dbgs() << "SLP: Reusing extract sequence.\n"); } else { @@ -1164,11 +1145,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // loading/storing it as an i8 struct. If we vectorize loads/stores from // such a struct we read/write packed bits disagreeing with the // unvectorized version. - const DataLayout &DL = F->getParent()->getDataLayout(); Type *ScalarTy = VL[0]->getType(); - if (DL.getTypeSizeInBits(ScalarTy) != - DL.getTypeAllocSizeInBits(ScalarTy)) { + if (DL->getTypeSizeInBits(ScalarTy) != + DL->getTypeAllocSizeInBits(ScalarTy)) { BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); @@ -1184,8 +1164,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { return; } - if (!isConsecutiveAccess(VL[i], VL[i + 1], DL)) { - if (VL.size() == 2 && isConsecutiveAccess(VL[1], VL[0], DL)) { + if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) { + if (VL.size() == 2 && isConsecutiveAccess(VL[1], VL[0], *DL, *SE)) { ++NumLoadsWantToChangeOrder; } BS.cancelScheduling(VL); @@ -1227,8 +1207,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (unsigned j = 0; j < VL.size(); ++j) - Operands.push_back(cast<Instruction>(VL[j])->getOperand(i)); + for (Value *j : VL) + Operands.push_back(cast<Instruction>(j)->getOperand(i)); buildTree_rec(Operands, Depth+1); } @@ -1256,8 +1236,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (unsigned j = 0; j < VL.size(); ++j) - Operands.push_back(cast<Instruction>(VL[j])->getOperand(i)); + for (Value *j : VL) + Operands.push_back(cast<Instruction>(j)->getOperand(i)); buildTree_rec(Operands, Depth+1); } @@ -1298,8 +1278,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (unsigned j = 0; j < VL.size(); ++j) - Operands.push_back(cast<Instruction>(VL[j])->getOperand(i)); + for (Value *j : VL) + Operands.push_back(cast<Instruction>(j)->getOperand(i)); buildTree_rec(Operands, Depth+1); } @@ -1346,18 +1326,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0, e = 2; i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (unsigned j = 0; j < VL.size(); ++j) - Operands.push_back(cast<Instruction>(VL[j])->getOperand(i)); + for (Value *j : VL) + Operands.push_back(cast<Instruction>(j)->getOperand(i)); buildTree_rec(Operands, Depth + 1); } return; } case Instruction::Store: { - const DataLayout &DL = F->getParent()->getDataLayout(); // Check if the stores are consecutive or of we need to swizzle them. for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) - if (!isConsecutiveAccess(VL[i], VL[i + 1], DL)) { + if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) { BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); @@ -1368,8 +1347,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { DEBUG(dbgs() << "SLP: added a vector of stores.\n"); ValueList Operands; - for (unsigned j = 0; j < VL.size(); ++j) - Operands.push_back(cast<Instruction>(VL[j])->getOperand(0)); + for (Value *j : VL) + Operands.push_back(cast<Instruction>(j)->getOperand(0)); buildTree_rec(Operands, Depth + 1); return; @@ -1379,7 +1358,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { CallInst *CI = cast<CallInst>(VL[0]); // Check if this is an Intrinsic call or something that can be // represented by an intrinsic call - Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); if (!isTriviallyVectorizable(ID)) { BS.cancelScheduling(VL); newTreeEntry(VL, false); @@ -1393,7 +1372,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 1, e = VL.size(); i != e; ++i) { CallInst *CI2 = dyn_cast<CallInst>(VL[i]); if (!CI2 || CI2->getCalledFunction() != Int || - getIntrinsicIDForCall(CI2, TLI) != ID) { + getVectorIntrinsicIDForCall(CI2, TLI) != ID || + !CI->hasIdenticalOperandBundleSchema(*CI2)) { BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i] @@ -1413,14 +1393,25 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { return; } } + // Verify that the bundle operands are identical between the two calls. + if (CI->hasOperandBundles() && + !std::equal(CI->op_begin() + CI->getBundleOperandsStartIndex(), + CI->op_begin() + CI->getBundleOperandsEndIndex(), + CI2->op_begin() + CI2->getBundleOperandsStartIndex())) { + BS.cancelScheduling(VL); + newTreeEntry(VL, false); + DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI << "!=" + << *VL[i] << '\n'); + return; + } } newTreeEntry(VL, true); for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { ValueList Operands; // Prepare the operand vector. - for (unsigned j = 0; j < VL.size(); ++j) { - CallInst *CI2 = dyn_cast<CallInst>(VL[j]); + for (Value *j : VL) { + CallInst *CI2 = dyn_cast<CallInst>(j); Operands.push_back(CI2->getArgOperand(i)); } buildTree_rec(Operands, Depth + 1); @@ -1451,8 +1442,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; // Prepare the operand vector. - for (unsigned j = 0; j < VL.size(); ++j) - Operands.push_back(cast<Instruction>(VL[j])->getOperand(i)); + for (Value *j : VL) + Operands.push_back(cast<Instruction>(j)->getOperand(i)); buildTree_rec(Operands, Depth + 1); } @@ -1466,6 +1457,74 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { } } +unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { + unsigned N; + Type *EltTy; + auto *ST = dyn_cast<StructType>(T); + if (ST) { + N = ST->getNumElements(); + EltTy = *ST->element_begin(); + } else { + N = cast<ArrayType>(T)->getNumElements(); + EltTy = cast<ArrayType>(T)->getElementType(); + } + if (!isValidElementType(EltTy)) + return 0; + uint64_t VTSize = DL.getTypeStoreSizeInBits(VectorType::get(EltTy, N)); + if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize || VTSize != DL.getTypeStoreSizeInBits(T)) + return 0; + if (ST) { + // Check that struct is homogeneous. + for (const auto *Ty : ST->elements()) + if (Ty != EltTy) + return 0; + } + return N; +} + +bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, unsigned Opcode) const { + assert(Opcode == Instruction::ExtractElement || + Opcode == Instruction::ExtractValue); + assert(Opcode == getSameOpcode(VL) && "Invalid opcode"); + // Check if all of the extracts come from the same vector and from the + // correct offset. + Value *VL0 = VL[0]; + Instruction *E0 = cast<Instruction>(VL0); + Value *Vec = E0->getOperand(0); + + // We have to extract from a vector/aggregate with the same number of elements. + unsigned NElts; + if (Opcode == Instruction::ExtractValue) { + const DataLayout &DL = E0->getModule()->getDataLayout(); + NElts = canMapToVector(Vec->getType(), DL); + if (!NElts) + return false; + // Check if load can be rewritten as load of vector. + LoadInst *LI = dyn_cast<LoadInst>(Vec); + if (!LI || !LI->isSimple() || !LI->hasNUses(VL.size())) + return false; + } else { + NElts = Vec->getType()->getVectorNumElements(); + } + + if (NElts != VL.size()) + return false; + + // Check that all of the indices extract from the correct offset. + if (!matchExtractIndex(E0, 0, Opcode)) + return false; + + for (unsigned i = 1, e = VL.size(); i < e; ++i) { + Instruction *E = cast<Instruction>(VL[i]); + if (!matchExtractIndex(E, i, Opcode)) + return false; + if (E->getOperand(0) != Vec) + return false; + } + + return true; +} + int BoUpSLP::getEntryCost(TreeEntry *E) { ArrayRef<Value*> VL = E->Scalars; @@ -1474,6 +1533,12 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { ScalarTy = SI->getValueOperand()->getType(); VectorType *VecTy = VectorType::get(ScalarTy, VL.size()); + // If we have computed a smaller type for the expression, update VecTy so + // that the costs will be accurate. + if (MinBWs.count(VL[0])) + VecTy = VectorType::get(IntegerType::get(F->getContext(), MinBWs[VL[0]]), + VL.size()); + if (E->NeedToGather) { if (allConstant(VL)) return 0; @@ -1489,11 +1554,12 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { case Instruction::PHI: { return 0; } + case Instruction::ExtractValue: case Instruction::ExtractElement: { - if (CanReuseExtract(VL)) { + if (canReuseExtract(VL, Opcode)) { int DeadCost = 0; for (unsigned i = 0, e = VL.size(); i < e; ++i) { - ExtractElementInst *E = cast<ExtractElementInst>(VL[i]); + Instruction *E = cast<Instruction>(VL[i]); if (E->hasOneUse()) // Take credit for instruction that will become dead. DeadCost += @@ -1527,7 +1593,14 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { } case Instruction::FCmp: case Instruction::ICmp: - case Instruction::Select: + case Instruction::Select: { + // Calculate the cost of this instruction. + VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size()); + int ScalarCost = VecTy->getNumElements() * + TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty()); + int VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy); + return VecCost - ScalarCost; + } case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -1546,59 +1619,48 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { case Instruction::And: case Instruction::Or: case Instruction::Xor: { - // Calculate the cost of this instruction. - int ScalarCost = 0; - int VecCost = 0; - if (Opcode == Instruction::FCmp || Opcode == Instruction::ICmp || - Opcode == Instruction::Select) { - VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size()); - ScalarCost = VecTy->getNumElements() * - TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty()); - VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy); - } else { - // Certain instructions can be cheaper to vectorize if they have a - // constant second vector operand. - TargetTransformInfo::OperandValueKind Op1VK = - TargetTransformInfo::OK_AnyValue; - TargetTransformInfo::OperandValueKind Op2VK = - TargetTransformInfo::OK_UniformConstantValue; - TargetTransformInfo::OperandValueProperties Op1VP = - TargetTransformInfo::OP_None; - TargetTransformInfo::OperandValueProperties Op2VP = - TargetTransformInfo::OP_None; - - // If all operands are exactly the same ConstantInt then set the - // operand kind to OK_UniformConstantValue. - // If instead not all operands are constants, then set the operand kind - // to OK_AnyValue. If all operands are constants but not the same, - // then set the operand kind to OK_NonUniformConstantValue. - ConstantInt *CInt = nullptr; - for (unsigned i = 0; i < VL.size(); ++i) { - const Instruction *I = cast<Instruction>(VL[i]); - if (!isa<ConstantInt>(I->getOperand(1))) { - Op2VK = TargetTransformInfo::OK_AnyValue; - break; - } - if (i == 0) { - CInt = cast<ConstantInt>(I->getOperand(1)); - continue; - } - if (Op2VK == TargetTransformInfo::OK_UniformConstantValue && - CInt != cast<ConstantInt>(I->getOperand(1))) - Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; + // Certain instructions can be cheaper to vectorize if they have a + // constant second vector operand. + TargetTransformInfo::OperandValueKind Op1VK = + TargetTransformInfo::OK_AnyValue; + TargetTransformInfo::OperandValueKind Op2VK = + TargetTransformInfo::OK_UniformConstantValue; + TargetTransformInfo::OperandValueProperties Op1VP = + TargetTransformInfo::OP_None; + TargetTransformInfo::OperandValueProperties Op2VP = + TargetTransformInfo::OP_None; + + // If all operands are exactly the same ConstantInt then set the + // operand kind to OK_UniformConstantValue. + // If instead not all operands are constants, then set the operand kind + // to OK_AnyValue. If all operands are constants but not the same, + // then set the operand kind to OK_NonUniformConstantValue. + ConstantInt *CInt = nullptr; + for (unsigned i = 0; i < VL.size(); ++i) { + const Instruction *I = cast<Instruction>(VL[i]); + if (!isa<ConstantInt>(I->getOperand(1))) { + Op2VK = TargetTransformInfo::OK_AnyValue; + break; + } + if (i == 0) { + CInt = cast<ConstantInt>(I->getOperand(1)); + continue; } - // FIXME: Currently cost of model modification for division by - // power of 2 is handled only for X86. Add support for other targets. - if (Op2VK == TargetTransformInfo::OK_UniformConstantValue && CInt && - CInt->getValue().isPowerOf2()) - Op2VP = TargetTransformInfo::OP_PowerOf2; - - ScalarCost = VecTy->getNumElements() * - TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, Op2VK, - Op1VP, Op2VP); - VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy, Op1VK, Op2VK, - Op1VP, Op2VP); + if (Op2VK == TargetTransformInfo::OK_UniformConstantValue && + CInt != cast<ConstantInt>(I->getOperand(1))) + Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; } + // FIXME: Currently cost of model modification for division by power of + // 2 is handled for X86 and AArch64. Add support for other targets. + if (Op2VK == TargetTransformInfo::OK_UniformConstantValue && CInt && + CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_PowerOf2; + + int ScalarCost = VecTy->getNumElements() * + TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, + Op2VK, Op1VP, Op2VP); + int VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy, Op1VK, Op2VK, + Op1VP, Op2VP); return VecCost - ScalarCost; } case Instruction::GetElementPtr: { @@ -1617,21 +1679,25 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { } case Instruction::Load: { // Cost of wide load - cost of scalar loads. + unsigned alignment = dyn_cast<LoadInst>(VL0)->getAlignment(); int ScalarLdCost = VecTy->getNumElements() * - TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0); - int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, VecTy, 1, 0); + TTI->getMemoryOpCost(Instruction::Load, ScalarTy, alignment, 0); + int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, + VecTy, alignment, 0); return VecLdCost - ScalarLdCost; } case Instruction::Store: { // We know that we can merge the stores. Calculate the cost. + unsigned alignment = dyn_cast<StoreInst>(VL0)->getAlignment(); int ScalarStCost = VecTy->getNumElements() * - TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0); - int VecStCost = TTI->getMemoryOpCost(Instruction::Store, VecTy, 1, 0); + TTI->getMemoryOpCost(Instruction::Store, ScalarTy, alignment, 0); + int VecStCost = TTI->getMemoryOpCost(Instruction::Store, + VecTy, alignment, 0); return VecStCost - ScalarStCost; } case Instruction::Call: { CallInst *CI = cast<CallInst>(VL0); - Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); // Calculate the cost of the scalar and vector calls. SmallVector<Type*, 4> ScalarTys, VecTys; @@ -1641,10 +1707,14 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { VecTy->getNumElements())); } + FastMathFlags FMF; + if (auto *FPMO = dyn_cast<FPMathOperator>(CI)) + FMF = FPMO->getFastMathFlags(); + int ScalarCallCost = VecTy->getNumElements() * - TTI->getIntrinsicInstrCost(ID, ScalarTy, ScalarTys); + TTI->getIntrinsicInstrCost(ID, ScalarTy, ScalarTys, FMF); - int VecCallCost = TTI->getIntrinsicInstrCost(ID, VecTy, VecTys); + int VecCallCost = TTI->getIntrinsicInstrCost(ID, VecTy, VecTys, FMF); DEBUG(dbgs() << "SLP: Call cost "<< VecCallCost - ScalarCallCost << " (" << VecCallCost << "-" << ScalarCallCost << ")" @@ -1659,8 +1729,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { TargetTransformInfo::OK_AnyValue; int ScalarCost = 0; int VecCost = 0; - for (unsigned i = 0; i < VL.size(); ++i) { - Instruction *I = cast<Instruction>(VL[i]); + for (Value *i : VL) { + Instruction *I = cast<Instruction>(i); if (!I) break; ScalarCost += @@ -1715,8 +1785,8 @@ int BoUpSLP::getSpillCost() { SmallPtrSet<Instruction*, 4> LiveValues; Instruction *PrevInst = nullptr; - for (unsigned N = 0; N < VectorizableTree.size(); ++N) { - Instruction *Inst = dyn_cast<Instruction>(VectorizableTree[N].Scalars[0]); + for (const auto &N : VectorizableTree) { + Instruction *Inst = dyn_cast<Instruction>(N.Scalars[0]); if (!Inst) continue; @@ -1725,6 +1795,13 @@ int BoUpSLP::getSpillCost() { continue; } + // Update LiveValues. + LiveValues.erase(PrevInst); + for (auto &J : PrevInst->operands()) { + if (isa<Instruction>(&*J) && ScalarToTreeEntry.count(&*J)) + LiveValues.insert(cast<Instruction>(&*J)); + } + DEBUG( dbgs() << "SLP: #LV: " << LiveValues.size(); for (auto *X : LiveValues) @@ -1733,13 +1810,6 @@ int BoUpSLP::getSpillCost() { Inst->dump(); ); - // Update LiveValues. - LiveValues.erase(PrevInst); - for (auto &J : PrevInst->operands()) { - if (isa<Instruction>(&*J) && ScalarToTreeEntry.count(&*J)) - LiveValues.insert(cast<Instruction>(&*J)); - } - // Now find the sequence of instructions between PrevInst and Inst. BasicBlock::reverse_iterator InstIt(Inst->getIterator()), PrevInstIt(PrevInst->getIterator()); @@ -1763,7 +1833,6 @@ int BoUpSLP::getSpillCost() { PrevInst = Inst; } - DEBUG(dbgs() << "SLP: SpillCost=" << Cost << "\n"); return Cost; } @@ -1785,7 +1854,7 @@ int BoUpSLP::getTreeCost() { for (TreeEntry &TE : VectorizableTree) { int C = getEntryCost(&TE); DEBUG(dbgs() << "SLP: Adding cost " << C << " for bundle that starts with " - << TE.Scalars[0] << " .\n"); + << *TE.Scalars[0] << ".\n"); Cost += C; } @@ -1802,15 +1871,29 @@ int BoUpSLP::getTreeCost() { if (EphValues.count(EU.User)) continue; - VectorType *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); - ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, - EU.Lane); + // If we plan to rewrite the tree in a smaller type, we will need to sign + // extend the extracted value back to the original type. Here, we account + // for the extract and the added cost of the sign extend if needed. + auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); + auto *ScalarRoot = VectorizableTree[0].Scalars[0]; + if (MinBWs.count(ScalarRoot)) { + auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]); + VecTy = VectorType::get(MinTy, BundleWidth); + ExtractCost += TTI->getExtractWithExtendCost( + Instruction::SExt, EU.Scalar->getType(), VecTy, EU.Lane); + } else { + ExtractCost += + TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); + } } - Cost += getSpillCost(); + int SpillCost = getSpillCost(); + Cost += SpillCost + ExtractCost; - DEBUG(dbgs() << "SLP: Total Cost " << Cost + ExtractCost<< ".\n"); - return Cost + ExtractCost; + DEBUG(dbgs() << "SLP: Spill Cost = " << SpillCost << ".\n" + << "SLP: Extract Cost = " << ExtractCost << ".\n" + << "SLP: Total Cost = " << Cost << ".\n"); + return Cost; } int BoUpSLP::getGatherCost(Type *Ty) { @@ -1830,63 +1913,6 @@ int BoUpSLP::getGatherCost(ArrayRef<Value *> VL) { return getGatherCost(VecTy); } -Value *BoUpSLP::getPointerOperand(Value *I) { - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - return LI->getPointerOperand(); - if (StoreInst *SI = dyn_cast<StoreInst>(I)) - return SI->getPointerOperand(); - return nullptr; -} - -unsigned BoUpSLP::getAddressSpaceOperand(Value *I) { - if (LoadInst *L = dyn_cast<LoadInst>(I)) - return L->getPointerAddressSpace(); - if (StoreInst *S = dyn_cast<StoreInst>(I)) - return S->getPointerAddressSpace(); - return -1; -} - -bool BoUpSLP::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL) { - Value *PtrA = getPointerOperand(A); - Value *PtrB = getPointerOperand(B); - unsigned ASA = getAddressSpaceOperand(A); - unsigned ASB = getAddressSpaceOperand(B); - - // Check that the address spaces match and that the pointers are valid. - if (!PtrA || !PtrB || (ASA != ASB)) - return false; - - // Make sure that A and B are different pointers of the same type. - if (PtrA == PtrB || PtrA->getType() != PtrB->getType()) - return false; - - unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA); - Type *Ty = cast<PointerType>(PtrA->getType())->getElementType(); - APInt Size(PtrBitWidth, DL.getTypeStoreSize(Ty)); - - APInt OffsetA(PtrBitWidth, 0), OffsetB(PtrBitWidth, 0); - PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); - PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); - - APInt OffsetDelta = OffsetB - OffsetA; - - // Check if they are based on the same pointer. That makes the offsets - // sufficient. - if (PtrA == PtrB) - return OffsetDelta == Size; - - // Compute the necessary base pointer delta to have the necessary final delta - // equal to the size. - APInt BaseDelta = Size - OffsetDelta; - - // Otherwise compute the distance with SCEV between the base pointers. - const SCEV *PtrSCEVA = SE->getSCEV(PtrA); - const SCEV *PtrSCEVB = SE->getSCEV(PtrB); - const SCEV *C = SE->getConstant(BaseDelta); - const SCEV *X = SE->getAddExpr(PtrSCEVA, C); - return X == PtrSCEVB; -} - // Reorder commutative operations in alternate shuffle if the resulting vectors // are consecutive loads. This would allow us to vectorize the tree. // If we have something like- @@ -1899,12 +1925,10 @@ bool BoUpSLP::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL) { void BoUpSLP::reorderAltShuffleOperands(ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, SmallVectorImpl<Value *> &Right) { - const DataLayout &DL = F->getParent()->getDataLayout(); - // Push left and right operands of binary operation into Left and Right - for (unsigned i = 0, e = VL.size(); i < e; ++i) { - Left.push_back(cast<Instruction>(VL[i])->getOperand(0)); - Right.push_back(cast<Instruction>(VL[i])->getOperand(1)); + for (Value *i : VL) { + Left.push_back(cast<Instruction>(i)->getOperand(0)); + Right.push_back(cast<Instruction>(i)->getOperand(1)); } // Reorder if we have a commutative operation and consecutive access @@ -1914,10 +1938,11 @@ void BoUpSLP::reorderAltShuffleOperands(ArrayRef<Value *> VL, if (LoadInst *L1 = dyn_cast<LoadInst>(Right[j + 1])) { Instruction *VL1 = cast<Instruction>(VL[j]); Instruction *VL2 = cast<Instruction>(VL[j + 1]); - if (isConsecutiveAccess(L, L1, DL) && VL1->isCommutative()) { + if (VL1->isCommutative() && isConsecutiveAccess(L, L1, *DL, *SE)) { std::swap(Left[j], Right[j]); continue; - } else if (isConsecutiveAccess(L, L1, DL) && VL2->isCommutative()) { + } else if (VL2->isCommutative() && + isConsecutiveAccess(L, L1, *DL, *SE)) { std::swap(Left[j + 1], Right[j + 1]); continue; } @@ -1928,10 +1953,11 @@ void BoUpSLP::reorderAltShuffleOperands(ArrayRef<Value *> VL, if (LoadInst *L1 = dyn_cast<LoadInst>(Left[j + 1])) { Instruction *VL1 = cast<Instruction>(VL[j]); Instruction *VL2 = cast<Instruction>(VL[j + 1]); - if (isConsecutiveAccess(L, L1, DL) && VL1->isCommutative()) { + if (VL1->isCommutative() && isConsecutiveAccess(L, L1, *DL, *SE)) { std::swap(Left[j], Right[j]); continue; - } else if (isConsecutiveAccess(L, L1, DL) && VL2->isCommutative()) { + } else if (VL2->isCommutative() && + isConsecutiveAccess(L, L1, *DL, *SE)) { std::swap(Left[j + 1], Right[j + 1]); continue; } @@ -2061,8 +2087,6 @@ void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, if (SplatRight || SplatLeft) return; - const DataLayout &DL = F->getParent()->getDataLayout(); - // Finally check if we can get longer vectorizable chain by reordering // without breaking the good operand order detected above. // E.g. If we have something like- @@ -2081,7 +2105,7 @@ void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, for (unsigned j = 0; j < VL.size() - 1; ++j) { if (LoadInst *L = dyn_cast<LoadInst>(Left[j])) { if (LoadInst *L1 = dyn_cast<LoadInst>(Right[j + 1])) { - if (isConsecutiveAccess(L, L1, DL)) { + if (isConsecutiveAccess(L, L1, *DL, *SE)) { std::swap(Left[j + 1], Right[j + 1]); continue; } @@ -2089,7 +2113,7 @@ void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, } if (LoadInst *L = dyn_cast<LoadInst>(Right[j])) { if (LoadInst *L1 = dyn_cast<LoadInst>(Left[j + 1])) { - if (isConsecutiveAccess(L, L1, DL)) { + if (isConsecutiveAccess(L, L1, *DL, *SE)) { std::swap(Left[j + 1], Right[j + 1]); continue; } @@ -2185,7 +2209,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return Gather(E->Scalars, VecTy); } - const DataLayout &DL = F->getParent()->getDataLayout(); unsigned Opcode = getSameOpcode(E->Scalars); switch (Opcode) { @@ -2225,13 +2248,25 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } case Instruction::ExtractElement: { - if (CanReuseExtract(E->Scalars)) { + if (canReuseExtract(E->Scalars, Instruction::ExtractElement)) { Value *V = VL0->getOperand(0); E->VectorizedValue = V; return V; } return Gather(E->Scalars, VecTy); } + case Instruction::ExtractValue: { + if (canReuseExtract(E->Scalars, Instruction::ExtractValue)) { + LoadInst *LI = cast<LoadInst>(VL0->getOperand(0)); + Builder.SetInsertPoint(LI); + PointerType *PtrTy = PointerType::get(VecTy, LI->getPointerAddressSpace()); + Value *Ptr = Builder.CreateBitCast(LI->getOperand(0), PtrTy); + LoadInst *V = Builder.CreateAlignedLoad(Ptr, LI->getAlignment()); + E->VectorizedValue = V; + return propagateMetadata(V, E->Scalars); + } + return Gather(E->Scalars, VecTy); + } case Instruction::ZExt: case Instruction::SExt: case Instruction::FPToUI: @@ -2382,7 +2417,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { unsigned Alignment = LI->getAlignment(); LI = Builder.CreateLoad(VecPtr); if (!Alignment) { - Alignment = DL.getABITypeAlignment(ScalarLoadTy); + Alignment = DL->getABITypeAlignment(ScalarLoadTy); } LI->setAlignment(Alignment); E->VectorizedValue = LI; @@ -2413,7 +2448,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ExternalUser(SI->getPointerOperand(), cast<User>(VecPtr), 0)); if (!Alignment) { - Alignment = DL.getABITypeAlignment(SI->getValueOperand()->getType()); + Alignment = DL->getABITypeAlignment(SI->getValueOperand()->getType()); } S->setAlignment(Alignment); E->VectorizedValue = S; @@ -2481,10 +2516,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Module *M = F->getParent(); - Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); Type *Tys[] = { VectorType::get(CI->getType(), E->Scalars.size()) }; Function *CF = Intrinsic::getDeclaration(M, ID, Tys); - Value *V = Builder.CreateCall(CF, OpVecs); + SmallVector<OperandBundleDef, 1> OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + Value *V = Builder.CreateCall(CF, OpVecs, OpBundles); // The scalar argument uses an in-tree scalar so we add the new vectorized // call to ExternalUses list to make sure that an extract will be @@ -2559,15 +2596,28 @@ Value *BoUpSLP::vectorizeTree() { } Builder.SetInsertPoint(&F->getEntryBlock().front()); - vectorizeTree(&VectorizableTree[0]); + auto *VectorRoot = vectorizeTree(&VectorizableTree[0]); + + // If the vectorized tree can be rewritten in a smaller type, we truncate the + // vectorized root. InstCombine will then rewrite the entire expression. We + // sign extend the extracted values below. + auto *ScalarRoot = VectorizableTree[0].Scalars[0]; + if (MinBWs.count(ScalarRoot)) { + if (auto *I = dyn_cast<Instruction>(VectorRoot)) + Builder.SetInsertPoint(&*++BasicBlock::iterator(I)); + auto BundleWidth = VectorizableTree[0].Scalars.size(); + auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]); + auto *VecTy = VectorType::get(MinTy, BundleWidth); + auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy); + VectorizableTree[0].VectorizedValue = Trunc; + } DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); // Extract all of the elements with the external uses. - for (UserList::iterator it = ExternalUses.begin(), e = ExternalUses.end(); - it != e; ++it) { - Value *Scalar = it->Scalar; - llvm::User *User = it->User; + for (const auto &ExternalUse : ExternalUses) { + Value *Scalar = ExternalUse.Scalar; + llvm::User *User = ExternalUse.User; // Skip users that we already RAUW. This happens when one instruction // has multiple uses of the same value. @@ -2583,15 +2633,24 @@ Value *BoUpSLP::vectorizeTree() { Value *Vec = E->VectorizedValue; assert(Vec && "Can't find vectorizable value"); - Value *Lane = Builder.getInt32(it->Lane); + Value *Lane = Builder.getInt32(ExternalUse.Lane); // Generate extracts for out-of-tree users. // Find the insertion point for the extractelement lane. - if (isa<Instruction>(Vec)){ + if (auto *VecI = dyn_cast<Instruction>(Vec)) { if (PHINode *PH = dyn_cast<PHINode>(User)) { for (int i = 0, e = PH->getNumIncomingValues(); i != e; ++i) { if (PH->getIncomingValue(i) == Scalar) { - Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); + TerminatorInst *IncomingTerminator = + PH->getIncomingBlock(i)->getTerminator(); + if (isa<CatchSwitchInst>(IncomingTerminator)) { + Builder.SetInsertPoint(VecI->getParent(), + std::next(VecI->getIterator())); + } else { + Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); + } Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MinBWs.count(ScalarRoot)) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(PH->getIncomingBlock(i)); PH->setOperand(i, Ex); } @@ -2599,12 +2658,16 @@ Value *BoUpSLP::vectorizeTree() { } else { Builder.SetInsertPoint(cast<Instruction>(User)); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MinBWs.count(ScalarRoot)) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(cast<Instruction>(User)->getParent()); User->replaceUsesOfWith(Scalar, Ex); } } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MinBWs.count(ScalarRoot)) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(&F->getEntryBlock()); User->replaceUsesOfWith(Scalar, Ex); } @@ -2613,8 +2676,8 @@ Value *BoUpSLP::vectorizeTree() { } // For each vectorized value: - for (int EIdx = 0, EE = VectorizableTree.size(); EIdx < EE; ++EIdx) { - TreeEntry *Entry = &VectorizableTree[EIdx]; + for (TreeEntry &EIdx : VectorizableTree) { + TreeEntry *Entry = &EIdx; // For each lane: for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) { @@ -2655,9 +2718,8 @@ void BoUpSLP::optimizeGatherSequence() { DEBUG(dbgs() << "SLP: Optimizing " << GatherSeq.size() << " gather sequences instructions.\n"); // LICM InsertElementInst sequences. - for (SetVector<Instruction *>::iterator it = GatherSeq.begin(), - e = GatherSeq.end(); it != e; ++it) { - InsertElementInst *Insert = dyn_cast<InsertElementInst>(*it); + for (Instruction *it : GatherSeq) { + InsertElementInst *Insert = dyn_cast<InsertElementInst>(it); if (!Insert) continue; @@ -2718,12 +2780,10 @@ void BoUpSLP::optimizeGatherSequence() { // Check if we can replace this instruction with any of the // visited instructions. - for (SmallVectorImpl<Instruction *>::iterator v = Visited.begin(), - ve = Visited.end(); - v != ve; ++v) { - if (In->isIdenticalTo(*v) && - DT->dominates((*v)->getParent(), In->getParent())) { - In->replaceAllUsesWith(*v); + for (Instruction *v : Visited) { + if (In->isIdenticalTo(v) && + DT->dominates(v->getParent(), In->getParent())) { + In->replaceAllUsesWith(v); eraseInstruction(In); In = nullptr; break; @@ -3139,90 +3199,265 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { BS->ScheduleStart = nullptr; } -/// The SLPVectorizer Pass. -struct SLPVectorizer : public FunctionPass { - typedef SmallVector<StoreInst *, 8> StoreList; - typedef MapVector<Value *, StoreList> StoreListMap; +unsigned BoUpSLP::getVectorElementSize(Value *V) { + // If V is a store, just return the width of the stored value without + // traversing the expression tree. This is the common case. + if (auto *Store = dyn_cast<StoreInst>(V)) + return DL->getTypeSizeInBits(Store->getValueOperand()->getType()); + + // If V is not a store, we can traverse the expression tree to find loads + // that feed it. The type of the loaded value may indicate a more suitable + // width than V's type. We want to base the vector element size on the width + // of memory operations where possible. + SmallVector<Instruction *, 16> Worklist; + SmallPtrSet<Instruction *, 16> Visited; + if (auto *I = dyn_cast<Instruction>(V)) + Worklist.push_back(I); + + // Traverse the expression tree in bottom-up order looking for loads. If we + // encounter an instruciton we don't yet handle, we give up. + auto MaxWidth = 0u; + auto FoundUnknownInst = false; + while (!Worklist.empty() && !FoundUnknownInst) { + auto *I = Worklist.pop_back_val(); + Visited.insert(I); + + // We should only be looking at scalar instructions here. If the current + // instruction has a vector type, give up. + auto *Ty = I->getType(); + if (isa<VectorType>(Ty)) + FoundUnknownInst = true; + + // If the current instruction is a load, update MaxWidth to reflect the + // width of the loaded value. + else if (isa<LoadInst>(I)) + MaxWidth = std::max<unsigned>(MaxWidth, DL->getTypeSizeInBits(Ty)); + + // Otherwise, we need to visit the operands of the instruction. We only + // handle the interesting cases from buildTree here. If an operand is an + // instruction we haven't yet visited, we add it to the worklist. + else if (isa<PHINode>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || + isa<CmpInst>(I) || isa<SelectInst>(I) || isa<BinaryOperator>(I)) { + for (Use &U : I->operands()) + if (auto *J = dyn_cast<Instruction>(U.get())) + if (!Visited.count(J)) + Worklist.push_back(J); + } - /// Pass identification, replacement for typeid - static char ID; + // If we don't yet handle the instruction, give up. + else + FoundUnknownInst = true; + } - explicit SLPVectorizer() : FunctionPass(ID) { - initializeSLPVectorizerPass(*PassRegistry::getPassRegistry()); + // If we didn't encounter a memory access in the expression tree, or if we + // gave up for some reason, just return the width of V. + if (!MaxWidth || FoundUnknownInst) + return DL->getTypeSizeInBits(V->getType()); + + // Otherwise, return the maximum width we found. + return MaxWidth; +} + +// Determine if a value V in a vectorizable expression Expr can be demoted to a +// smaller type with a truncation. We collect the values that will be demoted +// in ToDemote and additional roots that require investigating in Roots. +static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr, + SmallVectorImpl<Value *> &ToDemote, + SmallVectorImpl<Value *> &Roots) { + + // We can always demote constants. + if (isa<Constant>(V)) { + ToDemote.push_back(V); + return true; } - ScalarEvolution *SE; - TargetTransformInfo *TTI; - TargetLibraryInfo *TLI; - AliasAnalysis *AA; - LoopInfo *LI; - DominatorTree *DT; - AssumptionCache *AC; + // If the value is not an instruction in the expression with only one use, it + // cannot be demoted. + auto *I = dyn_cast<Instruction>(V); + if (!I || !I->hasOneUse() || !Expr.count(I)) + return false; - bool runOnFunction(Function &F) override { - if (skipOptnoneFunction(F)) + switch (I->getOpcode()) { + + // We can always demote truncations and extensions. Since truncations can + // seed additional demotion, we save the truncated value. + case Instruction::Trunc: + Roots.push_back(I->getOperand(0)); + case Instruction::ZExt: + case Instruction::SExt: + break; + + // We can demote certain binary operations if we can demote both of their + // operands. + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + if (!collectValuesToDemote(I->getOperand(0), Expr, ToDemote, Roots) || + !collectValuesToDemote(I->getOperand(1), Expr, ToDemote, Roots)) return false; + break; - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - TLI = TLIP ? &TLIP->getTLI() : nullptr; - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - - StoreRefs.clear(); - bool Changed = false; - - // If the target claims to have no vector registers don't attempt - // vectorization. - if (!TTI->getNumberOfRegisters(true)) + // We can demote selects if we can demote their true and false values. + case Instruction::Select: { + SelectInst *SI = cast<SelectInst>(I); + if (!collectValuesToDemote(SI->getTrueValue(), Expr, ToDemote, Roots) || + !collectValuesToDemote(SI->getFalseValue(), Expr, ToDemote, Roots)) return false; + break; + } - // Use the vector register size specified by the target unless overridden - // by a command-line option. - // TODO: It would be better to limit the vectorization factor based on - // data type rather than just register size. For example, x86 AVX has - // 256-bit registers, but it does not support integer operations - // at that width (that requires AVX2). - if (MaxVectorRegSizeOption.getNumOccurrences()) - MaxVecRegSize = MaxVectorRegSizeOption; - else - MaxVecRegSize = TTI->getRegisterBitWidth(true); + // We can demote phis if we can demote all their incoming operands. Note that + // we don't need to worry about cycles since we ensure single use above. + case Instruction::PHI: { + PHINode *PN = cast<PHINode>(I); + for (Value *IncValue : PN->incoming_values()) + if (!collectValuesToDemote(IncValue, Expr, ToDemote, Roots)) + return false; + break; + } - // Don't vectorize when the attribute NoImplicitFloat is used. - if (F.hasFnAttribute(Attribute::NoImplicitFloat)) - return false; + // Otherwise, conservatively give up. + default: + return false; + } - DEBUG(dbgs() << "SLP: Analyzing blocks in " << F.getName() << ".\n"); + // Record the value that we can demote. + ToDemote.push_back(V); + return true; +} - // Use the bottom up slp vectorizer to construct chains that start with - // store instructions. - BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC); +void BoUpSLP::computeMinimumValueSizes() { + // If there are no external uses, the expression tree must be rooted by a + // store. We can't demote in-memory values, so there is nothing to do here. + if (ExternalUses.empty()) + return; - // A general note: the vectorizer must use BoUpSLP::eraseInstruction() to - // delete instructions. + // We only attempt to truncate integer expressions. + auto &TreeRoot = VectorizableTree[0].Scalars; + auto *TreeRootIT = dyn_cast<IntegerType>(TreeRoot[0]->getType()); + if (!TreeRootIT) + return; - // Scan the blocks in the function in post order. - for (auto BB : post_order(&F.getEntryBlock())) { - // Vectorize trees that end at stores. - if (unsigned count = collectStores(BB, R)) { - (void)count; - DEBUG(dbgs() << "SLP: Found " << count << " stores to vectorize.\n"); - Changed |= vectorizeStoreChains(R); - } + // If the expression is not rooted by a store, these roots should have + // external uses. We will rely on InstCombine to rewrite the expression in + // the narrower type. However, InstCombine only rewrites single-use values. + // This means that if a tree entry other than a root is used externally, it + // must have multiple uses and InstCombine will not rewrite it. The code + // below ensures that only the roots are used externally. + SmallPtrSet<Value *, 32> Expr(TreeRoot.begin(), TreeRoot.end()); + for (auto &EU : ExternalUses) + if (!Expr.erase(EU.Scalar)) + return; + if (!Expr.empty()) + return; - // Vectorize trees that end at reductions. - Changed |= vectorizeChainsInBlock(BB, R); - } + // Collect the scalar values of the vectorizable expression. We will use this + // context to determine which values can be demoted. If we see a truncation, + // we mark it as seeding another demotion. + for (auto &Entry : VectorizableTree) + Expr.insert(Entry.Scalars.begin(), Entry.Scalars.end()); - if (Changed) { - R.optimizeGatherSequence(); - DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n"); - DEBUG(verifyFunction(F)); + // Ensure the roots of the vectorizable tree don't form a cycle. They must + // have a single external user that is not in the vectorizable tree. + for (auto *Root : TreeRoot) + if (!Root->hasOneUse() || Expr.count(*Root->user_begin())) + return; + + // Conservatively determine if we can actually truncate the roots of the + // expression. Collect the values that can be demoted in ToDemote and + // additional roots that require investigating in Roots. + SmallVector<Value *, 32> ToDemote; + SmallVector<Value *, 4> Roots; + for (auto *Root : TreeRoot) + if (!collectValuesToDemote(Root, Expr, ToDemote, Roots)) + return; + + // The maximum bit width required to represent all the values that can be + // demoted without loss of precision. It would be safe to truncate the roots + // of the expression to this width. + auto MaxBitWidth = 8u; + + // We first check if all the bits of the roots are demanded. If they're not, + // we can truncate the roots to this narrower type. + for (auto *Root : TreeRoot) { + auto Mask = DB->getDemandedBits(cast<Instruction>(Root)); + MaxBitWidth = std::max<unsigned>( + Mask.getBitWidth() - Mask.countLeadingZeros(), MaxBitWidth); + } + + // If all the bits of the roots are demanded, we can try a little harder to + // compute a narrower type. This can happen, for example, if the roots are + // getelementptr indices. InstCombine promotes these indices to the pointer + // width. Thus, all their bits are technically demanded even though the + // address computation might be vectorized in a smaller type. + // + // We start by looking at each entry that can be demoted. We compute the + // maximum bit width required to store the scalar by using ValueTracking to + // compute the number of high-order bits we can truncate. + if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType())) { + MaxBitWidth = 8u; + for (auto *Scalar : ToDemote) { + auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, 0, DT); + auto NumTypeBits = DL->getTypeSizeInBits(Scalar->getType()); + MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth); } - return Changed; + } + + // Round MaxBitWidth up to the next power-of-two. + if (!isPowerOf2_64(MaxBitWidth)) + MaxBitWidth = NextPowerOf2(MaxBitWidth); + + // If the maximum bit width we compute is less than the with of the roots' + // type, we can proceed with the narrowing. Otherwise, do nothing. + if (MaxBitWidth >= TreeRootIT->getBitWidth()) + return; + + // If we can truncate the root, we must collect additional values that might + // be demoted as a result. That is, those seeded by truncations we will + // modify. + while (!Roots.empty()) + collectValuesToDemote(Roots.pop_back_val(), Expr, ToDemote, Roots); + + // Finally, map the values we can demote to the maximum bit with we computed. + for (auto *Scalar : ToDemote) + MinBWs[Scalar] = MaxBitWidth; +} + +namespace { +/// The SLPVectorizer Pass. +struct SLPVectorizer : public FunctionPass { + SLPVectorizerPass Impl; + + /// Pass identification, replacement for typeid + static char ID; + + explicit SLPVectorizer() : FunctionPass(ID) { + initializeSLPVectorizerPass(*PassRegistry::getPassRegistry()); + } + + + bool doInitialization(Module &M) override { + return false; + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + auto *TLI = TLIP ? &TLIP->getTLI() : nullptr; + auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); + + return Impl.runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -3233,51 +3468,105 @@ struct SLPVectorizer : public FunctionPass { AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<DemandedBitsWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.setPreservesCFG(); } +}; +} // end anonymous namespace -private: +PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) { + auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); + auto *TTI = &AM.getResult<TargetIRAnalysis>(F); + auto *TLI = AM.getCachedResult<TargetLibraryAnalysis>(F); + auto *AA = &AM.getResult<AAManager>(F); + auto *LI = &AM.getResult<LoopAnalysis>(F); + auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); + auto *AC = &AM.getResult<AssumptionAnalysis>(F); + auto *DB = &AM.getResult<DemandedBitsAnalysis>(F); + + bool Changed = runImpl(F, SE, TTI, TLI, AA, LI, DT, AC, DB); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<LoopAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<AAManager>(); + PA.preserve<GlobalsAA>(); + return PA; +} - /// \brief Collect memory references and sort them according to their base - /// object. We sort the stores to their base objects to reduce the cost of the - /// quadratic search on the stores. TODO: We can further reduce this cost - /// if we flush the chain creation every time we run into a memory barrier. - unsigned collectStores(BasicBlock *BB, BoUpSLP &R); +bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, + TargetTransformInfo *TTI_, + TargetLibraryInfo *TLI_, AliasAnalysis *AA_, + LoopInfo *LI_, DominatorTree *DT_, + AssumptionCache *AC_, DemandedBits *DB_) { + SE = SE_; + TTI = TTI_; + TLI = TLI_; + AA = AA_; + LI = LI_; + DT = DT_; + AC = AC_; + DB = DB_; + DL = &F.getParent()->getDataLayout(); + + Stores.clear(); + GEPs.clear(); + bool Changed = false; - /// \brief Try to vectorize a chain that starts at two arithmetic instrs. - bool tryToVectorizePair(Value *A, Value *B, BoUpSLP &R); + // If the target claims to have no vector registers don't attempt + // vectorization. + if (!TTI->getNumberOfRegisters(true)) + return false; - /// \brief Try to vectorize a list of operands. - /// \@param BuildVector A list of users to ignore for the purpose of - /// scheduling and that don't need extracting. - /// \returns true if a value was vectorized. - bool tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, - ArrayRef<Value *> BuildVector = None, - bool allowReorder = false); + // Don't vectorize when the attribute NoImplicitFloat is used. + if (F.hasFnAttribute(Attribute::NoImplicitFloat)) + return false; - /// \brief Try to vectorize a chain that may start at the operands of \V; - bool tryToVectorize(BinaryOperator *V, BoUpSLP &R); + DEBUG(dbgs() << "SLP: Analyzing blocks in " << F.getName() << ".\n"); - /// \brief Vectorize the stores that were collected in StoreRefs. - bool vectorizeStoreChains(BoUpSLP &R); + // Use the bottom up slp vectorizer to construct chains that start with + // store instructions. + BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB, DL); - /// \brief Scan the basic block and look for patterns that are likely to start - /// a vectorization chain. - bool vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R); + // A general note: the vectorizer must use BoUpSLP::eraseInstruction() to + // delete instructions. - bool vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold, - BoUpSLP &R, unsigned VecRegSize); + // Scan the blocks in the function in post order. + for (auto BB : post_order(&F.getEntryBlock())) { + collectSeedInstructions(BB); - bool vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold, - BoUpSLP &R); -private: - StoreListMap StoreRefs; - unsigned MaxVecRegSize; // This is set by TTI or overridden by cl::opt. -}; + // Vectorize trees that end at stores. + if (!Stores.empty()) { + DEBUG(dbgs() << "SLP: Found stores for " << Stores.size() + << " underlying objects.\n"); + Changed |= vectorizeStoreChains(R); + } + + // Vectorize trees that end at reductions. + Changed |= vectorizeChainsInBlock(BB, R); + + // Vectorize the index computations of getelementptr instructions. This + // is primarily intended to catch gather-like idioms ending at + // non-consecutive loads. + if (!GEPs.empty()) { + DEBUG(dbgs() << "SLP: Found GEPs for " << GEPs.size() + << " underlying objects.\n"); + Changed |= vectorizeGEPIndices(BB, R); + } + } + + if (Changed) { + R.optimizeGatherSequence(); + DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n"); + DEBUG(verifyFunction(F)); + } + return Changed; +} /// \brief Check that the Values in the slice in VL array are still existent in /// the WeakVH array. @@ -3290,15 +3579,13 @@ static bool hasValueBeenRAUWed(ArrayRef<Value *> VL, ArrayRef<WeakVH> VH, return !std::equal(VL.begin(), VL.end(), VH.begin()); } -bool SLPVectorizer::vectorizeStoreChain(ArrayRef<Value *> Chain, - int CostThreshold, BoUpSLP &R, - unsigned VecRegSize) { +bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, + int CostThreshold, BoUpSLP &R, + unsigned VecRegSize) { unsigned ChainLen = Chain.size(); DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen << "\n"); - Type *StoreTy = cast<StoreInst>(Chain[0])->getValueOperand()->getType(); - auto &DL = cast<StoreInst>(Chain[0])->getModule()->getDataLayout(); - unsigned Sz = DL.getTypeSizeInBits(StoreTy); + unsigned Sz = R.getVectorElementSize(Chain[0]); unsigned VF = VecRegSize / Sz; if (!isPowerOf2_32(Sz) || VF < 2) @@ -3322,6 +3609,7 @@ bool SLPVectorizer::vectorizeStoreChain(ArrayRef<Value *> Chain, ArrayRef<Value *> Operands = Chain.slice(i, VF); R.buildTree(Operands); + R.computeMinimumValueSizes(); int Cost = R.getTreeCost(); @@ -3339,8 +3627,8 @@ bool SLPVectorizer::vectorizeStoreChain(ArrayRef<Value *> Chain, return Changed; } -bool SLPVectorizer::vectorizeStores(ArrayRef<StoreInst *> Stores, - int costThreshold, BoUpSLP &R) { +bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, + int costThreshold, BoUpSLP &R) { SetVector<StoreInst *> Heads, Tails; SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain; @@ -3353,7 +3641,6 @@ bool SLPVectorizer::vectorizeStores(ArrayRef<StoreInst *> Stores, // all of the pairs of stores that follow each other. SmallVector<unsigned, 16> IndexQueue; for (unsigned i = 0, e = Stores.size(); i < e; ++i) { - const DataLayout &DL = Stores[i]->getModule()->getDataLayout(); IndexQueue.clear(); // If a store has multiple consecutive store candidates, search Stores // array according to the sequence: from i+1 to e, then from i-1 to 0. @@ -3366,7 +3653,7 @@ bool SLPVectorizer::vectorizeStores(ArrayRef<StoreInst *> Stores, IndexQueue.push_back(j - 1); for (auto &k : IndexQueue) { - if (R.isConsecutiveAccess(Stores[i], Stores[k], DL)) { + if (isConsecutiveAccess(Stores[i], Stores[k], *DL, *SE)) { Tails.insert(Stores[k]); Heads.insert(Stores[i]); ConsecutiveChain[Stores[i]] = Stores[k]; @@ -3396,7 +3683,7 @@ bool SLPVectorizer::vectorizeStores(ArrayRef<StoreInst *> Stores, // FIXME: Is division-by-2 the correct step? Should we assert that the // register size is a power-of-2? - for (unsigned Size = MaxVecRegSize; Size >= MinVecRegSize; Size /= 2) { + for (unsigned Size = R.getMaxVecRegSize(); Size >= R.getMinVecRegSize(); Size /= 2) { if (vectorizeStoreChain(Operands, costThreshold, R, Size)) { // Mark the vectorized stores so that we don't vectorize them again. VectorizedStores.insert(Operands.begin(), Operands.end()); @@ -3409,45 +3696,53 @@ bool SLPVectorizer::vectorizeStores(ArrayRef<StoreInst *> Stores, return Changed; } +void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) { -unsigned SLPVectorizer::collectStores(BasicBlock *BB, BoUpSLP &R) { - unsigned count = 0; - StoreRefs.clear(); - const DataLayout &DL = BB->getModule()->getDataLayout(); - for (Instruction &I : *BB) { - StoreInst *SI = dyn_cast<StoreInst>(&I); - if (!SI) - continue; - - // Don't touch volatile stores. - if (!SI->isSimple()) - continue; + // Initialize the collections. We will make a single pass over the block. + Stores.clear(); + GEPs.clear(); - // Check that the pointer points to scalars. - Type *Ty = SI->getValueOperand()->getType(); - if (!isValidElementType(Ty)) - continue; + // Visit the store and getelementptr instructions in BB and organize them in + // Stores and GEPs according to the underlying objects of their pointer + // operands. + for (Instruction &I : *BB) { - // Find the base pointer. - Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), DL); + // Ignore store instructions that are volatile or have a pointer operand + // that doesn't point to a scalar type. + if (auto *SI = dyn_cast<StoreInst>(&I)) { + if (!SI->isSimple()) + continue; + if (!isValidElementType(SI->getValueOperand()->getType())) + continue; + Stores[GetUnderlyingObject(SI->getPointerOperand(), *DL)].push_back(SI); + } - // Save the store locations. - StoreRefs[Ptr].push_back(SI); - count++; + // Ignore getelementptr instructions that have more than one index, a + // constant index, or a pointer operand that doesn't point to a scalar + // type. + else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { + auto Idx = GEP->idx_begin()->get(); + if (GEP->getNumIndices() > 1 || isa<Constant>(Idx)) + continue; + if (!isValidElementType(Idx->getType())) + continue; + if (GEP->getType()->isVectorTy()) + continue; + GEPs[GetUnderlyingObject(GEP->getPointerOperand(), *DL)].push_back(GEP); + } } - return count; } -bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { +bool SLPVectorizerPass::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { if (!A || !B) return false; Value *VL[] = { A, B }; return tryToVectorizeList(VL, R, None, true); } -bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, - ArrayRef<Value *> BuildVector, - bool allowReorder) { +bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, + ArrayRef<Value *> BuildVector, + bool allowReorder) { if (VL.size() < 2) return false; @@ -3459,13 +3754,11 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, return false; unsigned Opcode0 = I0->getOpcode(); - const DataLayout &DL = I0->getModule()->getDataLayout(); - Type *Ty0 = I0->getType(); - unsigned Sz = DL.getTypeSizeInBits(Ty0); // FIXME: Register size should be a parameter to this function, so we can // try different vectorization factors. - unsigned VF = MinVecRegSize / Sz; + unsigned Sz = R.getVectorElementSize(I0); + unsigned VF = R.getMinVecRegSize() / Sz; for (Value *V : VL) { Type *Ty = V->getType(); @@ -3513,6 +3806,7 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, Value *ReorderedOps[] = { Ops[1], Ops[0] }; R.buildTree(ReorderedOps, None); } + R.computeMinimumValueSizes(); int Cost = R.getTreeCost(); if (Cost < -SLPCostThreshold) { @@ -3529,15 +3823,16 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, Instruction *InsertAfter = cast<Instruction>(BuildVectorSlice.back()); unsigned VecIdx = 0; for (auto &V : BuildVectorSlice) { - IRBuilder<true, NoFolder> Builder( - InsertAfter->getParent(), ++BasicBlock::iterator(InsertAfter)); - InsertElementInst *IE = cast<InsertElementInst>(V); + IRBuilder<NoFolder> Builder(InsertAfter->getParent(), + ++BasicBlock::iterator(InsertAfter)); + Instruction *I = cast<Instruction>(V); + assert(isa<InsertElementInst>(I) || isa<InsertValueInst>(I)); Instruction *Extract = cast<Instruction>(Builder.CreateExtractElement( VectorizedRoot, Builder.getInt32(VecIdx++))); - IE->setOperand(1, Extract); - IE->removeFromParent(); - IE->insertAfter(Extract); - InsertAfter = IE; + I->setOperand(1, Extract); + I->removeFromParent(); + I->insertAfter(Extract); + InsertAfter = I; } } // Move to the next bundle. @@ -3549,7 +3844,7 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, return Changed; } -bool SLPVectorizer::tryToVectorize(BinaryOperator *V, BoUpSLP &R) { +bool SLPVectorizerPass::tryToVectorize(BinaryOperator *V, BoUpSLP &R) { if (!V) return false; @@ -3662,9 +3957,14 @@ public: /// The width of one full horizontal reduction operation. unsigned ReduxWidth; - HorizontalReduction() - : ReductionRoot(nullptr), ReductionPHI(nullptr), ReductionOpcode(0), - ReducedValueOpcode(0), IsPairwiseReduction(false), ReduxWidth(0) {} + /// Minimal width of available vector registers. It's used to determine + /// ReduxWidth. + unsigned MinVecRegSize; + + HorizontalReduction(unsigned MinVecRegSize) + : ReductionRoot(nullptr), ReductionPHI(nullptr), ReductionOpcode(0), + ReducedValueOpcode(0), IsPairwiseReduction(false), ReduxWidth(0), + MinVecRegSize(MinVecRegSize) {} /// \brief Try to find a reduction tree. bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { @@ -3779,6 +4079,7 @@ public: for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) { V.buildTree(makeArrayRef(&ReducedVals[i], ReduxWidth), ReductionOps); + V.computeMinimumValueSizes(); // Estimate cost. int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]); @@ -3928,6 +4229,25 @@ static bool findBuildVector(InsertElementInst *FirstInsertElem, return false; } +/// \brief Like findBuildVector, but looks backwards for construction of aggregate. +/// +/// \return true if it matches. +static bool findBuildAggregate(InsertValueInst *IV, + SmallVectorImpl<Value *> &BuildVector, + SmallVectorImpl<Value *> &BuildVectorOpds) { + if (!IV->hasOneUse()) + return false; + Value *V = IV->getAggregateOperand(); + if (!isa<UndefValue>(V)) { + InsertValueInst *I = dyn_cast<InsertValueInst>(V); + if (!I || !findBuildAggregate(I, BuildVector, BuildVectorOpds)) + return false; + } + BuildVector.push_back(IV); + BuildVectorOpds.push_back(IV->getInsertedValueOperand()); + return true; +} + static bool PhiTypeSorterFunc(Value *V, Value *V2) { return V->getType() < V2->getType(); } @@ -3991,11 +4311,12 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P, /// \returns true if a horizontal reduction was matched and reduced. /// \returns false if a horizontal reduction was not matched. static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI, - BoUpSLP &R, TargetTransformInfo *TTI) { + BoUpSLP &R, TargetTransformInfo *TTI, + unsigned MinRegSize) { if (!ShouldVectorizeHor) return false; - HorizontalReduction HorRdx; + HorizontalReduction HorRdx(MinRegSize); if (!HorRdx.matchAssociativeReduction(P, BI)) return false; @@ -4008,7 +4329,7 @@ static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI, return HorRdx.tryToReduce(R, TTI); } -bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { +bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { bool Changed = false; SmallVector<Value *, 4> Incoming; SmallSet<Value *, 16> VisitedInstrs; @@ -4083,7 +4404,7 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { continue; // Try to match and vectorize a horizontal reduction. - if (canMatchHorizontalReduction(P, BI, R, TTI)) { + if (canMatchHorizontalReduction(P, BI, R, TTI, R.getMinVecRegSize())) { Changed = true; it = BB->begin(); e = BB->end(); @@ -4110,7 +4431,8 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (StoreInst *SI = dyn_cast<StoreInst>(it)) if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(SI->getValueOperand())) { - if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI) || + if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI, + R.getMinVecRegSize()) || tryToVectorize(BinOp, R)) { Changed = true; it = BB->begin(); @@ -4178,16 +4500,121 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { continue; } + + // Try to vectorize trees that start at insertvalue instructions feeding into + // a store. + if (StoreInst *SI = dyn_cast<StoreInst>(it)) { + if (InsertValueInst *LastInsertValue = dyn_cast<InsertValueInst>(SI->getValueOperand())) { + const DataLayout &DL = BB->getModule()->getDataLayout(); + if (R.canMapToVector(SI->getValueOperand()->getType(), DL)) { + SmallVector<Value *, 16> BuildVector; + SmallVector<Value *, 16> BuildVectorOpds; + if (!findBuildAggregate(LastInsertValue, BuildVector, BuildVectorOpds)) + continue; + + DEBUG(dbgs() << "SLP: store of array mappable to vector: " << *SI << "\n"); + if (tryToVectorizeList(BuildVectorOpds, R, BuildVector, false)) { + Changed = true; + it = BB->begin(); + e = BB->end(); + } + continue; + } + } + } } return Changed; } -bool SLPVectorizer::vectorizeStoreChains(BoUpSLP &R) { +bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) { + auto Changed = false; + for (auto &Entry : GEPs) { + + // If the getelementptr list has fewer than two elements, there's nothing + // to do. + if (Entry.second.size() < 2) + continue; + + DEBUG(dbgs() << "SLP: Analyzing a getelementptr list of length " + << Entry.second.size() << ".\n"); + + // We process the getelementptr list in chunks of 16 (like we do for + // stores) to minimize compile-time. + for (unsigned BI = 0, BE = Entry.second.size(); BI < BE; BI += 16) { + auto Len = std::min<unsigned>(BE - BI, 16); + auto GEPList = makeArrayRef(&Entry.second[BI], Len); + + // Initialize a set a candidate getelementptrs. Note that we use a + // SetVector here to preserve program order. If the index computations + // are vectorizable and begin with loads, we want to minimize the chance + // of having to reorder them later. + SetVector<Value *> Candidates(GEPList.begin(), GEPList.end()); + + // Some of the candidates may have already been vectorized after we + // initially collected them. If so, the WeakVHs will have nullified the + // values, so remove them from the set of candidates. + Candidates.remove(nullptr); + + // Remove from the set of candidates all pairs of getelementptrs with + // constant differences. Such getelementptrs are likely not good + // candidates for vectorization in a bottom-up phase since one can be + // computed from the other. We also ensure all candidate getelementptr + // indices are unique. + for (int I = 0, E = GEPList.size(); I < E && Candidates.size() > 1; ++I) { + auto *GEPI = cast<GetElementPtrInst>(GEPList[I]); + if (!Candidates.count(GEPI)) + continue; + auto *SCEVI = SE->getSCEV(GEPList[I]); + for (int J = I + 1; J < E && Candidates.size() > 1; ++J) { + auto *GEPJ = cast<GetElementPtrInst>(GEPList[J]); + auto *SCEVJ = SE->getSCEV(GEPList[J]); + if (isa<SCEVConstant>(SE->getMinusSCEV(SCEVI, SCEVJ))) { + Candidates.remove(GEPList[I]); + Candidates.remove(GEPList[J]); + } else if (GEPI->idx_begin()->get() == GEPJ->idx_begin()->get()) { + Candidates.remove(GEPList[J]); + } + } + } + + // We break out of the above computation as soon as we know there are + // fewer than two candidates remaining. + if (Candidates.size() < 2) + continue; + + // Add the single, non-constant index of each candidate to the bundle. We + // ensured the indices met these constraints when we originally collected + // the getelementptrs. + SmallVector<Value *, 16> Bundle(Candidates.size()); + auto BundleIndex = 0u; + for (auto *V : Candidates) { + auto *GEP = cast<GetElementPtrInst>(V); + auto *GEPIdx = GEP->idx_begin()->get(); + assert(GEP->getNumIndices() == 1 || !isa<Constant>(GEPIdx)); + Bundle[BundleIndex++] = GEPIdx; + } + + // Try and vectorize the indices. We are currently only interested in + // gather-like cases of the form: + // + // ... = g[a[0] - b[0]] + g[a[1] - b[1]] + ... + // + // where the loads of "a", the loads of "b", and the subtractions can be + // performed in parallel. It's likely that detecting this pattern in a + // bottom-up phase will be simpler and less costly than building a + // full-blown top-down phase beginning at the consecutive loads. + Changed |= tryToVectorizeList(Bundle, R); + } + } + return Changed; +} + +bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { bool Changed = false; // Attempt to sort and vectorize each of the store-groups. - for (StoreListMap::iterator it = StoreRefs.begin(), e = StoreRefs.end(); - it != e; ++it) { + for (StoreListMap::iterator it = Stores.begin(), e = Stores.end(); it != e; + ++it) { if (it->second.size() < 2) continue; @@ -4207,8 +4634,6 @@ bool SLPVectorizer::vectorizeStoreChains(BoUpSLP &R) { return Changed; } -} // end anonymous namespace - char SLPVectorizer::ID = 0; static const char lv_name[] = "SLP Vectorizer"; INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false) @@ -4217,6 +4642,7 @@ INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false) namespace llvm { diff --git a/lib/Transforms/Vectorize/Vectorize.cpp b/lib/Transforms/Vectorize/Vectorize.cpp index 6e002fd5d5db..28e0b2eb9866 100644 --- a/lib/Transforms/Vectorize/Vectorize.cpp +++ b/lib/Transforms/Vectorize/Vectorize.cpp @@ -29,6 +29,7 @@ void llvm::initializeVectorization(PassRegistry &Registry) { initializeBBVectorizePass(Registry); initializeLoopVectorizePass(Registry); initializeSLPVectorizerPass(Registry); + initializeLoadStoreVectorizerPass(Registry); } void LLVMInitializeVectorization(LLVMPassRegistryRef R) { |