diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2023-09-02 21:17:18 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-12-08 17:34:50 +0000 |
commit | 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e (patch) | |
tree | 62f873df87c7c675557a179e0c4c83fe9f3087bc /contrib/llvm-project/llvm/lib/Analysis | |
parent | cf037972ea8863e2bab7461d77345367d2c1e054 (diff) | |
parent | 7fa27ce4a07f19b07799a767fc29416f3b625afb (diff) | |
download | src-06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e.tar.gz src-06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e.zip |
Merge llvm-project main llvmorg-17-init-19304-gd0b54bb50e51
This updates llvm, clang, compiler-rt, libc++, libunwind, lld, lldb and
openmp to llvm-project main llvmorg-17-init-19304-gd0b54bb50e51, the
last commit before the upstream release/17.x branch was created.
PR: 273753
MFC after: 1 month
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis')
69 files changed, 6044 insertions, 6623 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysis.cpp index 9e24f6b87bdb..7b2f91f5392a 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysis.cpp @@ -227,12 +227,12 @@ ModRefInfo AAResults::getModRefInfo(const CallBase *Call, // We can completely ignore inaccessible memory here, because MemoryLocations // can only reference accessible memory. auto ME = getMemoryEffects(Call, AAQI) - .getWithoutLoc(MemoryEffects::InaccessibleMem); + .getWithoutLoc(IRMemLocation::InaccessibleMem); if (ME.doesNotAccessMemory()) return ModRefInfo::NoModRef; - ModRefInfo ArgMR = ME.getModRef(MemoryEffects::ArgMem); - ModRefInfo OtherMR = ME.getWithoutLoc(MemoryEffects::ArgMem).getModRef(); + ModRefInfo ArgMR = ME.getModRef(IRMemLocation::ArgMem); + ModRefInfo OtherMR = ME.getWithoutLoc(IRMemLocation::ArgMem).getModRef(); if ((ArgMR | OtherMR) != OtherMR) { // Refine the modref info for argument memory. We only bother to do this // if ArgMR is not a subset of OtherMR, otherwise this won't have an impact @@ -442,15 +442,15 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, ModRefInfo MR) { } raw_ostream &llvm::operator<<(raw_ostream &OS, MemoryEffects ME) { - for (MemoryEffects::Location Loc : MemoryEffects::locations()) { + for (IRMemLocation Loc : MemoryEffects::locations()) { switch (Loc) { - case MemoryEffects::ArgMem: + case IRMemLocation::ArgMem: OS << "ArgMem: "; break; - case MemoryEffects::InaccessibleMem: + case IRMemLocation::InaccessibleMem: OS << "InaccessibleMem: "; break; - case MemoryEffects::Other: + case IRMemLocation::Other: OS << "Other: "; break; } @@ -768,10 +768,6 @@ INITIALIZE_PASS_DEPENDENCY(TypeBasedAAWrapperPass) INITIALIZE_PASS_END(AAResultsWrapperPass, "aa", "Function Alias Analysis Results", false, true) -FunctionPass *llvm::createAAResultsWrapperPass() { - return new AAResultsWrapperPass(); -} - /// Run the wrapper pass to rebuild an aggregation over known AA passes. /// /// This is the legacy pass manager's interface to the new-style AA results @@ -840,29 +836,6 @@ AAManager::Result AAManager::run(Function &F, FunctionAnalysisManager &AM) { return R; } -AAResults llvm::createLegacyPMAAResults(Pass &P, Function &F, - BasicAAResult &BAR) { - AAResults AAR(P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F)); - - // Add in our explicitly constructed BasicAA results. - if (!DisableBasicAA) - AAR.addAAResult(BAR); - - // Populate the results with the other currently available AAs. - if (auto *WrapperPass = - P.getAnalysisIfAvailable<ScopedNoAliasAAWrapperPass>()) - AAR.addAAResult(WrapperPass->getResult()); - if (auto *WrapperPass = P.getAnalysisIfAvailable<TypeBasedAAWrapperPass>()) - AAR.addAAResult(WrapperPass->getResult()); - if (auto *WrapperPass = P.getAnalysisIfAvailable<GlobalsAAWrapperPass>()) - AAR.addAAResult(WrapperPass->getResult()); - if (auto *WrapperPass = P.getAnalysisIfAvailable<ExternalAAWrapperPass>()) - if (WrapperPass->CB) - WrapperPass->CB(P, F, AAR); - - return AAR; -} - bool llvm::isNoAliasCall(const Value *V) { if (const auto *Call = dyn_cast<CallBase>(V)) return Call->hasRetAttr(Attribute::NoAlias); @@ -935,14 +908,3 @@ bool llvm::isNotVisibleOnUnwind(const Value *Object, return false; } - -void llvm::getAAResultsAnalysisUsage(AnalysisUsage &AU) { - // This function needs to be in sync with llvm::createLegacyPMAAResults -- if - // more alias analyses are added to llvm::createLegacyPMAAResults, they need - // to be added here also. - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addUsedIfAvailable<ScopedNoAliasAAWrapperPass>(); - AU.addUsedIfAvailable<TypeBasedAAWrapperPass>(); - AU.addUsedIfAvailable<GlobalsAAWrapperPass>(); - AU.addUsedIfAvailable<ExternalAAWrapperPass>(); -} diff --git a/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysisSummary.cpp b/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysisSummary.cpp deleted file mode 100644 index a91791c0b4d5..000000000000 --- a/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysisSummary.cpp +++ /dev/null @@ -1,104 +0,0 @@ -#include "AliasAnalysisSummary.h" -#include "llvm/IR/Argument.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/Compiler.h" - -namespace llvm { -namespace cflaa { - -namespace { -const unsigned AttrEscapedIndex = 0; -const unsigned AttrUnknownIndex = 1; -const unsigned AttrGlobalIndex = 2; -const unsigned AttrCallerIndex = 3; -const unsigned AttrFirstArgIndex = 4; -const unsigned AttrLastArgIndex = NumAliasAttrs; -const unsigned AttrMaxNumArgs = AttrLastArgIndex - AttrFirstArgIndex; - -// It would be *slightly* prettier if we changed these to AliasAttrs, but it -// seems that both GCC and MSVC emit dynamic initializers for const bitsets. -using AliasAttr = unsigned; -const AliasAttr AttrNone = 0; -const AliasAttr AttrEscaped = 1 << AttrEscapedIndex; -const AliasAttr AttrUnknown = 1 << AttrUnknownIndex; -const AliasAttr AttrGlobal = 1 << AttrGlobalIndex; -const AliasAttr AttrCaller = 1 << AttrCallerIndex; -const AliasAttr ExternalAttrMask = AttrEscaped | AttrUnknown | AttrGlobal; -} - -AliasAttrs getAttrNone() { return AttrNone; } - -AliasAttrs getAttrUnknown() { return AttrUnknown; } -bool hasUnknownAttr(AliasAttrs Attr) { return Attr.test(AttrUnknownIndex); } - -AliasAttrs getAttrCaller() { return AttrCaller; } -bool hasCallerAttr(AliasAttrs Attr) { return Attr.test(AttrCaller); } -bool hasUnknownOrCallerAttr(AliasAttrs Attr) { - return Attr.test(AttrUnknownIndex) || Attr.test(AttrCallerIndex); -} - -AliasAttrs getAttrEscaped() { return AttrEscaped; } -bool hasEscapedAttr(AliasAttrs Attr) { return Attr.test(AttrEscapedIndex); } - -static AliasAttr argNumberToAttr(unsigned ArgNum) { - if (ArgNum >= AttrMaxNumArgs) - return AttrUnknown; - // N.B. MSVC complains if we use `1U` here, since AliasAttr' ctor takes - // an unsigned long long. - return AliasAttr(1ULL << (ArgNum + AttrFirstArgIndex)); -} - -AliasAttrs getGlobalOrArgAttrFromValue(const Value &Val) { - if (isa<GlobalValue>(Val)) - return AttrGlobal; - - if (auto *Arg = dyn_cast<Argument>(&Val)) - // Only pointer arguments should have the argument attribute, - // because things can't escape through scalars without us seeing a - // cast, and thus, interaction with them doesn't matter. - if (!Arg->hasNoAliasAttr() && Arg->getType()->isPointerTy()) - return argNumberToAttr(Arg->getArgNo()); - return AttrNone; -} - -bool isGlobalOrArgAttr(AliasAttrs Attr) { - return Attr.reset(AttrEscapedIndex) - .reset(AttrUnknownIndex) - .reset(AttrCallerIndex) - .any(); -} - -AliasAttrs getExternallyVisibleAttrs(AliasAttrs Attr) { - return Attr & AliasAttrs(ExternalAttrMask); -} - -std::optional<InstantiatedValue> -instantiateInterfaceValue(InterfaceValue IValue, CallBase &Call) { - auto Index = IValue.Index; - auto *V = (Index == 0) ? &Call : Call.getArgOperand(Index - 1); - if (V->getType()->isPointerTy()) - return InstantiatedValue{V, IValue.DerefLevel}; - return std::nullopt; -} - -std::optional<InstantiatedRelation> -instantiateExternalRelation(ExternalRelation ERelation, CallBase &Call) { - auto From = instantiateInterfaceValue(ERelation.From, Call); - if (!From) - return std::nullopt; - auto To = instantiateInterfaceValue(ERelation.To, Call); - if (!To) - return std::nullopt; - return InstantiatedRelation{*From, *To, ERelation.Offset}; -} - -std::optional<InstantiatedAttr> -instantiateExternalAttribute(ExternalAttribute EAttr, CallBase &Call) { - auto Value = instantiateInterfaceValue(EAttr.IValue, Call); - if (!Value) - return std::nullopt; - return InstantiatedAttr{*Value, EAttr.Attr}; -} -} -} diff --git a/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysisSummary.h b/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysisSummary.h deleted file mode 100644 index ab337bad22c7..000000000000 --- a/contrib/llvm-project/llvm/lib/Analysis/AliasAnalysisSummary.h +++ /dev/null @@ -1,268 +0,0 @@ -//=====- CFLSummary.h - Abstract stratified sets implementation. --------=====// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// \file -/// This file defines various utility types and functions useful to -/// summary-based alias analysis. -/// -/// Summary-based analysis, also known as bottom-up analysis, is a style of -/// interprocedrual static analysis that tries to analyze the callees before the -/// callers get analyzed. The key idea of summary-based analysis is to first -/// process each function independently, outline its behavior in a condensed -/// summary, and then instantiate the summary at the callsite when the said -/// function is called elsewhere. This is often in contrast to another style -/// called top-down analysis, in which callers are always analyzed first before -/// the callees. -/// -/// In a summary-based analysis, functions must be examined independently and -/// out-of-context. We have no information on the state of the memory, the -/// arguments, the global values, and anything else external to the function. To -/// carry out the analysis conservative assumptions have to be made about those -/// external states. In exchange for the potential loss of precision, the -/// summary we obtain this way is highly reusable, which makes the analysis -/// easier to scale to large programs even if carried out context-sensitively. -/// -/// Currently, all CFL-based alias analyses adopt the summary-based approach -/// and therefore heavily rely on this header. -/// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H -#define LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H - -#include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/SmallVector.h" -#include <bitset> -#include <optional> - -namespace llvm { - -class CallBase; -class Value; - -namespace cflaa { - -//===----------------------------------------------------------------------===// -// AliasAttr related stuffs -//===----------------------------------------------------------------------===// - -/// The number of attributes that AliasAttr should contain. Attributes are -/// described below, and 32 was an arbitrary choice because it fits nicely in 32 -/// bits (because we use a bitset for AliasAttr). -static const unsigned NumAliasAttrs = 32; - -/// These are attributes that an alias analysis can use to mark certain special -/// properties of a given pointer. Refer to the related functions below to see -/// what kinds of attributes are currently defined. -typedef std::bitset<NumAliasAttrs> AliasAttrs; - -/// Attr represent whether the said pointer comes from an unknown source -/// (such as opaque memory or an integer cast). -AliasAttrs getAttrNone(); - -/// AttrUnknown represent whether the said pointer comes from a source not known -/// to alias analyses (such as opaque memory or an integer cast). -AliasAttrs getAttrUnknown(); -bool hasUnknownAttr(AliasAttrs); - -/// AttrCaller represent whether the said pointer comes from a source not known -/// to the current function but known to the caller. Values pointed to by the -/// arguments of the current function have this attribute set -AliasAttrs getAttrCaller(); -bool hasCallerAttr(AliasAttrs); -bool hasUnknownOrCallerAttr(AliasAttrs); - -/// AttrEscaped represent whether the said pointer comes from a known source but -/// escapes to the unknown world (e.g. casted to an integer, or passed as an -/// argument to opaque function). Unlike non-escaped pointers, escaped ones may -/// alias pointers coming from unknown sources. -AliasAttrs getAttrEscaped(); -bool hasEscapedAttr(AliasAttrs); - -/// AttrGlobal represent whether the said pointer is a global value. -/// AttrArg represent whether the said pointer is an argument, and if so, what -/// index the argument has. -AliasAttrs getGlobalOrArgAttrFromValue(const Value &); -bool isGlobalOrArgAttr(AliasAttrs); - -/// Given an AliasAttrs, return a new AliasAttrs that only contains attributes -/// meaningful to the caller. This function is primarily used for -/// interprocedural analysis -/// Currently, externally visible AliasAttrs include AttrUnknown, AttrGlobal, -/// and AttrEscaped -AliasAttrs getExternallyVisibleAttrs(AliasAttrs); - -//===----------------------------------------------------------------------===// -// Function summary related stuffs -//===----------------------------------------------------------------------===// - -/// The maximum number of arguments we can put into a summary. -static const unsigned MaxSupportedArgsInSummary = 50; - -/// We use InterfaceValue to describe parameters/return value, as well as -/// potential memory locations that are pointed to by parameters/return value, -/// of a function. -/// Index is an integer which represents a single parameter or a return value. -/// When the index is 0, it refers to the return value. Non-zero index i refers -/// to the i-th parameter. -/// DerefLevel indicates the number of dereferences one must perform on the -/// parameter/return value to get this InterfaceValue. -struct InterfaceValue { - unsigned Index; - unsigned DerefLevel; -}; - -inline bool operator==(InterfaceValue LHS, InterfaceValue RHS) { - return LHS.Index == RHS.Index && LHS.DerefLevel == RHS.DerefLevel; -} -inline bool operator!=(InterfaceValue LHS, InterfaceValue RHS) { - return !(LHS == RHS); -} -inline bool operator<(InterfaceValue LHS, InterfaceValue RHS) { - return LHS.Index < RHS.Index || - (LHS.Index == RHS.Index && LHS.DerefLevel < RHS.DerefLevel); -} -inline bool operator>(InterfaceValue LHS, InterfaceValue RHS) { - return RHS < LHS; -} -inline bool operator<=(InterfaceValue LHS, InterfaceValue RHS) { - return !(RHS < LHS); -} -inline bool operator>=(InterfaceValue LHS, InterfaceValue RHS) { - return !(LHS < RHS); -} - -// We use UnknownOffset to represent pointer offsets that cannot be determined -// at compile time. Note that MemoryLocation::UnknownSize cannot be used here -// because we require a signed value. -static const int64_t UnknownOffset = INT64_MAX; - -inline int64_t addOffset(int64_t LHS, int64_t RHS) { - if (LHS == UnknownOffset || RHS == UnknownOffset) - return UnknownOffset; - // FIXME: Do we need to guard against integer overflow here? - return LHS + RHS; -} - -/// We use ExternalRelation to describe an externally visible aliasing relations -/// between parameters/return value of a function. -struct ExternalRelation { - InterfaceValue From, To; - int64_t Offset; -}; - -inline bool operator==(ExternalRelation LHS, ExternalRelation RHS) { - return LHS.From == RHS.From && LHS.To == RHS.To && LHS.Offset == RHS.Offset; -} -inline bool operator!=(ExternalRelation LHS, ExternalRelation RHS) { - return !(LHS == RHS); -} -inline bool operator<(ExternalRelation LHS, ExternalRelation RHS) { - if (LHS.From < RHS.From) - return true; - if (LHS.From > RHS.From) - return false; - if (LHS.To < RHS.To) - return true; - if (LHS.To > RHS.To) - return false; - return LHS.Offset < RHS.Offset; -} -inline bool operator>(ExternalRelation LHS, ExternalRelation RHS) { - return RHS < LHS; -} -inline bool operator<=(ExternalRelation LHS, ExternalRelation RHS) { - return !(RHS < LHS); -} -inline bool operator>=(ExternalRelation LHS, ExternalRelation RHS) { - return !(LHS < RHS); -} - -/// We use ExternalAttribute to describe an externally visible AliasAttrs -/// for parameters/return value. -struct ExternalAttribute { - InterfaceValue IValue; - AliasAttrs Attr; -}; - -/// AliasSummary is just a collection of ExternalRelation and ExternalAttribute -struct AliasSummary { - // RetParamRelations is a collection of ExternalRelations. - SmallVector<ExternalRelation, 8> RetParamRelations; - - // RetParamAttributes is a collection of ExternalAttributes. - SmallVector<ExternalAttribute, 8> RetParamAttributes; -}; - -/// This is the result of instantiating InterfaceValue at a particular call -struct InstantiatedValue { - Value *Val; - unsigned DerefLevel; -}; -std::optional<InstantiatedValue> -instantiateInterfaceValue(InterfaceValue IValue, CallBase &Call); - -inline bool operator==(InstantiatedValue LHS, InstantiatedValue RHS) { - return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel; -} -inline bool operator!=(InstantiatedValue LHS, InstantiatedValue RHS) { - return !(LHS == RHS); -} -inline bool operator<(InstantiatedValue LHS, InstantiatedValue RHS) { - return std::less<Value *>()(LHS.Val, RHS.Val) || - (LHS.Val == RHS.Val && LHS.DerefLevel < RHS.DerefLevel); -} -inline bool operator>(InstantiatedValue LHS, InstantiatedValue RHS) { - return RHS < LHS; -} -inline bool operator<=(InstantiatedValue LHS, InstantiatedValue RHS) { - return !(RHS < LHS); -} -inline bool operator>=(InstantiatedValue LHS, InstantiatedValue RHS) { - return !(LHS < RHS); -} - -/// This is the result of instantiating ExternalRelation at a particular -/// callsite -struct InstantiatedRelation { - InstantiatedValue From, To; - int64_t Offset; -}; -std::optional<InstantiatedRelation> -instantiateExternalRelation(ExternalRelation ERelation, CallBase &Call); - -/// This is the result of instantiating ExternalAttribute at a particular -/// callsite -struct InstantiatedAttr { - InstantiatedValue IValue; - AliasAttrs Attr; -}; -std::optional<InstantiatedAttr> -instantiateExternalAttribute(ExternalAttribute EAttr, CallBase &Call); -} - -template <> struct DenseMapInfo<cflaa::InstantiatedValue> { - static inline cflaa::InstantiatedValue getEmptyKey() { - return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getEmptyKey(), - DenseMapInfo<unsigned>::getEmptyKey()}; - } - static inline cflaa::InstantiatedValue getTombstoneKey() { - return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getTombstoneKey(), - DenseMapInfo<unsigned>::getTombstoneKey()}; - } - static unsigned getHashValue(const cflaa::InstantiatedValue &IV) { - return DenseMapInfo<std::pair<Value *, unsigned>>::getHashValue( - std::make_pair(IV.Val, IV.DerefLevel)); - } - static bool isEqual(const cflaa::InstantiatedValue &LHS, - const cflaa::InstantiatedValue &RHS) { - return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel; - } -}; -} - -#endif diff --git a/contrib/llvm-project/llvm/lib/Analysis/AliasSetTracker.cpp b/contrib/llvm-project/llvm/lib/Analysis/AliasSetTracker.cpp index 1c9ebadf3649..91b889116dfa 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/AliasSetTracker.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/AliasSetTracker.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/MemoryLocation.h" diff --git a/contrib/llvm-project/llvm/lib/Analysis/Analysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/Analysis.cpp index c1b843d74600..5461ce07af0b 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/Analysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/Analysis.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "llvm-c/Analysis.h" -#include "llvm-c/Initialization.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" @@ -35,7 +34,6 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeCycleInfoWrapperPassPass(Registry); initializeDependenceAnalysisWrapperPassPass(Registry); initializeDelinearizationPass(Registry); - initializeDemandedBitsWrapperPassPass(Registry); initializeDominanceFrontierWrapperPassPass(Registry); initializeDomViewerWrapperPassPass(Registry); initializeDomPrinterWrapperPassPass(Registry); @@ -55,16 +53,9 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeLazyBlockFrequencyInfoPassPass(Registry); initializeLazyValueInfoWrapperPassPass(Registry); initializeLazyValueInfoPrinterPass(Registry); - initializeLegacyDivergenceAnalysisPass(Registry); - initializeLintLegacyPassPass(Registry); initializeLoopInfoWrapperPassPass(Registry); - initializeMemDepPrinterPass(Registry); - initializeMemDerefPrinterPass(Registry); initializeMemoryDependenceWrapperPassPass(Registry); - initializeModuleDebugInfoLegacyPrinterPass(Registry); initializeModuleSummaryIndexWrapperPassPass(Registry); - initializeMustExecutePrinterPass(Registry); - initializeMustBeExecutedContextPrinterPass(Registry); initializeOptimizationRemarkEmitterWrapperPassPass(Registry); initializePhiValuesWrapperPassPass(Registry); initializePostDominatorTreeWrapperPassPass(Registry); @@ -82,15 +73,6 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeScopedNoAliasAAWrapperPassPass(Registry); initializeLCSSAVerificationPassPass(Registry); initializeMemorySSAWrapperPassPass(Registry); - initializeMemorySSAPrinterLegacyPassPass(Registry); -} - -void LLVMInitializeAnalysis(LLVMPassRegistryRef R) { - initializeAnalysis(*unwrap(R)); -} - -void LLVMInitializeIPA(LLVMPassRegistryRef R) { - initializeAnalysis(*unwrap(R)); } LLVMBool LLVMVerifyModule(LLVMModuleRef M, LLVMVerifierFailureAction Action, diff --git a/contrib/llvm-project/llvm/lib/Analysis/AssumeBundleQueries.cpp b/contrib/llvm-project/llvm/lib/Analysis/AssumeBundleQueries.cpp index 110cddb4a065..7440dbd29ccf 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/AssumeBundleQueries.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/AssumeBundleQueries.cpp @@ -162,7 +162,7 @@ llvm::getKnowledgeForValue(const Value *V, return RetainedKnowledge::none(); if (AC) { for (AssumptionCache::ResultElem &Elem : AC->assumptionsFor(V)) { - auto *II = dyn_cast_or_null<AssumeInst>(Elem.Assume); + auto *II = cast_or_null<AssumeInst>(Elem.Assume); if (!II || Elem.Index == AssumptionCache::ExprResultIdx) continue; if (RetainedKnowledge RK = getKnowledgeFromBundle( diff --git a/contrib/llvm-project/llvm/lib/Analysis/AssumptionCache.cpp b/contrib/llvm-project/llvm/lib/Analysis/AssumptionCache.cpp index 2d648ccee46c..b439dc1e6a76 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/AssumptionCache.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/AssumptionCache.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file contains a pass that keeps track of @llvm.assume and -// @llvm.experimental.guard intrinsics in the functions of a module. +// This file contains a pass that keeps track of @llvm.assume intrinsics in +// the functions of a module. // //===----------------------------------------------------------------------===// @@ -87,7 +87,7 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI, AddAffected(Cond); CmpInst::Predicate Pred; - if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) { + if (match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B)))) { AddAffected(A); AddAffected(B); @@ -128,7 +128,18 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI, if (match(A, m_Add(m_Value(X), m_ConstantInt())) && match(B, m_ConstantInt())) AddAffected(X); + } else if (CmpInst::isFPPredicate(Pred)) { + // fcmp fneg(x), y + // fcmp fabs(x), y + // fcmp fneg(fabs(x)), y + if (match(A, m_FNeg(m_Value(A)))) + AddAffected(A); + if (match(A, m_FAbs(m_Value(A)))) + AddAffected(A); } + } else if (match(Cond, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A), + m_Value(B)))) { + AddAffected(A); } if (TTI) { @@ -140,7 +151,7 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI, } } -void AssumptionCache::updateAffectedValues(CondGuardInst *CI) { +void AssumptionCache::updateAffectedValues(AssumeInst *CI) { SmallVector<AssumptionCache::ResultElem, 16> Affected; findAffectedValues(CI, TTI, Affected); @@ -153,7 +164,7 @@ void AssumptionCache::updateAffectedValues(CondGuardInst *CI) { } } -void AssumptionCache::unregisterAssumption(CondGuardInst *CI) { +void AssumptionCache::unregisterAssumption(AssumeInst *CI) { SmallVector<AssumptionCache::ResultElem, 16> Affected; findAffectedValues(CI, TTI, Affected); @@ -217,7 +228,7 @@ void AssumptionCache::scanFunction() { // to this cache. for (BasicBlock &B : F) for (Instruction &I : B) - if (isa<CondGuardInst>(&I)) + if (isa<AssumeInst>(&I)) AssumeHandles.push_back({&I, ExprResultIdx}); // Mark the scan as complete. @@ -225,10 +236,10 @@ void AssumptionCache::scanFunction() { // Update affected values. for (auto &A : AssumeHandles) - updateAffectedValues(cast<CondGuardInst>(A)); + updateAffectedValues(cast<AssumeInst>(A)); } -void AssumptionCache::registerAssumption(CondGuardInst *CI) { +void AssumptionCache::registerAssumption(AssumeInst *CI) { // If we haven't scanned the function yet, just drop this assumption. It will // be found when we scan later. if (!Scanned) @@ -238,9 +249,9 @@ void AssumptionCache::registerAssumption(CondGuardInst *CI) { #ifndef NDEBUG assert(CI->getParent() && - "Cannot a register CondGuardInst not in a basic block"); + "Cannot register @llvm.assume call not in a basic block"); assert(&F == CI->getParent()->getParent() && - "Cannot a register CondGuardInst not in this function"); + "Cannot register @llvm.assume call not in this function"); // We expect the number of assumptions to be small, so in an asserts build // check that we don't accumulate duplicates and that all assumptions point @@ -252,8 +263,8 @@ void AssumptionCache::registerAssumption(CondGuardInst *CI) { assert(&F == cast<Instruction>(VH)->getParent()->getParent() && "Cached assumption not inside this function!"); - assert(isa<CondGuardInst>(VH) && - "Cached something other than CondGuardInst!"); + assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) && + "Cached something other than a call to @llvm.assume!"); assert(AssumptionSet.insert(VH).second && "Cache contains multiple copies of a call!"); } diff --git a/contrib/llvm-project/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/BasicAliasAnalysis.cpp index dc728c1cbfeb..16e0e1f66524 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/BasicAliasAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -461,6 +461,17 @@ struct VariableGEPIndex { /// True if all operations in this expression are NSW. bool IsNSW; + /// True if the index should be subtracted rather than added. We don't simply + /// negate the Scale, to avoid losing the NSW flag: X - INT_MIN*1 may be + /// non-wrapping, while X + INT_MIN*(-1) wraps. + bool IsNegated; + + bool hasNegatedScaleOf(const VariableGEPIndex &Other) const { + if (IsNegated == Other.IsNegated) + return Scale == -Other.Scale; + return Scale == Other.Scale; + } + void dump() const { print(dbgs()); dbgs() << "\n"; @@ -470,7 +481,9 @@ struct VariableGEPIndex { << ", zextbits=" << Val.ZExtBits << ", sextbits=" << Val.SExtBits << ", truncbits=" << Val.TruncBits - << ", scale=" << Scale << ")"; + << ", scale=" << Scale + << ", nsw=" << IsNSW + << ", negated=" << IsNegated << ")"; } }; } @@ -659,7 +672,8 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL, Scale = adjustToIndexSize(Scale, IndexSize); if (!!Scale) { - VariableGEPIndex Entry = {LE.Val, Scale, CxtI, LE.IsNSW}; + VariableGEPIndex Entry = {LE.Val, Scale, CxtI, LE.IsNSW, + /* IsNegated */ false}; Decomposed.VarIndices.push_back(Entry); } } @@ -864,9 +878,11 @@ ModRefInfo BasicAAResult::getModRefInfo(const CallBase *Call, if (!AI->isStaticAlloca() && isIntrinsicCall(Call, Intrinsic::stackrestore)) return ModRefInfo::Mod; - // If the pointer is to a locally allocated object that does not escape, - // then the call can not mod/ref the pointer unless the call takes the pointer - // as an argument, and itself doesn't capture it. + // A call can access a locally allocated object either because it is passed as + // an argument to the call, or because it has escaped prior to the call. + // + // Make sure the object has not escaped here, and then check that none of the + // call arguments alias the object below. if (!isa<Constant>(Object) && Call != Object && AAQI.CI->isNotCapturedBeforeOrAt(Object, Call)) { @@ -877,12 +893,7 @@ ModRefInfo BasicAAResult::getModRefInfo(const CallBase *Call, unsigned OperandNo = 0; for (auto CI = Call->data_operands_begin(), CE = Call->data_operands_end(); CI != CE; ++CI, ++OperandNo) { - // Only look at the no-capture or byval pointer arguments. If this - // pointer were passed to arguments that were neither of these, then it - // couldn't be no-capture. - if (!(*CI)->getType()->isPointerTy() || - (!Call->doesNotCapture(OperandNo) && OperandNo < Call->arg_size() && - !Call->isByValArgument(OperandNo))) + if (!(*CI)->getType()->isPointerTy()) continue; // Call doesn't access memory through this operand, so we don't care @@ -1134,8 +1145,8 @@ AliasResult BasicAAResult::aliasGEP( const APInt &Scale = Index.Scale; APInt ScaleForGCD = Scale; if (!Index.IsNSW) - ScaleForGCD = APInt::getOneBitSet(Scale.getBitWidth(), - Scale.countTrailingZeros()); + ScaleForGCD = + APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero()); if (i == 0) GCD = ScaleForGCD.abs(); @@ -1154,9 +1165,14 @@ AliasResult BasicAAResult::aliasGEP( assert(OffsetRange.getBitWidth() == Scale.getBitWidth() && "Bit widths are normalized to MaxIndexSize"); if (Index.IsNSW) - OffsetRange = OffsetRange.add(CR.smul_sat(ConstantRange(Scale))); + CR = CR.smul_sat(ConstantRange(Scale)); + else + CR = CR.smul_fast(ConstantRange(Scale)); + + if (Index.IsNegated) + OffsetRange = OffsetRange.sub(CR); else - OffsetRange = OffsetRange.add(CR.smul_fast(ConstantRange(Scale))); + OffsetRange = OffsetRange.add(CR); } // We now have accesses at two offsets from the same base: @@ -1223,7 +1239,7 @@ AliasResult BasicAAResult::aliasGEP( // inequality of values across loop iterations. const VariableGEPIndex &Var0 = DecompGEP1.VarIndices[0]; const VariableGEPIndex &Var1 = DecompGEP1.VarIndices[1]; - if (Var0.Scale == -Var1.Scale && Var0.Val.TruncBits == 0 && + if (Var0.hasNegatedScaleOf(Var1) && Var0.Val.TruncBits == 0 && Var0.Val.hasSameCastsAs(Var1.Val) && !AAQI.MayBeCrossIteration && isKnownNonEqual(Var0.Val.V, Var1.Val.V, DL, &AC, /* CxtI */ nullptr, DT)) @@ -1516,6 +1532,8 @@ AliasResult BasicAAResult::aliasCheck(const Value *V1, LocationSize V1Size, assert(OBU.Inputs.size() == 2); const Value *Hint1 = OBU.Inputs[0].get(); const Value *Hint2 = OBU.Inputs[1].get(); + // This is often a no-op; instcombine rewrites this for us. No-op + // getUnderlyingObject calls are fast, though. const Value *HintO1 = getUnderlyingObject(Hint1); const Value *HintO2 = getUnderlyingObject(Hint2); @@ -1702,6 +1720,13 @@ void BasicAAResult::subtractDecomposedGEPs(DecomposedGEP &DestGEP, !Dest.Val.hasSameCastsAs(Src.Val)) continue; + // Normalize IsNegated if we're going to lose the NSW flag anyway. + if (Dest.IsNegated) { + Dest.Scale = -Dest.Scale; + Dest.IsNegated = false; + Dest.IsNSW = false; + } + // If we found it, subtract off Scale V's from the entry in Dest. If it // goes to zero, remove the entry. if (Dest.Scale != Src.Scale) { @@ -1716,7 +1741,8 @@ void BasicAAResult::subtractDecomposedGEPs(DecomposedGEP &DestGEP, // If we didn't consume this entry, add it to the end of the Dest list. if (!Found) { - VariableGEPIndex Entry = {Src.Val, -Src.Scale, Src.CxtI, Src.IsNSW}; + VariableGEPIndex Entry = {Src.Val, Src.Scale, Src.CxtI, Src.IsNSW, + /* IsNegated */ true}; DestGEP.VarIndices.push_back(Entry); } } @@ -1738,7 +1764,7 @@ bool BasicAAResult::constantOffsetHeuristic(const DecomposedGEP &GEP, const VariableGEPIndex &Var0 = GEP.VarIndices[0], &Var1 = GEP.VarIndices[1]; if (Var0.Val.TruncBits != 0 || !Var0.Val.hasSameCastsAs(Var1.Val) || - Var0.Scale != -Var1.Scale || + !Var0.hasNegatedScaleOf(Var1) || Var0.Val.V->getType() != Var1.Val.V->getType()) return false; @@ -1825,10 +1851,3 @@ void BasicAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive<DominatorTreeWrapperPass>(); AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); } - -BasicAAResult llvm::createLegacyPMBasicAAResult(Pass &P, Function &F) { - return BasicAAResult( - F.getParent()->getDataLayout(), F, - P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F), - P.getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F)); -} diff --git a/contrib/llvm-project/llvm/lib/Analysis/BlockFrequencyInfo.cpp b/contrib/llvm-project/llvm/lib/Analysis/BlockFrequencyInfo.cpp index dd84336da604..b18d04cc73db 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/BlockFrequencyInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/BlockFrequencyInfo.cpp @@ -333,9 +333,10 @@ bool BlockFrequencyInfoWrapperPass::runOnFunction(Function &F) { AnalysisKey BlockFrequencyAnalysis::Key; BlockFrequencyInfo BlockFrequencyAnalysis::run(Function &F, FunctionAnalysisManager &AM) { + auto &BP = AM.getResult<BranchProbabilityAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); BlockFrequencyInfo BFI; - BFI.calculate(F, AM.getResult<BranchProbabilityAnalysis>(F), - AM.getResult<LoopAnalysis>(F)); + BFI.calculate(F, BP, LI); return BFI; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/contrib/llvm-project/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp index 0945c5688f1f..82b1e3b9eede 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/Function.h" #include "llvm/Support/BlockFrequency.h" @@ -59,7 +60,7 @@ cl::opt<double> IterativeBFIPrecision( "iterative-bfi-precision", cl::init(1e-12), cl::Hidden, cl::desc("Iterative inference: delta convergence precision; smaller values " "typically lead to better results at the cost of worsen runtime")); -} +} // namespace llvm ScaledNumber<uint64_t> BlockMass::toScaled() const { if (isFull()) @@ -256,7 +257,7 @@ void Distribution::normalize() { if (DidOverflow) Shift = 33; else if (Total > UINT32_MAX) - Shift = 33 - countLeadingZeros(Total); + Shift = 33 - llvm::countl_zero(Total); // Early exit if nothing needs to be scaled. if (!Shift) { diff --git a/contrib/llvm-project/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/contrib/llvm-project/llvm/lib/Analysis/BranchProbabilityInfo.cpp index 7931001d0a2b..b45deccd913d 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -1163,7 +1163,7 @@ void BranchProbabilityInfo::copyEdgeProbabilities(BasicBlock *Src, assert(NumSuccessors == Dst->getTerminator()->getNumSuccessors()); if (NumSuccessors == 0) return; // Nothing to set. - if (this->Probs.find(std::make_pair(Src, 0)) == this->Probs.end()) + if (!this->Probs.contains(std::make_pair(Src, 0))) return; // No probability is set for edges from Src. Keep the same for Dst. Handles.insert(BasicBlockCallbackVH(Dst, this)); @@ -1175,6 +1175,14 @@ void BranchProbabilityInfo::copyEdgeProbabilities(BasicBlock *Src, } } +void BranchProbabilityInfo::swapSuccEdgesProbabilities(const BasicBlock *Src) { + assert(Src->getTerminator()->getNumSuccessors() == 2); + if (!Probs.contains(std::make_pair(Src, 0))) + return; // No probability is set for edges from Src + assert(Probs.contains(std::make_pair(Src, 1))); + std::swap(Probs[std::make_pair(Src, 0)], Probs[std::make_pair(Src, 1)]); +} + raw_ostream & BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS, const BasicBlock *Src, @@ -1303,11 +1311,12 @@ void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS, AnalysisKey BranchProbabilityAnalysis::Key; BranchProbabilityInfo BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) { + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); BranchProbabilityInfo BPI; - BPI.calculate(F, AM.getResult<LoopAnalysis>(F), - &AM.getResult<TargetLibraryAnalysis>(F), - &AM.getResult<DominatorTreeAnalysis>(F), - &AM.getResult<PostDominatorTreeAnalysis>(F)); + BPI.calculate(F, LI, &TLI, &DT, &PDT); return BPI; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/CFGPrinter.cpp b/contrib/llvm-project/llvm/lib/Analysis/CFGPrinter.cpp index f8eba1a00f28..f05dd6852d6d 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/CFGPrinter.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/CFGPrinter.cpp @@ -325,8 +325,7 @@ bool DOTGraphTraits<DOTFuncInfo *>::isNodeHidden(const BasicBlock *Node, return true; } if (HideUnreachablePaths || HideDeoptimizePaths) { - if (isOnDeoptOrUnreachablePath.find(Node) == - isOnDeoptOrUnreachablePath.end()) + if (!isOnDeoptOrUnreachablePath.contains(Node)) computeDeoptOrUnreachablePaths(Node->getParent()); return isOnDeoptOrUnreachablePath[Node]; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/CGSCCPassManager.cpp b/contrib/llvm-project/llvm/lib/Analysis/CGSCCPassManager.cpp index 2de19884014c..facb9c897da3 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/CGSCCPassManager.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/CGSCCPassManager.cpp @@ -86,11 +86,6 @@ PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &, PreservedAnalyses PassPA = Pass->run(*C, AM, G, UR); - if (UR.InvalidatedSCCs.count(C)) - PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass, PassPA); - else - PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C, PassPA); - // Update the SCC if necessary. C = UR.UpdatedC ? UR.UpdatedC : C; if (UR.UpdatedC) { @@ -107,6 +102,7 @@ PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &, // If the CGSCC pass wasn't able to provide a valid updated SCC, the // current SCC may simply need to be skipped if invalid. if (UR.InvalidatedSCCs.count(C)) { + PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass, PassPA); LLVM_DEBUG(dbgs() << "Skipping invalidated root or island SCC!\n"); break; } @@ -117,6 +113,8 @@ PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &, // Update the analysis manager as each pass runs and potentially // invalidates analyses. AM.invalidate(*C, PassPA); + + PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C, PassPA); } // Before we mark all of *this* SCC's analyses as preserved below, intersect @@ -276,11 +274,6 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { PreservedAnalyses PassPA = Pass->run(*C, CGAM, CG, UR); - if (UR.InvalidatedSCCs.count(C)) - PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass, PassPA); - else - PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C, PassPA); - // Update the SCC and RefSCC if necessary. C = UR.UpdatedC ? UR.UpdatedC : C; @@ -301,6 +294,7 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { // If the CGSCC pass wasn't able to provide a valid updated SCC, // the current SCC may simply need to be skipped if invalid. if (UR.InvalidatedSCCs.count(C)) { + PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass, PassPA); LLVM_DEBUG(dbgs() << "Skipping invalidated root or island SCC!\n"); break; } @@ -316,6 +310,8 @@ ModuleToPostOrderCGSCCPassAdaptor::run(Module &M, ModuleAnalysisManager &AM) { // processed. CGAM.invalidate(*C, PassPA); + PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C, PassPA); + // The pass may have restructured the call graph and refined the // current SCC and/or RefSCC. We need to update our current SCC and // RefSCC pointers to follow these. Also, when the current SCC is @@ -408,25 +404,27 @@ PreservedAnalyses DevirtSCCRepeatedPass::run(LazyCallGraph::SCC &InitialC, PreservedAnalyses PassPA = Pass->run(*C, AM, CG, UR); - if (UR.InvalidatedSCCs.count(C)) - PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass, PassPA); - else - PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C, PassPA); - PA.intersect(PassPA); - // If the SCC structure has changed, bail immediately and let the outer - // CGSCC layer handle any iteration to reflect the refined structure. - if (UR.UpdatedC && UR.UpdatedC != C) - break; - // If the CGSCC pass wasn't able to provide a valid updated SCC, the // current SCC may simply need to be skipped if invalid. if (UR.InvalidatedSCCs.count(C)) { + PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass, PassPA); LLVM_DEBUG(dbgs() << "Skipping invalidated root or island SCC!\n"); break; } + // Update the analysis manager with each run and intersect the total set + // of preserved analyses so we're ready to iterate. + AM.invalidate(*C, PassPA); + + PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C, PassPA); + + // If the SCC structure has changed, bail immediately and let the outer + // CGSCC layer handle any iteration to reflect the refined structure. + if (UR.UpdatedC && UR.UpdatedC != C) + break; + assert(C->begin() != C->end() && "Cannot have an empty SCC!"); // Check whether any of the handles were devirtualized. @@ -490,10 +488,6 @@ PreservedAnalyses DevirtSCCRepeatedPass::run(LazyCallGraph::SCC &InitialC, // Move over the new call counts in preparation for iterating. CallCounts = std::move(NewCallCounts); - - // Update the analysis manager with each run and intersect the total set - // of preserved analyses so we're ready to iterate. - AM.invalidate(*C, PassPA); } // Note that we don't add any preserved entries here unlike a more normal @@ -539,14 +533,13 @@ PreservedAnalyses CGSCCToFunctionPassAdaptor::run(LazyCallGraph::SCC &C, continue; PreservedAnalyses PassPA = Pass->run(F, FAM); - PI.runAfterPass<Function>(*Pass, F, PassPA); // We know that the function pass couldn't have invalidated any other // function's analyses (that's the contract of a function pass), so // directly handle the function analysis manager's invalidation here. FAM.invalidate(F, EagerlyInvalidate ? PreservedAnalyses::none() : PassPA); - if (NoRerun) - (void)FAM.getResult<ShouldNotRunFunctionPassesAnalysis>(F); + + PI.runAfterPass<Function>(*Pass, F, PassPA); // Then intersect the preserved set so that invalidation of module // analyses will eventually occur when the module pass completes. diff --git a/contrib/llvm-project/llvm/lib/Analysis/CallGraphSCCPass.cpp b/contrib/llvm-project/llvm/lib/Analysis/CallGraphSCCPass.cpp index d66f1e261780..307dddd51ece 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/CallGraphSCCPass.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/CallGraphSCCPass.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/IR/AbstractCallSite.h" #include "llvm/IR/Function.h" diff --git a/contrib/llvm-project/llvm/lib/Analysis/CaptureTracking.cpp b/contrib/llvm-project/llvm/lib/Analysis/CaptureTracking.cpp index 7f3a2b49aca9..00e096af3110 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/CaptureTracking.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/CaptureTracking.cpp @@ -58,17 +58,16 @@ CaptureTracker::~CaptureTracker() = default; bool CaptureTracker::shouldExplore(const Use *U) { return true; } bool CaptureTracker::isDereferenceableOrNull(Value *O, const DataLayout &DL) { - // An inbounds GEP can either be a valid pointer (pointing into - // or to the end of an allocation), or be null in the default - // address space. So for an inbounds GEP there is no way to let - // the pointer escape using clever GEP hacking because doing so - // would make the pointer point outside of the allocated object - // and thus make the GEP result a poison value. Similarly, other - // dereferenceable pointers cannot be manipulated without producing - // poison. - if (auto *GEP = dyn_cast<GetElementPtrInst>(O)) - if (GEP->isInBounds()) - return true; + // We want comparisons to null pointers to not be considered capturing, + // but need to guard against cases like gep(p, -ptrtoint(p2)) == null, + // which are equivalent to p == p2 and would capture the pointer. + // + // A dereferenceable pointer is a case where this is known to be safe, + // because the pointer resulting from such a construction would not be + // dereferenceable. + // + // It is not sufficient to check for inbounds GEP here, because GEP with + // zero offset is always inbounds. bool CanBeNull, CanBeFreed; return O->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed); } @@ -80,7 +79,10 @@ namespace { const SmallPtrSetImpl<const Value *> &EphValues, bool ReturnCaptures) : EphValues(EphValues), ReturnCaptures(ReturnCaptures) {} - void tooManyUses() override { Captured = true; } + void tooManyUses() override { + LLVM_DEBUG(dbgs() << "Captured due to too many uses\n"); + Captured = true; + } bool captured(const Use *U) override { if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures) @@ -89,6 +91,8 @@ namespace { if (EphValues.contains(U->getUser())) return false; + LLVM_DEBUG(dbgs() << "Captured by: " << *U->getUser() << "\n"); + Captured = true; return true; } @@ -233,12 +237,16 @@ bool llvm::PointerMayBeCaptured(const Value *V, bool ReturnCaptures, // take advantage of this. (void)StoreCaptures; + LLVM_DEBUG(dbgs() << "Captured?: " << *V << " = "); + SimpleCaptureTracker SCT(EphValues, ReturnCaptures); PointerMayBeCaptured(V, &SCT, MaxUsesToExplore); if (SCT.Captured) ++NumCaptured; - else + else { ++NumNotCaptured; + LLVM_DEBUG(dbgs() << "not captured\n"); + } return SCT.Captured; } @@ -403,12 +411,7 @@ UseCaptureKind llvm::DetermineUseCaptureKind( return UseCaptureKind::NO_CAPTURE; } } - // Comparison against value stored in global variable. Given the pointer - // does not escape, its value cannot be guessed and stored separately in a - // global variable. - auto *LI = dyn_cast<LoadInst>(I->getOperand(OtherIdx)); - if (LI && isa<GlobalVariable>(LI->getPointerOperand())) - return UseCaptureKind::NO_CAPTURE; + // Otherwise, be conservative. There are crazy ways to capture pointers // using comparisons. return UseCaptureKind::MAY_CAPTURE; diff --git a/contrib/llvm-project/llvm/lib/Analysis/CmpInstAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/CmpInstAnalysis.cpp index 20b1df6e1495..d6407e875073 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/CmpInstAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -79,7 +79,7 @@ bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, using namespace PatternMatch; const APInt *C; - if (!match(RHS, m_APInt(C))) + if (!match(RHS, m_APIntAllowUndef(C))) return false; switch (Pred) { diff --git a/contrib/llvm-project/llvm/lib/Analysis/ConstantFolding.cpp b/contrib/llvm-project/llvm/lib/Analysis/ConstantFolding.cpp index 6a2d6ba767e7..38cccb3ea3c2 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ConstantFolding.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ConstantFolding.cpp @@ -235,7 +235,8 @@ Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { ShiftAmt += isLittleEndian ? SrcBitSize : -SrcBitSize; // Mix it in. - Elt = ConstantExpr::getOr(Elt, Src); + Elt = ConstantFoldBinaryOpOperands(Instruction::Or, Elt, Src, DL); + assert(Elt && "Constant folding cannot fail on plain integers"); } Result.push_back(Elt); } @@ -429,18 +430,16 @@ bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, unsigned char *CurPtr, return true; if (auto *CI = dyn_cast<ConstantInt>(C)) { - if (CI->getBitWidth() > 64 || - (CI->getBitWidth() & 7) != 0) + if ((CI->getBitWidth() & 7) != 0) return false; - - uint64_t Val = CI->getZExtValue(); + const APInt &Val = CI->getValue(); unsigned IntBytes = unsigned(CI->getBitWidth()/8); for (unsigned i = 0; i != BytesLeft && ByteOffset != IntBytes; ++i) { - int n = ByteOffset; + unsigned n = ByteOffset; if (!DL.isLittleEndian()) n = IntBytes - n - 1; - CurPtr[i] = (unsigned char)(Val >> (n * 8)); + CurPtr[i] = Val.extractBits(8, n * 8).getZExtValue(); ++ByteOffset; } return true; @@ -501,16 +500,22 @@ bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, unsigned char *CurPtr, if (isa<ConstantArray>(C) || isa<ConstantVector>(C) || isa<ConstantDataSequential>(C)) { - uint64_t NumElts; + uint64_t NumElts, EltSize; Type *EltTy; if (auto *AT = dyn_cast<ArrayType>(C->getType())) { NumElts = AT->getNumElements(); EltTy = AT->getElementType(); + EltSize = DL.getTypeAllocSize(EltTy); } else { NumElts = cast<FixedVectorType>(C->getType())->getNumElements(); EltTy = cast<FixedVectorType>(C->getType())->getElementType(); + // TODO: For non-byte-sized vectors, current implementation assumes there is + // padding to the next byte boundary between elements. + if (!DL.typeSizeEqualsStoreSize(EltTy)) + return false; + + EltSize = DL.getTypeStoreSize(EltTy); } - uint64_t EltSize = DL.getTypeAllocSize(EltTy); uint64_t Index = ByteOffset / EltSize; uint64_t Offset = ByteOffset - Index * EltSize; @@ -713,7 +718,7 @@ Constant *llvm::ConstantFoldLoadFromConst(Constant *C, Type *Ty, return Result; // Try hard to fold loads from bitcasted strange and non-type-safe things. - if (Offset.getMinSignedBits() <= 64) + if (Offset.getSignificantBits() <= 64) if (Constant *Result = FoldReinterpretLoadFromConst(C, Ty, Offset.getSExtValue(), DL)) return Result; @@ -729,26 +734,23 @@ Constant *llvm::ConstantFoldLoadFromConst(Constant *C, Type *Ty, Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, APInt Offset, const DataLayout &DL) { + // We can only fold loads from constant globals with a definitive initializer. + // Check this upfront, to skip expensive offset calculations. + auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(C)); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + return nullptr; + C = cast<Constant>(C->stripAndAccumulateConstantOffsets( DL, Offset, /* AllowNonInbounds */ true)); - if (auto *GV = dyn_cast<GlobalVariable>(C)) - if (GV->isConstant() && GV->hasDefinitiveInitializer()) - if (Constant *Result = ConstantFoldLoadFromConst(GV->getInitializer(), Ty, - Offset, DL)) - return Result; + if (C == GV) + if (Constant *Result = ConstantFoldLoadFromConst(GV->getInitializer(), Ty, + Offset, DL)) + return Result; // If this load comes from anywhere in a uniform constant global, the value // is always the same, regardless of the loaded offset. - if (auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(C))) { - if (GV->isConstant() && GV->hasDefinitiveInitializer()) { - if (Constant *Res = - ConstantFoldLoadFromUniformValue(GV->getInitializer(), Ty)) - return Res; - } - } - - return nullptr; + return ConstantFoldLoadFromUniformValue(GV->getInitializer(), Ty); } Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, @@ -825,7 +827,8 @@ Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, Constant *Op1, /// If array indices are not pointer-sized integers, explicitly cast them so /// that they aren't implicitly casted by the getelementptr. Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops, - Type *ResultTy, std::optional<unsigned> InRangeIndex, + Type *ResultTy, bool InBounds, + std::optional<unsigned> InRangeIndex, const DataLayout &DL, const TargetLibraryInfo *TLI) { Type *IntIdxTy = DL.getIndexType(ResultTy); Type *IntIdxScalarTy = IntIdxTy->getScalarType(); @@ -854,23 +857,21 @@ Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops, return nullptr; Constant *C = ConstantExpr::getGetElementPtr( - SrcElemTy, Ops[0], NewIdxs, /*InBounds=*/false, InRangeIndex); + SrcElemTy, Ops[0], NewIdxs, InBounds, InRangeIndex); return ConstantFoldConstant(C, DL, TLI); } /// Strip the pointer casts, but preserve the address space information. -Constant *StripPtrCastKeepAS(Constant *Ptr) { +// TODO: This probably doesn't make sense with opaque pointers. +static Constant *StripPtrCastKeepAS(Constant *Ptr) { assert(Ptr->getType()->isPointerTy() && "Not a pointer type"); auto *OldPtrTy = cast<PointerType>(Ptr->getType()); Ptr = cast<Constant>(Ptr->stripPointerCasts()); auto *NewPtrTy = cast<PointerType>(Ptr->getType()); // Preserve the address space number of the pointer. - if (NewPtrTy->getAddressSpace() != OldPtrTy->getAddressSpace()) { - Ptr = ConstantExpr::getPointerCast( - Ptr, PointerType::getWithSamePointeeType(NewPtrTy, - OldPtrTy->getAddressSpace())); - } + if (NewPtrTy->getAddressSpace() != OldPtrTy->getAddressSpace()) + Ptr = ConstantExpr::getPointerCast(Ptr, OldPtrTy); return Ptr; } @@ -889,7 +890,8 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, return nullptr; if (Constant *C = CastGEPIndices(SrcElemTy, Ops, ResTy, - GEP->getInRangeIndex(), DL, TLI)) + GEP->isInBounds(), GEP->getInRangeIndex(), + DL, TLI)) return C; Constant *Ptr = Ops[0]; @@ -952,14 +954,10 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, // Otherwise form a regular getelementptr. Recompute the indices so that // we eliminate over-indexing of the notional static type array bounds. // This makes it easy to determine if the getelementptr is "inbounds". - // Also, this helps GlobalOpt do SROA on GlobalVariables. - // For GEPs of GlobalValues, use the value type even for opaque pointers. - // Otherwise use an i8 GEP. + // For GEPs of GlobalValues, use the value type, otherwise use an i8 GEP. if (auto *GV = dyn_cast<GlobalValue>(Ptr)) SrcElemTy = GV->getValueType(); - else if (!PTy->isOpaque()) - SrcElemTy = PTy->getNonOpaquePointerElementType(); else SrcElemTy = Type::getInt8Ty(Ptr->getContext()); @@ -1002,18 +1000,8 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, } // Create a GEP. - Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs, - InBounds, InRangeIndex); - assert( - cast<PointerType>(C->getType())->isOpaqueOrPointeeTypeMatches(ElemTy) && - "Computed GetElementPtr has unexpected type!"); - - // If we ended up indexing a member with a type that doesn't match - // the type of what the original indices indexed, add a cast. - if (C->getType() != ResTy) - C = FoldBitCast(C, ResTy, DL); - - return C; + return ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs, InBounds, + InRangeIndex); } /// Attempt to constant fold an instruction with the @@ -1053,11 +1041,15 @@ Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode, return ConstantFoldCastOperand(Opcode, Ops[0], DestTy, DL); if (auto *GEP = dyn_cast<GEPOperator>(InstOrCE)) { + Type *SrcElemTy = GEP->getSourceElementType(); + if (!ConstantExpr::isSupportedGetElementPtr(SrcElemTy)) + return nullptr; + if (Constant *C = SymbolicallyEvaluateGEP(GEP, Ops, DL, TLI)) return C; - return ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), Ops[0], - Ops.slice(1), GEP->isInBounds(), + return ConstantExpr::getGetElementPtr(SrcElemTy, Ops[0], Ops.slice(1), + GEP->isInBounds(), GEP->getInRangeIndex()); } @@ -1086,7 +1078,7 @@ Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode, } return nullptr; case Instruction::Select: - return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]); + return ConstantFoldSelectInstruction(Ops[0], Ops[1], Ops[2]); case Instruction::ExtractElement: return ConstantExpr::getExtractElement(Ops[0], Ops[1]); case Instruction::ExtractValue: @@ -1323,7 +1315,11 @@ Constant *llvm::ConstantFoldCompareInstOperands( // Flush any denormal constant float input according to denormal handling // mode. Ops0 = FlushFPConstant(Ops0, I, /* IsOutput */ false); + if (!Ops0) + return nullptr; Ops1 = FlushFPConstant(Ops1, I, /* IsOutput */ false); + if (!Ops1) + return nullptr; return ConstantExpr::getCompare(Predicate, Ops0, Ops1); } @@ -1358,6 +1354,10 @@ Constant *llvm::FlushFPConstant(Constant *Operand, const Instruction *I, return Operand; const APFloat &APF = CFP->getValueAPF(); + // TODO: Should this canonicalize nans? + if (!APF.isDenormal()) + return Operand; + Type *Ty = CFP->getType(); DenormalMode DenormMode = I->getFunction()->getDenormalMode(Ty->getFltSemantics()); @@ -1366,7 +1366,8 @@ Constant *llvm::FlushFPConstant(Constant *Operand, const Instruction *I, switch (Mode) { default: llvm_unreachable("unknown denormal mode"); - return Operand; + case DenormalMode::Dynamic: + return nullptr; case DenormalMode::IEEE: return Operand; case DenormalMode::PreserveSign: @@ -1392,7 +1393,11 @@ Constant *llvm::ConstantFoldFPInstOperands(unsigned Opcode, Constant *LHS, if (Instruction::isBinaryOp(Opcode)) { // Flush denormal inputs if needed. Constant *Op0 = FlushFPConstant(LHS, I, /* IsOutput */ false); + if (!Op0) + return nullptr; Constant *Op1 = FlushFPConstant(RHS, I, /* IsOutput */ false); + if (!Op1) + return nullptr; // Calculate constant result. Constant *C = ConstantFoldBinaryOpOperands(Opcode, Op0, Op1, DL); @@ -1571,6 +1576,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) { case Intrinsic::powi: case Intrinsic::fma: case Intrinsic::fmuladd: + case Intrinsic::frexp: case Intrinsic::fptoui_sat: case Intrinsic::fptosi_sat: case Intrinsic::convert_from_fp16: @@ -1966,13 +1972,25 @@ static Constant *constantFoldCanonicalize(const Type *Ty, const CallBase *CI, if (Src.isDenormal() && CI->getParent() && CI->getFunction()) { DenormalMode DenormMode = CI->getFunction()->getDenormalMode(Src.getSemantics()); + if (DenormMode == DenormalMode::getIEEE()) + return ConstantFP::get(CI->getContext(), Src); + + if (DenormMode.Input == DenormalMode::Dynamic) + return nullptr; + + // If we know if either input or output is flushed, we can fold. + if ((DenormMode.Input == DenormalMode::Dynamic && + DenormMode.Output == DenormalMode::IEEE) || + (DenormMode.Input == DenormalMode::IEEE && + DenormMode.Output == DenormalMode::Dynamic)) return nullptr; bool IsPositive = (!Src.isNegative() || DenormMode.Input == DenormalMode::PositiveZero || (DenormMode.Output == DenormalMode::PositiveZero && DenormMode.Input == DenormalMode::IEEE)); + return ConstantFP::get(CI->getContext(), APFloat::getZero(Src.getSemantics(), !IsPositive)); } @@ -2398,7 +2416,7 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, case Intrinsic::bswap: return ConstantInt::get(Ty->getContext(), Op->getValue().byteSwap()); case Intrinsic::ctpop: - return ConstantInt::get(Ty, Op->getValue().countPopulation()); + return ConstantInt::get(Ty, Op->getValue().popcount()); case Intrinsic::bitreverse: return ConstantInt::get(Ty->getContext(), Op->getValue().reverseBits()); case Intrinsic::convert_from_fp16: { @@ -2580,7 +2598,7 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, // The legacy behaviour is that multiplying +/- 0.0 by anything, even // NaN or infinity, gives +0.0. if (Op1V.isZero() || Op2V.isZero()) - return ConstantFP::getNullValue(Ty); + return ConstantFP::getZero(Ty); return ConstantFP::get(Ty->getContext(), Op1V * Op2V); } @@ -2633,18 +2651,18 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, } else if (auto *Op2C = dyn_cast<ConstantInt>(Operands[1])) { switch (IntrinsicID) { case Intrinsic::is_fpclass: { - uint32_t Mask = Op2C->getZExtValue(); + FPClassTest Mask = static_cast<FPClassTest>(Op2C->getZExtValue()); bool Result = ((Mask & fcSNan) && Op1V.isNaN() && Op1V.isSignaling()) || ((Mask & fcQNan) && Op1V.isNaN() && !Op1V.isSignaling()) || - ((Mask & fcNegInf) && Op1V.isInfinity() && Op1V.isNegative()) || + ((Mask & fcNegInf) && Op1V.isNegInfinity()) || ((Mask & fcNegNormal) && Op1V.isNormal() && Op1V.isNegative()) || ((Mask & fcNegSubnormal) && Op1V.isDenormal() && Op1V.isNegative()) || ((Mask & fcNegZero) && Op1V.isZero() && Op1V.isNegative()) || ((Mask & fcPosZero) && Op1V.isZero() && !Op1V.isNegative()) || ((Mask & fcPosSubnormal) && Op1V.isDenormal() && !Op1V.isNegative()) || ((Mask & fcPosNormal) && Op1V.isNormal() && !Op1V.isNegative()) || - ((Mask & fcPosInf) && Op1V.isInfinity() && !Op1V.isNegative()); + ((Mask & fcPosInf) && Op1V.isPosInfinity()); return ConstantInt::get(Ty, Result); } default: @@ -2804,9 +2822,9 @@ static Constant *ConstantFoldScalarCall2(StringRef Name, if (!C0) return Constant::getNullValue(Ty); if (IntrinsicID == Intrinsic::cttz) - return ConstantInt::get(Ty, C0->countTrailingZeros()); + return ConstantInt::get(Ty, C0->countr_zero()); else - return ConstantInt::get(Ty, C0->countLeadingZeros()); + return ConstantInt::get(Ty, C0->countl_zero()); case Intrinsic::abs: assert(C1 && "Must be constant int"); @@ -3265,6 +3283,69 @@ static Constant *ConstantFoldScalableVectorCall( return nullptr; } +static std::pair<Constant *, Constant *> +ConstantFoldScalarFrexpCall(Constant *Op, Type *IntTy) { + if (isa<PoisonValue>(Op)) + return {Op, PoisonValue::get(IntTy)}; + + auto *ConstFP = dyn_cast<ConstantFP>(Op); + if (!ConstFP) + return {}; + + const APFloat &U = ConstFP->getValueAPF(); + int FrexpExp; + APFloat FrexpMant = frexp(U, FrexpExp, APFloat::rmNearestTiesToEven); + Constant *Result0 = ConstantFP::get(ConstFP->getType(), FrexpMant); + + // The exponent is an "unspecified value" for inf/nan. We use zero to avoid + // using undef. + Constant *Result1 = FrexpMant.isFinite() ? ConstantInt::get(IntTy, FrexpExp) + : ConstantInt::getNullValue(IntTy); + return {Result0, Result1}; +} + +/// Handle intrinsics that return tuples, which may be tuples of vectors. +static Constant * +ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID, + StructType *StTy, ArrayRef<Constant *> Operands, + const DataLayout &DL, const TargetLibraryInfo *TLI, + const CallBase *Call) { + + switch (IntrinsicID) { + case Intrinsic::frexp: { + Type *Ty0 = StTy->getContainedType(0); + Type *Ty1 = StTy->getContainedType(1)->getScalarType(); + + if (auto *FVTy0 = dyn_cast<FixedVectorType>(Ty0)) { + SmallVector<Constant *, 4> Results0(FVTy0->getNumElements()); + SmallVector<Constant *, 4> Results1(FVTy0->getNumElements()); + + for (unsigned I = 0, E = FVTy0->getNumElements(); I != E; ++I) { + Constant *Lane = Operands[0]->getAggregateElement(I); + std::tie(Results0[I], Results1[I]) = + ConstantFoldScalarFrexpCall(Lane, Ty1); + if (!Results0[I]) + return nullptr; + } + + return ConstantStruct::get(StTy, ConstantVector::get(Results0), + ConstantVector::get(Results1)); + } + + auto [Result0, Result1] = ConstantFoldScalarFrexpCall(Operands[0], Ty1); + if (!Result0) + return nullptr; + return ConstantStruct::get(StTy, Result0, Result1); + } + default: + // TODO: Constant folding of vector intrinsics that fall through here does + // not work (e.g. overflow intrinsics) + return ConstantFoldScalarCall(Name, IntrinsicID, StTy, Operands, TLI, Call); + } + + return nullptr; +} + } // end anonymous namespace Constant *llvm::ConstantFoldCall(const CallBase *Call, Function *F, @@ -3276,7 +3357,8 @@ Constant *llvm::ConstantFoldCall(const CallBase *Call, Function *F, return nullptr; // If this is not an intrinsic and not recognized as a library call, bail out. - if (F->getIntrinsicID() == Intrinsic::not_intrinsic) { + Intrinsic::ID IID = F->getIntrinsicID(); + if (IID == Intrinsic::not_intrinsic) { if (!TLI) return nullptr; LibFunc LibF; @@ -3288,19 +3370,20 @@ Constant *llvm::ConstantFoldCall(const CallBase *Call, Function *F, Type *Ty = F->getReturnType(); if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) return ConstantFoldFixedVectorCall( - Name, F->getIntrinsicID(), FVTy, Operands, - F->getParent()->getDataLayout(), TLI, Call); + Name, IID, FVTy, Operands, F->getParent()->getDataLayout(), TLI, Call); if (auto *SVTy = dyn_cast<ScalableVectorType>(Ty)) return ConstantFoldScalableVectorCall( - Name, F->getIntrinsicID(), SVTy, Operands, - F->getParent()->getDataLayout(), TLI, Call); + Name, IID, SVTy, Operands, F->getParent()->getDataLayout(), TLI, Call); + + if (auto *StTy = dyn_cast<StructType>(Ty)) + return ConstantFoldStructCall(Name, IID, StTy, Operands, + F->getParent()->getDataLayout(), TLI, Call); // TODO: If this is a library function, we already discovered that above, // so we should pass the LibFunc, not the name (and it might be better // still to separate intrinsic handling from libcalls). - return ConstantFoldScalarCall(Name, F->getIntrinsicID(), Ty, Operands, TLI, - Call); + return ConstantFoldScalarCall(Name, IID, Ty, Operands, TLI, Call); } bool llvm::isMathLibCallNoop(const CallBase *Call, diff --git a/contrib/llvm-project/llvm/lib/Analysis/ConstraintSystem.cpp b/contrib/llvm-project/llvm/lib/Analysis/ConstraintSystem.cpp index 49bc5381841c..8a802515b6f4 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ConstraintSystem.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ConstraintSystem.cpp @@ -10,6 +10,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/MathExtras.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include <string> @@ -27,114 +28,169 @@ bool ConstraintSystem::eliminateUsingFM() { // IEEE conference on Supercomputing. IEEE, 1991. assert(!Constraints.empty() && "should only be called for non-empty constraint systems"); - unsigned NumVariables = Constraints[0].size(); - SmallVector<SmallVector<int64_t, 8>, 4> NewSystem; - unsigned NumConstraints = Constraints.size(); uint32_t NewGCD = 1; - // FIXME do not use copy - for (unsigned R1 = 0; R1 < NumConstraints; R1++) { - if (Constraints[R1][1] == 0) { - SmallVector<int64_t, 8> NR; - NR.push_back(Constraints[R1][0]); - for (unsigned i = 2; i < NumVariables; i++) { - NR.push_back(Constraints[R1][i]); - } - NewSystem.push_back(std::move(NR)); - continue; + unsigned LastIdx = NumVariables - 1; + + // First, either remove the variable in place if it is 0 or add the row to + // RemainingRows and remove it from the system. + SmallVector<SmallVector<Entry, 8>, 4> RemainingRows; + for (unsigned R1 = 0; R1 < Constraints.size();) { + SmallVector<Entry, 8> &Row1 = Constraints[R1]; + if (getLastCoefficient(Row1, LastIdx) == 0) { + if (Row1.size() > 0 && Row1.back().Id == LastIdx) + Row1.pop_back(); + R1++; + } else { + std::swap(Constraints[R1], Constraints.back()); + RemainingRows.push_back(std::move(Constraints.back())); + Constraints.pop_back(); } + } + // Process rows where the variable is != 0. + unsigned NumRemainingConstraints = RemainingRows.size(); + for (unsigned R1 = 0; R1 < NumRemainingConstraints; R1++) { // FIXME do not use copy - for (unsigned R2 = R1 + 1; R2 < NumConstraints; R2++) { + for (unsigned R2 = R1 + 1; R2 < NumRemainingConstraints; R2++) { if (R1 == R2) continue; - // FIXME: can we do better than just dropping things here? - if (Constraints[R2][1] == 0) - continue; + int64_t UpperLast = getLastCoefficient(RemainingRows[R2], LastIdx); + int64_t LowerLast = getLastCoefficient(RemainingRows[R1], LastIdx); + assert( + UpperLast != 0 && LowerLast != 0 && + "RemainingRows should only contain rows where the variable is != 0"); - if ((Constraints[R1][1] < 0 && Constraints[R2][1] < 0) || - (Constraints[R1][1] > 0 && Constraints[R2][1] > 0)) + if ((LowerLast < 0 && UpperLast < 0) || (LowerLast > 0 && UpperLast > 0)) continue; unsigned LowerR = R1; unsigned UpperR = R2; - if (Constraints[UpperR][1] < 0) + if (UpperLast < 0) { std::swap(LowerR, UpperR); + std::swap(LowerLast, UpperLast); + } - SmallVector<int64_t, 8> NR; - for (unsigned I = 0; I < NumVariables; I++) { - if (I == 1) - continue; - + SmallVector<Entry, 8> NR; + unsigned IdxUpper = 0; + unsigned IdxLower = 0; + auto &LowerRow = RemainingRows[LowerR]; + auto &UpperRow = RemainingRows[UpperR]; + while (true) { + if (IdxUpper >= UpperRow.size() || IdxLower >= LowerRow.size()) + break; int64_t M1, M2, N; - if (MulOverflow(Constraints[UpperR][I], - ((-1) * Constraints[LowerR][1] / GCD), M1)) + int64_t UpperV = 0; + int64_t LowerV = 0; + uint16_t CurrentId = std::numeric_limits<uint16_t>::max(); + if (IdxUpper < UpperRow.size()) { + CurrentId = std::min(UpperRow[IdxUpper].Id, CurrentId); + } + if (IdxLower < LowerRow.size()) { + CurrentId = std::min(LowerRow[IdxLower].Id, CurrentId); + } + + if (IdxUpper < UpperRow.size() && UpperRow[IdxUpper].Id == CurrentId) { + UpperV = UpperRow[IdxUpper].Coefficient; + IdxUpper++; + } + + if (MulOverflow(UpperV, ((-1) * LowerLast / GCD), M1)) return false; - if (MulOverflow(Constraints[LowerR][I], - (Constraints[UpperR][1] / GCD), M2)) + if (IdxLower < LowerRow.size() && LowerRow[IdxLower].Id == CurrentId) { + LowerV = LowerRow[IdxLower].Coefficient; + IdxLower++; + } + + if (MulOverflow(LowerV, (UpperLast / GCD), M2)) return false; if (AddOverflow(M1, M2, N)) return false; - NR.push_back(N); + if (N == 0) + continue; + NR.emplace_back(N, CurrentId); - NewGCD = APIntOps::GreatestCommonDivisor({32, (uint32_t)NR.back()}, - {32, NewGCD}) - .getZExtValue(); + NewGCD = + APIntOps::GreatestCommonDivisor({32, (uint32_t)N}, {32, NewGCD}) + .getZExtValue(); } - NewSystem.push_back(std::move(NR)); + if (NR.empty()) + continue; + Constraints.push_back(std::move(NR)); // Give up if the new system gets too big. - if (NewSystem.size() > 500) + if (Constraints.size() > 500) return false; } } - Constraints = std::move(NewSystem); + NumVariables -= 1; GCD = NewGCD; return true; } bool ConstraintSystem::mayHaveSolutionImpl() { - while (!Constraints.empty() && Constraints[0].size() > 1) { + while (!Constraints.empty() && NumVariables > 1) { if (!eliminateUsingFM()) return true; } - if (Constraints.empty() || Constraints[0].size() > 1) + if (Constraints.empty() || NumVariables > 1) return true; - return all_of(Constraints, [](auto &R) { return R[0] >= 0; }); + return all_of(Constraints, [](auto &R) { + if (R.empty()) + return true; + if (R[0].Id == 0) + return R[0].Coefficient >= 0; + return true; + }); } -void ConstraintSystem::dump(ArrayRef<std::string> Names) const { +SmallVector<std::string> ConstraintSystem::getVarNamesList() const { + SmallVector<std::string> Names(Value2Index.size(), ""); +#ifndef NDEBUG + for (auto &[V, Index] : Value2Index) { + std::string OperandName; + if (V->getName().empty()) + OperandName = V->getNameOrAsOperand(); + else + OperandName = std::string("%") + V->getName().str(); + Names[Index - 1] = OperandName; + } +#endif + return Names; +} + +void ConstraintSystem::dump() const { +#ifndef NDEBUG if (Constraints.empty()) return; - + SmallVector<std::string> Names = getVarNamesList(); for (const auto &Row : Constraints) { SmallVector<std::string, 16> Parts; - for (unsigned I = 1, S = Row.size(); I < S; ++I) { - if (Row[I] == 0) + for (unsigned I = 0, S = Row.size(); I < S; ++I) { + if (Row[I].Id >= NumVariables) + break; + if (Row[I].Id == 0) continue; std::string Coefficient; - if (Row[I] != 1) - Coefficient = std::to_string(Row[I]) + " * "; - Parts.push_back(Coefficient + Names[I - 1]); + if (Row[I].Coefficient != 1) + Coefficient = std::to_string(Row[I].Coefficient) + " * "; + Parts.push_back(Coefficient + Names[Row[I].Id - 1]); } - assert(!Parts.empty() && "need to have at least some parts"); + // assert(!Parts.empty() && "need to have at least some parts"); + int64_t ConstPart = 0; + if (Row[0].Id == 0) + ConstPart = Row[0].Coefficient; LLVM_DEBUG(dbgs() << join(Parts, std::string(" + ")) - << " <= " << std::to_string(Row[0]) << "\n"); + << " <= " << std::to_string(ConstPart) << "\n"); } -} - -void ConstraintSystem::dump() const { - SmallVector<std::string, 16> Names; - for (unsigned i = 1; i < Constraints.back().size(); ++i) - Names.push_back("x" + std::to_string(i)); - LLVM_DEBUG(dbgs() << "---\n"); - dump(Names); +#endif } bool ConstraintSystem::mayHaveSolution() { + LLVM_DEBUG(dbgs() << "---\n"); LLVM_DEBUG(dump()); bool HasSolution = mayHaveSolutionImpl(); LLVM_DEBUG(dbgs() << (HasSolution ? "sat" : "unsat") << "\n"); @@ -150,6 +206,8 @@ bool ConstraintSystem::isConditionImplied(SmallVector<int64_t, 8> R) const { // If there is no solution with the negation of R added to the system, the // condition must hold based on the existing constraints. R = ConstraintSystem::negate(R); + if (R.empty()) + return false; auto NewSystem = *this; NewSystem.addVariableRow(R); diff --git a/contrib/llvm-project/llvm/lib/Analysis/CycleAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/CycleAnalysis.cpp index 17998123fce7..41a95a4fa220 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/CycleAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/CycleAnalysis.cpp @@ -17,9 +17,6 @@ namespace llvm { class Module; } -template class llvm::GenericCycleInfo<SSAContext>; -template class llvm::GenericCycle<SSAContext>; - CycleInfo CycleAnalysis::run(Function &F, FunctionAnalysisManager &) { CycleInfo CI; CI.compute(F); diff --git a/contrib/llvm-project/llvm/lib/Analysis/DDG.cpp b/contrib/llvm-project/llvm/lib/Analysis/DDG.cpp index da64ef153960..a0774096c512 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/DDG.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/DDG.cpp @@ -241,11 +241,10 @@ bool DataDependenceGraph::addNode(DDGNode &N) { } const PiBlockDDGNode *DataDependenceGraph::getPiBlock(const NodeType &N) const { - if (PiBlockMap.find(&N) == PiBlockMap.end()) + if (!PiBlockMap.contains(&N)) return nullptr; auto *Pi = PiBlockMap.find(&N)->second; - assert(PiBlockMap.find(Pi) == PiBlockMap.end() && - "Nested pi-blocks detected."); + assert(!PiBlockMap.contains(Pi) && "Nested pi-blocks detected."); return Pi; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/DemandedBits.cpp b/contrib/llvm-project/llvm/lib/Analysis/DemandedBits.cpp index e01ed48be376..c5017bf52498 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/DemandedBits.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/DemandedBits.cpp @@ -34,8 +34,6 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" @@ -48,30 +46,6 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "demanded-bits" -char DemandedBitsWrapperPass::ID = 0; - -INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits", - "Demanded bits analysis", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits", - "Demanded bits analysis", false, false) - -DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) { - initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry()); -} - -void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesCFG(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.setPreservesAll(); -} - -void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const { - DB->print(OS); -} - static bool isAlwaysLive(Instruction *I) { return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() || I->mayHaveSideEffects(); @@ -109,7 +83,7 @@ void DemandedBits::determineLiveOperandBits( default: break; case Instruction::Call: case Instruction::Invoke: - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI)) { + if (const auto *II = dyn_cast<IntrinsicInst>(UserI)) { switch (II->getIntrinsicID()) { default: break; case Intrinsic::bswap: @@ -170,7 +144,7 @@ void DemandedBits::determineLiveOperandBits( case Intrinsic::smin: // If low bits of result are not demanded, they are also not demanded // for the min/max operands. - AB = APInt::getBitsSetFrom(BitWidth, AOut.countTrailingZeros()); + AB = APInt::getBitsSetFrom(BitWidth, AOut.countr_zero()); break; } } @@ -206,7 +180,7 @@ void DemandedBits::determineLiveOperandBits( // If the shift is nuw/nsw, then the high bits are not dead // (because we've promised that they *must* be zero). - const ShlOperator *S = cast<ShlOperator>(UserI); + const auto *S = cast<ShlOperator>(UserI); if (S->hasNoSignedWrap()) AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1); else if (S->hasNoUnsignedWrap()) @@ -310,17 +284,6 @@ void DemandedBits::determineLiveOperandBits( } } -bool DemandedBitsWrapperPass::runOnFunction(Function &F) { - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - DB.emplace(F, AC, DT); - return false; -} - -void DemandedBitsWrapperPass::releaseMemory() { - DB.reset(); -} - void DemandedBits::performAnalysis() { if (Analyzed) // Analysis already completed for this function. @@ -353,7 +316,7 @@ void DemandedBits::performAnalysis() { // Non-integer-typed instructions... for (Use &OI : I.operands()) { - if (Instruction *J = dyn_cast<Instruction>(OI)) { + if (auto *J = dyn_cast<Instruction>(OI)) { Type *T = J->getType(); if (T->isIntOrIntVectorTy()) AliveBits[J] = APInt::getAllOnes(T->getScalarSizeInBits()); @@ -394,7 +357,7 @@ void DemandedBits::performAnalysis() { for (Use &OI : UserI->operands()) { // We also want to detect dead uses of arguments, but will only store // demanded bits for instructions. - Instruction *I = dyn_cast<Instruction>(OI); + auto *I = dyn_cast<Instruction>(OI); if (!I && !isa<Argument>(OI)) continue; @@ -447,7 +410,7 @@ APInt DemandedBits::getDemandedBits(Instruction *I) { APInt DemandedBits::getDemandedBits(Use *U) { Type *T = (*U)->getType(); - Instruction *UserI = cast<Instruction>(U->getUser()); + auto *UserI = cast<Instruction>(U->getUser()); const DataLayout &DL = UserI->getModule()->getDataLayout(); unsigned BitWidth = DL.getTypeSizeInBits(T->getScalarType()); @@ -475,8 +438,7 @@ APInt DemandedBits::getDemandedBits(Use *U) { bool DemandedBits::isInstructionDead(Instruction *I) { performAnalysis(); - return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() && - !isAlwaysLive(I); + return !Visited.count(I) && !AliveBits.contains(I) && !isAlwaysLive(I); } bool DemandedBits::isUseDead(Use *U) { @@ -485,7 +447,7 @@ bool DemandedBits::isUseDead(Use *U) { return false; // Uses by always-live instructions are never dead. - Instruction *UserI = cast<Instruction>(U->getUser()); + auto *UserI = cast<Instruction>(U->getUser()); if (isAlwaysLive(UserI)) return false; @@ -515,6 +477,7 @@ void DemandedBits::print(raw_ostream &OS) { OS << *I << '\n'; }; + OS << "Printing analysis 'Demanded Bits Analysis' for function '" << F.getName() << "':\n"; performAnalysis(); for (auto &KV : AliveBits) { Instruction *I = KV.first; @@ -606,10 +569,6 @@ APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo, true); } -FunctionPass *llvm::createDemandedBitsWrapperPass() { - return new DemandedBitsWrapperPass(); -} - AnalysisKey DemandedBitsAnalysis::Key; DemandedBits DemandedBitsAnalysis::run(Function &F, diff --git a/contrib/llvm-project/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp b/contrib/llvm-project/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp index a91d2ffe6042..456d58660680 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp @@ -165,7 +165,6 @@ private: bool isLogging() const { return !!Logger; } std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB) override; - std::function<bool(CallBase &)> GetDefaultAdvice; const bool IsDoingInference; std::unique_ptr<TrainingLogger> Logger; @@ -280,10 +279,10 @@ TrainingLogger::TrainingLogger(StringRef LogFileName, append_range(FT, MUTR->extraOutputsForLoggingSpecs()); DefaultDecisionPos = FT.size(); - FT.push_back(TensorSpec::createSpec<int64_t>(DefaultDecisionName, {1})); + FT.push_back(DefaultDecisionSpec); DecisionPos = FT.size(); - FT.push_back(TensorSpec::createSpec<int64_t>(DecisionName, {1})); + FT.push_back(InlineDecisionSpec); std::error_code EC; auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC); if (EC) @@ -331,8 +330,7 @@ DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor( std::unique_ptr<MLModelRunner> ModelRunner, std::function<bool(CallBase &)> GetDefaultAdvice, std::unique_ptr<TrainingLogger> Logger) - : MLInlineAdvisor(M, MAM, std::move(ModelRunner)), - GetDefaultAdvice(GetDefaultAdvice), + : MLInlineAdvisor(M, MAM, std::move(ModelRunner), GetDefaultAdvice), IsDoingInference(isa<ModelUnderTrainingRunner>(getModelRunner())), Logger(std::move(Logger)), InitialNativeSize(isLogging() ? getTotalSizeEstimate() : 0), diff --git a/contrib/llvm-project/llvm/lib/Analysis/DivergenceAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/DivergenceAnalysis.cpp deleted file mode 100644 index 02c40d2640c1..000000000000 --- a/contrib/llvm-project/llvm/lib/Analysis/DivergenceAnalysis.cpp +++ /dev/null @@ -1,409 +0,0 @@ -//===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a general divergence analysis for loop vectorization -// and GPU programs. It determines which branches and values in a loop or GPU -// program are divergent. It can help branch optimizations such as jump -// threading and loop unswitching to make better decisions. -// -// GPU programs typically use the SIMD execution model, where multiple threads -// in the same execution group have to execute in lock-step. Therefore, if the -// code contains divergent branches (i.e., threads in a group do not agree on -// which path of the branch to take), the group of threads has to execute all -// the paths from that branch with different subsets of threads enabled until -// they re-converge. -// -// Due to this execution model, some optimizations such as jump -// threading and loop unswitching can interfere with thread re-convergence. -// Therefore, an analysis that computes which branches in a GPU program are -// divergent can help the compiler to selectively run these optimizations. -// -// This implementation is derived from the Vectorization Analysis of the -// Region Vectorizer (RV). The analysis is based on the approach described in -// -// An abstract interpretation for SPMD divergence -// on reducible control flow graphs. -// Julian Rosemann, Simon Moll and Sebastian Hack -// POPL '21 -// -// This implementation is generic in the sense that it does -// not itself identify original sources of divergence. -// Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and -// (DivergenceAnalysis) for functions, identify the sources of divergence -// (e.g., special variables that hold the thread ID or the iteration variable). -// -// The generic implementation propagates divergence to variables that are data -// or sync dependent on a source of divergence. -// -// While data dependency is a well-known concept, the notion of sync dependency -// is worth more explanation. Sync dependence characterizes the control flow -// aspect of the propagation of branch divergence. For example, -// -// %cond = icmp slt i32 %tid, 10 -// br i1 %cond, label %then, label %else -// then: -// br label %merge -// else: -// br label %merge -// merge: -// %a = phi i32 [ 0, %then ], [ 1, %else ] -// -// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid -// because %tid is not on its use-def chains, %a is sync dependent on %tid -// because the branch "br i1 %cond" depends on %tid and affects which value %a -// is assigned to. -// -// The sync dependence detection (which branch induces divergence in which join -// points) is implemented in the SyncDependenceAnalysis. -// -// The current implementation has the following limitations: -// 1. intra-procedural. It conservatively considers the arguments of a -// non-kernel-entry function and the return value of a function call as -// divergent. -// 2. memory as black box. It conservatively considers values loaded from -// generic or local address as divergent. This can be improved by leveraging -// pointer analysis and/or by modelling non-escaping memory objects in SSA -// as done in RV. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Analysis/DivergenceAnalysis.h" -#include "llvm/ADT/PostOrderIterator.h" -#include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/PostDominators.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Value.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - -using namespace llvm; - -#define DEBUG_TYPE "divergence" - -DivergenceAnalysisImpl::DivergenceAnalysisImpl( - const Function &F, const Loop *RegionLoop, const DominatorTree &DT, - const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) - : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), - IsLCSSAForm(IsLCSSAForm) {} - -bool DivergenceAnalysisImpl::markDivergent(const Value &DivVal) { - if (isAlwaysUniform(DivVal)) - return false; - assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal)); - assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); - return DivergentValues.insert(&DivVal).second; -} - -void DivergenceAnalysisImpl::addUniformOverride(const Value &UniVal) { - UniformOverrides.insert(&UniVal); -} - -bool DivergenceAnalysisImpl::isTemporalDivergent( - const BasicBlock &ObservingBlock, const Value &Val) const { - const auto *Inst = dyn_cast<const Instruction>(&Val); - if (!Inst) - return false; - // check whether any divergent loop carrying Val terminates before control - // proceeds to ObservingBlock - for (const auto *Loop = LI.getLoopFor(Inst->getParent()); - Loop != RegionLoop && !Loop->contains(&ObservingBlock); - Loop = Loop->getParentLoop()) { - if (DivergentLoops.contains(Loop)) - return true; - } - - return false; -} - -bool DivergenceAnalysisImpl::inRegion(const Instruction &I) const { - return I.getParent() && inRegion(*I.getParent()); -} - -bool DivergenceAnalysisImpl::inRegion(const BasicBlock &BB) const { - return RegionLoop ? RegionLoop->contains(&BB) : (BB.getParent() == &F); -} - -void DivergenceAnalysisImpl::pushUsers(const Value &V) { - const auto *I = dyn_cast<const Instruction>(&V); - - if (I && I->isTerminator()) { - analyzeControlDivergence(*I); - return; - } - - for (const auto *User : V.users()) { - const auto *UserInst = dyn_cast<const Instruction>(User); - if (!UserInst) - continue; - - // only compute divergent inside loop - if (!inRegion(*UserInst)) - continue; - - // All users of divergent values are immediate divergent - if (markDivergent(*UserInst)) - Worklist.push_back(UserInst); - } -} - -static const Instruction *getIfCarriedInstruction(const Use &U, - const Loop &DivLoop) { - const auto *I = dyn_cast<const Instruction>(&U); - if (!I) - return nullptr; - if (!DivLoop.contains(I)) - return nullptr; - return I; -} - -void DivergenceAnalysisImpl::analyzeTemporalDivergence( - const Instruction &I, const Loop &OuterDivLoop) { - if (isAlwaysUniform(I)) - return; - if (isDivergent(I)) - return; - - LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n"); - assert((isa<PHINode>(I) || !IsLCSSAForm) && - "In LCSSA form all users of loop-exiting defs are Phi nodes."); - for (const Use &Op : I.operands()) { - const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop); - if (!OpInst) - continue; - if (markDivergent(I)) - pushUsers(I); - return; - } -} - -// marks all users of loop-carried values of the loop headed by LoopHeader as -// divergent -void DivergenceAnalysisImpl::analyzeLoopExitDivergence( - const BasicBlock &DivExit, const Loop &OuterDivLoop) { - // All users are in immediate exit blocks - if (IsLCSSAForm) { - for (const auto &Phi : DivExit.phis()) { - analyzeTemporalDivergence(Phi, OuterDivLoop); - } - return; - } - - // For non-LCSSA we have to follow all live out edges wherever they may lead. - const BasicBlock &LoopHeader = *OuterDivLoop.getHeader(); - SmallVector<const BasicBlock *, 8> TaintStack; - TaintStack.push_back(&DivExit); - - // Otherwise potential users of loop-carried values could be anywhere in the - // dominance region of DivLoop (including its fringes for phi nodes) - DenseSet<const BasicBlock *> Visited; - Visited.insert(&DivExit); - - do { - auto *UserBlock = TaintStack.pop_back_val(); - - // don't spread divergence beyond the region - if (!inRegion(*UserBlock)) - continue; - - assert(!OuterDivLoop.contains(UserBlock) && - "irreducible control flow detected"); - - // phi nodes at the fringes of the dominance region - if (!DT.dominates(&LoopHeader, UserBlock)) { - // all PHI nodes of UserBlock become divergent - for (const auto &Phi : UserBlock->phis()) { - analyzeTemporalDivergence(Phi, OuterDivLoop); - } - continue; - } - - // Taint outside users of values carried by OuterDivLoop. - for (const auto &I : *UserBlock) { - analyzeTemporalDivergence(I, OuterDivLoop); - } - - // visit all blocks in the dominance region - for (const auto *SuccBlock : successors(UserBlock)) { - if (!Visited.insert(SuccBlock).second) { - continue; - } - TaintStack.push_back(SuccBlock); - } - } while (!TaintStack.empty()); -} - -void DivergenceAnalysisImpl::propagateLoopExitDivergence( - const BasicBlock &DivExit, const Loop &InnerDivLoop) { - LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n"); - - // Find outer-most loop that does not contain \p DivExit - const Loop *DivLoop = &InnerDivLoop; - const Loop *OuterDivLoop = DivLoop; - const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit); - const unsigned LoopExitDepth = - ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0; - while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) { - DivergentLoops.insert(DivLoop); // all crossed loops are divergent - OuterDivLoop = DivLoop; - DivLoop = DivLoop->getParentLoop(); - } - LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName() - << "\n"); - - analyzeLoopExitDivergence(DivExit, *OuterDivLoop); -} - -// this is a divergent join point - mark all phi nodes as divergent and push -// them onto the stack. -void DivergenceAnalysisImpl::taintAndPushPhiNodes(const BasicBlock &JoinBlock) { - LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName() - << "\n"); - - // ignore divergence outside the region - if (!inRegion(JoinBlock)) { - return; - } - - // push non-divergent phi nodes in JoinBlock to the worklist - for (const auto &Phi : JoinBlock.phis()) { - if (isDivergent(Phi)) - continue; - // FIXME Theoretically ,the 'undef' value could be replaced by any other - // value causing spurious divergence. - if (Phi.hasConstantOrUndefValue()) - continue; - if (markDivergent(Phi)) - Worklist.push_back(&Phi); - } -} - -void DivergenceAnalysisImpl::analyzeControlDivergence(const Instruction &Term) { - LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName() - << "\n"); - - // Don't propagate divergence from unreachable blocks. - if (!DT.isReachableFromEntry(Term.getParent())) - return; - - const auto *BranchLoop = LI.getLoopFor(Term.getParent()); - - const auto &DivDesc = SDA.getJoinBlocks(Term); - - // Iterate over all blocks now reachable by a disjoint path join - for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { - taintAndPushPhiNodes(*JoinBlock); - } - - assert(DivDesc.LoopDivBlocks.empty() || BranchLoop); - for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) { - propagateLoopExitDivergence(*DivExitBlock, *BranchLoop); - } -} - -void DivergenceAnalysisImpl::compute() { - // Initialize worklist. - auto DivValuesCopy = DivergentValues; - for (const auto *DivVal : DivValuesCopy) { - assert(isDivergent(*DivVal) && "Worklist invariant violated!"); - pushUsers(*DivVal); - } - - // All values on the Worklist are divergent. - // Their users may not have been updated yed. - while (!Worklist.empty()) { - const Instruction &I = *Worklist.back(); - Worklist.pop_back(); - - // propagate value divergence to users - assert(isDivergent(I) && "Worklist invariant violated!"); - pushUsers(I); - } -} - -bool DivergenceAnalysisImpl::isAlwaysUniform(const Value &V) const { - return UniformOverrides.contains(&V); -} - -bool DivergenceAnalysisImpl::isDivergent(const Value &V) const { - return DivergentValues.contains(&V); -} - -bool DivergenceAnalysisImpl::isDivergentUse(const Use &U) const { - Value &V = *U.get(); - Instruction &I = *cast<Instruction>(U.getUser()); - return isDivergent(V) || isTemporalDivergent(*I.getParent(), V); -} - -DivergenceInfo::DivergenceInfo(Function &F, const DominatorTree &DT, - const PostDominatorTree &PDT, const LoopInfo &LI, - const TargetTransformInfo &TTI, - bool KnownReducible) - : F(F) { - if (!KnownReducible) { - using RPOTraversal = ReversePostOrderTraversal<const Function *>; - RPOTraversal FuncRPOT(&F); - if (containsIrreducibleCFG<const BasicBlock *, const RPOTraversal, - const LoopInfo>(FuncRPOT, LI)) { - ContainsIrreducible = true; - return; - } - } - SDA = std::make_unique<SyncDependenceAnalysis>(DT, PDT, LI); - DA = std::make_unique<DivergenceAnalysisImpl>(F, nullptr, DT, LI, *SDA, - /* LCSSA */ false); - for (auto &I : instructions(F)) { - if (TTI.isSourceOfDivergence(&I)) { - DA->markDivergent(I); - } else if (TTI.isAlwaysUniform(&I)) { - DA->addUniformOverride(I); - } - } - for (auto &Arg : F.args()) { - if (TTI.isSourceOfDivergence(&Arg)) { - DA->markDivergent(Arg); - } - } - - DA->compute(); -} - -AnalysisKey DivergenceAnalysis::Key; - -DivergenceAnalysis::Result -DivergenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) { - auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); - auto &LI = AM.getResult<LoopAnalysis>(F); - auto &TTI = AM.getResult<TargetIRAnalysis>(F); - - return DivergenceInfo(F, DT, PDT, LI, TTI, /* KnownReducible = */ false); -} - -PreservedAnalyses -DivergenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) { - auto &DI = FAM.getResult<DivergenceAnalysis>(F); - OS << "'Divergence Analysis' for function '" << F.getName() << "':\n"; - if (DI.hasDivergence()) { - for (auto &Arg : F.args()) { - OS << (DI.isDivergent(Arg) ? "DIVERGENT: " : " "); - OS << Arg << "\n"; - } - for (const BasicBlock &BB : F) { - OS << "\n " << BB.getName() << ":\n"; - for (const auto &I : BB.instructionsWithoutDebug()) { - OS << (DI.isDivergent(I) ? "DIVERGENT: " : " "); - OS << I << "\n"; - } - } - } - return PreservedAnalyses::all(); -} diff --git a/contrib/llvm-project/llvm/lib/Analysis/EHPersonalities.cpp b/contrib/llvm-project/llvm/lib/Analysis/EHPersonalities.cpp deleted file mode 100644 index 277ff6ba735f..000000000000 --- a/contrib/llvm-project/llvm/lib/Analysis/EHPersonalities.cpp +++ /dev/null @@ -1,143 +0,0 @@ -//===- EHPersonalities.cpp - Compute EH-related information ---------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "llvm/Analysis/EHPersonalities.h" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/ADT/Triple.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Instructions.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -using namespace llvm; - -/// See if the given exception handling personality function is one that we -/// understand. If so, return a description of it; otherwise return Unknown. -EHPersonality llvm::classifyEHPersonality(const Value *Pers) { - const GlobalValue *F = - Pers ? dyn_cast<GlobalValue>(Pers->stripPointerCasts()) : nullptr; - if (!F || !F->getValueType() || !F->getValueType()->isFunctionTy()) - return EHPersonality::Unknown; - return StringSwitch<EHPersonality>(F->getName()) - .Case("__gnat_eh_personality", EHPersonality::GNU_Ada) - .Case("__gxx_personality_v0", EHPersonality::GNU_CXX) - .Case("__gxx_personality_seh0", EHPersonality::GNU_CXX) - .Case("__gxx_personality_sj0", EHPersonality::GNU_CXX_SjLj) - .Case("__gcc_personality_v0", EHPersonality::GNU_C) - .Case("__gcc_personality_seh0", EHPersonality::GNU_C) - .Case("__gcc_personality_sj0", EHPersonality::GNU_C_SjLj) - .Case("__objc_personality_v0", EHPersonality::GNU_ObjC) - .Case("_except_handler3", EHPersonality::MSVC_X86SEH) - .Case("_except_handler4", EHPersonality::MSVC_X86SEH) - .Case("__C_specific_handler", EHPersonality::MSVC_TableSEH) - .Case("__CxxFrameHandler3", EHPersonality::MSVC_CXX) - .Case("ProcessCLRException", EHPersonality::CoreCLR) - .Case("rust_eh_personality", EHPersonality::Rust) - .Case("__gxx_wasm_personality_v0", EHPersonality::Wasm_CXX) - .Case("__xlcxx_personality_v1", EHPersonality::XL_CXX) - .Default(EHPersonality::Unknown); -} - -StringRef llvm::getEHPersonalityName(EHPersonality Pers) { - switch (Pers) { - case EHPersonality::GNU_Ada: return "__gnat_eh_personality"; - case EHPersonality::GNU_CXX: return "__gxx_personality_v0"; - case EHPersonality::GNU_CXX_SjLj: return "__gxx_personality_sj0"; - case EHPersonality::GNU_C: return "__gcc_personality_v0"; - case EHPersonality::GNU_C_SjLj: return "__gcc_personality_sj0"; - case EHPersonality::GNU_ObjC: return "__objc_personality_v0"; - case EHPersonality::MSVC_X86SEH: return "_except_handler3"; - case EHPersonality::MSVC_TableSEH: - return "__C_specific_handler"; - case EHPersonality::MSVC_CXX: return "__CxxFrameHandler3"; - case EHPersonality::CoreCLR: return "ProcessCLRException"; - case EHPersonality::Rust: return "rust_eh_personality"; - case EHPersonality::Wasm_CXX: return "__gxx_wasm_personality_v0"; - case EHPersonality::XL_CXX: - return "__xlcxx_personality_v1"; - case EHPersonality::Unknown: llvm_unreachable("Unknown EHPersonality!"); - } - - llvm_unreachable("Invalid EHPersonality!"); -} - -EHPersonality llvm::getDefaultEHPersonality(const Triple &T) { - if (T.isPS5()) - return EHPersonality::GNU_CXX; - else - return EHPersonality::GNU_C; -} - -bool llvm::canSimplifyInvokeNoUnwind(const Function *F) { - EHPersonality Personality = classifyEHPersonality(F->getPersonalityFn()); - // We can't simplify any invokes to nounwind functions if the personality - // function wants to catch asynch exceptions. The nounwind attribute only - // implies that the function does not throw synchronous exceptions. - return !isAsynchronousEHPersonality(Personality); -} - -DenseMap<BasicBlock *, ColorVector> llvm::colorEHFunclets(Function &F) { - SmallVector<std::pair<BasicBlock *, BasicBlock *>, 16> Worklist; - BasicBlock *EntryBlock = &F.getEntryBlock(); - DenseMap<BasicBlock *, ColorVector> BlockColors; - - // Build up the color map, which maps each block to its set of 'colors'. - // For any block B the "colors" of B are the set of funclets F (possibly - // including a root "funclet" representing the main function) such that - // F will need to directly contain B or a copy of B (where the term "directly - // contain" is used to distinguish from being "transitively contained" in - // a nested funclet). - // - // Note: Despite not being a funclet in the truest sense, a catchswitch is - // considered to belong to its own funclet for the purposes of coloring. - - DEBUG_WITH_TYPE("winehprepare-coloring", dbgs() << "\nColoring funclets for " - << F.getName() << "\n"); - - Worklist.push_back({EntryBlock, EntryBlock}); - - while (!Worklist.empty()) { - BasicBlock *Visiting; - BasicBlock *Color; - std::tie(Visiting, Color) = Worklist.pop_back_val(); - DEBUG_WITH_TYPE("winehprepare-coloring", - dbgs() << "Visiting " << Visiting->getName() << ", " - << Color->getName() << "\n"); - Instruction *VisitingHead = Visiting->getFirstNonPHI(); - if (VisitingHead->isEHPad()) { - // Mark this funclet head as a member of itself. - Color = Visiting; - } - // Note that this is a member of the given color. - ColorVector &Colors = BlockColors[Visiting]; - if (!is_contained(Colors, Color)) - Colors.push_back(Color); - else - continue; - - DEBUG_WITH_TYPE("winehprepare-coloring", - dbgs() << " Assigned color \'" << Color->getName() - << "\' to block \'" << Visiting->getName() - << "\'.\n"); - - BasicBlock *SuccColor = Color; - Instruction *Terminator = Visiting->getTerminator(); - if (auto *CatchRet = dyn_cast<CatchReturnInst>(Terminator)) { - Value *ParentPad = CatchRet->getCatchSwitchParentPad(); - if (isa<ConstantTokenNone>(ParentPad)) - SuccColor = EntryBlock; - else - SuccColor = cast<Instruction>(ParentPad)->getParent(); - } - - for (BasicBlock *Succ : successors(Visiting)) - Worklist.push_back({Succ, SuccColor}); - } - return BlockColors; -} diff --git a/contrib/llvm-project/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp index 782c11937507..6094f22a17fd 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp @@ -82,13 +82,15 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F, } FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo( - const Function &F, FunctionAnalysisManager &FAM) { + Function &F, FunctionAnalysisManager &FAM) { + return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F), + FAM.getResult<LoopAnalysis>(F)); +} + +FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo( + const Function &F, const DominatorTree &DT, const LoopInfo &LI) { FunctionPropertiesInfo FPI; - // The const casts are due to the getResult API - there's no mutation of F. - const auto &LI = FAM.getResult<LoopAnalysis>(const_cast<Function &>(F)); - const auto &DT = - FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F)); for (const auto &BB : F) if (DT.isReachableFromEntry(&BB)) FPI.reIncludeBB(BB); @@ -127,7 +129,7 @@ FunctionPropertiesPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { } FunctionPropertiesUpdater::FunctionPropertiesUpdater( - FunctionPropertiesInfo &FPI, const CallBase &CB) + FunctionPropertiesInfo &FPI, CallBase &CB) : FPI(FPI), CallSiteBB(*CB.getParent()), Caller(*CallSiteBB.getParent()) { assert(isa<CallInst>(CB) || isa<InvokeInst>(CB)); // For BBs that are likely to change, we subtract from feature totals their @@ -247,5 +249,13 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const { const auto &LI = FAM.getResult<LoopAnalysis>(const_cast<Function &>(Caller)); FPI.updateAggregateStats(Caller, LI); - assert(FPI == FunctionPropertiesInfo::getFunctionPropertiesInfo(Caller, FAM)); } + +bool FunctionPropertiesUpdater::isUpdateValid(Function &F, + const FunctionPropertiesInfo &FPI, + FunctionAnalysisManager &FAM) { + DominatorTree DT(F); + LoopInfo LI(DT); + auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI); + return FPI == Fresh; +}
\ No newline at end of file diff --git a/contrib/llvm-project/llvm/lib/Analysis/GuardUtils.cpp b/contrib/llvm-project/llvm/lib/Analysis/GuardUtils.cpp index cd132c56991f..40b898e96f3b 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/GuardUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/GuardUtils.cpp @@ -32,12 +32,19 @@ bool llvm::isGuardAsWidenableBranch(const User *U) { if (!parseWidenableBranch(U, Condition, WidenableCondition, GuardedBB, DeoptBB)) return false; - for (auto &Insn : *DeoptBB) { - if (match(&Insn, m_Intrinsic<Intrinsic::experimental_deoptimize>())) - return true; - if (Insn.mayHaveSideEffects()) + SmallPtrSet<const BasicBlock *, 2> Visited; + Visited.insert(DeoptBB); + do { + for (auto &Insn : *DeoptBB) { + if (match(&Insn, m_Intrinsic<Intrinsic::experimental_deoptimize>())) + return true; + if (Insn.mayHaveSideEffects()) + return false; + } + DeoptBB = DeoptBB->getUniqueSuccessor(); + if (!DeoptBB) return false; - } + } while (Visited.insert(DeoptBB).second); return false; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/IRSimilarityIdentifier.cpp b/contrib/llvm-project/llvm/lib/Analysis/IRSimilarityIdentifier.cpp index f471e32344cb..f029c8342fde 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/IRSimilarityIdentifier.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/IRSimilarityIdentifier.cpp @@ -14,6 +14,7 @@ #include "llvm/Analysis/IRSimilarityIdentifier.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" #include "llvm/IR/User.h" @@ -97,7 +98,8 @@ void IRInstructionData::setBranchSuccessors( int CurrentBlockNumber = static_cast<int>(BBNumIt->second); - for (BasicBlock *Successor : BI->successors()) { + for (Value *V : getBlockOperVals()) { + BasicBlock *Successor = cast<BasicBlock>(V); BBNumIt = BasicBlockToInteger.find(Successor); assert(BBNumIt != BasicBlockToInteger.end() && "Could not find number for BasicBlock!"); @@ -108,6 +110,25 @@ void IRInstructionData::setBranchSuccessors( } } +ArrayRef<Value *> IRInstructionData::getBlockOperVals() { + assert((isa<BranchInst>(Inst) || + isa<PHINode>(Inst)) && "Instruction must be branch or PHINode"); + + if (BranchInst *BI = dyn_cast<BranchInst>(Inst)) + return ArrayRef<Value *>( + std::next(OperVals.begin(), BI->isConditional() ? 1 : 0), + OperVals.end() + ); + + if (PHINode *PN = dyn_cast<PHINode>(Inst)) + return ArrayRef<Value *>( + std::next(OperVals.begin(), PN->getNumIncomingValues()), + OperVals.end() + ); + + return ArrayRef<Value *>(); +} + void IRInstructionData::setCalleeName(bool MatchByName) { CallInst *CI = dyn_cast<CallInst>(Inst); assert(CI && "Instruction must be call"); @@ -159,7 +180,6 @@ void IRInstructionData::setPHIPredecessors( int Relative = OtherBlockNumber - CurrentBlockNumber; RelativeBlockLocations.push_back(Relative); - RelativeBlockLocations.push_back(Relative); } } @@ -439,7 +459,7 @@ IRSimilarityCandidate::IRSimilarityCandidate(unsigned StartIdx, unsigned Len, // Map the operand values to an unsigned integer if it does not already // have an unsigned integer assigned to it. for (Value *Arg : ID->OperVals) - if (ValueToNumber.find(Arg) == ValueToNumber.end()) { + if (!ValueToNumber.contains(Arg)) { ValueToNumber.try_emplace(Arg, LocalValNumber); NumberToValue.try_emplace(LocalValNumber, Arg); LocalValNumber++; @@ -447,7 +467,7 @@ IRSimilarityCandidate::IRSimilarityCandidate(unsigned StartIdx, unsigned Len, // Mapping the instructions to an unsigned integer if it is not already // exist in the mapping. - if (ValueToNumber.find(ID->Inst) == ValueToNumber.end()) { + if (!ValueToNumber.contains(ID->Inst)) { ValueToNumber.try_emplace(ID->Inst, LocalValNumber); NumberToValue.try_emplace(LocalValNumber, ID->Inst); LocalValNumber++; @@ -464,7 +484,7 @@ IRSimilarityCandidate::IRSimilarityCandidate(unsigned StartIdx, unsigned Len, DenseSet<BasicBlock *> BBSet; getBasicBlocks(BBSet); for (BasicBlock *BB : BBSet) { - if (ValueToNumber.find(BB) != ValueToNumber.end()) + if (ValueToNumber.contains(BB)) continue; ValueToNumber.try_emplace(BB, LocalValNumber); @@ -698,11 +718,39 @@ bool IRSimilarityCandidate::compareCommutativeOperandMapping( return true; } +bool IRSimilarityCandidate::compareAssignmentMapping( + const unsigned InstValA, const unsigned &InstValB, + DenseMap<unsigned, DenseSet<unsigned>> &ValueNumberMappingA, + DenseMap<unsigned, DenseSet<unsigned>> &ValueNumberMappingB) { + DenseMap<unsigned, DenseSet<unsigned>>::iterator ValueMappingIt; + bool WasInserted; + std::tie(ValueMappingIt, WasInserted) = ValueNumberMappingA.insert( + std::make_pair(InstValA, DenseSet<unsigned>({InstValB}))); + if (!WasInserted && !ValueMappingIt->second.contains(InstValB)) + return false; + else if (ValueMappingIt->second.size() != 1) { + for (unsigned OtherVal : ValueMappingIt->second) { + if (OtherVal == InstValB) + continue; + if (!ValueNumberMappingA.contains(OtherVal)) + continue; + if (!ValueNumberMappingA[OtherVal].contains(InstValA)) + continue; + ValueNumberMappingA[OtherVal].erase(InstValA); + } + ValueNumberMappingA.erase(ValueMappingIt); + std::tie(ValueMappingIt, WasInserted) = ValueNumberMappingA.insert( + std::make_pair(InstValA, DenseSet<unsigned>({InstValB}))); + } + + return true; +} + bool IRSimilarityCandidate::checkRelativeLocations(RelativeLocMapping A, RelativeLocMapping B) { // Get the basic blocks the label refers to. - BasicBlock *ABB = static_cast<BasicBlock *>(A.OperVal); - BasicBlock *BBB = static_cast<BasicBlock *>(B.OperVal); + BasicBlock *ABB = cast<BasicBlock>(A.OperVal); + BasicBlock *BBB = cast<BasicBlock>(B.OperVal); // Get the basic blocks contained in each region. DenseSet<BasicBlock *> BasicBlockA; @@ -715,7 +763,7 @@ bool IRSimilarityCandidate::checkRelativeLocations(RelativeLocMapping A, bool BContained = BasicBlockB.contains(BBB); // Both blocks need to be contained in the region, or both need to be outside - // the reigon. + // the region. if (AContained != BContained) return false; @@ -755,8 +803,6 @@ bool IRSimilarityCandidate::compareStructure( // in one candidate to values in the other candidate. If we create a set with // one element, and that same element maps to the original element in the // candidate we have a good mapping. - DenseMap<unsigned, DenseSet<unsigned>>::iterator ValueMappingIt; - // Iterate over the instructions contained in each candidate unsigned SectionLength = A.getStartIdx() + A.getLength(); @@ -779,16 +825,13 @@ bool IRSimilarityCandidate::compareStructure( unsigned InstValA = A.ValueToNumber.find(IA)->second; unsigned InstValB = B.ValueToNumber.find(IB)->second; - bool WasInserted; // Ensure that the mappings for the instructions exists. - std::tie(ValueMappingIt, WasInserted) = ValueNumberMappingA.insert( - std::make_pair(InstValA, DenseSet<unsigned>({InstValB}))); - if (!WasInserted && !ValueMappingIt->second.contains(InstValB)) + if (!compareAssignmentMapping(InstValA, InstValB, ValueNumberMappingA, + ValueNumberMappingB)) return false; - - std::tie(ValueMappingIt, WasInserted) = ValueNumberMappingB.insert( - std::make_pair(InstValB, DenseSet<unsigned>({InstValA}))); - if (!WasInserted && !ValueMappingIt->second.contains(InstValA)) + + if (!compareAssignmentMapping(InstValB, InstValA, ValueNumberMappingB, + ValueNumberMappingA)) return false; // We have different paths for commutative instructions and non-commutative @@ -826,12 +869,22 @@ bool IRSimilarityCandidate::compareStructure( SmallVector<int, 4> &RelBlockLocsA = ItA->RelativeBlockLocations; SmallVector<int, 4> &RelBlockLocsB = ItB->RelativeBlockLocations; + ArrayRef<Value *> ABL = ItA->getBlockOperVals(); + ArrayRef<Value *> BBL = ItB->getBlockOperVals(); + + // Check to make sure that the number of operands, and branching locations + // between BranchInsts is the same. if (RelBlockLocsA.size() != RelBlockLocsB.size() && - OperValsA.size() != OperValsB.size()) + ABL.size() != BBL.size()) return false; + assert(RelBlockLocsA.size() == ABL.size() && + "Block information vectors not the same size."); + assert(RelBlockLocsB.size() == BBL.size() && + "Block information vectors not the same size."); + ZippedRelativeLocationsT ZippedRelativeLocations = - zip(RelBlockLocsA, RelBlockLocsB, OperValsA, OperValsB); + zip(RelBlockLocsA, RelBlockLocsB, ABL, BBL); if (any_of(ZippedRelativeLocations, [&A, &B](std::tuple<int, int, Value *, Value *> R) { return !checkRelativeLocations( @@ -1026,7 +1079,7 @@ void IRSimilarityCandidate::createCanonicalRelationFrom( // We can skip the BasicBlock if the canonical numbering has already been // found in a separate instruction. - if (NumberToCanonNum.find(BBGVNForCurrCand) != NumberToCanonNum.end()) + if (NumberToCanonNum.contains(BBGVNForCurrCand)) continue; // If the basic block is the starting block, then the shared instruction may @@ -1048,6 +1101,76 @@ void IRSimilarityCandidate::createCanonicalRelationFrom( } } +void IRSimilarityCandidate::createCanonicalRelationFrom( + IRSimilarityCandidate &SourceCand, IRSimilarityCandidate &SourceCandLarge, + IRSimilarityCandidate &TargetCandLarge) { + assert(!SourceCand.CanonNumToNumber.empty() && + "Canonical Relationship is non-empty"); + assert(!SourceCand.NumberToCanonNum.empty() && + "Canonical Relationship is non-empty"); + + assert(!SourceCandLarge.CanonNumToNumber.empty() && + "Canonical Relationship is non-empty"); + assert(!SourceCandLarge.NumberToCanonNum.empty() && + "Canonical Relationship is non-empty"); + + assert(!TargetCandLarge.CanonNumToNumber.empty() && + "Canonical Relationship is non-empty"); + assert(!TargetCandLarge.NumberToCanonNum.empty() && + "Canonical Relationship is non-empty"); + + assert(CanonNumToNumber.empty() && "Canonical Relationship is non-empty"); + assert(NumberToCanonNum.empty() && "Canonical Relationship is non-empty"); + + // We're going to use the larger candidates as a "bridge" to create the + // canonical number for the target candidate since we have idetified two + // candidates as subsequences of larger sequences, and therefore must be + // structurally similar. + for (std::pair<Value *, unsigned> &ValueNumPair : ValueToNumber) { + Value *CurrVal = ValueNumPair.first; + unsigned TargetCandGVN = ValueNumPair.second; + + // Find the numbering in the large candidate that surrounds the + // current candidate. + std::optional<unsigned> OLargeTargetGVN = TargetCandLarge.getGVN(CurrVal); + assert(OLargeTargetGVN.has_value() && "GVN not found for Value"); + + // Get the canonical numbering in the large target candidate. + std::optional<unsigned> OTargetCandCanon = + TargetCandLarge.getCanonicalNum(OLargeTargetGVN.value()); + assert(OTargetCandCanon.has_value() && + "Canononical Number not found for GVN"); + + // Get the GVN in the large source candidate from the canonical numbering. + std::optional<unsigned> OLargeSourceGVN = + SourceCandLarge.fromCanonicalNum(OTargetCandCanon.value()); + assert(OLargeSourceGVN.has_value() && + "GVN Number not found for Canonical Number"); + + // Get the Value from the GVN in the large source candidate. + std::optional<Value *> OLargeSourceV = + SourceCandLarge.fromGVN(OLargeSourceGVN.value()); + assert(OLargeSourceV.has_value() && "Value not found for GVN"); + + // Get the GVN number for the Value in the source candidate. + std::optional<unsigned> OSourceGVN = + SourceCand.getGVN(OLargeSourceV.value()); + assert(OSourceGVN.has_value() && "GVN Number not found for Value"); + + // Get the canonical numbering from the GVN/ + std::optional<unsigned> OSourceCanon = + SourceCand.getCanonicalNum(OSourceGVN.value()); + assert(OSourceCanon.has_value() && "Canon Number not found for GVN"); + + // Insert the canonical numbering and GVN pair into their respective + // mappings. + CanonNumToNumber.insert( + std::make_pair(OSourceCanon.value(), TargetCandGVN)); + NumberToCanonNum.insert( + std::make_pair(TargetCandGVN, OSourceCanon.value())); + } +} + void IRSimilarityCandidate::createCanonicalMappingFor( IRSimilarityCandidate &CurrCand) { assert(CurrCand.CanonNumToNumber.size() == 0 && @@ -1065,6 +1188,81 @@ void IRSimilarityCandidate::createCanonicalMappingFor( } } +/// Look for larger IRSimilarityCandidates From the previously matched +/// IRSimilarityCandidates that fully contain \p CandA or \p CandB. If there is +/// an overlap, return a pair of structurally similar, larger +/// IRSimilarityCandidates. +/// +/// \param [in] CandA - The first candidate we are trying to determine the +/// structure of. +/// \param [in] CandB - The second candidate we are trying to determine the +/// structure of. +/// \param [in] IndexToIncludedCand - Mapping of index of the an instruction in +/// a circuit to the IRSimilarityCandidates that include this instruction. +/// \param [in] CandToOverallGroup - Mapping of IRSimilarityCandidate to a +/// number representing the structural group assigned to it. +static std::optional< + std::pair<IRSimilarityCandidate *, IRSimilarityCandidate *>> +CheckLargerCands( + IRSimilarityCandidate &CandA, IRSimilarityCandidate &CandB, + DenseMap<unsigned, DenseSet<IRSimilarityCandidate *>> &IndexToIncludedCand, + DenseMap<IRSimilarityCandidate *, unsigned> &CandToGroup) { + DenseMap<unsigned, IRSimilarityCandidate *> IncludedGroupAndCandA; + DenseMap<unsigned, IRSimilarityCandidate *> IncludedGroupAndCandB; + DenseSet<unsigned> IncludedGroupsA; + DenseSet<unsigned> IncludedGroupsB; + + // Find the overall similarity group numbers that fully contain the candidate, + // and record the larger candidate for each group. + auto IdxToCandidateIt = IndexToIncludedCand.find(CandA.getStartIdx()); + std::optional<std::pair<IRSimilarityCandidate *, IRSimilarityCandidate *>> + Result; + + unsigned CandAStart = CandA.getStartIdx(); + unsigned CandAEnd = CandA.getEndIdx(); + unsigned CandBStart = CandB.getStartIdx(); + unsigned CandBEnd = CandB.getEndIdx(); + if (IdxToCandidateIt == IndexToIncludedCand.end()) + return Result; + for (IRSimilarityCandidate *MatchedCand : IdxToCandidateIt->second) { + if (MatchedCand->getStartIdx() > CandAStart || + (MatchedCand->getEndIdx() < CandAEnd)) + continue; + unsigned GroupNum = CandToGroup.find(MatchedCand)->second; + IncludedGroupAndCandA.insert(std::make_pair(GroupNum, MatchedCand)); + IncludedGroupsA.insert(GroupNum); + } + + // Find the overall similarity group numbers that fully contain the next + // candidate, and record the larger candidate for each group. + IdxToCandidateIt = IndexToIncludedCand.find(CandBStart); + if (IdxToCandidateIt == IndexToIncludedCand.end()) + return Result; + for (IRSimilarityCandidate *MatchedCand : IdxToCandidateIt->second) { + if (MatchedCand->getStartIdx() > CandBStart || + MatchedCand->getEndIdx() < CandBEnd) + continue; + unsigned GroupNum = CandToGroup.find(MatchedCand)->second; + IncludedGroupAndCandB.insert(std::make_pair(GroupNum, MatchedCand)); + IncludedGroupsB.insert(GroupNum); + } + + // Find the intersection between the two groups, these are the groups where + // the larger candidates exist. + set_intersect(IncludedGroupsA, IncludedGroupsB); + + // If there is no intersection between the sets, then we cannot determine + // whether or not there is a match. + if (IncludedGroupsA.empty()) + return Result; + + // Create a pair that contains the larger candidates. + auto ItA = IncludedGroupAndCandA.find(*IncludedGroupsA.begin()); + auto ItB = IncludedGroupAndCandB.find(*IncludedGroupsA.begin()); + Result = std::make_pair(ItA->second, ItB->second); + return Result; +} + /// From the list of IRSimilarityCandidates, perform a comparison between each /// IRSimilarityCandidate to determine if there are overlapping /// IRInstructionData, or if they do not have the same structure. @@ -1074,9 +1272,16 @@ void IRSimilarityCandidate::createCanonicalMappingFor( /// \param [out] StructuralGroups - the mapping of unsigned integers to vector /// of IRSimilarityCandidates where each of the IRSimilarityCandidates in the /// vector are structurally similar to one another. +/// \param [in] IndexToIncludedCand - Mapping of index of the an instruction in +/// a circuit to the IRSimilarityCandidates that include this instruction. +/// \param [in] CandToOverallGroup - Mapping of IRSimilarityCandidate to a +/// number representing the structural group assigned to it. static void findCandidateStructures( std::vector<IRSimilarityCandidate> &CandsForRepSubstring, - DenseMap<unsigned, SimilarityGroup> &StructuralGroups) { + DenseMap<unsigned, SimilarityGroup> &StructuralGroups, + DenseMap<unsigned, DenseSet<IRSimilarityCandidate *>> &IndexToIncludedCand, + DenseMap<IRSimilarityCandidate *, unsigned> &CandToOverallGroup + ) { std::vector<IRSimilarityCandidate>::iterator CandIt, CandEndIt, InnerCandIt, InnerCandEndIt; @@ -1139,6 +1344,24 @@ static void findCandidateStructures( if (CandToGroupItInner != CandToGroup.end()) continue; + // Check if we have found structural similarity between two candidates + // that fully contains the first and second candidates. + std::optional<std::pair<IRSimilarityCandidate *, IRSimilarityCandidate *>> + LargerPair = CheckLargerCands( + *CandIt, *InnerCandIt, IndexToIncludedCand, CandToOverallGroup); + + // If a pair was found, it means that we can assume that these smaller + // substrings are also structurally similar. Use the larger candidates to + // determine the canonical mapping between the two sections. + if (LargerPair.has_value()) { + SameStructure = true; + InnerCandIt->createCanonicalRelationFrom( + *CandIt, *LargerPair.value().first, *LargerPair.value().second); + CandToGroup.insert(std::make_pair(&*InnerCandIt, OuterGroupNum)); + CurrentGroupPair->second.push_back(*InnerCandIt); + continue; + } + // Otherwise we determine if they have the same structure and add it to // vector if they match. ValueNumberMappingA.clear(); @@ -1165,24 +1388,58 @@ void IRSimilarityIdentifier::findCandidates( std::vector<SimilarityGroup> NewCandidateGroups; DenseMap<unsigned, SimilarityGroup> StructuralGroups; + DenseMap<unsigned, DenseSet<IRSimilarityCandidate *>> IndexToIncludedCand; + DenseMap<IRSimilarityCandidate *, unsigned> CandToGroup; // Iterate over the subsequences found by the Suffix Tree to create // IRSimilarityCandidates for each repeated subsequence and determine which // instances are structurally similar to one another. - for (SuffixTree::RepeatedSubstring &RS : ST) { + + // Sort the suffix tree from longest substring to shortest. + std::vector<SuffixTree::RepeatedSubstring> RSes; + for (SuffixTree::RepeatedSubstring &RS : ST) + RSes.push_back(RS); + + llvm::stable_sort(RSes, [](const SuffixTree::RepeatedSubstring &LHS, + const SuffixTree::RepeatedSubstring &RHS) { + return LHS.Length > RHS.Length; + }); + for (SuffixTree::RepeatedSubstring &RS : RSes) { createCandidatesFromSuffixTree(Mapper, InstrList, IntegerMapping, RS, CandsForRepSubstring); if (CandsForRepSubstring.size() < 2) continue; - findCandidateStructures(CandsForRepSubstring, StructuralGroups); - for (std::pair<unsigned, SimilarityGroup> &Group : StructuralGroups) + findCandidateStructures(CandsForRepSubstring, StructuralGroups, + IndexToIncludedCand, CandToGroup); + for (std::pair<unsigned, SimilarityGroup> &Group : StructuralGroups) { // We only add the group if it contains more than one // IRSimilarityCandidate. If there is only one, that means there is no // other repeated subsequence with the same structure. - if (Group.second.size() > 1) + if (Group.second.size() > 1) { SimilarityCandidates->push_back(Group.second); + // Iterate over each candidate in the group, and add an entry for each + // instruction included with a mapping to a set of + // IRSimilarityCandidates that include that instruction. + for (IRSimilarityCandidate &IRCand : SimilarityCandidates->back()) { + for (unsigned Idx = IRCand.getStartIdx(), Edx = IRCand.getEndIdx(); + Idx <= Edx; ++Idx) { + DenseMap<unsigned, DenseSet<IRSimilarityCandidate *>>::iterator + IdIt; + IdIt = IndexToIncludedCand.find(Idx); + bool Inserted = false; + if (IdIt == IndexToIncludedCand.end()) + std::tie(IdIt, Inserted) = IndexToIncludedCand.insert( + std::make_pair(Idx, DenseSet<IRSimilarityCandidate *>())); + IdIt->second.insert(&IRCand); + } + // Add mapping of candidate to the overall similarity group number. + CandToGroup.insert( + std::make_pair(&IRCand, SimilarityCandidates->size() - 1)); + } + } + } CandsForRepSubstring.clear(); StructuralGroups.clear(); diff --git a/contrib/llvm-project/llvm/lib/Analysis/IVDescriptors.cpp b/contrib/llvm-project/llvm/lib/Analysis/IVDescriptors.cpp index 950541ace9d7..6c750b7baa40 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/IVDescriptors.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/IVDescriptors.cpp @@ -107,7 +107,7 @@ static std::pair<Type *, bool> computeRecurrenceType(Instruction *Exit, // must be positive (i.e., IsSigned = false), because if this were not the // case, the sign bit would have been demanded. auto Mask = DB->getDemandedBits(Exit); - MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); + MaxBitWidth = Mask.getBitWidth() - Mask.countl_zero(); } if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) { @@ -128,8 +128,7 @@ static std::pair<Type *, bool> computeRecurrenceType(Instruction *Exit, ++MaxBitWidth; } } - if (!isPowerOf2_64(MaxBitWidth)) - MaxBitWidth = NextPowerOf2(MaxBitWidth); + MaxBitWidth = llvm::bit_ceil(MaxBitWidth); return std::make_pair(Type::getIntNTy(Exit->getContext(), MaxBitWidth), IsSigned); @@ -707,6 +706,10 @@ RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind, return InstDesc(Kind == RecurKind::FMin, I); if (match(I, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_Value()))) return InstDesc(Kind == RecurKind::FMax, I); + if (match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(), m_Value()))) + return InstDesc(Kind == RecurKind::FMinimum, I); + if (match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(), m_Value()))) + return InstDesc(Kind == RecurKind::FMaximum, I); return InstDesc(false, I); } @@ -746,15 +749,21 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) { return InstDesc(false, I); Value *Op1, *Op2; - if ((m_FAdd(m_Value(Op1), m_Value(Op2)).match(I1) || - m_FSub(m_Value(Op1), m_Value(Op2)).match(I1)) && - I1->isFast()) - return InstDesc(Kind == RecurKind::FAdd, SI); + if (!(((m_FAdd(m_Value(Op1), m_Value(Op2)).match(I1) || + m_FSub(m_Value(Op1), m_Value(Op2)).match(I1)) && + I1->isFast()) || + (m_FMul(m_Value(Op1), m_Value(Op2)).match(I1) && (I1->isFast())) || + ((m_Add(m_Value(Op1), m_Value(Op2)).match(I1) || + m_Sub(m_Value(Op1), m_Value(Op2)).match(I1))) || + (m_Mul(m_Value(Op1), m_Value(Op2)).match(I1)))) + return InstDesc(false, I); - if (m_FMul(m_Value(Op1), m_Value(Op2)).match(I1) && (I1->isFast())) - return InstDesc(Kind == RecurKind::FMul, SI); + Instruction *IPhi = isa<PHINode>(*Op1) ? dyn_cast<Instruction>(Op1) + : dyn_cast<Instruction>(Op2); + if (!IPhi || IPhi != FalseVal) + return InstDesc(false, I); - return InstDesc(false, I); + return InstDesc(true, SI); } RecurrenceDescriptor::InstDesc @@ -787,7 +796,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi, return InstDesc(Kind == RecurKind::FAdd, I, I->hasAllowReassoc() ? nullptr : I); case Instruction::Select: - if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul) + if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul || + Kind == RecurKind::Add || Kind == RecurKind::Mul) return isConditionalRdxPattern(Kind, I); [[fallthrough]]; case Instruction::FCmp: @@ -795,11 +805,18 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi, case Instruction::Call: if (isSelectCmpRecurrenceKind(Kind)) return isSelectCmpPattern(L, OrigPhi, I, Prev); + auto HasRequiredFMF = [&]() { + if (FuncFMF.noNaNs() && FuncFMF.noSignedZeros()) + return true; + if (isa<FPMathOperator>(I) && I->hasNoNaNs() && I->hasNoSignedZeros()) + return true; + // minimum and maximum intrinsics do not require nsz and nnan flags since + // NaN and signed zeroes are propagated in the intrinsic implementation. + return match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(), m_Value())) || + match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(), m_Value())); + }; if (isIntMinMaxRecurrenceKind(Kind) || - (((FuncFMF.noNaNs() && FuncFMF.noSignedZeros()) || - (isa<FPMathOperator>(I) && I->hasNoNaNs() && - I->hasNoSignedZeros())) && - isFPMinMaxRecurrenceKind(Kind))) + (HasRequiredFMF() && isFPMinMaxRecurrenceKind(Kind))) return isMinMaxPattern(I, Kind, Prev); else if (isFMulAddIntrinsic(I)) return InstDesc(Kind == RecurKind::FMulAdd, I, @@ -917,13 +934,22 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, LLVM_DEBUG(dbgs() << "Found an FMulAdd reduction PHI." << *Phi << "\n"); return true; } + if (AddReductionVar(Phi, RecurKind::FMaximum, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { + LLVM_DEBUG(dbgs() << "Found a float MAXIMUM reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RecurKind::FMinimum, TheLoop, FMF, RedDes, DB, AC, DT, + SE)) { + LLVM_DEBUG(dbgs() << "Found a float MINIMUM reduction PHI." << *Phi << "\n"); + return true; + } // Not a reduction of known type. return false; } -bool RecurrenceDescriptor::isFixedOrderRecurrence( - PHINode *Phi, Loop *TheLoop, - MapVector<Instruction *, Instruction *> &SinkAfter, DominatorTree *DT) { +bool RecurrenceDescriptor::isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop, + DominatorTree *DT) { // Ensure the phi node is in the loop header and has two incoming values. if (Phi->getParent() != TheLoop->getHeader() || @@ -959,8 +985,7 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence( Previous = dyn_cast<Instruction>(PrevPhi->getIncomingValueForBlock(Latch)); } - if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous) || - SinkAfter.count(Previous)) // Cannot rely on dominance due to motion. + if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous)) return false; // Ensure every user of the phi node (recursively) is dominated by the @@ -969,27 +994,16 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence( // loop. // TODO: Consider extending this sinking to handle memory instructions. - // We optimistically assume we can sink all users after Previous. Keep a set - // of instructions to sink after Previous ordered by dominance in the common - // basic block. It will be applied to SinkAfter if all users can be sunk. - auto CompareByComesBefore = [](const Instruction *A, const Instruction *B) { - return A->comesBefore(B); - }; - std::set<Instruction *, decltype(CompareByComesBefore)> InstrsToSink( - CompareByComesBefore); - + SmallPtrSet<Value *, 8> Seen; BasicBlock *PhiBB = Phi->getParent(); SmallVector<Instruction *, 8> WorkList; auto TryToPushSinkCandidate = [&](Instruction *SinkCandidate) { - // Already sunk SinkCandidate. - if (SinkCandidate->getParent() == PhiBB && - InstrsToSink.find(SinkCandidate) != InstrsToSink.end()) - return true; - // Cyclic dependence. if (Previous == SinkCandidate) return false; + if (!Seen.insert(SinkCandidate).second) + return true; if (DT->dominates(Previous, SinkCandidate)) // We already are good w/o sinking. return true; @@ -999,55 +1013,12 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence( SinkCandidate->mayReadFromMemory() || SinkCandidate->isTerminator()) return false; - // Avoid sinking an instruction multiple times (if multiple operands are - // fixed order recurrences) by sinking once - after the latest 'previous' - // instruction. - auto It = SinkAfter.find(SinkCandidate); - if (It != SinkAfter.end()) { - auto *OtherPrev = It->second; - // Find the earliest entry in the 'sink-after' chain. The last entry in - // the chain is the original 'Previous' for a recurrence handled earlier. - auto EarlierIt = SinkAfter.find(OtherPrev); - while (EarlierIt != SinkAfter.end()) { - Instruction *EarlierInst = EarlierIt->second; - EarlierIt = SinkAfter.find(EarlierInst); - // Bail out if order has not been preserved. - if (EarlierIt != SinkAfter.end() && - !DT->dominates(EarlierInst, OtherPrev)) - return false; - OtherPrev = EarlierInst; - } - // Bail out if order has not been preserved. - if (OtherPrev != It->second && !DT->dominates(It->second, OtherPrev)) - return false; - - // SinkCandidate is already being sunk after an instruction after - // Previous. Nothing left to do. - if (DT->dominates(Previous, OtherPrev) || Previous == OtherPrev) - return true; - - // If there are other instructions to be sunk after SinkCandidate, remove - // and re-insert SinkCandidate can break those instructions. Bail out for - // simplicity. - if (any_of(SinkAfter, - [SinkCandidate](const std::pair<Instruction *, Instruction *> &P) { - return P.second == SinkCandidate; - })) - return false; - - // Otherwise, Previous comes after OtherPrev and SinkCandidate needs to be - // re-sunk to Previous, instead of sinking to OtherPrev. Remove - // SinkCandidate from SinkAfter to ensure it's insert position is updated. - SinkAfter.erase(SinkCandidate); - } - // If we reach a PHI node that is not dominated by Previous, we reached a // header PHI. No need for sinking. if (isa<PHINode>(SinkCandidate)) return true; // Sink User tentatively and check its users - InstrsToSink.insert(SinkCandidate); WorkList.push_back(SinkCandidate); return true; }; @@ -1062,11 +1033,6 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence( } } - // We can sink all users of Phi. Update the mapping. - for (Instruction *I : InstrsToSink) { - SinkAfter[I] = Previous; - Previous = I; - } return true; } @@ -1101,7 +1067,7 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp, return ConstantFP::get(Tp, 0.0L); return ConstantFP::get(Tp, -0.0L); case RecurKind::UMin: - return ConstantInt::get(Tp, -1); + return ConstantInt::get(Tp, -1, true); case RecurKind::UMax: return ConstantInt::get(Tp, 0); case RecurKind::SMin: @@ -1118,6 +1084,10 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp, assert((FMF.noNaNs() && FMF.noSignedZeros()) && "nnan, nsz is expected to be set for FP max reduction."); return ConstantFP::getInfinity(Tp, true /*Negative*/); + case RecurKind::FMinimum: + return ConstantFP::getInfinity(Tp, false /*Negative*/); + case RecurKind::FMaximum: + return ConstantFP::getInfinity(Tp, true /*Negative*/); case RecurKind::SelectICmp: case RecurKind::SelectFCmp: return getRecurrenceStartValue(); @@ -1152,6 +1122,8 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) { return Instruction::ICmp; case RecurKind::FMax: case RecurKind::FMin: + case RecurKind::FMaximum: + case RecurKind::FMinimum: case RecurKind::SelectFCmp: return Instruction::FCmp; default: @@ -1264,10 +1236,8 @@ RecurrenceDescriptor::getReductionOpChain(PHINode *Phi, Loop *L) const { InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, const SCEV *Step, BinaryOperator *BOp, - Type *ElementType, SmallVectorImpl<Instruction *> *Casts) - : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp), - ElementType(ElementType) { + : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) { assert(IK != IK_NoInduction && "Not an induction"); // Start value type should match the induction kind and the value @@ -1282,8 +1252,6 @@ InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, assert((!getConstIntStepValue() || !getConstIntStepValue()->isZero()) && "Step value is zero"); - assert((IK != IK_PtrInduction || getConstIntStepValue()) && - "Step value should be constant for pointer induction"); assert((IK == IK_FpInduction || Step->getType()->isIntegerTy()) && "StepValue is not an integer"); @@ -1295,11 +1263,6 @@ InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, InductionBinOp->getOpcode() == Instruction::FSub))) && "Binary opcode should be specified for FP induction"); - if (IK == IK_PtrInduction) - assert(ElementType && "Pointer induction must have element type"); - else - assert(!ElementType && "Non-pointer induction cannot have element type"); - if (Casts) { for (auto &Inst : *Casts) { RedundantCasts.push_back(Inst); @@ -1541,6 +1504,12 @@ bool InductionDescriptor::isInductionPHI( return false; } + // This function assumes that InductionPhi is called only on Phi nodes + // present inside loop headers. Check for the same, and throw an assert if + // the current Phi is not present inside the loop header. + assert(Phi->getParent() == AR->getLoop()->getHeader() + && "Invalid Phi node, not present in loop header"); + Value *StartValue = Phi->getIncomingValueForBlock(AR->getLoop()->getLoopPreheader()); @@ -1559,39 +1528,13 @@ bool InductionDescriptor::isInductionPHI( BinaryOperator *BOp = dyn_cast<BinaryOperator>(Phi->getIncomingValueForBlock(Latch)); D = InductionDescriptor(StartValue, IK_IntInduction, Step, BOp, - /* ElementType */ nullptr, CastsToIgnore); + CastsToIgnore); return true; } assert(PhiTy->isPointerTy() && "The PHI must be a pointer"); - // Pointer induction should be a constant. - if (!ConstStep) - return false; - - // Always use i8 element type for opaque pointer inductions. - PointerType *PtrTy = cast<PointerType>(PhiTy); - Type *ElementType = PtrTy->isOpaque() - ? Type::getInt8Ty(PtrTy->getContext()) - : PtrTy->getNonOpaquePointerElementType(); - if (!ElementType->isSized()) - return false; - - ConstantInt *CV = ConstStep->getValue(); - const DataLayout &DL = Phi->getModule()->getDataLayout(); - TypeSize TySize = DL.getTypeAllocSize(ElementType); - // TODO: We could potentially support this for scalable vectors if we can - // prove at compile time that the constant step is always a multiple of - // the scalable type. - if (TySize.isZero() || TySize.isScalable()) - return false; - int64_t Size = static_cast<int64_t>(TySize.getFixedValue()); - int64_t CVSize = CV->getSExtValue(); - if (CVSize % Size) - return false; - auto *StepValue = - SE->getConstant(CV->getType(), CVSize / Size, true /* signed */); - D = InductionDescriptor(StartValue, IK_PtrInduction, StepValue, - /* BinOp */ nullptr, ElementType); + // This allows induction variables w/non-constant steps. + D = InductionDescriptor(StartValue, IK_PtrInduction, Step); return true; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/IVUsers.cpp b/contrib/llvm-project/llvm/lib/Analysis/IVUsers.cpp index 830211658353..5c7883fb3b37 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/IVUsers.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/IVUsers.cpp @@ -334,8 +334,8 @@ const SCEV *IVUsers::getReplacementExpr(const IVStrideUse &IU) const { /// getExpr - Return the expression for the use. const SCEV *IVUsers::getExpr(const IVStrideUse &IU) const { - return normalizeForPostIncUse(getReplacementExpr(IU), IU.getPostIncLoops(), - *SE); + const SCEV *Replacement = getReplacementExpr(IU); + return normalizeForPostIncUse(Replacement, IU.getPostIncLoops(), *SE); } static const SCEVAddRecExpr *findAddRecForLoop(const SCEV *S, const Loop *L) { @@ -356,7 +356,10 @@ static const SCEVAddRecExpr *findAddRecForLoop(const SCEV *S, const Loop *L) { } const SCEV *IVUsers::getStride(const IVStrideUse &IU, const Loop *L) const { - if (const SCEVAddRecExpr *AR = findAddRecForLoop(getExpr(IU), L)) + const SCEV *Expr = getExpr(IU); + if (!Expr) + return nullptr; + if (const SCEVAddRecExpr *AR = findAddRecForLoop(Expr, L)) return AR->getStepRecurrence(*SE); return nullptr; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/InlineAdvisor.cpp b/contrib/llvm-project/llvm/lib/Analysis/InlineAdvisor.cpp index 540aad7ee0c0..e2480d51d372 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/InlineAdvisor.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/InlineAdvisor.cpp @@ -13,6 +13,7 @@ #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -208,6 +209,10 @@ bool InlineAdvisorAnalysis::Result::tryCreate( Advisor.reset(DA.Factory(M, FAM, Params, IC)); return !!Advisor; } + auto GetDefaultAdvice = [&FAM, Params](CallBase &CB) { + auto OIC = getDefaultInlineAdvice(CB, FAM, Params); + return OIC.has_value(); + }; switch (Mode) { case InliningAdvisorMode::Default: LLVM_DEBUG(dbgs() << "Using default inliner heuristic.\n"); @@ -223,18 +228,12 @@ bool InlineAdvisorAnalysis::Result::tryCreate( case InliningAdvisorMode::Development: #ifdef LLVM_HAVE_TFLITE LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n"); - Advisor = - llvm::getDevelopmentModeAdvisor(M, MAM, [&FAM, Params](CallBase &CB) { - auto OIC = getDefaultInlineAdvice(CB, FAM, Params); - return OIC.has_value(); - }); + Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice); #endif break; case InliningAdvisorMode::Release: -#ifdef LLVM_HAVE_TF_AOT LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n"); - Advisor = llvm::getReleaseModeAdvisor(M, MAM); -#endif + Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice); break; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/InlineCost.cpp b/contrib/llvm-project/llvm/lib/Analysis/InlineCost.cpp index 5bcc8a2f384a..9ff277f5334e 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/InlineCost.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/InlineCost.cpp @@ -142,11 +142,11 @@ static cl::opt<size_t> cl::desc("Do not inline functions with a stack size " "that exceeds the specified limit")); -static cl::opt<size_t> - RecurStackSizeThreshold("recursive-inline-max-stacksize", cl::Hidden, - cl::init(InlineConstants::TotalAllocaSizeRecursiveCaller), - cl::desc("Do not inline recursive functions with a stack " - "size that exceeds the specified limit")); +static cl::opt<size_t> RecurStackSizeThreshold( + "recursive-inline-max-stacksize", cl::Hidden, + cl::init(InlineConstants::TotalAllocaSizeRecursiveCaller), + cl::desc("Do not inline recursive functions with a stack " + "size that exceeds the specified limit")); static cl::opt<bool> OptComputeFullInlineCost( "inline-cost-full", cl::Hidden, @@ -493,7 +493,7 @@ public: InlineResult analyze(); std::optional<Constant *> getSimplifiedValue(Instruction *I) { - if (SimplifiedValues.find(I) != SimplifiedValues.end()) + if (SimplifiedValues.contains(I)) return SimplifiedValues[I]; return std::nullopt; } @@ -717,7 +717,9 @@ class InlineCostCallAnalyzer final : public CallAnalyzer { void onInitializeSROAArg(AllocaInst *Arg) override { assert(Arg != nullptr && "Should not initialize SROA costs for null value."); - SROAArgCosts[Arg] = 0; + auto SROAArgCost = TTI.getCallerAllocaCost(&CandidateCall, Arg); + SROACostSavings += SROAArgCost; + SROAArgCosts[Arg] = SROAArgCost; } void onAggregateSROAUse(AllocaInst *SROAArg) override { @@ -1054,7 +1056,7 @@ public: void print(raw_ostream &OS); std::optional<InstructionCostDetail> getCostDetails(const Instruction *I) { - if (InstructionCostDetailMap.find(I) != InstructionCostDetailMap.end()) + if (InstructionCostDetailMap.contains(I)) return InstructionCostDetailMap[I]; return std::nullopt; } @@ -1108,31 +1110,31 @@ private: if (CostIt == SROACosts.end()) return; - increment(InlineCostFeatureIndex::SROALosses, CostIt->second); + increment(InlineCostFeatureIndex::sroa_losses, CostIt->second); SROACostSavingOpportunities -= CostIt->second; SROACosts.erase(CostIt); } void onDisableLoadElimination() override { - set(InlineCostFeatureIndex::LoadElimination, 1); + set(InlineCostFeatureIndex::load_elimination, 1); } void onCallPenalty() override { - increment(InlineCostFeatureIndex::CallPenalty, CallPenalty); + increment(InlineCostFeatureIndex::call_penalty, CallPenalty); } void onCallArgumentSetup(const CallBase &Call) override { - increment(InlineCostFeatureIndex::CallArgumentSetup, + increment(InlineCostFeatureIndex::call_argument_setup, Call.arg_size() * InstrCost); } void onLoadRelativeIntrinsic() override { - increment(InlineCostFeatureIndex::LoadRelativeIntrinsic, 3 * InstrCost); + increment(InlineCostFeatureIndex::load_relative_intrinsic, 3 * InstrCost); } void onLoweredCall(Function *F, CallBase &Call, bool IsIndirectCall) override { - increment(InlineCostFeatureIndex::LoweredCallArgSetup, + increment(InlineCostFeatureIndex::lowered_call_arg_setup, Call.arg_size() * InstrCost); if (IsIndirectCall) { @@ -1153,9 +1155,9 @@ private: GetAssumptionCache, GetBFI, PSI, ORE, false, true); if (CA.analyze().isSuccess()) { - increment(InlineCostFeatureIndex::NestedInlineCostEstimate, + increment(InlineCostFeatureIndex::nested_inline_cost_estimate, CA.getCost()); - increment(InlineCostFeatureIndex::NestedInlines, 1); + increment(InlineCostFeatureIndex::nested_inlines, 1); } } else { onCallPenalty(); @@ -1168,12 +1170,12 @@ private: if (JumpTableSize) { int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost + JTCostMultiplier * InstrCost; - increment(InlineCostFeatureIndex::JumpTablePenalty, JTCost); + increment(InlineCostFeatureIndex::jump_table_penalty, JTCost); return; } if (NumCaseCluster <= 3) { - increment(InlineCostFeatureIndex::CaseClusterPenalty, + increment(InlineCostFeatureIndex::case_cluster_penalty, NumCaseCluster * CaseClusterCostMultiplier * InstrCost); return; } @@ -1183,15 +1185,20 @@ private: int64_t SwitchCost = ExpectedNumberOfCompare * SwitchCostMultiplier * InstrCost; - increment(InlineCostFeatureIndex::SwitchPenalty, SwitchCost); + increment(InlineCostFeatureIndex::switch_penalty, SwitchCost); } void onMissedSimplification() override { - increment(InlineCostFeatureIndex::UnsimplifiedCommonInstructions, + increment(InlineCostFeatureIndex::unsimplified_common_instructions, InstrCost); } - void onInitializeSROAArg(AllocaInst *Arg) override { SROACosts[Arg] = 0; } + void onInitializeSROAArg(AllocaInst *Arg) override { + auto SROAArgCost = TTI.getCallerAllocaCost(&CandidateCall, Arg); + SROACosts[Arg] = SROAArgCost; + SROACostSavingOpportunities += SROAArgCost; + } + void onAggregateSROAUse(AllocaInst *Arg) override { SROACosts.find(Arg)->second += InstrCost; SROACostSavingOpportunities += InstrCost; @@ -1199,7 +1206,7 @@ private: void onBlockAnalyzed(const BasicBlock *BB) override { if (BB->getTerminator()->getNumSuccessors() > 1) - set(InlineCostFeatureIndex::IsMultipleBlocks, 1); + set(InlineCostFeatureIndex::is_multiple_blocks, 1); Threshold -= SingleBBBonus; } @@ -1212,24 +1219,24 @@ private: // Ignore loops that will not be executed if (DeadBlocks.count(L->getHeader())) continue; - increment(InlineCostFeatureIndex::NumLoops, + increment(InlineCostFeatureIndex::num_loops, InlineConstants::LoopPenalty); } } - set(InlineCostFeatureIndex::DeadBlocks, DeadBlocks.size()); - set(InlineCostFeatureIndex::SimplifiedInstructions, + set(InlineCostFeatureIndex::dead_blocks, DeadBlocks.size()); + set(InlineCostFeatureIndex::simplified_instructions, NumInstructionsSimplified); - set(InlineCostFeatureIndex::ConstantArgs, NumConstantArgs); - set(InlineCostFeatureIndex::ConstantOffsetPtrArgs, + set(InlineCostFeatureIndex::constant_args, NumConstantArgs); + set(InlineCostFeatureIndex::constant_offset_ptr_args, NumConstantOffsetPtrArgs); - set(InlineCostFeatureIndex::SROASavings, SROACostSavingOpportunities); + set(InlineCostFeatureIndex::sroa_savings, SROACostSavingOpportunities); if (NumVectorInstructions <= NumInstructions / 10) Threshold -= VectorBonus; else if (NumVectorInstructions <= NumInstructions / 2) Threshold -= VectorBonus / 2; - set(InlineCostFeatureIndex::Threshold, Threshold); + set(InlineCostFeatureIndex::threshold, Threshold); return InlineResult::success(); } @@ -1237,17 +1244,17 @@ private: bool shouldStop() override { return false; } void onLoadEliminationOpportunity() override { - increment(InlineCostFeatureIndex::LoadElimination, 1); + increment(InlineCostFeatureIndex::load_elimination, 1); } InlineResult onAnalysisStart() override { - increment(InlineCostFeatureIndex::CallSiteCost, + increment(InlineCostFeatureIndex::callsite_cost, -1 * getCallsiteCost(this->CandidateCall, DL)); - set(InlineCostFeatureIndex::ColdCcPenalty, + set(InlineCostFeatureIndex::cold_cc_penalty, (F.getCallingConv() == CallingConv::Cold)); - set(InlineCostFeatureIndex::LastCallToStaticBonus, + set(InlineCostFeatureIndex::last_call_to_static_bonus, isSoleCallToLocalFunction(CandidateCall, F)); // FIXME: we shouldn't repeat this logic in both the Features and Cost @@ -1607,7 +1614,7 @@ bool CallAnalyzer::simplifyIntrinsicCallIsConstant(CallBase &CB) { bool CallAnalyzer::simplifyIntrinsicCallObjectSize(CallBase &CB) { // As per the langref, "The fourth argument to llvm.objectsize determines if // the value should be evaluated at runtime." - if(cast<ConstantInt>(CB.getArgOperand(3))->isOne()) + if (cast<ConstantInt>(CB.getArgOperand(3))->isOne()) return false; Value *V = lowerObjectSizeCall(&cast<IntrinsicInst>(CB), DL, nullptr, @@ -1976,14 +1983,27 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) { } } + auto isImplicitNullCheckCmp = [](const CmpInst &I) { + for (auto *User : I.users()) + if (auto *Instr = dyn_cast<Instruction>(User)) + if (!Instr->getMetadata(LLVMContext::MD_make_implicit)) + return false; + return true; + }; + // If the comparison is an equality comparison with null, we can simplify it // if we know the value (argument) can't be null - if (I.isEquality() && isa<ConstantPointerNull>(I.getOperand(1)) && - isKnownNonNullInCallee(I.getOperand(0))) { - bool IsNotEqual = I.getPredicate() == CmpInst::ICMP_NE; - SimplifiedValues[&I] = IsNotEqual ? ConstantInt::getTrue(I.getType()) - : ConstantInt::getFalse(I.getType()); - return true; + if (I.isEquality() && isa<ConstantPointerNull>(I.getOperand(1))) { + if (isKnownNonNullInCallee(I.getOperand(0))) { + bool IsNotEqual = I.getPredicate() == CmpInst::ICMP_NE; + SimplifiedValues[&I] = IsNotEqual ? ConstantInt::getTrue(I.getType()) + : ConstantInt::getFalse(I.getType()); + return true; + } + // Implicit null checks act as unconditional branches and their comparisons + // should be treated as simplified and free of cost. + if (isImplicitNullCheckCmp(I)) + return true; } return handleSROA(I.getOperand(0), isa<ConstantPointerNull>(I.getOperand(1))); } @@ -2265,6 +2285,7 @@ bool CallAnalyzer::visitBranchInst(BranchInst &BI) { // inliner more regular and predictable. Interestingly, conditional branches // which will fold away are also free. return BI.isUnconditional() || isa<ConstantInt>(BI.getCondition()) || + BI.getMetadata(LLVMContext::MD_make_implicit) || isa_and_nonnull<ConstantInt>( SimplifiedValues.lookup(BI.getCondition())); } @@ -2314,10 +2335,10 @@ bool CallAnalyzer::visitSelectInst(SelectInst &SI) { : nullptr; if (!SelectedV) { // Condition is a vector constant that is not all 1s or all 0s. If all - // operands are constants, ConstantExpr::getSelect() can handle the cases - // such as select vectors. + // operands are constants, ConstantFoldSelectInstruction() can handle the + // cases such as select vectors. if (TrueC && FalseC) { - if (auto *C = ConstantExpr::getSelect(CondC, TrueC, FalseC)) { + if (auto *C = ConstantFoldSelectInstruction(CondC, TrueC, FalseC)) { SimplifiedValues[&SI] = C; return true; } @@ -2666,9 +2687,7 @@ InlineResult CallAnalyzer::analyze() { // basic blocks in a breadth-first order as we insert live successors. To // accomplish this, prioritizing for small iterations because we exit after // crossing our threshold, we use a small-size optimized SetVector. - typedef SetVector<BasicBlock *, SmallVector<BasicBlock *, 16>, - SmallPtrSet<BasicBlock *, 16>> - BBSetVector; + typedef SmallSetVector<BasicBlock *, 16> BBSetVector; BBSetVector BBWorklist; BBWorklist.insert(&F.getEntryBlock()); @@ -2787,16 +2806,14 @@ LLVM_DUMP_METHOD void InlineCostCallAnalyzer::dump() { print(dbgs()); } /// Test that there are no attribute conflicts between Caller and Callee /// that prevent inlining. static bool functionsHaveCompatibleAttributes( - Function *Caller, Function *Callee, TargetTransformInfo &TTI, + Function *Caller, Function *Callee, function_ref<const TargetLibraryInfo &(Function &)> &GetTLI) { // Note that CalleeTLI must be a copy not a reference. The legacy pass manager // caches the most recently created TLI in the TargetLibraryInfoWrapperPass // object, and always returns the same object (which is overwritten on each // GetTLI call). Therefore we copy the first result. auto CalleeTLI = GetTLI(*Callee); - return (IgnoreTTIInlineCompatible || - TTI.areInlineCompatible(Caller, Callee)) && - GetTLI(*Caller).areInlineCompatible(CalleeTLI, + return GetTLI(*Caller).areInlineCompatible(CalleeTLI, InlineCallerSupersetNoBuiltin) && AttributeFuncs::areInlineCompatible(*Caller, *Callee); } @@ -2912,6 +2929,12 @@ std::optional<InlineResult> llvm::getAttributeBasedInliningDecision( " address space"); } + // Never inline functions with conflicting target attributes. + Function *Caller = Call.getCaller(); + if (!IgnoreTTIInlineCompatible && + !CalleeTTI.areInlineCompatible(Caller, Callee)) + return InlineResult::failure("conflicting target attributes"); + // Calls to functions with always-inline attributes should be inlined // whenever possible. if (Call.hasFnAttr(Attribute::AlwaysInline)) { @@ -2926,8 +2949,12 @@ std::optional<InlineResult> llvm::getAttributeBasedInliningDecision( // Never inline functions with conflicting attributes (unless callee has // always-inline attribute). - Function *Caller = Call.getCaller(); - if (!functionsHaveCompatibleAttributes(Caller, Callee, CalleeTTI, GetTLI)) + // FIXME: functionsHaveCompatibleAttributes below checks for compatibilities + // of different kinds of function attributes -- sanitizer-related ones, + // checkDenormMode, no-builtin-memcpy, etc. It's unclear if we really want + // the always-inline attribute to take precedence over these different types + // of function attributes. + if (!functionsHaveCompatibleAttributes(Caller, Callee, GetTLI)) return InlineResult::failure("conflicting attributes"); // Don't inline this call if the caller has the optnone attribute. diff --git a/contrib/llvm-project/llvm/lib/Analysis/InlineOrder.cpp b/contrib/llvm-project/llvm/lib/Analysis/InlineOrder.cpp index 8d0e49936901..3b85820d7b8f 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/InlineOrder.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/InlineOrder.cpp @@ -33,8 +33,7 @@ static cl::opt<InlinePriorityMode> UseInlinePriority( "Use inline cost priority."), clEnumValN(InlinePriorityMode::CostBenefit, "cost-benefit", "Use cost-benefit ratio."), - clEnumValN(InlinePriorityMode::ML, "ml", - "Use ML."))); + clEnumValN(InlinePriorityMode::ML, "ml", "Use ML."))); static cl::opt<int> ModuleInlinerTopPriorityThreshold( "moudle-inliner-top-priority-threshold", cl::Hidden, cl::init(0), @@ -281,8 +280,13 @@ private: } // namespace +AnalysisKey llvm::PluginInlineOrderAnalysis::Key; +bool llvm::PluginInlineOrderAnalysis::HasBeenRegistered; + std::unique_ptr<InlineOrder<std::pair<CallBase *, int>>> -llvm::getInlineOrder(FunctionAnalysisManager &FAM, const InlineParams &Params) { +llvm::getDefaultInlineOrder(FunctionAnalysisManager &FAM, + const InlineParams &Params, + ModuleAnalysisManager &MAM, Module &M) { switch (UseInlinePriority) { case InlinePriorityMode::Size: LLVM_DEBUG(dbgs() << " Current used priority: Size priority ---- \n"); @@ -295,11 +299,22 @@ llvm::getInlineOrder(FunctionAnalysisManager &FAM, const InlineParams &Params) { case InlinePriorityMode::CostBenefit: LLVM_DEBUG( dbgs() << " Current used priority: cost-benefit priority ---- \n"); - return std::make_unique<PriorityInlineOrder<CostBenefitPriority>>(FAM, Params); + return std::make_unique<PriorityInlineOrder<CostBenefitPriority>>(FAM, + Params); case InlinePriorityMode::ML: - LLVM_DEBUG( - dbgs() << " Current used priority: ML priority ---- \n"); + LLVM_DEBUG(dbgs() << " Current used priority: ML priority ---- \n"); return std::make_unique<PriorityInlineOrder<MLPriority>>(FAM, Params); } return nullptr; } + +std::unique_ptr<InlineOrder<std::pair<CallBase *, int>>> +llvm::getInlineOrder(FunctionAnalysisManager &FAM, const InlineParams &Params, + ModuleAnalysisManager &MAM, Module &M) { + if (llvm::PluginInlineOrderAnalysis::isRegistered()) { + LLVM_DEBUG(dbgs() << " Current used priority: plugin ---- \n"); + return MAM.getResult<PluginInlineOrderAnalysis>(M).Factory(FAM, Params, MAM, + M); + } + return getDefaultInlineOrder(FAM, Params, MAM, M); +}
\ No newline at end of file diff --git a/contrib/llvm-project/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp b/contrib/llvm-project/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp index 78e7f456ebc6..fba5859b74ce 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/InstructionPrecedenceTracking.cpp @@ -47,9 +47,9 @@ const Instruction *InstructionPrecedenceTracking::getFirstSpecialInstruction( validate(BB); #endif - if (FirstSpecialInsts.find(BB) == FirstSpecialInsts.end()) { + if (!FirstSpecialInsts.contains(BB)) { fill(BB); - assert(FirstSpecialInsts.find(BB) != FirstSpecialInsts.end() && "Must be!"); + assert(FirstSpecialInsts.contains(BB) && "Must be!"); } return FirstSpecialInsts[BB]; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp b/contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp index c83eb96bbc69..0bfea6140ab5 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp @@ -74,6 +74,10 @@ static Value *simplifyGEPInst(Type *, Value *, ArrayRef<Value *>, bool, const SimplifyQuery &, unsigned); static Value *simplifySelectInst(Value *, Value *, Value *, const SimplifyQuery &, unsigned); +static Value *simplifyInstructionWithOperands(Instruction *I, + ArrayRef<Value *> NewOps, + const SimplifyQuery &SQ, + unsigned MaxRecurse); static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal, Value *FalseVal) { @@ -214,12 +218,6 @@ static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { // Arguments and constants dominate all instructions. return true; - // If we are processing instructions (and/or basic blocks) that have not been - // fully added to a function, the parent nodes may still be null. Simply - // return the conservative answer in these cases. - if (!I->getParent() || !P->getParent() || !I->getFunction()) - return false; - // If we have a DominatorTree then do a precise test. if (DT) return DT->dominates(I, P); @@ -539,12 +537,16 @@ static Value *threadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, // Evaluate the BinOp on the incoming phi values. Value *CommonValue = nullptr; - for (Value *Incoming : PI->incoming_values()) { + for (Use &Incoming : PI->incoming_values()) { // If the incoming value is the phi node itself, it can safely be skipped. if (Incoming == PI) continue; - Value *V = PI == LHS ? simplifyBinOp(Opcode, Incoming, RHS, Q, MaxRecurse) - : simplifyBinOp(Opcode, LHS, Incoming, Q, MaxRecurse); + Instruction *InTI = PI->getIncomingBlock(Incoming)->getTerminator(); + Value *V = PI == LHS + ? simplifyBinOp(Opcode, Incoming, RHS, + Q.getWithInstruction(InTI), MaxRecurse) + : simplifyBinOp(Opcode, LHS, Incoming, + Q.getWithInstruction(InTI), MaxRecurse); // If the operation failed to simplify, or simplified to a different value // to previously, then give up. if (!V || (CommonValue && V != CommonValue)) @@ -992,6 +994,82 @@ Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, return ::simplifyMulInst(Op0, Op1, IsNSW, IsNUW, Q, RecursionLimit); } +/// Given a predicate and two operands, return true if the comparison is true. +/// This is a helper for div/rem simplification where we return some other value +/// when we can prove a relationship between the operands. +static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + Value *V = simplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse); + Constant *C = dyn_cast_or_null<Constant>(V); + return (C && C->isAllOnesValue()); +} + +/// Return true if we can simplify X / Y to 0. Remainder can adapt that answer +/// to simplify X % Y to X. +static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q, + unsigned MaxRecurse, bool IsSigned) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return false; + + if (IsSigned) { + // (X srem Y) sdiv Y --> 0 + if (match(X, m_SRem(m_Value(), m_Specific(Y)))) + return true; + + // |X| / |Y| --> 0 + // + // We require that 1 operand is a simple constant. That could be extended to + // 2 variables if we computed the sign bit for each. + // + // Make sure that a constant is not the minimum signed value because taking + // the abs() of that is undefined. + Type *Ty = X->getType(); + const APInt *C; + if (match(X, m_APInt(C)) && !C->isMinSignedValue()) { + // Is the variable divisor magnitude always greater than the constant + // dividend magnitude? + // |Y| > |C| --> Y < -abs(C) or Y > abs(C) + Constant *PosDividendC = ConstantInt::get(Ty, C->abs()); + Constant *NegDividendC = ConstantInt::get(Ty, -C->abs()); + if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) || + isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse)) + return true; + } + if (match(Y, m_APInt(C))) { + // Special-case: we can't take the abs() of a minimum signed value. If + // that's the divisor, then all we have to do is prove that the dividend + // is also not the minimum signed value. + if (C->isMinSignedValue()) + return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse); + + // Is the variable dividend magnitude always less than the constant + // divisor magnitude? + // |X| < |C| --> X > -abs(C) and X < abs(C) + Constant *PosDivisorC = ConstantInt::get(Ty, C->abs()); + Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs()); + if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) && + isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse)) + return true; + } + return false; + } + + // IsSigned == false. + + // Is the unsigned dividend known to be less than a constant divisor? + // TODO: Convert this (and above) to range analysis + // ("computeConstantRangeIncludingKnownBits")? + const APInt *C; + if (match(Y, m_APInt(C)) && + computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C)) + return true; + + // Try again for any divisor: + // Is the dividend unsigned less than the divisor? + return isICmpTrue(ICmpInst::ICMP_ULT, X, Y, Q, MaxRecurse); +} + /// Check for common or similar folds of integer division or integer remainder. /// This applies to all 4 opcodes (sdiv/udiv/srem/urem). static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0, @@ -1046,19 +1124,28 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0, if (Op0 == Op1) return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty); + + KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + // X / 0 -> poison + // X % 0 -> poison + // If the divisor is known to be zero, just return poison. This can happen in + // some cases where its provable indirectly the denominator is zero but it's + // not trivially simplifiable (i.e known zero through a phi node). + if (Known.isZero()) + return PoisonValue::get(Ty); + // X / 1 -> X // X % 1 -> 0 - // If this is a boolean op (single-bit element type), we can't have - // division-by-zero or remainder-by-zero, so assume the divisor is 1. - // Similarly, if we're zero-extending a boolean divisor, then assume it's a 1. - Value *X; - if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1) || - (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) + // If the divisor can only be zero or one, we can't have division-by-zero + // or remainder-by-zero, so assume the divisor is 1. + // e.g. 1, zext (i8 X), sdiv X (Y and 1) + if (Known.countMinLeadingZeros() == Known.getBitWidth() - 1) return IsDiv ? Op0 : Constant::getNullValue(Ty); // If X * Y does not overflow, then: // X * Y / Y -> X // X * Y % Y -> 0 + Value *X; if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { auto *Mul = cast<OverflowingBinaryOperator>(Op0); // The multiplication can't overflow if it is defined not to, or if @@ -1071,82 +1158,25 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0, } } + if (isDivZero(Op0, Op1, Q, MaxRecurse, IsSigned)) + return IsDiv ? Constant::getNullValue(Op0->getType()) : Op0; + if (Value *V = simplifyByDomEq(Opcode, Op0, Op1, Q, MaxRecurse)) return V; - return nullptr; -} - -/// Given a predicate and two operands, return true if the comparison is true. -/// This is a helper for div/rem simplification where we return some other value -/// when we can prove a relationship between the operands. -static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS, - const SimplifyQuery &Q, unsigned MaxRecurse) { - Value *V = simplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse); - Constant *C = dyn_cast_or_null<Constant>(V); - return (C && C->isAllOnesValue()); -} - -/// Return true if we can simplify X / Y to 0. Remainder can adapt that answer -/// to simplify X % Y to X. -static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q, - unsigned MaxRecurse, bool IsSigned) { - // Recursion is always used, so bail out at once if we already hit the limit. - if (!MaxRecurse--) - return false; - - if (IsSigned) { - // |X| / |Y| --> 0 - // - // We require that 1 operand is a simple constant. That could be extended to - // 2 variables if we computed the sign bit for each. - // - // Make sure that a constant is not the minimum signed value because taking - // the abs() of that is undefined. - Type *Ty = X->getType(); - const APInt *C; - if (match(X, m_APInt(C)) && !C->isMinSignedValue()) { - // Is the variable divisor magnitude always greater than the constant - // dividend magnitude? - // |Y| > |C| --> Y < -abs(C) or Y > abs(C) - Constant *PosDividendC = ConstantInt::get(Ty, C->abs()); - Constant *NegDividendC = ConstantInt::get(Ty, -C->abs()); - if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) || - isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse)) - return true; - } - if (match(Y, m_APInt(C))) { - // Special-case: we can't take the abs() of a minimum signed value. If - // that's the divisor, then all we have to do is prove that the dividend - // is also not the minimum signed value. - if (C->isMinSignedValue()) - return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse); - - // Is the variable dividend magnitude always less than the constant - // divisor magnitude? - // |X| < |C| --> X > -abs(C) and X < abs(C) - Constant *PosDivisorC = ConstantInt::get(Ty, C->abs()); - Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs()); - if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) && - isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse)) - return true; - } - return false; - } - - // IsSigned == false. + // If the operation is with the result of a select instruction, check whether + // operating on either branch of the select always yields the same value. + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (Value *V = threadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; - // Is the unsigned dividend known to be less than a constant divisor? - // TODO: Convert this (and above) to range analysis - // ("computeConstantRangeIncludingKnownBits")? - const APInt *C; - if (match(Y, m_APInt(C)) && - computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C)) - return true; + // If the operation is with the result of a phi instruction, check whether + // operating on all incoming values of the phi always yields the same value. + if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) + if (Value *V = threadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; - // Try again for any divisor: - // Is the dividend unsigned less than the divisor? - return isICmpTrue(ICmpInst::ICMP_ULT, X, Y, Q, MaxRecurse); + return nullptr; } /// These are simplifications common to SDiv and UDiv. @@ -1163,44 +1193,12 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, // at least as many trailing zeros as the divisor to divide evenly. If it has // less trailing zeros, then the result must be poison. const APInt *DivC; - if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countTrailingZeros()) { + if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) { KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (KnownOp0.countMaxTrailingZeros() < DivC->countTrailingZeros()) + if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero()) return PoisonValue::get(Op0->getType()); } - bool IsSigned = Opcode == Instruction::SDiv; - - // (X rem Y) / Y -> 0 - if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || - (!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) - return Constant::getNullValue(Op0->getType()); - - // (X /u C1) /u C2 -> 0 if C1 * C2 overflow - ConstantInt *C1, *C2; - if (!IsSigned && match(Op0, m_UDiv(m_Value(), m_ConstantInt(C1))) && - match(Op1, m_ConstantInt(C2))) { - bool Overflow; - (void)C1->getValue().umul_ov(C2->getValue(), Overflow); - if (Overflow) - return Constant::getNullValue(Op0->getType()); - } - - // If the operation is with the result of a select instruction, check whether - // operating on either branch of the select always yields the same value. - if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) - if (Value *V = threadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) - return V; - - // If the operation is with the result of a phi instruction, check whether - // operating on all incoming values of the phi always yields the same value. - if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) - if (Value *V = threadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) - return V; - - if (isDivZero(Op0, Op1, Q, MaxRecurse, IsSigned)) - return Constant::getNullValue(Op0->getType()); - return nullptr; } @@ -1213,13 +1211,6 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q, MaxRecurse)) return V; - // (X % Y) % Y -> X % Y - if ((Opcode == Instruction::SRem && - match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || - (Opcode == Instruction::URem && - match(Op0, m_URem(m_Value(), m_Specific(Op1))))) - return Op0; - // (X << Y) % X -> 0 if (Q.IIQ.UseInstrInfo && ((Opcode == Instruction::SRem && @@ -1228,22 +1219,6 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, match(Op0, m_NUWShl(m_Specific(Op1), m_Value()))))) return Constant::getNullValue(Op0->getType()); - // If the operation is with the result of a select instruction, check whether - // operating on either branch of the select always yields the same value. - if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) - if (Value *V = threadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) - return V; - - // If the operation is with the result of a phi instruction, check whether - // operating on all incoming values of the phi always yields the same value. - if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) - if (Value *V = threadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) - return V; - - // If X / Y == 0, then X % Y == X. - if (isDivZero(Op0, Op1, Q, MaxRecurse, Opcode == Instruction::SRem)) - return Op0; - return nullptr; } @@ -1407,8 +1382,8 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0, return nullptr; } -/// Given operands for an Shl, LShr or AShr, see if we can -/// fold the result. If not, this returns null. +/// Given operands for an LShr or AShr, see if we can fold the result. If not, +/// this returns null. static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, bool IsExact, const SimplifyQuery &Q, unsigned MaxRecurse) { @@ -1445,10 +1420,11 @@ static Value *simplifyShlInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, simplifyShift(Instruction::Shl, Op0, Op1, IsNSW, Q, MaxRecurse)) return V; + Type *Ty = Op0->getType(); // undef << X -> 0 // undef << X -> undef if (if it's NSW/NUW) if (Q.isUndefValue(Op0)) - return IsNSW || IsNUW ? Op0 : Constant::getNullValue(Op0->getType()); + return IsNSW || IsNUW ? Op0 : Constant::getNullValue(Ty); // (X >> A) << A -> X Value *X; @@ -1462,6 +1438,13 @@ static Value *simplifyShlInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, // NOTE: could use computeKnownBits() / LazyValueInfo, // but the cost-benefit analysis suggests it isn't worth it. + // "nuw" guarantees that only zeros are shifted out, and "nsw" guarantees + // that the sign-bit does not change, so the only input that does not + // produce poison is 0, and "0 << (bitwidth-1) --> 0". + if (IsNSW && IsNUW && + match(Op1, m_SpecificInt(Ty->getScalarSizeInBits() - 1))) + return Constant::getNullValue(Ty); + return nullptr; } @@ -1960,13 +1943,16 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1, return nullptr; } -static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, FCmpInst *LHS, +static Value *simplifyAndOrOfFCmps(const SimplifyQuery &Q, FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); if (LHS0->getType() != RHS0->getType()) return nullptr; + const DataLayout &DL = Q.DL; + const TargetLibraryInfo *TLI = Q.TLI; + FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); if ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && !IsAnd)) { @@ -1978,8 +1964,10 @@ static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, FCmpInst *LHS, // (fcmp uno NNAN, X) | (fcmp uno Y, X) --> fcmp uno Y, X // (fcmp uno X, NNAN) | (fcmp uno X, Y) --> fcmp uno X, Y // (fcmp uno X, NNAN) | (fcmp uno Y, X) --> fcmp uno Y, X - if ((isKnownNeverNaN(LHS0, TLI) && (LHS1 == RHS0 || LHS1 == RHS1)) || - (isKnownNeverNaN(LHS1, TLI) && (LHS0 == RHS0 || LHS0 == RHS1))) + if (((LHS1 == RHS0 || LHS1 == RHS1) && + isKnownNeverNaN(LHS0, DL, TLI, 0, Q.AC, Q.CxtI, Q.DT)) || + ((LHS0 == RHS0 || LHS0 == RHS1) && + isKnownNeverNaN(LHS1, DL, TLI, 0, Q.AC, Q.CxtI, Q.DT))) return RHS; // (fcmp ord X, Y) & (fcmp ord NNAN, X) --> fcmp ord X, Y @@ -1990,8 +1978,10 @@ static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, FCmpInst *LHS, // (fcmp uno Y, X) | (fcmp uno NNAN, X) --> fcmp uno Y, X // (fcmp uno X, Y) | (fcmp uno X, NNAN) --> fcmp uno X, Y // (fcmp uno Y, X) | (fcmp uno X, NNAN) --> fcmp uno Y, X - if ((isKnownNeverNaN(RHS0, TLI) && (RHS1 == LHS0 || RHS1 == LHS1)) || - (isKnownNeverNaN(RHS1, TLI) && (RHS0 == LHS0 || RHS0 == LHS1))) + if (((RHS1 == LHS0 || RHS1 == LHS1) && + isKnownNeverNaN(RHS0, DL, TLI, 0, Q.AC, Q.CxtI, Q.DT)) || + ((RHS0 == LHS0 || RHS0 == LHS1) && + isKnownNeverNaN(RHS1, DL, TLI, 0, Q.AC, Q.CxtI, Q.DT))) return LHS; } @@ -2019,7 +2009,7 @@ static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, Value *Op0, auto *FCmp0 = dyn_cast<FCmpInst>(Op0); auto *FCmp1 = dyn_cast<FCmpInst>(Op1); if (FCmp0 && FCmp1) - V = simplifyAndOrOfFCmps(Q.TLI, FCmp0, FCmp1, IsAnd); + V = simplifyAndOrOfFCmps(Q, FCmp0, FCmp1, IsAnd); if (!V) return nullptr; @@ -2642,7 +2632,7 @@ static bool isAllocDisjoint(const Value *V) { // that might be resolve lazily to symbols in another dynamically-loaded // library (and, thus, could be malloc'ed by the implementation). if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) - return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); + return AI->isStaticAlloca(); if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) return (GV->hasLocalLinkage() || GV->hasHiddenVisibility() || GV->hasProtectedVisibility() || GV->hasGlobalUnnamedAddr()) && @@ -2727,16 +2717,13 @@ static bool haveNonOverlappingStorage(const Value *V1, const Value *V2) { // this optimization. static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q) { + assert(LHS->getType() == RHS->getType() && "Must have same types"); const DataLayout &DL = Q.DL; const TargetLibraryInfo *TLI = Q.TLI; const DominatorTree *DT = Q.DT; const Instruction *CxtI = Q.CxtI; const InstrInfoQuery &IIQ = Q.IIQ; - // First, skip past any trivial no-ops. - LHS = LHS->stripPointerCasts(); - RHS = RHS->stripPointerCasts(); - // A non-null pointer is not equal to a null pointer. if (isa<ConstantPointerNull>(RHS) && ICmpInst::isEquality(Pred) && llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr, @@ -2775,8 +2762,10 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS, // Even if an non-inbounds GEP occurs along the path we can still optimize // equality comparisons concerning the result. bool AllowNonInbounds = ICmpInst::isEquality(Pred); - APInt LHSOffset = stripAndComputeConstantOffsets(DL, LHS, AllowNonInbounds); - APInt RHSOffset = stripAndComputeConstantOffsets(DL, RHS, AllowNonInbounds); + unsigned IndexSize = DL.getIndexTypeSizeInBits(LHS->getType()); + APInt LHSOffset(IndexSize, 0), RHSOffset(IndexSize, 0); + LHS = LHS->stripAndAccumulateConstantOffsets(DL, LHSOffset, AllowNonInbounds); + RHS = RHS->stripAndAccumulateConstantOffsets(DL, RHSOffset, AllowNonInbounds); // If LHS and RHS are related via constant offsets to the same base // value, we can replace it with an icmp which just compares the offsets. @@ -2804,11 +2793,11 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS, }(LHS); Opts.NullIsUnknownSize = F ? NullPointerIsDefined(F) : true; if (getObjectSize(LHS, LHSSize, DL, TLI, Opts) && - getObjectSize(RHS, RHSSize, DL, TLI, Opts) && - !LHSOffset.isNegative() && !RHSOffset.isNegative() && - LHSOffset.ult(LHSSize) && RHSOffset.ult(RHSSize)) { - return ConstantInt::get(getCompareTy(LHS), - !CmpInst::isTrueWhenEqual(Pred)); + getObjectSize(RHS, RHSSize, DL, TLI, Opts)) { + APInt Dist = LHSOffset - RHSOffset; + if (Dist.isNonNegative() ? Dist.ult(LHSSize) : (-Dist).ult(RHSSize)) + return ConstantInt::get(getCompareTy(LHS), + !CmpInst::isTrueWhenEqual(Pred)); } } @@ -2850,11 +2839,35 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS, else if (isAllocLikeFn(RHS, TLI) && llvm::isKnownNonZero(LHS, DL, 0, nullptr, CxtI, DT)) MI = RHS; - // FIXME: We should also fold the compare when the pointer escapes, but the - // compare dominates the pointer escape - if (MI && !PointerMayBeCaptured(MI, true, true)) - return ConstantInt::get(getCompareTy(LHS), - CmpInst::isFalseWhenEqual(Pred)); + if (MI) { + // FIXME: This is incorrect, see PR54002. While we can assume that the + // allocation is at an address that makes the comparison false, this + // requires that *all* comparisons to that address be false, which + // InstSimplify cannot guarantee. + struct CustomCaptureTracker : public CaptureTracker { + bool Captured = false; + void tooManyUses() override { Captured = true; } + bool captured(const Use *U) override { + if (auto *ICmp = dyn_cast<ICmpInst>(U->getUser())) { + // Comparison against value stored in global variable. Given the + // pointer does not escape, its value cannot be guessed and stored + // separately in a global variable. + unsigned OtherIdx = 1 - U->getOperandNo(); + auto *LI = dyn_cast<LoadInst>(ICmp->getOperand(OtherIdx)); + if (LI && isa<GlobalVariable>(LI->getPointerOperand())) + return false; + } + + Captured = true; + return true; + } + }; + CustomCaptureTracker Tracker; + PointerMayBeCaptured(MI, &Tracker); + if (!Tracker.Captured) + return ConstantInt::get(getCompareTy(LHS), + CmpInst::isFalseWhenEqual(Pred)); + } } // Otherwise, fail. @@ -3394,8 +3407,26 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return ConstantInt::getTrue(getCompareTy(RHS)); } - if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && - LBO->getOperand(1) == RBO->getOperand(1)) { + if (!MaxRecurse || !LBO || !RBO || LBO->getOpcode() != RBO->getOpcode()) + return nullptr; + + if (LBO->getOperand(0) == RBO->getOperand(0)) { + switch (LBO->getOpcode()) { + default: + break; + case Instruction::Shl: + bool NUW = Q.IIQ.hasNoUnsignedWrap(LBO) && Q.IIQ.hasNoUnsignedWrap(RBO); + bool NSW = Q.IIQ.hasNoSignedWrap(LBO) && Q.IIQ.hasNoSignedWrap(RBO); + if (!NUW || (ICmpInst::isSigned(Pred) && !NSW) || + !isKnownNonZero(LBO->getOperand(0), Q.DL)) + break; + if (Value *V = simplifyICmpInst(Pred, LBO->getOperand(1), + RBO->getOperand(1), Q, MaxRecurse - 1)) + return V; + } + } + + if (LBO->getOperand(1) == RBO->getOperand(1)) { switch (LBO->getOpcode()) { default: break; @@ -3631,7 +3662,7 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q) { // Gracefully handle instructions that have not been inserted yet. - if (!Q.AC || !Q.CxtI || !Q.CxtI->getParent()) + if (!Q.AC || !Q.CxtI) return nullptr; for (Value *AssumeBaseOp : {LHS, RHS}) { @@ -3650,6 +3681,36 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate, return nullptr; } +static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred, + Value *LHS, Value *RHS) { + auto *II = dyn_cast<IntrinsicInst>(LHS); + if (!II) + return nullptr; + + switch (II->getIntrinsicID()) { + case Intrinsic::uadd_sat: + // uadd.sat(X, Y) uge X, uadd.sat(X, Y) uge Y + if (II->getArgOperand(0) == RHS || II->getArgOperand(1) == RHS) { + if (Pred == ICmpInst::ICMP_UGE) + return ConstantInt::getTrue(getCompareTy(II)); + if (Pred == ICmpInst::ICMP_ULT) + return ConstantInt::getFalse(getCompareTy(II)); + } + return nullptr; + case Intrinsic::usub_sat: + // usub.sat(X, Y) ule X + if (II->getArgOperand(0) == RHS) { + if (Pred == ICmpInst::ICMP_ULE) + return ConstantInt::getTrue(getCompareTy(II)); + if (Pred == ICmpInst::ICMP_UGT) + return ConstantInt::getFalse(getCompareTy(II)); + } + return nullptr; + default: + return nullptr; + } +} + /// Given operands for an ICmpInst, see if we can fold the result. /// If not, this returns null. static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -3764,22 +3825,27 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended // too. If not, then try to deduce the result of the comparison. - else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + else if (match(RHS, m_ImmConstant())) { + Constant *C = dyn_cast<Constant>(RHS); + assert(C != nullptr); + // Compute the constant that would happen if we truncated to SrcTy then // reextended to DstTy. - Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy); Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); + Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C); - // If the re-extended constant didn't change then this is effectively - // also a case of comparing two zero-extended values. - if (RExt == CI && MaxRecurse) + // If the re-extended constant didn't change any of the elements then + // this is effectively also a case of comparing two zero-extended + // values. + if (AnyEq->isAllOnesValue() && MaxRecurse) if (Value *V = simplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), SrcOp, Trunc, Q, MaxRecurse - 1)) return V; // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit // there. Use this to work out the result of the comparison. - if (RExt != CI) { + if (AnyEq->isNullValue()) { switch (Pred) { default: llvm_unreachable("Unknown ICmp predicate!"); @@ -3787,26 +3853,23 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - return ConstantInt::getFalse(CI->getContext()); + return Constant::getNullValue(ITy); case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - return ConstantInt::getTrue(CI->getContext()); + return Constant::getAllOnesValue(ITy); // LHS is non-negative. If RHS is negative then LHS >s LHS. If RHS // is non-negative then LHS <s RHS. case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() - ? ConstantInt::getTrue(CI->getContext()) - : ConstantInt::getFalse(CI->getContext()); - + return ConstantExpr::getICmp(ICmpInst::ICMP_SLT, C, + Constant::getNullValue(C->getType())); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() - ? ConstantInt::getFalse(CI->getContext()) - : ConstantInt::getTrue(CI->getContext()); + return ConstantExpr::getICmp(ICmpInst::ICMP_SGE, C, + Constant::getNullValue(C->getType())); } } } @@ -3833,42 +3896,44 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended // too. If not, then try to deduce the result of the comparison. - else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + else if (match(RHS, m_ImmConstant())) { + Constant *C = dyn_cast<Constant>(RHS); + assert(C != nullptr); + // Compute the constant that would happen if we truncated to SrcTy then // reextended to DstTy. - Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); + Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy); Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); + Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C); // If the re-extended constant didn't change then this is effectively // also a case of comparing two sign-extended values. - if (RExt == CI && MaxRecurse) + if (AnyEq->isAllOnesValue() && MaxRecurse) if (Value *V = simplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse - 1)) return V; // Otherwise the upper bits of LHS are all equal, while RHS has varying // bits there. Use this to work out the result of the comparison. - if (RExt != CI) { + if (AnyEq->isNullValue()) { switch (Pred) { default: llvm_unreachable("Unknown ICmp predicate!"); case ICmpInst::ICMP_EQ: - return ConstantInt::getFalse(CI->getContext()); + return Constant::getNullValue(ITy); case ICmpInst::ICMP_NE: - return ConstantInt::getTrue(CI->getContext()); + return Constant::getAllOnesValue(ITy); // If RHS is non-negative then LHS <s RHS. If RHS is negative then // LHS >s RHS. case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() - ? ConstantInt::getTrue(CI->getContext()) - : ConstantInt::getFalse(CI->getContext()); + return ConstantExpr::getICmp(ICmpInst::ICMP_SLT, C, + Constant::getNullValue(C->getType())); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() - ? ConstantInt::getFalse(CI->getContext()) - : ConstantInt::getTrue(CI->getContext()); + return ConstantExpr::getICmp(ICmpInst::ICMP_SGE, C, + Constant::getNullValue(C->getType())); // If LHS is non-negative then LHS <u RHS. If LHS is negative then // LHS >u RHS. @@ -3910,9 +3975,19 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse)) return V; + if (Value *V = simplifyICmpWithIntrinsicOnLHS(Pred, LHS, RHS)) + return V; + if (Value *V = simplifyICmpWithIntrinsicOnLHS( + ICmpInst::getSwappedPredicate(Pred), RHS, LHS)) + return V; + if (Value *V = simplifyICmpWithDominatingAssume(Pred, LHS, RHS, Q)) return V; + if (std::optional<bool> Res = + isImpliedByDomCondition(Pred, LHS, RHS, Q.CxtI, Q.DL)) + return ConstantInt::getBool(ITy, *Res); + // Simplify comparisons of related pointers using a powerful, recursive // GEP-walk when we have target data available.. if (LHS->getType()->isPointerTy()) @@ -3920,10 +3995,9 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return C; if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS)) if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS)) - if (Q.DL.getTypeSizeInBits(CLHS->getPointerOperandType()) == - Q.DL.getTypeSizeInBits(CLHS->getType()) && - Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) == - Q.DL.getTypeSizeInBits(CRHS->getType())) + if (CLHS->getPointerOperandType() == CRHS->getPointerOperandType() && + Q.DL.getTypeSizeInBits(CLHS->getPointerOperandType()) == + Q.DL.getTypeSizeInBits(CLHS->getType())) if (auto *C = computePointerICmp(Pred, CLHS->getPointerOperand(), CRHS->getPointerOperand(), Q)) return C; @@ -3976,7 +4050,8 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, // Fold (un)ordered comparison if we can determine there are no NaNs. if (Pred == FCmpInst::FCMP_UNO || Pred == FCmpInst::FCMP_ORD) if (FMF.noNaNs() || - (isKnownNeverNaN(LHS, Q.TLI) && isKnownNeverNaN(RHS, Q.TLI))) + (isKnownNeverNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT) && + isKnownNeverNaN(RHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT))) return ConstantInt::get(RetTy, Pred == FCmpInst::FCMP_ORD); // NaN is unordered; NaN is not ordered. @@ -4038,18 +4113,20 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // LHS == Inf - if (Pred == FCmpInst::FCMP_OEQ && isKnownNeverInfinity(LHS, Q.TLI)) + if (Pred == FCmpInst::FCMP_OEQ && + isKnownNeverInfinity(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) return getFalse(RetTy); // LHS != Inf - if (Pred == FCmpInst::FCMP_UNE && isKnownNeverInfinity(LHS, Q.TLI)) + if (Pred == FCmpInst::FCMP_UNE && + isKnownNeverInfinity(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(RetTy); // LHS == Inf || LHS == NaN - if (Pred == FCmpInst::FCMP_UEQ && isKnownNeverInfinity(LHS, Q.TLI) && - isKnownNeverNaN(LHS, Q.TLI)) + if (Pred == FCmpInst::FCMP_UEQ && + isKnownNeverInfOrNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) return getFalse(RetTy); // LHS != Inf && LHS != NaN - if (Pred == FCmpInst::FCMP_ONE && isKnownNeverInfinity(LHS, Q.TLI) && - isKnownNeverNaN(LHS, Q.TLI)) + if (Pred == FCmpInst::FCMP_ONE && + isKnownNeverInfOrNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(RetTy); } if (C->isNegative() && !C->isNegZero()) { @@ -4061,14 +4138,16 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, case FCmpInst::FCMP_UGT: case FCmpInst::FCMP_UNE: // (X >= 0) implies (X > C) when (C < 0) - if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0, + Q.AC, Q.CxtI, Q.DT)) return getTrue(RetTy); break; case FCmpInst::FCMP_OEQ: case FCmpInst::FCMP_OLE: case FCmpInst::FCMP_OLT: // (X >= 0) implies !(X < C) when (C < 0) - if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, + Q.DT)) return getFalse(RetTy); break; default: @@ -4125,18 +4204,23 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (match(RHS, m_AnyZeroFP())) { switch (Pred) { case FCmpInst::FCMP_OGE: - case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_ULT: { + FPClassTest Interested = FMF.noNaNs() ? fcNegative : fcNegative | fcNan; + KnownFPClass Known = computeKnownFPClass(LHS, Q.DL, Interested, 0, + Q.TLI, Q.AC, Q.CxtI, Q.DT); + // Positive or zero X >= 0.0 --> true // Positive or zero X < 0.0 --> false - if ((FMF.noNaNs() || isKnownNeverNaN(LHS, Q.TLI)) && - CannotBeOrderedLessThanZero(LHS, Q.TLI)) + if ((FMF.noNaNs() || Known.isKnownNeverNaN()) && + Known.cannotBeOrderedLessThanZero()) return Pred == FCmpInst::FCMP_OGE ? getTrue(RetTy) : getFalse(RetTy); break; + } case FCmpInst::FCMP_UGE: case FCmpInst::FCMP_OLT: // Positive or zero or nan X >= 0.0 --> true // Positive or zero or nan X < 0.0 --> false - if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) return Pred == FCmpInst::FCMP_UGE ? getTrue(RetTy) : getFalse(RetTy); break; default: @@ -4172,26 +4256,45 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (V == Op) return RepOp; + if (!MaxRecurse--) + return nullptr; + // We cannot replace a constant, and shouldn't even try. if (isa<Constant>(Op)) return nullptr; auto *I = dyn_cast<Instruction>(V); - if (!I || !is_contained(I->operands(), Op)) + if (!I) + return nullptr; + + // The arguments of a phi node might refer to a value from a previous + // cycle iteration. + if (isa<PHINode>(I)) return nullptr; if (Op->getType()->isVectorTy()) { // For vector types, the simplification must hold per-lane, so forbid // potentially cross-lane operations like shufflevector. - assert(I->getType()->isVectorTy() && "Vector type mismatch"); - if (isa<ShuffleVectorInst>(I) || isa<CallBase>(I)) + if (!I->getType()->isVectorTy() || isa<ShuffleVectorInst>(I) || + isa<CallBase>(I)) return nullptr; } // Replace Op with RepOp in instruction operands. - SmallVector<Value *, 8> NewOps(I->getNumOperands()); - transform(I->operands(), NewOps.begin(), - [&](Value *V) { return V == Op ? RepOp : V; }); + SmallVector<Value *, 8> NewOps; + bool AnyReplaced = false; + for (Value *InstOp : I->operands()) { + if (Value *NewInstOp = simplifyWithOpReplaced( + InstOp, Op, RepOp, Q, AllowRefinement, MaxRecurse)) { + NewOps.push_back(NewInstOp); + AnyReplaced = InstOp != NewInstOp; + } else { + NewOps.push_back(InstOp); + } + } + + if (!AnyReplaced) + return nullptr; if (!AllowRefinement) { // General InstSimplify functions may refine the result, e.g. by returning @@ -4211,15 +4314,35 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if ((Opcode == Instruction::And || Opcode == Instruction::Or) && NewOps[0] == NewOps[1]) return NewOps[0]; + + // x - x -> 0, x ^ x -> 0. This is non-refining, because x is non-poison + // by assumption and this case never wraps, so nowrap flags can be + // ignored. + if ((Opcode == Instruction::Sub || Opcode == Instruction::Xor) && + NewOps[0] == RepOp && NewOps[1] == RepOp) + return Constant::getNullValue(I->getType()); + + // If we are substituting an absorber constant into a binop and extra + // poison can't leak if we remove the select -- because both operands of + // the binop are based on the same value -- then it may be safe to replace + // the value with the absorber constant. Examples: + // (Op == 0) ? 0 : (Op & -Op) --> Op & -Op + // (Op == 0) ? 0 : (Op * (binop Op, C)) --> Op * (binop Op, C) + // (Op == -1) ? -1 : (Op | (binop C, Op) --> Op | (binop C, Op) + Constant *Absorber = + ConstantExpr::getBinOpAbsorber(Opcode, I->getType()); + if ((NewOps[0] == Absorber || NewOps[1] == Absorber) && + impliesPoison(BO, Op)) + return Absorber; } - if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { - // getelementptr x, 0 -> x - if (NewOps.size() == 2 && match(NewOps[1], m_Zero()) && - !GEP->isInBounds()) + if (isa<GetElementPtrInst>(I)) { + // getelementptr x, 0 -> x. + // This never returns poison, even if inbounds is set. + if (NewOps.size() == 2 && match(NewOps[1], m_Zero())) return NewOps[0]; } - } else if (MaxRecurse) { + } else { // The simplification queries below may return the original value. Consider: // %div = udiv i32 %arg, %arg2 // %mul = mul nsw i32 %div, %arg2 @@ -4233,23 +4356,8 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, return Simplified != V ? Simplified : nullptr; }; - if (auto *B = dyn_cast<BinaryOperator>(I)) - return PreventSelfSimplify(simplifyBinOp(B->getOpcode(), NewOps[0], - NewOps[1], Q, MaxRecurse - 1)); - - if (CmpInst *C = dyn_cast<CmpInst>(I)) - return PreventSelfSimplify(simplifyCmpInst(C->getPredicate(), NewOps[0], - NewOps[1], Q, MaxRecurse - 1)); - - if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) - return PreventSelfSimplify(simplifyGEPInst( - GEP->getSourceElementType(), NewOps[0], ArrayRef(NewOps).slice(1), - GEP->isInBounds(), Q, MaxRecurse - 1)); - - if (isa<SelectInst>(I)) - return PreventSelfSimplify(simplifySelectInst( - NewOps[0], NewOps[1], NewOps[2], Q, MaxRecurse - 1)); - // TODO: We could hand off more cases to instsimplify here. + return PreventSelfSimplify( + ::simplifyInstructionWithOperands(I, NewOps, Q, MaxRecurse)); } // If all operands are constant after substituting Op for RepOp then we can @@ -4406,6 +4514,24 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS, } /// Try to simplify a select instruction when its condition operand is an +/// integer equality comparison. +static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ false, + MaxRecurse) == TrueVal) + return FalseVal; + if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ true, + MaxRecurse) == FalseVal) + return FalseVal; + + return nullptr; +} + +/// Try to simplify a select instruction when its condition operand is an /// integer comparison. static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, Value *FalseVal, @@ -4493,20 +4619,38 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, // the arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ false, - MaxRecurse) == TrueVal || - simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ false, - MaxRecurse) == TrueVal) - return FalseVal; - if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ true, - MaxRecurse) == FalseVal || - simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ true, - MaxRecurse) == FalseVal) - return FalseVal; + if (Value *V = simplifySelectWithICmpEq(CmpLHS, CmpRHS, TrueVal, FalseVal, + Q, MaxRecurse)) + return V; + if (Value *V = simplifySelectWithICmpEq(CmpRHS, CmpLHS, TrueVal, FalseVal, + Q, MaxRecurse)) + return V; + + Value *X; + Value *Y; + // select((X | Y) == 0 ? X : 0) --> 0 (commuted 2 ways) + if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) && + match(CmpRHS, m_Zero())) { + // (X | Y) == 0 implies X == 0 and Y == 0. + if (Value *V = simplifySelectWithICmpEq(X, CmpRHS, TrueVal, FalseVal, Q, + MaxRecurse)) + return V; + if (Value *V = simplifySelectWithICmpEq(Y, CmpRHS, TrueVal, FalseVal, Q, + MaxRecurse)) + return V; + } + + // select((X & Y) == -1 ? X : -1) --> -1 (commuted 2 ways) + if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) && + match(CmpRHS, m_AllOnes())) { + // (X & Y) == -1 implies X == -1 and Y == -1. + if (Value *V = simplifySelectWithICmpEq(X, CmpRHS, TrueVal, FalseVal, Q, + MaxRecurse)) + return V; + if (Value *V = simplifySelectWithICmpEq(Y, CmpRHS, TrueVal, FalseVal, Q, + MaxRecurse)) + return V; + } } return nullptr; @@ -4550,7 +4694,8 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, if (auto *CondC = dyn_cast<Constant>(Cond)) { if (auto *TrueC = dyn_cast<Constant>(TrueVal)) if (auto *FalseC = dyn_cast<Constant>(FalseVal)) - return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); + if (Constant *C = ConstantFoldSelectInstruction(CondC, TrueC, FalseC)) + return C; // select poison, X, Y -> poison if (isa<PoisonValue>(CondC)) @@ -4598,6 +4743,9 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, // !(X || Y) && X --> false (commuted 2 ways) if (match(Cond, m_Not(m_c_LogicalOr(m_Specific(TrueVal), m_Value())))) return ConstantInt::getFalse(Cond->getType()); + // X && !(X || Y) --> false (commuted 2 ways) + if (match(TrueVal, m_Not(m_c_LogicalOr(m_Specific(Cond), m_Value())))) + return ConstantInt::getFalse(Cond->getType()); // (X || Y) && Y --> Y (commuted 2 ways) if (match(Cond, m_c_LogicalOr(m_Specific(TrueVal), m_Value()))) @@ -4618,6 +4766,13 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, // Match patterns that end in logical-or. if (match(TrueVal, m_One())) { + // !(X && Y) || X --> true (commuted 2 ways) + if (match(Cond, m_Not(m_c_LogicalAnd(m_Specific(FalseVal), m_Value())))) + return ConstantInt::getTrue(Cond->getType()); + // X || !(X && Y) --> true (commuted 2 ways) + if (match(FalseVal, m_Not(m_c_LogicalAnd(m_Specific(Cond), m_Value())))) + return ConstantInt::getTrue(Cond->getType()); + // (X && Y) || Y --> Y (commuted 2 ways) if (match(Cond, m_c_LogicalAnd(m_Specific(FalseVal), m_Value()))) return FalseVal; @@ -4747,10 +4902,8 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, } } - // For opaque pointers an all-zero GEP is a no-op. For typed pointers, - // it may be equivalent to a bitcast. - if (Ptr->getType()->getScalarType()->isOpaquePointerTy() && - Ptr->getType() == GEPTy && + // All-zero GEP is a no-op, unless it performs a vector splat. + if (Ptr->getType() == GEPTy && all_of(Indices, [](const auto *V) { return match(V, m_Zero()); })) return Ptr; @@ -4760,9 +4913,9 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, any_of(Indices, [](const auto *V) { return isa<PoisonValue>(V); })) return PoisonValue::get(GEPTy); + // getelementptr undef, idx -> undef if (Q.isUndefValue(Ptr)) - // If inbounds, we can choose an out-of-bounds pointer as a base pointer. - return InBounds ? PoisonValue::get(GEPTy) : UndefValue::get(GEPTy); + return UndefValue::get(GEPTy); bool IsScalableVec = isa<ScalableVectorType>(SrcTy) || any_of(Indices, [](const Value *V) { @@ -4853,6 +5006,10 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, !all_of(Indices, [](Value *V) { return isa<Constant>(V); })) return nullptr; + if (!ConstantExpr::isSupportedGetElementPtr(SrcTy)) + return ConstantFoldGetElementPtr(SrcTy, cast<Constant>(Ptr), InBounds, + std::nullopt, Indices); + auto *CE = ConstantExpr::getGetElementPtr(SrcTy, cast<Constant>(Ptr), Indices, InBounds); return ConstantFoldConstant(CE, Q.DL); @@ -4882,8 +5039,11 @@ static Value *simplifyInsertValueInst(Value *Agg, Value *Val, if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(Val)) if (EV->getAggregateOperand()->getType() == Agg->getType() && EV->getIndices() == Idxs) { - // insertvalue undef, (extractvalue y, n), n -> y - if (Q.isUndefValue(Agg)) + // insertvalue poison, (extractvalue y, n), n -> y + // insertvalue undef, (extractvalue y, n), n -> y if y cannot be poison + if (isa<PoisonValue>(Agg) || + (Q.isUndefValue(Agg) && + isGuaranteedNotToBePoison(EV->getAggregateOperand()))) return EV->getAggregateOperand(); // insertvalue y, (extractvalue y, n), n -> y @@ -5151,8 +5311,8 @@ static Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask, Type *RetTy, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (all_of(Mask, [](int Elem) { return Elem == UndefMaskElem; })) - return UndefValue::get(RetTy); + if (all_of(Mask, [](int Elem) { return Elem == PoisonMaskElem; })) + return PoisonValue::get(RetTy); auto *InVecTy = cast<VectorType>(Op0->getType()); unsigned MaskNumElts = Mask.size(); @@ -5217,11 +5377,11 @@ static Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, })) { assert(isa<UndefValue>(Op1) && "Expected undef operand 1 for splat"); - // Shuffle mask undefs become undefined constant result elements. + // Shuffle mask poisons become poison constant result elements. SmallVector<Constant *, 16> VecC(MaskNumElts, C); for (unsigned i = 0; i != MaskNumElts; ++i) if (Indices[i] == -1) - VecC[i] = UndefValue::get(C->getType()); + VecC[i] = PoisonValue::get(C->getType()); return ConstantVector::get(VecC); } } @@ -5299,28 +5459,42 @@ Value *llvm::simplifyFNegInst(Value *Op, FastMathFlags FMF, /// Try to propagate existing NaN values when possible. If not, replace the /// constant or elements in the constant with a canonical NaN. static Constant *propagateNaN(Constant *In) { - if (auto *VecTy = dyn_cast<FixedVectorType>(In->getType())) { + Type *Ty = In->getType(); + if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) { unsigned NumElts = VecTy->getNumElements(); SmallVector<Constant *, 32> NewC(NumElts); for (unsigned i = 0; i != NumElts; ++i) { Constant *EltC = In->getAggregateElement(i); - // Poison and existing NaN elements propagate. + // Poison elements propagate. NaN propagates except signaling is quieted. // Replace unknown or undef elements with canonical NaN. - if (EltC && (isa<PoisonValue>(EltC) || EltC->isNaN())) + if (EltC && isa<PoisonValue>(EltC)) NewC[i] = EltC; + else if (EltC && EltC->isNaN()) + NewC[i] = ConstantFP::get( + EltC->getType(), cast<ConstantFP>(EltC)->getValue().makeQuiet()); else - NewC[i] = (ConstantFP::getNaN(VecTy->getElementType())); + NewC[i] = ConstantFP::getNaN(VecTy->getElementType()); } return ConstantVector::get(NewC); } - // It is not a fixed vector, but not a simple NaN either? + // If it is not a fixed vector, but not a simple NaN either, return a + // canonical NaN. if (!In->isNaN()) - return ConstantFP::getNaN(In->getType()); + return ConstantFP::getNaN(Ty); + + // If we known this is a NaN, and it's scalable vector, we must have a splat + // on our hands. Grab that before splatting a QNaN constant. + if (isa<ScalableVectorType>(Ty)) { + auto *Splat = In->getSplatValue(); + assert(Splat && Splat->isNaN() && + "Found a scalable-vector NaN but not a splat"); + In = Splat; + } - // Propagate the existing NaN constant when possible. - // TODO: Should we quiet a signaling NaN? - return In; + // Propagate an existing QNaN constant. If it is an SNaN, make it quiet, but + // preserve the sign/payload. + return ConstantFP::get(Ty, cast<ConstantFP>(In)->getValue().makeQuiet()); } /// Perform folds that are common to any floating-point operation. This implies @@ -5393,7 +5567,7 @@ simplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, // fadd X, 0 ==> X, when we know X is not -0 if (canIgnoreSNaN(ExBehavior, FMF)) if (match(Op1, m_PosZeroFP()) && - (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) + (FMF.noSignedZeros() || cannotBeNegativeZero(Op0, Q.DL, Q.TLI))) return Op0; if (!isDefaultFPEnvironment(ExBehavior, Rounding)) @@ -5413,11 +5587,11 @@ simplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 if (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)))) - return ConstantFP::getNullValue(Op0->getType()); + return ConstantFP::getZero(Op0->getType()); if (match(Op0, m_FNeg(m_Specific(Op1))) || match(Op1, m_FNeg(m_Specific(Op0)))) - return ConstantFP::getNullValue(Op0->getType()); + return ConstantFP::getZero(Op0->getType()); } // (X - Y) + Y --> X @@ -5455,7 +5629,7 @@ simplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, // fsub X, -0 ==> X, when we know X is not -0 if (canIgnoreSNaN(ExBehavior, FMF)) if (match(Op1, m_NegZeroFP()) && - (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) + (FMF.noSignedZeros() || cannotBeNegativeZero(Op0, Q.DL, Q.TLI))) return Op0; // fsub -0.0, (fsub -0.0, X) ==> X @@ -5521,11 +5695,12 @@ static Value *simplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, if (match(Op1, m_AnyZeroFP())) { // X * 0.0 --> 0.0 (with nnan and nsz) if (FMF.noNaNs() && FMF.noSignedZeros()) - return ConstantFP::getNullValue(Op0->getType()); + return ConstantFP::getZero(Op0->getType()); // +normal number * (-)0.0 --> (-)0.0 - if (isKnownNeverInfinity(Op0, Q.TLI) && isKnownNeverNaN(Op0, Q.TLI) && - SignBitMustBeZero(Op0, Q.TLI)) + if (isKnownNeverInfOrNaN(Op0, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT) && + // TODO: Check SignBit from computeKnownFPClass when it's more complete. + SignBitMustBeZero(Op0, Q.DL, Q.TLI)) return Op1; } @@ -5610,7 +5785,7 @@ simplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, // Requires that NaNs are off (X could be zero) and signed zeroes are // ignored (X could be positive or negative, so the output sign is unknown). if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP())) - return ConstantFP::getNullValue(Op0->getType()); + return ConstantFP::getZero(Op0->getType()); if (FMF.noNaNs()) { // X / X -> 1.0 is legal when NaNs are ignored. @@ -5667,7 +5842,7 @@ simplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (FMF.noNaNs()) { // +0 % X -> 0 if (match(Op0, m_PosZeroFP())) - return ConstantFP::getNullValue(Op0->getType()); + return ConstantFP::getZero(Op0->getType()); // -0 % X -> -0 if (match(Op0, m_NegZeroFP())) return ConstantFP::getNegativeZero(Op0->getType()); @@ -5932,7 +6107,7 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, Value *X; switch (IID) { case Intrinsic::fabs: - if (SignBitMustBeZero(Op0, Q.TLI)) + if (SignBitMustBeZero(Op0, Q.DL, Q.TLI)) return Op0; break; case Intrinsic::bswap: @@ -5998,6 +6173,15 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, if (isSplatValue(Op0)) return Op0; break; + case Intrinsic::frexp: { + // Frexp is idempotent with the added complication of the struct return. + if (match(Op0, m_ExtractValue<0>(m_Value(X)))) { + if (match(X, m_Intrinsic<Intrinsic::frexp>(m_Value()))) + return X; + } + + break; + } default: break; } @@ -6030,6 +6214,51 @@ static Value *foldMinMaxSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) { return nullptr; } +/// Given a min/max intrinsic, see if it can be removed based on having an +/// operand that is another min/max intrinsic with shared operand(s). The caller +/// is expected to swap the operand arguments to handle commutation. +static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0, + Value *Op1) { + assert((IID == Intrinsic::maxnum || IID == Intrinsic::minnum || + IID == Intrinsic::maximum || IID == Intrinsic::minimum) && + "Unsupported intrinsic"); + + auto *M0 = dyn_cast<IntrinsicInst>(Op0); + // If Op0 is not the same intrinsic as IID, do not process. + // This is a difference with integer min/max handling. We do not process the + // case like max(min(X,Y),min(X,Y)) => min(X,Y). But it can be handled by GVN. + if (!M0 || M0->getIntrinsicID() != IID) + return nullptr; + Value *X0 = M0->getOperand(0); + Value *Y0 = M0->getOperand(1); + // Simple case, m(m(X,Y), X) => m(X, Y) + // m(m(X,Y), Y) => m(X, Y) + // For minimum/maximum, X is NaN => m(NaN, Y) == NaN and m(NaN, NaN) == NaN. + // For minimum/maximum, Y is NaN => m(X, NaN) == NaN and m(NaN, NaN) == NaN. + // For minnum/maxnum, X is NaN => m(NaN, Y) == Y and m(Y, Y) == Y. + // For minnum/maxnum, Y is NaN => m(X, NaN) == X and m(X, NaN) == X. + if (X0 == Op1 || Y0 == Op1) + return M0; + + auto *M1 = dyn_cast<IntrinsicInst>(Op1); + if (!M1) + return nullptr; + Value *X1 = M1->getOperand(0); + Value *Y1 = M1->getOperand(1); + Intrinsic::ID IID1 = M1->getIntrinsicID(); + // we have a case m(m(X,Y),m'(X,Y)) taking into account m' is commutative. + // if m' is m or inversion of m => m(m(X,Y),m'(X,Y)) == m(X,Y). + // For minimum/maximum, X is NaN => m(NaN,Y) == m'(NaN, Y) == NaN. + // For minimum/maximum, Y is NaN => m(X,NaN) == m'(X, NaN) == NaN. + // For minnum/maxnum, X is NaN => m(NaN,Y) == m'(NaN, Y) == Y. + // For minnum/maxnum, Y is NaN => m(X,NaN) == m'(X, NaN) == X. + if ((X0 == X1 && Y0 == Y1) || (X0 == Y1 && Y0 == X1)) + if (IID1 == IID || getInverseMinMaxIntrinsic(IID1) == IID) + return M0; + + return nullptr; +} + static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, const SimplifyQuery &Q) { Intrinsic::ID IID = F->getIntrinsicID(); @@ -6116,13 +6345,6 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, if (isICmpTrue(Pred, Op1, Op0, Q.getWithoutUndef(), RecursionLimit)) return Op1; - if (std::optional<bool> Imp = - isImpliedByDomCondition(Pred, Op0, Op1, Q.CxtI, Q.DL)) - return *Imp ? Op0 : Op1; - if (std::optional<bool> Imp = - isImpliedByDomCondition(Pred, Op1, Op0, Q.CxtI, Q.DL)) - return *Imp ? Op1 : Op0; - break; } case Intrinsic::usub_with_overflow: @@ -6276,14 +6498,10 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, // Min/max of the same operation with common operand: // m(m(X, Y)), X --> m(X, Y) (4 commuted variants) - if (auto *M0 = dyn_cast<IntrinsicInst>(Op0)) - if (M0->getIntrinsicID() == IID && - (M0->getOperand(0) == Op1 || M0->getOperand(1) == Op1)) - return Op0; - if (auto *M1 = dyn_cast<IntrinsicInst>(Op1)) - if (M1->getIntrinsicID() == IID && - (M1->getOperand(0) == Op0 || M1->getOperand(1) == Op0)) - return Op1; + if (Value *V = foldMinimumMaximumSharedOp(IID, Op0, Op1)) + return V; + if (Value *V = foldMinimumMaximumSharedOp(IID, Op1, Op0)) + return V; break; } @@ -6307,10 +6525,13 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, return nullptr; } -static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { - - unsigned NumOperands = Call->arg_size(); - Function *F = cast<Function>(Call->getCalledFunction()); +static Value *simplifyIntrinsic(CallBase *Call, Value *Callee, + ArrayRef<Value *> Args, + const SimplifyQuery &Q) { + // Operand bundles should not be in Args. + assert(Call->arg_size() == Args.size()); + unsigned NumOperands = Args.size(); + Function *F = cast<Function>(Callee); Intrinsic::ID IID = F->getIntrinsicID(); // Most of the intrinsics with no operands have some kind of side effect. @@ -6318,9 +6539,6 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { if (!NumOperands) { switch (IID) { case Intrinsic::vscale: { - // Call may not be inserted into the IR yet at point of calling simplify. - if (!Call->getParent() || !Call->getParent()->getParent()) - return nullptr; auto Attr = Call->getFunction()->getFnAttribute(Attribute::VScaleRange); if (!Attr.isValid()) return nullptr; @@ -6336,18 +6554,17 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } if (NumOperands == 1) - return simplifyUnaryIntrinsic(F, Call->getArgOperand(0), Q); + return simplifyUnaryIntrinsic(F, Args[0], Q); if (NumOperands == 2) - return simplifyBinaryIntrinsic(F, Call->getArgOperand(0), - Call->getArgOperand(1), Q); + return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q); // Handle intrinsics with 3 or more arguments. switch (IID) { case Intrinsic::masked_load: case Intrinsic::masked_gather: { - Value *MaskArg = Call->getArgOperand(2); - Value *PassthruArg = Call->getArgOperand(3); + Value *MaskArg = Args[2]; + Value *PassthruArg = Args[3]; // If the mask is all zeros or undef, the "passthru" argument is the result. if (maskIsAllZeroOrUndef(MaskArg)) return PassthruArg; @@ -6355,8 +6572,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } case Intrinsic::fshl: case Intrinsic::fshr: { - Value *Op0 = Call->getArgOperand(0), *Op1 = Call->getArgOperand(1), - *ShAmtArg = Call->getArgOperand(2); + Value *Op0 = Args[0], *Op1 = Args[1], *ShAmtArg = Args[2]; // If both operands are undef, the result is undef. if (Q.isUndefValue(Op0) && Q.isUndefValue(Op1)) @@ -6364,14 +6580,14 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { // If shift amount is undef, assume it is zero. if (Q.isUndefValue(ShAmtArg)) - return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); + return Args[IID == Intrinsic::fshl ? 0 : 1]; const APInt *ShAmtC; if (match(ShAmtArg, m_APInt(ShAmtC))) { // If there's effectively no shift, return the 1st arg or 2nd arg. APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth()); if (ShAmtC->urem(BitWidth).isZero()) - return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); + return Args[IID == Intrinsic::fshl ? 0 : 1]; } // Rotating zero by anything is zero. @@ -6385,31 +6601,24 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { return nullptr; } case Intrinsic::experimental_constrained_fma: { - Value *Op0 = Call->getArgOperand(0); - Value *Op1 = Call->getArgOperand(1); - Value *Op2 = Call->getArgOperand(2); auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - if (Value *V = - simplifyFPOp({Op0, Op1, Op2}, {}, Q, *FPI->getExceptionBehavior(), - *FPI->getRoundingMode())) + if (Value *V = simplifyFPOp(Args, {}, Q, *FPI->getExceptionBehavior(), + *FPI->getRoundingMode())) return V; return nullptr; } case Intrinsic::fma: case Intrinsic::fmuladd: { - Value *Op0 = Call->getArgOperand(0); - Value *Op1 = Call->getArgOperand(1); - Value *Op2 = Call->getArgOperand(2); - if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q, fp::ebIgnore, + if (Value *V = simplifyFPOp(Args, {}, Q, fp::ebIgnore, RoundingMode::NearestTiesToEven)) return V; return nullptr; } case Intrinsic::smul_fix: case Intrinsic::smul_fix_sat: { - Value *Op0 = Call->getArgOperand(0); - Value *Op1 = Call->getArgOperand(1); - Value *Op2 = Call->getArgOperand(2); + Value *Op0 = Args[0]; + Value *Op1 = Args[1]; + Value *Op2 = Args[2]; Type *ReturnType = F->getReturnType(); // Canonicalize constant operand as Op1 (ConstantFolding handles the case @@ -6436,9 +6645,9 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { return nullptr; } case Intrinsic::vector_insert: { - Value *Vec = Call->getArgOperand(0); - Value *SubVec = Call->getArgOperand(1); - Value *Idx = Call->getArgOperand(2); + Value *Vec = Args[0]; + Value *SubVec = Args[1]; + Value *Idx = Args[2]; Type *ReturnType = F->getReturnType(); // (insert_vector Y, (extract_vector X, 0), 0) -> X @@ -6455,51 +6664,52 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } case Intrinsic::experimental_constrained_fadd: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return simplifyFAddInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFAddInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } case Intrinsic::experimental_constrained_fsub: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return simplifyFSubInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFSubInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } case Intrinsic::experimental_constrained_fmul: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return simplifyFMulInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFMulInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } case Intrinsic::experimental_constrained_fdiv: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return simplifyFDivInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFDivInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } case Intrinsic::experimental_constrained_frem: { auto *FPI = cast<ConstrainedFPIntrinsic>(Call); - return simplifyFRemInst( - FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(), - Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); + return simplifyFRemInst(Args[0], Args[1], FPI->getFastMathFlags(), Q, + *FPI->getExceptionBehavior(), + *FPI->getRoundingMode()); } default: return nullptr; } } -static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) { - auto *F = dyn_cast<Function>(Call->getCalledOperand()); +static Value *tryConstantFoldCall(CallBase *Call, Value *Callee, + ArrayRef<Value *> Args, + const SimplifyQuery &Q) { + auto *F = dyn_cast<Function>(Callee); if (!F || !canConstantFoldCallTo(Call, F)) return nullptr; SmallVector<Constant *, 4> ConstantArgs; - unsigned NumArgs = Call->arg_size(); - ConstantArgs.reserve(NumArgs); - for (auto &Arg : Call->args()) { - Constant *C = dyn_cast<Constant>(&Arg); + ConstantArgs.reserve(Args.size()); + for (Value *Arg : Args) { + Constant *C = dyn_cast<Constant>(Arg); if (!C) { - if (isa<MetadataAsValue>(Arg.get())) + if (isa<MetadataAsValue>(Arg)) continue; return nullptr; } @@ -6509,7 +6719,11 @@ static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) { return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI); } -Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) { +Value *llvm::simplifyCall(CallBase *Call, Value *Callee, ArrayRef<Value *> Args, + const SimplifyQuery &Q) { + // Args should not contain operand bundle operands. + assert(Call->arg_size() == Args.size()); + // musttail calls can only be simplified if they are also DCEd. // As we can't guarantee this here, don't simplify them. if (Call->isMustTailCall()) @@ -6517,16 +6731,15 @@ Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) { // call undef -> poison // call null -> poison - Value *Callee = Call->getCalledOperand(); if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee)) return PoisonValue::get(Call->getType()); - if (Value *V = tryConstantFoldCall(Call, Q)) + if (Value *V = tryConstantFoldCall(Call, Callee, Args, Q)) return V; auto *F = dyn_cast<Function>(Callee); if (F && F->isIntrinsic()) - if (Value *Ret = simplifyIntrinsic(Call, Q)) + if (Value *Ret = simplifyIntrinsic(Call, Callee, Args, Q)) return Ret; return nullptr; @@ -6534,9 +6747,10 @@ Value *llvm::simplifyCall(CallBase *Call, const SimplifyQuery &Q) { Value *llvm::simplifyConstrainedFPCall(CallBase *Call, const SimplifyQuery &Q) { assert(isa<ConstrainedFPIntrinsic>(Call)); - if (Value *V = tryConstantFoldCall(Call, Q)) + SmallVector<Value *, 4> Args(Call->args()); + if (Value *V = tryConstantFoldCall(Call, Call->getCalledOperand(), Args, Q)) return V; - if (Value *Ret = simplifyIntrinsic(Call, Q)) + if (Value *Ret = simplifyIntrinsic(Call, Call->getCalledOperand(), Args, Q)) return Ret; return nullptr; } @@ -6554,27 +6768,38 @@ Value *llvm::simplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { return ::simplifyFreezeInst(Op0, Q); } -static Value *simplifyLoadInst(LoadInst *LI, Value *PtrOp, - const SimplifyQuery &Q) { +Value *llvm::simplifyLoadInst(LoadInst *LI, Value *PtrOp, + const SimplifyQuery &Q) { if (LI->isVolatile()) return nullptr; - APInt Offset(Q.DL.getIndexTypeSizeInBits(PtrOp->getType()), 0); - auto *PtrOpC = dyn_cast<Constant>(PtrOp); + if (auto *PtrOpC = dyn_cast<Constant>(PtrOp)) + return ConstantFoldLoadFromConstPtr(PtrOpC, LI->getType(), Q.DL); + + // We can only fold the load if it is from a constant global with definitive + // initializer. Skip expensive logic if this is not the case. + auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp)); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + return nullptr; + + // If GlobalVariable's initializer is uniform, then return the constant + // regardless of its offset. + if (Constant *C = + ConstantFoldLoadFromUniformValue(GV->getInitializer(), LI->getType())) + return C; + // Try to convert operand into a constant by stripping offsets while looking - // through invariant.group intrinsics. Don't bother if the underlying object - // is not constant, as calculating GEP offsets is expensive. - if (!PtrOpC && isa<Constant>(getUnderlyingObject(PtrOp))) { - PtrOp = PtrOp->stripAndAccumulateConstantOffsets( - Q.DL, Offset, /* AllowNonInbounts */ true, - /* AllowInvariantGroup */ true); + // through invariant.group intrinsics. + APInt Offset(Q.DL.getIndexTypeSizeInBits(PtrOp->getType()), 0); + PtrOp = PtrOp->stripAndAccumulateConstantOffsets( + Q.DL, Offset, /* AllowNonInbounts */ true, + /* AllowInvariantGroup */ true); + if (PtrOp == GV) { // Index size may have changed due to address space casts. Offset = Offset.sextOrTrunc(Q.DL.getIndexTypeSizeInBits(PtrOp->getType())); - PtrOpC = dyn_cast<Constant>(PtrOp); + return ConstantFoldLoadFromConstPtr(GV, LI->getType(), Offset, Q.DL); } - if (PtrOpC) - return ConstantFoldLoadFromConstPtr(PtrOpC, LI->getType(), Offset, Q.DL); return nullptr; } @@ -6584,7 +6809,8 @@ static Value *simplifyLoadInst(LoadInst *LI, Value *PtrOp, static Value *simplifyInstructionWithOperands(Instruction *I, ArrayRef<Value *> NewOps, const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { + unsigned MaxRecurse) { + assert(I->getFunction() && "instruction should be inserted in a function"); const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); switch (I->getOpcode()) { @@ -6597,97 +6823,112 @@ static Value *simplifyInstructionWithOperands(Instruction *I, } return nullptr; case Instruction::FNeg: - return simplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q); + return simplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q, MaxRecurse); case Instruction::FAdd: - return simplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + return simplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q, + MaxRecurse); case Instruction::Add: - return simplifyAddInst(NewOps[0], NewOps[1], - Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), - Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + return simplifyAddInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q, MaxRecurse); case Instruction::FSub: - return simplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + return simplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q, + MaxRecurse); case Instruction::Sub: - return simplifySubInst(NewOps[0], NewOps[1], - Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), - Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + return simplifySubInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q, MaxRecurse); case Instruction::FMul: - return simplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + return simplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q, + MaxRecurse); case Instruction::Mul: - return simplifyMulInst(NewOps[0], NewOps[1], - Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), - Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + return simplifyMulInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q, MaxRecurse); case Instruction::SDiv: return simplifySDivInst(NewOps[0], NewOps[1], - Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); + Q.IIQ.isExact(cast<BinaryOperator>(I)), Q, + MaxRecurse); case Instruction::UDiv: return simplifyUDivInst(NewOps[0], NewOps[1], - Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); + Q.IIQ.isExact(cast<BinaryOperator>(I)), Q, + MaxRecurse); case Instruction::FDiv: - return simplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + return simplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q, + MaxRecurse); case Instruction::SRem: - return simplifySRemInst(NewOps[0], NewOps[1], Q); + return simplifySRemInst(NewOps[0], NewOps[1], Q, MaxRecurse); case Instruction::URem: - return simplifyURemInst(NewOps[0], NewOps[1], Q); + return simplifyURemInst(NewOps[0], NewOps[1], Q, MaxRecurse); case Instruction::FRem: - return simplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + return simplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q, + MaxRecurse); case Instruction::Shl: - return simplifyShlInst(NewOps[0], NewOps[1], - Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), - Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + return simplifyShlInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q, MaxRecurse); case Instruction::LShr: return simplifyLShrInst(NewOps[0], NewOps[1], - Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); + Q.IIQ.isExact(cast<BinaryOperator>(I)), Q, + MaxRecurse); case Instruction::AShr: return simplifyAShrInst(NewOps[0], NewOps[1], - Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); + Q.IIQ.isExact(cast<BinaryOperator>(I)), Q, + MaxRecurse); case Instruction::And: - return simplifyAndInst(NewOps[0], NewOps[1], Q); + return simplifyAndInst(NewOps[0], NewOps[1], Q, MaxRecurse); case Instruction::Or: - return simplifyOrInst(NewOps[0], NewOps[1], Q); + return simplifyOrInst(NewOps[0], NewOps[1], Q, MaxRecurse); case Instruction::Xor: - return simplifyXorInst(NewOps[0], NewOps[1], Q); + return simplifyXorInst(NewOps[0], NewOps[1], Q, MaxRecurse); case Instruction::ICmp: return simplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), NewOps[0], - NewOps[1], Q); + NewOps[1], Q, MaxRecurse); case Instruction::FCmp: return simplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), NewOps[0], - NewOps[1], I->getFastMathFlags(), Q); + NewOps[1], I->getFastMathFlags(), Q, MaxRecurse); case Instruction::Select: - return simplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q); + return simplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q, MaxRecurse); break; case Instruction::GetElementPtr: { auto *GEPI = cast<GetElementPtrInst>(I); return simplifyGEPInst(GEPI->getSourceElementType(), NewOps[0], - ArrayRef(NewOps).slice(1), GEPI->isInBounds(), Q); + ArrayRef(NewOps).slice(1), GEPI->isInBounds(), Q, + MaxRecurse); } case Instruction::InsertValue: { InsertValueInst *IV = cast<InsertValueInst>(I); - return simplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q); + return simplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q, + MaxRecurse); } case Instruction::InsertElement: return simplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q); case Instruction::ExtractValue: { auto *EVI = cast<ExtractValueInst>(I); - return simplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q); + return simplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q, + MaxRecurse); } case Instruction::ExtractElement: - return simplifyExtractElementInst(NewOps[0], NewOps[1], Q); + return simplifyExtractElementInst(NewOps[0], NewOps[1], Q, MaxRecurse); case Instruction::ShuffleVector: { auto *SVI = cast<ShuffleVectorInst>(I); return simplifyShuffleVectorInst(NewOps[0], NewOps[1], - SVI->getShuffleMask(), SVI->getType(), Q); + SVI->getShuffleMask(), SVI->getType(), Q, + MaxRecurse); } case Instruction::PHI: return simplifyPHINode(cast<PHINode>(I), NewOps, Q); case Instruction::Call: - // TODO: Use NewOps - return simplifyCall(cast<CallInst>(I), Q); + return simplifyCall( + cast<CallInst>(I), NewOps.back(), + NewOps.drop_back(1 + cast<CallInst>(I)->getNumTotalBundleOperands()), Q); case Instruction::Freeze: return llvm::simplifyFreezeInst(NewOps[0], Q); #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - return simplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q); + return simplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q, + MaxRecurse); case Instruction::Alloca: // No simplifications for Alloca and it can't be constant folded. return nullptr; @@ -6698,17 +6939,15 @@ static Value *simplifyInstructionWithOperands(Instruction *I, Value *llvm::simplifyInstructionWithOperands(Instruction *I, ArrayRef<Value *> NewOps, - const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { + const SimplifyQuery &SQ) { assert(NewOps.size() == I->getNumOperands() && "Number of operands should match the instruction!"); - return ::simplifyInstructionWithOperands(I, NewOps, SQ, ORE); + return ::simplifyInstructionWithOperands(I, NewOps, SQ, RecursionLimit); } -Value *llvm::simplifyInstruction(Instruction *I, const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { +Value *llvm::simplifyInstruction(Instruction *I, const SimplifyQuery &SQ) { SmallVector<Value *, 8> Ops(I->operands()); - Value *Result = ::simplifyInstructionWithOperands(I, Ops, SQ, ORE); + Value *Result = ::simplifyInstructionWithOperands(I, Ops, SQ, RecursionLimit); /// If called on unreachable code, the instruction may simplify to itself. /// Make life easier for users by detecting that case here, and returning a @@ -6747,10 +6986,7 @@ static bool replaceAndRecursivelySimplifyImpl( // Replace the instruction with its simplified value. I->replaceAllUsesWith(SimpleV); - // Gracefully handle edge cases where the instruction is not wired into any - // parent block. - if (I->getParent() && !I->isEHPad() && !I->isTerminator() && - !I->mayHaveSideEffects()) + if (!I->isEHPad() && !I->isTerminator() && !I->mayHaveSideEffects()) I->eraseFromParent(); } else { Worklist.insert(I); @@ -6779,10 +7015,7 @@ static bool replaceAndRecursivelySimplifyImpl( // Replace the instruction with its simplified value. I->replaceAllUsesWith(SimpleV); - // Gracefully handle edge cases where the instruction is not wired into any - // parent block. - if (I->getParent() && !I->isEHPad() && !I->isTerminator() && - !I->mayHaveSideEffects()) + if (!I->isEHPad() && !I->isTerminator() && !I->mayHaveSideEffects()) I->eraseFromParent(); } return Simplified; diff --git a/contrib/llvm-project/llvm/lib/Analysis/InteractiveModelRunner.cpp b/contrib/llvm-project/llvm/lib/Analysis/InteractiveModelRunner.cpp new file mode 100644 index 000000000000..99b009b6616f --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Analysis/InteractiveModelRunner.cpp @@ -0,0 +1,82 @@ +//===- InteractiveModelRunner.cpp - noop ML model runner ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A runner that communicates with an external agent via 2 file descriptors. +//===----------------------------------------------------------------------===// +#include "llvm/Analysis/InteractiveModelRunner.h" +#include "llvm/Analysis/MLModelRunner.h" +#include "llvm/Analysis/TensorSpec.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +static cl::opt<bool> DebugReply( + "interactive-model-runner-echo-reply", cl::init(false), cl::Hidden, + cl::desc("The InteractiveModelRunner will echo back to stderr " + "the data received from the host (for debugging purposes).")); + +InteractiveModelRunner::InteractiveModelRunner( + LLVMContext &Ctx, const std::vector<TensorSpec> &Inputs, + const TensorSpec &Advice, StringRef OutboundName, StringRef InboundName) + : MLModelRunner(Ctx, MLModelRunner::Kind::Interactive, Inputs.size()), + InputSpecs(Inputs), OutputSpec(Advice), + InEC(sys::fs::openFileForRead(InboundName, Inbound)), + OutputBuffer(OutputSpec.getTotalTensorBufferSize()) { + if (InEC) { + Ctx.emitError("Cannot open inbound file: " + InEC.message()); + return; + } + { + auto OutStream = std::make_unique<raw_fd_ostream>(OutboundName, OutEC); + if (OutEC) { + Ctx.emitError("Cannot open outbound file: " + OutEC.message()); + return; + } + Log = std::make_unique<Logger>(std::move(OutStream), InputSpecs, Advice, + /*IncludeReward=*/false, Advice); + } + // Just like in the no inference case, this will allocate an appropriately + // sized buffer. + for (size_t I = 0; I < InputSpecs.size(); ++I) + setUpBufferForTensor(I, InputSpecs[I], nullptr); + Log->flush(); +} + +InteractiveModelRunner::~InteractiveModelRunner() { + sys::fs::file_t FDAsOSHandle = sys::fs::convertFDToNativeFile(Inbound); + sys::fs::closeFile(FDAsOSHandle); +} + +void *InteractiveModelRunner::evaluateUntyped() { + Log->startObservation(); + for (size_t I = 0; I < InputSpecs.size(); ++I) + Log->logTensorValue(I, reinterpret_cast<const char *>(getTensorUntyped(I))); + Log->endObservation(); + Log->flush(); + + size_t InsPoint = 0; + char *Buff = OutputBuffer.data(); + const size_t Limit = OutputBuffer.size(); + while (InsPoint < Limit) { + auto ReadOrErr = ::sys::fs::readNativeFile( + sys::fs::convertFDToNativeFile(Inbound), + {Buff + InsPoint, OutputBuffer.size() - InsPoint}); + if (ReadOrErr.takeError()) { + Ctx.emitError("Failed reading from inbound file"); + break; + } + InsPoint += *ReadOrErr; + } + if (DebugReply) + dbgs() << OutputSpec.name() << ": " + << tensorValueToString(OutputBuffer.data(), OutputSpec) << "\n"; + return OutputBuffer.data(); +} diff --git a/contrib/llvm-project/llvm/lib/Analysis/LazyValueInfo.cpp b/contrib/llvm-project/llvm/lib/Analysis/LazyValueInfo.cpp index 0e9fab667e6e..33651783cb17 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/LazyValueInfo.cpp @@ -162,7 +162,7 @@ namespace { struct BlockCacheEntry { SmallDenseMap<AssertingVH<Value>, ValueLatticeElement, 4> LatticeElements; SmallDenseSet<AssertingVH<Value>, 4> OverDefined; - // None indicates that the nonnull pointers for this basic block + // std::nullopt indicates that the nonnull pointers for this basic block // block have not been computed yet. std::optional<NonNullPointerSet> NonNullPointers; }; @@ -876,10 +876,14 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) { // condition itself? This shows up with idioms like e.g. select(a > 5, a, 5). // TODO: We could potentially refine an overdefined true value above. Value *Cond = SI->getCondition(); - TrueVal = intersect(TrueVal, - getValueFromCondition(SI->getTrueValue(), Cond, true)); - FalseVal = intersect(FalseVal, - getValueFromCondition(SI->getFalseValue(), Cond, false)); + // If the value is undef, a different value may be chosen in + // the select condition. + if (isGuaranteedNotToBeUndefOrPoison(Cond, AC)) { + TrueVal = intersect(TrueVal, + getValueFromCondition(SI->getTrueValue(), Cond, true)); + FalseVal = intersect( + FalseVal, getValueFromCondition(SI->getFalseValue(), Cond, false)); + } ValueLatticeElement Result = TrueVal; Result.mergeIn(FalseVal); @@ -990,10 +994,11 @@ LazyValueInfoImpl::solveBlockValueOverflowIntrinsic(WithOverflowInst *WO, std::optional<ValueLatticeElement> LazyValueInfoImpl::solveBlockValueIntrinsic(IntrinsicInst *II, BasicBlock *BB) { + ValueLatticeElement MetadataVal = getFromRangeMetadata(II); if (!ConstantRange::isIntrinsicSupported(II->getIntrinsicID())) { LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() << "' - unknown intrinsic.\n"); - return getFromRangeMetadata(II); + return MetadataVal; } SmallVector<ConstantRange, 2> OpRanges; @@ -1004,8 +1009,9 @@ LazyValueInfoImpl::solveBlockValueIntrinsic(IntrinsicInst *II, BasicBlock *BB) { OpRanges.push_back(*Range); } - return ValueLatticeElement::getRange( - ConstantRange::intrinsic(II->getIntrinsicID(), OpRanges)); + return intersect(ValueLatticeElement::getRange(ConstantRange::intrinsic( + II->getIntrinsicID(), OpRanges)), + MetadataVal); } std::optional<ValueLatticeElement> @@ -1123,7 +1129,7 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI, // bit of Mask. if (EdgePred == ICmpInst::ICMP_NE && !Mask->isZero() && C->isZero()) { return ValueLatticeElement::getRange(ConstantRange::getNonEmpty( - APInt::getOneBitSet(BitWidth, Mask->countTrailingZeros()), + APInt::getOneBitSet(BitWidth, Mask->countr_zero()), APInt::getZero(BitWidth))); } } @@ -1665,6 +1671,10 @@ ConstantRange LazyValueInfo::getConstantRangeAtUse(const Use &U, std::optional<ValueLatticeElement> CondVal; auto *CurrI = cast<Instruction>(CurrU->getUser()); if (auto *SI = dyn_cast<SelectInst>(CurrI)) { + // If the value is undef, a different value may be chosen in + // the select condition and at use. + if (!isGuaranteedNotToBeUndefOrPoison(SI->getCondition(), AC)) + break; if (CurrU->getOperandNo() == 1) CondVal = getValueFromCondition(V, SI->getCondition(), true); else if (CurrU->getOperandNo() == 2) @@ -1739,7 +1749,7 @@ getPredicateResult(unsigned Pred, Constant *C, const ValueLatticeElement &Val, Constant *Res = nullptr; if (Val.isConstant()) { Res = ConstantFoldCompareInstOperands(Pred, Val.getConstant(), C, DL, TLI); - if (ConstantInt *ResCI = dyn_cast<ConstantInt>(Res)) + if (ConstantInt *ResCI = dyn_cast_or_null<ConstantInt>(Res)) return ResCI->isZero() ? LazyValueInfo::False : LazyValueInfo::True; return LazyValueInfo::Unknown; } @@ -1781,14 +1791,14 @@ getPredicateResult(unsigned Pred, Constant *C, const ValueLatticeElement &Val, Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE, Val.getNotConstant(), C, DL, TLI); - if (Res->isNullValue()) + if (Res && Res->isNullValue()) return LazyValueInfo::False; } else if (Pred == ICmpInst::ICMP_NE) { // !C1 != C -> true iff C1 == C. Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE, Val.getNotConstant(), C, DL, TLI); - if (Res->isNullValue()) + if (Res && Res->isNullValue()) return LazyValueInfo::True; } return LazyValueInfo::Unknown; diff --git a/contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp deleted file mode 100644 index baa7e9daa0ae..000000000000 --- a/contrib/llvm-project/llvm/lib/Analysis/LegacyDivergenceAnalysis.cpp +++ /dev/null @@ -1,435 +0,0 @@ -//===- LegacyDivergenceAnalysis.cpp --------- Legacy Divergence Analysis -//Implementation -==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements divergence analysis which determines whether a branch -// in a GPU program is divergent.It can help branch optimizations such as jump -// threading and loop unswitching to make better decisions. -// -// GPU programs typically use the SIMD execution model, where multiple threads -// in the same execution group have to execute in lock-step. Therefore, if the -// code contains divergent branches (i.e., threads in a group do not agree on -// which path of the branch to take), the group of threads has to execute all -// the paths from that branch with different subsets of threads enabled until -// they converge at the immediately post-dominating BB of the paths. -// -// Due to this execution model, some optimizations such as jump -// threading and loop unswitching can be unfortunately harmful when performed on -// divergent branches. Therefore, an analysis that computes which branches in a -// GPU program are divergent can help the compiler to selectively run these -// optimizations. -// -// This file defines divergence analysis which computes a conservative but -// non-trivial approximation of all divergent branches in a GPU program. It -// partially implements the approach described in -// -// Divergence Analysis -// Sampaio, Souza, Collange, Pereira -// TOPLAS '13 -// -// The divergence analysis identifies the sources of divergence (e.g., special -// variables that hold the thread ID), and recursively marks variables that are -// data or sync dependent on a source of divergence as divergent. -// -// While data dependency is a well-known concept, the notion of sync dependency -// is worth more explanation. Sync dependence characterizes the control flow -// aspect of the propagation of branch divergence. For example, -// -// %cond = icmp slt i32 %tid, 10 -// br i1 %cond, label %then, label %else -// then: -// br label %merge -// else: -// br label %merge -// merge: -// %a = phi i32 [ 0, %then ], [ 1, %else ] -// -// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid -// because %tid is not on its use-def chains, %a is sync dependent on %tid -// because the branch "br i1 %cond" depends on %tid and affects which value %a -// is assigned to. -// -// The current implementation has the following limitations: -// 1. intra-procedural. It conservatively considers the arguments of a -// non-kernel-entry function and the return value of a function call as -// divergent. -// 2. memory as black box. It conservatively considers values loaded from -// generic or local address as divergent. This can be improved by leveraging -// pointer analysis. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Analysis/LegacyDivergenceAnalysis.h" -#include "llvm/ADT/PostOrderIterator.h" -#include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/DivergenceAnalysis.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/Passes.h" -#include "llvm/Analysis/PostDominators.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include <vector> -using namespace llvm; - -#define DEBUG_TYPE "divergence" - -// transparently use the GPUDivergenceAnalysis -static cl::opt<bool> UseGPUDA("use-gpu-divergence-analysis", cl::init(false), - cl::Hidden, - cl::desc("turn the LegacyDivergenceAnalysis into " - "a wrapper for GPUDivergenceAnalysis")); - -namespace { - -class DivergencePropagator { -public: - DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT, - PostDominatorTree &PDT, DenseSet<const Value *> &DV, - DenseSet<const Use *> &DU) - : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV), DU(DU) {} - void populateWithSourcesOfDivergence(); - void propagate(); - -private: - // A helper function that explores data dependents of V. - void exploreDataDependency(Value *V); - // A helper function that explores sync dependents of TI. - void exploreSyncDependency(Instruction *TI); - // Computes the influence region from Start to End. This region includes all - // basic blocks on any simple path from Start to End. - void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End, - DenseSet<BasicBlock *> &InfluenceRegion); - // Finds all users of I that are outside the influence region, and add these - // users to Worklist. - void findUsersOutsideInfluenceRegion( - Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion); - - Function &F; - TargetTransformInfo &TTI; - DominatorTree &DT; - PostDominatorTree &PDT; - std::vector<Value *> Worklist; // Stack for DFS. - DenseSet<const Value *> &DV; // Stores all divergent values. - DenseSet<const Use *> &DU; // Stores divergent uses of possibly uniform - // values. -}; - -void DivergencePropagator::populateWithSourcesOfDivergence() { - Worklist.clear(); - DV.clear(); - DU.clear(); - for (auto &I : instructions(F)) { - if (TTI.isSourceOfDivergence(&I)) { - Worklist.push_back(&I); - DV.insert(&I); - } - } - for (auto &Arg : F.args()) { - if (TTI.isSourceOfDivergence(&Arg)) { - Worklist.push_back(&Arg); - DV.insert(&Arg); - } - } -} - -void DivergencePropagator::exploreSyncDependency(Instruction *TI) { - // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's - // immediate post dominator are divergent. This rule handles if-then-else - // patterns. For example, - // - // if (tid < 5) - // a1 = 1; - // else - // a2 = 2; - // a = phi(a1, a2); // sync dependent on (tid < 5) - BasicBlock *ThisBB = TI->getParent(); - - // Unreachable blocks may not be in the dominator tree. - if (!DT.isReachableFromEntry(ThisBB)) - return; - - // If the function has no exit blocks or doesn't reach any exit blocks, the - // post dominator may be null. - DomTreeNode *ThisNode = PDT.getNode(ThisBB); - if (!ThisNode) - return; - - BasicBlock *IPostDom = ThisNode->getIDom()->getBlock(); - if (IPostDom == nullptr) - return; - - for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) { - // A PHINode is uniform if it returns the same value no matter which path is - // taken. - if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second) - Worklist.push_back(&*I); - } - - // Propagation rule 2: if a value defined in a loop is used outside, the user - // is sync dependent on the condition of the loop exits that dominate the - // user. For example, - // - // int i = 0; - // do { - // i++; - // if (foo(i)) ... // uniform - // } while (i < tid); - // if (bar(i)) ... // divergent - // - // A program may contain unstructured loops. Therefore, we cannot leverage - // LoopInfo, which only recognizes natural loops. - // - // The algorithm used here handles both natural and unstructured loops. Given - // a branch TI, we first compute its influence region, the union of all simple - // paths from TI to its immediate post dominator (IPostDom). Then, we search - // for all the values defined in the influence region but used outside. All - // these users are sync dependent on TI. - DenseSet<BasicBlock *> InfluenceRegion; - computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion); - // An insight that can speed up the search process is that all the in-region - // values that are used outside must dominate TI. Therefore, instead of - // searching every basic blocks in the influence region, we search all the - // dominators of TI until it is outside the influence region. - BasicBlock *InfluencedBB = ThisBB; - while (InfluenceRegion.count(InfluencedBB)) { - for (auto &I : *InfluencedBB) { - if (!DV.count(&I)) - findUsersOutsideInfluenceRegion(I, InfluenceRegion); - } - DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom(); - if (IDomNode == nullptr) - break; - InfluencedBB = IDomNode->getBlock(); - } -} - -void DivergencePropagator::findUsersOutsideInfluenceRegion( - Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) { - for (Use &Use : I.uses()) { - Instruction *UserInst = cast<Instruction>(Use.getUser()); - if (!InfluenceRegion.count(UserInst->getParent())) { - DU.insert(&Use); - if (DV.insert(UserInst).second) - Worklist.push_back(UserInst); - } - } -} - -// A helper function for computeInfluenceRegion that adds successors of "ThisBB" -// to the influence region. -static void -addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End, - DenseSet<BasicBlock *> &InfluenceRegion, - std::vector<BasicBlock *> &InfluenceStack) { - for (BasicBlock *Succ : successors(ThisBB)) { - if (Succ != End && InfluenceRegion.insert(Succ).second) - InfluenceStack.push_back(Succ); - } -} - -void DivergencePropagator::computeInfluenceRegion( - BasicBlock *Start, BasicBlock *End, - DenseSet<BasicBlock *> &InfluenceRegion) { - assert(PDT.properlyDominates(End, Start) && - "End does not properly dominate Start"); - - // The influence region starts from the end of "Start" to the beginning of - // "End". Therefore, "Start" should not be in the region unless "Start" is in - // a loop that doesn't contain "End". - std::vector<BasicBlock *> InfluenceStack; - addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack); - while (!InfluenceStack.empty()) { - BasicBlock *BB = InfluenceStack.back(); - InfluenceStack.pop_back(); - addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack); - } -} - -void DivergencePropagator::exploreDataDependency(Value *V) { - // Follow def-use chains of V. - for (User *U : V->users()) { - if (!TTI.isAlwaysUniform(U) && DV.insert(U).second) - Worklist.push_back(U); - } -} - -void DivergencePropagator::propagate() { - // Traverse the dependency graph using DFS. - while (!Worklist.empty()) { - Value *V = Worklist.back(); - Worklist.pop_back(); - if (Instruction *I = dyn_cast<Instruction>(V)) { - // Terminators with less than two successors won't introduce sync - // dependency. Ignore them. - if (I->isTerminator() && I->getNumSuccessors() > 1) - exploreSyncDependency(I); - } - exploreDataDependency(V); - } -} - -} // namespace - -// Register this pass. -char LegacyDivergenceAnalysis::ID = 0; -LegacyDivergenceAnalysis::LegacyDivergenceAnalysis() : FunctionPass(ID) { - initializeLegacyDivergenceAnalysisPass(*PassRegistry::getPassRegistry()); -} -INITIALIZE_PASS_BEGIN(LegacyDivergenceAnalysis, "divergence", - "Legacy Divergence Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(LegacyDivergenceAnalysis, "divergence", - "Legacy Divergence Analysis", false, true) - -FunctionPass *llvm::createLegacyDivergenceAnalysisPass() { - return new LegacyDivergenceAnalysis(); -} - -bool LegacyDivergenceAnalysisImpl::shouldUseGPUDivergenceAnalysis( - const Function &F, const TargetTransformInfo &TTI, const LoopInfo &LI) { - if (!(UseGPUDA || TTI.useGPUDivergenceAnalysis())) - return false; - - // GPUDivergenceAnalysis requires a reducible CFG. - using RPOTraversal = ReversePostOrderTraversal<const Function *>; - RPOTraversal FuncRPOT(&F); - return !containsIrreducibleCFG<const BasicBlock *, const RPOTraversal, - const LoopInfo>(FuncRPOT, LI); -} - -void LegacyDivergenceAnalysisImpl::run(Function &F, - llvm::TargetTransformInfo &TTI, - llvm::DominatorTree &DT, - llvm::PostDominatorTree &PDT, - const llvm::LoopInfo &LI) { - if (shouldUseGPUDivergenceAnalysis(F, TTI, LI)) { - // run the new GPU divergence analysis - gpuDA = std::make_unique<DivergenceInfo>(F, DT, PDT, LI, TTI, - /* KnownReducible = */ true); - - } else { - // run LLVM's existing DivergenceAnalysis - DivergencePropagator DP(F, TTI, DT, PDT, DivergentValues, DivergentUses); - DP.populateWithSourcesOfDivergence(); - DP.propagate(); - } -} - -bool LegacyDivergenceAnalysisImpl::isDivergent(const Value *V) const { - if (gpuDA) { - return gpuDA->isDivergent(*V); - } - return DivergentValues.count(V); -} - -bool LegacyDivergenceAnalysisImpl::isDivergentUse(const Use *U) const { - if (gpuDA) { - return gpuDA->isDivergentUse(*U); - } - return DivergentValues.count(U->get()) || DivergentUses.count(U); -} - -void LegacyDivergenceAnalysisImpl::print(raw_ostream &OS, - const Module *) const { - if ((!gpuDA || !gpuDA->hasDivergence()) && DivergentValues.empty()) - return; - - const Function *F = nullptr; - if (!DivergentValues.empty()) { - const Value *FirstDivergentValue = *DivergentValues.begin(); - if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) { - F = Arg->getParent(); - } else if (const Instruction *I = - dyn_cast<Instruction>(FirstDivergentValue)) { - F = I->getParent()->getParent(); - } else { - llvm_unreachable("Only arguments and instructions can be divergent"); - } - } else if (gpuDA) { - F = &gpuDA->getFunction(); - } - if (!F) - return; - - // Dumps all divergent values in F, arguments and then instructions. - for (const auto &Arg : F->args()) { - OS << (isDivergent(&Arg) ? "DIVERGENT: " : " "); - OS << Arg << "\n"; - } - // Iterate instructions using instructions() to ensure a deterministic order. - for (const BasicBlock &BB : *F) { - OS << "\n " << BB.getName() << ":\n"; - for (const auto &I : BB.instructionsWithoutDebug()) { - OS << (isDivergent(&I) ? "DIVERGENT: " : " "); - OS << I << "\n"; - } - } - OS << "\n"; -} - -void LegacyDivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequiredTransitive<DominatorTreeWrapperPass>(); - AU.addRequiredTransitive<PostDominatorTreeWrapperPass>(); - AU.addRequiredTransitive<LoopInfoWrapperPass>(); - AU.setPreservesAll(); -} - -bool LegacyDivergenceAnalysis::runOnFunction(Function &F) { - auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); - if (TTIWP == nullptr) - return false; - - TargetTransformInfo &TTI = TTIWP->getTTI(F); - // Fast path: if the target does not have branch divergence, we do not mark - // any branch as divergent. - if (!TTI.hasBranchDivergence()) - return false; - - DivergentValues.clear(); - DivergentUses.clear(); - gpuDA = nullptr; - - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LegacyDivergenceAnalysisImpl::run(F, TTI, DT, PDT, LI); - LLVM_DEBUG(dbgs() << "\nAfter divergence analysis on " << F.getName() - << ":\n"; - LegacyDivergenceAnalysisImpl::print(dbgs(), F.getParent())); - - return false; -} - -PreservedAnalyses -LegacyDivergenceAnalysisPass::run(Function &F, FunctionAnalysisManager &AM) { - auto &TTI = AM.getResult<TargetIRAnalysis>(F); - if (!TTI.hasBranchDivergence()) - return PreservedAnalyses::all(); - - DivergentValues.clear(); - DivergentUses.clear(); - gpuDA = nullptr; - - auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); - auto &LI = AM.getResult<LoopAnalysis>(F); - LegacyDivergenceAnalysisImpl::run(F, TTI, DT, PDT, LI); - LLVM_DEBUG(dbgs() << "\nAfter divergence analysis on " << F.getName() - << ":\n"; - LegacyDivergenceAnalysisImpl::print(dbgs(), F.getParent())); - return PreservedAnalyses::all(); -} diff --git a/contrib/llvm-project/llvm/lib/Analysis/Lint.cpp b/contrib/llvm-project/llvm/lib/Analysis/Lint.cpp index d3120a41ac27..ff022006df65 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/Lint.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/Lint.cpp @@ -40,11 +40,14 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TypeBasedAliasAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -60,13 +63,10 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" @@ -93,8 +93,6 @@ class Lint : public InstVisitor<Lint> { void visitCallBase(CallBase &CB); void visitMemoryReference(Instruction &I, const MemoryLocation &Loc, MaybeAlign Alignment, Type *Ty, unsigned Flags); - void visitEHBeginCatch(IntrinsicInst *II); - void visitEHEndCatch(IntrinsicInst *II); void visitReturnInst(ReturnInst &I); void visitLoadInst(LoadInst &I); @@ -715,73 +713,35 @@ PreservedAnalyses LintPass::run(Function &F, FunctionAnalysisManager &AM) { return PreservedAnalyses::all(); } -namespace { -class LintLegacyPass : public FunctionPass { -public: - static char ID; // Pass identification, replacement for typeid - LintLegacyPass() : FunctionPass(ID) { - initializeLintLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - } - void print(raw_ostream &O, const Module *M) const override {} -}; -} // namespace - -char LintLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LintLegacyPass, "lint", "Statically lint-checks LLVM IR", - false, true) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(LintLegacyPass, "lint", "Statically lint-checks LLVM IR", - false, true) - -bool LintLegacyPass::runOnFunction(Function &F) { - auto *Mod = F.getParent(); - auto *DL = &F.getParent()->getDataLayout(); - auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - Lint L(Mod, DL, AA, AC, DT, TLI); - L.visit(F); - dbgs() << L.MessagesStr.str(); - return false; -} - //===----------------------------------------------------------------------===// // Implement the public interfaces to this file... //===----------------------------------------------------------------------===// -FunctionPass *llvm::createLintLegacyPassPass() { return new LintLegacyPass(); } - /// lintFunction - Check a function for errors, printing messages on stderr. /// void llvm::lintFunction(const Function &f) { Function &F = const_cast<Function &>(f); assert(!F.isDeclaration() && "Cannot lint external functions"); - legacy::FunctionPassManager FPM(F.getParent()); - auto *V = new LintLegacyPass(); - FPM.add(V); - FPM.run(F); + FunctionAnalysisManager FAM; + FAM.registerPass([&] { return TargetLibraryAnalysis(); }); + FAM.registerPass([&] { return DominatorTreeAnalysis(); }); + FAM.registerPass([&] { return AssumptionAnalysis(); }); + FAM.registerPass([&] { + AAManager AA; + AA.registerFunctionAnalysis<BasicAA>(); + AA.registerFunctionAnalysis<ScopedNoAliasAA>(); + AA.registerFunctionAnalysis<TypeBasedAA>(); + return AA; + }); + LintPass().run(F, FAM); } /// lintModule - Check a module for errors, printing messages on stderr. /// void llvm::lintModule(const Module &M) { - legacy::PassManager PM; - auto *V = new LintLegacyPass(); - PM.add(V); - PM.run(const_cast<Module &>(M)); + for (const Function &F : M) { + if (!F.isDeclaration()) + lintFunction(F); + } } diff --git a/contrib/llvm-project/llvm/lib/Analysis/Loads.cpp b/contrib/llvm-project/llvm/lib/Analysis/Loads.cpp index f55333303f8d..97d21db86abf 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/Loads.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/Loads.cpp @@ -29,9 +29,7 @@ using namespace llvm; static bool isAligned(const Value *Base, const APInt &Offset, Align Alignment, const DataLayout &DL) { Align BA = Base->getPointerAlignment(DL); - const APInt APAlign(Offset.getBitWidth(), Alignment.value()); - assert(APAlign.isPowerOf2() && "must be a power of 2!"); - return BA >= Alignment && !(Offset & (APAlign - 1)); + return BA >= Alignment && Offset.isAligned(BA); } /// Test if V is always a pointer to allocated and suitably aligned memory for @@ -204,7 +202,7 @@ bool llvm::isDereferenceableAndAlignedPointer( const TargetLibraryInfo *TLI) { // For unsized types or scalable vectors we don't know exactly how many bytes // are dereferenced, so bail out. - if (!Ty->isSized() || isa<ScalableVectorType>(Ty)) + if (!Ty->isSized() || Ty->isScalableTy()) return false; // When dereferenceability information is provided by a dereferenceable @@ -286,21 +284,48 @@ bool llvm::isDereferenceableAndAlignedInLoop(LoadInst *LI, Loop *L, auto* Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(SE)); if (!Step) return false; - // TODO: generalize to access patterns which have gaps - if (Step->getAPInt() != EltSize) - return false; auto TC = SE.getSmallConstantMaxTripCount(L); if (!TC) return false; - const APInt AccessSize = TC * EltSize; + // TODO: Handle overlapping accesses. + // We should be computing AccessSize as (TC - 1) * Step + EltSize. + if (EltSize.sgt(Step->getAPInt())) + return false; + + // Compute the total access size for access patterns with unit stride and + // patterns with gaps. For patterns with unit stride, Step and EltSize are the + // same. + // For patterns with gaps (i.e. non unit stride), we are + // accessing EltSize bytes at every Step. + APInt AccessSize = TC * Step->getAPInt(); + + assert(SE.isLoopInvariant(AddRec->getStart(), L) && + "implied by addrec definition"); + Value *Base = nullptr; + if (auto *StartS = dyn_cast<SCEVUnknown>(AddRec->getStart())) { + Base = StartS->getValue(); + } else if (auto *StartS = dyn_cast<SCEVAddExpr>(AddRec->getStart())) { + // Handle (NewBase + offset) as start value. + const auto *Offset = dyn_cast<SCEVConstant>(StartS->getOperand(0)); + const auto *NewBase = dyn_cast<SCEVUnknown>(StartS->getOperand(1)); + if (StartS->getNumOperands() == 2 && Offset && NewBase) { + // For the moment, restrict ourselves to the case where the offset is a + // multiple of the requested alignment and the base is aligned. + // TODO: generalize if a case found which warrants + if (Offset->getAPInt().urem(Alignment.value()) != 0) + return false; + Base = NewBase->getValue(); + bool Overflow = false; + AccessSize = AccessSize.uadd_ov(Offset->getAPInt(), Overflow); + if (Overflow) + return false; + } + } - auto *StartS = dyn_cast<SCEVUnknown>(AddRec->getStart()); - if (!StartS) + if (!Base) return false; - assert(SE.isLoopInvariant(StartS, L) && "implied by addrec definition"); - Value *Base = StartS->getValue(); // For the moment, restrict ourselves to the case where the access size is a // multiple of the requested alignment and the base is aligned. @@ -653,7 +678,7 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, AAResults &AA, // Try to find an available value first, and delay expensive alias analysis // queries until later. - Value *Available = nullptr;; + Value *Available = nullptr; SmallVector<Instruction *> MustNotAliasInsts; for (Instruction &Inst : make_range(++Load->getReverseIterator(), ScanBB->rend())) { diff --git a/contrib/llvm-project/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 9e110567e98e..fd0e81c51ac8 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -43,6 +43,7 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" @@ -136,37 +137,37 @@ static cl::opt<unsigned> MaxForkedSCEVDepth( cl::desc("Maximum recursion depth when finding forked SCEVs (default = 5)"), cl::init(5)); +static cl::opt<bool> SpeculateUnitStride( + "laa-speculate-unit-stride", cl::Hidden, + cl::desc("Speculate that non-constant strides are unit in LAA"), + cl::init(true)); + bool VectorizerParams::isInterleaveForced() { return ::VectorizationInterleave.getNumOccurrences() > 0; } -Value *llvm::stripIntegerCast(Value *V) { - if (auto *CI = dyn_cast<CastInst>(V)) - if (CI->getOperand(0)->getType()->isIntegerTy()) - return CI->getOperand(0); - return V; -} - const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE, - const ValueToValueMap &PtrToStride, + const DenseMap<Value *, const SCEV *> &PtrToStride, Value *Ptr) { const SCEV *OrigSCEV = PSE.getSCEV(Ptr); // If there is an entry in the map return the SCEV of the pointer with the // symbolic stride replaced by one. - ValueToValueMap::const_iterator SI = PtrToStride.find(Ptr); + DenseMap<Value *, const SCEV *>::const_iterator SI = PtrToStride.find(Ptr); if (SI == PtrToStride.end()) // For a non-symbolic stride, just return the original expression. return OrigSCEV; - Value *StrideVal = stripIntegerCast(SI->second); + const SCEV *StrideSCEV = SI->second; + // Note: This assert is both overly strong and overly weak. The actual + // invariant here is that StrideSCEV should be loop invariant. The only + // such invariant strides we happen to speculate right now are unknowns + // and thus this is a reasonable proxy of the actual invariant. + assert(isa<SCEVUnknown>(StrideSCEV) && "shouldn't be in map"); ScalarEvolution *SE = PSE.getSE(); - const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal)); - const auto *CT = - static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType())); - - PSE.addPredicate(*SE->getEqualPredicate(U, CT)); + const auto *CT = SE->getOne(StrideSCEV->getType()); + PSE.addPredicate(*SE->getEqualPredicate(StrideSCEV, CT)); auto *Expr = PSE.getSCEV(Ptr); LLVM_DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV @@ -231,6 +232,9 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd); } } + assert(SE->isLoopInvariant(ScStart, Lp) && "ScStart needs to be invariant"); + assert(SE->isLoopInvariant(ScEnd, Lp)&& "ScEnd needs to be invariant"); + // Add the size of the pointed element to ScEnd. auto &DL = Lp->getHeader()->getModule()->getDataLayout(); Type *IdxTy = DL.getIndexType(Ptr->getType()); @@ -652,7 +656,7 @@ public: /// the bounds of the pointer. bool createCheckForAccess(RuntimePointerChecking &RtCheck, MemAccessInfo Access, Type *AccessTy, - const ValueToValueMap &Strides, + const DenseMap<Value *, const SCEV *> &Strides, DenseMap<Value *, unsigned> &DepSetId, Loop *TheLoop, unsigned &RunningDepId, unsigned ASId, bool ShouldCheckStride, bool Assume); @@ -663,7 +667,7 @@ public: /// Returns true if we need no check or if we do and we can generate them /// (i.e. the pointers have computable bounds). bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, - Loop *TheLoop, const ValueToValueMap &Strides, + Loop *TheLoop, const DenseMap<Value *, const SCEV *> &Strides, Value *&UncomputablePtr, bool ShouldCheckWrap = false); /// Goes over all memory accesses, checks whether a RT check is needed @@ -758,7 +762,7 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr, /// Check whether a pointer address cannot wrap. static bool isNoWrap(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, Type *AccessTy, + const DenseMap<Value *, const SCEV *> &Strides, Value *Ptr, Type *AccessTy, Loop *L) { const SCEV *PtrScev = PSE.getSCEV(Ptr); if (PSE.getSE()->isLoopInvariant(PtrScev, L)) @@ -951,7 +955,7 @@ static void findForkedSCEVs( static SmallVector<PointerIntPair<const SCEV *, 1, bool>> findForkedPointer(PredicatedScalarEvolution &PSE, - const ValueToValueMap &StridesMap, Value *Ptr, + const DenseMap<Value *, const SCEV *> &StridesMap, Value *Ptr, const Loop *L) { ScalarEvolution *SE = PSE.getSE(); assert(SE->isSCEVable(Ptr->getType()) && "Value is not SCEVable!"); @@ -976,7 +980,7 @@ findForkedPointer(PredicatedScalarEvolution &PSE, bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, MemAccessInfo Access, Type *AccessTy, - const ValueToValueMap &StridesMap, + const DenseMap<Value *, const SCEV *> &StridesMap, DenseMap<Value *, unsigned> &DepSetId, Loop *TheLoop, unsigned &RunningDepId, unsigned ASId, bool ShouldCheckWrap, @@ -1037,7 +1041,7 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, Loop *TheLoop, - const ValueToValueMap &StridesMap, + const DenseMap<Value *, const SCEV *> &StridesMap, Value *&UncomputablePtr, bool ShouldCheckWrap) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. @@ -1311,20 +1315,18 @@ void AccessAnalysis::processMemAccesses() { } } -static bool isInBoundsGep(Value *Ptr) { - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr)) - return GEP->isInBounds(); - return false; -} - /// Return true if an AddRec pointer \p Ptr is unsigned non-wrapping, /// i.e. monotonically increasing/decreasing. static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, PredicatedScalarEvolution &PSE, const Loop *L) { + // FIXME: This should probably only return true for NUW. if (AR->getNoWrapFlags(SCEV::NoWrapMask)) return true; + if (PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW)) + return true; + // Scalar evolution does not propagate the non-wrapping flags to values that // are derived from a non-wrapping induction variable because non-wrapping // could be flow-sensitive. @@ -1369,7 +1371,7 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, std::optional<int64_t> llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap, + const DenseMap<Value *, const SCEV *> &StridesMap, bool Assume, bool ShouldCheckWrap) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -1399,35 +1401,6 @@ std::optional<int64_t> llvm::getPtrStride(PredicatedScalarEvolution &PSE, return std::nullopt; } - // The address calculation must not wrap. Otherwise, a dependence could be - // inverted. - // An inbounds getelementptr that is a AddRec with a unit stride - // cannot wrap per definition. The unit stride requirement is checked later. - // An getelementptr without an inbounds attribute and unit stride would have - // to access the pointer value "0" which is undefined behavior in address - // space 0, therefore we can also vectorize this case. - unsigned AddrSpace = Ty->getPointerAddressSpace(); - bool IsInBoundsGEP = isInBoundsGep(Ptr); - bool IsNoWrapAddRec = !ShouldCheckWrap || - PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) || - isNoWrapAddRec(Ptr, AR, PSE, Lp); - if (!IsNoWrapAddRec && !IsInBoundsGEP && - NullPointerIsDefined(Lp->getHeader()->getParent(), AddrSpace)) { - if (Assume) { - PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); - IsNoWrapAddRec = true; - LLVM_DEBUG(dbgs() << "LAA: Pointer may wrap in the address space:\n" - << "LAA: Pointer: " << *Ptr << "\n" - << "LAA: SCEV: " << *AR << "\n" - << "LAA: Added an overflow assumption\n"); - } else { - LLVM_DEBUG( - dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " - << *Ptr << " SCEV: " << *AR << "\n"); - return std::nullopt; - } - } - // Check the step is constant. const SCEV *Step = AR->getStepRecurrence(*PSE.getSE()); @@ -1456,25 +1429,42 @@ std::optional<int64_t> llvm::getPtrStride(PredicatedScalarEvolution &PSE, if (Rem) return std::nullopt; - // If the SCEV could wrap but we have an inbounds gep with a unit stride we - // know we can't "wrap around the address space". In case of address space - // zero we know that this won't happen without triggering undefined behavior. - if (!IsNoWrapAddRec && Stride != 1 && Stride != -1 && - (IsInBoundsGEP || !NullPointerIsDefined(Lp->getHeader()->getParent(), - AddrSpace))) { - if (Assume) { - // We can avoid this case by adding a run-time check. - LLVM_DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either " - << "inbounds or in address space 0 may wrap:\n" - << "LAA: Pointer: " << *Ptr << "\n" - << "LAA: SCEV: " << *AR << "\n" - << "LAA: Added an overflow assumption\n"); - PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); - } else - return std::nullopt; - } + if (!ShouldCheckWrap) + return Stride; + + // The address calculation must not wrap. Otherwise, a dependence could be + // inverted. + if (isNoWrapAddRec(Ptr, AR, PSE, Lp)) + return Stride; - return Stride; + // An inbounds getelementptr that is a AddRec with a unit stride + // cannot wrap per definition. If it did, the result would be poison + // and any memory access dependent on it would be immediate UB + // when executed. + if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); + GEP && GEP->isInBounds() && (Stride == 1 || Stride == -1)) + return Stride; + + // If the null pointer is undefined, then a access sequence which would + // otherwise access it can be assumed not to unsigned wrap. Note that this + // assumes the object in memory is aligned to the natural alignment. + unsigned AddrSpace = Ty->getPointerAddressSpace(); + if (!NullPointerIsDefined(Lp->getHeader()->getParent(), AddrSpace) && + (Stride == 1 || Stride == -1)) + return Stride; + + if (Assume) { + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + LLVM_DEBUG(dbgs() << "LAA: Pointer may wrap:\n" + << "LAA: Pointer: " << *Ptr << "\n" + << "LAA: SCEV: " << *AR << "\n" + << "LAA: Added an overflow assumption\n"); + return Stride; + } + LLVM_DEBUG( + dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " + << *Ptr << " SCEV: " << *AR << "\n"); + return std::nullopt; } std::optional<int> llvm::getPointersDiff(Type *ElemTyA, Value *PtrA, @@ -1483,10 +1473,6 @@ std::optional<int> llvm::getPointersDiff(Type *ElemTyA, Value *PtrA, ScalarEvolution &SE, bool StrictCheck, bool CheckType) { assert(PtrA && PtrB && "Expected non-nullptr pointers."); - assert(cast<PointerType>(PtrA->getType()) - ->isOpaqueOrPointeeTypeMatches(ElemTyA) && "Wrong PtrA type"); - assert(cast<PointerType>(PtrB->getType()) - ->isOpaqueOrPointeeTypeMatches(ElemTyB) && "Wrong PtrB type"); // Make sure that A and B are different pointers. if (PtrA == PtrB) @@ -1830,7 +1816,7 @@ static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride, MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, const MemAccessInfo &B, unsigned BIdx, - const ValueToValueMap &Strides) { + const DenseMap<Value *, const SCEV *> &Strides) { assert (AIdx < BIdx && "Must pass arguments in program order"); auto [APtr, AIsWrite] = A; @@ -2024,7 +2010,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps, - const ValueToValueMap &Strides) { + const DenseMap<Value *, const SCEV *> &Strides) { MaxSafeDepDistBytes = -1; SmallPtrSet<MemAccessInfo, 8> Visited; @@ -2303,7 +2289,7 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, for (StoreInst *ST : Stores) { Value *Ptr = ST->getPointerOperand(); - if (isUniform(Ptr)) { + if (isInvariant(Ptr)) { // Record store instructions to loop invariant addresses StoresToInvariantAddresses.push_back(ST); HasDependenceInvolvingLoopInvariantAddress |= @@ -2545,15 +2531,151 @@ OptimizationRemarkAnalysis &LoopAccessInfo::recordAnalysis(StringRef RemarkName, return *Report; } -bool LoopAccessInfo::isUniform(Value *V) const { +bool LoopAccessInfo::isInvariant(Value *V) const { auto *SE = PSE->getSE(); - // Since we rely on SCEV for uniformity, if the type is not SCEVable, it is - // never considered uniform. // TODO: Is this really what we want? Even without FP SCEV, we may want some - // trivially loop-invariant FP values to be considered uniform. + // trivially loop-invariant FP values to be considered invariant. if (!SE->isSCEVable(V->getType())) return false; - return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop)); + const SCEV *S = SE->getSCEV(V); + return SE->isLoopInvariant(S, TheLoop); +} + +/// Find the operand of the GEP that should be checked for consecutive +/// stores. This ignores trailing indices that have no effect on the final +/// pointer. +static unsigned getGEPInductionOperand(const GetElementPtrInst *Gep) { + const DataLayout &DL = Gep->getModule()->getDataLayout(); + unsigned LastOperand = Gep->getNumOperands() - 1; + TypeSize GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); + + // Walk backwards and try to peel off zeros. + while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { + // Find the type we're currently indexing into. + gep_type_iterator GEPTI = gep_type_begin(Gep); + std::advance(GEPTI, LastOperand - 2); + + // If it's a type with the same allocation size as the result of the GEP we + // can peel off the zero index. + if (DL.getTypeAllocSize(GEPTI.getIndexedType()) != GEPAllocSize) + break; + --LastOperand; + } + + return LastOperand; +} + +/// If the argument is a GEP, then returns the operand identified by +/// getGEPInductionOperand. However, if there is some other non-loop-invariant +/// operand, it returns that instead. +static Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); + if (!GEP) + return Ptr; + + unsigned InductionOperand = getGEPInductionOperand(GEP); + + // Check that all of the gep indices are uniform except for our induction + // operand. + for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) + if (i != InductionOperand && + !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) + return Ptr; + return GEP->getOperand(InductionOperand); +} + +/// If a value has only one user that is a CastInst, return it. +static Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { + Value *UniqueCast = nullptr; + for (User *U : Ptr->users()) { + CastInst *CI = dyn_cast<CastInst>(U); + if (CI && CI->getType() == Ty) { + if (!UniqueCast) + UniqueCast = CI; + else + return nullptr; + } + } + return UniqueCast; +} + +/// Get the stride of a pointer access in a loop. Looks for symbolic +/// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. +static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { + auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); + if (!PtrTy || PtrTy->isAggregateType()) + return nullptr; + + // Try to remove a gep instruction to make the pointer (actually index at this + // point) easier analyzable. If OrigPtr is equal to Ptr we are analyzing the + // pointer, otherwise, we are analyzing the index. + Value *OrigPtr = Ptr; + + // The size of the pointer access. + int64_t PtrAccessSize = 1; + + Ptr = stripGetElementPtr(Ptr, SE, Lp); + const SCEV *V = SE->getSCEV(Ptr); + + if (Ptr != OrigPtr) + // Strip off casts. + while (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V)) + V = C->getOperand(); + + const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); + if (!S) + return nullptr; + + // If the pointer is invariant then there is no stride and it makes no + // sense to add it here. + if (Lp != S->getLoop()) + return nullptr; + + V = S->getStepRecurrence(*SE); + if (!V) + return nullptr; + + // Strip off the size of access multiplication if we are still analyzing the + // pointer. + if (OrigPtr == Ptr) { + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { + if (M->getOperand(0)->getSCEVType() != scConstant) + return nullptr; + + const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); + + // Huge step value - give up. + if (APStepVal.getBitWidth() > 64) + return nullptr; + + int64_t StepVal = APStepVal.getSExtValue(); + if (PtrAccessSize != StepVal) + return nullptr; + V = M->getOperand(1); + } + } + + // Note that the restriction after this loop invariant check are only + // profitability restrictions. + if (!SE->isLoopInvariant(V, Lp)) + return nullptr; + + // Look for the loop invariant symbolic value. + const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); + if (!U) { + const auto *C = dyn_cast<SCEVIntegralCastExpr>(V); + if (!C) + return nullptr; + U = dyn_cast<SCEVUnknown>(C->getOperand()); + if (!U) + return nullptr; + + // Match legacy behavior - this is not needed for correctness + if (!getUniqueCastUse(U->getValue(), Lp, V->getType())) + return nullptr; + } + + return V; } void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { @@ -2561,13 +2683,24 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { if (!Ptr) return; - Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop); - if (!Stride) + // Note: getStrideFromPointer is a *profitability* heuristic. We + // could broaden the scope of values returned here - to anything + // which happens to be loop invariant and contributes to the + // computation of an interesting IV - but we chose not to as we + // don't have a cost model here, and broadening the scope exposes + // far too many unprofitable cases. + const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop); + if (!StrideExpr) return; LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for " "versioning:"); - LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); + LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *StrideExpr << "\n"); + + if (!SpeculateUnitStride) { + LLVM_DEBUG(dbgs() << " Chose not to due to -laa-speculate-unit-stride\n"); + return; + } // Avoid adding the "Stride == 1" predicate when we know that // Stride >= Trip-Count. Such a predicate will effectively optimize a single @@ -2582,7 +2715,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { // of various possible stride specializations, considering the alternatives // of using gather/scatters (if available). - const SCEV *StrideExpr = PSE->getSCEV(Stride); const SCEV *BETakenCount = PSE->getBackedgeTakenCount(); // Match the types so we can compare the stride and the BETakenCount. @@ -2611,8 +2743,12 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { } LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n"); - SymbolicStrides[Ptr] = Stride; - StrideSet.insert(Stride); + // Strip back off the integer cast, and check that our result is a + // SCEVUnknown as we expect. + const SCEV *StrideBase = StrideExpr; + if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase)) + StrideBase = C->getOperand(); + SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase); } LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, @@ -2680,55 +2816,32 @@ const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L) { return *I.first->second; } -LoopAccessLegacyAnalysis::LoopAccessLegacyAnalysis() : FunctionPass(ID) { - initializeLoopAccessLegacyAnalysisPass(*PassRegistry::getPassRegistry()); -} - -bool LoopAccessLegacyAnalysis::runOnFunction(Function &F) { - auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; - auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LAIs = std::make_unique<LoopAccessInfoManager>(SE, AA, DT, LI, TLI); - return false; -} - -void LoopAccessLegacyAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequiredTransitive<ScalarEvolutionWrapperPass>(); - AU.addRequiredTransitive<AAResultsWrapperPass>(); - AU.addRequiredTransitive<DominatorTreeWrapperPass>(); - AU.addRequiredTransitive<LoopInfoWrapperPass>(); +bool LoopAccessInfoManager::invalidate( + Function &F, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &Inv) { + // Check whether our analysis is preserved. + auto PAC = PA.getChecker<LoopAccessAnalysis>(); + if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<Function>>()) + // If not, give up now. + return true; - AU.setPreservesAll(); + // Check whether the analyses we depend on became invalid for any reason. + // Skip checking TargetLibraryAnalysis as it is immutable and can't become + // invalid. + return Inv.invalidate<AAManager>(F, PA) || + Inv.invalidate<ScalarEvolutionAnalysis>(F, PA) || + Inv.invalidate<LoopAnalysis>(F, PA) || + Inv.invalidate<DominatorTreeAnalysis>(F, PA); } LoopAccessInfoManager LoopAccessAnalysis::run(Function &F, - FunctionAnalysisManager &AM) { - return LoopAccessInfoManager( - AM.getResult<ScalarEvolutionAnalysis>(F), AM.getResult<AAManager>(F), - AM.getResult<DominatorTreeAnalysis>(F), AM.getResult<LoopAnalysis>(F), - &AM.getResult<TargetLibraryAnalysis>(F)); + FunctionAnalysisManager &FAM) { + auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F); + auto &AA = FAM.getResult<AAManager>(F); + auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); + auto &LI = FAM.getResult<LoopAnalysis>(F); + auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + return LoopAccessInfoManager(SE, AA, DT, LI, &TLI); } -char LoopAccessLegacyAnalysis::ID = 0; -static const char laa_name[] = "Loop Access Analysis"; -#define LAA_NAME "loop-accesses" - -INITIALIZE_PASS_BEGIN(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) - AnalysisKey LoopAccessAnalysis::Key; - -namespace llvm { - - Pass *createLAAPass() { - return new LoopAccessLegacyAnalysis(); - } - -} // end namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/LoopCacheAnalysis.cpp index 46198f78b643..c3a56639b5c8 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -297,7 +297,7 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, Type *WiderType = SE.getWiderType(Stride->getType(), TripCount->getType()); const SCEV *CacheLineSize = SE.getConstant(WiderType, CLS); Stride = SE.getNoopOrAnyExtend(Stride, WiderType); - TripCount = SE.getNoopOrAnyExtend(TripCount, WiderType); + TripCount = SE.getNoopOrZeroExtend(TripCount, WiderType); const SCEV *Numerator = SE.getMulExpr(Stride, TripCount); RefCost = SE.getUDivExpr(Numerator, CacheLineSize); @@ -323,8 +323,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, const SCEV *TripCount = computeTripCount(*AR->getLoop(), *Sizes.back(), SE); Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType()); - RefCost = SE.getMulExpr(SE.getNoopOrAnyExtend(RefCost, WiderType), - SE.getNoopOrAnyExtend(TripCount, WiderType)); + RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType), + SE.getNoopOrZeroExtend(TripCount, WiderType)); } LLVM_DEBUG(dbgs().indent(4) @@ -334,7 +334,7 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, // Attempt to fold RefCost into a constant. if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost)) - return ConstantCost->getValue()->getSExtValue(); + return ConstantCost->getValue()->getZExtValue(); LLVM_DEBUG(dbgs().indent(4) << "RefCost is not a constant! Setting to RefCost=InvalidCost " diff --git a/contrib/llvm-project/llvm/lib/Analysis/LoopInfo.cpp b/contrib/llvm-project/llvm/lib/Analysis/LoopInfo.cpp index 69bcbcb11203..60a72079e864 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/LoopInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/LoopInfo.cpp @@ -17,7 +17,6 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/IVDescriptors.h" -#include "llvm/Analysis/LoopInfoImpl.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopNestAnalysis.h" #include "llvm/Analysis/MemorySSA.h" @@ -36,6 +35,7 @@ #include "llvm/IR/PrintPasses.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/GenericLoopInfoImpl.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -737,7 +737,7 @@ void UnloopUpdater::updateBlockParents() { bool Changed = FoundIB; for (unsigned NIters = 0; Changed; ++NIters) { assert(NIters < Unloop.getNumBlocks() && "runaway iterative algorithm"); - (void) NIters; + (void)NIters; // Iterate over the postorder list of blocks, propagating the nearest loop // from successors to predecessors as before. @@ -929,9 +929,8 @@ void LoopInfo::erase(Loop *Unloop) { } } -bool -LoopInfo::wouldBeOutOfLoopUseRequiringLCSSA(const Value *V, - const BasicBlock *ExitBB) const { +bool LoopInfo::wouldBeOutOfLoopUseRequiringLCSSA( + const Value *V, const BasicBlock *ExitBB) const { if (V->getType()->isTokenTy()) // We can't form PHIs of token type, so the definition of LCSSA excludes // values of that type. diff --git a/contrib/llvm-project/llvm/lib/Analysis/MLInlineAdvisor.cpp b/contrib/llvm-project/llvm/lib/Analysis/MLInlineAdvisor.cpp index a20c05243b77..0660a9993b6d 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -18,10 +18,12 @@ #include "llvm/Analysis/FunctionPropertiesAnalysis.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InlineModelFeatureMaps.h" +#include "llvm/Analysis/InteractiveModelRunner.h" #include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MLModelRunner.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ReleaseModeModelRunner.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" @@ -30,19 +32,50 @@ using namespace llvm; +static cl::opt<std::string> InteractiveChannelBaseName( + "inliner-interactive-channel-base", cl::Hidden, + cl::desc( + "Base file path for the interactive mode. The incoming filename should " + "have the name <inliner-interactive-channel-base>.in, while the " + "outgoing name should be <inliner-interactive-channel-base>.out")); +static const std::string InclDefaultMsg = + (Twine("In interactive mode, also send the default policy decision: ") + + DefaultDecisionName + ".") + .str(); +static cl::opt<bool> + InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden, + cl::desc(InclDefaultMsg)); + #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL) -#include "llvm/Analysis/ReleaseModeModelRunner.h" // codegen-ed file #include "InlinerSizeModel.h" // NOLINT +using CompiledModelType = llvm::InlinerSizeModel; +#else +using CompiledModelType = NoopSavedModelImpl; +#endif std::unique_ptr<InlineAdvisor> -llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM) { - auto AOTRunner = - std::make_unique<ReleaseModeModelRunner<llvm::InlinerSizeModel>>( - M.getContext(), FeatureMap, DecisionName); - return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner)); +llvm::getReleaseModeAdvisor(Module &M, ModuleAnalysisManager &MAM, + std::function<bool(CallBase &)> GetDefaultAdvice) { + if (!llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() && + InteractiveChannelBaseName.empty()) + return nullptr; + std::unique_ptr<MLModelRunner> AOTRunner; + if (InteractiveChannelBaseName.empty()) + AOTRunner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>( + M.getContext(), FeatureMap, DecisionName); + else { + auto Features = FeatureMap; + if (InteractiveIncludeDefault) + Features.push_back(DefaultDecisionSpec); + AOTRunner = std::make_unique<InteractiveModelRunner>( + M.getContext(), Features, InlineDecisionSpec, + InteractiveChannelBaseName + ".out", + InteractiveChannelBaseName + ".in"); + } + return std::make_unique<MLInlineAdvisor>(M, MAM, std::move(AOTRunner), + GetDefaultAdvice); } -#endif #define DEBUG_TYPE "inline-ml" @@ -59,21 +92,23 @@ static cl::opt<bool> KeepFPICache( cl::init(false)); // clang-format off -const std::array<TensorSpec, NumberOfFeatures> llvm::FeatureMap{ -#define POPULATE_NAMES(_, NAME) TensorSpec::createSpec<int64_t>(NAME, {1} ), +const std::vector<TensorSpec> llvm::FeatureMap{ +#define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE), // InlineCost features - these must come first INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES) -#undef POPULATE_NAMES // Non-cost features -#define POPULATE_NAMES(_, NAME, __) TensorSpec::createSpec<int64_t>(NAME, {1} ), INLINE_FEATURE_ITERATOR(POPULATE_NAMES) #undef POPULATE_NAMES }; // clang-format on const char *const llvm::DecisionName = "inlining_decision"; +const TensorSpec llvm::InlineDecisionSpec = + TensorSpec::createSpec<int64_t>(DecisionName, {1}); const char *const llvm::DefaultDecisionName = "inlining_default"; +const TensorSpec llvm::DefaultDecisionSpec = + TensorSpec::createSpec<int64_t>(DefaultDecisionName, {1}); const char *const llvm::RewardName = "delta_size"; CallBase *getInlinableCS(Instruction &I) { @@ -86,15 +121,17 @@ CallBase *getInlinableCS(Instruction &I) { return nullptr; } -MLInlineAdvisor::MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, - std::unique_ptr<MLModelRunner> Runner) +MLInlineAdvisor::MLInlineAdvisor( + Module &M, ModuleAnalysisManager &MAM, + std::unique_ptr<MLModelRunner> Runner, + std::function<bool(CallBase &)> GetDefaultAdvice) : InlineAdvisor( M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()), - ModelRunner(std::move(Runner)), + ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice), CG(MAM.getResult<LazyCallGraphAnalysis>(M)), InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize) { assert(ModelRunner); - + ModelRunner->switchContext(""); // Extract the 'call site height' feature - the position of a call site // relative to the farthest statically reachable SCC node. We don't mutate // this value while inlining happens. Empirically, this feature proved @@ -344,26 +381,27 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) { auto &CallerBefore = getCachedFPI(Caller); auto &CalleeBefore = getCachedFPI(Callee); - *ModelRunner->getTensor<int64_t>(FeatureIndex::CalleeBasicBlockCount) = + *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_basic_block_count) = CalleeBefore.BasicBlockCount; - *ModelRunner->getTensor<int64_t>(FeatureIndex::CallSiteHeight) = + *ModelRunner->getTensor<int64_t>(FeatureIndex::callsite_height) = getInitialFunctionLevel(Caller); - *ModelRunner->getTensor<int64_t>(FeatureIndex::NodeCount) = NodeCount; - *ModelRunner->getTensor<int64_t>(FeatureIndex::NrCtantParams) = NrCtantParams; - *ModelRunner->getTensor<int64_t>(FeatureIndex::EdgeCount) = EdgeCount; - *ModelRunner->getTensor<int64_t>(FeatureIndex::CallerUsers) = + *ModelRunner->getTensor<int64_t>(FeatureIndex::node_count) = NodeCount; + *ModelRunner->getTensor<int64_t>(FeatureIndex::nr_ctant_params) = + NrCtantParams; + *ModelRunner->getTensor<int64_t>(FeatureIndex::edge_count) = EdgeCount; + *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_users) = CallerBefore.Uses; *ModelRunner->getTensor<int64_t>( - FeatureIndex::CallerConditionallyExecutedBlocks) = + FeatureIndex::caller_conditionally_executed_blocks) = CallerBefore.BlocksReachedFromConditionalInstruction; - *ModelRunner->getTensor<int64_t>(FeatureIndex::CallerBasicBlockCount) = + *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_basic_block_count) = CallerBefore.BasicBlockCount; *ModelRunner->getTensor<int64_t>( - FeatureIndex::CalleeConditionallyExecutedBlocks) = + FeatureIndex::callee_conditionally_executed_blocks) = CalleeBefore.BlocksReachedFromConditionalInstruction; - *ModelRunner->getTensor<int64_t>(FeatureIndex::CalleeUsers) = + *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_users) = CalleeBefore.Uses; - *ModelRunner->getTensor<int64_t>(FeatureIndex::CostEstimate) = CostEstimate; + *ModelRunner->getTensor<int64_t>(FeatureIndex::cost_estimate) = CostEstimate; // Add the cost features for (size_t I = 0; @@ -371,7 +409,10 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) { *ModelRunner->getTensor<int64_t>(inlineCostFeatureToMlFeature( static_cast<InlineCostFeatureIndex>(I))) = CostFeatures->at(I); } - + // This one would have been set up to be right at the end. + if (!InteractiveChannelBaseName.empty() && InteractiveIncludeDefault) + *ModelRunner->getTensor<int64_t>(InlineCostFeatureIndex::NumberOfFeatures) = + GetDefaultAdvice(CB); return getAdviceFromModel(CB, ORE); } diff --git a/contrib/llvm-project/llvm/lib/Analysis/MemDepPrinter.cpp b/contrib/llvm-project/llvm/lib/Analysis/MemDepPrinter.cpp deleted file mode 100644 index 305ae3e2a992..000000000000 --- a/contrib/llvm-project/llvm/lib/Analysis/MemDepPrinter.cpp +++ /dev/null @@ -1,164 +0,0 @@ -//===- MemDepPrinter.cpp - Printer for MemoryDependenceAnalysis -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/SetVector.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/MemoryDependenceAnalysis.h" -#include "llvm/Analysis/Passes.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" -#include "llvm/InitializePasses.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/raw_ostream.h" - -using namespace llvm; - -namespace { - struct MemDepPrinter : public FunctionPass { - const Function *F; - - enum DepType { - Clobber = 0, - Def, - NonFuncLocal, - Unknown - }; - - static const char *const DepTypeStr[]; - - typedef PointerIntPair<const Instruction *, 2, DepType> InstTypePair; - typedef std::pair<InstTypePair, const BasicBlock *> Dep; - typedef SmallSetVector<Dep, 4> DepSet; - typedef DenseMap<const Instruction *, DepSet> DepSetMap; - DepSetMap Deps; - - static char ID; // Pass identifcation, replacement for typeid - MemDepPrinter() : FunctionPass(ID) { - initializeMemDepPrinterPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void print(raw_ostream &OS, const Module * = nullptr) const override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequiredTransitive<AAResultsWrapperPass>(); - AU.addRequiredTransitive<MemoryDependenceWrapperPass>(); - AU.setPreservesAll(); - } - - void releaseMemory() override { - Deps.clear(); - F = nullptr; - } - - private: - static InstTypePair getInstTypePair(MemDepResult dep) { - if (dep.isClobber()) - return InstTypePair(dep.getInst(), Clobber); - if (dep.isDef()) - return InstTypePair(dep.getInst(), Def); - if (dep.isNonFuncLocal()) - return InstTypePair(dep.getInst(), NonFuncLocal); - assert(dep.isUnknown() && "unexpected dependence type"); - return InstTypePair(dep.getInst(), Unknown); - } - }; -} - -char MemDepPrinter::ID = 0; -INITIALIZE_PASS_BEGIN(MemDepPrinter, "print-memdeps", - "Print MemDeps of function", false, true) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) -INITIALIZE_PASS_END(MemDepPrinter, "print-memdeps", - "Print MemDeps of function", false, true) - -FunctionPass *llvm::createMemDepPrinter() { - return new MemDepPrinter(); -} - -const char *const MemDepPrinter::DepTypeStr[] - = {"Clobber", "Def", "NonFuncLocal", "Unknown"}; - -bool MemDepPrinter::runOnFunction(Function &F) { - this->F = &F; - MemoryDependenceResults &MDA = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); - - // All this code uses non-const interfaces because MemDep is not - // const-friendly, though nothing is actually modified. - for (auto &I : instructions(F)) { - Instruction *Inst = &I; - - if (!Inst->mayReadFromMemory() && !Inst->mayWriteToMemory()) - continue; - - MemDepResult Res = MDA.getDependency(Inst); - if (!Res.isNonLocal()) { - Deps[Inst].insert(std::make_pair(getInstTypePair(Res), - static_cast<BasicBlock *>(nullptr))); - } else if (auto *Call = dyn_cast<CallBase>(Inst)) { - const MemoryDependenceResults::NonLocalDepInfo &NLDI = - MDA.getNonLocalCallDependency(Call); - - DepSet &InstDeps = Deps[Inst]; - for (const NonLocalDepEntry &I : NLDI) { - const MemDepResult &Res = I.getResult(); - InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB())); - } - } else { - SmallVector<NonLocalDepResult, 4> NLDI; - assert( (isa<LoadInst>(Inst) || isa<StoreInst>(Inst) || - isa<VAArgInst>(Inst)) && "Unknown memory instruction!"); - MDA.getNonLocalPointerDependency(Inst, NLDI); - - DepSet &InstDeps = Deps[Inst]; - for (const NonLocalDepResult &I : NLDI) { - const MemDepResult &Res = I.getResult(); - InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB())); - } - } - } - - return false; -} - -void MemDepPrinter::print(raw_ostream &OS, const Module *M) const { - for (const auto &I : instructions(*F)) { - const Instruction *Inst = &I; - - DepSetMap::const_iterator DI = Deps.find(Inst); - if (DI == Deps.end()) - continue; - - const DepSet &InstDeps = DI->second; - - for (const auto &I : InstDeps) { - const Instruction *DepInst = I.first.getPointer(); - DepType type = I.first.getInt(); - const BasicBlock *DepBB = I.second; - - OS << " "; - OS << DepTypeStr[type]; - if (DepBB) { - OS << " in block "; - DepBB->printAsOperand(OS, /*PrintType=*/false, M); - } - if (DepInst) { - OS << " from: "; - DepInst->print(OS); - } - OS << "\n"; - } - - Inst->print(OS); - OS << "\n\n"; - } -} diff --git a/contrib/llvm-project/llvm/lib/Analysis/MemDerefPrinter.cpp b/contrib/llvm-project/llvm/lib/Analysis/MemDerefPrinter.cpp index 4dd5c76cc604..2632bc50d6e6 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/MemDerefPrinter.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/MemDerefPrinter.cpp @@ -18,65 +18,6 @@ using namespace llvm; -namespace { - struct MemDerefPrinter : public FunctionPass { - SmallVector<Value *, 4> Deref; - SmallPtrSet<Value *, 4> DerefAndAligned; - - static char ID; // Pass identification, replacement for typeid - MemDerefPrinter() : FunctionPass(ID) { - initializeMemDerefPrinterPass(*PassRegistry::getPassRegistry()); - } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - bool runOnFunction(Function &F) override; - void print(raw_ostream &OS, const Module * = nullptr) const override; - void releaseMemory() override { - Deref.clear(); - DerefAndAligned.clear(); - } - }; -} - -char MemDerefPrinter::ID = 0; -INITIALIZE_PASS_BEGIN(MemDerefPrinter, "print-memderefs", - "Memory Dereferenciblity of pointers in function", false, true) -INITIALIZE_PASS_END(MemDerefPrinter, "print-memderefs", - "Memory Dereferenciblity of pointers in function", false, true) - -FunctionPass *llvm::createMemDerefPrinter() { - return new MemDerefPrinter(); -} - -bool MemDerefPrinter::runOnFunction(Function &F) { - const DataLayout &DL = F.getParent()->getDataLayout(); - for (auto &I: instructions(F)) { - if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { - Value *PO = LI->getPointerOperand(); - if (isDereferenceablePointer(PO, LI->getType(), DL)) - Deref.push_back(PO); - if (isDereferenceableAndAlignedPointer(PO, LI->getType(), LI->getAlign(), - DL)) - DerefAndAligned.insert(PO); - } - } - return false; -} - -void MemDerefPrinter::print(raw_ostream &OS, const Module *M) const { - OS << "The following are dereferenceable:\n"; - for (Value *V: Deref) { - OS << " "; - V->print(OS); - if (DerefAndAligned.count(V)) - OS << "\t(aligned)"; - else - OS << "\t(unaligned)"; - OS << "\n"; - } -} - PreservedAnalyses MemDerefPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { OS << "Memory Dereferencibility of pointers in function '" << F.getName() diff --git a/contrib/llvm-project/llvm/lib/Analysis/MemoryBuiltins.cpp b/contrib/llvm-project/llvm/lib/Analysis/MemoryBuiltins.cpp index 0edad0557369..53e089ba1fea 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/MemoryBuiltins.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/MemoryBuiltins.cpp @@ -115,17 +115,25 @@ static const std::pair<LibFunc, AllocFnsTy> AllocationFnData[] = { {LibFunc_ZnwjSt11align_val_t, {OpNewLike, 2, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned int, align_val_t) {LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned int, align_val_t, nothrow) {LibFunc_Znwm, {OpNewLike, 1, 0, -1, -1, MallocFamily::CPPNew}}, // new(unsigned long) + {LibFunc_Znwm12__hot_cold_t, {OpNewLike, 2, 0, -1, -1, MallocFamily::CPPNew}}, // new(unsigned long, __hot_cold_t) {LibFunc_ZnwmRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1, MallocFamily::CPPNew}}, // new(unsigned long, nothrow) + {LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, {MallocLike, 3, 0, -1, -1, MallocFamily::CPPNew}}, // new(unsigned long, nothrow, __hot_cold_t) {LibFunc_ZnwmSt11align_val_t, {OpNewLike, 2, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned long, align_val_t) + {LibFunc_ZnwmSt11align_val_t12__hot_cold_t, {OpNewLike, 3, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned long, align_val_t, __hot_cold_t) {LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned long, align_val_t, nothrow) + {LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t, {MallocLike, 4, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new(unsigned long, align_val_t, nothrow, __hot_cold_t) {LibFunc_Znaj, {OpNewLike, 1, 0, -1, -1, MallocFamily::CPPNewArray}}, // new[](unsigned int) {LibFunc_ZnajRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1, MallocFamily::CPPNewArray}}, // new[](unsigned int, nothrow) {LibFunc_ZnajSt11align_val_t, {OpNewLike, 2, 0, -1, 1, MallocFamily::CPPNewArrayAligned}}, // new[](unsigned int, align_val_t) {LibFunc_ZnajSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1, MallocFamily::CPPNewArrayAligned}}, // new[](unsigned int, align_val_t, nothrow) {LibFunc_Znam, {OpNewLike, 1, 0, -1, -1, MallocFamily::CPPNewArray}}, // new[](unsigned long) + {LibFunc_Znam12__hot_cold_t, {OpNewLike, 2, 0, -1, -1, MallocFamily::CPPNew}}, // new[](unsigned long, __hot_cold_t) {LibFunc_ZnamRKSt9nothrow_t, {MallocLike, 2, 0, -1, -1, MallocFamily::CPPNewArray}}, // new[](unsigned long, nothrow) + {LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, {MallocLike, 3, 0, -1, -1, MallocFamily::CPPNew}}, // new[](unsigned long, nothrow, __hot_cold_t) {LibFunc_ZnamSt11align_val_t, {OpNewLike, 2, 0, -1, 1, MallocFamily::CPPNewArrayAligned}}, // new[](unsigned long, align_val_t) + {LibFunc_ZnamSt11align_val_t12__hot_cold_t, {OpNewLike, 3, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new[](unsigned long, align_val_t, __hot_cold_t) {LibFunc_ZnamSt11align_val_tRKSt9nothrow_t, {MallocLike, 3, 0, -1, 1, MallocFamily::CPPNewArrayAligned}}, // new[](unsigned long, align_val_t, nothrow) + {LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t, {MallocLike, 4, 0, -1, 1, MallocFamily::CPPNewAligned}}, // new[](unsigned long, align_val_t, nothrow, __hot_cold_t) {LibFunc_msvc_new_int, {OpNewLike, 1, 0, -1, -1, MallocFamily::MSVCNew}}, // new(unsigned int) {LibFunc_msvc_new_int_nothrow, {MallocLike, 2, 0, -1, -1, MallocFamily::MSVCNew}}, // new(unsigned int, nothrow) {LibFunc_msvc_new_longlong, {OpNewLike, 1, 0, -1, -1, MallocFamily::MSVCNew}}, // new(unsigned long long) @@ -594,10 +602,10 @@ Value *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize, MustSucceed); } -Value *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize, - const DataLayout &DL, - const TargetLibraryInfo *TLI, AAResults *AA, - bool MustSucceed) { +Value *llvm::lowerObjectSizeCall( + IntrinsicInst *ObjectSize, const DataLayout &DL, + const TargetLibraryInfo *TLI, AAResults *AA, bool MustSucceed, + SmallVectorImpl<Instruction *> *InsertedInstructions) { assert(ObjectSize->getIntrinsicID() == Intrinsic::objectsize && "ObjectSize must be a call to llvm.objectsize!"); @@ -632,7 +640,11 @@ Value *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize, Eval.compute(ObjectSize->getArgOperand(0)); if (SizeOffsetPair != ObjectSizeOffsetEvaluator::unknown()) { - IRBuilder<TargetFolder> Builder(Ctx, TargetFolder(DL)); + IRBuilder<TargetFolder, IRBuilderCallbackInserter> Builder( + Ctx, TargetFolder(DL), IRBuilderCallbackInserter([&](Instruction *I) { + if (InsertedInstructions) + InsertedInstructions->push_back(I); + })); Builder.SetInsertPoint(ObjectSize); // If we've outside the end of the object, then we can always access @@ -818,7 +830,9 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalAlias(GlobalAlias &GA) { } SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalVariable(GlobalVariable &GV){ - if (!GV.hasDefinitiveInitializer()) + if (!GV.getValueType()->isSized() || GV.hasExternalWeakLinkage() || + ((!GV.hasInitializer() || GV.isInterposable()) && + Options.EvalMode != ObjectSizeOpts::Mode::Min)) return unknown(); APInt Size(IntTyBits, DL.getTypeAllocSize(GV.getValueType())); @@ -976,6 +990,8 @@ SizeOffsetType ObjectSizeOffsetVisitor::combineSizeOffset(SizeOffsetType LHS, } SizeOffsetType ObjectSizeOffsetVisitor::visitPHINode(PHINode &PN) { + if (PN.getNumIncomingValues() == 0) + return unknown(); auto IncomingValues = PN.incoming_values(); return std::accumulate(IncomingValues.begin() + 1, IncomingValues.end(), compute(*IncomingValues.begin()), @@ -1099,12 +1115,13 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitAllocaInst(AllocaInst &I) { // must be a VLA assert(I.isArrayAllocation()); - // If needed, adjust the alloca's operand size to match the pointer size. - // Subsequent math operations expect the types to match. + // If needed, adjust the alloca's operand size to match the pointer indexing + // size. Subsequent math operations expect the types to match. Value *ArraySize = Builder.CreateZExtOrTrunc( - I.getArraySize(), DL.getIntPtrType(I.getContext())); + I.getArraySize(), + DL.getIndexType(I.getContext(), DL.getAllocaAddrSpace())); assert(ArraySize->getType() == Zero->getType() && - "Expected zero constant to have pointer type"); + "Expected zero constant to have pointer index type"); Value *Size = ConstantInt::get(ArraySize->getType(), DL.getTypeAllocSize(I.getAllocatedType())); diff --git a/contrib/llvm-project/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp index 93c388abb0fd..071ecdba8a54 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -1238,7 +1238,7 @@ bool MemoryDependenceResults::getNonLocalPointerDepFromBB( // phi translation to change it into a value live in the predecessor block. // If not, we just add the predecessors to the worklist and scan them with // the same Pointer. - if (!Pointer.NeedsPHITranslationFromBlock(BB)) { + if (!Pointer.needsPHITranslationFromBlock(BB)) { SkipFirstBlock = false; SmallVector<BasicBlock *, 16> NewBlocks; for (BasicBlock *Pred : PredCache.get(BB)) { @@ -1277,7 +1277,7 @@ bool MemoryDependenceResults::getNonLocalPointerDepFromBB( // We do need to do phi translation, if we know ahead of time we can't phi // translate this value, don't even try. - if (!Pointer.IsPotentiallyPHITranslatable()) + if (!Pointer.isPotentiallyPHITranslatable()) goto PredTranslationFailure; // We may have added values to the cache list before this PHI translation. @@ -1298,8 +1298,8 @@ bool MemoryDependenceResults::getNonLocalPointerDepFromBB( // Get the PHI translated pointer in this predecessor. This can fail if // not translatable, in which case the getAddr() returns null. PHITransAddr &PredPointer = PredList.back().second; - PredPointer.PHITranslateValue(BB, Pred, &DT, /*MustDominate=*/false); - Value *PredPtrVal = PredPointer.getAddr(); + Value *PredPtrVal = + PredPointer.translateValue(BB, Pred, &DT, /*MustDominate=*/false); // Check to see if we have already visited this pred block with another // pointer. If so, we can't do this lookup. This failure can occur diff --git a/contrib/llvm-project/llvm/lib/Analysis/MemoryLocation.cpp b/contrib/llvm-project/llvm/lib/Analysis/MemoryLocation.cpp index e839f9e0dfb2..0404b32be848 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/MemoryLocation.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/MemoryLocation.cpp @@ -257,7 +257,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, case LibFunc_memset_chk: assert(ArgIdx == 0 && "Invalid argument index for memset_chk"); - LLVM_FALLTHROUGH; + [[fallthrough]]; case LibFunc_memcpy_chk: { assert((ArgIdx == 0 || ArgIdx == 1) && "Invalid argument index for memcpy_chk"); diff --git a/contrib/llvm-project/llvm/lib/Analysis/MemoryProfileInfo.cpp b/contrib/llvm-project/llvm/lib/Analysis/MemoryProfileInfo.cpp index 8ced1d2fd140..7fbcffc6489d 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/MemoryProfileInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/MemoryProfileInfo.cpp @@ -18,26 +18,47 @@ using namespace llvm::memprof; #define DEBUG_TYPE "memory-profile-info" -// Upper bound on accesses per byte for marking an allocation cold. -cl::opt<float> MemProfAccessesPerByteColdThreshold( - "memprof-accesses-per-byte-cold-threshold", cl::init(10.0), cl::Hidden, - cl::desc("The threshold the accesses per byte must be under to consider " - "an allocation cold")); +// Upper bound on lifetime access density (accesses per byte per lifetime sec) +// for marking an allocation cold. +cl::opt<float> MemProfLifetimeAccessDensityColdThreshold( + "memprof-lifetime-access-density-cold-threshold", cl::init(0.05), + cl::Hidden, + cl::desc("The threshold the lifetime access density (accesses per byte per " + "lifetime sec) must be under to consider an allocation cold")); // Lower bound on lifetime to mark an allocation cold (in addition to accesses -// per byte above). This is to avoid pessimizing short lived objects. -cl::opt<unsigned> MemProfMinLifetimeColdThreshold( - "memprof-min-lifetime-cold-threshold", cl::init(200), cl::Hidden, - cl::desc("The minimum lifetime (s) for an allocation to be considered " +// per byte per sec above). This is to avoid pessimizing short lived objects. +cl::opt<unsigned> MemProfAveLifetimeColdThreshold( + "memprof-ave-lifetime-cold-threshold", cl::init(200), cl::Hidden, + cl::desc("The average lifetime (s) for an allocation to be considered " "cold")); -AllocationType llvm::memprof::getAllocType(uint64_t MaxAccessCount, - uint64_t MinSize, - uint64_t MinLifetime) { - if (((float)MaxAccessCount) / MinSize < MemProfAccessesPerByteColdThreshold && - // MinLifetime is expected to be in ms, so convert the threshold to ms. - MinLifetime >= MemProfMinLifetimeColdThreshold * 1000) +// Lower bound on average lifetime accesses density (total life time access +// density / alloc count) for marking an allocation hot. +cl::opt<unsigned> MemProfMinAveLifetimeAccessDensityHotThreshold( + "memprof-min-ave-lifetime-access-density-hot-threshold", cl::init(1000), + cl::Hidden, + cl::desc("The minimum TotalLifetimeAccessDensity / AllocCount for an " + "allocation to be considered hot")); + +AllocationType llvm::memprof::getAllocType(uint64_t TotalLifetimeAccessDensity, + uint64_t AllocCount, + uint64_t TotalLifetime) { + // The access densities are multiplied by 100 to hold 2 decimal places of + // precision, so need to divide by 100. + if (((float)TotalLifetimeAccessDensity) / AllocCount / 100 < + MemProfLifetimeAccessDensityColdThreshold + // Lifetime is expected to be in ms, so convert the threshold to ms. + && ((float)TotalLifetime) / AllocCount >= + MemProfAveLifetimeColdThreshold * 1000) return AllocationType::Cold; + + // The access densities are multiplied by 100 to hold 2 decimal places of + // precision, so need to divide by 100. + if (((float)TotalLifetimeAccessDensity) / AllocCount / 100 > + MemProfMinAveLifetimeAccessDensityHotThreshold) + return AllocationType::Hot; + return AllocationType::NotCold; } @@ -65,12 +86,15 @@ AllocationType llvm::memprof::getMIBAllocType(const MDNode *MIB) { // types that can be applied based on the allocation profile data. auto *MDS = dyn_cast<MDString>(MIB->getOperand(1)); assert(MDS); - if (MDS->getString().equals("cold")) + if (MDS->getString().equals("cold")) { return AllocationType::Cold; + } else if (MDS->getString().equals("hot")) { + return AllocationType::Hot; + } return AllocationType::NotCold; } -static std::string getAllocTypeAttributeString(AllocationType Type) { +std::string llvm::memprof::getAllocTypeAttributeString(AllocationType Type) { switch (Type) { case AllocationType::NotCold: return "notcold"; @@ -78,6 +102,9 @@ static std::string getAllocTypeAttributeString(AllocationType Type) { case AllocationType::Cold: return "cold"; break; + case AllocationType::Hot: + return "hot"; + break; default: assert(false && "Unexpected alloc type"); } @@ -91,7 +118,7 @@ static void addAllocTypeAttribute(LLVMContext &Ctx, CallBase *CI, CI->addFnAttr(A); } -static bool hasSingleAllocType(uint8_t AllocTypes) { +bool llvm::memprof::hasSingleAllocType(uint8_t AllocTypes) { const unsigned NumAllocTypes = llvm::popcount(AllocTypes); assert(NumAllocTypes != 0); return NumAllocTypes == 1; @@ -242,3 +269,9 @@ CallStack<MDNode, MDNode::op_iterator>::CallStackIterator::operator*() { assert(StackIdCInt); return StackIdCInt->getZExtValue(); } + +template <> uint64_t CallStack<MDNode, MDNode::op_iterator>::back() const { + assert(N); + return mdconst::dyn_extract<ConstantInt>(N->operands().back()) + ->getZExtValue(); +} diff --git a/contrib/llvm-project/llvm/lib/Analysis/MemorySSA.cpp b/contrib/llvm-project/llvm/lib/Analysis/MemorySSA.cpp index aefb66863b8f..d16658028266 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/MemorySSA.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/MemorySSA.cpp @@ -71,12 +71,6 @@ INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, true) -INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa", - "Memory SSA Printer", false, false) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", - "Memory SSA Printer", false, false) - static cl::opt<unsigned> MaxCheckLimit( "memssa-check-limit", cl::Hidden, cl::init(100), cl::desc("The maximum number of stores/phis MemorySSA" @@ -304,7 +298,6 @@ instructionClobbersQuery(const MemoryDef *MD, const MemoryLocation &UseLoc, case Intrinsic::experimental_noalias_scope_decl: case Intrinsic::pseudoprobe: return false; - case Intrinsic::dbg_addr: case Intrinsic::dbg_declare: case Intrinsic::dbg_label: case Intrinsic::dbg_value: @@ -371,7 +364,8 @@ struct UpwardsMemoryQuery { } // end anonymous namespace -static bool isUseTriviallyOptimizableToLiveOnEntry(BatchAAResults &AA, +template <typename AliasAnalysisType> +static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysisType &AA, const Instruction *I) { // If the memory can't be changed, then loads of the memory can't be // clobbered. @@ -1368,11 +1362,6 @@ void MemorySSA::OptimizeUses::optimizeUsesInBlock( if (MU->isOptimized()) continue; - if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) { - MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true); - continue; - } - MemoryLocOrCall UseMLOC(MU); auto &LocInfo = LocStackInfo[UseMLOC]; // If the pop epoch changed, it means we've removed stuff from top of @@ -1788,10 +1777,15 @@ MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I, return nullptr; MemoryUseOrDef *MUD; - if (Def) + if (Def) { MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++); - else + } else { MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent()); + if (isUseTriviallyOptimizableToLiveOnEntry(*AAP, I)) { + MemoryAccess *LiveOnEntry = getLiveOnEntryDef(); + MUD->setOptimized(LiveOnEntry); + } + } ValueToMemoryAccess[I] = MUD; return MUD; } @@ -2220,17 +2214,6 @@ void MemoryAccess::dump() const { #endif } -char MemorySSAPrinterLegacyPass::ID = 0; - -MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) { - initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); -} - -void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesAll(); - AU.addRequired<MemorySSAWrapperPass>(); -} - class DOTFuncMSSAInfo { private: const Function &F; @@ -2315,20 +2298,6 @@ struct DOTGraphTraits<DOTFuncMSSAInfo *> : public DefaultDOTGraphTraits { } // namespace llvm -bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { - auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSA.ensureOptimizedUses(); - if (DotCFGMSSA != "") { - DOTFuncMSSAInfo CFGInfo(F, MSSA); - WriteGraph(&CFGInfo, "", false, "MSSA", DotCFGMSSA); - } else - MSSA.print(dbgs()); - - if (VerifyMemorySSA) - MSSA.verifyMemorySSA(); - return false; -} - AnalysisKey MemorySSAAnalysis::Key; MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F, @@ -2350,7 +2319,8 @@ bool MemorySSAAnalysis::Result::invalidate( PreservedAnalyses MemorySSAPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); - MSSA.ensureOptimizedUses(); + if (EnsureOptimizedUses) + MSSA.ensureOptimizedUses(); if (DotCFGMSSA != "") { DOTFuncMSSAInfo CFGInfo(F, MSSA); WriteGraph(&CFGInfo, "", false, "MSSA", DotCFGMSSA); diff --git a/contrib/llvm-project/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp b/contrib/llvm-project/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp index 756f92e1aac4..919f8f5c01d6 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp @@ -25,39 +25,6 @@ #include "llvm/Support/raw_ostream.h" using namespace llvm; -namespace { -class ModuleDebugInfoLegacyPrinter : public ModulePass { - DebugInfoFinder Finder; - -public: - static char ID; // Pass identification, replacement for typeid - ModuleDebugInfoLegacyPrinter() : ModulePass(ID) { - initializeModuleDebugInfoLegacyPrinterPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - void print(raw_ostream &O, const Module *M) const override; -}; -} - -char ModuleDebugInfoLegacyPrinter::ID = 0; -INITIALIZE_PASS(ModuleDebugInfoLegacyPrinter, "module-debuginfo", - "Decodes module-level debug info", false, true) - -ModulePass *llvm::createModuleDebugInfoPrinterPass() { - return new ModuleDebugInfoLegacyPrinter(); -} - -bool ModuleDebugInfoLegacyPrinter::runOnModule(Module &M) { - Finder.processModule(M); - return false; -} - static void printFile(raw_ostream &O, StringRef Filename, StringRef Directory, unsigned Line = 0) { if (Filename.empty()) @@ -132,11 +99,6 @@ static void printModuleDebugInfo(raw_ostream &O, const Module *M, } } -void ModuleDebugInfoLegacyPrinter::print(raw_ostream &O, - const Module *M) const { - printModuleDebugInfo(O, M, Finder); -} - ModuleDebugInfoPrinterPass::ModuleDebugInfoPrinterPass(raw_ostream &OS) : OS(OS) {} diff --git a/contrib/llvm-project/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp index 3dfa2d821e83..2076ed48ea34 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -80,6 +80,8 @@ static cl::opt<std::string> ModuleSummaryDotFile( "module-summary-dot-file", cl::Hidden, cl::value_desc("filename"), cl::desc("File to emit dot graph of new summary into")); +extern cl::opt<bool> ScalePartialSampleProfileWorkingSetSize; + // Walk through the operands of a given User via worklist iteration and populate // the set of GlobalValue references encountered. Invoked either on an // Instruction or a GlobalVariable (which walks its initializer). @@ -196,6 +198,7 @@ static void addIntrinsicToSummary( break; } + case Intrinsic::type_checked_load_relative: case Intrinsic::type_checked_load: { auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(2)); auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata()); @@ -263,7 +266,9 @@ static void computeFunctionSummary( unsigned NumInsts = 0; // Map from callee ValueId to profile count. Used to accumulate profile // counts for all static calls to a given callee. - MapVector<ValueInfo, CalleeInfo> CallGraphEdges; + MapVector<ValueInfo, CalleeInfo, DenseMap<ValueInfo, unsigned>, + std::vector<std::pair<ValueInfo, CalleeInfo>>> + CallGraphEdges; SetVector<ValueInfo> RefEdges, LoadRefEdges, StoreRefEdges; SetVector<GlobalValue::GUID> TypeTests; SetVector<FunctionSummary::VFuncId> TypeTestAssumeVCalls, @@ -282,6 +287,10 @@ static void computeFunctionSummary( std::vector<CallsiteInfo> Callsites; std::vector<AllocInfo> Allocs; +#ifndef NDEBUG + DenseSet<const CallBase *> CallsThatMayHaveMemprofSummary; +#endif + bool HasInlineAsmMaybeReferencingInternal = false; bool HasIndirBranchToBlockAddress = false; bool HasUnknownCall = false; @@ -425,6 +434,10 @@ static void computeFunctionSummary( .updateHotness(getHotness(Candidate.Count, PSI)); } + // Summarize memprof related metadata. This is only needed for ThinLTO. + if (!IsThinLTO) + continue; + // TODO: Skip indirect calls for now. Need to handle these better, likely // by creating multiple Callsites, one per target, then speculatively // devirtualize while applying clone info in the ThinLTO backends. This @@ -435,6 +448,14 @@ static void computeFunctionSummary( if (!CalledFunction) continue; + // Ensure we keep this analysis in sync with the handling in the ThinLTO + // backend (see MemProfContextDisambiguation::applyImport). Save this call + // so that we can skip it in checking the reverse case later. + assert(mayHaveMemprofSummary(CB)); +#ifndef NDEBUG + CallsThatMayHaveMemprofSummary.insert(CB); +#endif + // Compute the list of stack ids first (so we can trim them from the stack // ids on any MIBs). CallStack<MDNode, MDNode::op_iterator> InstCallsite( @@ -477,7 +498,9 @@ static void computeFunctionSummary( } } } - Index.addBlockCount(F.size()); + + if (PSI->hasPartialSampleProfile() && ScalePartialSampleProfileWorkingSetSize) + Index.addBlockCount(F.size()); std::vector<ValueInfo> Refs; if (IsThinLTO) { @@ -542,6 +565,25 @@ static void computeFunctionSummary( ? CalleeInfo::HotnessType::Cold : CalleeInfo::HotnessType::Critical); +#ifndef NDEBUG + // Make sure that all calls we decided could not have memprof summaries get a + // false value for mayHaveMemprofSummary, to ensure that this handling remains + // in sync with the ThinLTO backend handling. + if (IsThinLTO) { + for (const BasicBlock &BB : F) { + for (const Instruction &I : BB) { + const auto *CB = dyn_cast<CallBase>(&I); + if (!CB) + continue; + // We already checked these above. + if (CallsThatMayHaveMemprofSummary.count(CB)) + continue; + assert(!mayHaveMemprofSummary(CB)); + } + } + } +#endif + bool NonRenamableLocal = isNonRenamableLocal(F); bool NotEligibleForImport = NonRenamableLocal || HasInlineAsmMaybeReferencingInternal || @@ -583,12 +625,17 @@ static void findFuncPointers(const Constant *I, uint64_t StartingOffset, VTableFuncList &VTableFuncs) { // First check if this is a function pointer. if (I->getType()->isPointerTy()) { - auto Fn = dyn_cast<Function>(I->stripPointerCasts()); - // We can disregard __cxa_pure_virtual as a possible call target, as - // calls to pure virtuals are UB. - if (Fn && Fn->getName() != "__cxa_pure_virtual") - VTableFuncs.push_back({Index.getOrInsertValueInfo(Fn), StartingOffset}); - return; + auto C = I->stripPointerCasts(); + auto A = dyn_cast<GlobalAlias>(C); + if (isa<Function>(C) || (A && isa<Function>(A->getAliasee()))) { + auto GV = dyn_cast<GlobalValue>(C); + assert(GV); + // We can disregard __cxa_pure_virtual as a possible call target, as + // calls to pure virtuals are UB. + if (GV && GV->getName() != "__cxa_pure_virtual") + VTableFuncs.push_back({Index.getOrInsertValueInfo(GV), StartingOffset}); + return; + } } // Walk through the elements in the constant struct or array and recursively @@ -741,10 +788,14 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( std::function<const StackSafetyInfo *(const Function &F)> GetSSICallback) { assert(PSI); bool EnableSplitLTOUnit = false; + bool UnifiedLTO = false; if (auto *MD = mdconst::extract_or_null<ConstantInt>( M.getModuleFlag("EnableSplitLTOUnit"))) EnableSplitLTOUnit = MD->getZExtValue(); - ModuleSummaryIndex Index(/*HaveGVs=*/true, EnableSplitLTOUnit); + if (auto *MD = + mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("UnifiedLTO"))) + UnifiedLTO = MD->getZExtValue(); + ModuleSummaryIndex Index(/*HaveGVs=*/true, EnableSplitLTOUnit, UnifiedLTO); // Identify the local values in the llvm.used and llvm.compiler.used sets, // which should not be exported as they would then require renaming and @@ -1033,3 +1084,36 @@ ImmutablePass *llvm::createImmutableModuleSummaryIndexWrapperPass( INITIALIZE_PASS(ImmutableModuleSummaryIndexWrapperPass, "module-summary-info", "Module summary info", false, true) + +bool llvm::mayHaveMemprofSummary(const CallBase *CB) { + if (!CB) + return false; + if (CB->isDebugOrPseudoInst()) + return false; + auto *CI = dyn_cast<CallInst>(CB); + auto *CalledValue = CB->getCalledOperand(); + auto *CalledFunction = CB->getCalledFunction(); + if (CalledValue && !CalledFunction) { + CalledValue = CalledValue->stripPointerCasts(); + // Stripping pointer casts can reveal a called function. + CalledFunction = dyn_cast<Function>(CalledValue); + } + // Check if this is an alias to a function. If so, get the + // called aliasee for the checks below. + if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) { + assert(!CalledFunction && + "Expected null called function in callsite for alias"); + CalledFunction = dyn_cast<Function>(GA->getAliaseeObject()); + } + // Check if this is a direct call to a known function or a known + // intrinsic, or an indirect call with profile data. + if (CalledFunction) { + if (CI && CalledFunction->isIntrinsic()) + return false; + } else { + // TODO: For now skip indirect calls. See comments in + // computeFunctionSummary for what is needed to handle this. + return false; + } + return true; +} diff --git a/contrib/llvm-project/llvm/lib/Analysis/MustExecute.cpp b/contrib/llvm-project/llvm/lib/Analysis/MustExecute.cpp index 2f68996e1c60..d4b31f2b0018 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/MustExecute.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/MustExecute.cpp @@ -309,101 +309,6 @@ bool ICFLoopSafetyInfo::doesNotWriteMemoryBefore(const Instruction &I, doesNotWriteMemoryBefore(BB, CurLoop); } -namespace { -struct MustExecutePrinter : public FunctionPass { - - static char ID; // Pass identification, replacement for typeid - MustExecutePrinter() : FunctionPass(ID) { - initializeMustExecutePrinterPass(*PassRegistry::getPassRegistry()); - } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - } - bool runOnFunction(Function &F) override; -}; -struct MustBeExecutedContextPrinter : public ModulePass { - static char ID; - - MustBeExecutedContextPrinter() : ModulePass(ID) { - initializeMustBeExecutedContextPrinterPass( - *PassRegistry::getPassRegistry()); - } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - bool runOnModule(Module &M) override; -}; -} - -char MustExecutePrinter::ID = 0; -INITIALIZE_PASS_BEGIN(MustExecutePrinter, "print-mustexecute", - "Instructions which execute on loop entry", false, true) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(MustExecutePrinter, "print-mustexecute", - "Instructions which execute on loop entry", false, true) - -FunctionPass *llvm::createMustExecutePrinter() { - return new MustExecutePrinter(); -} - -char MustBeExecutedContextPrinter::ID = 0; -INITIALIZE_PASS_BEGIN(MustBeExecutedContextPrinter, - "print-must-be-executed-contexts", - "print the must-be-executed-context for all instructions", - false, true) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(MustBeExecutedContextPrinter, - "print-must-be-executed-contexts", - "print the must-be-executed-context for all instructions", - false, true) - -ModulePass *llvm::createMustBeExecutedContextPrinter() { - return new MustBeExecutedContextPrinter(); -} - -bool MustBeExecutedContextPrinter::runOnModule(Module &M) { - // We provide non-PM analysis here because the old PM doesn't like to query - // function passes from a module pass. - SmallVector<std::unique_ptr<PostDominatorTree>, 8> PDTs; - SmallVector<std::unique_ptr<DominatorTree>, 8> DTs; - SmallVector<std::unique_ptr<LoopInfo>, 8> LIs; - - GetterTy<LoopInfo> LIGetter = [&](const Function &F) { - DTs.push_back(std::make_unique<DominatorTree>(const_cast<Function &>(F))); - LIs.push_back(std::make_unique<LoopInfo>(*DTs.back())); - return LIs.back().get(); - }; - GetterTy<DominatorTree> DTGetter = [&](const Function &F) { - DTs.push_back(std::make_unique<DominatorTree>(const_cast<Function&>(F))); - return DTs.back().get(); - }; - GetterTy<PostDominatorTree> PDTGetter = [&](const Function &F) { - PDTs.push_back( - std::make_unique<PostDominatorTree>(const_cast<Function &>(F))); - return PDTs.back().get(); - }; - MustBeExecutedContextExplorer Explorer( - /* ExploreInterBlock */ true, - /* ExploreCFGForward */ true, - /* ExploreCFGBackward */ true, LIGetter, DTGetter, PDTGetter); - - for (Function &F : M) { - for (Instruction &I : instructions(F)) { - dbgs() << "-- Explore context of: " << I << "\n"; - for (const Instruction *CI : Explorer.range(&I)) - dbgs() << " [F: " << CI->getFunction()->getName() << "] " << *CI - << "\n"; - } - } - - return false; -} - static bool isMustExecuteIn(const Instruction &I, Loop *L, DominatorTree *DT) { // TODO: merge these two routines. For the moment, we display the best // result obtained by *either* implementation. This is a bit unfair since no @@ -467,16 +372,6 @@ public: }; } // namespace -bool MustExecutePrinter::runOnFunction(Function &F) { - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - - MustExecuteAnnotatedWriter Writer(F, DT, LI); - F.print(dbgs(), &Writer); - - return false; -} - /// Return true if \p L might be an endless loop. static bool maybeEndlessLoop(const Loop &L) { if (L.getHeader()->getParent()->hasFnAttribute(Attribute::WillReturn)) diff --git a/contrib/llvm-project/llvm/lib/Analysis/PHITransAddr.cpp b/contrib/llvm-project/llvm/lib/Analysis/PHITransAddr.cpp index 1262530ae642..5700fd664a4c 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/PHITransAddr.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/PHITransAddr.cpp @@ -25,13 +25,8 @@ static cl::opt<bool> EnableAddPhiTranslation( "gvn-add-phi-translation", cl::init(false), cl::Hidden, cl::desc("Enable phi-translation of add instructions")); -static bool CanPHITrans(Instruction *Inst) { - if (isa<PHINode>(Inst) || - isa<GetElementPtrInst>(Inst)) - return true; - - if (isa<CastInst>(Inst) && - isSafeToSpeculativelyExecute(Inst)) +static bool canPHITrans(Instruction *Inst) { + if (isa<PHINode>(Inst) || isa<GetElementPtrInst>(Inst) || isa<CastInst>(Inst)) return true; if (Inst->getOpcode() == Instruction::Add && @@ -53,47 +48,42 @@ LLVM_DUMP_METHOD void PHITransAddr::dump() const { } #endif - -static bool VerifySubExpr(Value *Expr, - SmallVectorImpl<Instruction*> &InstInputs) { +static bool verifySubExpr(Value *Expr, + SmallVectorImpl<Instruction *> &InstInputs) { // If this is a non-instruction value, there is nothing to do. Instruction *I = dyn_cast<Instruction>(Expr); if (!I) return true; // If it's an instruction, it is either in Tmp or its operands recursively // are. - SmallVectorImpl<Instruction *>::iterator Entry = find(InstInputs, I); - if (Entry != InstInputs.end()) { + if (auto Entry = find(InstInputs, I); Entry != InstInputs.end()) { InstInputs.erase(Entry); return true; } // If it isn't in the InstInputs list it is a subexpr incorporated into the // address. Validate that it is phi translatable. - if (!CanPHITrans(I)) { + if (!canPHITrans(I)) { errs() << "Instruction in PHITransAddr is not phi-translatable:\n"; errs() << *I << '\n'; llvm_unreachable("Either something is missing from InstInputs or " - "CanPHITrans is wrong."); + "canPHITrans is wrong."); } // Validate the operands of the instruction. - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) - if (!VerifySubExpr(I->getOperand(i), InstInputs)) - return false; - - return true; + return all_of(I->operands(), + [&](Value *Op) { return verifySubExpr(Op, InstInputs); }); } -/// Verify - Check internal consistency of this data structure. If the +/// verify - Check internal consistency of this data structure. If the /// structure is valid, it returns true. If invalid, it prints errors and /// returns false. -bool PHITransAddr::Verify() const { +bool PHITransAddr::verify() const { if (!Addr) return true; SmallVector<Instruction*, 8> Tmp(InstInputs.begin(), InstInputs.end()); - if (!VerifySubExpr(Addr, Tmp)) + if (!verifySubExpr(Addr, Tmp)) return false; if (!Tmp.empty()) { @@ -107,26 +97,23 @@ bool PHITransAddr::Verify() const { return true; } - -/// IsPotentiallyPHITranslatable - If this needs PHI translation, return true +/// isPotentiallyPHITranslatable - If this needs PHI translation, return true /// if we have some hope of doing it. This should be used as a filter to /// avoid calling PHITranslateValue in hopeless situations. -bool PHITransAddr::IsPotentiallyPHITranslatable() const { +bool PHITransAddr::isPotentiallyPHITranslatable() const { // If the input value is not an instruction, or if it is not defined in CurBB, // then we don't need to phi translate it. Instruction *Inst = dyn_cast<Instruction>(Addr); - return !Inst || CanPHITrans(Inst); + return !Inst || canPHITrans(Inst); } - static void RemoveInstInputs(Value *V, SmallVectorImpl<Instruction*> &InstInputs) { Instruction *I = dyn_cast<Instruction>(V); if (!I) return; // If the instruction is in the InstInputs list, remove it. - SmallVectorImpl<Instruction *>::iterator Entry = find(InstInputs, I); - if (Entry != InstInputs.end()) { + if (auto Entry = find(InstInputs, I); Entry != InstInputs.end()) { InstInputs.erase(Entry); return; } @@ -134,15 +121,14 @@ static void RemoveInstInputs(Value *V, assert(!isa<PHINode>(I) && "Error, removing something that isn't an input"); // Otherwise, it must have instruction inputs itself. Zap them recursively. - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - if (Instruction *Op = dyn_cast<Instruction>(I->getOperand(i))) - RemoveInstInputs(Op, InstInputs); - } + for (Value *Op : I->operands()) + if (Instruction *OpInst = dyn_cast<Instruction>(Op)) + RemoveInstInputs(OpInst, InstInputs); } -Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, - BasicBlock *PredBB, - const DominatorTree *DT) { +Value *PHITransAddr::translateSubExpr(Value *V, BasicBlock *CurBB, + BasicBlock *PredBB, + const DominatorTree *DT) { // If this is a non-instruction value, it can't require PHI translation. Instruction *Inst = dyn_cast<Instruction>(V); if (!Inst) return V; @@ -166,18 +152,17 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, // If this is a PHI, go ahead and translate it. if (PHINode *PN = dyn_cast<PHINode>(Inst)) - return AddAsInput(PN->getIncomingValueForBlock(PredBB)); + return addAsInput(PN->getIncomingValueForBlock(PredBB)); // If this is a non-phi value, and it is analyzable, we can incorporate it // into the expression by making all instruction operands be inputs. - if (!CanPHITrans(Inst)) + if (!canPHITrans(Inst)) return nullptr; // All instruction operands are now inputs (and of course, they may also be // defined in this block, so they may need to be phi translated themselves. - for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i) - if (Instruction *Op = dyn_cast<Instruction>(Inst->getOperand(i))) - InstInputs.push_back(Op); + for (Value *Op : Inst->operands()) + addAsInput(Op); } // Ok, it must be an intermediate result (either because it started that way @@ -185,18 +170,19 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, // operands need to be phi translated, and if so, reconstruct it. if (CastInst *Cast = dyn_cast<CastInst>(Inst)) { - if (!isSafeToSpeculativelyExecute(Cast)) return nullptr; - Value *PHIIn = PHITranslateSubExpr(Cast->getOperand(0), CurBB, PredBB, DT); + Value *PHIIn = translateSubExpr(Cast->getOperand(0), CurBB, PredBB, DT); if (!PHIIn) return nullptr; if (PHIIn == Cast->getOperand(0)) return Cast; // Find an available version of this cast. - // Constants are trivial to find. - if (Constant *C = dyn_cast<Constant>(PHIIn)) - return AddAsInput(ConstantExpr::getCast(Cast->getOpcode(), - C, Cast->getType())); + // Try to simplify cast first. + if (Value *V = simplifyCastInst(Cast->getOpcode(), PHIIn, Cast->getType(), + {DL, TLI, DT, AC})) { + RemoveInstInputs(PHIIn, InstInputs); + return addAsInput(V); + } // Otherwise we have to see if a casted version of the incoming pointer // is available. If so, we can use it, otherwise we have to fail. @@ -214,11 +200,11 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { SmallVector<Value*, 8> GEPOps; bool AnyChanged = false; - for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) { - Value *GEPOp = PHITranslateSubExpr(GEP->getOperand(i), CurBB, PredBB, DT); + for (Value *Op : GEP->operands()) { + Value *GEPOp = translateSubExpr(Op, CurBB, PredBB, DT); if (!GEPOp) return nullptr; - AnyChanged |= GEPOp != GEP->getOperand(i); + AnyChanged |= GEPOp != Op; GEPOps.push_back(GEPOp); } @@ -232,7 +218,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) RemoveInstInputs(GEPOps[i], InstInputs); - return AddAsInput(V); + return addAsInput(V); } // Scan to see if we have this GEP available. @@ -259,7 +245,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, bool isNSW = cast<BinaryOperator>(Inst)->hasNoSignedWrap(); bool isNUW = cast<BinaryOperator>(Inst)->hasNoUnsignedWrap(); - Value *LHS = PHITranslateSubExpr(Inst->getOperand(0), CurBB, PredBB, DT); + Value *LHS = translateSubExpr(Inst->getOperand(0), CurBB, PredBB, DT); if (!LHS) return nullptr; // If the PHI translated LHS is an add of a constant, fold the immediates. @@ -273,7 +259,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, // If the old 'LHS' was an input, add the new 'LHS' as an input. if (is_contained(InstInputs, BOp)) { RemoveInstInputs(BOp, InstInputs); - AddAsInput(LHS); + addAsInput(LHS); } } @@ -282,7 +268,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, // If we simplified the operands, the LHS is no longer an input, but Res // is. RemoveInstInputs(LHS, InstInputs); - return AddAsInput(Res); + return addAsInput(Res); } // If we didn't modify the add, just return it. @@ -306,21 +292,19 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, return nullptr; } - /// PHITranslateValue - PHI translate the current address up the CFG from /// CurBB to Pred, updating our state to reflect any needed changes. If -/// 'MustDominate' is true, the translated value must dominate -/// PredBB. This returns true on failure and sets Addr to null. -bool PHITransAddr::PHITranslateValue(BasicBlock *CurBB, BasicBlock *PredBB, - const DominatorTree *DT, - bool MustDominate) { +/// 'MustDominate' is true, the translated value must dominate PredBB. +Value *PHITransAddr::translateValue(BasicBlock *CurBB, BasicBlock *PredBB, + const DominatorTree *DT, + bool MustDominate) { assert(DT || !MustDominate); - assert(Verify() && "Invalid PHITransAddr!"); + assert(verify() && "Invalid PHITransAddr!"); if (DT && DT->isReachableFromEntry(PredBB)) - Addr = PHITranslateSubExpr(Addr, CurBB, PredBB, DT); + Addr = translateSubExpr(Addr, CurBB, PredBB, DT); else Addr = nullptr; - assert(Verify() && "Invalid PHITransAddr!"); + assert(verify() && "Invalid PHITransAddr!"); if (MustDominate) // Make sure the value is live in the predecessor. @@ -328,7 +312,7 @@ bool PHITransAddr::PHITranslateValue(BasicBlock *CurBB, BasicBlock *PredBB, if (!DT->dominates(Inst->getParent(), PredBB)) Addr = nullptr; - return Addr == nullptr; + return Addr; } /// PHITranslateWithInsertion - PHI translate this value into the specified @@ -338,14 +322,14 @@ bool PHITransAddr::PHITranslateValue(BasicBlock *CurBB, BasicBlock *PredBB, /// All newly created instructions are added to the NewInsts list. This /// returns null on failure. /// -Value *PHITransAddr:: -PHITranslateWithInsertion(BasicBlock *CurBB, BasicBlock *PredBB, - const DominatorTree &DT, - SmallVectorImpl<Instruction*> &NewInsts) { +Value * +PHITransAddr::translateWithInsertion(BasicBlock *CurBB, BasicBlock *PredBB, + const DominatorTree &DT, + SmallVectorImpl<Instruction *> &NewInsts) { unsigned NISize = NewInsts.size(); // Attempt to PHI translate with insertion. - Addr = InsertPHITranslatedSubExpr(Addr, CurBB, PredBB, DT, NewInsts); + Addr = insertTranslatedSubExpr(Addr, CurBB, PredBB, DT, NewInsts); // If successful, return the new value. if (Addr) return Addr; @@ -356,21 +340,20 @@ PHITranslateWithInsertion(BasicBlock *CurBB, BasicBlock *PredBB, return nullptr; } - -/// InsertPHITranslatedPointer - Insert a computation of the PHI translated +/// insertTranslatedSubExpr - Insert a computation of the PHI translated /// version of 'V' for the edge PredBB->CurBB into the end of the PredBB /// block. All newly created instructions are added to the NewInsts list. /// This returns null on failure. /// -Value *PHITransAddr:: -InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, - BasicBlock *PredBB, const DominatorTree &DT, - SmallVectorImpl<Instruction*> &NewInsts) { +Value *PHITransAddr::insertTranslatedSubExpr( + Value *InVal, BasicBlock *CurBB, BasicBlock *PredBB, + const DominatorTree &DT, SmallVectorImpl<Instruction *> &NewInsts) { // See if we have a version of this value already available and dominating // PredBB. If so, there is no need to insert a new instance of it. PHITransAddr Tmp(InVal, DL, AC); - if (!Tmp.PHITranslateValue(CurBB, PredBB, &DT, /*MustDominate=*/true)) - return Tmp.getAddr(); + if (Value *Addr = + Tmp.translateValue(CurBB, PredBB, &DT, /*MustDominate=*/true)) + return Addr; // We don't need to PHI translate values which aren't instructions. auto *Inst = dyn_cast<Instruction>(InVal); @@ -379,9 +362,8 @@ InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, // Handle cast of PHI translatable value. if (CastInst *Cast = dyn_cast<CastInst>(Inst)) { - if (!isSafeToSpeculativelyExecute(Cast)) return nullptr; - Value *OpVal = InsertPHITranslatedSubExpr(Cast->getOperand(0), - CurBB, PredBB, DT, NewInsts); + Value *OpVal = insertTranslatedSubExpr(Cast->getOperand(0), CurBB, PredBB, + DT, NewInsts); if (!OpVal) return nullptr; // Otherwise insert a cast at the end of PredBB. @@ -397,9 +379,8 @@ InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { SmallVector<Value*, 8> GEPOps; BasicBlock *CurBB = GEP->getParent(); - for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) { - Value *OpVal = InsertPHITranslatedSubExpr(GEP->getOperand(i), - CurBB, PredBB, DT, NewInsts); + for (Value *Op : GEP->operands()) { + Value *OpVal = insertTranslatedSubExpr(Op, CurBB, PredBB, DT, NewInsts); if (!OpVal) return nullptr; GEPOps.push_back(OpVal); } @@ -422,8 +403,8 @@ InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, // This needs to be evaluated carefully to consider its cost trade offs. // PHI translate the LHS. - Value *OpVal = InsertPHITranslatedSubExpr(Inst->getOperand(0), - CurBB, PredBB, DT, NewInsts); + Value *OpVal = insertTranslatedSubExpr(Inst->getOperand(0), CurBB, PredBB, + DT, NewInsts); if (OpVal == nullptr) return nullptr; diff --git a/contrib/llvm-project/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/contrib/llvm-project/llvm/lib/Analysis/ProfileSummaryInfo.cpp index 6b9f15bf2f64..203f1e42733f 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ProfileSummaryInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ProfileSummaryInfo.cpp @@ -95,129 +95,11 @@ std::optional<uint64_t> ProfileSummaryInfo::getProfileCount( return std::nullopt; } -/// Returns true if the function's entry is hot. If it returns false, it -/// either means it is not hot or it is unknown whether it is hot or not (for -/// example, no profile data is available). -bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) const { - if (!F || !hasProfileSummary()) - return false; - auto FunctionCount = F->getEntryCount(); - // FIXME: The heuristic used below for determining hotness is based on - // preliminary SPEC tuning for inliner. This will eventually be a - // convenience method that calls isHotCount. - return FunctionCount && isHotCount(FunctionCount->getCount()); -} - -/// Returns true if the function contains hot code. This can include a hot -/// function entry count, hot basic block, or (in the case of Sample PGO) -/// hot total call edge count. -/// If it returns false, it either means it is not hot or it is unknown -/// (for example, no profile data is available). -bool ProfileSummaryInfo::isFunctionHotInCallGraph( - const Function *F, BlockFrequencyInfo &BFI) const { - if (!F || !hasProfileSummary()) - return false; - if (auto FunctionCount = F->getEntryCount()) - if (isHotCount(FunctionCount->getCount())) - return true; - - if (hasSampleProfile()) { - uint64_t TotalCallCount = 0; - for (const auto &BB : *F) - for (const auto &I : BB) - if (isa<CallInst>(I) || isa<InvokeInst>(I)) - if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr)) - TotalCallCount += *CallCount; - if (isHotCount(TotalCallCount)) - return true; - } - for (const auto &BB : *F) - if (isHotBlock(&BB, &BFI)) - return true; - return false; -} - -/// Returns true if the function only contains cold code. This means that -/// the function entry and blocks are all cold, and (in the case of Sample PGO) -/// the total call edge count is cold. -/// If it returns false, it either means it is not cold or it is unknown -/// (for example, no profile data is available). -bool ProfileSummaryInfo::isFunctionColdInCallGraph( - const Function *F, BlockFrequencyInfo &BFI) const { - if (!F || !hasProfileSummary()) - return false; - if (auto FunctionCount = F->getEntryCount()) - if (!isColdCount(FunctionCount->getCount())) - return false; - - if (hasSampleProfile()) { - uint64_t TotalCallCount = 0; - for (const auto &BB : *F) - for (const auto &I : BB) - if (isa<CallInst>(I) || isa<InvokeInst>(I)) - if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr)) - TotalCallCount += *CallCount; - if (!isColdCount(TotalCallCount)) - return false; - } - for (const auto &BB : *F) - if (!isColdBlock(&BB, &BFI)) - return false; - return true; -} - bool ProfileSummaryInfo::isFunctionHotnessUnknown(const Function &F) const { assert(hasPartialSampleProfile() && "Expect partial sample profile"); return !F.getEntryCount(); } -template <bool isHot> -bool ProfileSummaryInfo::isFunctionHotOrColdInCallGraphNthPercentile( - int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) const { - if (!F || !hasProfileSummary()) - return false; - if (auto FunctionCount = F->getEntryCount()) { - if (isHot && - isHotCountNthPercentile(PercentileCutoff, FunctionCount->getCount())) - return true; - if (!isHot && - !isColdCountNthPercentile(PercentileCutoff, FunctionCount->getCount())) - return false; - } - if (hasSampleProfile()) { - uint64_t TotalCallCount = 0; - for (const auto &BB : *F) - for (const auto &I : BB) - if (isa<CallInst>(I) || isa<InvokeInst>(I)) - if (auto CallCount = getProfileCount(cast<CallBase>(I), nullptr)) - TotalCallCount += *CallCount; - if (isHot && isHotCountNthPercentile(PercentileCutoff, TotalCallCount)) - return true; - if (!isHot && !isColdCountNthPercentile(PercentileCutoff, TotalCallCount)) - return false; - } - for (const auto &BB : *F) { - if (isHot && isHotBlockNthPercentile(PercentileCutoff, &BB, &BFI)) - return true; - if (!isHot && !isColdBlockNthPercentile(PercentileCutoff, &BB, &BFI)) - return false; - } - return !isHot; -} - -// Like isFunctionHotInCallGraph but for a given cutoff. -bool ProfileSummaryInfo::isFunctionHotInCallGraphNthPercentile( - int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) const { - return isFunctionHotOrColdInCallGraphNthPercentile<true>( - PercentileCutoff, F, BFI); -} - -bool ProfileSummaryInfo::isFunctionColdInCallGraphNthPercentile( - int PercentileCutoff, const Function *F, BlockFrequencyInfo &BFI) const { - return isFunctionHotOrColdInCallGraphNthPercentile<false>( - PercentileCutoff, F, BFI); -} - /// Returns true if the function's entry is a cold. If it returns false, it /// either means it is not cold or it is unknown whether it is cold or not (for /// example, no profile data is available). @@ -325,38 +207,6 @@ uint64_t ProfileSummaryInfo::getOrCompColdCountThreshold() const { return ColdCountThreshold.value_or(0); } -bool ProfileSummaryInfo::isHotBlock(const BasicBlock *BB, - BlockFrequencyInfo *BFI) const { - auto Count = BFI->getBlockProfileCount(BB); - return Count && isHotCount(*Count); -} - -bool ProfileSummaryInfo::isColdBlock(const BasicBlock *BB, - BlockFrequencyInfo *BFI) const { - auto Count = BFI->getBlockProfileCount(BB); - return Count && isColdCount(*Count); -} - -template <bool isHot> -bool ProfileSummaryInfo::isHotOrColdBlockNthPercentile( - int PercentileCutoff, const BasicBlock *BB, BlockFrequencyInfo *BFI) const { - auto Count = BFI->getBlockProfileCount(BB); - if (isHot) - return Count && isHotCountNthPercentile(PercentileCutoff, *Count); - else - return Count && isColdCountNthPercentile(PercentileCutoff, *Count); -} - -bool ProfileSummaryInfo::isHotBlockNthPercentile( - int PercentileCutoff, const BasicBlock *BB, BlockFrequencyInfo *BFI) const { - return isHotOrColdBlockNthPercentile<true>(PercentileCutoff, BB, BFI); -} - -bool ProfileSummaryInfo::isColdBlockNthPercentile( - int PercentileCutoff, const BasicBlock *BB, BlockFrequencyInfo *BFI) const { - return isHotOrColdBlockNthPercentile<false>(PercentileCutoff, BB, BFI); -} - bool ProfileSummaryInfo::isHotCallSite(const CallBase &CB, BlockFrequencyInfo *BFI) const { auto C = getProfileCount(CB, BFI); diff --git a/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolution.cpp index 8c62fc37c4a3..111d4d30aab9 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolution.cpp @@ -71,11 +71,13 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -134,10 +136,10 @@ using namespace PatternMatch; #define DEBUG_TYPE "scalar-evolution" -STATISTIC(NumTripCountsComputed, - "Number of loops with predictable loop counts"); -STATISTIC(NumTripCountsNotComputed, - "Number of loops without predictable loop counts"); +STATISTIC(NumExitCountsComputed, + "Number of loop exits with predictable exit counts"); +STATISTIC(NumExitCountsNotComputed, + "Number of loop exits without predictable exit counts"); STATISTIC(NumBruteForceTripCountsComputed, "Number of loops with trip counts computed by force"); @@ -160,10 +162,6 @@ static cl::opt<bool, true> VerifySCEVOpt( static cl::opt<bool> VerifySCEVStrict( "verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed")); -static cl::opt<bool> - VerifySCEVMap("verify-scev-maps", cl::Hidden, - cl::desc("Verify no dangling value in ScalarEvolution's " - "ExprValueMap (slow)")); static cl::opt<bool> VerifyIR( "scev-verify-ir", cl::Hidden, @@ -271,6 +269,9 @@ void SCEV::print(raw_ostream &OS) const { case scConstant: cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false); return; + case scVScale: + OS << "vscale"; + return; case scPtrToInt: { const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this); const SCEV *Op = PtrToInt->getOperand(); @@ -366,31 +367,9 @@ void SCEV::print(raw_ostream &OS) const { OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; return; } - case scUnknown: { - const SCEVUnknown *U = cast<SCEVUnknown>(this); - Type *AllocTy; - if (U->isSizeOf(AllocTy)) { - OS << "sizeof(" << *AllocTy << ")"; - return; - } - if (U->isAlignOf(AllocTy)) { - OS << "alignof(" << *AllocTy << ")"; - return; - } - - Type *CTy; - Constant *FieldNo; - if (U->isOffsetOf(CTy, FieldNo)) { - OS << "offsetof(" << *CTy << ", "; - FieldNo->printAsOperand(OS, false); - OS << ")"; - return; - } - - // Otherwise just print it normally. - U->getValue()->printAsOperand(OS, false); + case scUnknown: + cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false); return; - } case scCouldNotCompute: OS << "***COULDNOTCOMPUTE***"; return; @@ -402,6 +381,8 @@ Type *SCEV::getType() const { switch (getSCEVType()) { case scConstant: return cast<SCEVConstant>(this)->getType(); + case scVScale: + return cast<SCEVVScale>(this)->getType(); case scPtrToInt: case scTruncate: case scZeroExtend: @@ -433,6 +414,7 @@ Type *SCEV::getType() const { ArrayRef<const SCEV *> SCEV::operands() const { switch (getSCEVType()) { case scConstant: + case scVScale: case scUnknown: return {}; case scPtrToInt: @@ -515,6 +497,18 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { return getConstant(ConstantInt::get(ITy, V, isSigned)); } +const SCEV *ScalarEvolution::getVScale(Type *Ty) { + FoldingSetNodeID ID; + ID.AddInteger(scVScale); + ID.AddPointer(Ty); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; +} + SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} @@ -574,67 +568,6 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) { setValPtr(New); } -bool SCEVUnknown::isSizeOf(Type *&AllocTy) const { - if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) - if (VCE->getOpcode() == Instruction::PtrToInt) - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) - if (CE->getOpcode() == Instruction::GetElementPtr && - CE->getOperand(0)->isNullValue() && - CE->getNumOperands() == 2) - if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1))) - if (CI->isOne()) { - AllocTy = cast<GEPOperator>(CE)->getSourceElementType(); - return true; - } - - return false; -} - -bool SCEVUnknown::isAlignOf(Type *&AllocTy) const { - if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) - if (VCE->getOpcode() == Instruction::PtrToInt) - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) - if (CE->getOpcode() == Instruction::GetElementPtr && - CE->getOperand(0)->isNullValue()) { - Type *Ty = cast<GEPOperator>(CE)->getSourceElementType(); - if (StructType *STy = dyn_cast<StructType>(Ty)) - if (!STy->isPacked() && - CE->getNumOperands() == 3 && - CE->getOperand(1)->isNullValue()) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2))) - if (CI->isOne() && - STy->getNumElements() == 2 && - STy->getElementType(0)->isIntegerTy(1)) { - AllocTy = STy->getElementType(1); - return true; - } - } - } - - return false; -} - -bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { - if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) - if (VCE->getOpcode() == Instruction::PtrToInt) - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) - if (CE->getOpcode() == Instruction::GetElementPtr && - CE->getNumOperands() == 3 && - CE->getOperand(0)->isNullValue() && - CE->getOperand(1)->isNullValue()) { - Type *Ty = cast<GEPOperator>(CE)->getSourceElementType(); - // Ignore vector types here so that ScalarEvolutionExpander doesn't - // emit getelementptrs that index into vectors. - if (Ty->isStructTy() || Ty->isArrayTy()) { - CTy = Ty; - FieldNo = CE->getOperand(2); - return true; - } - } - - return false; -} - //===----------------------------------------------------------------------===// // SCEV Utilities //===----------------------------------------------------------------------===// @@ -785,6 +718,12 @@ CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV, return LA.ult(RA) ? -1 : 1; } + case scVScale: { + const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType()); + const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType()); + return LTy->getBitWidth() - RTy->getBitWidth(); + } + case scAddRecExpr: { const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); @@ -798,9 +737,8 @@ CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV, assert(LHead != RHead && "Two loops share the same header?"); if (DT.dominates(LHead, RHead)) return 1; - else - assert(DT.dominates(RHead, LHead) && - "No dominance between recurrences used by one SCEV?"); + assert(DT.dominates(RHead, LHead) && + "No dominance between recurrences used by one SCEV?"); return -1; } @@ -984,7 +922,7 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, unsigned T = 1; for (unsigned i = 3; i <= K; ++i) { APInt Mult(W, i); - unsigned TwoFactors = Mult.countTrailingZeros(); + unsigned TwoFactors = Mult.countr_zero(); T += TwoFactors; Mult.lshrInPlace(TwoFactors); OddFactorial *= Mult; @@ -1252,10 +1190,9 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, if (numTruncs < 2) { if (isa<SCEVAddExpr>(Op)) return getAddExpr(Operands); - else if (isa<SCEVMulExpr>(Op)) + if (isa<SCEVMulExpr>(Op)) return getMulExpr(Operands); - else - llvm_unreachable("Unexpected SCEV type for Op."); + llvm_unreachable("Unexpected SCEV type for Op."); } // Although we checked in the beginning that ID is not in the cache, it is // possible that during recursion and different modification ID was inserted @@ -1273,7 +1210,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, } // Return zero if truncating to known zeros. - uint32_t MinTrailingZeros = GetMinTrailingZeros(Op); + uint32_t MinTrailingZeros = getMinTrailingZeros(Op); if (MinTrailingZeros >= getTypeSizeInBits(Ty)) return getZero(Ty); @@ -1558,7 +1495,7 @@ static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, // Find number of trailing zeros of (x + y + ...) w/o the C first: uint32_t TZ = BitWidth; for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I) - TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I))); + TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I))); if (TZ) { // Set D to be as many least significant bits of C as possible while still // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap: @@ -1575,7 +1512,7 @@ static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const APInt &ConstantStart, const SCEV *Step) { const unsigned BitWidth = ConstantStart.getBitWidth(); - const uint32_t TZ = SE.GetMinTrailingZeros(Step); + const uint32_t TZ = SE.getMinTrailingZeros(Step); if (TZ) return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth) : ConstantStart; @@ -1614,10 +1551,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { assert(!Op->getType()->isPointerTy() && "Can't extend pointer!"); Ty = getEffectiveSCEVType(Ty); - FoldID ID; - ID.addInteger(scZeroExtend); - ID.addPointer(Op); - ID.addPointer(Ty); + FoldID ID(scZeroExtend, Op, Ty); auto Iter = FoldCache.find(ID); if (Iter != FoldCache.end()) return Iter->second; @@ -1684,11 +1618,6 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); - if (!AR->hasNoUnsignedWrap()) { - auto NewFlags = proveNoWrapViaConstantRanges(AR); - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); - } - // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->hasNoUnsignedWrap()) { @@ -1771,7 +1700,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // these to compute max backedge taken counts, but can still use // these to prove lack of overflow. Use this fact to avoid // doing extra work that may not pay off. - if (!isa<SCEVCouldNotCompute>(MaxBECount) || !AC.assumptions().empty()) { + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { auto NewFlags = proveNoUnsignedWrapViaInduction(AR); setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); @@ -1917,6 +1847,27 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, } } + // zext(umin(x, y)) -> umin(zext(x), zext(y)) + // zext(umax(x, y)) -> umax(zext(x), zext(y)) + if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) { + auto *MinMax = cast<SCEVMinMaxExpr>(Op); + SmallVector<const SCEV *, 4> Operands; + for (auto *Operand : MinMax->operands()) + Operands.push_back(getZeroExtendExpr(Operand, Ty)); + if (isa<SCEVUMinExpr>(MinMax)) + return getUMinExpr(Operands); + return getUMaxExpr(Operands); + } + + // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y)) + if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) { + assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!"); + SmallVector<const SCEV *, 4> Operands; + for (auto *Operand : MinMax->operands()) + Operands.push_back(getZeroExtendExpr(Operand, Ty)); + return getUMinExpr(Operands, /*Sequential*/ true); + } + // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; @@ -1936,10 +1887,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { assert(!Op->getType()->isPointerTy() && "Can't extend pointer!"); Ty = getEffectiveSCEVType(Ty); - FoldID ID; - ID.addInteger(scSignExtend); - ID.addPointer(Op); - ID.addPointer(Ty); + FoldID ID(scSignExtend, Op, Ty); auto Iter = FoldCache.find(ID); if (Iter != FoldCache.end()) return Iter->second; @@ -2045,11 +1993,6 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); - if (!AR->hasNoSignedWrap()) { - auto NewFlags = proveNoWrapViaConstantRanges(AR); - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); - } - // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. if (AR->hasNoSignedWrap()) { @@ -2177,6 +2120,18 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, if (isKnownNonNegative(Op)) return getZeroExtendExpr(Op, Ty, Depth + 1); + // sext(smin(x, y)) -> smin(sext(x), sext(y)) + // sext(smax(x, y)) -> smax(sext(x), sext(y)) + if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) { + auto *MinMax = cast<SCEVMinMaxExpr>(Op); + SmallVector<const SCEV *, 4> Operands; + for (auto *Operand : MinMax->operands()) + Operands.push_back(getSignExtendExpr(Operand, Ty)); + if (isa<SCEVSMinExpr>(MinMax)) + return getSMinExpr(Operands); + return getSMaxExpr(Operands); + } + // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; @@ -2377,25 +2332,42 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, // Can we use context to prove the fact we need? if (!CtxI) return false; - // We can prove that add(x, constant) doesn't wrap if isKnownPredicateAt can - // guarantee that x <= max_int - constant at the given context. - // TODO: Support other operations. - if (BinOp != Instruction::Add) + // TODO: Support mul. + if (BinOp == Instruction::Mul) return false; auto *RHSC = dyn_cast<SCEVConstant>(RHS); // TODO: Lift this limitation. if (!RHSC) return false; APInt C = RHSC->getAPInt(); - // TODO: Also lift this limitation. - if (Signed && C.isNegative()) - return false; unsigned NumBits = C.getBitWidth(); - APInt Max = - Signed ? APInt::getSignedMaxValue(NumBits) : APInt::getMaxValue(NumBits); - APInt Limit = Max - C; + bool IsSub = (BinOp == Instruction::Sub); + bool IsNegativeConst = (Signed && C.isNegative()); + // Compute the direction and magnitude by which we need to check overflow. + bool OverflowDown = IsSub ^ IsNegativeConst; + APInt Magnitude = C; + if (IsNegativeConst) { + if (C == APInt::getSignedMinValue(NumBits)) + // TODO: SINT_MIN on inversion gives the same negative value, we don't + // want to deal with that. + return false; + Magnitude = -C; + } + ICmpInst::Predicate Pred = Signed ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; - return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI); + if (OverflowDown) { + // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS. + APInt Min = Signed ? APInt::getSignedMinValue(NumBits) + : APInt::getMinValue(NumBits); + APInt Limit = Min + Magnitude; + return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI); + } else { + // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude. + APInt Max = Signed ? APInt::getSignedMaxValue(NumBits) + : APInt::getMaxValue(NumBits); + APInt Limit = Max - Magnitude; + return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI); + } } std::optional<SCEV::NoWrapFlags> @@ -3229,9 +3201,20 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, for (const SCEV *AddRecOp : AddRec->operands()) Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap, Depth + 1)); - + // Let M be the minimum representable signed value. AddRec with nsw + // multiplied by -1 can have signed overflow if and only if it takes a + // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the + // maximum signed value. In all other cases signed overflow is + // impossible. + auto FlagsMask = SCEV::FlagNW; + if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) { + auto MinInt = + APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType())); + if (getSignedRangeMin(AddRec) != MinInt) + FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW); + } return getAddRecExpr(Operands, AddRec->getLoop(), - AddRec->getNoWrapFlags(SCEV::FlagNW)); + AddRec->getNoWrapFlags(FlagsMask)); } } } @@ -3273,9 +3256,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // if they are loop invariant w.r.t. the recurrence. SmallVector<const SCEV *, 8> LIOps; const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); - const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { + if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) { LIOps.push_back(Ops[i]); Ops.erase(Ops.begin()+i); --i; --e; @@ -3298,7 +3280,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // will be inferred if either NUW or NSW is true. SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec}); const SCEV *NewRec = getAddRecExpr( - NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags)); + NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(Flags)); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -3332,7 +3314,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, ++OtherIdx) { const SCEVAddRecExpr *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); - if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) + if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop()) continue; // Limit max number of arguments to avoid creation of unreasonably big @@ -3371,7 +3353,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1)); } if (!Overflow) { - const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop, + const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), SCEV::FlagAnyWrap); if (Ops.size() == 2) return NewAddRec; Ops[Idx] = NewAddRec; @@ -3455,7 +3437,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // its operands. // TODO: Generalize this to non-constants by using known-bits information. Type *Ty = LHS->getType(); - unsigned LZ = RHSC->getAPInt().countLeadingZeros(); + unsigned LZ = RHSC->getAPInt().countl_zero(); unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; // For non-power-of-two values, effectively round the value up to the // nearest power of two. @@ -3867,15 +3849,18 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, ++Idx; assert(Idx < Ops.size()); auto FoldOp = [&](const APInt &LHS, const APInt &RHS) { - if (Kind == scSMaxExpr) + switch (Kind) { + case scSMaxExpr: return APIntOps::smax(LHS, RHS); - else if (Kind == scSMinExpr) + case scSMinExpr: return APIntOps::smin(LHS, RHS); - else if (Kind == scUMaxExpr) + case scUMaxExpr: return APIntOps::umax(LHS, RHS); - else if (Kind == scUMinExpr) + case scUMinExpr: return APIntOps::umin(LHS, RHS); - llvm_unreachable("Unknown SCEV min/max opcode"); + default: + llvm_unreachable("Unknown SCEV min/max opcode"); + } }; while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { @@ -4050,6 +4035,8 @@ public: RetVal visitConstant(const SCEVConstant *Constant) { return Constant; } + RetVal visitVScale(const SCEVVScale *VScale) { return VScale; } + RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; } RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; } @@ -4096,6 +4083,7 @@ public: static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) { switch (Kind) { case scConstant: + case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -4131,38 +4119,15 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { // with the notable exception of umin_seq, where only poison from the first // operand is (unconditionally) propagated. struct SCEVPoisonCollector { - bool LookThroughSeq; + bool LookThroughMaybePoisonBlocking; SmallPtrSet<const SCEV *, 4> MaybePoison; - SCEVPoisonCollector(bool LookThroughSeq) : LookThroughSeq(LookThroughSeq) {} + SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking) + : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {} bool follow(const SCEV *S) { - if (!scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) { - switch (S->getSCEVType()) { - case scConstant: - case scTruncate: - case scZeroExtend: - case scSignExtend: - case scPtrToInt: - case scAddExpr: - case scMulExpr: - case scUDivExpr: - case scAddRecExpr: - case scUMaxExpr: - case scSMaxExpr: - case scUMinExpr: - case scSMinExpr: - case scUnknown: - llvm_unreachable("These all unconditionally propagate poison."); - case scSequentialUMinExpr: - // TODO: We can always follow the first operand, - // but the SCEVTraversal API doesn't support this. - if (!LookThroughSeq) - return false; - break; - case scCouldNotCompute: - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - } - } + if (!LookThroughMaybePoisonBlocking && + !scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) + return false; if (auto *SU = dyn_cast<SCEVUnknown>(S)) { if (!isGuaranteedNotToBePoison(SU->getValue())) @@ -4174,9 +4139,10 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { }; // First collect all SCEVs that might result in AssumedPoison to be poison. - // We need to look through umin_seq here, because we want to find all SCEVs - // that *might* result in poison, not only those that are *required* to. - SCEVPoisonCollector PC1(/* LookThroughSeq */ true); + // We need to look through potentially poison-blocking operations here, + // because we want to find all SCEVs that *might* result in poison, not only + // those that are *required* to. + SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true); visitAll(AssumedPoison, PC1); // AssumedPoison is never poison. As the assumption is false, the implication @@ -4185,9 +4151,9 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { return true; // Collect all SCEVs in S that, if poison, *will* result in S being poison - // as well. We cannot look through umin_seq here, as its argument only *may* - // make the result poison. - SCEVPoisonCollector PC2(/* LookThroughSeq */ false); + // as well. We cannot look through potentially poison-blocking operations + // here, as their arguments only *may* make the result poison. + SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false); visitAll(S, PC2); // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison, @@ -4348,33 +4314,19 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops, } const SCEV * -ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy, - ScalableVectorType *ScalableTy) { - Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo()); - Constant *One = ConstantInt::get(IntTy, 1); - Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One); - // Note that the expression we created is the final expression, we don't - // want to simplify it any further Also, if we call a normal getSCEV(), - // we'll end up in an endless recursion. So just create an SCEVUnknown. - return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy)); +ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { + const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue()); + if (Size.isScalable()) + Res = getMulExpr(Res, getVScale(IntTy)); + return Res; } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { - if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy)) - return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy); - // We can bypass creating a target-independent constant expression and then - // folding it back into a ConstantInt. This is just a compile-time - // optimization. - return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); + return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { - if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy)) - return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy); - // We can bypass creating a target-independent constant expression and then - // folding it back into a ConstantInt. This is just a compile-time - // optimization. - return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); + return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); } const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, @@ -4383,8 +4335,10 @@ const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, // We can bypass creating a target-independent constant expression and then // folding it back into a ConstantInt. This is just a compile-time // optimization. - return getConstant( - IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); + const StructLayout *SL = getDataLayout().getStructLayout(STy); + assert(!SL->getSizeInBits().isScalable() && + "Cannot get offset for structure containing scalable vector types"); + return getConstant(IntTy, SL->getElementOffset(FieldNo)); } const SCEV *ScalarEvolution::getUnknown(Value *V) { @@ -4494,13 +4448,6 @@ ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) { ExprValueMapType::iterator SI = ExprValueMap.find_as(S); if (SI == ExprValueMap.end()) return std::nullopt; -#ifndef NDEBUG - if (VerifySCEVMap) { - // Check there is no dangling Value in the set returned. - for (Value *V : SI->second) - assert(ValueExprMap.count(V)); - } -#endif return SI->second.getArrayRef(); } @@ -4529,6 +4476,18 @@ void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) { } } +/// Determine whether this instruction is either not SCEVable or will always +/// produce a SCEVUnknown. We do not have to walk past such instructions when +/// invalidating. +static bool isAlwaysUnknown(const Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Load: + return true; + default: + return false; + } +} + /// Return an existing SCEV if it exists, otherwise analyze the expression and /// create a new one. const SCEV *ScalarEvolution::getSCEV(Value *V) { @@ -4536,7 +4495,11 @@ const SCEV *ScalarEvolution::getSCEV(Value *V) { if (const SCEV *S = getExistingSCEV(V)) return S; - return createSCEVIter(V); + const SCEV *S = createSCEVIter(V); + assert((!isa<Instruction>(V) || !isAlwaysUnknown(cast<Instruction>(V)) || + isa<SCEVUnknown>(S)) && + "isAlwaysUnknown() instruction is not SCEVUnknown"); + return S; } const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { @@ -4837,6 +4800,8 @@ static void PushDefUseChildren(Instruction *I, // Push the def-use children onto the Worklist stack. for (User *U : I->users()) { auto *UserInsn = cast<Instruction>(U); + if (isAlwaysUnknown(UserInsn)) + continue; if (Visited.insert(UserInsn).second) Worklist.push_back(UserInsn); } @@ -5054,6 +5019,18 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; + if (!AR->hasNoSelfWrap()) { + const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop()); + if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) { + ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this)); + const APInt &BECountAP = BECountMax->getAPInt(); + unsigned NoOverflowBitWidth = + BECountAP.getActiveBits() + StepCR.getMinSignedBits(); + if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType())) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNW); + } + } + if (!AR->hasNoSignedWrap()) { ConstantRange AddRecRange = getSignedRange(AR); ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); @@ -5112,7 +5089,8 @@ ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { // these to prove lack of overflow. Use this fact to avoid // doing extra work that may not pay off. - if (isa<SCEVCouldNotCompute>(MaxBECount) && AC.assumptions().empty()) + if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards && + AC.assumptions().empty()) return Result; // If the backedge is guarded by a comparison with the pre-inc value the @@ -5165,7 +5143,8 @@ ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { // these to prove lack of overflow. Use this fact to avoid // doing extra work that may not pay off. - if (isa<SCEVCouldNotCompute>(MaxBECount) && AC.assumptions().empty()) + if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards && + AC.assumptions().empty()) return Result; // If the backedge is guarded by a comparison with the pre-inc value the @@ -5733,6 +5712,12 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); insertValueToMap(PN, PHISCEV); + if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), + (SCEV::NoWrapFlags)(AR->getNoWrapFlags() | + proveNoWrapViaConstantRanges(AR))); + } + // We can add Flags to the post-inc expression only if we // know that it is *undefined behavior* for BEValueV to // overflow. @@ -5838,9 +5823,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // indices form a positive value. if (GEP->isInBounds() && GEP->getOperand(0) == PN) { Flags = setFlags(Flags, SCEV::FlagNW); - - const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); - if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) + if (isKnownPositive(Accum)) Flags = setFlags(Flags, SCEV::FlagNUW); } @@ -5858,6 +5841,12 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { forgetMemoizedResults(SymbolicName); insertValueToMap(PN, PHISCEV); + if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), + (SCEV::NoWrapFlags)(AR->getNoWrapFlags() | + proveNoWrapViaConstantRanges(AR))); + } + // We can add Flags to the post-inc expression only if we // know that it is *undefined behavior* for BEValueV to // overflow. @@ -5903,89 +5892,6 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { return nullptr; } -// Checks if the SCEV S is available at BB. S is considered available at BB -// if S can be materialized at BB without introducing a fault. -static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, - BasicBlock *BB) { - struct CheckAvailable { - bool TraversalDone = false; - bool Available = true; - - const Loop *L = nullptr; // The loop BB is in (can be nullptr) - BasicBlock *BB = nullptr; - DominatorTree &DT; - - CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT) - : L(L), BB(BB), DT(DT) {} - - bool setUnavailable() { - TraversalDone = true; - Available = false; - return false; - } - - bool follow(const SCEV *S) { - switch (S->getSCEVType()) { - case scConstant: - case scPtrToInt: - case scTruncate: - case scZeroExtend: - case scSignExtend: - case scAddExpr: - case scMulExpr: - case scUMaxExpr: - case scSMaxExpr: - case scUMinExpr: - case scSMinExpr: - case scSequentialUMinExpr: - // These expressions are available if their operand(s) is/are. - return true; - - case scAddRecExpr: { - // We allow add recurrences that are on the loop BB is in, or some - // outer loop. This guarantees availability because the value of the - // add recurrence at BB is simply the "current" value of the induction - // variable. We can relax this in the future; for instance an add - // recurrence on a sibling dominating loop is also available at BB. - const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop(); - if (L && (ARLoop == L || ARLoop->contains(L))) - return true; - - return setUnavailable(); - } - - case scUnknown: { - // For SCEVUnknown, we check for simple dominance. - const auto *SU = cast<SCEVUnknown>(S); - Value *V = SU->getValue(); - - if (isa<Argument>(V)) - return false; - - if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB)) - return false; - - return setUnavailable(); - } - - case scUDivExpr: - case scCouldNotCompute: - // We do not try to smart about these at all. - return setUnavailable(); - } - llvm_unreachable("Unknown SCEV kind!"); - } - - bool isDone() { return TraversalDone; } - }; - - CheckAvailable CA(L, BB, DT); - SCEVTraversal<CheckAvailable> ST(CA); - - ST.visitAll(S); - return CA.Available; -} - // Try to match a control flow sequence that branches out at BI and merges back // at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful // match. @@ -6023,13 +5929,6 @@ const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { auto IsReachable = [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) { - const Loop *L = LI.getLoopFor(PN->getParent()); - - // We don't want to break LCSSA, even in a SCEV expression tree. - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (LI.getLoopFor(PN->getIncomingBlock(i)) != L) - return nullptr; - // Try to match // // br %cond, label %left, label %right @@ -6050,8 +5949,8 @@ const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { if (BI && BI->isConditional() && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) && - IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) && - IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent())) + properlyDominates(getSCEV(LHS), PN->getParent()) && + properlyDominates(getSCEV(RHS), PN->getParent())) return createNodeForSelectOrPHI(PN, Cond, LHS, RHS); } @@ -6062,12 +5961,12 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { if (const SCEV *S = createAddRecFromPHI(PN)) return S; - if (const SCEV *S = createNodeFromSelectLikePHI(PN)) - return S; - if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC})) return getSCEV(V); + if (const SCEV *S = createNodeFromSelectLikePHI(PN)) + return S; + // If it's not a loop phi, we can't handle it yet. return getUnknown(PN); } @@ -6310,63 +6209,85 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { return getGEPExpr(GEP, IndexExprs); } -uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { +APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { + uint64_t BitWidth = getTypeSizeInBits(S->getType()); + auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) { + return TrailingZeros >= BitWidth + ? APInt::getZero(BitWidth) + : APInt::getOneBitSet(BitWidth, TrailingZeros); + }; + auto GetGCDMultiple = [this](const SCEVNAryExpr *N) { + // The result is GCD of all operands results. + APInt Res = getConstantMultiple(N->getOperand(0)); + for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I) + Res = APIntOps::GreatestCommonDivisor( + Res, getConstantMultiple(N->getOperand(I))); + return Res; + }; + switch (S->getSCEVType()) { case scConstant: - return cast<SCEVConstant>(S)->getAPInt().countTrailingZeros(); + return cast<SCEVConstant>(S)->getAPInt(); + case scPtrToInt: + return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand()); + case scUDivExpr: + case scVScale: + return APInt(BitWidth, 1); case scTruncate: { + // Only multiples that are a power of 2 will hold after truncation. const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S); - return std::min(GetMinTrailingZeros(T->getOperand()), - (uint32_t)getTypeSizeInBits(T->getType())); + uint32_t TZ = getMinTrailingZeros(T->getOperand()); + return GetShiftedByZeros(TZ); } case scZeroExtend: { - const SCEVZeroExtendExpr *E = cast<SCEVZeroExtendExpr>(S); - uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) - ? getTypeSizeInBits(E->getType()) - : OpRes; + const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S); + return getConstantMultiple(Z->getOperand()).zext(BitWidth); } case scSignExtend: { const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S); - uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) - ? getTypeSizeInBits(E->getType()) - : OpRes; + return getConstantMultiple(E->getOperand()).sext(BitWidth); } case scMulExpr: { const SCEVMulExpr *M = cast<SCEVMulExpr>(S); - // The result is the sum of all operands results. - uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); - uint32_t BitWidth = getTypeSizeInBits(M->getType()); - for (unsigned i = 1, e = M->getNumOperands(); - SumOpRes != BitWidth && i != e; ++i) - SumOpRes = - std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); - return SumOpRes; + if (M->hasNoUnsignedWrap()) { + // The result is the product of all operand results. + APInt Res = getConstantMultiple(M->getOperand(0)); + for (const SCEV *Operand : M->operands().drop_front()) + Res = Res * getConstantMultiple(Operand); + return Res; + } + + // If there are no wrap guarentees, find the trailing zeros, which is the + // sum of trailing zeros for all its operands. + uint32_t TZ = 0; + for (const SCEV *Operand : M->operands()) + TZ += getMinTrailingZeros(Operand); + return GetShiftedByZeros(TZ); } - case scUDivExpr: - return 0; - case scPtrToInt: case scAddExpr: - case scAddRecExpr: + case scAddRecExpr: { + const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S); + if (N->hasNoUnsignedWrap()) + return GetGCDMultiple(N); + // Find the trailing bits, which is the minimum of its operands. + uint32_t TZ = getMinTrailingZeros(N->getOperand(0)); + for (const SCEV *Operand : N->operands().drop_front()) + TZ = std::min(TZ, getMinTrailingZeros(Operand)); + return GetShiftedByZeros(TZ); + } case scUMaxExpr: case scSMaxExpr: case scUMinExpr: case scSMinExpr: - case scSequentialUMinExpr: { - // The result is the min of all operands results. - ArrayRef<const SCEV *> Ops = S->operands(); - uint32_t MinOpRes = GetMinTrailingZeros(Ops[0]); - for (unsigned I = 1, E = Ops.size(); MinOpRes && I != E; ++I) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(Ops[I])); - return MinOpRes; - } + case scSequentialUMinExpr: + return GetGCDMultiple(cast<SCEVNAryExpr>(S)); case scUnknown: { + // ask ValueTracking for known bits const SCEVUnknown *U = cast<SCEVUnknown>(S); - // For a SCEVUnknown, ask ValueTracking. - KnownBits Known = - computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT); - return Known.countMinTrailingZeros(); + unsigned Known = + computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT) + .countMinTrailingZeros(); + return GetShiftedByZeros(Known); } case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); @@ -6374,17 +6295,27 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { llvm_unreachable("Unknown SCEV kind!"); } -uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { - auto I = MinTrailingZerosCache.find(S); - if (I != MinTrailingZerosCache.end()) +APInt ScalarEvolution::getConstantMultiple(const SCEV *S) { + auto I = ConstantMultipleCache.find(S); + if (I != ConstantMultipleCache.end()) return I->second; - uint32_t Result = GetMinTrailingZerosImpl(S); - auto InsertPair = MinTrailingZerosCache.insert({S, Result}); + APInt Result = getConstantMultipleImpl(S); + auto InsertPair = ConstantMultipleCache.insert({S, Result}); assert(InsertPair.second && "Should insert a new key"); return InsertPair.first->second; } +APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) { + APInt Multiple = getConstantMultiple(S); + return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple; +} + +uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) { + return std::min(getConstantMultiple(S).countTrailingZeros(), + (unsigned)getTypeSizeInBits(S->getType())); +} + /// Helper method to assign a range to V from metadata present in the IR. static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) @@ -6400,6 +6331,7 @@ void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec, AddRec->setNoWrapFlags(Flags); UnsignedRanges.erase(AddRec); SignedRanges.erase(AddRec); + ConstantMultipleCache.erase(AddRec); } } @@ -6536,7 +6468,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) { if (!Seen.insert(Expr).second) return; - if (Cache.find(Expr) != Cache.end()) + if (Cache.contains(Expr)) return; switch (Expr->getSCEVType()) { case scUnknown: @@ -6544,6 +6476,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, break; [[fallthrough]]; case scConstant: + case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -6632,21 +6565,28 @@ const ConstantRange &ScalarEvolution::getRangeRef( // If the value has known zeros, the maximum value will have those known zeros // as well. - uint32_t TZ = GetMinTrailingZeros(S); - if (TZ != 0) { - if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { + APInt Multiple = getNonZeroConstantMultiple(S); + APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple); + if (!Remainder.isZero()) ConservativeResult = ConstantRange(APInt::getMinValue(BitWidth), - APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); - else + APInt::getMaxValue(BitWidth) - Remainder + 1); + } + else { + uint32_t TZ = getMinTrailingZeros(S); + if (TZ != 0) { ConservativeResult = ConstantRange( APInt::getSignedMinValue(BitWidth), APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); + } } switch (S->getSCEVType()) { case scConstant: llvm_unreachable("Already handled above."); + case scVScale: + return setRange(S, SignHint, getVScaleRange(&F, BitWidth)); case scTruncate: { const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S); ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1); @@ -6742,21 +6682,30 @@ const ConstantRange &ScalarEvolution::getRangeRef( // TODO: non-affine addrec if (AddRec->isAffine()) { - const SCEV *MaxBECount = + const SCEV *MaxBEScev = getConstantMaxBackedgeTakenCount(AddRec->getLoop()); - if (!isa<SCEVCouldNotCompute>(MaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { - auto RangeFromAffine = getRangeForAffineAR( - AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, - BitWidth); - ConservativeResult = - ConservativeResult.intersectWith(RangeFromAffine, RangeType); + if (!isa<SCEVCouldNotCompute>(MaxBEScev)) { + APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt(); + + // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if + // MaxBECount's active bits are all <= AddRec's bit width. + if (MaxBECount.getBitWidth() > BitWidth && + MaxBECount.getActiveBits() <= BitWidth) + MaxBECount = MaxBECount.trunc(BitWidth); + else if (MaxBECount.getBitWidth() < BitWidth) + MaxBECount = MaxBECount.zext(BitWidth); + + if (MaxBECount.getBitWidth() == BitWidth) { + auto RangeFromAffine = getRangeForAffineAR( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount); + ConservativeResult = + ConservativeResult.intersectWith(RangeFromAffine, RangeType); - auto RangeFromFactoring = getRangeViaFactoring( - AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, - BitWidth); - ConservativeResult = - ConservativeResult.intersectWith(RangeFromFactoring, RangeType); + auto RangeFromFactoring = getRangeViaFactoring( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount); + ConservativeResult = + ConservativeResult.intersectWith(RangeFromFactoring, RangeType); + } } // Now try symbolic BE count and more powerful methods. @@ -6764,7 +6713,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( const SCEV *SymbolicMaxBECount = getSymbolicMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && + getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth && AddRec->hasNoSelfWrap()) { auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR( AddRec, SymbolicMaxBECount, BitWidth, SignHint); @@ -6810,9 +6759,10 @@ const ConstantRange &ScalarEvolution::getRangeRef( } case scUnknown: { const SCEVUnknown *U = cast<SCEVUnknown>(S); + Value *V = U->getValue(); // Check if the IR explicitly contains !range metadata. - std::optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); + std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V); if (MDRange) ConservativeResult = ConservativeResult.intersectWith(*MDRange, RangeType); @@ -6825,13 +6775,13 @@ const ConstantRange &ScalarEvolution::getRangeRef( // See if ValueTracking can give us a useful range. const DataLayout &DL = getDataLayout(); - KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT); + KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT); if (Known.getBitWidth() != BitWidth) Known = Known.zextOrTrunc(BitWidth); // ValueTracking may be able to compute a tighter result for the number of // sign bits than for the value of those sign bits. - unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT); + unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT); if (U->getType()->isPointerTy()) { // If the pointer size is larger than the index size type, this can cause // NS to be larger than BitWidth. So compensate for this. @@ -6859,8 +6809,36 @@ const ConstantRange &ScalarEvolution::getRangeRef( APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1), RangeType); + if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) { + // Strengthen the range if the underlying IR value is a + // global/alloca/heap allocation using the size of the object. + ObjectSizeOpts Opts; + Opts.RoundToAlign = false; + Opts.NullIsUnknownSize = true; + uint64_t ObjSize; + if ((isa<GlobalVariable>(V) || isa<AllocaInst>(V) || + isAllocationFn(V, &TLI)) && + getObjectSize(V, ObjSize, DL, &TLI, Opts) && ObjSize > 1) { + // The highest address the object can start is ObjSize bytes before the + // end (unsigned max value). If this value is not a multiple of the + // alignment, the last possible start value is the next lowest multiple + // of the alignment. Note: The computations below cannot overflow, + // because if they would there's no possible start address for the + // object. + APInt MaxVal = APInt::getMaxValue(BitWidth) - APInt(BitWidth, ObjSize); + uint64_t Align = U->getValue()->getPointerAlignment(DL).value(); + uint64_t Rem = MaxVal.urem(Align); + MaxVal -= APInt(BitWidth, Rem); + APInt MinVal = APInt::getZero(BitWidth); + if (llvm::isKnownNonZero(V, DL)) + MinVal = Align; + ConservativeResult = ConservativeResult.intersectWith( + {MinVal, MaxVal + 1}, RangeType); + } + } + // A range of Phi is a subset of union of all ranges of its input. - if (PHINode *Phi = dyn_cast<PHINode>(U->getValue())) { + if (PHINode *Phi = dyn_cast<PHINode>(V)) { // Make sure that we do not run over cycled Phis. if (PendingPhiRanges.insert(Phi).second) { ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false); @@ -6881,7 +6859,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( } // vscale can't be equal to zero - if (const auto *II = dyn_cast<IntrinsicInst>(U->getValue())) + if (const auto *II = dyn_cast<IntrinsicInst>(V)) if (II->getIntrinsicID() == Intrinsic::vscale) { ConstantRange Disallowed = APInt::getZero(BitWidth); ConservativeResult = ConservativeResult.difference(Disallowed); @@ -6903,7 +6881,10 @@ const ConstantRange &ScalarEvolution::getRangeRef( static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, - unsigned BitWidth, bool Signed) { + bool Signed) { + unsigned BitWidth = Step.getBitWidth(); + assert(BitWidth == StartRange.getBitWidth() && + BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths"); // If either Step or MaxBECount is 0, then the expression won't change, and we // just need to return the initial range. if (Step == 0 || MaxBECount == 0) @@ -6962,14 +6943,11 @@ static ConstantRange getRangeForAffineARHelper(APInt Step, ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, const SCEV *Step, - const SCEV *MaxBECount, - unsigned BitWidth) { - assert(!isa<SCEVCouldNotCompute>(MaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && - "Precondition!"); - - MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); - APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount); + const APInt &MaxBECount) { + assert(getTypeSizeInBits(Start->getType()) == + getTypeSizeInBits(Step->getType()) && + getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() && + "mismatched bit widths"); // First, consider step signed. ConstantRange StartSRange = getSignedRange(Start); @@ -6977,17 +6955,16 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, // If Step can be both positive and negative, we need to find ranges for the // maximum absolute step values in both directions and union them. - ConstantRange SR = - getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange, - MaxBECountValue, BitWidth, /* Signed = */ true); + ConstantRange SR = getRangeForAffineARHelper( + StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true); SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(), - StartSRange, MaxBECountValue, - BitWidth, /* Signed = */ true)); + StartSRange, MaxBECount, + /* Signed = */ true)); // Next, consider step unsigned. ConstantRange UR = getRangeForAffineARHelper( - getUnsignedRangeMax(Step), getUnsignedRange(Start), - MaxBECountValue, BitWidth, /* Signed = */ false); + getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount, + /* Signed = */ false); // Finally, intersect signed and unsigned ranges. return SR.intersectWith(UR, ConstantRange::Smallest); @@ -7038,7 +7015,7 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that // knowledge, let's try to prove that we are dealing with Case 1. It is so if // Start <= End and step is positive, or Start >= End and step is negative. - const SCEV *Start = AddRec->getStart(); + const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop()); ConstantRange StartRange = getRangeRef(Start, SignHint); ConstantRange EndRange = getRangeRef(End, SignHint); ConstantRange RangeBetween = StartRange.unionWith(EndRange); @@ -7055,7 +7032,7 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( if (isKnownPositive(Step) && isKnownPredicateViaConstantRanges(LEPred, Start, End)) return RangeBetween; - else if (isKnownNegative(Step) && + if (isKnownNegative(Step) && isKnownPredicateViaConstantRanges(GEPred, Start, End)) return RangeBetween; return ConstantRange::getFull(BitWidth); @@ -7063,11 +7040,15 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, const SCEV *Step, - const SCEV *MaxBECount, - unsigned BitWidth) { + const APInt &MaxBECount) { // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) + unsigned BitWidth = MaxBECount.getBitWidth(); + assert(getTypeSizeInBits(Start->getType()) == BitWidth && + getTypeSizeInBits(Step->getType()) == BitWidth && + "mismatched bit widths"); + struct SelectPattern { Value *Condition = nullptr; APInt TrueValue; @@ -7169,9 +7150,9 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); ConstantRange TrueRange = - this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); + this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount); ConstantRange FalseRange = - this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); + this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount); return TrueRange.unionWith(FalseRange); } @@ -7294,62 +7275,43 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { if (isSCEVExprNeverPoison(I)) return true; - // For an add recurrence specifically, we assume that infinite loops without - // side effects are undefined behavior, and then reason as follows: + // If the loop only has one exit, then we know that, if the loop is entered, + // any instruction dominating that exit will be executed. If any such + // instruction would result in UB, the addrec cannot be poison. // - // If the add recurrence is poison in any iteration, it is poison on all - // future iterations (since incrementing poison yields poison). If the result - // of the add recurrence is fed into the loop latch condition and the loop - // does not contain any throws or exiting blocks other than the latch, we now - // have the ability to "choose" whether the backedge is taken or not (by - // choosing a sufficiently evil value for the poison feeding into the branch) - // for every iteration including and after the one in which \p I first became - // poison. There are two possibilities (let's call the iteration in which \p - // I first became poison as K): - // - // 1. In the set of iterations including and after K, the loop body executes - // no side effects. In this case executing the backege an infinte number - // of times will yield undefined behavior. - // - // 2. In the set of iterations including and after K, the loop body executes - // at least one side effect. In this case, that specific instance of side - // effect is control dependent on poison, which also yields undefined - // behavior. + // This is basically the same reasoning as in isSCEVExprNeverPoison(), but + // also handles uses outside the loop header (they just need to dominate the + // single exit). auto *ExitingBB = L->getExitingBlock(); - auto *LatchBB = L->getLoopLatch(); - if (!ExitingBB || !LatchBB || ExitingBB != LatchBB) + if (!ExitingBB || !loopHasNoAbnormalExits(L)) return false; - SmallPtrSet<const Instruction *, 16> Pushed; - SmallVector<const Instruction *, 8> PoisonStack; + SmallPtrSet<const Value *, 16> KnownPoison; + SmallVector<const Instruction *, 8> Worklist; // We start by assuming \c I, the post-inc add recurrence, is poison. Only // things that are known to be poison under that assumption go on the - // PoisonStack. - Pushed.insert(I); - PoisonStack.push_back(I); + // Worklist. + KnownPoison.insert(I); + Worklist.push_back(I); - bool LatchControlDependentOnPoison = false; - while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { - const Instruction *Poison = PoisonStack.pop_back_val(); + while (!Worklist.empty()) { + const Instruction *Poison = Worklist.pop_back_val(); for (const Use &U : Poison->uses()) { - const User *PoisonUser = U.getUser(); - if (propagatesPoison(U)) { - if (Pushed.insert(cast<Instruction>(PoisonUser)).second) - PoisonStack.push_back(cast<Instruction>(PoisonUser)); - } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { - assert(BI->isConditional() && "Only possibility!"); - if (BI->getParent() == LatchBB) { - LatchControlDependentOnPoison = true; - break; - } - } + const Instruction *PoisonUser = cast<Instruction>(U.getUser()); + if (mustTriggerUB(PoisonUser, KnownPoison) && + DT.dominates(PoisonUser->getParent(), ExitingBB)) + return true; + + if (propagatesPoison(U) && L->contains(PoisonUser)) + if (KnownPoison.insert(PoisonUser).second) + Worklist.push_back(PoisonUser); } } - return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); + return false; } ScalarEvolution::LoopProperties @@ -7448,13 +7410,9 @@ ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) { return getUnknown(PoisonValue::get(V->getType())); } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) return getConstant(CI); - else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (!GA->isInterposable()) { - Ops.push_back(GA->getAliasee()); - return nullptr; - } + else if (isa<GlobalAlias>(V)) return getUnknown(V); - } else if (!isa<ConstantExpr>(V)) + else if (!isa<ConstantExpr>(V)) return getUnknown(V); Operator *U = cast<Operator>(V); @@ -7478,18 +7436,18 @@ ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) { auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT, dyn_cast<Instruction>(V)); if (!NewBO || - (U->getOpcode() == Instruction::Add && + (BO->Opcode == Instruction::Add && (NewBO->Opcode != Instruction::Add && NewBO->Opcode != Instruction::Sub)) || - (U->getOpcode() == Instruction::Mul && + (BO->Opcode == Instruction::Mul && NewBO->Opcode != Instruction::Mul)) { Ops.push_back(BO->LHS); break; } // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions // requires a SCEV for the LHS. - if (NewBO->Op && (NewBO->IsNSW || NewBO->IsNUW)) { - auto *I = dyn_cast<Instruction>(NewBO->Op); + if (BO->Op && (BO->IsNSW || BO->IsNUW)) { + auto *I = dyn_cast<Instruction>(BO->Op); if (I && programUndefinedIfPoison(I)) { Ops.push_back(BO->LHS); break; @@ -7511,7 +7469,7 @@ ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) { break; case Instruction::And: case Instruction::Or: - if (!IsConstArg && BO->LHS->getType()->isIntegerTy(1)) + if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1)) return nullptr; break; case Instruction::LShr: @@ -7638,8 +7596,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getUnknown(PoisonValue::get(V->getType())); } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) return getConstant(CI); - else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) - return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); + else if (isa<GlobalAlias>(V)) + return getUnknown(V); else if (!isa<ConstantExpr>(V)) return getUnknown(V); @@ -7762,8 +7720,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // constants, obscuring what would otherwise be a low-bits mask. // Use computeKnownBits to compute what ShrinkDemandedConstant // knew about to reconstruct a low-bits mask value. - unsigned LZ = A.countLeadingZeros(); - unsigned TZ = A.countTrailingZeros(); + unsigned LZ = A.countl_zero(); + unsigned TZ = A.countr_zero(); unsigned BitWidth = A.getBitWidth(); KnownBits Known(BitWidth); computeKnownBits(BO->LHS, Known, getDataLayout(), @@ -7778,7 +7736,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) { if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) { // For an expression like (x * 8) & 8, simplify the multiply. - unsigned MulZeros = OpC->getAPInt().countTrailingZeros(); + unsigned MulZeros = OpC->getAPInt().countr_zero(); unsigned GCD = std::min(MulZeros, TZ); APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); SmallVector<const SCEV*, 4> MulOps; @@ -8057,6 +8015,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is // just eqivalent to the first operand for SCEV purposes. return getSCEV(II->getArgOperand(0)); + case Intrinsic::vscale: + return getVScale(II->getType()); default: break; } @@ -8071,21 +8031,45 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // -const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, - bool Extend) { +const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) { if (isa<SCEVCouldNotCompute>(ExitCount)) return getCouldNotCompute(); auto *ExitCountType = ExitCount->getType(); assert(ExitCountType->isIntegerTy()); + auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(), + 1 + ExitCountType->getScalarSizeInBits()); + return getTripCountFromExitCount(ExitCount, EvalTy, nullptr); +} + +const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, + Type *EvalTy, + const Loop *L) { + if (isa<SCEVCouldNotCompute>(ExitCount)) + return getCouldNotCompute(); + + unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType()); + unsigned EvalSize = EvalTy->getPrimitiveSizeInBits(); - if (!Extend) - return getAddExpr(ExitCount, getOne(ExitCountType)); + auto CanAddOneWithoutOverflow = [&]() { + ConstantRange ExitCountRange = + getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED); + if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize))) + return true; + + return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount, + getMinusOne(ExitCount->getType())); + }; + + // If we need to zero extend the backedge count, check if we can add one to + // it prior to zero extending without overflow. Provided this is safe, it + // allows better simplification of the +1. + if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow()) + return getZeroExtendExpr( + getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy); - auto *WiderType = Type::getIntNTy(ExitCountType->getContext(), - 1 + ExitCountType->getScalarSizeInBits()); - return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType), - getOne(WiderType)); + // Get the total trip count from the count by adding 1. This may wrap. + return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy)); } static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { @@ -8124,126 +8108,6 @@ unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) { return getConstantTripCount(MaxExitCount); } -const SCEV *ScalarEvolution::getConstantMaxTripCountFromArray(const Loop *L) { - // We can't infer from Array in Irregular Loop. - // FIXME: It's hard to infer loop bound from array operated in Nested Loop. - if (!L->isLoopSimplifyForm() || !L->isInnermost()) - return getCouldNotCompute(); - - // FIXME: To make the scene more typical, we only analysis loops that have - // one exiting block and that block must be the latch. To make it easier to - // capture loops that have memory access and memory access will be executed - // in each iteration. - const BasicBlock *LoopLatch = L->getLoopLatch(); - assert(LoopLatch && "See defination of simplify form loop."); - if (L->getExitingBlock() != LoopLatch) - return getCouldNotCompute(); - - const DataLayout &DL = getDataLayout(); - SmallVector<const SCEV *> InferCountColl; - for (auto *BB : L->getBlocks()) { - // Go here, we can know that Loop is a single exiting and simplified form - // loop. Make sure that infer from Memory Operation in those BBs must be - // executed in loop. First step, we can make sure that max execution time - // of MemAccessBB in loop represents latch max excution time. - // If MemAccessBB does not dom Latch, skip. - // Entry - // │ - // ┌─────▼─────┐ - // │Loop Header◄─────┐ - // └──┬──────┬─┘ │ - // │ │ │ - // ┌────────▼──┐ ┌─▼─────┐ │ - // │MemAccessBB│ │OtherBB│ │ - // └────────┬──┘ └─┬─────┘ │ - // │ │ │ - // ┌─▼──────▼─┐ │ - // │Loop Latch├─────┘ - // └────┬─────┘ - // ▼ - // Exit - if (!DT.dominates(BB, LoopLatch)) - continue; - - for (Instruction &Inst : *BB) { - // Find Memory Operation Instruction. - auto *GEP = getLoadStorePointerOperand(&Inst); - if (!GEP) - continue; - - auto *ElemSize = dyn_cast<SCEVConstant>(getElementSize(&Inst)); - // Do not infer from scalar type, eg."ElemSize = sizeof()". - if (!ElemSize) - continue; - - // Use a existing polynomial recurrence on the trip count. - auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(GEP)); - if (!AddRec) - continue; - auto *ArrBase = dyn_cast<SCEVUnknown>(getPointerBase(AddRec)); - auto *Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(*this)); - if (!ArrBase || !Step) - continue; - assert(isLoopInvariant(ArrBase, L) && "See addrec definition"); - - // Only handle { %array + step }, - // FIXME: {(SCEVAddRecExpr) + step } could not be analysed here. - if (AddRec->getStart() != ArrBase) - continue; - - // Memory operation pattern which have gaps. - // Or repeat memory opreation. - // And index of GEP wraps arround. - if (Step->getAPInt().getActiveBits() > 32 || - Step->getAPInt().getZExtValue() != - ElemSize->getAPInt().getZExtValue() || - Step->isZero() || Step->getAPInt().isNegative()) - continue; - - // Only infer from stack array which has certain size. - // Make sure alloca instruction is not excuted in loop. - AllocaInst *AllocateInst = dyn_cast<AllocaInst>(ArrBase->getValue()); - if (!AllocateInst || L->contains(AllocateInst->getParent())) - continue; - - // Make sure only handle normal array. - auto *Ty = dyn_cast<ArrayType>(AllocateInst->getAllocatedType()); - auto *ArrSize = dyn_cast<ConstantInt>(AllocateInst->getArraySize()); - if (!Ty || !ArrSize || !ArrSize->isOne()) - continue; - - // FIXME: Since gep indices are silently zext to the indexing type, - // we will have a narrow gep index which wraps around rather than - // increasing strictly, we shoule ensure that step is increasing - // strictly by the loop iteration. - // Now we can infer a max execution time by MemLength/StepLength. - const SCEV *MemSize = - getConstant(Step->getType(), DL.getTypeAllocSize(Ty)); - auto *MaxExeCount = - dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step)); - if (!MaxExeCount || MaxExeCount->getAPInt().getActiveBits() > 32) - continue; - - // If the loop reaches the maximum number of executions, we can not - // access bytes starting outside the statically allocated size without - // being immediate UB. But it is allowed to enter loop header one more - // time. - auto *InferCount = dyn_cast<SCEVConstant>( - getAddExpr(MaxExeCount, getOne(MaxExeCount->getType()))); - // Discard the maximum number of execution times under 32bits. - if (!InferCount || InferCount->getAPInt().getActiveBits() > 32) - continue; - - InferCountColl.push_back(InferCount); - } - } - - if (InferCountColl.size() == 0) - return getCouldNotCompute(); - - return getUMinFromMismatchedTypes(InferCountColl); -} - unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { SmallVector<BasicBlock *, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); @@ -8264,26 +8128,14 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, return 1; // Get the trip count - const SCEV *TCExpr = getTripCountFromExitCount(ExitCount); - - const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr); - if (!TC) - // Attempt to factor more general cases. Returns the greatest power of - // two divisor. If overflow happens, the trip count expression is still - // divisible by the greatest power of 2 divisor returned. - return 1U << std::min((uint32_t)31, - GetMinTrailingZeros(applyLoopGuards(TCExpr, L))); - - ConstantInt *Result = TC->getValue(); - - // Guard against huge trip counts (this requires checking - // for zero to handle the case where the trip count == -1 and the - // addition wraps). - if (!Result || Result->getValue().getActiveBits() > 32 || - Result->getValue().getActiveBits() == 0) - return 1; + const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L)); - return (unsigned)Result->getZExtValue(); + APInt Multiple = getNonZeroConstantMultiple(TCExpr); + // If a trip multiple is huge (>=2^32), the trip count is still divisible by + // the greatest power of 2 divisor less than 2^32. + return Multiple.getActiveBits() > 32 + ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros()) + : (unsigned)Multiple.zextOrTrunc(32).getZExtValue(); } /// Returns the largest constant divisor of the trip count of this loop as a @@ -8391,23 +8243,6 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // must be cleared in this scope. BackedgeTakenInfo Result = computeBackedgeTakenCount(L); - // In product build, there are no usage of statistic. - (void)NumTripCountsComputed; - (void)NumTripCountsNotComputed; -#if LLVM_ENABLE_STATS || !defined(NDEBUG) - const SCEV *BEExact = Result.getExact(L, this); - if (BEExact != getCouldNotCompute()) { - assert(isLoopInvariant(BEExact, L) && - isLoopInvariant(Result.getConstantMax(this), L) && - "Computed backedge-taken count isn't loop invariant for loop!"); - ++NumTripCountsComputed; - } else if (Result.getConstantMax(this) == getCouldNotCompute() && - isa<PHINode>(L->getHeader()->begin())) { - // Only count loops that have phi nodes as not being computable. - ++NumTripCountsNotComputed; - } -#endif // LLVM_ENABLE_STATS || !defined(NDEBUG) - // Now that we know more about the trip count for this loop, forget any // existing SCEV values for PHI nodes in this loop since they are only // conservative estimates made without the benefit of trip count @@ -8454,11 +8289,32 @@ void ScalarEvolution::forgetAllLoops() { SignedRanges.clear(); ExprValueMap.clear(); HasRecMap.clear(); - MinTrailingZerosCache.clear(); + ConstantMultipleCache.clear(); PredicatedSCEVRewrites.clear(); FoldCache.clear(); FoldCacheUser.clear(); } +void ScalarEvolution::visitAndClearUsers( + SmallVectorImpl<Instruction *> &Worklist, + SmallPtrSetImpl<Instruction *> &Visited, + SmallVectorImpl<const SCEV *> &ToForget) { + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!isSCEVable(I->getType())) + continue; + + ValueExprMapType::iterator It = + ValueExprMap.find_as(static_cast<Value *>(I)); + if (It != ValueExprMap.end()) { + eraseValueFromMap(It->first); + ToForget.push_back(It->second); + if (PHINode *PN = dyn_cast<PHINode>(I)) + ConstantEvolutionLoopExitValue.erase(PN); + } + + PushDefUseChildren(I, Worklist, Visited); + } +} void ScalarEvolution::forgetLoop(const Loop *L) { SmallVector<const Loop *, 16> LoopWorklist(1, L); @@ -8492,21 +8348,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) { // Drop information about expressions based on loop-header PHIs. PushLoopPHIs(CurrL, Worklist, Visited); - - while (!Worklist.empty()) { - Instruction *I = Worklist.pop_back_val(); - - ValueExprMapType::iterator It = - ValueExprMap.find_as(static_cast<Value *>(I)); - if (It != ValueExprMap.end()) { - eraseValueFromMap(It->first); - ToForget.push_back(It->second); - if (PHINode *PN = dyn_cast<PHINode>(I)) - ConstantEvolutionLoopExitValue.erase(PN); - } - - PushDefUseChildren(I, Worklist, Visited); - } + visitAndClearUsers(Worklist, Visited, ToForget); LoopPropertiesCache.erase(CurrL); // Forget all contained loops too, to avoid dangling entries in the @@ -8530,20 +8372,8 @@ void ScalarEvolution::forgetValue(Value *V) { SmallVector<const SCEV *, 8> ToForget; Worklist.push_back(I); Visited.insert(I); + visitAndClearUsers(Worklist, Visited, ToForget); - while (!Worklist.empty()) { - I = Worklist.pop_back_val(); - ValueExprMapType::iterator It = - ValueExprMap.find_as(static_cast<Value *>(I)); - if (It != ValueExprMap.end()) { - eraseValueFromMap(It->first); - ToForget.push_back(It->second); - if (PHINode *PN = dyn_cast<PHINode>(I)) - ConstantEvolutionLoopExitValue.erase(PN); - } - - PushDefUseChildren(I, Worklist, Visited); - } forgetMemoizedResults(ToForget); } @@ -8798,7 +8628,9 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, // 1. For each exit that can be computed, add an entry to ExitCounts. // CouldComputeBECount is true only if all exits can be computed. - if (EL.ExactNotTaken == getCouldNotCompute()) + if (EL.ExactNotTaken != getCouldNotCompute()) + ++NumExitCountsComputed; + else // We couldn't compute an exact value for this exit, so // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; @@ -8806,9 +8638,11 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, // Exact always implies symbolic, only check symbolic. if (EL.SymbolicMaxNotTaken != getCouldNotCompute()) ExitCounts.emplace_back(ExitBB, EL); - else + else { assert(EL.ExactNotTaken == getCouldNotCompute() && "Exact is known but symbolic isn't?"); + ++NumExitCountsNotComputed; + } // 2. Derive the loop's MaxBECount from each exit's max number of // non-exiting iterations. Partition the loop exits into two kinds: @@ -8878,9 +8712,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) && "It should have one successor in loop and one exit block!"); // Proceed to the next level to examine the exit condition expression. - return computeExitLimitFromCond( - L, BI->getCondition(), ExitIfTrue, - /*ControlsExit=*/IsOnlyExit, AllowPredicates); + return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue, + /*ControlsOnlyExit=*/IsOnlyExit, + AllowPredicates); } if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) { @@ -8893,24 +8727,25 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, Exit = SBB; } assert(Exit && "Exiting block must have at least one exit"); - return computeExitLimitFromSingleExitSwitch(L, SI, Exit, - /*ControlsExit=*/IsOnlyExit); + return computeExitLimitFromSingleExitSwitch( + L, SI, Exit, + /*ControlsOnlyExit=*/IsOnlyExit); } return getCouldNotCompute(); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( - const Loop *L, Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates) { + const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, + bool AllowPredicates) { ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates); return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue, - ControlsExit, AllowPredicates); + ControlsOnlyExit, AllowPredicates); } std::optional<ScalarEvolution::ExitLimit> ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, - bool ExitIfTrue, bool ControlsExit, + bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { (void)this->L; (void)this->ExitIfTrue; @@ -8919,7 +8754,7 @@ ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, assert(this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && "Variance in assumed invariant key components!"); - auto Itr = TripCountMap.find({ExitCond, ControlsExit}); + auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit}); if (Itr == TripCountMap.end()) return std::nullopt; return Itr->second; @@ -8927,14 +8762,14 @@ ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, + bool ControlsOnlyExit, bool AllowPredicates, const ExitLimit &EL) { assert(this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && "Variance in assumed invariant key components!"); - auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL}); + auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL}); assert(InsertResult.second && "Expected successful insertion!"); (void)InsertResult; (void)ExitIfTrue; @@ -8942,36 +8777,37 @@ void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates) { + bool ControlsOnlyExit, bool AllowPredicates) { - if (auto MaybeEL = - Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates)) + if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit, + AllowPredicates)) return *MaybeEL; - ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue, - ControlsExit, AllowPredicates); - Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL); + ExitLimit EL = computeExitLimitFromCondImpl( + Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates); + Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL); return EL; } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates) { + bool ControlsOnlyExit, bool AllowPredicates) { // Handle BinOp conditions (And, Or). if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( - Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates)) + Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) return *LimitFromBinOp; // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { ExitLimit EL = - computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit); + computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit); if (EL.hasFullInfo() || !AllowPredicates) return EL; // Try again, but use SCEV predicates this time. - return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit, + return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, + ControlsOnlyExit, /*AllowPredicates=*/true); } @@ -8983,9 +8819,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( if (ExitIfTrue == !CI->getZExtValue()) // The backedge is always taken. return getCouldNotCompute(); - else - // The backedge is never taken. - return getZero(CI->getType()); + // The backedge is never taken. + return getZero(CI->getType()); } // If we're exiting based on the overflow flag of an x.with.overflow intrinsic @@ -9007,8 +8842,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( if (Offset != 0) LHS = getAddExpr(LHS, getConstant(Offset)); auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), - ControlsExit, AllowPredicates); - if (EL.hasAnyInfo()) return EL; + ControlsOnlyExit, AllowPredicates); + if (EL.hasAnyInfo()) + return EL; } // If it's not an integer or pointer comparison then compute it the hard way. @@ -9018,7 +8854,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( std::optional<ScalarEvolution::ExitLimit> ScalarEvolution::computeExitLimitFromCondFromBinOp( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates) { + bool ControlsOnlyExit, bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. Value *Op0, *Op1; bool IsAnd = false; @@ -9033,12 +8869,12 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( // br (and Op0 Op1), loop, exit // br (or Op0 Op1), exit, loop bool EitherMayExit = IsAnd ^ ExitIfTrue; - ExitLimit EL0 = computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue, - ControlsExit && !EitherMayExit, - AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCondCached(Cache, L, Op1, ExitIfTrue, - ControlsExit && !EitherMayExit, - AllowPredicates); + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit, + AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit, + AllowPredicates); // Be robust against unsimplified IR for the form "op i1 X, NeutralElement" const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd); @@ -9096,12 +8932,9 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( { &EL0.Predicates, &EL1.Predicates }); } -ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimitFromICmp(const Loop *L, - ICmpInst *ExitCond, - bool ExitIfTrue, - bool ControlsExit, - bool AllowPredicates) { +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( + const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, + bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Pred; if (!ExitIfTrue) @@ -9113,9 +8946,10 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); - ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsExit, + ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, AllowPredicates); - if (EL.hasAnyInfo()) return EL; + if (EL.hasAnyInfo()) + return EL; auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue); @@ -9126,12 +8960,9 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, return computeShiftCompareExitLimit(ExitCond->getOperand(0), ExitCond->getOperand(1), L, OriginalPred); } -ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimitFromICmp(const Loop *L, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - bool ControlsExit, - bool AllowPredicates) { +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( + const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + bool ControlsOnlyExit, bool AllowPredicates) { // Try to evaluate any dependencies out of the loop. LHS = getSCEVAtScope(LHS, L); @@ -9145,12 +8976,10 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, Pred = ICmpInst::getSwappedPredicate(Pred); } - bool ControllingFiniteLoop = - ControlsExit && loopHasNoAbnormalExits(L) && loopIsFiniteByAssumption(L); + bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) && + loopIsFiniteByAssumption(L); // Simplify the operands before analyzing them. - (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0, - (EnableFiniteLoopControl ? ControllingFiniteLoop - : false)); + (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0); // If we have a comparison of a chrec against a constant, try to use value // ranges to answer this query. @@ -9202,9 +9031,10 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, if (isa<SCEVCouldNotCompute>(RHS)) return RHS; } - ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit, AllowPredicates); - if (EL.hasAnyInfo()) return EL; + if (EL.hasAnyInfo()) + return EL; break; } case ICmpInst::ICMP_EQ: { // while (X == Y) @@ -9223,21 +9053,40 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, if (EL.hasAnyInfo()) return EL; break; } + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_ULE: + // Since the loop is finite, an invariant RHS cannot include the boundary + // value, otherwise it would loop forever. + if (!EnableFiniteLoopControl || !ControllingFiniteLoop || + !isLoopInvariant(RHS, L)) + break; + RHS = getAddExpr(getOne(RHS->getType()), RHS); + [[fallthrough]]; case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_ULT: { // while (X < Y) - bool IsSigned = Pred == ICmpInst::ICMP_SLT; - ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, + case ICmpInst::ICMP_ULT: { // while (X < Y) + bool IsSigned = ICmpInst::isSigned(Pred); + ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, AllowPredicates); - if (EL.hasAnyInfo()) return EL; + if (EL.hasAnyInfo()) + return EL; break; } + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: + // Since the loop is finite, an invariant RHS cannot include the boundary + // value, otherwise it would loop forever. + if (!EnableFiniteLoopControl || !ControllingFiniteLoop || + !isLoopInvariant(RHS, L)) + break; + RHS = getAddExpr(getMinusOne(RHS->getType()), RHS); + [[fallthrough]]; case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_UGT: { // while (X > Y) - bool IsSigned = Pred == ICmpInst::ICMP_SGT; - ExitLimit EL = - howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, - AllowPredicates); - if (EL.hasAnyInfo()) return EL; + case ICmpInst::ICMP_UGT: { // while (X > Y) + bool IsSigned = ICmpInst::isSigned(Pred); + ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, + AllowPredicates); + if (EL.hasAnyInfo()) + return EL; break; } default: @@ -9251,7 +9100,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, - bool ControlsExit) { + bool ControlsOnlyExit) { assert(!L->contains(ExitingBlock) && "Not an exiting block!"); // Give up if the exit is the default dest of a switch. @@ -9264,7 +9113,7 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit); if (EL.hasAnyInfo()) return EL; @@ -9762,6 +9611,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { switch (V->getSCEVType()) { case scCouldNotCompute: case scAddRecExpr: + case scVScale: return nullptr; case scConstant: return cast<SCEVConstant>(V)->getValue(); @@ -9842,9 +9692,46 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { llvm_unreachable("Unknown SCEV kind!"); } +const SCEV * +ScalarEvolution::getWithOperands(const SCEV *S, + SmallVectorImpl<const SCEV *> &NewOps) { + switch (S->getSCEVType()) { + case scTruncate: + case scZeroExtend: + case scSignExtend: + case scPtrToInt: + return getCastExpr(S->getSCEVType(), NewOps[0], S->getType()); + case scAddRecExpr: { + auto *AddRec = cast<SCEVAddRecExpr>(S); + return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags()); + } + case scAddExpr: + return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags()); + case scMulExpr: + return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags()); + case scUDivExpr: + return getUDivExpr(NewOps[0], NewOps[1]); + case scUMaxExpr: + case scSMaxExpr: + case scUMinExpr: + case scSMinExpr: + return getMinMaxExpr(S->getSCEVType(), NewOps); + case scSequentialUMinExpr: + return getSequentialMinMaxExpr(S->getSCEVType(), NewOps); + case scConstant: + case scVScale: + case scUnknown: + return S; + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); +} + const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { switch (V->getSCEVType()) { case scConstant: + case scVScale: return V; case scAddRecExpr: { // If this is a loop recurrence for a loop that does not contain L, then we @@ -9923,32 +9810,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { NewOps.push_back(OpAtScope); } - switch (V->getSCEVType()) { - case scTruncate: - case scZeroExtend: - case scSignExtend: - case scPtrToInt: - return getCastExpr(V->getSCEVType(), NewOps[0], V->getType()); - case scAddExpr: - return getAddExpr(NewOps, cast<SCEVAddExpr>(V)->getNoWrapFlags()); - case scMulExpr: - return getMulExpr(NewOps, cast<SCEVMulExpr>(V)->getNoWrapFlags()); - case scUDivExpr: - return getUDivExpr(NewOps[0], NewOps[1]); - case scUMaxExpr: - case scSMaxExpr: - case scUMinExpr: - case scSMinExpr: - return getMinMaxExpr(V->getSCEVType(), NewOps); - case scSequentialUMinExpr: - return getSequentialMinMaxExpr(V->getSCEVType(), NewOps); - case scConstant: - case scAddRecExpr: - case scUnknown: - case scCouldNotCompute: - llvm_unreachable("Can not get those expressions here."); - } - llvm_unreachable("Unknown n-ary-like SCEV type!"); + return getWithOperands(V, NewOps); } } // If we got here, all operands are loop invariant. @@ -10012,17 +9874,6 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { return getSCEV(RV); } } - - // If there is a single-input Phi, evaluate it at our scope. If we can - // prove that this replacement does not break LCSSA form, use new value. - if (PN->getNumOperands() == 1) { - const SCEV *Input = getSCEV(PN->getOperand(0)); - const SCEV *InputAtScope = getSCEVAtScope(Input, L); - // TODO: We can generalize it using LI.replacementPreservesLCSSAForm, - // for the simplest case just support constants. - if (isa<SCEVConstant>(InputAtScope)) - return InputAtScope; - } } // Okay, this is an expression that we cannot symbolically evaluate @@ -10108,14 +9959,14 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, // // The gcd of A and N may have only one prime factor: 2. The number of // trailing zeros in A is its multiplicity - uint32_t Mult2 = A.countTrailingZeros(); + uint32_t Mult2 = A.countr_zero(); // D = 2^Mult2 // 2. Check if B is divisible by D. // // B is divisible by D if and only if the multiplicity of prime factor 2 for B // is not less than multiplicity of this prime factor for D. - if (SE.GetMinTrailingZeros(B) < Mult2) + if (SE.getMinTrailingZeros(B) < Mult2) return SE.getCouldNotCompute(); // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic @@ -10410,9 +10261,10 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth); } -ScalarEvolution::ExitLimit -ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, - bool AllowPredicates) { +ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, + const Loop *L, + bool ControlsOnlyExit, + bool AllowPredicates) { // This is only used for loops with a "x != y" exit test. The exit condition // is now expressed as a single expression, V = x-y. So the exit test is @@ -10521,7 +10373,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // compute the backedge count. In this case, the step may not divide the // distance, but we don't care because if the condition is "missed" the loop // will have undefined behavior due to wrapping. - if (ControlsExit && AddRec->hasNoSelfWrap() && + if (ControlsOnlyExit && AddRec->hasNoSelfWrap() && loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); @@ -10616,8 +10468,7 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, - unsigned Depth, - bool ControllingFiniteLoop) { + unsigned Depth) { bool Changed = false; // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or // '0 != 0'. @@ -10638,8 +10489,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, LHSC->getValue(), RHSC->getValue())->isNullValue()) return TrivialCase(false); - else - return TrivialCase(true); + return TrivialCase(true); } // Otherwise swap the operands to put the constant on the right. std::swap(LHS, RHS); @@ -10670,7 +10520,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); if (ExactCR.isFullSet()) return TrivialCase(true); - else if (ExactCR.isEmptySet()) + if (ExactCR.isEmptySet()) return TrivialCase(false); APInt NewRHS; @@ -10746,15 +10596,10 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, } // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by - // adding or subtracting 1 from one of the operands. This can be done for - // one of two reasons: - // 1) The range of the RHS does not include the (signed/unsigned) boundaries - // 2) The loop is finite, with this comparison controlling the exit. Since the - // loop is finite, the bound cannot include the corresponding boundary - // (otherwise it would loop forever). + // adding or subtracting 1 from one of the operands. switch (Pred) { case ICmpInst::ICMP_SLE: - if (ControllingFiniteLoop || !getSignedRangeMax(RHS).isMaxSignedValue()) { + if (!getSignedRangeMax(RHS).isMaxSignedValue()) { RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SLT; @@ -10767,7 +10612,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, } break; case ICmpInst::ICMP_SGE: - if (ControllingFiniteLoop || !getSignedRangeMin(RHS).isMinSignedValue()) { + if (!getSignedRangeMin(RHS).isMinSignedValue()) { RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SGT; @@ -10780,7 +10625,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, } break; case ICmpInst::ICMP_ULE: - if (ControllingFiniteLoop || !getUnsignedRangeMax(RHS).isMaxValue()) { + if (!getUnsignedRangeMax(RHS).isMaxValue()) { RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, SCEV::FlagNUW); Pred = ICmpInst::ICMP_ULT; @@ -10792,7 +10637,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, } break; case ICmpInst::ICMP_UGE: - if (ControllingFiniteLoop || !getUnsignedRangeMin(RHS).isMinValue()) { + if (!getUnsignedRangeMin(RHS).isMinValue()) { RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); Pred = ICmpInst::ICMP_UGT; Changed = true; @@ -10812,8 +10657,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // Recursively simplify until we either hit a recursion limit or nothing // changes. if (Changed) - return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1, - ControllingFiniteLoop); + return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1); return Changed; } @@ -10921,7 +10765,7 @@ std::optional<bool> ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred, const SCEV *RHS) { if (isKnownPredicate(Pred, LHS, RHS)) return true; - else if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS)) + if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS)) return false; return std::nullopt; } @@ -10943,7 +10787,7 @@ ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS)) return true; - else if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), + if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), ICmpInst::getInversePredicate(Pred), LHS, RHS)) return false; @@ -11004,22 +10848,21 @@ ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, if (!LHS->hasNoUnsignedWrap()) return std::nullopt; return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; - } else { - assert(ICmpInst::isSigned(Pred) && - "Relational predicate is either signed or unsigned!"); - if (!LHS->hasNoSignedWrap()) - return std::nullopt; + } + assert(ICmpInst::isSigned(Pred) && + "Relational predicate is either signed or unsigned!"); + if (!LHS->hasNoSignedWrap()) + return std::nullopt; - const SCEV *Step = LHS->getStepRecurrence(*this); + const SCEV *Step = LHS->getStepRecurrence(*this); - if (isKnownNonNegative(Step)) - return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; + if (isKnownNonNegative(Step)) + return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; - if (isKnownNonPositive(Step)) - return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; + if (isKnownNonPositive(Step)) + return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; - return std::nullopt; - } + return std::nullopt; } std::optional<ScalarEvolution::LoopInvariantPredicate> @@ -11353,7 +11196,7 @@ bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // No need to even try if we know the module has no guards. - if (AC.assumptions().empty()) + if (!HasGuards) return false; return any_of(*BB, [&](const Instruction &I) { @@ -11563,6 +11406,15 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, return true; } + // Check conditions due to any @llvm.experimental.guard intrinsics. + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + if (GuardDecl) + for (const auto *GU : GuardDecl->users()) + if (const auto *Guard = dyn_cast<IntrinsicInst>(GU)) + if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB)) + if (ProveViaCond(Guard->getArgOperand(0), false)) + return true; return false; } @@ -12731,7 +12583,7 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit, bool AllowPredicates) { + bool ControlsOnlyExit, bool AllowPredicates) { SmallPtrSet<const SCEVPredicate *, 4> Predicates; const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); @@ -12759,7 +12611,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!StrideC || !StrideC->getAPInt().isPowerOf2()) return false; - if (!ControlsExit || !loopHasNoAbnormalExits(L)) + if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L)) return false; return loopIsFiniteByAssumption(L); @@ -12834,7 +12686,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // implicit/exceptional) which causes the loop to execute before the // exiting instruction we're analyzing would trigger UB. auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW; - bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType); + bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; const SCEV *Stride = IV->getStepRecurrence(*this); @@ -13154,10 +13006,9 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, Predicates); } -ScalarEvolution::ExitLimit -ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool IsSigned, - bool ControlsExit, bool AllowPredicates) { +ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( + const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, + bool ControlsOnlyExit, bool AllowPredicates) { SmallPtrSet<const SCEVPredicate *, 4> Predicates; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) @@ -13175,7 +13026,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, return getCouldNotCompute(); auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW; - bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType); + bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); @@ -13435,16 +13286,30 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, LoopInfo &LI) : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), - LoopDispositions(64), BlockDispositions(64) {} + LoopDispositions(64), BlockDispositions(64) { + // To use guards for proving predicates, we need to scan every instruction in + // relevant basic blocks, and not just terminators. Doing this is a waste of + // time if the IR does not actually contain any calls to + // @llvm.experimental.guard, so do a quick check and remember this beforehand. + // + // This pessimizes the case where a pass that preserves ScalarEvolution wants + // to _add_ guards to the module when there weren't any before, and wants + // ScalarEvolution to optimize based on those guards. For now we prefer to be + // efficient in lieu of being smart in that rather obscure case. + + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + HasGuards = GuardDecl && !GuardDecl->use_empty(); +} ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) - : F(Arg.F), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI), - CouldNotCompute(std::move(Arg.CouldNotCompute)), + : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), + LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), PendingPhiRanges(std::move(Arg.PendingPhiRanges)), PendingMerges(std::move(Arg.PendingMerges)), - MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), + ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( std::move(Arg.PredicatedBackedgeTakenCounts)), @@ -13580,16 +13445,36 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, } } -static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { +namespace llvm { +raw_ostream &operator<<(raw_ostream &OS, ScalarEvolution::LoopDisposition LD) { switch (LD) { case ScalarEvolution::LoopVariant: - return "Variant"; + OS << "Variant"; + break; case ScalarEvolution::LoopInvariant: - return "Invariant"; + OS << "Invariant"; + break; case ScalarEvolution::LoopComputable: - return "Computable"; + OS << "Computable"; + break; } - llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!"); + return OS; +} + +raw_ostream &operator<<(raw_ostream &OS, ScalarEvolution::BlockDisposition BD) { + switch (BD) { + case ScalarEvolution::DoesNotDominateBlock: + OS << "DoesNotDominate"; + break; + case ScalarEvolution::DominatesBlock: + OS << "Dominates"; + break; + case ScalarEvolution::ProperlyDominatesBlock: + OS << "ProperlyDominates"; + break; + } + return OS; +} } void ScalarEvolution::print(raw_ostream &OS) const { @@ -13651,7 +13536,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { } Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); - OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter)); + OS << ": " << SE.getLoopDisposition(SV, Iter); } for (const auto *InnerL : depth_first(L)) { @@ -13665,7 +13550,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { } InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); - OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL)); + OS << ": " << SE.getLoopDisposition(SV, InnerL); } OS << " }"; @@ -13705,6 +13590,7 @@ ScalarEvolution::LoopDisposition ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { switch (S->getSCEVType()) { case scConstant: + case scVScale: return LoopInvariant; case scAddRecExpr: { const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); @@ -13803,6 +13689,7 @@ ScalarEvolution::BlockDisposition ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { switch (S->getSCEVType()) { case scConstant: + case scVScale: return ProperlyDominatesBlock; case scAddRecExpr: { // This uses a "dominates" query instead of "properly dominates" query @@ -13917,7 +13804,7 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { UnsignedRanges.erase(S); SignedRanges.erase(S); HasRecMap.erase(S); - MinTrailingZerosCache.erase(S); + ConstantMultipleCache.erase(S); if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) { UnsignedWrapViaInductionTried.erase(AR); @@ -14249,9 +14136,8 @@ void ScalarEvolution::verify() const { const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop); if (CachedDisposition != RecomputedDisposition) { dbgs() << "Cached disposition of " << *S << " for loop " << *Loop - << " is incorrect: cached " - << loopDispositionToStr(CachedDisposition) << ", actual " - << loopDispositionToStr(RecomputedDisposition) << "\n"; + << " is incorrect: cached " << CachedDisposition << ", actual " + << RecomputedDisposition << "\n"; std::abort(); } } @@ -14263,7 +14149,8 @@ void ScalarEvolution::verify() const { const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB); if (CachedDisposition != RecomputedDisposition) { dbgs() << "Cached disposition of " << *S << " for block %" - << BB->getName() << " is incorrect! \n"; + << BB->getName() << " is incorrect: cached " << CachedDisposition + << ", actual " << RecomputedDisposition << "\n"; std::abort(); } } @@ -14297,6 +14184,23 @@ void ScalarEvolution::verify() const { } } } + + // Verify that ConstantMultipleCache computations are correct. We check that + // cached multiples and recomputed multiples are multiples of each other to + // verify correctness. It is possible that a recomputed multiple is different + // from the cached multiple due to strengthened no wrap flags or changes in + // KnownBits computations. + for (auto [S, Multiple] : ConstantMultipleCache) { + APInt RecomputedMultiple = SE2.getConstantMultiple(S); + if ((Multiple != 0 && RecomputedMultiple != 0 && + Multiple.urem(RecomputedMultiple) != 0 && + RecomputedMultiple.urem(Multiple) != 0)) { + dbgs() << "Incorrect cached computation in ConstantMultipleCache for " + << *S << " : Computed " << RecomputedMultiple + << " but cache contains " << Multiple << "!\n"; + std::abort(); + } + } } bool ScalarEvolution::invalidate( @@ -14315,10 +14219,11 @@ AnalysisKey ScalarEvolutionAnalysis::Key; ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, FunctionAnalysisManager &AM) { - return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F), - AM.getResult<AssumptionAnalysis>(F), - AM.getResult<DominatorTreeAnalysis>(F), - AM.getResult<LoopAnalysis>(F)); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + return ScalarEvolution(F, TLI, AC, DT, LI); } PreservedAnalyses @@ -14603,8 +14508,7 @@ void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const { if (Pred == ICmpInst::ICMP_EQ) OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; else - OS.indent(Depth) << "Compare predicate: " << *LHS - << " " << CmpInst::getPredicateName(Pred) << ") " + OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") " << *RHS << "\n"; } @@ -14933,9 +14837,6 @@ ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) { /// A rewriter to replace SCEV expressions in Map with the corresponding entry /// in the map. It skips AddRecExpr because we cannot guarantee that the /// replacement is loop invariant in the loop of the AddRec. -/// -/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is -/// supported. class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { const DenseMap<const SCEV *, const SCEV *> ⤅ @@ -14955,9 +14856,47 @@ public: const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { auto I = Map.find(Expr); - if (I == Map.end()) + if (I == Map.end()) { + // If we didn't find the extact ZExt expr in the map, check if there's an + // entry for a smaller ZExt we can use instead. + Type *Ty = Expr->getType(); + const SCEV *Op = Expr->getOperand(0); + unsigned Bitwidth = Ty->getScalarSizeInBits() / 2; + while (Bitwidth % 8 == 0 && Bitwidth >= 8 && + Bitwidth > Op->getType()->getScalarSizeInBits()) { + Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth); + auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy); + auto I = Map.find(NarrowExt); + if (I != Map.end()) + return SE.getZeroExtendExpr(I->second, Ty); + Bitwidth = Bitwidth / 2; + } + return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr( Expr); + } + return I->second; + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + auto I = Map.find(Expr); + if (I == Map.end()) + return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSignExtendExpr( + Expr); + return I->second; + } + + const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { + auto I = Map.find(Expr); + if (I == Map.end()) + return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitUMinExpr(Expr); + return I->second; + } + + const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { + auto I = Map.find(Expr); + if (I == Map.end()) + return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr); return I->second; } }; @@ -15012,6 +14951,93 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { if (MatchRangeCheckIdiom()) return; + // Return true if \p Expr is a MinMax SCEV expression with a non-negative + // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS + // the non-constant operand and in \p LHS the constant operand. + auto IsMinMaxSCEVWithNonNegativeConstant = + [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, + const SCEV *&RHS) { + if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) { + if (MinMax->getNumOperands() != 2) + return false; + if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) { + if (C->getAPInt().isNegative()) + return false; + SCTy = MinMax->getSCEVType(); + LHS = MinMax->getOperand(0); + RHS = MinMax->getOperand(1); + return true; + } + } + return false; + }; + + // Checks whether Expr is a non-negative constant, and Divisor is a positive + // constant, and returns their APInt in ExprVal and in DivisorVal. + auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, + APInt &ExprVal, APInt &DivisorVal) { + auto *ConstExpr = dyn_cast<SCEVConstant>(Expr); + auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor); + if (!ConstExpr || !ConstDivisor) + return false; + ExprVal = ConstExpr->getAPInt(); + DivisorVal = ConstDivisor->getAPInt(); + return ExprVal.isNonNegative() && !DivisorVal.isNonPositive(); + }; + + // Return a new SCEV that modifies \p Expr to the closest number divides by + // \p Divisor and greater or equal than Expr. + // For now, only handle constant Expr and Divisor. + auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + APInt ExprVal; + APInt DivisorVal; + if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + return Expr; + APInt Rem = ExprVal.urem(DivisorVal); + if (!Rem.isZero()) + // return the SCEV: Expr + Divisor - Expr % Divisor + return getConstant(ExprVal + DivisorVal - Rem); + return Expr; + }; + + // Return a new SCEV that modifies \p Expr to the closest number divides by + // \p Divisor and less or equal than Expr. + // For now, only handle constant Expr and Divisor. + auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, + const SCEV *Divisor) { + APInt ExprVal; + APInt DivisorVal; + if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) + return Expr; + APInt Rem = ExprVal.urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return getConstant(ExprVal - Rem); + }; + + // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, + // recursively. This is done by aligning up/down the constant value to the + // Divisor. + std::function<const SCEV *(const SCEV *, const SCEV *)> + ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, + const SCEV *Divisor) { + const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; + SCEVTypes SCTy; + if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, + MinMaxRHS)) + return MinMaxExpr; + auto IsMin = + isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr); + assert(isKnownNonNegative(MinMaxLHS) && + "Expected non-negative operand!"); + auto *DivisibleExpr = + IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) + : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); + SmallVector<const SCEV *> Ops = { + ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; + return getMinMaxExpr(SCTy, Ops); + }; + // If we have LHS == 0, check if LHS is computing a property of some unknown // SCEV %v which we can rewrite %v to express explicitly. const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS); @@ -15023,7 +15049,12 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { const SCEV *URemRHS = nullptr; if (matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) { - const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); + auto I = RewriteMap.find(LHSUnknown); + const SCEV *RewrittenLHS = + I != RewriteMap.end() ? I->second : LHSUnknown; + RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); + const auto *Multiple = + getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; ExprsToRewrite.push_back(LHSUnknown); return; @@ -15041,62 +15072,170 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { Predicate = CmpInst::getSwappedPredicate(Predicate); } - // Limit to expressions that can be rewritten. - if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS)) - return; + // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From + // and \p FromRewritten are the same (i.e. there has been no rewrite + // registered for \p From), then puts this value in the list of rewritten + // expressions. + auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten, + const SCEV *To) { + if (From == FromRewritten) + ExprsToRewrite.push_back(From); + RewriteMap[From] = To; + }; + + // Checks whether \p S has already been rewritten. In that case returns the + // existing rewrite because we want to chain further rewrites onto the + // already rewritten value. Otherwise returns \p S. + auto GetMaybeRewritten = [&](const SCEV *S) { + auto I = RewriteMap.find(S); + return I != RewriteMap.end() ? I->second : S; + }; - // Check whether LHS has already been rewritten. In that case we want to - // chain further rewrites onto the already rewritten value. - auto I = RewriteMap.find(LHS); - const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS; + // Check for the SCEV expression (A /u B) * B while B is a constant, inside + // \p Expr. The check is done recuresively on \p Expr, which is assumed to + // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A + // /u B) * B was found, and return the divisor B in \p DividesBy. For + // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since + // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p + // DividesBy. + std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo = + [&](const SCEV *Expr, const SCEV *&DividesBy) { + if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) { + if (Mul->getNumOperands() != 2) + return false; + auto *MulLHS = Mul->getOperand(0); + auto *MulRHS = Mul->getOperand(1); + if (isa<SCEVConstant>(MulLHS)) + std::swap(MulLHS, MulRHS); + if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS)) + if (Div->getOperand(1) == MulRHS) { + DividesBy = MulRHS; + return true; + } + } + if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) + return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) || + HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy); + return false; + }; - const SCEV *RewrittenRHS = nullptr; + // Return true if Expr known to divide by \p DividesBy. + std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy = + [&](const SCEV *Expr, const SCEV *DividesBy) { + if (getURemExpr(Expr, DividesBy)->isZero()) + return true; + if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) + return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) && + IsKnownToDivideBy(MinMax->getOperand(1), DividesBy); + return false; + }; + + const SCEV *RewrittenLHS = GetMaybeRewritten(LHS); + const SCEV *DividesBy = nullptr; + if (HasDivisibiltyInfo(RewrittenLHS, DividesBy)) + // Check that the whole expression is divided by DividesBy + DividesBy = + IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr; + + // Collect rewrites for LHS and its transitive operands based on the + // condition. + // For min/max expressions, also apply the guard to its operands: + // 'min(a, b) >= c' -> '(a >= c) and (b >= c)', + // 'min(a, b) > c' -> '(a > c) and (b > c)', + // 'max(a, b) <= c' -> '(a <= c) and (b <= c)', + // 'max(a, b) < c' -> '(a < c) and (b < c)'. + + // We cannot express strict predicates in SCEV, so instead we replace them + // with non-strict ones against plus or minus one of RHS depending on the + // predicate. + const SCEV *One = getOne(RHS->getType()); switch (Predicate) { - case CmpInst::ICMP_ULT: - RewrittenRHS = - getUMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType()))); - break; - case CmpInst::ICMP_SLT: - RewrittenRHS = - getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType()))); - break; - case CmpInst::ICMP_ULE: - RewrittenRHS = getUMinExpr(RewrittenLHS, RHS); - break; - case CmpInst::ICMP_SLE: - RewrittenRHS = getSMinExpr(RewrittenLHS, RHS); - break; - case CmpInst::ICMP_UGT: - RewrittenRHS = - getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); - break; - case CmpInst::ICMP_SGT: - RewrittenRHS = - getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); - break; - case CmpInst::ICMP_UGE: - RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS); - break; - case CmpInst::ICMP_SGE: - RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS); - break; - case CmpInst::ICMP_EQ: - if (isa<SCEVConstant>(RHS)) - RewrittenRHS = RHS; - break; - case CmpInst::ICMP_NE: - if (isa<SCEVConstant>(RHS) && - cast<SCEVConstant>(RHS)->getValue()->isNullValue()) - RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType())); - break; - default: - break; + case CmpInst::ICMP_ULT: + if (RHS->getType()->isPointerTy()) + return; + RHS = getUMaxExpr(RHS, One); + [[fallthrough]]; + case CmpInst::ICMP_SLT: { + RHS = getMinusSCEV(RHS, One); + RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + } + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_SGT: + RHS = getAddExpr(RHS, One); + RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + case CmpInst::ICMP_ULE: + case CmpInst::ICMP_SLE: + RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_SGE: + RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + default: + break; } - if (RewrittenRHS) { - RewriteMap[LHS] = RewrittenRHS; - if (LHS == RewrittenLHS) - ExprsToRewrite.push_back(LHS); + SmallVector<const SCEV *, 16> Worklist(1, LHS); + SmallPtrSet<const SCEV *, 16> Visited; + + auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) { + append_range(Worklist, S->operands()); + }; + + while (!Worklist.empty()) { + const SCEV *From = Worklist.pop_back_val(); + if (isa<SCEVConstant>(From)) + continue; + if (!Visited.insert(From).second) + continue; + const SCEV *FromRewritten = GetMaybeRewritten(From); + const SCEV *To = nullptr; + + switch (Predicate) { + case CmpInst::ICMP_ULT: + case CmpInst::ICMP_ULE: + To = getUMinExpr(FromRewritten, RHS); + if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten)) + EnqueueOperands(UMax); + break; + case CmpInst::ICMP_SLT: + case CmpInst::ICMP_SLE: + To = getSMinExpr(FromRewritten, RHS); + if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten)) + EnqueueOperands(SMax); + break; + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_UGE: + To = getUMaxExpr(FromRewritten, RHS); + if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten)) + EnqueueOperands(UMin); + break; + case CmpInst::ICMP_SGT: + case CmpInst::ICMP_SGE: + To = getSMaxExpr(FromRewritten, RHS); + if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten)) + EnqueueOperands(SMin); + break; + case CmpInst::ICMP_EQ: + if (isa<SCEVConstant>(RHS)) + To = RHS; + break; + case CmpInst::ICMP_NE: + if (isa<SCEVConstant>(RHS) && + cast<SCEVConstant>(RHS)->getValue()->isNullValue()) { + const SCEV *OneAlignedUp = + DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One; + To = getUMaxExpr(FromRewritten, OneAlignedUp); + } + break; + default: + break; + } + + if (To) + AddRewrite(From, FromRewritten, To); } }; @@ -15112,7 +15251,16 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { Terms.emplace_back(AssumeI->getOperand(0), true); } - // Second, collect conditions from dominating branches. Starting at the loop + // Second, collect information from llvm.experimental.guards dominating the loop. + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + if (GuardDecl) + for (const auto *GU : GuardDecl->users()) + if (const auto *Guard = dyn_cast<IntrinsicInst>(GU)) + if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header)) + Terms.emplace_back(Guard->getArgOperand(0), true); + + // Third, collect conditions from dominating branches. Starting at the loop // predecessor, climb up the predecessor chain, as long as there are // predecessors that can be found that have unique successors leading to the // original header. diff --git a/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionDivision.cpp b/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionDivision.cpp index 0619569bf816..e1dd834cfb10 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionDivision.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionDivision.cpp @@ -126,6 +126,10 @@ void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { } } +void SCEVDivision::visitVScale(const SCEVVScale *Numerator) { + return cannotDivide(Numerator); +} + void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { const SCEV *StartQ, *StartR, *StepQ, *StepR; if (!Numerator->isAffine()) diff --git a/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp b/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp index 22dff5efec5c..cfc5b8455454 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -96,11 +96,20 @@ NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, const PostIncLoopSet &Loops, - ScalarEvolution &SE) { + ScalarEvolution &SE, + bool CheckInvertible) { + if (Loops.empty()) + return S; auto Pred = [&](const SCEVAddRecExpr *AR) { return Loops.count(AR->getLoop()); }; - return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); + const SCEV *Normalized = + NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); + const SCEV *Denormalized = denormalizeForPostIncUse(Normalized, Loops, SE); + // If the normalized expression isn't invertible. + if (CheckInvertible && Denormalized != S) + return nullptr; + return Normalized; } const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, @@ -111,6 +120,8 @@ const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, const PostIncLoopSet &Loops, ScalarEvolution &SE) { + if (Loops.empty()) + return S; auto Pred = [&](const SCEVAddRecExpr *AR) { return Loops.count(AR->getLoop()); }; diff --git a/contrib/llvm-project/llvm/lib/Analysis/StackLifetime.cpp b/contrib/llvm-project/llvm/lib/Analysis/StackLifetime.cpp index ee77e81fc978..3e1b5dea6f6c 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/StackLifetime.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/StackLifetime.cpp @@ -39,7 +39,7 @@ StackLifetime::getLiveRange(const AllocaInst *AI) const { } bool StackLifetime::isReachable(const Instruction *I) const { - return BlockInstRange.find(I->getParent()) != BlockInstRange.end(); + return BlockInstRange.contains(I->getParent()); } bool StackLifetime::isAliveAfter(const AllocaInst *AI, @@ -414,7 +414,7 @@ void StackLifetimePrinterPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<StackLifetimePrinterPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; switch (Type) { case StackLifetime::LivenessType::May: OS << "may"; @@ -423,5 +423,5 @@ void StackLifetimePrinterPass::printPipeline( OS << "must"; break; } - OS << ">"; + OS << '>'; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/StratifiedSets.h b/contrib/llvm-project/llvm/lib/Analysis/StratifiedSets.h deleted file mode 100644 index 193e4a461e66..000000000000 --- a/contrib/llvm-project/llvm/lib/Analysis/StratifiedSets.h +++ /dev/null @@ -1,595 +0,0 @@ -//===- StratifiedSets.h - Abstract stratified sets implementation. --------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_ADT_STRATIFIEDSETS_H -#define LLVM_ADT_STRATIFIEDSETS_H - -#include "AliasAnalysisSummary.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/SmallVector.h" -#include <bitset> -#include <cassert> -#include <cmath> -#include <type_traits> -#include <utility> -#include <vector> - -namespace llvm { -namespace cflaa { -/// An index into Stratified Sets. -typedef unsigned StratifiedIndex; -/// NOTE: ^ This can't be a short -- bootstrapping clang has a case where -/// ~1M sets exist. - -// Container of information related to a value in a StratifiedSet. -struct StratifiedInfo { - StratifiedIndex Index; - /// For field sensitivity, etc. we can tack fields on here. -}; - -/// A "link" between two StratifiedSets. -struct StratifiedLink { - /// This is a value used to signify "does not exist" where the - /// StratifiedIndex type is used. - /// - /// This is used instead of std::optional<StratifiedIndex> because - /// std::optional<StratifiedIndex> would eat up a considerable amount of extra - /// memory, after struct padding/alignment is taken into account. - static const StratifiedIndex SetSentinel; - - /// The index for the set "above" current - StratifiedIndex Above; - - /// The link for the set "below" current - StratifiedIndex Below; - - /// Attributes for these StratifiedSets. - AliasAttrs Attrs; - - StratifiedLink() : Above(SetSentinel), Below(SetSentinel) {} - - bool hasBelow() const { return Below != SetSentinel; } - bool hasAbove() const { return Above != SetSentinel; } - - void clearBelow() { Below = SetSentinel; } - void clearAbove() { Above = SetSentinel; } -}; - -/// These are stratified sets, as described in "Fast algorithms for -/// Dyck-CFL-reachability with applications to Alias Analysis" by Zhang Q, Lyu M -/// R, Yuan H, and Su Z. -- in short, this is meant to represent different sets -/// of Value*s. If two Value*s are in the same set, or if both sets have -/// overlapping attributes, then the Value*s are said to alias. -/// -/// Sets may be related by position, meaning that one set may be considered as -/// above or below another. In CFL Alias Analysis, this gives us an indication -/// of how two variables are related; if the set of variable A is below a set -/// containing variable B, then at some point, a variable that has interacted -/// with B (or B itself) was either used in order to extract the variable A, or -/// was used as storage of variable A. -/// -/// Sets may also have attributes (as noted above). These attributes are -/// generally used for noting whether a variable in the set has interacted with -/// a variable whose origins we don't quite know (i.e. globals/arguments), or if -/// the variable may have had operations performed on it (modified in a function -/// call). All attributes that exist in a set A must exist in all sets marked as -/// below set A. -template <typename T> class StratifiedSets { -public: - StratifiedSets() = default; - StratifiedSets(StratifiedSets &&) = default; - StratifiedSets &operator=(StratifiedSets &&) = default; - - StratifiedSets(DenseMap<T, StratifiedInfo> Map, - std::vector<StratifiedLink> Links) - : Values(std::move(Map)), Links(std::move(Links)) {} - - std::optional<StratifiedInfo> find(const T &Elem) const { - auto Iter = Values.find(Elem); - if (Iter == Values.end()) - return std::nullopt; - return Iter->second; - } - - const StratifiedLink &getLink(StratifiedIndex Index) const { - assert(inbounds(Index)); - return Links[Index]; - } - -private: - DenseMap<T, StratifiedInfo> Values; - std::vector<StratifiedLink> Links; - - bool inbounds(StratifiedIndex Idx) const { return Idx < Links.size(); } -}; - -/// Generic Builder class that produces StratifiedSets instances. -/// -/// The goal of this builder is to efficiently produce correct StratifiedSets -/// instances. To this end, we use a few tricks: -/// > Set chains (A method for linking sets together) -/// > Set remaps (A method for marking a set as an alias [irony?] of another) -/// -/// ==== Set chains ==== -/// This builder has a notion of some value A being above, below, or with some -/// other value B: -/// > The `A above B` relationship implies that there is a reference edge -/// going from A to B. Namely, it notes that A can store anything in B's set. -/// > The `A below B` relationship is the opposite of `A above B`. It implies -/// that there's a dereference edge going from A to B. -/// > The `A with B` relationship states that there's an assignment edge going -/// from A to B, and that A and B should be treated as equals. -/// -/// As an example, take the following code snippet: -/// -/// %a = alloca i32, align 4 -/// %ap = alloca i32*, align 8 -/// %app = alloca i32**, align 8 -/// store %a, %ap -/// store %ap, %app -/// %aw = getelementptr %ap, i32 0 -/// -/// Given this, the following relations exist: -/// - %a below %ap & %ap above %a -/// - %ap below %app & %app above %ap -/// - %aw with %ap & %ap with %aw -/// -/// These relations produce the following sets: -/// [{%a}, {%ap, %aw}, {%app}] -/// -/// ...Which state that the only MayAlias relationship in the above program is -/// between %ap and %aw. -/// -/// Because LLVM allows arbitrary casts, code like the following needs to be -/// supported: -/// %ip = alloca i64, align 8 -/// %ipp = alloca i64*, align 8 -/// %i = bitcast i64** ipp to i64 -/// store i64* %ip, i64** %ipp -/// store i64 %i, i64* %ip -/// -/// Which, because %ipp ends up *both* above and below %ip, is fun. -/// -/// This is solved by merging %i and %ipp into a single set (...which is the -/// only way to solve this, since their bit patterns are equivalent). Any sets -/// that ended up in between %i and %ipp at the time of merging (in this case, -/// the set containing %ip) also get conservatively merged into the set of %i -/// and %ipp. In short, the resulting StratifiedSet from the above code would be -/// {%ip, %ipp, %i}. -/// -/// ==== Set remaps ==== -/// More of an implementation detail than anything -- when merging sets, we need -/// to update the numbers of all of the elements mapped to those sets. Rather -/// than doing this at each merge, we note in the BuilderLink structure that a -/// remap has occurred, and use this information so we can defer renumbering set -/// elements until build time. -template <typename T> class StratifiedSetsBuilder { - /// Represents a Stratified Set, with information about the Stratified - /// Set above it, the set below it, and whether the current set has been - /// remapped to another. - struct BuilderLink { - const StratifiedIndex Number; - - BuilderLink(StratifiedIndex N) : Number(N) { - Remap = StratifiedLink::SetSentinel; - } - - bool hasAbove() const { - assert(!isRemapped()); - return Link.hasAbove(); - } - - bool hasBelow() const { - assert(!isRemapped()); - return Link.hasBelow(); - } - - void setBelow(StratifiedIndex I) { - assert(!isRemapped()); - Link.Below = I; - } - - void setAbove(StratifiedIndex I) { - assert(!isRemapped()); - Link.Above = I; - } - - void clearBelow() { - assert(!isRemapped()); - Link.clearBelow(); - } - - void clearAbove() { - assert(!isRemapped()); - Link.clearAbove(); - } - - StratifiedIndex getBelow() const { - assert(!isRemapped()); - assert(hasBelow()); - return Link.Below; - } - - StratifiedIndex getAbove() const { - assert(!isRemapped()); - assert(hasAbove()); - return Link.Above; - } - - AliasAttrs getAttrs() { - assert(!isRemapped()); - return Link.Attrs; - } - - void setAttrs(AliasAttrs Other) { - assert(!isRemapped()); - Link.Attrs |= Other; - } - - bool isRemapped() const { return Remap != StratifiedLink::SetSentinel; } - - /// For initial remapping to another set - void remapTo(StratifiedIndex Other) { - assert(!isRemapped()); - Remap = Other; - } - - StratifiedIndex getRemapIndex() const { - assert(isRemapped()); - return Remap; - } - - /// Should only be called when we're already remapped. - void updateRemap(StratifiedIndex Other) { - assert(isRemapped()); - Remap = Other; - } - - /// Prefer the above functions to calling things directly on what's returned - /// from this -- they guard against unexpected calls when the current - /// BuilderLink is remapped. - const StratifiedLink &getLink() const { return Link; } - - private: - StratifiedLink Link; - StratifiedIndex Remap; - }; - - /// This function performs all of the set unioning/value renumbering - /// that we've been putting off, and generates a vector<StratifiedLink> that - /// may be placed in a StratifiedSets instance. - void finalizeSets(std::vector<StratifiedLink> &StratLinks) { - DenseMap<StratifiedIndex, StratifiedIndex> Remaps; - for (auto &Link : Links) { - if (Link.isRemapped()) - continue; - - StratifiedIndex Number = StratLinks.size(); - Remaps.insert(std::make_pair(Link.Number, Number)); - StratLinks.push_back(Link.getLink()); - } - - for (auto &Link : StratLinks) { - if (Link.hasAbove()) { - auto &Above = linksAt(Link.Above); - auto Iter = Remaps.find(Above.Number); - assert(Iter != Remaps.end()); - Link.Above = Iter->second; - } - - if (Link.hasBelow()) { - auto &Below = linksAt(Link.Below); - auto Iter = Remaps.find(Below.Number); - assert(Iter != Remaps.end()); - Link.Below = Iter->second; - } - } - - for (auto &Pair : Values) { - auto &Info = Pair.second; - auto &Link = linksAt(Info.Index); - auto Iter = Remaps.find(Link.Number); - assert(Iter != Remaps.end()); - Info.Index = Iter->second; - } - } - - /// There's a guarantee in StratifiedLink where all bits set in a - /// Link.externals will be set in all Link.externals "below" it. - static void propagateAttrs(std::vector<StratifiedLink> &Links) { - const auto getHighestParentAbove = [&Links](StratifiedIndex Idx) { - const auto *Link = &Links[Idx]; - while (Link->hasAbove()) { - Idx = Link->Above; - Link = &Links[Idx]; - } - return Idx; - }; - - SmallSet<StratifiedIndex, 16> Visited; - for (unsigned I = 0, E = Links.size(); I < E; ++I) { - auto CurrentIndex = getHighestParentAbove(I); - if (!Visited.insert(CurrentIndex).second) - continue; - - while (Links[CurrentIndex].hasBelow()) { - auto &CurrentBits = Links[CurrentIndex].Attrs; - auto NextIndex = Links[CurrentIndex].Below; - auto &NextBits = Links[NextIndex].Attrs; - NextBits |= CurrentBits; - CurrentIndex = NextIndex; - } - } - } - -public: - /// Builds a StratifiedSet from the information we've been given since either - /// construction or the prior build() call. - StratifiedSets<T> build() { - std::vector<StratifiedLink> StratLinks; - finalizeSets(StratLinks); - propagateAttrs(StratLinks); - Links.clear(); - return StratifiedSets<T>(std::move(Values), std::move(StratLinks)); - } - - bool has(const T &Elem) const { return get(Elem).has_value(); } - - bool add(const T &Main) { - if (get(Main)) - return false; - - auto NewIndex = getNewUnlinkedIndex(); - return addAtMerging(Main, NewIndex); - } - - /// Restructures the stratified sets as necessary to make "ToAdd" in a - /// set above "Main". There are some cases where this is not possible (see - /// above), so we merge them such that ToAdd and Main are in the same set. - bool addAbove(const T &Main, const T &ToAdd) { - assert(has(Main)); - auto Index = *indexOf(Main); - if (!linksAt(Index).hasAbove()) - addLinkAbove(Index); - - auto Above = linksAt(Index).getAbove(); - return addAtMerging(ToAdd, Above); - } - - /// Restructures the stratified sets as necessary to make "ToAdd" in a - /// set below "Main". There are some cases where this is not possible (see - /// above), so we merge them such that ToAdd and Main are in the same set. - bool addBelow(const T &Main, const T &ToAdd) { - assert(has(Main)); - auto Index = *indexOf(Main); - if (!linksAt(Index).hasBelow()) - addLinkBelow(Index); - - auto Below = linksAt(Index).getBelow(); - return addAtMerging(ToAdd, Below); - } - - bool addWith(const T &Main, const T &ToAdd) { - assert(has(Main)); - auto MainIndex = *indexOf(Main); - return addAtMerging(ToAdd, MainIndex); - } - - void noteAttributes(const T &Main, AliasAttrs NewAttrs) { - assert(has(Main)); - auto *Info = *get(Main); - auto &Link = linksAt(Info->Index); - Link.setAttrs(NewAttrs); - } - -private: - DenseMap<T, StratifiedInfo> Values; - std::vector<BuilderLink> Links; - - /// Adds the given element at the given index, merging sets if necessary. - bool addAtMerging(const T &ToAdd, StratifiedIndex Index) { - StratifiedInfo Info = {Index}; - auto Pair = Values.insert(std::make_pair(ToAdd, Info)); - if (Pair.second) - return true; - - auto &Iter = Pair.first; - auto &IterSet = linksAt(Iter->second.Index); - auto &ReqSet = linksAt(Index); - - // Failed to add where we wanted to. Merge the sets. - if (&IterSet != &ReqSet) - merge(IterSet.Number, ReqSet.Number); - - return false; - } - - /// Gets the BuilderLink at the given index, taking set remapping into - /// account. - BuilderLink &linksAt(StratifiedIndex Index) { - auto *Start = &Links[Index]; - if (!Start->isRemapped()) - return *Start; - - auto *Current = Start; - while (Current->isRemapped()) - Current = &Links[Current->getRemapIndex()]; - - auto NewRemap = Current->Number; - - // Run through everything that has yet to be updated, and update them to - // remap to NewRemap - Current = Start; - while (Current->isRemapped()) { - auto *Next = &Links[Current->getRemapIndex()]; - Current->updateRemap(NewRemap); - Current = Next; - } - - return *Current; - } - - /// Merges two sets into one another. Assumes that these sets are not - /// already one in the same. - void merge(StratifiedIndex Idx1, StratifiedIndex Idx2) { - assert(inbounds(Idx1) && inbounds(Idx2)); - assert(&linksAt(Idx1) != &linksAt(Idx2) && - "Merging a set into itself is not allowed"); - - // CASE 1: If the set at `Idx1` is above or below `Idx2`, we need to merge - // both the - // given sets, and all sets between them, into one. - if (tryMergeUpwards(Idx1, Idx2)) - return; - - if (tryMergeUpwards(Idx2, Idx1)) - return; - - // CASE 2: The set at `Idx1` is not in the same chain as the set at `Idx2`. - // We therefore need to merge the two chains together. - mergeDirect(Idx1, Idx2); - } - - /// Merges two sets assuming that the set at `Idx1` is unreachable from - /// traversing above or below the set at `Idx2`. - void mergeDirect(StratifiedIndex Idx1, StratifiedIndex Idx2) { - assert(inbounds(Idx1) && inbounds(Idx2)); - - auto *LinksInto = &linksAt(Idx1); - auto *LinksFrom = &linksAt(Idx2); - // Merging everything above LinksInto then proceeding to merge everything - // below LinksInto becomes problematic, so we go as far "up" as possible! - while (LinksInto->hasAbove() && LinksFrom->hasAbove()) { - LinksInto = &linksAt(LinksInto->getAbove()); - LinksFrom = &linksAt(LinksFrom->getAbove()); - } - - if (LinksFrom->hasAbove()) { - LinksInto->setAbove(LinksFrom->getAbove()); - auto &NewAbove = linksAt(LinksInto->getAbove()); - NewAbove.setBelow(LinksInto->Number); - } - - // Merging strategy: - // > If neither has links below, stop. - // > If only `LinksInto` has links below, stop. - // > If only `LinksFrom` has links below, reset `LinksInto.Below` to - // match `LinksFrom.Below` - // > If both have links above, deal with those next. - while (LinksInto->hasBelow() && LinksFrom->hasBelow()) { - auto FromAttrs = LinksFrom->getAttrs(); - LinksInto->setAttrs(FromAttrs); - - // Remap needs to happen after getBelow(), but before - // assignment of LinksFrom - auto *NewLinksFrom = &linksAt(LinksFrom->getBelow()); - LinksFrom->remapTo(LinksInto->Number); - LinksFrom = NewLinksFrom; - LinksInto = &linksAt(LinksInto->getBelow()); - } - - if (LinksFrom->hasBelow()) { - LinksInto->setBelow(LinksFrom->getBelow()); - auto &NewBelow = linksAt(LinksInto->getBelow()); - NewBelow.setAbove(LinksInto->Number); - } - - LinksInto->setAttrs(LinksFrom->getAttrs()); - LinksFrom->remapTo(LinksInto->Number); - } - - /// Checks to see if lowerIndex is at a level lower than upperIndex. If so, it - /// will merge lowerIndex with upperIndex (and all of the sets between) and - /// return true. Otherwise, it will return false. - bool tryMergeUpwards(StratifiedIndex LowerIndex, StratifiedIndex UpperIndex) { - assert(inbounds(LowerIndex) && inbounds(UpperIndex)); - auto *Lower = &linksAt(LowerIndex); - auto *Upper = &linksAt(UpperIndex); - if (Lower == Upper) - return true; - - SmallVector<BuilderLink *, 8> Found; - auto *Current = Lower; - auto Attrs = Current->getAttrs(); - while (Current->hasAbove() && Current != Upper) { - Found.push_back(Current); - Attrs |= Current->getAttrs(); - Current = &linksAt(Current->getAbove()); - } - - if (Current != Upper) - return false; - - Upper->setAttrs(Attrs); - - if (Lower->hasBelow()) { - auto NewBelowIndex = Lower->getBelow(); - Upper->setBelow(NewBelowIndex); - auto &NewBelow = linksAt(NewBelowIndex); - NewBelow.setAbove(UpperIndex); - } else { - Upper->clearBelow(); - } - - for (const auto &Ptr : Found) - Ptr->remapTo(Upper->Number); - - return true; - } - - std::optional<const StratifiedInfo *> get(const T &Val) const { - auto Result = Values.find(Val); - if (Result == Values.end()) - return std::nullopt; - return &Result->second; - } - - std::optional<StratifiedInfo *> get(const T &Val) { - auto Result = Values.find(Val); - if (Result == Values.end()) - return std::nullopt; - return &Result->second; - } - - std::optional<StratifiedIndex> indexOf(const T &Val) { - auto MaybeVal = get(Val); - if (!MaybeVal) - return std::nullopt; - auto *Info = *MaybeVal; - auto &Link = linksAt(Info->Index); - return Link.Number; - } - - StratifiedIndex addLinkBelow(StratifiedIndex Set) { - auto At = addLinks(); - Links[Set].setBelow(At); - Links[At].setAbove(Set); - return At; - } - - StratifiedIndex addLinkAbove(StratifiedIndex Set) { - auto At = addLinks(); - Links[At].setBelow(Set); - Links[Set].setAbove(At); - return At; - } - - StratifiedIndex getNewUnlinkedIndex() { return addLinks(); } - - StratifiedIndex addLinks() { - auto Link = Links.size(); - Links.push_back(BuilderLink(Link)); - return Link; - } - - bool inbounds(StratifiedIndex N) const { return N < Links.size(); } -}; -} -} -#endif // LLVM_ADT_STRATIFIEDSETS_H diff --git a/contrib/llvm-project/llvm/lib/Analysis/SyncDependenceAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/SyncDependenceAnalysis.cpp deleted file mode 100644 index 17d7676024a5..000000000000 --- a/contrib/llvm-project/llvm/lib/Analysis/SyncDependenceAnalysis.cpp +++ /dev/null @@ -1,478 +0,0 @@ -//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements an algorithm that returns for a divergent branch -// the set of basic blocks whose phi nodes become divergent due to divergent -// control. These are the blocks that are reachable by two disjoint paths from -// the branch or loop exits that have a reaching path that is disjoint from a -// path to the loop latch. -// -// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model -// control-induced divergence in phi nodes. -// -// -// -- Reference -- -// The algorithm is presented in Section 5 of -// -// An abstract interpretation for SPMD divergence -// on reducible control flow graphs. -// Julian Rosemann, Simon Moll and Sebastian Hack -// POPL '21 -// -// -// -- Sync dependence -- -// Sync dependence characterizes the control flow aspect of the -// propagation of branch divergence. For example, -// -// %cond = icmp slt i32 %tid, 10 -// br i1 %cond, label %then, label %else -// then: -// br label %merge -// else: -// br label %merge -// merge: -// %a = phi i32 [ 0, %then ], [ 1, %else ] -// -// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid -// because %tid is not on its use-def chains, %a is sync dependent on %tid -// because the branch "br i1 %cond" depends on %tid and affects which value %a -// is assigned to. -// -// -// -- Reduction to SSA construction -- -// There are two disjoint paths from A to X, if a certain variant of SSA -// construction places a phi node in X under the following set-up scheme. -// -// This variant of SSA construction ignores incoming undef values. -// That is paths from the entry without a definition do not result in -// phi nodes. -// -// entry -// / \ -// A \ -// / \ Y -// B C / -// \ / \ / -// D E -// \ / -// F -// -// Assume that A contains a divergent branch. We are interested -// in the set of all blocks where each block is reachable from A -// via two disjoint paths. This would be the set {D, F} in this -// case. -// To generally reduce this query to SSA construction we introduce -// a virtual variable x and assign to x different values in each -// successor block of A. -// -// entry -// / \ -// A \ -// / \ Y -// x = 0 x = 1 / -// \ / \ / -// D E -// \ / -// F -// -// Our flavor of SSA construction for x will construct the following -// -// entry -// / \ -// A \ -// / \ Y -// x0 = 0 x1 = 1 / -// \ / \ / -// x2 = phi E -// \ / -// x3 = phi -// -// The blocks D and F contain phi nodes and are thus each reachable -// by two disjoins paths from A. -// -// -- Remarks -- -// * In case of loop exits we need to check the disjoint path criterion for loops. -// To this end, we check whether the definition of x differs between the -// loop exit and the loop header (_after_ SSA construction). -// -// -- Known Limitations & Future Work -- -// * The algorithm requires reducible loops because the implementation -// implicitly performs a single iteration of the underlying data flow analysis. -// This was done for pragmatism, simplicity and speed. -// -// Relevant related work for extending the algorithm to irreducible control: -// A simple algorithm for global data flow analysis problems. -// Matthew S. Hecht and Jeffrey D. Ullman. -// SIAM Journal on Computing, 4(4):519–532, December 1975. -// -// * Another reason for requiring reducible loops is that points of -// synchronization in irreducible loops aren't 'obvious' - there is no unique -// header where threads 'should' synchronize when entering or coming back -// around from the latch. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Analysis/SyncDependenceAnalysis.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" - -#include <functional> - -#define DEBUG_TYPE "sync-dependence" - -// The SDA algorithm operates on a modified CFG - we modify the edges leaving -// loop headers as follows: -// -// * We remove all edges leaving all loop headers. -// * We add additional edges from the loop headers to their exit blocks. -// -// The modification is virtual, that is whenever we visit a loop header we -// pretend it had different successors. -namespace { -using namespace llvm; - -// Custom Post-Order Traveral -// -// We cannot use the vanilla (R)PO computation of LLVM because: -// * We (virtually) modify the CFG. -// * We want a loop-compact block enumeration, that is the numbers assigned to -// blocks of a loop form an interval -// -using POCB = std::function<void(const BasicBlock &)>; -using VisitedSet = std::set<const BasicBlock *>; -using BlockStack = std::vector<const BasicBlock *>; - -// forward -static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, - VisitedSet &Finalized); - -// for a nested region (top-level loop or nested loop) -static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, - POCB CallBack, VisitedSet &Finalized) { - const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; - while (!Stack.empty()) { - const auto *NextBB = Stack.back(); - - auto *NestedLoop = LI.getLoopFor(NextBB); - bool IsNestedLoop = NestedLoop != Loop; - - // Treat the loop as a node - if (IsNestedLoop) { - SmallVector<BasicBlock *, 3> NestedExits; - NestedLoop->getUniqueExitBlocks(NestedExits); - bool PushedNodes = false; - for (const auto *NestedExitBB : NestedExits) { - if (NestedExitBB == LoopHeader) - continue; - if (Loop && !Loop->contains(NestedExitBB)) - continue; - if (Finalized.count(NestedExitBB)) - continue; - PushedNodes = true; - Stack.push_back(NestedExitBB); - } - if (!PushedNodes) { - // All loop exits finalized -> finish this node - Stack.pop_back(); - computeLoopPO(LI, *NestedLoop, CallBack, Finalized); - } - continue; - } - - // DAG-style - bool PushedNodes = false; - for (const auto *SuccBB : successors(NextBB)) { - if (SuccBB == LoopHeader) - continue; - if (Loop && !Loop->contains(SuccBB)) - continue; - if (Finalized.count(SuccBB)) - continue; - PushedNodes = true; - Stack.push_back(SuccBB); - } - if (!PushedNodes) { - // Never push nodes twice - Stack.pop_back(); - if (!Finalized.insert(NextBB).second) - continue; - CallBack(*NextBB); - } - } -} - -static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { - VisitedSet Finalized; - BlockStack Stack; - Stack.reserve(24); // FIXME made-up number - Stack.push_back(&F.getEntryBlock()); - computeStackPO(Stack, LI, nullptr, CallBack, Finalized); -} - -static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, - VisitedSet &Finalized) { - /// Call CallBack on all loop blocks. - std::vector<const BasicBlock *> Stack; - const auto *LoopHeader = Loop.getHeader(); - - // Visit the header last - Finalized.insert(LoopHeader); - CallBack(*LoopHeader); - - // Initialize with immediate successors - for (const auto *BB : successors(LoopHeader)) { - if (!Loop.contains(BB)) - continue; - if (BB == LoopHeader) - continue; - Stack.push_back(BB); - } - - // Compute PO inside region - computeStackPO(Stack, LI, &Loop, CallBack, Finalized); -} - -} // namespace - -namespace llvm { - -ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; - -SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, - const PostDominatorTree &PDT, - const LoopInfo &LI) - : DT(DT), PDT(PDT), LI(LI) { - computeTopLevelPO(*DT.getRoot()->getParent(), LI, - [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); -} - -SyncDependenceAnalysis::~SyncDependenceAnalysis() = default; - -namespace { -// divergence propagator for reducible CFGs -struct DivergencePropagator { - const ModifiedPO &LoopPOT; - const DominatorTree &DT; - const PostDominatorTree &PDT; - const LoopInfo &LI; - const BasicBlock &DivTermBlock; - - // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at - // block B - // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet - // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths - // from X or B is an immediate successor of X (initial value). - using BlockLabelVec = std::vector<const BasicBlock *>; - BlockLabelVec BlockLabels; - // divergent join and loop exit descriptor. - std::unique_ptr<ControlDivergenceDesc> DivDesc; - - DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, - const PostDominatorTree &PDT, const LoopInfo &LI, - const BasicBlock &DivTermBlock) - : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), - BlockLabels(LoopPOT.size(), nullptr), - DivDesc(new ControlDivergenceDesc) {} - - void printDefs(raw_ostream &Out) { - Out << "Propagator::BlockLabels {\n"; - for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { - const auto *Label = BlockLabels[BlockIdx]; - Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx - << ") : "; - if (!Label) { - Out << "<null>\n"; - } else { - Out << Label->getName() << "\n"; - } - } - Out << "}\n"; - } - - // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this - // causes a divergent join. - bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { - auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); - - // unset or same reaching label - const auto *OldLabel = BlockLabels[SuccIdx]; - if (!OldLabel || (OldLabel == &PushedLabel)) { - BlockLabels[SuccIdx] = &PushedLabel; - return false; - } - - // Update the definition - BlockLabels[SuccIdx] = &SuccBlock; - return true; - } - - // visiting a virtual loop exit edge from the loop header --> temporal - // divergence on join - bool visitLoopExitEdge(const BasicBlock &ExitBlock, - const BasicBlock &DefBlock, bool FromParentLoop) { - // Pushing from a non-parent loop cannot cause temporal divergence. - if (!FromParentLoop) - return visitEdge(ExitBlock, DefBlock); - - if (!computeJoin(ExitBlock, DefBlock)) - return false; - - // Identified a divergent loop exit - DivDesc->LoopDivBlocks.insert(&ExitBlock); - LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() - << "\n"); - return true; - } - - // process \p SuccBlock with reaching definition \p DefBlock - bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { - if (!computeJoin(SuccBlock, DefBlock)) - return false; - - // Divergent, disjoint paths join. - DivDesc->JoinDivBlocks.insert(&SuccBlock); - LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); - return true; - } - - std::unique_ptr<ControlDivergenceDesc> computeJoinPoints() { - assert(DivDesc); - - LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() - << "\n"); - - const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); - - // Early stopping criterion - int FloorIdx = LoopPOT.size() - 1; - const BasicBlock *FloorLabel = nullptr; - - // bootstrap with branch targets - int BlockIdx = 0; - - for (const auto *SuccBlock : successors(&DivTermBlock)) { - auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); - BlockLabels[SuccIdx] = SuccBlock; - - // Find the successor with the highest index to start with - BlockIdx = std::max<int>(BlockIdx, SuccIdx); - FloorIdx = std::min<int>(FloorIdx, SuccIdx); - - // Identify immediate divergent loop exits - if (!DivBlockLoop) - continue; - - const auto *BlockLoop = LI.getLoopFor(SuccBlock); - if (BlockLoop && DivBlockLoop->contains(BlockLoop)) - continue; - DivDesc->LoopDivBlocks.insert(SuccBlock); - LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " - << SuccBlock->getName() << "\n"); - } - - // propagate definitions at the immediate successors of the node in RPO - for (; BlockIdx >= FloorIdx; --BlockIdx) { - LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); - - // Any label available here - const auto *Label = BlockLabels[BlockIdx]; - if (!Label) - continue; - - // Ok. Get the block - const auto *Block = LoopPOT.getBlockAt(BlockIdx); - LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); - - auto *BlockLoop = LI.getLoopFor(Block); - bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; - bool CausedJoin = false; - int LoweredFloorIdx = FloorIdx; - if (IsLoopHeader) { - // Disconnect from immediate successors and propagate directly to loop - // exits. - SmallVector<BasicBlock *, 4> BlockLoopExits; - BlockLoop->getExitBlocks(BlockLoopExits); - - bool IsParentLoop = BlockLoop->contains(&DivTermBlock); - for (const auto *BlockLoopExit : BlockLoopExits) { - CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); - LoweredFloorIdx = std::min<int>(LoweredFloorIdx, - LoopPOT.getIndexOf(*BlockLoopExit)); - } - } else { - // Acyclic successor case - for (const auto *SuccBlock : successors(Block)) { - CausedJoin |= visitEdge(*SuccBlock, *Label); - LoweredFloorIdx = - std::min<int>(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); - } - } - - // Floor update - if (CausedJoin) { - // 1. Different labels pushed to successors - FloorIdx = LoweredFloorIdx; - } else if (FloorLabel != Label) { - // 2. No join caused BUT we pushed a label that is different than the - // last pushed label - FloorIdx = LoweredFloorIdx; - FloorLabel = Label; - } - } - - LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); - - return std::move(DivDesc); - } -}; -} // end anonymous namespace - -#ifndef NDEBUG -static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { - Out << "["; - ListSeparator LS; - for (const auto *BB : Blocks) - Out << LS << BB->getName(); - Out << "]"; -} -#endif - -const ControlDivergenceDesc & -SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { - // trivial case - if (Term.getNumSuccessors() <= 1) { - return EmptyDivergenceDesc; - } - - // already available in cache? - auto ItCached = CachedControlDivDescs.find(&Term); - if (ItCached != CachedControlDivDescs.end()) - return *ItCached->second; - - // compute all join points - // Special handling of divergent loop exits is not needed for LCSSA - const auto &TermBlock = *Term.getParent(); - DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); - auto DivDesc = Propagator.computeJoinPoints(); - - LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; - dbgs() << "JoinDivBlocks: "; - printBlockSet(DivDesc->JoinDivBlocks, dbgs()); - dbgs() << "\nLoopDivBlocks: "; - printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); - - auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); - assert(ItInserted.second); - return *ItInserted.first->second; -} - -} // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Analysis/TargetLibraryInfo.cpp b/contrib/llvm-project/llvm/lib/Analysis/TargetLibraryInfo.cpp index 31cc0e7ec30e..05fa67d0bbf1 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/TargetLibraryInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/TargetLibraryInfo.cpp @@ -11,10 +11,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/ADT/Triple.h" #include "llvm/IR/Constants.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" +#include "llvm/TargetParser/Triple.h" using namespace llvm; static cl::opt<TargetLibraryInfoImpl::VectorLibrary> ClVectorLibrary( @@ -33,7 +33,9 @@ static cl::opt<TargetLibraryInfoImpl::VectorLibrary> ClVectorLibrary( clEnumValN(TargetLibraryInfoImpl::SVML, "SVML", "Intel SVML library"), clEnumValN(TargetLibraryInfoImpl::SLEEFGNUABI, "sleefgnuabi", - "SIMD Library for Evaluating Elementary Functions"))); + "SIMD Library for Evaluating Elementary Functions"), + clEnumValN(TargetLibraryInfoImpl::ArmPL, "ArmPL", + "Arm Performance Libraries"))); StringLiteral const TargetLibraryInfoImpl::StandardNames[LibFunc::NumLibFuncs] = { @@ -474,6 +476,7 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, TLI.setUnavailable(LibFunc_ZnajSt11align_val_tRKSt9nothrow_t); TLI.setUnavailable(LibFunc_Znam); TLI.setUnavailable(LibFunc_ZnamRKSt9nothrow_t); + TLI.setUnavailable(LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t); TLI.setUnavailable(LibFunc_ZnamSt11align_val_t); TLI.setUnavailable(LibFunc_ZnamSt11align_val_tRKSt9nothrow_t); TLI.setUnavailable(LibFunc_Znwj); @@ -482,8 +485,15 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, TLI.setUnavailable(LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t); TLI.setUnavailable(LibFunc_Znwm); TLI.setUnavailable(LibFunc_ZnwmRKSt9nothrow_t); + TLI.setUnavailable(LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t); TLI.setUnavailable(LibFunc_ZnwmSt11align_val_t); TLI.setUnavailable(LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t); + TLI.setUnavailable(LibFunc_Znwm12__hot_cold_t); + TLI.setUnavailable(LibFunc_ZnwmSt11align_val_t12__hot_cold_t); + TLI.setUnavailable(LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t); + TLI.setUnavailable(LibFunc_Znam12__hot_cold_t); + TLI.setUnavailable(LibFunc_ZnamSt11align_val_t12__hot_cold_t); + TLI.setUnavailable(LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t); } else { // Not MSVC, assume it's Itanium. TLI.setUnavailable(LibFunc_msvc_new_int); @@ -1181,10 +1191,17 @@ void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( case SLEEFGNUABI: { const VecDesc VecFuncs_VF2[] = { #define TLI_DEFINE_SLEEFGNUABI_VF2_VECFUNCS +#define TLI_DEFINE_VECFUNC(SCAL, VEC, VF) {SCAL, VEC, VF, /* MASK = */ false}, #include "llvm/Analysis/VecFuncs.def" }; const VecDesc VecFuncs_VF4[] = { #define TLI_DEFINE_SLEEFGNUABI_VF4_VECFUNCS +#define TLI_DEFINE_VECFUNC(SCAL, VEC, VF) {SCAL, VEC, VF, /* MASK = */ false}, +#include "llvm/Analysis/VecFuncs.def" + }; + const VecDesc VecFuncs_VFScalable[] = { +#define TLI_DEFINE_SLEEFGNUABI_SCALABLE_VECFUNCS +#define TLI_DEFINE_VECFUNC(SCAL, VEC, VF, MASK) {SCAL, VEC, VF, MASK}, #include "llvm/Analysis/VecFuncs.def" }; @@ -1195,6 +1212,24 @@ void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( case llvm::Triple::aarch64_be: addVectorizableFunctions(VecFuncs_VF2); addVectorizableFunctions(VecFuncs_VF4); + addVectorizableFunctions(VecFuncs_VFScalable); + break; + } + break; + } + case ArmPL: { + const VecDesc VecFuncs[] = { +#define TLI_DEFINE_ARMPL_VECFUNCS +#define TLI_DEFINE_VECFUNC(SCAL, VEC, VF, MASK) {SCAL, VEC, VF, MASK}, +#include "llvm/Analysis/VecFuncs.def" + }; + + switch (TargetTriple.getArch()) { + default: + break; + case llvm::Triple::aarch64: + case llvm::Triple::aarch64_be: + addVectorizableFunctions(VecFuncs); break; } break; @@ -1214,16 +1249,16 @@ bool TargetLibraryInfoImpl::isFunctionVectorizable(StringRef funcName) const { return I != VectorDescs.end() && StringRef(I->ScalarFnName) == funcName; } -StringRef -TargetLibraryInfoImpl::getVectorizedFunction(StringRef F, - const ElementCount &VF) const { +StringRef TargetLibraryInfoImpl::getVectorizedFunction(StringRef F, + const ElementCount &VF, + bool Masked) const { F = sanitizeFunctionName(F); if (F.empty()) return F; std::vector<VecDesc>::const_iterator I = llvm::lower_bound(VectorDescs, F, compareWithScalarFnName); while (I != VectorDescs.end() && StringRef(I->ScalarFnName) == F) { - if (I->VectorizationFactor == VF) + if ((I->VectorizationFactor == VF) && (I->Masked == Masked)) return I->VectorFnName; ++I; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/TargetTransformInfo.cpp b/contrib/llvm-project/llvm/lib/Analysis/TargetTransformInfo.cpp index ad7e5432d4c5..c751d174a48a 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -37,6 +37,11 @@ static cl::opt<unsigned> CacheLineSize( cl::desc("Use this to override the target cache line size when " "specified by the user.")); +static cl::opt<unsigned> PredictableBranchThreshold( + "predictable-branch-threshold", cl::init(99), cl::Hidden, + cl::desc( + "Use this to override the target's predictable branch threshold (%).")); + namespace { /// No-op implementation of the TTI interface using the utility base /// classes. @@ -103,6 +108,14 @@ IntrinsicCostAttributes::IntrinsicCostAttributes(Intrinsic::ID Id, Type *RTy, Arguments.insert(Arguments.begin(), Args.begin(), Args.end()); } +HardwareLoopInfo::HardwareLoopInfo(Loop *L) : L(L) { + // Match default options: + // - hardware-loop-counter-bitwidth = 32 + // - hardware-loop-decrement = 1 + CountType = Type::getInt32Ty(L->getHeader()->getContext()); + LoopDecrement = ConstantInt::get(CountType, 1); +} + bool HardwareLoopInfo::isHardwareLoopCandidate(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, bool ForceNestedLoop, @@ -204,15 +217,28 @@ TargetTransformInfo::adjustInliningThreshold(const CallBase *CB) const { return TTIImpl->adjustInliningThreshold(CB); } +unsigned TargetTransformInfo::getCallerAllocaCost(const CallBase *CB, + const AllocaInst *AI) const { + return TTIImpl->getCallerAllocaCost(CB, AI); +} + int TargetTransformInfo::getInlinerVectorBonusPercent() const { return TTIImpl->getInlinerVectorBonusPercent(); } -InstructionCost -TargetTransformInfo::getGEPCost(Type *PointeeType, const Value *Ptr, - ArrayRef<const Value *> Operands, - TTI::TargetCostKind CostKind) const { - return TTIImpl->getGEPCost(PointeeType, Ptr, Operands, CostKind); +InstructionCost TargetTransformInfo::getGEPCost( + Type *PointeeType, const Value *Ptr, ArrayRef<const Value *> Operands, + Type *AccessType, TTI::TargetCostKind CostKind) const { + return TTIImpl->getGEPCost(PointeeType, Ptr, Operands, AccessType, CostKind); +} + +InstructionCost TargetTransformInfo::getPointersChainCost( + ArrayRef<const Value *> Ptrs, const Value *Base, + const TTI::PointersChainInfo &Info, Type *AccessTy, + TTI::TargetCostKind CostKind) const { + assert((Base || !Info.isSameBase()) && + "If pointers have same base address it has to be provided."); + return TTIImpl->getPointersChainCost(Ptrs, Base, Info, AccessTy, CostKind); } unsigned TargetTransformInfo::getEstimatedNumberOfCaseClusters( @@ -232,15 +258,13 @@ TargetTransformInfo::getInstructionCost(const User *U, } BranchProbability TargetTransformInfo::getPredictableBranchThreshold() const { - return TTIImpl->getPredictableBranchThreshold(); + return PredictableBranchThreshold.getNumOccurrences() > 0 + ? BranchProbability(PredictableBranchThreshold, 100) + : TTIImpl->getPredictableBranchThreshold(); } -bool TargetTransformInfo::hasBranchDivergence() const { - return TTIImpl->hasBranchDivergence(); -} - -bool TargetTransformInfo::useGPUDivergenceAnalysis() const { - return TTIImpl->useGPUDivergenceAnalysis(); +bool TargetTransformInfo::hasBranchDivergence(const Function *F) const { + return TTIImpl->hasBranchDivergence(F); } bool TargetTransformInfo::isSourceOfDivergence(const Value *V) const { @@ -251,6 +275,16 @@ bool llvm::TargetTransformInfo::isAlwaysUniform(const Value *V) const { return TTIImpl->isAlwaysUniform(V); } +bool llvm::TargetTransformInfo::isValidAddrSpaceCast(unsigned FromAS, + unsigned ToAS) const { + return TTIImpl->isValidAddrSpaceCast(FromAS, ToAS); +} + +bool llvm::TargetTransformInfo::addrspacesMayAlias(unsigned FromAS, + unsigned ToAS) const { + return TTIImpl->addrspacesMayAlias(FromAS, ToAS); +} + unsigned TargetTransformInfo::getFlatAddressSpace() const { return TTIImpl->getFlatAddressSpace(); } @@ -299,14 +333,13 @@ bool TargetTransformInfo::isHardwareLoopProfitable( } bool TargetTransformInfo::preferPredicateOverEpilogue( - Loop *L, LoopInfo *LI, ScalarEvolution &SE, AssumptionCache &AC, - TargetLibraryInfo *TLI, DominatorTree *DT, LoopVectorizationLegality *LVL, - InterleavedAccessInfo *IAI) const { - return TTIImpl->preferPredicateOverEpilogue(L, LI, SE, AC, TLI, DT, LVL, IAI); + TailFoldingInfo *TFI) const { + return TTIImpl->preferPredicateOverEpilogue(TFI); } -PredicationStyle TargetTransformInfo::emitGetActiveLaneMask() const { - return TTIImpl->emitGetActiveLaneMask(); +TailFoldingStyle TargetTransformInfo::getPreferredTailFoldingStyle( + bool IVUpdateMayOverflow) const { + return TTIImpl->getPreferredTailFoldingStyle(IVUpdateMayOverflow); } std::optional<Instruction *> @@ -664,6 +697,10 @@ std::optional<unsigned> TargetTransformInfo::getVScaleForTuning() const { return TTIImpl->getVScaleForTuning(); } +bool TargetTransformInfo::isVScaleKnownToBeAPowerOfTwo() const { + return TTIImpl->isVScaleKnownToBeAPowerOfTwo(); +} + bool TargetTransformInfo::shouldMaximizeVectorBandwidth( TargetTransformInfo::RegisterKind K) const { return TTIImpl->shouldMaximizeVectorBandwidth(K); @@ -728,7 +765,7 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const { return TTIImpl->shouldPrefetchAddressSpace(AS); } -unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const { +unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const { return TTIImpl->getMaxInterleaveFactor(VF); } @@ -1007,6 +1044,10 @@ InstructionCost TargetTransformInfo::getMemcpyCost(const Instruction *I) const { return Cost; } +uint64_t TargetTransformInfo::getMaxMemIntrinsicInlineSizeThreshold() const { + return TTIImpl->getMaxMemIntrinsicInlineSizeThreshold(); +} + InstructionCost TargetTransformInfo::getArithmeticReductionCost( unsigned Opcode, VectorType *Ty, std::optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) const { @@ -1017,17 +1058,17 @@ InstructionCost TargetTransformInfo::getArithmeticReductionCost( } InstructionCost TargetTransformInfo::getMinMaxReductionCost( - VectorType *Ty, VectorType *CondTy, bool IsUnsigned, + Intrinsic::ID IID, VectorType *Ty, FastMathFlags FMF, TTI::TargetCostKind CostKind) const { InstructionCost Cost = - TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind); + TTIImpl->getMinMaxReductionCost(IID, Ty, FMF, CostKind); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } InstructionCost TargetTransformInfo::getExtendedReductionCost( unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty, - std::optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) const { + FastMathFlags FMF, TTI::TargetCostKind CostKind) const { return TTIImpl->getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, CostKind); } @@ -1163,6 +1204,14 @@ TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const { return TTIImpl->getVPLegalizationStrategy(VPI); } +bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const { + return TTIImpl->hasArmWideBranch(Thumb); +} + +unsigned TargetTransformInfo::getMaxNumArgs() const { + return TTIImpl->getMaxNumArgs(); +} + bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const { return TTIImpl->shouldExpandReduction(II); } diff --git a/contrib/llvm-project/llvm/lib/Analysis/TensorSpec.cpp b/contrib/llvm-project/llvm/lib/Analysis/TensorSpec.cpp index 4f7428ded85e..8dd1a054af88 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/TensorSpec.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/TensorSpec.cpp @@ -10,8 +10,10 @@ // utils. // //===----------------------------------------------------------------------===// +#include "llvm/ADT/STLExtras.h" #include "llvm/Config/config.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/TensorSpec.h" #include "llvm/Support/CommandLine.h" @@ -102,4 +104,23 @@ std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, return std::nullopt; } +std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) { + switch (Spec.type()) { +#define _IMR_DBG_PRINTER(T, N) \ + case TensorType::N: { \ + const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \ + auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \ + return llvm::join( \ + llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \ + } + SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER) +#undef _IMR_DBG_PRINTER + case TensorType::Total: + case TensorType::Invalid: + llvm_unreachable("invalid tensor type"); + } + // To appease warnings about not all control paths returning a value. + return ""; +} + } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp b/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp index dcee8d40c53d..e236890aa2bc 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp @@ -32,7 +32,7 @@ static cl::opt<bool> UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden, cl::desc("Output simple (non-protobuf) log.")); -void Logger::writeHeader() { +void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) { json::OStream JOS(*OS); JOS.object([&]() { JOS.attributeArray("features", [&]() { @@ -44,6 +44,11 @@ void Logger::writeHeader() { RewardSpec.toJSON(JOS); JOS.attributeEnd(); } + if (AdviceSpec.has_value()) { + JOS.attributeBegin("advice"); + AdviceSpec->toJSON(JOS); + JOS.attributeEnd(); + } }); *OS << "\n"; } @@ -81,8 +86,9 @@ void Logger::logRewardImpl(const char *RawData) { Logger::Logger(std::unique_ptr<raw_ostream> OS, const std::vector<TensorSpec> &FeatureSpecs, - const TensorSpec &RewardSpec, bool IncludeReward) + const TensorSpec &RewardSpec, bool IncludeReward, + std::optional<TensorSpec> AdviceSpec) : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), IncludeReward(IncludeReward) { - writeHeader(); + writeHeader(AdviceSpec); } diff --git a/contrib/llvm-project/llvm/lib/Analysis/TypeMetadataUtils.cpp b/contrib/llvm-project/llvm/lib/Analysis/TypeMetadataUtils.cpp index 1c9354fbe01f..bbaee06ed8a5 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/TypeMetadataUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -99,7 +99,9 @@ void llvm::findDevirtualizableCallsForTypeCheckedLoad( SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, const CallInst *CI, DominatorTree &DT) { assert(CI->getCalledFunction()->getIntrinsicID() == - Intrinsic::type_checked_load); + Intrinsic::type_checked_load || + CI->getCalledFunction()->getIntrinsicID() == + Intrinsic::type_checked_load_relative); auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1)); if (!Offset) { @@ -161,7 +163,7 @@ Constant *llvm::getPointerAtOffset(Constant *I, uint64_t Offset, Module &M, // (Swift-specific) relative-pointer support starts here. if (auto *CI = dyn_cast<ConstantInt>(I)) { - if (Offset == 0 && CI->getZExtValue() == 0) { + if (Offset == 0 && CI->isZero()) { return I; } } diff --git a/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp b/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp index 8ed5af8a8d1c..bf0b194dcd70 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp @@ -1,4 +1,4 @@ -//===- ConvergenceUtils.cpp -----------------------------------------------===// +//===- UniformityAnalysis.cpp ---------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -26,18 +26,16 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs( template <> bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent( - const Instruction &Instr, bool AllDefsDivergent) { - return markDivergent(&Instr); + const Instruction &Instr) { + return markDivergent(cast<Value>(&Instr)); } template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() { for (auto &I : instructions(F)) { - if (TTI->isSourceOfDivergence(&I)) { - assert(!I.isTerminator()); + if (TTI->isSourceOfDivergence(&I)) markDivergent(I); - } else if (TTI->isAlwaysUniform(&I)) { + else if (TTI->isAlwaysUniform(&I)) addUniformOverride(I); - } } for (auto &Arg : F.args()) { if (TTI->isSourceOfDivergence(&Arg)) { @@ -50,13 +48,8 @@ template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( const Value *V) { for (const auto *User : V->users()) { - const auto *UserInstr = dyn_cast<const Instruction>(User); - if (!UserInstr) - continue; - if (isAlwaysUniform(*UserInstr)) - continue; - if (markDivergent(*UserInstr)) { - Worklist.push_back(UserInstr); + if (const auto *UserInstr = dyn_cast<const Instruction>(User)) { + markDivergent(*UserInstr); } } } @@ -73,8 +66,7 @@ void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( template <> bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle( const Instruction &I, const Cycle &DefCycle) const { - if (isAlwaysUniform(I)) - return false; + assert(!isAlwaysUniform(I)); for (const Use &U : I.operands()) { if (auto *I = dyn_cast<Instruction>(&U)) { if (DefCycle.contains(I->getParent())) @@ -84,6 +76,33 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle( return false; } +template <> +void llvm::GenericUniformityAnalysisImpl< + SSAContext>::propagateTemporalDivergence(const Instruction &I, + const Cycle &DefCycle) { + if (isDivergent(I)) + return; + for (auto *User : I.users()) { + auto *UserInstr = cast<Instruction>(User); + if (DefCycle.contains(UserInstr->getParent())) + continue; + markDivergent(*UserInstr); + } +} + +template <> +bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse( + const Use &U) const { + const auto *V = U.get(); + if (isDivergent(V)) + return true; + if (const auto *DefInstr = dyn_cast<Instruction>(V)) { + const auto *UseInstr = cast<Instruction>(U.getUser()); + return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); + } + return false; +} + // This ensures explicit instantiation of // GenericUniformityAnalysisImpl::ImplDeleter::operator() template class llvm::GenericUniformityInfo<SSAContext>; @@ -99,7 +118,12 @@ llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); auto &TTI = FAM.getResult<TargetIRAnalysis>(F); auto &CI = FAM.getResult<CycleAnalysis>(F); - return UniformityInfo{F, DT, CI, &TTI}; + UniformityInfo UI{F, DT, CI, &TTI}; + // Skip computation if we can assume everything is uniform. + if (TTI.hasBranchDivergence(&F)) + UI.compute(); + + return UI; } AnalysisKey UniformityInfoAnalysis::Key; @@ -125,17 +149,18 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) { initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry()); } -INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo", - "Uniform Info Analysis", true, true) +INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity", + "Uniformity Analysis", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo", - "Uniform Info Analysis", true, true) +INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", + "Uniformity Analysis", true, true) void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<CycleInfoWrapperPass>(); + AU.addRequiredTransitive<CycleInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); } @@ -148,6 +173,11 @@ bool UniformityInfoWrapperPass::runOnFunction(Function &F) { m_function = &F; m_uniformityInfo = UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo}; + + // Skip computation if we can assume everything is uniform. + if (targetTransformInfo.hasBranchDivergence(m_function)) + m_uniformityInfo.compute(); + return false; } diff --git a/contrib/llvm-project/llvm/lib/Analysis/ValueTracking.cpp b/contrib/llvm-project/llvm/lib/Analysis/ValueTracking.cpp index a13bdade320f..5d526858e00e 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/ValueTracking.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/ValueTracking.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -25,7 +26,6 @@ #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" @@ -42,6 +42,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" @@ -53,6 +54,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/LLVMContext.h" @@ -93,33 +95,6 @@ static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { return DL.getPointerTypeSizeInBits(Ty); } -namespace { - -// Simplifying using an assume can only be done in a particular control-flow -// context (the context instruction provides that context). If an assume and -// the context instruction are not in the same block then the DT helps in -// figuring out if we can use it. -struct Query { - const DataLayout &DL; - AssumptionCache *AC; - const Instruction *CxtI; - const DominatorTree *DT; - - // Unlike the other analyses, this may be a nullptr because not all clients - // provide it currently. - OptimizationRemarkEmitter *ORE; - - /// If true, it is safe to use metadata during simplification. - InstrInfoQuery IIQ; - - Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT, bool UseInstrInfo, - OptimizationRemarkEmitter *ORE = nullptr) - : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {} -}; - -} // end anonymous namespace - // Given the provided Value and, potentially, a context instruction, return // the preferred context instruction (if any). static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { @@ -170,10 +145,11 @@ static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf, } static void computeKnownBits(const Value *V, const APInt &DemandedElts, - KnownBits &Known, unsigned Depth, const Query &Q); + KnownBits &Known, unsigned Depth, + const SimplifyQuery &Q); static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { // Since the number of lanes in a scalable vector is unknown at compile time, // we track one bit which is implicitly broadcast to all lanes. This means // that all lanes in a scalable vector are considered demanded. @@ -186,46 +162,44 @@ static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, void llvm::computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT, - OptimizationRemarkEmitter *ORE, bool UseInstrInfo) { + const DominatorTree *DT, bool UseInstrInfo) { ::computeKnownBits(V, Known, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V, CxtI), UseInstrInfo)); } void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts, KnownBits &Known, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, - OptimizationRemarkEmitter *ORE, bool UseInstrInfo) { + bool UseInstrInfo) { ::computeKnownBits(V, DemandedElts, Known, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V, CxtI), UseInstrInfo)); } static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, - unsigned Depth, const Query &Q); + unsigned Depth, const SimplifyQuery &Q); static KnownBits computeKnownBits(const Value *V, unsigned Depth, - const Query &Q); + const SimplifyQuery &Q); KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT, - OptimizationRemarkEmitter *ORE, - bool UseInstrInfo) { - return ::computeKnownBits( - V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); + const DominatorTree *DT, bool UseInstrInfo) { + return ::computeKnownBits(V, Depth, + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V, CxtI), UseInstrInfo)); } KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT, - OptimizationRemarkEmitter *ORE, - bool UseInstrInfo) { - return ::computeKnownBits( - V, DemandedElts, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); + const DominatorTree *DT, bool UseInstrInfo) { + return ::computeKnownBits(V, DemandedElts, Depth, + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V, CxtI), UseInstrInfo)); } bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, @@ -282,11 +256,18 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType()); KnownBits LHSKnown(IT->getBitWidth()); KnownBits RHSKnown(IT->getBitWidth()); - computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); - computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); + computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT, UseInstrInfo); + computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT, UseInstrInfo); return KnownBits::haveNoCommonBitsSet(LHSKnown, RHSKnown); } +bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) { + return !I->user_empty() && all_of(I->users(), [](const User *U) { + ICmpInst::Predicate P; + return match(U, m_ICmp(P, m_Value(), m_Zero())); + }); +} + bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) { return !I->user_empty() && all_of(I->users(), [](const User *U) { ICmpInst::Predicate P; @@ -295,34 +276,37 @@ bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) { } static bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, - const Query &Q); + const SimplifyQuery &Q); bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL, bool OrZero, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { - return ::isKnownToBeAPowerOfTwo( - V, OrZero, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); + return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth, + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V, CxtI), + UseInstrInfo)); } static bool isKnownNonZero(const Value *V, const APInt &DemandedElts, - unsigned Depth, const Query &Q); + unsigned Depth, const SimplifyQuery &Q); -static bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q); +static bool isKnownNonZero(const Value *V, unsigned Depth, + const SimplifyQuery &Q); bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { return ::isKnownNonZero(V, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V, CxtI), UseInstrInfo)); } bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { - KnownBits Known = - computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo); + KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT, UseInstrInfo); return Known.isNonNegative(); } @@ -341,39 +325,39 @@ bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth, bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { - KnownBits Known = - computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo); + KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT, UseInstrInfo); return Known.isNegative(); } static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth, - const Query &Q); + const SimplifyQuery &Q); bool llvm::isKnownNonEqual(const Value *V1, const Value *V2, const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { return ::isKnownNonEqual(V1, V2, 0, - Query(DL, AC, safeCxtI(V2, V1, CxtI), DT, - UseInstrInfo, /*ORE=*/nullptr)); + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V2, V1, CxtI), UseInstrInfo)); } static bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, - const Query &Q); + const SimplifyQuery &Q); bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { - return ::MaskedValueIsZero( - V, Mask, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); + return ::MaskedValueIsZero(V, Mask, Depth, + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V, CxtI), UseInstrInfo)); } static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts, - unsigned Depth, const Query &Q); + unsigned Depth, const SimplifyQuery &Q); static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { auto *FVTy = dyn_cast<FixedVectorType>(V->getType()); APInt DemandedElts = FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1); @@ -384,8 +368,9 @@ unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { - return ::ComputeNumSignBits( - V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); + return ::ComputeNumSignBits(V, Depth, + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(V, CxtI), UseInstrInfo)); } unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL, @@ -399,7 +384,7 @@ unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL, static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, bool NSW, const APInt &DemandedElts, KnownBits &KnownOut, KnownBits &Known2, - unsigned Depth, const Query &Q) { + unsigned Depth, const SimplifyQuery &Q) { computeKnownBits(Op1, DemandedElts, KnownOut, Depth + 1, Q); // If one operand is unknown and we have no nowrap information, @@ -414,7 +399,7 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, const APInt &DemandedElts, KnownBits &Known, KnownBits &Known2, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); @@ -479,7 +464,7 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, // The first CommonPrefixBits of all values in Range are equal. unsigned CommonPrefixBits = - (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros(); + (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero(); APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits); APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(BitWidth); Known.One &= UnsignedMax & Mask; @@ -579,6 +564,11 @@ bool llvm::isValidAssumeForContext(const Instruction *Inv, return false; } +// TODO: cmpExcludesZero misses many cases where `RHS` is non-constant but +// we still have enough information about `RHS` to conclude non-zero. For +// example Pred=EQ, RHS=isKnownNonZero. cmpExcludesZero is called in loops +// so the extra compile time may not be worth it, but possibly a second API +// should be created for use outside of loops. static bool cmpExcludesZero(CmpInst::Predicate Pred, const Value *RHS) { // v u> y implies v != 0. if (Pred == ICmpInst::ICMP_UGT) @@ -597,7 +587,7 @@ static bool cmpExcludesZero(CmpInst::Predicate Pred, const Value *RHS) { return !TrueValues.contains(APInt::getZero(C->getBitWidth())); } -static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { +static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) { // Use of assumptions is context-sensitive. If we don't have a context, we // cannot use them! if (!Q.AC || !Q.CxtI) @@ -616,7 +606,7 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { if (!AssumeVH) continue; - CondGuardInst *I = cast<CondGuardInst>(AssumeVH); + CallInst *I = cast<CallInst>(AssumeVH); assert(I->getFunction() == Q.CxtI->getFunction() && "Got assumption for the wrong function!"); @@ -624,6 +614,9 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { // We're running this loop for once for each value queried resulting in a // runtime of ~O(#assumes * #values). + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + Value *RHS; CmpInst::Predicate Pred; auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); @@ -637,8 +630,167 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { return false; } -static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, - unsigned Depth, const Query &Q) { +static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp, + KnownBits &Known, unsigned Depth, + const SimplifyQuery &Q) { + unsigned BitWidth = Known.getBitWidth(); + // We are attempting to compute known bits for the operands of an assume. + // Do not try to use other assumptions for those recursive calls because + // that can lead to mutual recursion and a compile-time explosion. + // An example of the mutual recursion: computeKnownBits can call + // isKnownNonZero which calls computeKnownBitsFromAssume (this function) + // and so on. + SimplifyQuery QueryNoAC = Q; + QueryNoAC.AC = nullptr; + + // Note that ptrtoint may change the bitwidth. + Value *A, *B; + auto m_V = + m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V))); + + CmpInst::Predicate Pred; + uint64_t C; + switch (Cmp->getPredicate()) { + case ICmpInst::ICMP_EQ: + // assume(v = a) + if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + Known = Known.unionWith(RHSKnown); + // assume(v & b = a) + } else if (match(Cmp, + m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC); + + // For those bits in the mask that are known to be one, we can propagate + // known bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & MaskKnown.One; + Known.One |= RHSKnown.One & MaskKnown.One; + // assume(~(v & b) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), + m_Value(A)))) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC); + + // For those bits in the mask that are known to be one, we can propagate + // inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.One & MaskKnown.One; + Known.One |= RHSKnown.Zero & MaskKnown.One; + // assume(v | b = a) + } else if (match(Cmp, + m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC); + + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & BKnown.Zero; + Known.One |= RHSKnown.One & BKnown.Zero; + // assume(~(v | b) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), + m_Value(A)))) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC); + + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.One & BKnown.Zero; + Known.One |= RHSKnown.Zero & BKnown.Zero; + // assume(v ^ b = a) + } else if (match(Cmp, + m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC); + + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. For those bits in B that are known to be one, + // we can propagate inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & BKnown.Zero; + Known.One |= RHSKnown.One & BKnown.Zero; + Known.Zero |= RHSKnown.One & BKnown.One; + Known.One |= RHSKnown.Zero & BKnown.One; + // assume(~(v ^ b) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), + m_Value(A)))) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC); + + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. For those bits in B that are + // known to be one, we can propagate known bits from the RHS to V. + Known.Zero |= RHSKnown.One & BKnown.Zero; + Known.One |= RHSKnown.Zero & BKnown.Zero; + Known.Zero |= RHSKnown.Zero & BKnown.One; + Known.One |= RHSKnown.One & BKnown.One; + // assume(v << c = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), + m_Value(A))) && + C < BitWidth) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + RHSKnown.Zero.lshrInPlace(C); + RHSKnown.One.lshrInPlace(C); + Known = Known.unionWith(RHSKnown); + // assume(~(v << c) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), + m_Value(A))) && + C < BitWidth) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + RHSKnown.One.lshrInPlace(C); + Known.Zero |= RHSKnown.One; + RHSKnown.Zero.lshrInPlace(C); + Known.One |= RHSKnown.Zero; + // assume(v >> c = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), + m_Value(A))) && + C < BitWidth) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + Known.Zero |= RHSKnown.Zero << C; + Known.One |= RHSKnown.One << C; + // assume(~(v >> c) = a) + } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), + m_Value(A))) && + C < BitWidth) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + Known.Zero |= RHSKnown.One << C; + Known.One |= RHSKnown.Zero << C; + } + break; + case ICmpInst::ICMP_NE: { + // assume (v & b != 0) where b is a power of 2 + const APInt *BPow2; + if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) { + Known.One |= *BPow2; + } + break; + } + default: + const APInt *Offset = nullptr; + if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))), + m_Value(A)))) { + KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC); + ConstantRange RHSRange = + ConstantRange::fromKnownBits(RHSKnown, Cmp->isSigned()); + ConstantRange LHSRange = + ConstantRange::makeAllowedICmpRegion(Pred, RHSRange); + if (Offset) + LHSRange = LHSRange.sub(*Offset); + Known = Known.unionWith(LHSRange.toKnownBits()); + } + break; + } +} + +void llvm::computeKnownBitsFromAssume(const Value *V, KnownBits &Known, + unsigned Depth, const SimplifyQuery &Q) { // Use of assumptions is context-sensitive. If we don't have a context, we // cannot use them! if (!Q.AC || !Q.CxtI) @@ -649,7 +801,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // Refine Known set if the pointer alignment is set by assume bundles. if (V->getType()->isPointerTy()) { if (RetainedKnowledge RK = getKnowledgeValidInContext( - V, {Attribute::Alignment}, Q.CxtI, Q.DT, Q.AC)) { + V, { Attribute::Alignment }, Q.CxtI, Q.DT, Q.AC)) { if (isPowerOf2_64(RK.ArgValue)) Known.Zero.setLowBits(Log2_64(RK.ArgValue)); } @@ -661,7 +813,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { if (!AssumeVH) continue; - CondGuardInst *I = cast<CondGuardInst>(AssumeVH); + CallInst *I = cast<CallInst>(AssumeVH); assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && "Got assumption for the wrong function!"); @@ -669,16 +821,21 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // We're running this loop for once for each value queried resulting in a // runtime of ~O(#assumes * #values). + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + Value *Arg = I->getArgOperand(0); if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { assert(BitWidth == 1 && "assume operand is not i1?"); + (void)BitWidth; Known.setAllOnes(); return; } if (match(Arg, m_Not(m_Specific(V))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { assert(BitWidth == 1 && "assume operand is not i1?"); + (void)BitWidth; Known.setAllZero(); return; } @@ -691,278 +848,16 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (!Cmp) continue; - // We are attempting to compute known bits for the operands of an assume. - // Do not try to use other assumptions for those recursive calls because - // that can lead to mutual recursion and a compile-time explosion. - // An example of the mutual recursion: computeKnownBits can call - // isKnownNonZero which calls computeKnownBitsFromAssume (this function) - // and so on. - Query QueryNoAC = Q; - QueryNoAC.AC = nullptr; - - // Note that ptrtoint may change the bitwidth. - Value *A, *B; - auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); - - CmpInst::Predicate Pred; - uint64_t C; - switch (Cmp->getPredicate()) { - default: - break; - case ICmpInst::ICMP_EQ: - // assume(v = a) - if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - Known.Zero |= RHSKnown.Zero; - Known.One |= RHSKnown.One; - // assume(v & b = a) - } else if (match(Cmp, - m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - KnownBits MaskKnown = - computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // For those bits in the mask that are known to be one, we can propagate - // known bits from the RHS to V. - Known.Zero |= RHSKnown.Zero & MaskKnown.One; - Known.One |= RHSKnown.One & MaskKnown.One; - // assume(~(v & b) = a) - } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), - m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - KnownBits MaskKnown = - computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // For those bits in the mask that are known to be one, we can propagate - // inverted known bits from the RHS to V. - Known.Zero |= RHSKnown.One & MaskKnown.One; - Known.One |= RHSKnown.Zero & MaskKnown.One; - // assume(v | b = a) - } else if (match(Cmp, - m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - KnownBits BKnown = - computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // For those bits in B that are known to be zero, we can propagate known - // bits from the RHS to V. - Known.Zero |= RHSKnown.Zero & BKnown.Zero; - Known.One |= RHSKnown.One & BKnown.Zero; - // assume(~(v | b) = a) - } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), - m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - KnownBits BKnown = - computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // For those bits in B that are known to be zero, we can propagate - // inverted known bits from the RHS to V. - Known.Zero |= RHSKnown.One & BKnown.Zero; - Known.One |= RHSKnown.Zero & BKnown.Zero; - // assume(v ^ b = a) - } else if (match(Cmp, - m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - KnownBits BKnown = - computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // For those bits in B that are known to be zero, we can propagate known - // bits from the RHS to V. For those bits in B that are known to be one, - // we can propagate inverted known bits from the RHS to V. - Known.Zero |= RHSKnown.Zero & BKnown.Zero; - Known.One |= RHSKnown.One & BKnown.Zero; - Known.Zero |= RHSKnown.One & BKnown.One; - Known.One |= RHSKnown.Zero & BKnown.One; - // assume(~(v ^ b) = a) - } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), - m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - KnownBits BKnown = - computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // For those bits in B that are known to be zero, we can propagate - // inverted known bits from the RHS to V. For those bits in B that are - // known to be one, we can propagate known bits from the RHS to V. - Known.Zero |= RHSKnown.One & BKnown.Zero; - Known.One |= RHSKnown.Zero & BKnown.Zero; - Known.Zero |= RHSKnown.Zero & BKnown.One; - Known.One |= RHSKnown.One & BKnown.One; - // assume(v << c = a) - } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), - m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // For those bits in RHS that are known, we can propagate them to known - // bits in V shifted to the right by C. - RHSKnown.Zero.lshrInPlace(C); - Known.Zero |= RHSKnown.Zero; - RHSKnown.One.lshrInPlace(C); - Known.One |= RHSKnown.One; - // assume(~(v << c) = a) - } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), - m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - // For those bits in RHS that are known, we can propagate them inverted - // to known bits in V shifted to the right by C. - RHSKnown.One.lshrInPlace(C); - Known.Zero |= RHSKnown.One; - RHSKnown.Zero.lshrInPlace(C); - Known.One |= RHSKnown.Zero; - // assume(v >> c = a) - } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), - m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - // For those bits in RHS that are known, we can propagate them to known - // bits in V shifted to the right by C. - Known.Zero |= RHSKnown.Zero << C; - Known.One |= RHSKnown.One << C; - // assume(~(v >> c) = a) - } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), - m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - // For those bits in RHS that are known, we can propagate them inverted - // to known bits in V shifted to the right by C. - Known.Zero |= RHSKnown.One << C; - Known.One |= RHSKnown.Zero << C; - } - break; - case ICmpInst::ICMP_SGE: - // assume(v >=_s c) where c is non-negative - if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); - - if (RHSKnown.isNonNegative()) { - // We know that the sign bit is zero. - Known.makeNonNegative(); - } - } - break; - case ICmpInst::ICMP_SGT: - // assume(v >_s c) where c is at least -1. - if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); - - if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { - // We know that the sign bit is zero. - Known.makeNonNegative(); - } - } - break; - case ICmpInst::ICMP_SLE: - // assume(v <=_s c) where c is negative - if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); - - if (RHSKnown.isNegative()) { - // We know that the sign bit is one. - Known.makeNegative(); - } - } - break; - case ICmpInst::ICMP_SLT: - // assume(v <_s c) where c is non-positive - if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - if (RHSKnown.isZero() || RHSKnown.isNegative()) { - // We know that the sign bit is one. - Known.makeNegative(); - } - } - break; - case ICmpInst::ICMP_ULE: - // assume(v <=_u c) - if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // Whatever high bits in c are zero are known to be zero. - Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); - } - break; - case ICmpInst::ICMP_ULT: - // assume(v <_u c) - if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown = - computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); - - // If the RHS is known zero, then this assumption must be wrong (nothing - // is unsigned less than zero). Signal a conflict and get out of here. - if (RHSKnown.isZero()) { - Known.Zero.setAllBits(); - Known.One.setAllBits(); - break; - } + if (!isValidAssumeForContext(I, Q.CxtI, Q.DT)) + continue; - // Whatever high bits in c are zero are known to be zero (if c is a power - // of 2, then one more). - if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, QueryNoAC)) - Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1); - else - Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); - } - break; - case ICmpInst::ICMP_NE: { - // assume (v & b != 0) where b is a power of 2 - const APInt *BPow2; - if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero())) && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - Known.One |= BPow2->zextOrTrunc(BitWidth); - } - } break; - } + computeKnownBitsFromCmp(V, Cmp, Known, Depth, Q); } - // If assumptions conflict with each other or previous known bits, then we - // have a logical fallacy. It's possible that the assumption is not reachable, - // so this isn't a real bug. On the other hand, the program may have undefined - // behavior, or we might have a bug in the compiler. We can't assert/crash, so - // clear out the known bits, try to warn the user, and hope for the best. - if (Known.Zero.intersects(Known.One)) { + // Conflicting assumption: Undefined behavior will occur on this execution + // path. + if (Known.hasConflict()) Known.resetAll(); - - if (Q.ORE) - Q.ORE->emit([&]() { - auto *CxtI = const_cast<Instruction *>(Q.CxtI); - return OptimizationRemarkAnalysis("value-tracking", "BadAssumption", - CxtI) - << "Detected conflicting code assumptions. Program may " - "have undefined behavior, or compiler may have " - "internal error."; - }); - } } /// Compute known bits from a shift operator, including those with a @@ -975,93 +870,128 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, /// combined for all permitted shift amounts. static void computeKnownBitsFromShiftOperator( const Operator *I, const APInt &DemandedElts, KnownBits &Known, - KnownBits &Known2, unsigned Depth, const Query &Q, - function_ref<KnownBits(const KnownBits &, const KnownBits &)> KF) { - unsigned BitWidth = Known.getBitWidth(); + KnownBits &Known2, unsigned Depth, const SimplifyQuery &Q, + function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF) { computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q); + // To limit compile-time impact, only query isKnownNonZero() if we know at + // least something about the shift amount. + bool ShAmtNonZero = + Known.isNonZero() || + (Known.getMaxValue().ult(Known.getBitWidth()) && + isKnownNonZero(I->getOperand(1), DemandedElts, Depth + 1, Q)); + Known = KF(Known2, Known, ShAmtNonZero); +} + +static KnownBits +getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts, + const KnownBits &KnownLHS, const KnownBits &KnownRHS, + unsigned Depth, const SimplifyQuery &Q) { + unsigned BitWidth = KnownLHS.getBitWidth(); + KnownBits KnownOut(BitWidth); + bool IsAnd = false; + bool HasKnownOne = !KnownLHS.One.isZero() || !KnownRHS.One.isZero(); + Value *X = nullptr, *Y = nullptr; - // Note: We cannot use Known.Zero.getLimitedValue() here, because if - // BitWidth > 64 and any upper bits are known, we'll end up returning the - // limit value (which implies all bits are known). - uint64_t ShiftAmtKZ = Known.Zero.zextOrTrunc(64).getZExtValue(); - uint64_t ShiftAmtKO = Known.One.zextOrTrunc(64).getZExtValue(); - bool ShiftAmtIsConstant = Known.isConstant(); - bool MaxShiftAmtIsOutOfRange = Known.getMaxValue().uge(BitWidth); - - if (ShiftAmtIsConstant) { - Known = KF(Known2, Known); - - // If the known bits conflict, this must be an overflowing left shift, so - // the shift result is poison. We can return anything we want. Choose 0 for - // the best folding opportunity. - if (Known.hasConflict()) - Known.setAllZero(); - - return; + switch (I->getOpcode()) { + case Instruction::And: + KnownOut = KnownLHS & KnownRHS; + IsAnd = true; + // and(x, -x) is common idioms that will clear all but lowest set + // bit. If we have a single known bit in x, we can clear all bits + // above it. + // TODO: instcombine often reassociates independent `and` which can hide + // this pattern. Try to match and(x, and(-x, y)) / and(and(x, y), -x). + if (HasKnownOne && match(I, m_c_And(m_Value(X), m_Neg(m_Deferred(X))))) { + // -(-x) == x so using whichever (LHS/RHS) gets us a better result. + if (KnownLHS.countMaxTrailingZeros() <= KnownRHS.countMaxTrailingZeros()) + KnownOut = KnownLHS.blsi(); + else + KnownOut = KnownRHS.blsi(); + } + break; + case Instruction::Or: + KnownOut = KnownLHS | KnownRHS; + break; + case Instruction::Xor: + KnownOut = KnownLHS ^ KnownRHS; + // xor(x, x-1) is common idioms that will clear all but lowest set + // bit. If we have a single known bit in x, we can clear all bits + // above it. + // TODO: xor(x, x-1) is often rewritting as xor(x, x-C) where C != + // -1 but for the purpose of demanded bits (xor(x, x-C) & + // Demanded) == (xor(x, x-1) & Demanded). Extend the xor pattern + // to use arbitrary C if xor(x, x-C) as the same as xor(x, x-1). + if (HasKnownOne && + match(I, m_c_Xor(m_Value(X), m_c_Add(m_Deferred(X), m_AllOnes())))) { + const KnownBits &XBits = I->getOperand(0) == X ? KnownLHS : KnownRHS; + KnownOut = XBits.blsmsk(); + } + break; + default: + llvm_unreachable("Invalid Op used in 'analyzeKnownBitsFromAndXorOr'"); + } + + // and(x, add (x, -1)) is a common idiom that always clears the low bit; + // xor/or(x, add (x, -1)) is an idiom that will always set the low bit. + // here we handle the more general case of adding any odd number by + // matching the form and/xor/or(x, add(x, y)) where y is odd. + // TODO: This could be generalized to clearing any bit set in y where the + // following bit is known to be unset in y. + if (!KnownOut.Zero[0] && !KnownOut.One[0] && + (match(I, m_c_BinOp(m_Value(X), m_c_Add(m_Deferred(X), m_Value(Y)))) || + match(I, m_c_BinOp(m_Value(X), m_Sub(m_Deferred(X), m_Value(Y)))) || + match(I, m_c_BinOp(m_Value(X), m_Sub(m_Value(Y), m_Deferred(X)))))) { + KnownBits KnownY(BitWidth); + computeKnownBits(Y, DemandedElts, KnownY, Depth + 1, Q); + if (KnownY.countMinTrailingOnes() > 0) { + if (IsAnd) + KnownOut.Zero.setBit(0); + else + KnownOut.One.setBit(0); + } } + return KnownOut; +} - // If the shift amount could be greater than or equal to the bit-width of the - // LHS, the value could be poison, but bail out because the check below is - // expensive. - // TODO: Should we just carry on? - if (MaxShiftAmtIsOutOfRange) { - Known.resetAll(); - return; - } +// Public so this can be used in `SimplifyDemandedUseBits`. +KnownBits llvm::analyzeKnownBitsFromAndXorOr( + const Operator *I, const KnownBits &KnownLHS, const KnownBits &KnownRHS, + unsigned Depth, const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { + auto *FVTy = dyn_cast<FixedVectorType>(I->getType()); + APInt DemandedElts = + FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1); - // It would be more-clearly correct to use the two temporaries for this - // calculation. Reusing the APInts here to prevent unnecessary allocations. - Known.resetAll(); + return getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS, KnownRHS, Depth, + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, + safeCxtI(I, CxtI), + UseInstrInfo)); +} - // If we know the shifter operand is nonzero, we can sometimes infer more - // known bits. However this is expensive to compute, so be lazy about it and - // only compute it when absolutely necessary. - std::optional<bool> ShifterOperandIsNonZero; - - // Early exit if we can't constrain any well-defined shift amount. - if (!(ShiftAmtKZ & (PowerOf2Ceil(BitWidth) - 1)) && - !(ShiftAmtKO & (PowerOf2Ceil(BitWidth) - 1))) { - ShifterOperandIsNonZero = - isKnownNonZero(I->getOperand(1), DemandedElts, Depth + 1, Q); - if (!*ShifterOperandIsNonZero) - return; - } +ConstantRange llvm::getVScaleRange(const Function *F, unsigned BitWidth) { + Attribute Attr = F->getFnAttribute(Attribute::VScaleRange); + // Without vscale_range, we only know that vscale is non-zero. + if (!Attr.isValid()) + return ConstantRange(APInt(BitWidth, 1), APInt::getZero(BitWidth)); - Known.Zero.setAllBits(); - Known.One.setAllBits(); - for (unsigned ShiftAmt = 0; ShiftAmt < BitWidth; ++ShiftAmt) { - // Combine the shifted known input bits only for those shift amounts - // compatible with its known constraints. - if ((ShiftAmt & ~ShiftAmtKZ) != ShiftAmt) - continue; - if ((ShiftAmt | ShiftAmtKO) != ShiftAmt) - continue; - // If we know the shifter is nonzero, we may be able to infer more known - // bits. This check is sunk down as far as possible to avoid the expensive - // call to isKnownNonZero if the cheaper checks above fail. - if (ShiftAmt == 0) { - if (!ShifterOperandIsNonZero) - ShifterOperandIsNonZero = - isKnownNonZero(I->getOperand(1), DemandedElts, Depth + 1, Q); - if (*ShifterOperandIsNonZero) - continue; - } + unsigned AttrMin = Attr.getVScaleRangeMin(); + // Minimum is larger than vscale width, result is always poison. + if ((unsigned)llvm::bit_width(AttrMin) > BitWidth) + return ConstantRange::getEmpty(BitWidth); - Known = KnownBits::commonBits( - Known, KF(Known2, KnownBits::makeConstant(APInt(32, ShiftAmt)))); - } + APInt Min(BitWidth, AttrMin); + std::optional<unsigned> AttrMax = Attr.getVScaleRangeMax(); + if (!AttrMax || (unsigned)llvm::bit_width(*AttrMax) > BitWidth) + return ConstantRange(Min, APInt::getZero(BitWidth)); - // If the known bits conflict, the result is poison. Return a 0 and hope the - // caller can further optimize that. - if (Known.hasConflict()) - Known.setAllZero(); + return ConstantRange(Min, APInt(BitWidth, *AttrMax) + 1); } static void computeKnownBitsFromOperator(const Operator *I, const APInt &DemandedElts, KnownBits &Known, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { unsigned BitWidth = Known.getBitWidth(); KnownBits Known2(BitWidth); @@ -1072,39 +1002,23 @@ static void computeKnownBitsFromOperator(const Operator *I, Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range)) computeKnownBitsFromRangeMetadata(*MD, Known); break; - case Instruction::And: { - // If either the LHS or the RHS are Zero, the result is zero. + case Instruction::And: computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q); computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); - Known &= Known2; - - // and(x, add (x, -1)) is a common idiom that always clears the low bit; - // here we handle the more general case of adding any odd number by - // matching the form add(x, add(x, y)) where y is odd. - // TODO: This could be generalized to clearing any bit set in y where the - // following bit is known to be unset in y. - Value *X = nullptr, *Y = nullptr; - if (!Known.Zero[0] && !Known.One[0] && - match(I, m_c_BinOp(m_Value(X), m_Add(m_Deferred(X), m_Value(Y))))) { - Known2.resetAll(); - computeKnownBits(Y, DemandedElts, Known2, Depth + 1, Q); - if (Known2.countMinTrailingOnes() > 0) - Known.Zero.setBit(0); - } + Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Depth, Q); break; - } case Instruction::Or: computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q); computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); - Known |= Known2; + Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Depth, Q); break; case Instruction::Xor: computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q); computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); - Known ^= Known2; + Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Depth, Q); break; case Instruction::Mul: { bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); @@ -1115,7 +1029,15 @@ static void computeKnownBitsFromOperator(const Operator *I, case Instruction::UDiv: { computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); - Known = KnownBits::udiv(Known, Known2); + Known = + KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I))); + break; + } + case Instruction::SDiv: { + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + Known = + KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I))); break; } case Instruction::Select: { @@ -1147,7 +1069,7 @@ static void computeKnownBitsFromOperator(const Operator *I, computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); // Only known if known in both the LHS and RHS. - Known = KnownBits::commonBits(Known, Known2); + Known = Known.intersectWith(Known2); if (SPF == SPF_ABS) { // RHS from matchSelectPattern returns the negation part of abs pattern. @@ -1254,42 +1176,37 @@ static void computeKnownBitsFromOperator(const Operator *I, break; } case Instruction::Shl: { + bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I)); bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); - auto KF = [NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) { - KnownBits Result = KnownBits::shl(KnownVal, KnownAmt); - // If this shift has "nsw" keyword, then the result is either a poison - // value or has the same sign bit as the first operand. - if (NSW) { - if (KnownVal.Zero.isSignBitSet()) - Result.Zero.setSignBit(); - if (KnownVal.One.isSignBitSet()) - Result.One.setSignBit(); - } - return Result; + auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt, + bool ShAmtNonZero) { + return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW, ShAmtNonZero); }; computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, KF); // Trailing zeros of a right-shifted constant never decrease. const APInt *C; if (match(I->getOperand(0), m_APInt(C))) - Known.Zero.setLowBits(C->countTrailingZeros()); + Known.Zero.setLowBits(C->countr_zero()); break; } case Instruction::LShr: { - auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt) { - return KnownBits::lshr(KnownVal, KnownAmt); + auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt, + bool ShAmtNonZero) { + return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero); }; computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, KF); // Leading zeros of a left-shifted constant never decrease. const APInt *C; if (match(I->getOperand(0), m_APInt(C))) - Known.Zero.setHighBits(C->countLeadingZeros()); + Known.Zero.setHighBits(C->countl_zero()); break; } case Instruction::AShr: { - auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt) { - return KnownBits::ashr(KnownVal, KnownAmt); + auto KF = [](const KnownBits &KnownVal, const KnownBits &KnownAmt, + bool ShAmtNonZero) { + return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero); }; computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, KF); @@ -1376,7 +1293,7 @@ static void computeKnownBitsFromOperator(const Operator *I, if (IndexTypeSize.isScalable()) { // For scalable types the only thing we know about sizeof is // that this is a multiple of the minimum size. - ScalingFactor.Zero.setLowBits(countTrailingZeros(TypeSizeInBytes)); + ScalingFactor.Zero.setLowBits(llvm::countr_zero(TypeSizeInBytes)); } else if (IndexBits.isConstant()) { APInt IndexConst = IndexBits.getConstant(); APInt ScalingFactor(IndexBitWidth, TypeSizeInBytes); @@ -1431,7 +1348,7 @@ static void computeKnownBitsFromOperator(const Operator *I, // inferred hold at original context instruction. TODO: It may be // correct to use the original context. IF warranted, explore and // add sufficient tests to cover. - Query RecQ = Q; + SimplifyQuery RecQ = Q; RecQ.CxtI = P; computeKnownBits(R, DemandedElts, Known2, Depth + 1, RecQ); switch (Opcode) { @@ -1464,7 +1381,7 @@ static void computeKnownBitsFromOperator(const Operator *I, // phi. This is important because that is where the value is actually // "evaluated" even though it is used later somewhere else. (see also // D69571). - Query RecQ = Q; + SimplifyQuery RecQ = Q; unsigned OpNum = P->getOperand(0) == R ? 0 : 1; Instruction *RInst = P->getIncomingBlock(OpNum)->getTerminator(); @@ -1526,7 +1443,7 @@ static void computeKnownBitsFromOperator(const Operator *I, // Otherwise take the unions of the known bit sets of the operands, // taking conservative care to avoid excessive recursion. - if (Depth < MaxAnalysisRecursionDepth - 1 && !Known.Zero && !Known.One) { + if (Depth < MaxAnalysisRecursionDepth - 1 && Known.isUnknown()) { // Skip if every incoming value references to ourself. if (isa_and_nonnull<UndefValue>(P->hasConstantValue())) break; @@ -1542,7 +1459,7 @@ static void computeKnownBitsFromOperator(const Operator *I, // phi. This is important because that is where the value is actually // "evaluated" even though it is used later somewhere else. (see also // D69571). - Query RecQ = Q; + SimplifyQuery RecQ = Q; RecQ.CxtI = P->getIncomingBlock(u)->getTerminator(); Known2 = KnownBits(BitWidth); @@ -1572,10 +1489,10 @@ static void computeKnownBitsFromOperator(const Operator *I, Known2 = KnownBits::makeConstant(*RHSC); break; case CmpInst::Predicate::ICMP_ULE: - Known2.Zero.setHighBits(RHSC->countLeadingZeros()); + Known2.Zero.setHighBits(RHSC->countl_zero()); break; case CmpInst::Predicate::ICMP_ULT: - Known2.Zero.setHighBits((*RHSC - 1).countLeadingZeros()); + Known2.Zero.setHighBits((*RHSC - 1).countl_zero()); break; default: // TODO - add additional integer predicate handling. @@ -1585,7 +1502,7 @@ static void computeKnownBitsFromOperator(const Operator *I, } } - Known = KnownBits::commonBits(Known, Known2); + Known = Known.intersectWith(Known2); // If all bits have been ruled out, there's no need to check // more operands. if (Known.isUnknown()) @@ -1604,8 +1521,7 @@ static void computeKnownBitsFromOperator(const Operator *I, computeKnownBitsFromRangeMetadata(*MD, Known); if (const Value *RV = cast<CallBase>(I)->getReturnedArgOperand()) { computeKnownBits(RV, Known2, Depth + 1, Q); - Known.Zero |= Known2.Zero; - Known.One |= Known2.One; + Known = Known.unionWith(Known2); } if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { @@ -1681,36 +1597,25 @@ static void computeKnownBitsFromOperator(const Operator *I, break; } case Intrinsic::uadd_sat: - case Intrinsic::usub_sat: { - bool IsAdd = II->getIntrinsicID() == Intrinsic::uadd_sat; computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); - - // Add: Leading ones of either operand are preserved. - // Sub: Leading zeros of LHS and leading ones of RHS are preserved - // as leading zeros in the result. - unsigned LeadingKnown; - if (IsAdd) - LeadingKnown = std::max(Known.countMinLeadingOnes(), - Known2.countMinLeadingOnes()); - else - LeadingKnown = std::max(Known.countMinLeadingZeros(), - Known2.countMinLeadingOnes()); - - Known = KnownBits::computeForAddSub( - IsAdd, /* NSW */ false, Known, Known2); - - // We select between the operation result and all-ones/zero - // respectively, so we can preserve known ones/zeros. - if (IsAdd) { - Known.One.setHighBits(LeadingKnown); - Known.Zero.clearAllBits(); - } else { - Known.Zero.setHighBits(LeadingKnown); - Known.One.clearAllBits(); - } + Known = KnownBits::uadd_sat(Known, Known2); + break; + case Intrinsic::usub_sat: + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + Known = KnownBits::usub_sat(Known, Known2); + break; + case Intrinsic::sadd_sat: + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + Known = KnownBits::sadd_sat(Known, Known2); + break; + case Intrinsic::ssub_sat: + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + Known = KnownBits::ssub_sat(Known, Known2); break; - } case Intrinsic::umin: computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); @@ -1731,42 +1636,31 @@ static void computeKnownBitsFromOperator(const Operator *I, computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); Known = KnownBits::smax(Known, Known2); break; + case Intrinsic::ptrmask: { + computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + + const Value *Mask = I->getOperand(1); + Known2 = KnownBits(Mask->getType()->getScalarSizeInBits()); + computeKnownBits(Mask, Known2, Depth + 1, Q); + // This is basically a pointer typed and. + Known &= Known2.zextOrTrunc(Known.getBitWidth()); + break; + } case Intrinsic::x86_sse42_crc32_64_64: Known.Zero.setBitsFrom(32); break; case Intrinsic::riscv_vsetvli: case Intrinsic::riscv_vsetvlimax: - // Assume that VL output is positive and would fit in an int32_t. - // TODO: VLEN might be capped at 16 bits in a future V spec update. - if (BitWidth >= 32) - Known.Zero.setBitsFrom(31); + // Assume that VL output is >= 65536. + // TODO: Take SEW and LMUL into account. + if (BitWidth > 17) + Known.Zero.setBitsFrom(17); break; case Intrinsic::vscale: { - if (!II->getParent() || !II->getFunction() || - !II->getFunction()->hasFnAttribute(Attribute::VScaleRange)) + if (!II->getParent() || !II->getFunction()) break; - auto Attr = II->getFunction()->getFnAttribute(Attribute::VScaleRange); - std::optional<unsigned> VScaleMax = Attr.getVScaleRangeMax(); - - if (!VScaleMax) - break; - - unsigned VScaleMin = Attr.getVScaleRangeMin(); - - // If vscale min = max then we know the exact value at compile time - // and hence we know the exact bits. - if (VScaleMin == VScaleMax) { - Known.One = VScaleMin; - Known.Zero = VScaleMin; - Known.Zero.flipAllBits(); - break; - } - - unsigned FirstZeroHighBit = llvm::bit_width(*VScaleMax); - if (FirstZeroHighBit < BitWidth) - Known.Zero.setBitsFrom(FirstZeroHighBit); - + Known = getVScaleRange(II->getFunction(), BitWidth).toKnownBits(); break; } } @@ -1798,7 +1692,7 @@ static void computeKnownBitsFromOperator(const Operator *I, if (!!DemandedRHS) { const Value *RHS = Shuf->getOperand(1); computeKnownBits(RHS, DemandedRHS, Known2, Depth + 1, Q); - Known = KnownBits::commonBits(Known, Known2); + Known = Known.intersectWith(Known2); } break; } @@ -1831,7 +1725,7 @@ static void computeKnownBitsFromOperator(const Operator *I, DemandedVecElts.clearBit(EltIdx); if (!!DemandedVecElts) { computeKnownBits(Vec, DemandedVecElts, Known2, Depth + 1, Q); - Known = KnownBits::commonBits(Known, Known2); + Known = Known.intersectWith(Known2); } break; } @@ -1892,7 +1786,7 @@ static void computeKnownBitsFromOperator(const Operator *I, /// Determine which bits of V are known to be either zero or one and return /// them. KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, - unsigned Depth, const Query &Q) { + unsigned Depth, const SimplifyQuery &Q) { KnownBits Known(getBitWidth(V->getType(), Q.DL)); computeKnownBits(V, DemandedElts, Known, Depth, Q); return Known; @@ -1900,7 +1794,8 @@ KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, /// Determine which bits of V are known to be either zero or one and return /// them. -KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) { +KnownBits computeKnownBits(const Value *V, unsigned Depth, + const SimplifyQuery &Q) { KnownBits Known(getBitWidth(V->getType(), Q.DL)); computeKnownBits(V, Known, Depth, Q); return Known; @@ -1922,7 +1817,8 @@ KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) { /// same width as the vector element, and the bit is set only if it is true /// for all of the demanded elements in the vector specified by DemandedElts. void computeKnownBits(const Value *V, const APInt &DemandedElts, - KnownBits &Known, unsigned Depth, const Query &Q) { + KnownBits &Known, unsigned Depth, + const SimplifyQuery &Q) { if (!DemandedElts) { // No demanded elts, better to assume we don't know anything. Known.resetAll(); @@ -2032,6 +1928,10 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts, if (const Operator *I = dyn_cast<Operator>(V)) computeKnownBitsFromOperator(I, DemandedElts, Known, Depth, Q); + else if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) { + if (std::optional<ConstantRange> CR = GV->getAbsoluteSymbolRange()) + Known = CR->toKnownBits(); + } // Aligned pointers have trailing zeros - refine Known.Zero set if (isa<PointerType>(V->getType())) { @@ -2051,7 +1951,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts, /// Try to detect a recurrence that the value of the induction variable is /// always a power of two (or zero). static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero, - unsigned Depth, Query &Q) { + unsigned Depth, SimplifyQuery &Q) { BinaryOperator *BO = nullptr; Value *Start = nullptr, *Step = nullptr; if (!matchSimpleRecurrence(PN, BO, Start, Step)) @@ -2110,7 +2010,7 @@ static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero, /// be a power of two when defined. Supports values with integer or pointer /// types and vectors of integers. bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); // Attempt to match against constants. @@ -2118,6 +2018,11 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, return true; if (match(V, m_Power2())) return true; + if (Q.CxtI && match(V, m_VScale())) { + const Function *F = Q.CxtI->getFunction(); + // The vscale_range indicates vscale is a power-of-two. + return F->hasFnAttribute(Attribute::VScaleRange); + } // 1 << X is clearly a power of two if the one is not shifted off the end. If // it is shifted off the end then the result is undefined. @@ -2199,7 +2104,7 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, // A PHI node is power of two if all incoming values are power of two, or if // it is an induction variable where in each step its value is a power of two. if (const PHINode *PN = dyn_cast<PHINode>(V)) { - Query RecQ = Q; + SimplifyQuery RecQ = Q; // Check if it is an induction variable and always power of two. if (isPowerOfTwoRecurrence(PN, OrZero, Depth, RecQ)) @@ -2239,7 +2144,7 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, /// /// Currently this routine does not support vector GEPs. static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { const Function *F = nullptr; if (const Instruction *I = dyn_cast<Instruction>(GEP)) F = I->getFunction(); @@ -2302,8 +2207,7 @@ static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth, static bool isKnownNonNullFromDominatingCondition(const Value *V, const Instruction *CtxI, const DominatorTree *DT) { - if (isa<Constant>(V)) - return false; + assert(!isa<Constant>(V) && "Called for constant?"); if (!CtxI || !DT) return false; @@ -2437,131 +2341,156 @@ static bool isNonZeroRecurrence(const PHINode *PN) { } } -/// Return true if the given value is known to be non-zero when defined. For -/// vectors, return true if every demanded element is known to be non-zero when -/// defined. For pointers, if the context instruction and dominator tree are -/// specified, perform context-sensitive analysis and return true if the -/// pointer couldn't possibly be null at the specified instruction. -/// Supports values with integer or pointer type and vectors of integers. -bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, - const Query &Q) { +static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth, + const SimplifyQuery &Q, unsigned BitWidth, Value *X, + Value *Y, bool NSW) { + KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q); + KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q); -#ifndef NDEBUG - Type *Ty = V->getType(); - assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); - - if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) { - assert( - FVTy->getNumElements() == DemandedElts.getBitWidth() && - "DemandedElt width should equal the fixed vector number of elements"); - } else { - assert(DemandedElts == APInt(1, 1) && - "DemandedElt width should be 1 for scalars"); - } -#endif - - if (auto *C = dyn_cast<Constant>(V)) { - if (C->isNullValue()) - return false; - if (isa<ConstantInt>(C)) - // Must be non-zero due to null test above. + // If X and Y are both non-negative (as signed values) then their sum is not + // zero unless both X and Y are zero. + if (XKnown.isNonNegative() && YKnown.isNonNegative()) + if (isKnownNonZero(Y, DemandedElts, Depth, Q) || + isKnownNonZero(X, DemandedElts, Depth, Q)) return true; - // For constant vectors, check that all elements are undefined or known - // non-zero to determine that the whole vector is known non-zero. - if (auto *VecTy = dyn_cast<FixedVectorType>(C->getType())) { - for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) { - if (!DemandedElts[i]) - continue; - Constant *Elt = C->getAggregateElement(i); - if (!Elt || Elt->isNullValue()) - return false; - if (!isa<UndefValue>(Elt) && !isa<ConstantInt>(Elt)) - return false; - } + // If X and Y are both negative (as signed values) then their sum is not + // zero unless both X and Y equal INT_MIN. + if (XKnown.isNegative() && YKnown.isNegative()) { + APInt Mask = APInt::getSignedMaxValue(BitWidth); + // The sign bit of X is set. If some other bit is set then X is not equal + // to INT_MIN. + if (XKnown.One.intersects(Mask)) + return true; + // The sign bit of Y is set. If some other bit is set then Y is not equal + // to INT_MIN. + if (YKnown.One.intersects(Mask)) return true; - } - - // A global variable in address space 0 is non null unless extern weak - // or an absolute symbol reference. Other address spaces may have null as a - // valid address for a global, so we can't assume anything. - if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) { - if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() && - GV->getType()->getAddressSpace() == 0) - return true; - } - - // For constant expressions, fall through to the Operator code below. - if (!isa<ConstantExpr>(V)) - return false; - } - - if (auto *I = dyn_cast<Instruction>(V)) { - if (MDNode *Ranges = Q.IIQ.getMetadata(I, LLVMContext::MD_range)) { - // If the possible ranges don't contain zero, then the value is - // definitely non-zero. - if (auto *Ty = dyn_cast<IntegerType>(V->getType())) { - const APInt ZeroValue(Ty->getBitWidth(), 0); - if (rangeMetadataExcludesValue(Ranges, ZeroValue)) - return true; - } - } } - if (!isa<Constant>(V) && isKnownNonZeroFromAssume(V, Q)) + // The sum of a non-negative number and a power of two is not zero. + if (XKnown.isNonNegative() && + isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q)) + return true; + if (YKnown.isNonNegative() && + isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q)) return true; - // Some of the tests below are recursive, so bail out if we hit the limit. - if (Depth++ >= MaxAnalysisRecursionDepth) - return false; - - // Check for pointer simplifications. + return KnownBits::computeForAddSub(/*Add*/ true, NSW, XKnown, YKnown) + .isNonZero(); +} - if (PointerType *PtrTy = dyn_cast<PointerType>(V->getType())) { - // Alloca never returns null, malloc might. - if (isa<AllocaInst>(V) && Q.DL.getAllocaAddrSpace() == 0) +static bool isNonZeroSub(const APInt &DemandedElts, unsigned Depth, + const SimplifyQuery &Q, unsigned BitWidth, Value *X, + Value *Y) { + if (auto *C = dyn_cast<Constant>(X)) + if (C->isNullValue() && isKnownNonZero(Y, DemandedElts, Depth, Q)) return true; - // A byval, inalloca may not be null in a non-default addres space. A - // nonnull argument is assumed never 0. - if (const Argument *A = dyn_cast<Argument>(V)) { - if (((A->hasPassPointeeByValueCopyAttr() && - !NullPointerIsDefined(A->getParent(), PtrTy->getAddressSpace())) || - A->hasNonNullAttr())) - return true; + KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q); + if (XKnown.isUnknown()) + return false; + KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q); + // If X != Y then X - Y is non zero. + std::optional<bool> ne = KnownBits::ne(XKnown, YKnown); + // If we are unable to compute if X != Y, we won't be able to do anything + // computing the knownbits of the sub expression so just return here. + return ne && *ne; +} + +static bool isNonZeroShift(const Operator *I, const APInt &DemandedElts, + unsigned Depth, const SimplifyQuery &Q, + const KnownBits &KnownVal) { + auto ShiftOp = [&](const APInt &Lhs, const APInt &Rhs) { + switch (I->getOpcode()) { + case Instruction::Shl: + return Lhs.shl(Rhs); + case Instruction::LShr: + return Lhs.lshr(Rhs); + case Instruction::AShr: + return Lhs.ashr(Rhs); + default: + llvm_unreachable("Unknown Shift Opcode"); } + }; - // A Load tagged with nonnull metadata is never null. - if (const LoadInst *LI = dyn_cast<LoadInst>(V)) - if (Q.IIQ.getMetadata(LI, LLVMContext::MD_nonnull)) - return true; - - if (const auto *Call = dyn_cast<CallBase>(V)) { - if (Call->isReturnNonNull()) - return true; - if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, true)) - return isKnownNonZero(RP, Depth, Q); + auto InvShiftOp = [&](const APInt &Lhs, const APInt &Rhs) { + switch (I->getOpcode()) { + case Instruction::Shl: + return Lhs.lshr(Rhs); + case Instruction::LShr: + case Instruction::AShr: + return Lhs.shl(Rhs); + default: + llvm_unreachable("Unknown Shift Opcode"); } - } + }; - if (!isa<Constant>(V) && - isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT)) - return true; + if (KnownVal.isUnknown()) + return false; - const Operator *I = dyn_cast<Operator>(V); - if (!I) + KnownBits KnownCnt = + computeKnownBits(I->getOperand(1), DemandedElts, Depth, Q); + APInt MaxShift = KnownCnt.getMaxValue(); + unsigned NumBits = KnownVal.getBitWidth(); + if (MaxShift.uge(NumBits)) return false; - unsigned BitWidth = getBitWidth(V->getType()->getScalarType(), Q.DL); + if (!ShiftOp(KnownVal.One, MaxShift).isZero()) + return true; + + // If all of the bits shifted out are known to be zero, and Val is known + // non-zero then at least one non-zero bit must remain. + if (InvShiftOp(KnownVal.Zero, NumBits - MaxShift) + .eq(InvShiftOp(APInt::getAllOnes(NumBits), NumBits - MaxShift)) && + isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q)) + return true; + + return false; +} + +static bool isKnownNonZeroFromOperator(const Operator *I, + const APInt &DemandedElts, + unsigned Depth, const SimplifyQuery &Q) { + unsigned BitWidth = getBitWidth(I->getType()->getScalarType(), Q.DL); switch (I->getOpcode()) { case Instruction::GetElementPtr: if (I->getType()->isPointerTy()) return isGEPKnownNonNull(cast<GEPOperator>(I), Depth, Q); break; - case Instruction::BitCast: - if (I->getType()->isPointerTy()) + case Instruction::BitCast: { + // We need to be a bit careful here. We can only peek through the bitcast + // if the scalar size of elements in the operand are smaller than and a + // multiple of the size they are casting too. Take three cases: + // + // 1) Unsafe: + // bitcast <2 x i16> %NonZero to <4 x i8> + // + // %NonZero can have 2 non-zero i16 elements, but isKnownNonZero on a + // <4 x i8> requires that all 4 i8 elements be non-zero which isn't + // guranteed (imagine just sign bit set in the 2 i16 elements). + // + // 2) Unsafe: + // bitcast <4 x i3> %NonZero to <3 x i4> + // + // Even though the scalar size of the src (`i3`) is smaller than the + // scalar size of the dst `i4`, because `i3` is not a multiple of `i4` + // its possible for the `3 x i4` elements to be zero because there are + // some elements in the destination that don't contain any full src + // element. + // + // 3) Safe: + // bitcast <4 x i8> %NonZero to <2 x i16> + // + // This is always safe as non-zero in the 4 i8 elements implies + // non-zero in the combination of any two adjacent ones. Since i8 is a + // multiple of i16, each i16 is guranteed to have 2 full i8 elements. + // This all implies the 2 i16 elements are non-zero. + Type *FromTy = I->getOperand(0)->getType(); + if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) && + (BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0) return isKnownNonZero(I->getOperand(0), Depth, Q); - break; + } break; case Instruction::IntToPtr: // Note that we have to take special care to avoid looking through // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well @@ -2579,19 +2508,22 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, Q.DL.getTypeSizeInBits(I->getType()).getFixedValue()) return isKnownNonZero(I->getOperand(0), Depth, Q); break; + case Instruction::Sub: + return isNonZeroSub(DemandedElts, Depth, Q, BitWidth, I->getOperand(0), + I->getOperand(1)); case Instruction::Or: // X | Y != 0 if X != 0 or Y != 0. - return isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) || - isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q); + return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) || + isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q); case Instruction::SExt: case Instruction::ZExt: // ext X != 0 if X != 0. return isKnownNonZero(I->getOperand(0), Depth, Q); case Instruction::Shl: { - // shl nuw can't remove any non-zero bits. - const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); - if (Q.IIQ.hasNoUnsignedWrap(BO)) + // shl nsw/nuw can't remove any non-zero bits. + const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I); + if (Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO)) return isKnownNonZero(I->getOperand(0), Depth, Q); // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined @@ -2600,12 +2532,13 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth, Q); if (Known.One[0]) return true; - break; + + return isNonZeroShift(I, DemandedElts, Depth, Q, Known); } case Instruction::LShr: case Instruction::AShr: { // shr exact can only shift out zero bits. - const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V); + const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(I); if (BO->isExact()) return isKnownNonZero(I->getOperand(0), Depth, Q); @@ -2616,86 +2549,110 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, if (Known.isNegative()) return true; - // If the shifter operand is a constant, and all of the bits shifted - // out are known to be zero, and X is known non-zero then at least one - // non-zero bit must remain. - if (ConstantInt *Shift = dyn_cast<ConstantInt>(I->getOperand(1))) { - auto ShiftVal = Shift->getLimitedValue(BitWidth - 1); - // Is there a known one in the portion not shifted out? - if (Known.countMaxLeadingZeros() < BitWidth - ShiftVal) - return true; - // Are all the bits to be shifted out known zero? - if (Known.countMinTrailingZeros() >= ShiftVal) - return isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q); - } - break; + return isNonZeroShift(I, DemandedElts, Depth, Q, Known); } case Instruction::UDiv: case Instruction::SDiv: + // X / Y // div exact can only produce a zero if the dividend is zero. if (cast<PossiblyExactOperator>(I)->isExact()) return isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q); + if (I->getOpcode() == Instruction::UDiv) { + std::optional<bool> XUgeY; + KnownBits XKnown = + computeKnownBits(I->getOperand(0), DemandedElts, Depth, Q); + if (!XKnown.isUnknown()) { + KnownBits YKnown = + computeKnownBits(I->getOperand(1), DemandedElts, Depth, Q); + // If X u>= Y then div is non zero (0/0 is UB). + XUgeY = KnownBits::uge(XKnown, YKnown); + } + // If X is total unknown or X u< Y we won't be able to prove non-zero + // with compute known bits so just return early. + return XUgeY && *XUgeY; + } break; case Instruction::Add: { // X + Y. + + // If Add has nuw wrap flag, then if either X or Y is non-zero the result is + // non-zero. + auto *BO = cast<OverflowingBinaryOperator>(I); + if (Q.IIQ.hasNoUnsignedWrap(BO)) + return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) || + isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q); + + return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, I->getOperand(0), + I->getOperand(1), Q.IIQ.hasNoSignedWrap(BO)); + } + case Instruction::Mul: { + // If X and Y are non-zero then so is X * Y as long as the multiplication + // does not overflow. + const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I); + if (Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) + return isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) && + isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q); + + // If either X or Y is odd, then if the other is non-zero the result can't + // be zero. KnownBits XKnown = computeKnownBits(I->getOperand(0), DemandedElts, Depth, Q); + if (XKnown.One[0]) + return isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q); + KnownBits YKnown = computeKnownBits(I->getOperand(1), DemandedElts, Depth, Q); + if (YKnown.One[0]) + return XKnown.isNonZero() || + isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q); + + // If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is + // non-zero, then X * Y is non-zero. We can find sX and sY by just taking + // the lowest known One of X and Y. If they are non-zero, the result + // must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing + // X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth. + return (XKnown.countMaxTrailingZeros() + YKnown.countMaxTrailingZeros()) < + BitWidth; + } + case Instruction::Select: { + // (C ? X : Y) != 0 if X != 0 and Y != 0. - // If X and Y are both non-negative (as signed values) then their sum is not - // zero unless both X and Y are zero. - if (XKnown.isNonNegative() && YKnown.isNonNegative()) - if (isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) || - isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q)) + // First check if the arm is non-zero using `isKnownNonZero`. If that fails, + // then see if the select condition implies the arm is non-zero. For example + // (X != 0 ? X : Y), we know the true arm is non-zero as the `X` "return" is + // dominated by `X != 0`. + auto SelectArmIsNonZero = [&](bool IsTrueArm) { + Value *Op; + Op = IsTrueArm ? I->getOperand(1) : I->getOperand(2); + // Op is trivially non-zero. + if (isKnownNonZero(Op, DemandedElts, Depth, Q)) return true; - // If X and Y are both negative (as signed values) then their sum is not - // zero unless both X and Y equal INT_MIN. - if (XKnown.isNegative() && YKnown.isNegative()) { - APInt Mask = APInt::getSignedMaxValue(BitWidth); - // The sign bit of X is set. If some other bit is set then X is not equal - // to INT_MIN. - if (XKnown.One.intersects(Mask)) - return true; - // The sign bit of Y is set. If some other bit is set then Y is not equal - // to INT_MIN. - if (YKnown.One.intersects(Mask)) - return true; - } + // The condition of the select dominates the true/false arm. Check if the + // condition implies that a given arm is non-zero. + Value *X; + CmpInst::Predicate Pred; + if (!match(I->getOperand(0), m_c_ICmp(Pred, m_Specific(Op), m_Value(X)))) + return false; - // The sum of a non-negative number and a power of two is not zero. - if (XKnown.isNonNegative() && - isKnownToBeAPowerOfTwo(I->getOperand(1), /*OrZero*/ false, Depth, Q)) - return true; - if (YKnown.isNonNegative() && - isKnownToBeAPowerOfTwo(I->getOperand(0), /*OrZero*/ false, Depth, Q)) - return true; - break; - } - case Instruction::Mul: { - // If X and Y are non-zero then so is X * Y as long as the multiplication - // does not overflow. - const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); - if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) && - isKnownNonZero(I->getOperand(0), DemandedElts, Depth, Q) && - isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q)) + if (!IsTrueArm) + Pred = ICmpInst::getInversePredicate(Pred); + + return cmpExcludesZero(Pred, X); + }; + + if (SelectArmIsNonZero(/* IsTrueArm */ true) && + SelectArmIsNonZero(/* IsTrueArm */ false)) return true; break; } - case Instruction::Select: - // (C ? X : Y) != 0 if X != 0 and Y != 0. - if (isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) && - isKnownNonZero(I->getOperand(2), DemandedElts, Depth, Q)) - return true; - break; case Instruction::PHI: { auto *PN = cast<PHINode>(I); if (Q.IIQ.UseInstrInfo && isNonZeroRecurrence(PN)) return true; // Check if all incoming values are non-zero using recursion. - Query RecQ = Q; + SimplifyQuery RecQ = Q; unsigned NewDepth = std::max(Depth, MaxAnalysisRecursionDepth - 1); return llvm::all_of(PN->operands(), [&](const Use &U) { if (U.get() == PN) @@ -2705,7 +2662,7 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, }); } case Instruction::ExtractElement: - if (const auto *EEI = dyn_cast<ExtractElementInst>(V)) { + if (const auto *EEI = dyn_cast<ExtractElementInst>(I)) { const Value *Vec = EEI->getVectorOperand(); const Value *Idx = EEI->getIndexOperand(); auto *CIdx = dyn_cast<ConstantInt>(Idx); @@ -2722,18 +2679,198 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, return isKnownNonZero(I->getOperand(0), Depth, Q) && isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT, Depth); - case Instruction::Call: - if (cast<CallInst>(I)->getIntrinsicID() == Intrinsic::vscale) + case Instruction::Load: + // A Load tagged with nonnull metadata is never null. + if (Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_nonnull)) return true; + + // No need to fall through to computeKnownBits as range metadata is already + // handled in isKnownNonZero. + return false; + case Instruction::Call: + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + case Intrinsic::sshl_sat: + case Intrinsic::ushl_sat: + case Intrinsic::abs: + case Intrinsic::bitreverse: + case Intrinsic::bswap: + case Intrinsic::ctpop: + return isKnownNonZero(II->getArgOperand(0), DemandedElts, Depth, Q); + case Intrinsic::ssub_sat: + return isNonZeroSub(DemandedElts, Depth, Q, BitWidth, + II->getArgOperand(0), II->getArgOperand(1)); + case Intrinsic::sadd_sat: + return isNonZeroAdd(DemandedElts, Depth, Q, BitWidth, + II->getArgOperand(0), II->getArgOperand(1), + /*NSW*/ true); + case Intrinsic::umax: + case Intrinsic::uadd_sat: + return isKnownNonZero(II->getArgOperand(1), DemandedElts, Depth, Q) || + isKnownNonZero(II->getArgOperand(0), DemandedElts, Depth, Q); + case Intrinsic::smin: + case Intrinsic::smax: { + auto KnownOpImpliesNonZero = [&](const KnownBits &K) { + return II->getIntrinsicID() == Intrinsic::smin + ? K.isNegative() + : K.isStrictlyPositive(); + }; + KnownBits XKnown = + computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q); + if (KnownOpImpliesNonZero(XKnown)) + return true; + KnownBits YKnown = + computeKnownBits(II->getArgOperand(1), DemandedElts, Depth, Q); + if (KnownOpImpliesNonZero(YKnown)) + return true; + + if (XKnown.isNonZero() && YKnown.isNonZero()) + return true; + } + [[fallthrough]]; + case Intrinsic::umin: + return isKnownNonZero(II->getArgOperand(0), DemandedElts, Depth, Q) && + isKnownNonZero(II->getArgOperand(1), DemandedElts, Depth, Q); + case Intrinsic::cttz: + return computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q) + .Zero[0]; + case Intrinsic::ctlz: + return computeKnownBits(II->getArgOperand(0), DemandedElts, Depth, Q) + .isNonNegative(); + case Intrinsic::fshr: + case Intrinsic::fshl: + // If Op0 == Op1, this is a rotate. rotate(x, y) != 0 iff x != 0. + if (II->getArgOperand(0) == II->getArgOperand(1)) + return isKnownNonZero(II->getArgOperand(0), DemandedElts, Depth, Q); + break; + case Intrinsic::vscale: + return true; + default: + break; + } + } break; } KnownBits Known(BitWidth); - computeKnownBits(V, DemandedElts, Known, Depth, Q); + computeKnownBits(I, DemandedElts, Known, Depth, Q); return Known.One != 0; } -bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) { +/// Return true if the given value is known to be non-zero when defined. For +/// vectors, return true if every demanded element is known to be non-zero when +/// defined. For pointers, if the context instruction and dominator tree are +/// specified, perform context-sensitive analysis and return true if the +/// pointer couldn't possibly be null at the specified instruction. +/// Supports values with integer or pointer type and vectors of integers. +bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, + const SimplifyQuery &Q) { + +#ifndef NDEBUG + Type *Ty = V->getType(); + assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); + + if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) { + assert( + FVTy->getNumElements() == DemandedElts.getBitWidth() && + "DemandedElt width should equal the fixed vector number of elements"); + } else { + assert(DemandedElts == APInt(1, 1) && + "DemandedElt width should be 1 for scalars"); + } +#endif + + if (auto *C = dyn_cast<Constant>(V)) { + if (C->isNullValue()) + return false; + if (isa<ConstantInt>(C)) + // Must be non-zero due to null test above. + return true; + + // For constant vectors, check that all elements are undefined or known + // non-zero to determine that the whole vector is known non-zero. + if (auto *VecTy = dyn_cast<FixedVectorType>(C->getType())) { + for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) { + if (!DemandedElts[i]) + continue; + Constant *Elt = C->getAggregateElement(i); + if (!Elt || Elt->isNullValue()) + return false; + if (!isa<UndefValue>(Elt) && !isa<ConstantInt>(Elt)) + return false; + } + return true; + } + + // A global variable in address space 0 is non null unless extern weak + // or an absolute symbol reference. Other address spaces may have null as a + // valid address for a global, so we can't assume anything. + if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) { + if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() && + GV->getType()->getAddressSpace() == 0) + return true; + } + + // For constant expressions, fall through to the Operator code below. + if (!isa<ConstantExpr>(V)) + return false; + } + + if (auto *I = dyn_cast<Instruction>(V)) { + if (MDNode *Ranges = Q.IIQ.getMetadata(I, LLVMContext::MD_range)) { + // If the possible ranges don't contain zero, then the value is + // definitely non-zero. + if (auto *Ty = dyn_cast<IntegerType>(V->getType())) { + const APInt ZeroValue(Ty->getBitWidth(), 0); + if (rangeMetadataExcludesValue(Ranges, ZeroValue)) + return true; + } + } + } + + if (!isa<Constant>(V) && isKnownNonZeroFromAssume(V, Q)) + return true; + + // Some of the tests below are recursive, so bail out if we hit the limit. + if (Depth++ >= MaxAnalysisRecursionDepth) + return false; + + // Check for pointer simplifications. + + if (PointerType *PtrTy = dyn_cast<PointerType>(V->getType())) { + // Alloca never returns null, malloc might. + if (isa<AllocaInst>(V) && PtrTy->getAddressSpace() == 0) + return true; + + // A byval, inalloca may not be null in a non-default addres space. A + // nonnull argument is assumed never 0. + if (const Argument *A = dyn_cast<Argument>(V)) { + if (((A->hasPassPointeeByValueCopyAttr() && + !NullPointerIsDefined(A->getParent(), PtrTy->getAddressSpace())) || + A->hasNonNullAttr())) + return true; + } + + if (const auto *Call = dyn_cast<CallBase>(V)) { + if (Call->isReturnNonNull()) + return true; + if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, true)) + return isKnownNonZero(RP, Depth, Q); + } + } + + if (const auto *I = dyn_cast<Operator>(V)) + if (isKnownNonZeroFromOperator(I, DemandedElts, Depth, Q)) + return true; + + if (!isa<Constant>(V) && + isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT)) + return true; + + return false; +} + +bool isKnownNonZero(const Value *V, unsigned Depth, const SimplifyQuery &Q) { auto *FVTy = dyn_cast<FixedVectorType>(V->getType()); APInt DemandedElts = FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1); @@ -2849,7 +2986,7 @@ getInvertibleOperands(const Operator *Op1, /// Return true if V2 == V1 + X, where X is known non-zero. static bool isAddOfNonZero(const Value *V1, const Value *V2, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1); if (!BO || BO->getOpcode() != Instruction::Add) return false; @@ -2866,7 +3003,7 @@ static bool isAddOfNonZero(const Value *V1, const Value *V2, unsigned Depth, /// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and /// the multiplication is nuw or nsw. static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) { const APInt *C; return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) && @@ -2879,7 +3016,7 @@ static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth, /// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and /// the shift is nuw or nsw. static bool isNonEqualShl(const Value *V1, const Value *V2, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) { const APInt *C; return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) && @@ -2890,7 +3027,7 @@ static bool isNonEqualShl(const Value *V1, const Value *V2, unsigned Depth, } static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2, - unsigned Depth, const Query &Q) { + unsigned Depth, const SimplifyQuery &Q) { // Check two PHIs are in same block. if (PN1->getParent() != PN2->getParent()) return false; @@ -2910,7 +3047,7 @@ static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2, if (UsedFullRecursion) return false; - Query RecQ = Q; + SimplifyQuery RecQ = Q; RecQ.CxtI = IncomBB->getTerminator(); if (!isKnownNonEqual(IV1, IV2, Depth + 1, RecQ)) return false; @@ -2921,7 +3058,7 @@ static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2, /// Return true if it is known that V1 != V2. static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { if (V1 == V2) return false; if (V1->getType() != V2->getType()) @@ -2981,7 +3118,7 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth, /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, - const Query &Q) { + const SimplifyQuery &Q) { KnownBits Known(Mask.getBitWidth()); computeKnownBits(V, Known, Depth, Q); return Mask.isSubsetOf(Known.Zero); @@ -3065,10 +3202,10 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V, static unsigned ComputeNumSignBitsImpl(const Value *V, const APInt &DemandedElts, - unsigned Depth, const Query &Q); + unsigned Depth, const SimplifyQuery &Q); static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts, - unsigned Depth, const Query &Q) { + unsigned Depth, const SimplifyQuery &Q) { unsigned Result = ComputeNumSignBitsImpl(V, DemandedElts, Depth, Q); assert(Result > 0 && "At least one sign bit needs to be present!"); return Result; @@ -3083,7 +3220,7 @@ static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts, /// elements in the vector specified by DemandedElts. static unsigned ComputeNumSignBitsImpl(const Value *V, const APInt &DemandedElts, - unsigned Depth, const Query &Q) { + unsigned Depth, const SimplifyQuery &Q) { Type *Ty = V->getType(); #ifndef NDEBUG assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); @@ -3303,7 +3440,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, // Take the minimum of all incoming values. This can't infinitely loop // because of our depth threshold. - Query RecQ = Q; + SimplifyQuery RecQ = Q; Tmp = TyBits; for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) { if (Tmp == 1) return Tmp; @@ -3511,68 +3648,13 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB, return Intrinsic::not_intrinsic; } -/// Return true if we can prove that the specified FP value is never equal to -/// -0.0. -/// NOTE: Do not check 'nsz' here because that fast-math-flag does not guarantee -/// that a value is not -0.0. It only guarantees that -0.0 may be treated -/// the same as +0.0 in floating-point ops. -bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, - unsigned Depth) { - if (auto *CFP = dyn_cast<ConstantFP>(V)) - return !CFP->getValueAPF().isNegZero(); - - if (Depth == MaxAnalysisRecursionDepth) - return false; - - auto *Op = dyn_cast<Operator>(V); - if (!Op) - return false; - - // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0. - if (match(Op, m_FAdd(m_Value(), m_PosZeroFP()))) - return true; - - // sitofp and uitofp turn into +0.0 for zero. - if (isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) - return true; - - if (auto *Call = dyn_cast<CallInst>(Op)) { - Intrinsic::ID IID = getIntrinsicForCallSite(*Call, TLI); - switch (IID) { - default: - break; - // sqrt(-0.0) = -0.0, no other negative results are possible. - case Intrinsic::sqrt: - case Intrinsic::canonicalize: - return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); - case Intrinsic::experimental_constrained_sqrt: { - // NOTE: This rounding mode restriction may be too strict. - const auto *CI = cast<ConstrainedFPIntrinsic>(Call); - if (CI->getRoundingMode() == RoundingMode::NearestTiesToEven) - return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); - else - return false; - } - // fabs(x) != -0.0 - case Intrinsic::fabs: - return true; - // sitofp and uitofp turn into +0.0 for zero. - case Intrinsic::experimental_constrained_sitofp: - case Intrinsic::experimental_constrained_uitofp: - return true; - } - } - - return false; -} - /// If \p SignBitOnly is true, test for a known 0 sign bit rather than a /// standard ordered compare. e.g. make -0.0 olt 0.0 be true because of the sign /// bit despite comparing equal. static bool cannotBeOrderedLessThanZeroImpl(const Value *V, + const DataLayout &DL, const TargetLibraryInfo *TLI, - bool SignBitOnly, - unsigned Depth) { + bool SignBitOnly, unsigned Depth) { // TODO: This function does not do the right thing when SignBitOnly is true // and we're lowering to a hypothetical IEEE 754-compliant-but-evil platform // which flips the sign bits of NaNs. See @@ -3621,9 +3703,9 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, return true; // Set SignBitOnly for RHS, because X / -0.0 is -Inf (or NaN). - return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, - Depth + 1) && - cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), DL, TLI, + SignBitOnly, Depth + 1) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), DL, TLI, /*SignBitOnly*/ true, Depth + 1); case Instruction::FMul: // X * X is always non-negative or a NaN. @@ -3634,26 +3716,26 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, [[fallthrough]]; case Instruction::FAdd: case Instruction::FRem: - return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, - Depth + 1) && - cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, - Depth + 1); + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), DL, TLI, + SignBitOnly, Depth + 1) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), DL, TLI, + SignBitOnly, Depth + 1); case Instruction::Select: - return cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, - Depth + 1) && - cannotBeOrderedLessThanZeroImpl(I->getOperand(2), TLI, SignBitOnly, - Depth + 1); + return cannotBeOrderedLessThanZeroImpl(I->getOperand(1), DL, TLI, + SignBitOnly, Depth + 1) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(2), DL, TLI, + SignBitOnly, Depth + 1); case Instruction::FPExt: case Instruction::FPTrunc: // Widening/narrowing never change sign. - return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, - Depth + 1); + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), DL, TLI, + SignBitOnly, Depth + 1); case Instruction::ExtractElement: // Look through extract element. At the moment we keep this simple and skip // tracking the specific element. But at least we might find information // valid for all elements of the vector. - return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, - Depth + 1); + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), DL, TLI, + SignBitOnly, Depth + 1); case Instruction::Call: const auto *CI = cast<CallInst>(I); Intrinsic::ID IID = getIntrinsicForCallSite(*CI, TLI); @@ -3670,7 +3752,8 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, case Intrinsic::round: case Intrinsic::roundeven: case Intrinsic::fptrunc_round: - return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1); + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), DL, TLI, + SignBitOnly, Depth + 1); case Intrinsic::maxnum: { Value *V0 = I->getOperand(0), *V1 = I->getOperand(1); auto isPositiveNum = [&](Value *V) { @@ -3685,8 +3768,8 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, // -0.0 compares equal to 0.0, so if this operand is at least -0.0, // maxnum can't be ordered-less-than-zero. - return isKnownNeverNaN(V, TLI) && - cannotBeOrderedLessThanZeroImpl(V, TLI, false, Depth + 1); + return isKnownNeverNaN(V, DL, TLI) && + cannotBeOrderedLessThanZeroImpl(V, DL, TLI, false, Depth + 1); }; // TODO: This could be improved. We could also check that neither operand @@ -3695,30 +3778,31 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, } case Intrinsic::maximum: - return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, - Depth + 1) || - cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, - Depth + 1); + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), DL, TLI, + SignBitOnly, Depth + 1) || + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), DL, TLI, + SignBitOnly, Depth + 1); case Intrinsic::minnum: case Intrinsic::minimum: - return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, - Depth + 1) && - cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, - Depth + 1); + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), DL, TLI, + SignBitOnly, Depth + 1) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), DL, TLI, + SignBitOnly, Depth + 1); case Intrinsic::exp: case Intrinsic::exp2: case Intrinsic::fabs: return true; case Intrinsic::copysign: // Only the sign operand matters. - return cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, true, + return cannotBeOrderedLessThanZeroImpl(I->getOperand(1), DL, TLI, true, Depth + 1); case Intrinsic::sqrt: // sqrt(x) is always >= -0 or NaN. Moreover, sqrt(x) == -0 iff x == -0. if (!SignBitOnly) return true; - return CI->hasNoNaNs() && (CI->hasNoSignedZeros() || - CannotBeNegativeZero(CI->getOperand(0), TLI)); + return CI->hasNoNaNs() && + (CI->hasNoSignedZeros() || + cannotBeNegativeZero(CI->getOperand(0), DL, TLI)); case Intrinsic::powi: if (ConstantInt *Exponent = dyn_cast<ConstantInt>(I->getOperand(1))) { @@ -3739,264 +3823,1423 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, // but we must return false if x == -0. Unfortunately we do not currently // have a way of expressing this constraint. See details in // https://llvm.org/bugs/show_bug.cgi?id=31702. - return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, - Depth + 1); + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), DL, TLI, + SignBitOnly, Depth + 1); case Intrinsic::fma: case Intrinsic::fmuladd: // x*x+y is non-negative if y is non-negative. return I->getOperand(0) == I->getOperand(1) && (!SignBitOnly || cast<FPMathOperator>(I)->hasNoNaNs()) && - cannotBeOrderedLessThanZeroImpl(I->getOperand(2), TLI, SignBitOnly, - Depth + 1); + cannotBeOrderedLessThanZeroImpl(I->getOperand(2), DL, TLI, + SignBitOnly, Depth + 1); } break; } return false; } -bool llvm::CannotBeOrderedLessThanZero(const Value *V, +bool llvm::CannotBeOrderedLessThanZero(const Value *V, const DataLayout &DL, const TargetLibraryInfo *TLI) { - return cannotBeOrderedLessThanZeroImpl(V, TLI, false, 0); + return cannotBeOrderedLessThanZeroImpl(V, DL, TLI, false, 0); } -bool llvm::SignBitMustBeZero(const Value *V, const TargetLibraryInfo *TLI) { - return cannotBeOrderedLessThanZeroImpl(V, TLI, true, 0); +bool llvm::SignBitMustBeZero(const Value *V, const DataLayout &DL, + const TargetLibraryInfo *TLI) { + return cannotBeOrderedLessThanZeroImpl(V, DL, TLI, true, 0); } -bool llvm::isKnownNeverInfinity(const Value *V, const TargetLibraryInfo *TLI, - unsigned Depth) { - assert(V->getType()->isFPOrFPVectorTy() && "Querying for Inf on non-FP type"); +/// Return true if it's possible to assume IEEE treatment of input denormals in +/// \p F for \p Val. +static bool inputDenormalIsIEEE(const Function &F, const Type *Ty) { + Ty = Ty->getScalarType(); + return F.getDenormalMode(Ty->getFltSemantics()).Input == DenormalMode::IEEE; +} - // If we're told that infinities won't happen, assume they won't. - if (auto *FPMathOp = dyn_cast<FPMathOperator>(V)) - if (FPMathOp->hasNoInfs()) - return true; +static bool inputDenormalIsIEEEOrPosZero(const Function &F, const Type *Ty) { + Ty = Ty->getScalarType(); + DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics()); + return Mode.Input == DenormalMode::IEEE || + Mode.Input == DenormalMode::PositiveZero; +} - // Handle scalar constants. - if (auto *CFP = dyn_cast<ConstantFP>(V)) - return !CFP->isInfinity(); +static bool outputDenormalIsIEEEOrPosZero(const Function &F, const Type *Ty) { + Ty = Ty->getScalarType(); + DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics()); + return Mode.Output == DenormalMode::IEEE || + Mode.Output == DenormalMode::PositiveZero; +} - if (Depth == MaxAnalysisRecursionDepth) +bool KnownFPClass::isKnownNeverLogicalZero(const Function &F, Type *Ty) const { + return isKnownNeverZero() && + (isKnownNeverSubnormal() || inputDenormalIsIEEE(F, Ty)); +} + +bool KnownFPClass::isKnownNeverLogicalNegZero(const Function &F, + Type *Ty) const { + return isKnownNeverNegZero() && + (isKnownNeverNegSubnormal() || inputDenormalIsIEEEOrPosZero(F, Ty)); +} + +bool KnownFPClass::isKnownNeverLogicalPosZero(const Function &F, + Type *Ty) const { + if (!isKnownNeverPosZero()) return false; - if (auto *Inst = dyn_cast<Instruction>(V)) { - switch (Inst->getOpcode()) { - case Instruction::Select: { - return isKnownNeverInfinity(Inst->getOperand(1), TLI, Depth + 1) && - isKnownNeverInfinity(Inst->getOperand(2), TLI, Depth + 1); + // If we know there are no denormals, nothing can be flushed to zero. + if (isKnownNeverSubnormal()) + return true; + + DenormalMode Mode = F.getDenormalMode(Ty->getScalarType()->getFltSemantics()); + switch (Mode.Input) { + case DenormalMode::IEEE: + return true; + case DenormalMode::PreserveSign: + // Negative subnormal won't flush to +0 + return isKnownNeverPosSubnormal(); + case DenormalMode::PositiveZero: + default: + // Both positive and negative subnormal could flush to +0 + return false; + } + + llvm_unreachable("covered switch over denormal mode"); +} + +void KnownFPClass::propagateDenormal(const KnownFPClass &Src, const Function &F, + Type *Ty) { + KnownFPClasses = Src.KnownFPClasses; + // If we aren't assuming the source can't be a zero, we don't have to check if + // a denormal input could be flushed. + if (!Src.isKnownNeverPosZero() && !Src.isKnownNeverNegZero()) + return; + + // If we know the input can't be a denormal, it can't be flushed to 0. + if (Src.isKnownNeverSubnormal()) + return; + + DenormalMode Mode = F.getDenormalMode(Ty->getScalarType()->getFltSemantics()); + + if (!Src.isKnownNeverPosSubnormal() && Mode != DenormalMode::getIEEE()) + KnownFPClasses |= fcPosZero; + + if (!Src.isKnownNeverNegSubnormal() && Mode != DenormalMode::getIEEE()) { + if (Mode != DenormalMode::getPositiveZero()) + KnownFPClasses |= fcNegZero; + + if (Mode.Input == DenormalMode::PositiveZero || + Mode.Output == DenormalMode::PositiveZero || + Mode.Input == DenormalMode::Dynamic || + Mode.Output == DenormalMode::Dynamic) + KnownFPClasses |= fcPosZero; + } +} + +void KnownFPClass::propagateCanonicalizingSrc(const KnownFPClass &Src, + const Function &F, Type *Ty) { + propagateDenormal(Src, F, Ty); + propagateNaN(Src, /*PreserveSign=*/true); +} + +/// Returns a pair of values, which if passed to llvm.is.fpclass, returns the +/// same result as an fcmp with the given operands. +std::pair<Value *, FPClassTest> llvm::fcmpToClassTest(FCmpInst::Predicate Pred, + const Function &F, + Value *LHS, Value *RHS, + bool LookThroughSrc) { + const APFloat *ConstRHS; + if (!match(RHS, m_APFloat(ConstRHS))) + return {nullptr, fcNone}; + + // fcmp ord x, zero|normal|subnormal|inf -> ~fcNan + if (Pred == FCmpInst::FCMP_ORD && !ConstRHS->isNaN()) + return {LHS, ~fcNan}; + + // fcmp uno x, zero|normal|subnormal|inf -> fcNan + if (Pred == FCmpInst::FCMP_UNO && !ConstRHS->isNaN()) + return {LHS, fcNan}; + + if (ConstRHS->isZero()) { + // Compares with fcNone are only exactly equal to fcZero if input denormals + // are not flushed. + // TODO: Handle DAZ by expanding masks to cover subnormal cases. + if (Pred != FCmpInst::FCMP_ORD && Pred != FCmpInst::FCMP_UNO && + !inputDenormalIsIEEE(F, LHS->getType())) + return {nullptr, fcNone}; + + switch (Pred) { + case FCmpInst::FCMP_OEQ: // Match x == 0.0 + return {LHS, fcZero}; + case FCmpInst::FCMP_UEQ: // Match isnan(x) || (x == 0.0) + return {LHS, fcZero | fcNan}; + case FCmpInst::FCMP_UNE: // Match (x != 0.0) + return {LHS, ~fcZero}; + case FCmpInst::FCMP_ONE: // Match !isnan(x) && x != 0.0 + return {LHS, ~fcNan & ~fcZero}; + case FCmpInst::FCMP_ORD: + // Canonical form of ord/uno is with a zero. We could also handle + // non-canonical other non-NaN constants or LHS == RHS. + return {LHS, ~fcNan}; + case FCmpInst::FCMP_UNO: + return {LHS, fcNan}; + case FCmpInst::FCMP_OGT: // x > 0 + return {LHS, fcPosSubnormal | fcPosNormal | fcPosInf}; + case FCmpInst::FCMP_UGT: // isnan(x) || x > 0 + return {LHS, fcPosSubnormal | fcPosNormal | fcPosInf | fcNan}; + case FCmpInst::FCMP_OGE: // x >= 0 + return {LHS, fcPositive | fcNegZero}; + case FCmpInst::FCMP_UGE: // isnan(x) || x >= 0 + return {LHS, fcPositive | fcNegZero | fcNan}; + case FCmpInst::FCMP_OLT: // x < 0 + return {LHS, fcNegSubnormal | fcNegNormal | fcNegInf}; + case FCmpInst::FCMP_ULT: // isnan(x) || x < 0 + return {LHS, fcNegSubnormal | fcNegNormal | fcNegInf | fcNan}; + case FCmpInst::FCMP_OLE: // x <= 0 + return {LHS, fcNegative | fcPosZero}; + case FCmpInst::FCMP_ULE: // isnan(x) || x <= 0 + return {LHS, fcNegative | fcPosZero | fcNan}; + default: + break; } - case Instruction::SIToFP: - case Instruction::UIToFP: { - // Get width of largest magnitude integer (remove a bit if signed). - // This still works for a signed minimum value because the largest FP - // value is scaled by some fraction close to 2.0 (1.0 + 0.xxxx). - int IntSize = Inst->getOperand(0)->getType()->getScalarSizeInBits(); - if (Inst->getOpcode() == Instruction::SIToFP) - --IntSize; - // If the exponent of the largest finite FP value can hold the largest - // integer, the result of the cast must be finite. - Type *FPTy = Inst->getType()->getScalarType(); - return ilogb(APFloat::getLargest(FPTy->getFltSemantics())) >= IntSize; + return {nullptr, fcNone}; + } + + Value *Src = LHS; + const bool IsFabs = LookThroughSrc && match(LHS, m_FAbs(m_Value(Src))); + + // Compute the test mask that would return true for the ordered comparisons. + FPClassTest Mask; + + if (ConstRHS->isInfinity()) { + switch (Pred) { + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UNE: { + // Match __builtin_isinf patterns + // + // fcmp oeq x, +inf -> is_fpclass x, fcPosInf + // fcmp oeq fabs(x), +inf -> is_fpclass x, fcInf + // fcmp oeq x, -inf -> is_fpclass x, fcNegInf + // fcmp oeq fabs(x), -inf -> is_fpclass x, 0 -> false + // + // fcmp une x, +inf -> is_fpclass x, ~fcPosInf + // fcmp une fabs(x), +inf -> is_fpclass x, ~fcInf + // fcmp une x, -inf -> is_fpclass x, ~fcNegInf + // fcmp une fabs(x), -inf -> is_fpclass x, fcAllFlags -> true + + if (ConstRHS->isNegative()) { + Mask = fcNegInf; + if (IsFabs) + Mask = fcNone; + } else { + Mask = fcPosInf; + if (IsFabs) + Mask |= fcNegInf; + } + + break; } - case Instruction::FNeg: - case Instruction::FPExt: { - // Peek through to source op. If it is not infinity, this is not infinity. - return isKnownNeverInfinity(Inst->getOperand(0), TLI, Depth + 1); + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UEQ: { + // Match __builtin_isinf patterns + // fcmp one x, -inf -> is_fpclass x, fcNegInf + // fcmp one fabs(x), -inf -> is_fpclass x, ~fcNegInf & ~fcNan + // fcmp one x, +inf -> is_fpclass x, ~fcNegInf & ~fcNan + // fcmp one fabs(x), +inf -> is_fpclass x, ~fcInf & fcNan + // + // fcmp ueq x, +inf -> is_fpclass x, fcPosInf|fcNan + // fcmp ueq (fabs x), +inf -> is_fpclass x, fcInf|fcNan + // fcmp ueq x, -inf -> is_fpclass x, fcNegInf|fcNan + // fcmp ueq fabs(x), -inf -> is_fpclass x, fcNan + if (ConstRHS->isNegative()) { + Mask = ~fcNegInf & ~fcNan; + if (IsFabs) + Mask = ~fcNan; + } else { + Mask = ~fcPosInf & ~fcNan; + if (IsFabs) + Mask &= ~fcNegInf; + } + + break; } - case Instruction::FPTrunc: { - // Need a range check. - return false; + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_UGE: { + if (ConstRHS->isNegative()) { + // No value is ordered and less than negative infinity. + // All values are unordered with or at least negative infinity. + // fcmp olt x, -inf -> false + // fcmp uge x, -inf -> true + Mask = fcNone; + break; + } + + // fcmp olt fabs(x), +inf -> fcFinite + // fcmp uge fabs(x), +inf -> ~fcFinite + // fcmp olt x, +inf -> fcFinite|fcNegInf + // fcmp uge x, +inf -> ~(fcFinite|fcNegInf) + Mask = fcFinite; + if (!IsFabs) + Mask |= fcNegInf; + break; + } + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_ULT: { + if (ConstRHS->isNegative()) // TODO + return {nullptr, fcNone}; + + // fcmp oge fabs(x), +inf -> fcInf + // fcmp oge x, +inf -> fcPosInf + // fcmp ult fabs(x), +inf -> ~fcInf + // fcmp ult x, +inf -> ~fcPosInf + Mask = fcPosInf; + if (IsFabs) + Mask |= fcNegInf; + break; + } + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_ULE: { + if (ConstRHS->isNegative()) + return {nullptr, fcNone}; + + // No value is ordered and greater than infinity. + Mask = fcNone; + break; } default: + return {nullptr, fcNone}; + } + } else if (ConstRHS->isSmallestNormalized() && !ConstRHS->isNegative()) { + // Match pattern that's used in __builtin_isnormal. + switch (Pred) { + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_UGE: { + // fcmp olt x, smallest_normal -> fcNegInf|fcNegNormal|fcSubnormal|fcZero + // fcmp olt fabs(x), smallest_normal -> fcSubnormal|fcZero + // fcmp uge x, smallest_normal -> fcNan|fcPosNormal|fcPosInf + // fcmp uge fabs(x), smallest_normal -> ~(fcSubnormal|fcZero) + Mask = fcZero | fcSubnormal; + if (!IsFabs) + Mask |= fcNegNormal | fcNegInf; + break; } + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_ULT: { + // fcmp oge x, smallest_normal -> fcPosNormal | fcPosInf + // fcmp oge fabs(x), smallest_normal -> fcInf | fcNormal + // fcmp ult x, smallest_normal -> ~(fcPosNormal | fcPosInf) + // fcmp ult fabs(x), smallest_normal -> ~(fcInf | fcNormal) + Mask = fcPosInf | fcPosNormal; + if (IsFabs) + Mask |= fcNegInf | fcNegNormal; + break; + } + default: + return {nullptr, fcNone}; + } + } else if (ConstRHS->isNaN()) { + // fcmp o__ x, nan -> false + // fcmp u__ x, nan -> true + Mask = fcNone; + } else + return {nullptr, fcNone}; - if (const auto *II = dyn_cast<IntrinsicInst>(V)) { - switch (II->getIntrinsicID()) { + // Invert the comparison for the unordered cases. + if (FCmpInst::isUnordered(Pred)) + Mask = ~Mask; + + return {Src, Mask}; +} + +static FPClassTest computeKnownFPClassFromAssumes(const Value *V, + const SimplifyQuery &Q) { + FPClassTest KnownFromAssume = fcAllFlags; + + // Try to restrict the floating-point classes based on information from + // assumptions. + for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { + if (!AssumeVH) + continue; + CallInst *I = cast<CallInst>(AssumeVH); + const Function *F = I->getFunction(); + + assert(F == Q.CxtI->getParent()->getParent() && + "Got assumption for the wrong function!"); + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + if (!isValidAssumeForContext(I, Q.CxtI, Q.DT)) + continue; + + CmpInst::Predicate Pred; + Value *LHS, *RHS; + uint64_t ClassVal = 0; + if (match(I->getArgOperand(0), m_FCmp(Pred, m_Value(LHS), m_Value(RHS)))) { + auto [TestedValue, TestedMask] = + fcmpToClassTest(Pred, *F, LHS, RHS, true); + // First see if we can fold in fabs/fneg into the test. + if (TestedValue == V) + KnownFromAssume &= TestedMask; + else { + // Try again without the lookthrough if we found a different source + // value. + auto [TestedValue, TestedMask] = + fcmpToClassTest(Pred, *F, LHS, RHS, false); + if (TestedValue == V) + KnownFromAssume &= TestedMask; + } + } else if (match(I->getArgOperand(0), + m_Intrinsic<Intrinsic::is_fpclass>( + m_Value(LHS), m_ConstantInt(ClassVal)))) { + KnownFromAssume &= static_cast<FPClassTest>(ClassVal); + } + } + + return KnownFromAssume; +} + +void computeKnownFPClass(const Value *V, const APInt &DemandedElts, + FPClassTest InterestedClasses, KnownFPClass &Known, + unsigned Depth, const SimplifyQuery &Q); + +static void computeKnownFPClass(const Value *V, KnownFPClass &Known, + FPClassTest InterestedClasses, unsigned Depth, + const SimplifyQuery &Q) { + auto *FVTy = dyn_cast<FixedVectorType>(V->getType()); + APInt DemandedElts = + FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1); + computeKnownFPClass(V, DemandedElts, InterestedClasses, Known, Depth, Q); +} + +static void computeKnownFPClassForFPTrunc(const Operator *Op, + const APInt &DemandedElts, + FPClassTest InterestedClasses, + KnownFPClass &Known, unsigned Depth, + const SimplifyQuery &Q) { + if ((InterestedClasses & + (KnownFPClass::OrderedLessThanZeroMask | fcNan)) == fcNone) + return; + + KnownFPClass KnownSrc; + computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses, + KnownSrc, Depth + 1, Q); + + // Sign should be preserved + // TODO: Handle cannot be ordered greater than zero + if (KnownSrc.cannotBeOrderedLessThanZero()) + Known.knownNot(KnownFPClass::OrderedLessThanZeroMask); + + Known.propagateNaN(KnownSrc, true); + + // Infinity needs a range check. +} + +// TODO: Merge implementation of cannotBeOrderedLessThanZero into here. +void computeKnownFPClass(const Value *V, const APInt &DemandedElts, + FPClassTest InterestedClasses, KnownFPClass &Known, + unsigned Depth, const SimplifyQuery &Q) { + assert(Known.isUnknown() && "should not be called with known information"); + + if (!DemandedElts) { + // No demanded elts, better to assume we don't know anything. + Known.resetAll(); + return; + } + + assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); + + if (auto *CFP = dyn_cast_or_null<ConstantFP>(V)) { + Known.KnownFPClasses = CFP->getValueAPF().classify(); + Known.SignBit = CFP->isNegative(); + return; + } + + // Try to handle fixed width vector constants + auto *VFVTy = dyn_cast<FixedVectorType>(V->getType()); + const Constant *CV = dyn_cast<Constant>(V); + if (VFVTy && CV) { + Known.KnownFPClasses = fcNone; + + // For vectors, verify that each element is not NaN. + unsigned NumElts = VFVTy->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = CV->getAggregateElement(i); + if (!Elt) { + Known = KnownFPClass(); + return; + } + if (isa<UndefValue>(Elt)) + continue; + auto *CElt = dyn_cast<ConstantFP>(Elt); + if (!CElt) { + Known = KnownFPClass(); + return; + } + + KnownFPClass KnownElt{CElt->getValueAPF().classify(), CElt->isNegative()}; + Known |= KnownElt; + } + + return; + } + + FPClassTest KnownNotFromFlags = fcNone; + if (const auto *CB = dyn_cast<CallBase>(V)) + KnownNotFromFlags |= CB->getRetNoFPClass(); + else if (const auto *Arg = dyn_cast<Argument>(V)) + KnownNotFromFlags |= Arg->getNoFPClass(); + + const Operator *Op = dyn_cast<Operator>(V); + if (const FPMathOperator *FPOp = dyn_cast_or_null<FPMathOperator>(Op)) { + if (FPOp->hasNoNaNs()) + KnownNotFromFlags |= fcNan; + if (FPOp->hasNoInfs()) + KnownNotFromFlags |= fcInf; + } + + if (Q.AC) { + FPClassTest AssumedClasses = computeKnownFPClassFromAssumes(V, Q); + KnownNotFromFlags |= ~AssumedClasses; + } + + // We no longer need to find out about these bits from inputs if we can + // assume this from flags/attributes. + InterestedClasses &= ~KnownNotFromFlags; + + auto ClearClassesFromFlags = make_scope_exit([=, &Known] { + Known.knownNot(KnownNotFromFlags); + }); + + if (!Op) + return; + + // All recursive calls that increase depth must come after this. + if (Depth == MaxAnalysisRecursionDepth) + return; + + const unsigned Opc = Op->getOpcode(); + switch (Opc) { + case Instruction::FNeg: { + computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses, + Known, Depth + 1, Q); + Known.fneg(); + break; + } + case Instruction::Select: { + Value *Cond = Op->getOperand(0); + Value *LHS = Op->getOperand(1); + Value *RHS = Op->getOperand(2); + + FPClassTest FilterLHS = fcAllFlags; + FPClassTest FilterRHS = fcAllFlags; + + Value *TestedValue = nullptr; + FPClassTest TestedMask = fcNone; + uint64_t ClassVal = 0; + const Function *F = cast<Instruction>(Op)->getFunction(); + CmpInst::Predicate Pred; + Value *CmpLHS, *CmpRHS; + if (F && match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { + // If the select filters out a value based on the class, it no longer + // participates in the class of the result + + // TODO: In some degenerate cases we can infer something if we try again + // without looking through sign operations. + bool LookThroughFAbsFNeg = CmpLHS != LHS && CmpLHS != RHS; + std::tie(TestedValue, TestedMask) = + fcmpToClassTest(Pred, *F, CmpLHS, CmpRHS, LookThroughFAbsFNeg); + } else if (match(Cond, + m_Intrinsic<Intrinsic::is_fpclass>( + m_Value(TestedValue), m_ConstantInt(ClassVal)))) { + TestedMask = static_cast<FPClassTest>(ClassVal); + } + + if (TestedValue == LHS) { + // match !isnan(x) ? x : y + FilterLHS = TestedMask; + } else if (TestedValue == RHS) { + // match !isnan(x) ? y : x + FilterRHS = ~TestedMask; + } + + KnownFPClass Known2; + computeKnownFPClass(LHS, DemandedElts, InterestedClasses & FilterLHS, Known, + Depth + 1, Q); + Known.KnownFPClasses &= FilterLHS; + + computeKnownFPClass(RHS, DemandedElts, InterestedClasses & FilterRHS, + Known2, Depth + 1, Q); + Known2.KnownFPClasses &= FilterRHS; + + Known |= Known2; + break; + } + case Instruction::Call: { + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op)) { + const Intrinsic::ID IID = II->getIntrinsicID(); + switch (IID) { + case Intrinsic::fabs: { + if ((InterestedClasses & (fcNan | fcPositive)) != fcNone) { + // If we only care about the sign bit we don't need to inspect the + // operand. + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, Known, Depth + 1, Q); + } + + Known.fabs(); + break; + } + case Intrinsic::copysign: { + KnownFPClass KnownSign; + + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, Known, Depth + 1, Q); + computeKnownFPClass(II->getArgOperand(1), DemandedElts, + InterestedClasses, KnownSign, Depth + 1, Q); + Known.copysign(KnownSign); + break; + } + case Intrinsic::fma: + case Intrinsic::fmuladd: { + if ((InterestedClasses & fcNegative) == fcNone) + break; + + if (II->getArgOperand(0) != II->getArgOperand(1)) + break; + + // The multiply cannot be -0 and therefore the add can't be -0 + Known.knownNot(fcNegZero); + + // x * x + y is non-negative if y is non-negative. + KnownFPClass KnownAddend; + computeKnownFPClass(II->getArgOperand(2), DemandedElts, + InterestedClasses, KnownAddend, Depth + 1, Q); + + // TODO: Known sign bit with no nans + if (KnownAddend.cannotBeOrderedLessThanZero()) + Known.knownNot(fcNegative); + break; + } + case Intrinsic::sqrt: + case Intrinsic::experimental_constrained_sqrt: { + KnownFPClass KnownSrc; + FPClassTest InterestedSrcs = InterestedClasses; + if (InterestedClasses & fcNan) + InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask; + + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedSrcs, KnownSrc, Depth + 1, Q); + + if (KnownSrc.isKnownNeverPosInfinity()) + Known.knownNot(fcPosInf); + if (KnownSrc.isKnownNever(fcSNan)) + Known.knownNot(fcSNan); + + // Any negative value besides -0 returns a nan. + if (KnownSrc.isKnownNeverNaN() && + KnownSrc.cannotBeOrderedLessThanZero()) + Known.knownNot(fcNan); + + // The only negative value that can be returned is -0 for -0 inputs. + Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal); + + // If the input denormal mode could be PreserveSign, a negative + // subnormal input could produce a negative zero output. + const Function *F = II->getFunction(); + if (Q.IIQ.hasNoSignedZeros(II) || + (F && KnownSrc.isKnownNeverLogicalNegZero(*F, II->getType()))) { + Known.knownNot(fcNegZero); + if (KnownSrc.isKnownNeverNaN()) + Known.SignBit = false; + } + + break; + } case Intrinsic::sin: - case Intrinsic::cos: + case Intrinsic::cos: { // Return NaN on infinite inputs. - return true; - case Intrinsic::fabs: - case Intrinsic::sqrt: - case Intrinsic::canonicalize: - case Intrinsic::copysign: - case Intrinsic::arithmetic_fence: + KnownFPClass KnownSrc; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, KnownSrc, Depth + 1, Q); + Known.knownNot(fcInf); + if (KnownSrc.isKnownNeverNaN() && KnownSrc.isKnownNeverInfinity()) + Known.knownNot(fcNan); + break; + } + + case Intrinsic::maxnum: + case Intrinsic::minnum: + case Intrinsic::minimum: + case Intrinsic::maximum: { + KnownFPClass KnownLHS, KnownRHS; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, KnownLHS, Depth + 1, Q); + computeKnownFPClass(II->getArgOperand(1), DemandedElts, + InterestedClasses, KnownRHS, Depth + 1, Q); + + bool NeverNaN = + KnownLHS.isKnownNeverNaN() || KnownRHS.isKnownNeverNaN(); + Known = KnownLHS | KnownRHS; + + // If either operand is not NaN, the result is not NaN. + if (NeverNaN && (IID == Intrinsic::minnum || IID == Intrinsic::maxnum)) + Known.knownNot(fcNan); + + if (IID == Intrinsic::maxnum) { + // If at least one operand is known to be positive, the result must be + // positive. + if ((KnownLHS.cannotBeOrderedLessThanZero() && + KnownLHS.isKnownNeverNaN()) || + (KnownRHS.cannotBeOrderedLessThanZero() && + KnownRHS.isKnownNeverNaN())) + Known.knownNot(KnownFPClass::OrderedLessThanZeroMask); + } else if (IID == Intrinsic::maximum) { + // If at least one operand is known to be positive, the result must be + // positive. + if (KnownLHS.cannotBeOrderedLessThanZero() || + KnownRHS.cannotBeOrderedLessThanZero()) + Known.knownNot(KnownFPClass::OrderedLessThanZeroMask); + } else if (IID == Intrinsic::minnum) { + // If at least one operand is known to be negative, the result must be + // negative. + if ((KnownLHS.cannotBeOrderedGreaterThanZero() && + KnownLHS.isKnownNeverNaN()) || + (KnownRHS.cannotBeOrderedGreaterThanZero() && + KnownRHS.isKnownNeverNaN())) + Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask); + } else { + // If at least one operand is known to be negative, the result must be + // negative. + if (KnownLHS.cannotBeOrderedGreaterThanZero() || + KnownRHS.cannotBeOrderedGreaterThanZero()) + Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask); + } + + // Fixup zero handling if denormals could be returned as a zero. + // + // As there's no spec for denormal flushing, be conservative with the + // treatment of denormals that could be flushed to zero. For older + // subtargets on AMDGPU the min/max instructions would not flush the + // output and return the original value. + // + // TODO: This could be refined based on the sign + if ((Known.KnownFPClasses & fcZero) != fcNone && + !Known.isKnownNeverSubnormal()) { + const Function *Parent = II->getFunction(); + if (!Parent) + break; + + DenormalMode Mode = Parent->getDenormalMode( + II->getType()->getScalarType()->getFltSemantics()); + if (Mode != DenormalMode::getIEEE()) + Known.KnownFPClasses |= fcZero; + } + + break; + } + case Intrinsic::canonicalize: { + KnownFPClass KnownSrc; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, KnownSrc, Depth + 1, Q); + + // This is essentially a stronger form of + // propagateCanonicalizingSrc. Other "canonicalizing" operations don't + // actually have an IR canonicalization guarantee. + + // Canonicalize may flush denormals to zero, so we have to consider the + // denormal mode to preserve known-not-0 knowledge. + Known.KnownFPClasses = KnownSrc.KnownFPClasses | fcZero | fcQNan; + + // Stronger version of propagateNaN + // Canonicalize is guaranteed to quiet signaling nans. + if (KnownSrc.isKnownNeverNaN()) + Known.knownNot(fcNan); + else + Known.knownNot(fcSNan); + + const Function *F = II->getFunction(); + if (!F) + break; + + // If the parent function flushes denormals, the canonical output cannot + // be a denormal. + const fltSemantics &FPType = + II->getType()->getScalarType()->getFltSemantics(); + DenormalMode DenormMode = F->getDenormalMode(FPType); + if (DenormMode == DenormalMode::getIEEE()) { + if (KnownSrc.isKnownNever(fcPosZero)) + Known.knownNot(fcPosZero); + if (KnownSrc.isKnownNever(fcNegZero)) + Known.knownNot(fcNegZero); + break; + } + + if (DenormMode.inputsAreZero() || DenormMode.outputsAreZero()) + Known.knownNot(fcSubnormal); + + if (DenormMode.Input == DenormalMode::PositiveZero || + (DenormMode.Output == DenormalMode::PositiveZero && + DenormMode.Input == DenormalMode::IEEE)) + Known.knownNot(fcNegZero); + + break; + } case Intrinsic::trunc: - return isKnownNeverInfinity(Inst->getOperand(0), TLI, Depth + 1); case Intrinsic::floor: case Intrinsic::ceil: case Intrinsic::rint: case Intrinsic::nearbyint: case Intrinsic::round: - case Intrinsic::roundeven: - // PPC_FP128 is a special case. - if (V->getType()->isMultiUnitFPType()) - return false; - return isKnownNeverInfinity(Inst->getOperand(0), TLI, Depth + 1); - case Intrinsic::fptrunc_round: - // Requires knowing the value range. - return false; - case Intrinsic::minnum: - case Intrinsic::maxnum: - case Intrinsic::minimum: - case Intrinsic::maximum: - return isKnownNeverInfinity(Inst->getOperand(0), TLI, Depth + 1) && - isKnownNeverInfinity(Inst->getOperand(1), TLI, Depth + 1); + case Intrinsic::roundeven: { + KnownFPClass KnownSrc; + FPClassTest InterestedSrcs = InterestedClasses; + if (InterestedSrcs & fcPosFinite) + InterestedSrcs |= fcPosFinite; + if (InterestedSrcs & fcNegFinite) + InterestedSrcs |= fcNegFinite; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedSrcs, KnownSrc, Depth + 1, Q); + + // Integer results cannot be subnormal. + Known.knownNot(fcSubnormal); + + Known.propagateNaN(KnownSrc, true); + + // Pass through infinities, except PPC_FP128 is a special case for + // intrinsics other than trunc. + if (IID == Intrinsic::trunc || !V->getType()->isMultiUnitFPType()) { + if (KnownSrc.isKnownNeverPosInfinity()) + Known.knownNot(fcPosInf); + if (KnownSrc.isKnownNeverNegInfinity()) + Known.knownNot(fcNegInf); + } + + // Negative round ups to 0 produce -0 + if (KnownSrc.isKnownNever(fcPosFinite)) + Known.knownNot(fcPosFinite); + if (KnownSrc.isKnownNever(fcNegFinite)) + Known.knownNot(fcNegFinite); + + break; + } + case Intrinsic::exp: + case Intrinsic::exp2: { + Known.knownNot(fcNegative); + if ((InterestedClasses & fcNan) == fcNone) + break; + + KnownFPClass KnownSrc; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, KnownSrc, Depth + 1, Q); + if (KnownSrc.isKnownNeverNaN()) { + Known.knownNot(fcNan); + Known.SignBit = false; + } + + break; + } + case Intrinsic::fptrunc_round: { + computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, + Known, Depth, Q); + break; + } case Intrinsic::log: case Intrinsic::log10: case Intrinsic::log2: + case Intrinsic::experimental_constrained_log: + case Intrinsic::experimental_constrained_log10: + case Intrinsic::experimental_constrained_log2: { // log(+inf) -> +inf // log([+-]0.0) -> -inf // log(-inf) -> nan // log(-x) -> nan - // TODO: We lack API to check the == 0 case. - return false; - case Intrinsic::exp: - case Intrinsic::exp2: - case Intrinsic::pow: - case Intrinsic::powi: - case Intrinsic::fma: - case Intrinsic::fmuladd: - // These can return infinities on overflow cases, so it's hard to prove - // anything about it. - return false; + if ((InterestedClasses & (fcNan | fcInf)) == fcNone) + break; + + FPClassTest InterestedSrcs = InterestedClasses; + if ((InterestedClasses & fcNegInf) != fcNone) + InterestedSrcs |= fcZero | fcSubnormal; + if ((InterestedClasses & fcNan) != fcNone) + InterestedSrcs |= fcNan | (fcNegative & ~fcNan); + + KnownFPClass KnownSrc; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs, + KnownSrc, Depth + 1, Q); + + if (KnownSrc.isKnownNeverPosInfinity()) + Known.knownNot(fcPosInf); + + if (KnownSrc.isKnownNeverNaN() && + KnownSrc.cannotBeOrderedLessThanZero()) + Known.knownNot(fcNan); + + const Function *F = II->getFunction(); + if (F && KnownSrc.isKnownNeverLogicalZero(*F, II->getType())) + Known.knownNot(fcNegInf); + + break; + } + case Intrinsic::powi: { + if ((InterestedClasses & fcNegative) == fcNone) + break; + + const Value *Exp = II->getArgOperand(1); + Type *ExpTy = Exp->getType(); + unsigned BitWidth = ExpTy->getScalarType()->getIntegerBitWidth(); + KnownBits ExponentKnownBits(BitWidth); + computeKnownBits(Exp, + isa<VectorType>(ExpTy) ? DemandedElts : APInt(1, 1), + ExponentKnownBits, Depth + 1, Q); + + if (ExponentKnownBits.Zero[0]) { // Is even + Known.knownNot(fcNegative); + break; + } + + // Given that exp is an integer, here are the + // ways that pow can return a negative value: + // + // pow(-x, exp) --> negative if exp is odd and x is negative. + // pow(-0, exp) --> -inf if exp is negative odd. + // pow(-0, exp) --> -0 if exp is positive odd. + // pow(-inf, exp) --> -0 if exp is negative odd. + // pow(-inf, exp) --> -inf if exp is positive odd. + KnownFPClass KnownSrc; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, fcNegative, + KnownSrc, Depth + 1, Q); + if (KnownSrc.isKnownNever(fcNegative)) + Known.knownNot(fcNegative); + break; + } + case Intrinsic::ldexp: { + KnownFPClass KnownSrc; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, KnownSrc, Depth + 1, Q); + Known.propagateNaN(KnownSrc, /*PropagateSign=*/true); + + // Sign is preserved, but underflows may produce zeroes. + if (KnownSrc.isKnownNever(fcNegative)) + Known.knownNot(fcNegative); + else if (KnownSrc.cannotBeOrderedLessThanZero()) + Known.knownNot(KnownFPClass::OrderedLessThanZeroMask); + + if (KnownSrc.isKnownNever(fcPositive)) + Known.knownNot(fcPositive); + else if (KnownSrc.cannotBeOrderedGreaterThanZero()) + Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask); + + // Can refine inf/zero handling based on the exponent operand. + const FPClassTest ExpInfoMask = fcZero | fcSubnormal | fcInf; + if ((InterestedClasses & ExpInfoMask) == fcNone) + break; + if ((KnownSrc.KnownFPClasses & ExpInfoMask) == fcNone) + break; + + const fltSemantics &Flt + = II->getType()->getScalarType()->getFltSemantics(); + unsigned Precision = APFloat::semanticsPrecision(Flt); + const Value *ExpArg = II->getArgOperand(1); + ConstantRange ExpRange = computeConstantRange( + ExpArg, true, Q.IIQ.UseInstrInfo, Q.AC, Q.CxtI, Q.DT, Depth + 1); + + const int MantissaBits = Precision - 1; + if (ExpRange.getSignedMin().sge(static_cast<int64_t>(MantissaBits))) + Known.knownNot(fcSubnormal); + + const Function *F = II->getFunction(); + const APInt *ConstVal = ExpRange.getSingleElement(); + if (ConstVal && ConstVal->isZero()) { + // ldexp(x, 0) -> x, so propagate everything. + Known.propagateCanonicalizingSrc(KnownSrc, *F, + II->getType()); + } else if (ExpRange.isAllNegative()) { + // If we know the power is <= 0, can't introduce inf + if (KnownSrc.isKnownNeverPosInfinity()) + Known.knownNot(fcPosInf); + if (KnownSrc.isKnownNeverNegInfinity()) + Known.knownNot(fcNegInf); + } else if (ExpRange.isAllNonNegative()) { + // If we know the power is >= 0, can't introduce subnormal or zero + if (KnownSrc.isKnownNeverPosSubnormal()) + Known.knownNot(fcPosSubnormal); + if (KnownSrc.isKnownNeverNegSubnormal()) + Known.knownNot(fcNegSubnormal); + if (F && KnownSrc.isKnownNeverLogicalPosZero(*F, II->getType())) + Known.knownNot(fcPosZero); + if (F && KnownSrc.isKnownNeverLogicalNegZero(*F, II->getType())) + Known.knownNot(fcNegZero); + } + + break; + } + case Intrinsic::arithmetic_fence: { + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, Known, Depth + 1, Q); + break; + } + case Intrinsic::experimental_constrained_sitofp: + case Intrinsic::experimental_constrained_uitofp: + // Cannot produce nan + Known.knownNot(fcNan); + + // sitofp and uitofp turn into +0.0 for zero. + Known.knownNot(fcNegZero); + + // Integers cannot be subnormal + Known.knownNot(fcSubnormal); + + if (IID == Intrinsic::experimental_constrained_uitofp) + Known.signBitMustBeZero(); + + // TODO: Copy inf handling from instructions + break; default: break; } } + + break; } + case Instruction::FAdd: + case Instruction::FSub: { + KnownFPClass KnownLHS, KnownRHS; + bool WantNegative = + Op->getOpcode() == Instruction::FAdd && + (InterestedClasses & KnownFPClass::OrderedLessThanZeroMask) != fcNone; + bool WantNaN = (InterestedClasses & fcNan) != fcNone; + bool WantNegZero = (InterestedClasses & fcNegZero) != fcNone; + + if (!WantNaN && !WantNegative && !WantNegZero) + break; - // try to handle fixed width vector constants - auto *VFVTy = dyn_cast<FixedVectorType>(V->getType()); - if (VFVTy && isa<Constant>(V)) { - // For vectors, verify that each element is not infinity. - unsigned NumElts = VFVTy->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = cast<Constant>(V)->getAggregateElement(i); - if (!Elt) - return false; - if (isa<UndefValue>(Elt)) - continue; - auto *CElt = dyn_cast<ConstantFP>(Elt); - if (!CElt || CElt->isInfinity()) - return false; + FPClassTest InterestedSrcs = InterestedClasses; + if (WantNegative) + InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask; + if (InterestedClasses & fcNan) + InterestedSrcs |= fcInf; + computeKnownFPClass(Op->getOperand(1), DemandedElts, InterestedSrcs, + KnownRHS, Depth + 1, Q); + + if ((WantNaN && KnownRHS.isKnownNeverNaN()) || + (WantNegative && KnownRHS.cannotBeOrderedLessThanZero()) || + WantNegZero || Opc == Instruction::FSub) { + + // RHS is canonically cheaper to compute. Skip inspecting the LHS if + // there's no point. + computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedSrcs, + KnownLHS, Depth + 1, Q); + // Adding positive and negative infinity produces NaN. + // TODO: Check sign of infinities. + if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() && + (KnownLHS.isKnownNeverInfinity() || KnownRHS.isKnownNeverInfinity())) + Known.knownNot(fcNan); + + // FIXME: Context function should always be passed in separately + const Function *F = cast<Instruction>(Op)->getFunction(); + + if (Op->getOpcode() == Instruction::FAdd) { + if (KnownLHS.cannotBeOrderedLessThanZero() && + KnownRHS.cannotBeOrderedLessThanZero()) + Known.knownNot(KnownFPClass::OrderedLessThanZeroMask); + if (!F) + break; + + // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0. + if ((KnownLHS.isKnownNeverLogicalNegZero(*F, Op->getType()) || + KnownRHS.isKnownNeverLogicalNegZero(*F, Op->getType())) && + // Make sure output negative denormal can't flush to -0 + outputDenormalIsIEEEOrPosZero(*F, Op->getType())) + Known.knownNot(fcNegZero); + } else { + if (!F) + break; + + // Only fsub -0, +0 can return -0 + if ((KnownLHS.isKnownNeverLogicalNegZero(*F, Op->getType()) || + KnownRHS.isKnownNeverLogicalPosZero(*F, Op->getType())) && + // Make sure output negative denormal can't flush to -0 + outputDenormalIsIEEEOrPosZero(*F, Op->getType())) + Known.knownNot(fcNegZero); + } } - // All elements were confirmed non-infinity or undefined. - return true; + + break; } + case Instruction::FMul: { + // X * X is always non-negative or a NaN. + if (Op->getOperand(0) == Op->getOperand(1)) + Known.knownNot(fcNegative); - // was not able to prove that V never contains infinity - return false; -} + if ((InterestedClasses & fcNan) != fcNan) + break; -bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, - unsigned Depth) { - assert(V->getType()->isFPOrFPVectorTy() && "Querying for NaN on non-FP type"); + // fcSubnormal is only needed in case of DAZ. + const FPClassTest NeedForNan = fcNan | fcInf | fcZero | fcSubnormal; - // If we're told that NaNs won't happen, assume they won't. - if (auto *FPMathOp = dyn_cast<FPMathOperator>(V)) - if (FPMathOp->hasNoNaNs()) - return true; + KnownFPClass KnownLHS, KnownRHS; + computeKnownFPClass(Op->getOperand(1), DemandedElts, NeedForNan, KnownRHS, + Depth + 1, Q); + if (!KnownRHS.isKnownNeverNaN()) + break; - // Handle scalar constants. - if (auto *CFP = dyn_cast<ConstantFP>(V)) - return !CFP->isNaN(); + computeKnownFPClass(Op->getOperand(0), DemandedElts, NeedForNan, KnownLHS, + Depth + 1, Q); + if (!KnownLHS.isKnownNeverNaN()) + break; - if (Depth == MaxAnalysisRecursionDepth) - return false; + // If 0 * +/-inf produces NaN. + if (KnownLHS.isKnownNeverInfinity() && KnownRHS.isKnownNeverInfinity()) { + Known.knownNot(fcNan); + break; + } - if (auto *Inst = dyn_cast<Instruction>(V)) { - switch (Inst->getOpcode()) { - case Instruction::FAdd: - case Instruction::FSub: - // Adding positive and negative infinity produces NaN. - return isKnownNeverNaN(Inst->getOperand(0), TLI, Depth + 1) && - isKnownNeverNaN(Inst->getOperand(1), TLI, Depth + 1) && - (isKnownNeverInfinity(Inst->getOperand(0), TLI, Depth + 1) || - isKnownNeverInfinity(Inst->getOperand(1), TLI, Depth + 1)); - - case Instruction::FMul: - // Zero multiplied with infinity produces NaN. - // FIXME: If neither side can be zero fmul never produces NaN. - return isKnownNeverNaN(Inst->getOperand(0), TLI, Depth + 1) && - isKnownNeverInfinity(Inst->getOperand(0), TLI, Depth + 1) && - isKnownNeverNaN(Inst->getOperand(1), TLI, Depth + 1) && - isKnownNeverInfinity(Inst->getOperand(1), TLI, Depth + 1); - - case Instruction::FDiv: - case Instruction::FRem: - // FIXME: Only 0/0, Inf/Inf, Inf REM x and x REM 0 produce NaN. - return false; + const Function *F = cast<Instruction>(Op)->getFunction(); + if (!F) + break; - case Instruction::Select: { - return isKnownNeverNaN(Inst->getOperand(1), TLI, Depth + 1) && - isKnownNeverNaN(Inst->getOperand(2), TLI, Depth + 1); + if ((KnownRHS.isKnownNeverInfinity() || + KnownLHS.isKnownNeverLogicalZero(*F, Op->getType())) && + (KnownLHS.isKnownNeverInfinity() || + KnownRHS.isKnownNeverLogicalZero(*F, Op->getType()))) + Known.knownNot(fcNan); + + break; + } + case Instruction::FDiv: + case Instruction::FRem: { + if (Op->getOperand(0) == Op->getOperand(1)) { + // TODO: Could filter out snan if we inspect the operand + if (Op->getOpcode() == Instruction::FDiv) { + // X / X is always exactly 1.0 or a NaN. + Known.KnownFPClasses = fcNan | fcPosNormal; + } else { + // X % X is always exactly [+-]0.0 or a NaN. + Known.KnownFPClasses = fcNan | fcZero; + } + + break; } - case Instruction::SIToFP: - case Instruction::UIToFP: - return true; - case Instruction::FPTrunc: - case Instruction::FPExt: - case Instruction::FNeg: - return isKnownNeverNaN(Inst->getOperand(0), TLI, Depth + 1); - default: + + const bool WantNan = (InterestedClasses & fcNan) != fcNone; + const bool WantNegative = (InterestedClasses & fcNegative) != fcNone; + const bool WantPositive = + Opc == Instruction::FRem && (InterestedClasses & fcPositive) != fcNone; + if (!WantNan && !WantNegative && !WantPositive) break; + + KnownFPClass KnownLHS, KnownRHS; + + computeKnownFPClass(Op->getOperand(1), DemandedElts, + fcNan | fcInf | fcZero | fcNegative, KnownRHS, + Depth + 1, Q); + + bool KnowSomethingUseful = + KnownRHS.isKnownNeverNaN() || KnownRHS.isKnownNever(fcNegative); + + if (KnowSomethingUseful || WantPositive) { + const FPClassTest InterestedLHS = + WantPositive ? fcAllFlags + : fcNan | fcInf | fcZero | fcSubnormal | fcNegative; + + computeKnownFPClass(Op->getOperand(0), DemandedElts, + InterestedClasses & InterestedLHS, KnownLHS, + Depth + 1, Q); } + + const Function *F = cast<Instruction>(Op)->getFunction(); + + if (Op->getOpcode() == Instruction::FDiv) { + // Only 0/0, Inf/Inf produce NaN. + if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() && + (KnownLHS.isKnownNeverInfinity() || + KnownRHS.isKnownNeverInfinity()) && + ((F && KnownLHS.isKnownNeverLogicalZero(*F, Op->getType())) || + (F && KnownRHS.isKnownNeverLogicalZero(*F, Op->getType())))) { + Known.knownNot(fcNan); + } + + // X / -0.0 is -Inf (or NaN). + // +X / +X is +X + if (KnownLHS.isKnownNever(fcNegative) && KnownRHS.isKnownNever(fcNegative)) + Known.knownNot(fcNegative); + } else { + // Inf REM x and x REM 0 produce NaN. + if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() && + KnownLHS.isKnownNeverInfinity() && F && + KnownRHS.isKnownNeverLogicalZero(*F, Op->getType())) { + Known.knownNot(fcNan); + } + + // The sign for frem is the same as the first operand. + if (KnownLHS.cannotBeOrderedLessThanZero()) + Known.knownNot(KnownFPClass::OrderedLessThanZeroMask); + if (KnownLHS.cannotBeOrderedGreaterThanZero()) + Known.knownNot(KnownFPClass::OrderedGreaterThanZeroMask); + + // See if we can be more aggressive about the sign of 0. + if (KnownLHS.isKnownNever(fcNegative)) + Known.knownNot(fcNegative); + if (KnownLHS.isKnownNever(fcPositive)) + Known.knownNot(fcPositive); + } + + break; } + case Instruction::FPExt: { + // Infinity, nan and zero propagate from source. + computeKnownFPClass(Op->getOperand(0), DemandedElts, InterestedClasses, + Known, Depth + 1, Q); - if (const auto *II = dyn_cast<IntrinsicInst>(V)) { - switch (II->getIntrinsicID()) { - case Intrinsic::canonicalize: - case Intrinsic::fabs: - case Intrinsic::copysign: - case Intrinsic::exp: - case Intrinsic::exp2: - case Intrinsic::floor: - case Intrinsic::ceil: - case Intrinsic::trunc: - case Intrinsic::rint: - case Intrinsic::nearbyint: - case Intrinsic::round: - case Intrinsic::roundeven: - case Intrinsic::arithmetic_fence: - return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1); - case Intrinsic::sqrt: - return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) && - CannotBeOrderedLessThanZero(II->getArgOperand(0), TLI); - case Intrinsic::minnum: - case Intrinsic::maxnum: - // If either operand is not NaN, the result is not NaN. - return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) || - isKnownNeverNaN(II->getArgOperand(1), TLI, Depth + 1); - default: - return false; + const fltSemantics &DstTy = + Op->getType()->getScalarType()->getFltSemantics(); + const fltSemantics &SrcTy = + Op->getOperand(0)->getType()->getScalarType()->getFltSemantics(); + + // All subnormal inputs should be in the normal range in the result type. + if (APFloat::isRepresentableAsNormalIn(SrcTy, DstTy)) + Known.knownNot(fcSubnormal); + + // Sign bit of a nan isn't guaranteed. + if (!Known.isKnownNeverNaN()) + Known.SignBit = std::nullopt; + break; + } + case Instruction::FPTrunc: { + computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known, + Depth, Q); + break; + } + case Instruction::SIToFP: + case Instruction::UIToFP: { + // Cannot produce nan + Known.knownNot(fcNan); + + // Integers cannot be subnormal + Known.knownNot(fcSubnormal); + + // sitofp and uitofp turn into +0.0 for zero. + Known.knownNot(fcNegZero); + if (Op->getOpcode() == Instruction::UIToFP) + Known.signBitMustBeZero(); + + if (InterestedClasses & fcInf) { + // Get width of largest magnitude integer (remove a bit if signed). + // This still works for a signed minimum value because the largest FP + // value is scaled by some fraction close to 2.0 (1.0 + 0.xxxx). + int IntSize = Op->getOperand(0)->getType()->getScalarSizeInBits(); + if (Op->getOpcode() == Instruction::SIToFP) + --IntSize; + + // If the exponent of the largest finite FP value can hold the largest + // integer, the result of the cast must be finite. + Type *FPTy = Op->getType()->getScalarType(); + if (ilogb(APFloat::getLargest(FPTy->getFltSemantics())) >= IntSize) + Known.knownNot(fcInf); } + + break; } + case Instruction::ExtractElement: { + // Look through extract element. If the index is non-constant or + // out-of-range demand all elements, otherwise just the extracted element. + const Value *Vec = Op->getOperand(0); + const Value *Idx = Op->getOperand(1); + auto *CIdx = dyn_cast<ConstantInt>(Idx); - // Try to handle fixed width vector constants - auto *VFVTy = dyn_cast<FixedVectorType>(V->getType()); - if (VFVTy && isa<Constant>(V)) { - // For vectors, verify that each element is not NaN. - unsigned NumElts = VFVTy->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = cast<Constant>(V)->getAggregateElement(i); - if (!Elt) - return false; - if (isa<UndefValue>(Elt)) - continue; - auto *CElt = dyn_cast<ConstantFP>(Elt); - if (!CElt || CElt->isNaN()) - return false; + if (auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType())) { + unsigned NumElts = VecTy->getNumElements(); + APInt DemandedVecElts = APInt::getAllOnes(NumElts); + if (CIdx && CIdx->getValue().ult(NumElts)) + DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue()); + return computeKnownFPClass(Vec, DemandedVecElts, InterestedClasses, Known, + Depth + 1, Q); } - // All elements were confirmed not-NaN or undefined. - return true; + + break; } + case Instruction::InsertElement: { + if (isa<ScalableVectorType>(Op->getType())) + return; - // Was not able to prove that V never contains NaN - return false; + const Value *Vec = Op->getOperand(0); + const Value *Elt = Op->getOperand(1); + auto *CIdx = dyn_cast<ConstantInt>(Op->getOperand(2)); + // Early out if the index is non-constant or out-of-range. + unsigned NumElts = DemandedElts.getBitWidth(); + if (!CIdx || CIdx->getValue().uge(NumElts)) + return; + + unsigned EltIdx = CIdx->getZExtValue(); + // Do we demand the inserted element? + if (DemandedElts[EltIdx]) { + computeKnownFPClass(Elt, Known, InterestedClasses, Depth + 1, Q); + // If we don't know any bits, early out. + if (Known.isUnknown()) + break; + } else { + Known.KnownFPClasses = fcNone; + } + + // We don't need the base vector element that has been inserted. + APInt DemandedVecElts = DemandedElts; + DemandedVecElts.clearBit(EltIdx); + if (!!DemandedVecElts) { + KnownFPClass Known2; + computeKnownFPClass(Vec, DemandedVecElts, InterestedClasses, Known2, + Depth + 1, Q); + Known |= Known2; + } + + break; + } + case Instruction::ShuffleVector: { + // For undef elements, we don't know anything about the common state of + // the shuffle result. + APInt DemandedLHS, DemandedRHS; + auto *Shuf = dyn_cast<ShuffleVectorInst>(Op); + if (!Shuf || !getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) + return; + + if (!!DemandedLHS) { + const Value *LHS = Shuf->getOperand(0); + computeKnownFPClass(LHS, DemandedLHS, InterestedClasses, Known, + Depth + 1, Q); + + // If we don't know any bits, early out. + if (Known.isUnknown()) + break; + } else { + Known.KnownFPClasses = fcNone; + } + + if (!!DemandedRHS) { + KnownFPClass Known2; + const Value *RHS = Shuf->getOperand(1); + computeKnownFPClass(RHS, DemandedRHS, InterestedClasses, Known2, + Depth + 1, Q); + Known |= Known2; + } + + break; + } + case Instruction::ExtractValue: { + const ExtractValueInst *Extract = cast<ExtractValueInst>(Op); + ArrayRef<unsigned> Indices = Extract->getIndices(); + const Value *Src = Extract->getAggregateOperand(); + if (isa<StructType>(Src->getType()) && Indices.size() == 1 && + Indices[0] == 0) { + if (const auto *II = dyn_cast<IntrinsicInst>(Src)) { + switch (II->getIntrinsicID()) { + case Intrinsic::frexp: { + Known.knownNot(fcSubnormal); + + KnownFPClass KnownSrc; + computeKnownFPClass(II->getArgOperand(0), DemandedElts, + InterestedClasses, KnownSrc, Depth + 1, Q); + + const Function *F = cast<Instruction>(Op)->getFunction(); + + if (KnownSrc.isKnownNever(fcNegative)) + Known.knownNot(fcNegative); + else { + if (F && KnownSrc.isKnownNeverLogicalNegZero(*F, Op->getType())) + Known.knownNot(fcNegZero); + if (KnownSrc.isKnownNever(fcNegInf)) + Known.knownNot(fcNegInf); + } + + if (KnownSrc.isKnownNever(fcPositive)) + Known.knownNot(fcPositive); + else { + if (F && KnownSrc.isKnownNeverLogicalPosZero(*F, Op->getType())) + Known.knownNot(fcPosZero); + if (KnownSrc.isKnownNever(fcPosInf)) + Known.knownNot(fcPosInf); + } + + Known.propagateNaN(KnownSrc); + return; + } + default: + break; + } + } + } + + computeKnownFPClass(Src, DemandedElts, InterestedClasses, Known, Depth + 1, + Q); + break; + } + case Instruction::PHI: { + const PHINode *P = cast<PHINode>(Op); + // Unreachable blocks may have zero-operand PHI nodes. + if (P->getNumIncomingValues() == 0) + break; + + // Otherwise take the unions of the known bit sets of the operands, + // taking conservative care to avoid excessive recursion. + const unsigned PhiRecursionLimit = MaxAnalysisRecursionDepth - 2; + + if (Depth < PhiRecursionLimit) { + // Skip if every incoming value references to ourself. + if (isa_and_nonnull<UndefValue>(P->hasConstantValue())) + break; + + bool First = true; + + for (Value *IncValue : P->incoming_values()) { + // Skip direct self references. + if (IncValue == P) + continue; + + KnownFPClass KnownSrc; + // Recurse, but cap the recursion to two levels, because we don't want + // to waste time spinning around in loops. We need at least depth 2 to + // detect known sign bits. + computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc, + PhiRecursionLimit, Q); + + if (First) { + Known = KnownSrc; + First = false; + } else { + Known |= KnownSrc; + } + + if (Known.KnownFPClasses == fcAllFlags) + break; + } + } + + break; + } + default: + break; + } +} + +KnownFPClass llvm::computeKnownFPClass( + const Value *V, const APInt &DemandedElts, const DataLayout &DL, + FPClassTest InterestedClasses, unsigned Depth, const TargetLibraryInfo *TLI, + AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + KnownFPClass KnownClasses; + ::computeKnownFPClass( + V, DemandedElts, InterestedClasses, KnownClasses, Depth, + SimplifyQuery(DL, TLI, DT, AC, safeCxtI(V, CxtI), UseInstrInfo)); + return KnownClasses; +} + +KnownFPClass llvm::computeKnownFPClass( + const Value *V, const DataLayout &DL, FPClassTest InterestedClasses, + unsigned Depth, const TargetLibraryInfo *TLI, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { + KnownFPClass Known; + ::computeKnownFPClass( + V, Known, InterestedClasses, Depth, + SimplifyQuery(DL, TLI, DT, AC, safeCxtI(V, CxtI), UseInstrInfo)); + return Known; } Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) { @@ -4530,6 +5773,16 @@ bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( case Intrinsic::strip_invariant_group: case Intrinsic::aarch64_irg: case Intrinsic::aarch64_tagp: + // The amdgcn_make_buffer_rsrc function does not alter the address of the + // input pointer (and thus preserve null-ness for the purposes of escape + // analysis, which is where the MustPreserveNullness flag comes in to play). + // However, it will not necessarily map ptr addrspace(N) null to ptr + // addrspace(8) null, aka the "null descriptor", which has "all loads return + // 0, all stores are dropped" semantics. Given the context of this intrinsic + // list, no one should be relying on such a strict interpretation of + // MustPreserveNullness (and, at time of writing, they are not), but we + // document this fact out of an abundance of caution. + case Intrinsic::amdgcn_make_buffer_rsrc: return true; case Intrinsic::ptrmask: return !MustPreserveNullness; @@ -4941,11 +6194,10 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) { static ConstantRange computeConstantRangeIncludingKnownBits( const Value *V, bool ForSigned, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, - OptimizationRemarkEmitter *ORE = nullptr, bool UseInstrInfo = true) { - KnownBits Known = computeKnownBits( - V, DL, Depth, AC, CxtI, DT, ORE, UseInstrInfo); + bool UseInstrInfo = true) { + KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT, UseInstrInfo); ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned); - ConstantRange CR2 = computeConstantRange(V, UseInstrInfo); + ConstantRange CR2 = computeConstantRange(V, ForSigned, UseInstrInfo); ConstantRange::PreferredRangeType RangeType = ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned; return CR1.intersectWith(CR2, RangeType); @@ -4956,9 +6208,9 @@ OverflowResult llvm::computeOverflowForUnsignedMul( AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT, - nullptr, UseInstrInfo); + UseInstrInfo); KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT, - nullptr, UseInstrInfo); + UseInstrInfo); ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false); ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false); return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange)); @@ -4998,9 +6250,9 @@ llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS, // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 // For simplicity we just check if at least one side is not negative. KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT, - nullptr, UseInstrInfo); + UseInstrInfo); KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT, - nullptr, UseInstrInfo); + UseInstrInfo); if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) return OverflowResult::NeverOverflows; } @@ -5012,11 +6264,9 @@ OverflowResult llvm::computeOverflowForUnsignedAdd( AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo) { ConstantRange LHSRange = computeConstantRangeIncludingKnownBits( - LHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT, - nullptr, UseInstrInfo); + LHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT, UseInstrInfo); ConstantRange RHSRange = computeConstantRangeIncludingKnownBits( - RHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT, - nullptr, UseInstrInfo); + RHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT, UseInstrInfo); return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange)); } @@ -5074,7 +6324,8 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS, if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) { KnownBits AddKnown(LHSRange.getBitWidth()); computeKnownBitsFromAssume( - Add, AddKnown, /*Depth=*/0, Query(DL, AC, CxtI, DT, true)); + Add, AddKnown, /*Depth=*/0, + SimplifyQuery(DL, /*TLI*/ nullptr, DT, AC, CxtI, DT)); if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) || (AddKnown.isNegative() && LHSOrRHSKnownNegative)) return OverflowResult::NeverOverflows; @@ -5346,7 +6597,7 @@ static bool canCreateUndefOrPoison(const Operator *Op, bool PoisonOnly, ArrayRef<int> Mask = isa<ConstantExpr>(Op) ? cast<ConstantExpr>(Op)->getShuffleMask() : cast<ShuffleVectorInst>(Op)->getShuffleMask(); - return is_contained(Mask, UndefMaskElem); + return is_contained(Mask, PoisonMaskElem); } case Instruction::FNeg: case Instruction::PHI: @@ -5421,7 +6672,7 @@ static bool directlyImpliesPoison(const Value *ValAssumedPoison, static bool impliesPoison(const Value *ValAssumedPoison, const Value *V, unsigned Depth) { - if (isGuaranteedNotToBeUndefOrPoison(ValAssumedPoison)) + if (isGuaranteedNotToBePoison(ValAssumedPoison)) return true; if (directlyImpliesPoison(ValAssumedPoison, V, /* Depth */ 0)) @@ -5459,7 +6710,9 @@ static bool isGuaranteedNotToBeUndefOrPoison(const Value *V, return false; if (const auto *A = dyn_cast<Argument>(V)) { - if (A->hasAttribute(Attribute::NoUndef)) + if (A->hasAttribute(Attribute::NoUndef) || + A->hasAttribute(Attribute::Dereferenceable) || + A->hasAttribute(Attribute::DereferenceableOrNull)) return true; } @@ -5592,6 +6845,50 @@ bool llvm::isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC, return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth, true); } +/// Return true if undefined behavior would provably be executed on the path to +/// OnPathTo if Root produced a posion result. Note that this doesn't say +/// anything about whether OnPathTo is actually executed or whether Root is +/// actually poison. This can be used to assess whether a new use of Root can +/// be added at a location which is control equivalent with OnPathTo (such as +/// immediately before it) without introducing UB which didn't previously +/// exist. Note that a false result conveys no information. +bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root, + Instruction *OnPathTo, + DominatorTree *DT) { + // Basic approach is to assume Root is poison, propagate poison forward + // through all users we can easily track, and then check whether any of those + // users are provable UB and must execute before out exiting block might + // exit. + + // The set of all recursive users we've visited (which are assumed to all be + // poison because of said visit) + SmallSet<const Value *, 16> KnownPoison; + SmallVector<const Instruction*, 16> Worklist; + Worklist.push_back(Root); + while (!Worklist.empty()) { + const Instruction *I = Worklist.pop_back_val(); + + // If we know this must trigger UB on a path leading our target. + if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo)) + return true; + + // If we can't analyze propagation through this instruction, just skip it + // and transitive users. Safe as false is a conservative result. + if (I != Root && !any_of(I->operands(), [&KnownPoison](const Use &U) { + return KnownPoison.contains(U) && propagatesPoison(U); + })) + continue; + + if (KnownPoison.insert(I).second) + for (const User *User : I->users()) + Worklist.push_back(cast<Instruction>(User)); + } + + // Might be non-UB, or might have a path we couldn't prove must execute on + // way to exiting bb. + return false; +} + OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add, const DataLayout &DL, AssumptionCache *AC, @@ -5756,7 +7053,8 @@ void llvm::getGuaranteedWellDefinedOps( Operands.push_back(CB->getCalledOperand()); for (unsigned i = 0; i < CB->arg_size(); ++i) { if (CB->paramHasAttr(i, Attribute::NoUndef) || - CB->paramHasAttr(i, Attribute::Dereferenceable)) + CB->paramHasAttr(i, Attribute::Dereferenceable) || + CB->paramHasAttr(i, Attribute::DereferenceableOrNull)) Operands.push_back(CB->getArgOperand(i)); } break; @@ -5796,7 +7094,7 @@ void llvm::getGuaranteedNonPoisonOps(const Instruction *I, } bool llvm::mustTriggerUB(const Instruction *I, - const SmallSet<const Value *, 16>& KnownPoison) { + const SmallPtrSetImpl<const Value *> &KnownPoison) { SmallVector<const Value *, 4> NonPoisonOps; getGuaranteedNonPoisonOps(I, NonPoisonOps); @@ -5882,6 +7180,15 @@ static bool programUndefinedIfUndefOrPoison(const Value *V, break; } } + + // Special handling for select, which returns poison if its operand 0 is + // poison (handled in the loop above) *or* if both its true/false operands + // are poison (handled here). + if (I.getOpcode() == Instruction::Select && + YieldsPoison.count(I.getOperand(1)) && + YieldsPoison.count(I.getOperand(2))) { + YieldsPoison.insert(&I); + } } BB = BB->getSingleSuccessor(); @@ -6618,6 +7925,12 @@ Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) { case Intrinsic::smin: return Intrinsic::smax; case Intrinsic::umax: return Intrinsic::umin; case Intrinsic::umin: return Intrinsic::umax; + // Please note that next four intrinsics may produce the same result for + // original and inverted case even if X != Y due to NaN is handled specially. + case Intrinsic::maximum: return Intrinsic::minimum; + case Intrinsic::minimum: return Intrinsic::maximum; + case Intrinsic::maxnum: return Intrinsic::minnum; + case Intrinsic::minnum: return Intrinsic::maxnum; default: llvm_unreachable("Unexpected intrinsic"); } } @@ -6765,6 +8078,10 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, if (match(RHS, m_NUWAdd(m_Specific(LHS), m_APInt(C)))) return true; + // RHS >> V u<= RHS for any V + if (match(LHS, m_LShr(m_Specific(RHS), m_Value()))) + return true; + // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB) auto MatchNUWAddsToSameValue = [&](const Value *A, const Value *B, const Value *&X, @@ -6813,12 +8130,26 @@ isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, return true; return std::nullopt; + case CmpInst::ICMP_SGT: + case CmpInst::ICMP_SGE: + if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS, DL, Depth) && + isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS, DL, Depth)) + return true; + return std::nullopt; + case CmpInst::ICMP_ULT: case CmpInst::ICMP_ULE: if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth) && isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth)) return true; return std::nullopt; + + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_UGE: + if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS, DL, Depth) && + isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS, DL, Depth)) + return true; + return std::nullopt; } } @@ -7119,7 +8450,7 @@ static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower, } else if (match(BO.getOperand(0), m_APInt(C))) { unsigned ShiftAmount = Width - 1; if (!C->isZero() && IIQ.isExact(&BO)) - ShiftAmount = C->countTrailingZeros(); + ShiftAmount = C->countr_zero(); if (C->isNegative()) { // 'ashr C, x' produces [C, C >> (Width-1)] Lower = *C; @@ -7140,7 +8471,7 @@ static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower, // 'lshr C, x' produces [C >> (Width-1), C]. unsigned ShiftAmount = Width - 1; if (!C->isZero() && IIQ.isExact(&BO)) - ShiftAmount = C->countTrailingZeros(); + ShiftAmount = C->countr_zero(); Lower = C->lshr(ShiftAmount); Upper = *C + 1; } @@ -7151,16 +8482,16 @@ static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower, if (IIQ.hasNoUnsignedWrap(&BO)) { // 'shl nuw C, x' produces [C, C << CLZ(C)] Lower = *C; - Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + Upper = Lower.shl(Lower.countl_zero()) + 1; } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw? if (C->isNegative()) { // 'shl nsw C, x' produces [C << CLO(C)-1, C] - unsigned ShiftAmount = C->countLeadingOnes() - 1; + unsigned ShiftAmount = C->countl_one() - 1; Lower = C->shl(ShiftAmount); Upper = *C + 1; } else { // 'shl nsw C, x' produces [C, C << CLZ(C)-1] - unsigned ShiftAmount = C->countLeadingZeros() - 1; + unsigned ShiftAmount = C->countl_zero() - 1; Lower = *C; Upper = C->shl(ShiftAmount) + 1; } @@ -7177,7 +8508,7 @@ static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower, // where C != -1 and C != 0 and C != 1 Lower = IntMin + 1; Upper = IntMax + 1; - } else if (C->countLeadingZeros() < Width - 1) { + } else if (C->countl_zero() < Width - 1) { // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C] // where C != -1 and C != 0 and C != 1 Lower = IntMin.sdiv(*C); @@ -7229,67 +8560,67 @@ static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower, } } -static void setLimitsForIntrinsic(const IntrinsicInst &II, APInt &Lower, - APInt &Upper) { - unsigned Width = Lower.getBitWidth(); +static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II) { + unsigned Width = II.getType()->getScalarSizeInBits(); const APInt *C; switch (II.getIntrinsicID()) { case Intrinsic::ctpop: case Intrinsic::ctlz: case Intrinsic::cttz: // Maximum of set/clear bits is the bit width. - assert(Lower == 0 && "Expected lower bound to be zero"); - Upper = Width + 1; - break; + return ConstantRange::getNonEmpty(APInt::getZero(Width), + APInt(Width, Width + 1)); case Intrinsic::uadd_sat: // uadd.sat(x, C) produces [C, UINT_MAX]. if (match(II.getOperand(0), m_APInt(C)) || match(II.getOperand(1), m_APInt(C))) - Lower = *C; + return ConstantRange::getNonEmpty(*C, APInt::getZero(Width)); break; case Intrinsic::sadd_sat: if (match(II.getOperand(0), m_APInt(C)) || match(II.getOperand(1), m_APInt(C))) { - if (C->isNegative()) { + if (C->isNegative()) // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)]. - Lower = APInt::getSignedMinValue(Width); - Upper = APInt::getSignedMaxValue(Width) + *C + 1; - } else { - // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX]. - Lower = APInt::getSignedMinValue(Width) + *C; - Upper = APInt::getSignedMaxValue(Width) + 1; - } + return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width), + APInt::getSignedMaxValue(Width) + *C + + 1); + + // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX]. + return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width) + *C, + APInt::getSignedMaxValue(Width) + 1); } break; case Intrinsic::usub_sat: // usub.sat(C, x) produces [0, C]. if (match(II.getOperand(0), m_APInt(C))) - Upper = *C + 1; + return ConstantRange::getNonEmpty(APInt::getZero(Width), *C + 1); + // usub.sat(x, C) produces [0, UINT_MAX - C]. - else if (match(II.getOperand(1), m_APInt(C))) - Upper = APInt::getMaxValue(Width) - *C + 1; + if (match(II.getOperand(1), m_APInt(C))) + return ConstantRange::getNonEmpty(APInt::getZero(Width), + APInt::getMaxValue(Width) - *C + 1); break; case Intrinsic::ssub_sat: if (match(II.getOperand(0), m_APInt(C))) { - if (C->isNegative()) { + if (C->isNegative()) // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)]. - Lower = APInt::getSignedMinValue(Width); - Upper = *C - APInt::getSignedMinValue(Width) + 1; - } else { - // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX]. - Lower = *C - APInt::getSignedMaxValue(Width); - Upper = APInt::getSignedMaxValue(Width) + 1; - } + return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width), + *C - APInt::getSignedMinValue(Width) + + 1); + + // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX]. + return ConstantRange::getNonEmpty(*C - APInt::getSignedMaxValue(Width), + APInt::getSignedMaxValue(Width) + 1); } else if (match(II.getOperand(1), m_APInt(C))) { - if (C->isNegative()) { + if (C->isNegative()) // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]: - Lower = APInt::getSignedMinValue(Width) - *C; - Upper = APInt::getSignedMaxValue(Width) + 1; - } else { - // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C]. - Lower = APInt::getSignedMinValue(Width); - Upper = APInt::getSignedMaxValue(Width) - *C + 1; - } + return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width) - *C, + APInt::getSignedMaxValue(Width) + 1); + + // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C]. + return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width), + APInt::getSignedMaxValue(Width) - *C + + 1); } break; case Intrinsic::umin: @@ -7302,19 +8633,15 @@ static void setLimitsForIntrinsic(const IntrinsicInst &II, APInt &Lower, switch (II.getIntrinsicID()) { case Intrinsic::umin: - Upper = *C + 1; - break; + return ConstantRange::getNonEmpty(APInt::getZero(Width), *C + 1); case Intrinsic::umax: - Lower = *C; - break; + return ConstantRange::getNonEmpty(*C, APInt::getZero(Width)); case Intrinsic::smin: - Lower = APInt::getSignedMinValue(Width); - Upper = *C + 1; - break; + return ConstantRange::getNonEmpty(APInt::getSignedMinValue(Width), + *C + 1); case Intrinsic::smax: - Lower = *C; - Upper = APInt::getSignedMaxValue(Width) + 1; - break; + return ConstantRange::getNonEmpty(*C, + APInt::getSignedMaxValue(Width) + 1); default: llvm_unreachable("Must be min/max intrinsic"); } @@ -7323,13 +8650,20 @@ static void setLimitsForIntrinsic(const IntrinsicInst &II, APInt &Lower, // If abs of SIGNED_MIN is poison, then the result is [0..SIGNED_MAX], // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN. if (match(II.getOperand(1), m_One())) - Upper = APInt::getSignedMaxValue(Width) + 1; - else - Upper = APInt::getSignedMinValue(Width) + 1; - break; + return ConstantRange::getNonEmpty(APInt::getZero(Width), + APInt::getSignedMaxValue(Width) + 1); + + return ConstantRange::getNonEmpty(APInt::getZero(Width), + APInt::getSignedMinValue(Width) + 1); + case Intrinsic::vscale: + if (!II.getParent() || !II.getFunction()) + break; + return getVScaleRange(II.getFunction(), Width); default: break; } + + return ConstantRange::getFull(Width); } static void setLimitsForSelectPattern(const SelectInst &SI, APInt &Lower, @@ -7418,18 +8752,28 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned, InstrInfoQuery IIQ(UseInstrInfo); unsigned BitWidth = V->getType()->getScalarSizeInBits(); - APInt Lower = APInt(BitWidth, 0); - APInt Upper = APInt(BitWidth, 0); - if (auto *BO = dyn_cast<BinaryOperator>(V)) + ConstantRange CR = ConstantRange::getFull(BitWidth); + if (auto *BO = dyn_cast<BinaryOperator>(V)) { + APInt Lower = APInt(BitWidth, 0); + APInt Upper = APInt(BitWidth, 0); + // TODO: Return ConstantRange. setLimitsForBinOp(*BO, Lower, Upper, IIQ, ForSigned); - else if (auto *II = dyn_cast<IntrinsicInst>(V)) - setLimitsForIntrinsic(*II, Lower, Upper); - else if (auto *SI = dyn_cast<SelectInst>(V)) + CR = ConstantRange::getNonEmpty(Lower, Upper); + } else if (auto *II = dyn_cast<IntrinsicInst>(V)) + CR = getRangeForIntrinsic(*II); + else if (auto *SI = dyn_cast<SelectInst>(V)) { + APInt Lower = APInt(BitWidth, 0); + APInt Upper = APInt(BitWidth, 0); + // TODO: Return ConstantRange. setLimitsForSelectPattern(*SI, Lower, Upper, IIQ); - else if (isa<FPToUIInst>(V) || isa<FPToSIInst>(V)) + CR = ConstantRange::getNonEmpty(Lower, Upper); + } else if (isa<FPToUIInst>(V) || isa<FPToSIInst>(V)) { + APInt Lower = APInt(BitWidth, 0); + APInt Upper = APInt(BitWidth, 0); + // TODO: Return ConstantRange. setLimitForFPToI(cast<Instruction>(V), Lower, Upper); - - ConstantRange CR = ConstantRange::getNonEmpty(Lower, Upper); + CR = ConstantRange::getNonEmpty(Lower, Upper); + } if (auto *I = dyn_cast<Instruction>(V)) if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range)) @@ -7440,9 +8784,11 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned, for (auto &AssumeVH : AC->assumptionsFor(V)) { if (!AssumeVH) continue; - IntrinsicInst *I = cast<IntrinsicInst>(AssumeVH); + CallInst *I = cast<CallInst>(AssumeVH); assert(I->getParent()->getParent() == CtxI->getParent()->getParent() && "Got assumption for the wrong function!"); + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); if (!isValidAssumeForContext(I, CtxI, DT)) continue; @@ -7462,74 +8808,3 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned, return CR; } - -static std::optional<int64_t> -getOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, const DataLayout &DL) { - // Skip over the first indices. - gep_type_iterator GTI = gep_type_begin(GEP); - for (unsigned i = 1; i != Idx; ++i, ++GTI) - /*skip along*/; - - // Compute the offset implied by the rest of the indices. - int64_t Offset = 0; - for (unsigned i = Idx, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { - ConstantInt *OpC = dyn_cast<ConstantInt>(GEP->getOperand(i)); - if (!OpC) - return std::nullopt; - if (OpC->isZero()) - continue; // No offset. - - // Handle struct indices, which add their field offset to the pointer. - if (StructType *STy = GTI.getStructTypeOrNull()) { - Offset += DL.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); - continue; - } - - // Otherwise, we have a sequential type like an array or fixed-length - // vector. Multiply the index by the ElementSize. - TypeSize Size = DL.getTypeAllocSize(GTI.getIndexedType()); - if (Size.isScalable()) - return std::nullopt; - Offset += Size.getFixedValue() * OpC->getSExtValue(); - } - - return Offset; -} - -std::optional<int64_t> llvm::isPointerOffset(const Value *Ptr1, - const Value *Ptr2, - const DataLayout &DL) { - APInt Offset1(DL.getIndexTypeSizeInBits(Ptr1->getType()), 0); - APInt Offset2(DL.getIndexTypeSizeInBits(Ptr2->getType()), 0); - Ptr1 = Ptr1->stripAndAccumulateConstantOffsets(DL, Offset1, true); - Ptr2 = Ptr2->stripAndAccumulateConstantOffsets(DL, Offset2, true); - - // Handle the trivial case first. - if (Ptr1 == Ptr2) - return Offset2.getSExtValue() - Offset1.getSExtValue(); - - const GEPOperator *GEP1 = dyn_cast<GEPOperator>(Ptr1); - const GEPOperator *GEP2 = dyn_cast<GEPOperator>(Ptr2); - - // Right now we handle the case when Ptr1/Ptr2 are both GEPs with an identical - // base. After that base, they may have some number of common (and - // potentially variable) indices. After that they handle some constant - // offset, which determines their offset from each other. At this point, we - // handle no other case. - if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0) || - GEP1->getSourceElementType() != GEP2->getSourceElementType()) - return std::nullopt; - - // Skip any common indices and track the GEP types. - unsigned Idx = 1; - for (; Idx != GEP1->getNumOperands() && Idx != GEP2->getNumOperands(); ++Idx) - if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx)) - break; - - auto IOffset1 = getOffsetFromIndex(GEP1, Idx, DL); - auto IOffset2 = getOffsetFromIndex(GEP2, Idx, DL); - if (!IOffset1 || !IOffset2) - return std::nullopt; - return *IOffset2 - *IOffset1 + Offset2.getSExtValue() - - Offset1.getSExtValue(); -} diff --git a/contrib/llvm-project/llvm/lib/Analysis/VectorUtils.cpp b/contrib/llvm-project/llvm/lib/Analysis/VectorUtils.cpp index 1e48d3e2fbca..87f0bb690477 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/VectorUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/VectorUtils.cpp @@ -12,6 +12,7 @@ #include "llvm/Analysis/VectorUtils.h" #include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" @@ -20,7 +21,6 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" @@ -87,6 +87,7 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::pow: case Intrinsic::fma: case Intrinsic::fmuladd: + case Intrinsic::is_fpclass: case Intrinsic::powi: case Intrinsic::canonicalize: case Intrinsic::fptosi_sat: @@ -104,6 +105,7 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, case Intrinsic::abs: case Intrinsic::ctlz: case Intrinsic::cttz: + case Intrinsic::is_fpclass: case Intrinsic::powi: return (ScalarOpdIdx == 1); case Intrinsic::smul_fix: @@ -117,15 +119,17 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, } bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, - unsigned OpdIdx) { + int OpdIdx) { switch (ID) { case Intrinsic::fptosi_sat: case Intrinsic::fptoui_sat: + return OpdIdx == -1 || OpdIdx == 0; + case Intrinsic::is_fpclass: return OpdIdx == 0; case Intrinsic::powi: - return OpdIdx == 1; + return OpdIdx == -1 || OpdIdx == 1; default: - return false; + return OpdIdx == -1; } } @@ -146,139 +150,6 @@ Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI, return Intrinsic::not_intrinsic; } -/// Find the operand of the GEP that should be checked for consecutive -/// stores. This ignores trailing indices that have no effect on the final -/// pointer. -unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { - const DataLayout &DL = Gep->getModule()->getDataLayout(); - unsigned LastOperand = Gep->getNumOperands() - 1; - TypeSize GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); - - // Walk backwards and try to peel off zeros. - while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { - // Find the type we're currently indexing into. - gep_type_iterator GEPTI = gep_type_begin(Gep); - std::advance(GEPTI, LastOperand - 2); - - // If it's a type with the same allocation size as the result of the GEP we - // can peel off the zero index. - if (DL.getTypeAllocSize(GEPTI.getIndexedType()) != GEPAllocSize) - break; - --LastOperand; - } - - return LastOperand; -} - -/// If the argument is a GEP, then returns the operand identified by -/// getGEPInductionOperand. However, if there is some other non-loop-invariant -/// operand, it returns that instead. -Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { - GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); - if (!GEP) - return Ptr; - - unsigned InductionOperand = getGEPInductionOperand(GEP); - - // Check that all of the gep indices are uniform except for our induction - // operand. - for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) - if (i != InductionOperand && - !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) - return Ptr; - return GEP->getOperand(InductionOperand); -} - -/// If a value has only one user that is a CastInst, return it. -Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { - Value *UniqueCast = nullptr; - for (User *U : Ptr->users()) { - CastInst *CI = dyn_cast<CastInst>(U); - if (CI && CI->getType() == Ty) { - if (!UniqueCast) - UniqueCast = CI; - else - return nullptr; - } - } - return UniqueCast; -} - -/// Get the stride of a pointer access in a loop. Looks for symbolic -/// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. -Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { - auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); - if (!PtrTy || PtrTy->isAggregateType()) - return nullptr; - - // Try to remove a gep instruction to make the pointer (actually index at this - // point) easier analyzable. If OrigPtr is equal to Ptr we are analyzing the - // pointer, otherwise, we are analyzing the index. - Value *OrigPtr = Ptr; - - // The size of the pointer access. - int64_t PtrAccessSize = 1; - - Ptr = stripGetElementPtr(Ptr, SE, Lp); - const SCEV *V = SE->getSCEV(Ptr); - - if (Ptr != OrigPtr) - // Strip off casts. - while (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V)) - V = C->getOperand(); - - const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); - if (!S) - return nullptr; - - V = S->getStepRecurrence(*SE); - if (!V) - return nullptr; - - // Strip off the size of access multiplication if we are still analyzing the - // pointer. - if (OrigPtr == Ptr) { - if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { - if (M->getOperand(0)->getSCEVType() != scConstant) - return nullptr; - - const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); - - // Huge step value - give up. - if (APStepVal.getBitWidth() > 64) - return nullptr; - - int64_t StepVal = APStepVal.getSExtValue(); - if (PtrAccessSize != StepVal) - return nullptr; - V = M->getOperand(1); - } - } - - // Strip off casts. - Type *StripedOffRecurrenceCast = nullptr; - if (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V)) { - StripedOffRecurrenceCast = C->getType(); - V = C->getOperand(); - } - - // Look for the loop invariant symbolic value. - const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); - if (!U) - return nullptr; - - Value *Stride = U->getValue(); - if (!Lp->isLoopInvariant(Stride)) - return nullptr; - - // If we have stripped off the recurrence cast we have to make sure that we - // return the value that is used in this loop so that we can replace it later. - if (StripedOffRecurrenceCast) - Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); - - return Stride; -} - /// Given a vector and an element number, see if the scalar value is /// already around as a register, for example if it were inserted then extracted /// from the vector. @@ -574,13 +445,13 @@ void llvm::processShuffleMasks( int Idx = I * SzDest + K; if (Idx == Sz) break; - if (Mask[Idx] >= Sz || Mask[Idx] == UndefMaskElem) + if (Mask[Idx] >= Sz || Mask[Idx] == PoisonMaskElem) continue; int SrcRegIdx = Mask[Idx] / SzSrc; // Add a cost of PermuteTwoSrc for each new source register permute, // if we have more than one source registers. if (RegMasks[SrcRegIdx].empty()) - RegMasks[SrcRegIdx].assign(SzDest, UndefMaskElem); + RegMasks[SrcRegIdx].assign(SzDest, PoisonMaskElem); RegMasks[SrcRegIdx][K] = Mask[Idx] % SzSrc; } } @@ -612,8 +483,8 @@ void llvm::processShuffleMasks( auto &&CombineMasks = [](MutableArrayRef<int> FirstMask, ArrayRef<int> SecondMask) { for (int Idx = 0, VF = FirstMask.size(); Idx < VF; ++Idx) { - if (SecondMask[Idx] != UndefMaskElem) { - assert(FirstMask[Idx] == UndefMaskElem && + if (SecondMask[Idx] != PoisonMaskElem) { + assert(FirstMask[Idx] == PoisonMaskElem && "Expected undefined mask element."); FirstMask[Idx] = SecondMask[Idx] + VF; } @@ -621,7 +492,7 @@ void llvm::processShuffleMasks( }; auto &&NormalizeMask = [](MutableArrayRef<int> Mask) { for (int Idx = 0, VF = Mask.size(); Idx < VF; ++Idx) { - if (Mask[Idx] != UndefMaskElem) + if (Mask[Idx] != PoisonMaskElem) Mask[Idx] = Idx; } }; @@ -770,11 +641,9 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, for (Value *M : llvm::make_range(ECs.member_begin(I), ECs.member_end())) LeaderDemandedBits |= DBits[M]; - uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - - llvm::countLeadingZeros(LeaderDemandedBits); + uint64_t MinBW = llvm::bit_width(LeaderDemandedBits); // Round up to a power of 2 - if (!isPowerOf2_64((uint64_t)MinBW)) - MinBW = NextPowerOf2(MinBW); + MinBW = llvm::bit_ceil(MinBW); // We don't modify the types of PHIs. Reductions will already have been // truncated if possible, and inductions' sizes will have been chosen by @@ -790,13 +659,32 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, continue; for (Value *M : llvm::make_range(ECs.member_begin(I), ECs.member_end())) { - if (!isa<Instruction>(M)) + auto *MI = dyn_cast<Instruction>(M); + if (!MI) continue; Type *Ty = M->getType(); if (Roots.count(M)) - Ty = cast<Instruction>(M)->getOperand(0)->getType(); - if (MinBW < Ty->getScalarSizeInBits()) - MinBWs[cast<Instruction>(M)] = MinBW; + Ty = MI->getOperand(0)->getType(); + + if (MinBW >= Ty->getScalarSizeInBits()) + continue; + + // If any of M's operands demand more bits than MinBW then M cannot be + // performed safely in MinBW. + if (any_of(MI->operands(), [&DB, MinBW](Use &U) { + auto *CI = dyn_cast<ConstantInt>(U); + // For constants shift amounts, check if the shift would result in + // poison. + if (CI && + isa<ShlOperator, LShrOperator, AShrOperator>(U.getUser()) && + U.getOperandNo() == 1) + return CI->uge(MinBW); + uint64_t BW = bit_width(DB.getDemandedBits(&U).getZExtValue()); + return bit_ceil(BW) > MinBW; + })) + continue; + + MinBWs[MI] = MinBW; } } @@ -1143,7 +1031,7 @@ bool InterleavedAccessInfo::isStrided(int Stride) { void InterleavedAccessInfo::collectConstStrideAccesses( MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, - const ValueToValueMap &Strides) { + const DenseMap<Value*, const SCEV*> &Strides) { auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); // Since it's desired that the load/store instructions be maintained in @@ -1223,7 +1111,7 @@ void InterleavedAccessInfo::collectConstStrideAccesses( void InterleavedAccessInfo::analyzeInterleaving( bool EnablePredicatedInterleavedMemAccesses) { LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); - const ValueToValueMap &Strides = LAI->getSymbolicStrides(); + const auto &Strides = LAI->getSymbolicStrides(); // Holds all accesses with a constant stride. MapVector<Instruction *, StrideDescriptor> AccessStrideInfo; @@ -1239,6 +1127,8 @@ void InterleavedAccessInfo::analyzeInterleaving( SmallSetVector<InterleaveGroup<Instruction> *, 4> StoreGroups; // Holds all interleaved load groups temporarily. SmallSetVector<InterleaveGroup<Instruction> *, 4> LoadGroups; + // Groups added to this set cannot have new members added. + SmallPtrSet<InterleaveGroup<Instruction> *, 4> CompletedLoadGroups; // Search in bottom-up program order for pairs of accesses (A and B) that can // form interleaved load or store groups. In the algorithm below, access A @@ -1260,19 +1150,22 @@ void InterleavedAccessInfo::analyzeInterleaving( // Initialize a group for B if it has an allowable stride. Even if we don't // create a group for B, we continue with the bottom-up algorithm to ensure // we don't break any of B's dependences. - InterleaveGroup<Instruction> *Group = nullptr; + InterleaveGroup<Instruction> *GroupB = nullptr; if (isStrided(DesB.Stride) && (!isPredicated(B->getParent()) || EnablePredicatedInterleavedMemAccesses)) { - Group = getInterleaveGroup(B); - if (!Group) { + GroupB = getInterleaveGroup(B); + if (!GroupB) { LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B << '\n'); - Group = createInterleaveGroup(B, DesB.Stride, DesB.Alignment); + GroupB = createInterleaveGroup(B, DesB.Stride, DesB.Alignment); + } else if (CompletedLoadGroups.contains(GroupB)) { + // Skip B if no new instructions can be added to its load group. + continue; } if (B->mayWriteToMemory()) - StoreGroups.insert(Group); + StoreGroups.insert(GroupB); else - LoadGroups.insert(Group); + LoadGroups.insert(GroupB); } for (auto AI = std::next(BI); AI != E; ++AI) { @@ -1313,6 +1206,16 @@ void InterleavedAccessInfo::analyzeInterleaving( StoreGroups.remove(StoreGroup); releaseGroup(StoreGroup); } + // If B is a load and part of an interleave group, no earlier loads can + // be added to B's interleave group, because this would mean the load B + // would need to be moved across store A. Mark the interleave group as + // complete. + if (GroupB && isa<LoadInst>(B)) { + LLVM_DEBUG(dbgs() << "LV: Marking interleave group for " << *B + << " as complete.\n"); + + CompletedLoadGroups.insert(GroupB); + } // If a dependence exists and A is not already in a group (or it was // and we just released it), B might be hoisted above A (if B is a @@ -1371,18 +1274,18 @@ void InterleavedAccessInfo::analyzeInterleaving( // The index of A is the index of B plus A's distance to B in multiples // of the size. int IndexA = - Group->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size); + GroupB->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size); // Try to insert A into B's group. - if (Group->insertMember(A, IndexA, DesA.Alignment)) { + if (GroupB->insertMember(A, IndexA, DesA.Alignment)) { LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' << " into the interleave group with" << *B << '\n'); - InterleaveGroupMap[A] = Group; + InterleaveGroupMap[A] = GroupB; // Set the first load in program order as the insert position. if (A->mayReadFromMemory()) - Group->setInsertPos(A); + GroupB->setInsertPos(A); } } // Iteration over A accesses. } // Iteration over B accesses. @@ -1531,10 +1434,10 @@ void InterleaveGroup<Instruction>::addMetadata(Instruction *NewInst) const { std::string VFABI::mangleTLIVectorName(StringRef VectorName, StringRef ScalarName, unsigned numArgs, - ElementCount VF) { + ElementCount VF, bool Masked) { SmallString<256> Buffer; llvm::raw_svector_ostream Out(Buffer); - Out << "_ZGV" << VFABI::_LLVM_ << "N"; + Out << "_ZGV" << VFABI::_LLVM_ << (Masked ? "M" : "N"); if (VF.isScalable()) Out << 'x'; else |