diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/IPO')
27 files changed, 3897 insertions, 1523 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index 824da6395f2e..fb3fa8d23daa 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -121,19 +121,24 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, // that we are *not* promoting. For the ones that we do promote, the parameter // attributes are lost SmallVector<AttributeSet, 8> ArgAttrVec; + // Mapping from old to new argument indices. -1 for promoted or removed + // arguments. + SmallVector<unsigned> NewArgIndices; AttributeList PAL = F->getAttributes(); // First, determine the new argument list - unsigned ArgNo = 0; + unsigned ArgNo = 0, NewArgNo = 0; for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I, ++ArgNo) { if (!ArgsToPromote.count(&*I)) { // Unchanged argument Params.push_back(I->getType()); ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo)); + NewArgIndices.push_back(NewArgNo++); } else if (I->use_empty()) { // Dead argument (which are always marked as promotable) ++NumArgumentsDead; + NewArgIndices.push_back((unsigned)-1); } else { const auto &ArgParts = ArgsToPromote.find(&*I)->second; for (const auto &Pair : ArgParts) { @@ -141,6 +146,8 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, ArgAttrVec.push_back(AttributeSet()); } ++NumArgumentsPromoted; + NewArgIndices.push_back((unsigned)-1); + NewArgNo += ArgParts.size(); } } @@ -154,6 +161,7 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, F->getName()); NF->copyAttributesFrom(F); NF->copyMetadata(F, 0); + NF->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat); // The new function will have the !dbg metadata copied from the original // function. The original function may not be deleted, and dbg metadata need @@ -173,6 +181,19 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, // the function. NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(), PAL.getRetAttrs(), ArgAttrVec)); + + // Remap argument indices in allocsize attribute. + if (auto AllocSize = NF->getAttributes().getFnAttrs().getAllocSizeArgs()) { + unsigned Arg1 = NewArgIndices[AllocSize->first]; + assert(Arg1 != (unsigned)-1 && "allocsize cannot be promoted argument"); + std::optional<unsigned> Arg2; + if (AllocSize->second) { + Arg2 = NewArgIndices[*AllocSize->second]; + assert(Arg2 != (unsigned)-1 && "allocsize cannot be promoted argument"); + } + NF->addFnAttr(Attribute::getWithAllocSizeArgs(F->getContext(), Arg1, Arg2)); + } + AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth); ArgAttrVec.clear(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp index 847d07a49dee..d8e290cbc8a4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CallGraph.h" @@ -50,6 +51,7 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include <cstdint> +#include <memory> #ifdef EXPENSIVE_CHECKS #include "llvm/IR/Verifier.h" @@ -93,6 +95,13 @@ static cl::opt<unsigned> cl::desc("Maximal number of fixpoint iterations."), cl::init(32)); +static cl::opt<unsigned> + MaxSpecializationPerCB("attributor-max-specializations-per-call-base", + cl::Hidden, + cl::desc("Maximal number of callees specialized for " + "a call base"), + cl::init(UINT32_MAX)); + static cl::opt<unsigned, true> MaxInitializationChainLengthX( "attributor-max-initialization-chain-length", cl::Hidden, cl::desc( @@ -166,6 +175,10 @@ static cl::opt<bool> SimplifyAllLoads("attributor-simplify-all-loads", cl::desc("Try to simplify all loads."), cl::init(true)); +static cl::opt<bool> CloseWorldAssumption( + "attributor-assume-closed-world", cl::Hidden, + cl::desc("Should a closed world be assumed, or not. Default if not set.")); + /// Logic operators for the change status enum class. /// ///{ @@ -226,10 +239,10 @@ bool AA::isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA, return InstanceInfoAA && InstanceInfoAA->isAssumedUniqueForAnalysis(); } -Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty, - const TargetLibraryInfo *TLI, - const DataLayout &DL, - AA::RangeTy *RangePtr) { +Constant * +AA::getInitialValueForObj(Attributor &A, const AbstractAttribute &QueryingAA, + Value &Obj, Type &Ty, const TargetLibraryInfo *TLI, + const DataLayout &DL, AA::RangeTy *RangePtr) { if (isa<AllocaInst>(Obj)) return UndefValue::get(&Ty); if (Constant *Init = getInitialValueOfAllocation(&Obj, TLI, &Ty)) @@ -242,12 +255,13 @@ Constant *AA::getInitialValueForObj(Attributor &A, Value &Obj, Type &Ty, Constant *Initializer = nullptr; if (A.hasGlobalVariableSimplificationCallback(*GV)) { auto AssumedGV = A.getAssumedInitializerFromCallBack( - *GV, /* const AbstractAttribute *AA */ nullptr, UsedAssumedInformation); + *GV, &QueryingAA, UsedAssumedInformation); Initializer = *AssumedGV; if (!Initializer) return nullptr; } else { - if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer())) + if (!GV->hasLocalLinkage() && + (GV->isInterposable() || !(GV->isConstant() && GV->hasInitializer()))) return nullptr; if (!GV->hasInitializer()) return UndefValue::get(&Ty); @@ -316,7 +330,7 @@ Value *AA::getWithType(Value &V, Type &Ty) { if (C->getType()->isIntegerTy() && Ty.isIntegerTy()) return ConstantExpr::getTrunc(C, &Ty, /* OnlyIfReduced */ true); if (C->getType()->isFloatingPointTy() && Ty.isFloatingPointTy()) - return ConstantExpr::getFPTrunc(C, &Ty, /* OnlyIfReduced */ true); + return ConstantFoldCastInstruction(Instruction::FPTrunc, C, &Ty); } } return nullptr; @@ -350,7 +364,7 @@ AA::combineOptionalValuesInAAValueLatice(const std::optional<Value *> &A, template <bool IsLoad, typename Ty> static bool getPotentialCopiesOfMemoryValue( Attributor &A, Ty &I, SmallSetVector<Value *, 4> &PotentialCopies, - SmallSetVector<Instruction *, 4> &PotentialValueOrigins, + SmallSetVector<Instruction *, 4> *PotentialValueOrigins, const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, bool OnlyExact) { LLVM_DEBUG(dbgs() << "Trying to determine the potential copies of " << I @@ -361,8 +375,8 @@ static bool getPotentialCopiesOfMemoryValue( // sure that we can find all of them. If we abort we want to avoid spurious // dependences and potential copies in the provided container. SmallVector<const AAPointerInfo *> PIs; - SmallVector<Value *> NewCopies; - SmallVector<Instruction *> NewCopyOrigins; + SmallSetVector<Value *, 8> NewCopies; + SmallSetVector<Instruction *, 8> NewCopyOrigins; const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*I.getFunction()); @@ -425,6 +439,30 @@ static bool getPotentialCopiesOfMemoryValue( return AdjV; }; + auto SkipCB = [&](const AAPointerInfo::Access &Acc) { + if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead())) + return true; + if (IsLoad) { + if (Acc.isWrittenValueYetUndetermined()) + return true; + if (PotentialValueOrigins && !isa<AssumeInst>(Acc.getRemoteInst())) + return false; + if (!Acc.isWrittenValueUnknown()) + if (Value *V = AdjustWrittenValueType(Acc, *Acc.getWrittenValue())) + if (NewCopies.count(V)) { + NewCopyOrigins.insert(Acc.getRemoteInst()); + return true; + } + if (auto *SI = dyn_cast<StoreInst>(Acc.getRemoteInst())) + if (Value *V = AdjustWrittenValueType(Acc, *SI->getValueOperand())) + if (NewCopies.count(V)) { + NewCopyOrigins.insert(Acc.getRemoteInst()); + return true; + } + } + return false; + }; + auto CheckAccess = [&](const AAPointerInfo::Access &Acc, bool IsExact) { if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead())) return true; @@ -449,8 +487,9 @@ static bool getPotentialCopiesOfMemoryValue( Value *V = AdjustWrittenValueType(Acc, *Acc.getWrittenValue()); if (!V) return false; - NewCopies.push_back(V); - NewCopyOrigins.push_back(Acc.getRemoteInst()); + NewCopies.insert(V); + if (PotentialValueOrigins) + NewCopyOrigins.insert(Acc.getRemoteInst()); return true; } auto *SI = dyn_cast<StoreInst>(Acc.getRemoteInst()); @@ -463,8 +502,9 @@ static bool getPotentialCopiesOfMemoryValue( Value *V = AdjustWrittenValueType(Acc, *SI->getValueOperand()); if (!V) return false; - NewCopies.push_back(V); - NewCopyOrigins.push_back(SI); + NewCopies.insert(V); + if (PotentialValueOrigins) + NewCopyOrigins.insert(SI); } else { assert(isa<StoreInst>(I) && "Expected load or store instruction only!"); auto *LI = dyn_cast<LoadInst>(Acc.getRemoteInst()); @@ -474,7 +514,7 @@ static bool getPotentialCopiesOfMemoryValue( << *Acc.getRemoteInst() << "\n";); return false; } - NewCopies.push_back(Acc.getRemoteInst()); + NewCopies.insert(Acc.getRemoteInst()); } return true; }; @@ -486,11 +526,11 @@ static bool getPotentialCopiesOfMemoryValue( AA::RangeTy Range; auto *PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(Obj), DepClassTy::NONE); - if (!PI || - !PI->forallInterferingAccesses(A, QueryingAA, I, - /* FindInterferingWrites */ IsLoad, - /* FindInterferingReads */ !IsLoad, - CheckAccess, HasBeenWrittenTo, Range)) { + if (!PI || !PI->forallInterferingAccesses( + A, QueryingAA, I, + /* FindInterferingWrites */ IsLoad, + /* FindInterferingReads */ !IsLoad, CheckAccess, + HasBeenWrittenTo, Range, SkipCB)) { LLVM_DEBUG( dbgs() << "Failed to verify all interfering accesses for underlying object: " @@ -500,8 +540,8 @@ static bool getPotentialCopiesOfMemoryValue( if (IsLoad && !HasBeenWrittenTo && !Range.isUnassigned()) { const DataLayout &DL = A.getDataLayout(); - Value *InitialValue = - AA::getInitialValueForObj(A, Obj, *I.getType(), TLI, DL, &Range); + Value *InitialValue = AA::getInitialValueForObj( + A, QueryingAA, Obj, *I.getType(), TLI, DL, &Range); if (!InitialValue) { LLVM_DEBUG(dbgs() << "Could not determine required initial value of " "underlying object, abort!\n"); @@ -514,8 +554,9 @@ static bool getPotentialCopiesOfMemoryValue( return false; } - NewCopies.push_back(InitialValue); - NewCopyOrigins.push_back(nullptr); + NewCopies.insert(InitialValue); + if (PotentialValueOrigins) + NewCopyOrigins.insert(nullptr); } PIs.push_back(PI); @@ -540,7 +581,8 @@ static bool getPotentialCopiesOfMemoryValue( A.recordDependence(*PI, QueryingAA, DepClassTy::OPTIONAL); } PotentialCopies.insert(NewCopies.begin(), NewCopies.end()); - PotentialValueOrigins.insert(NewCopyOrigins.begin(), NewCopyOrigins.end()); + if (PotentialValueOrigins) + PotentialValueOrigins->insert(NewCopyOrigins.begin(), NewCopyOrigins.end()); return true; } @@ -551,7 +593,7 @@ bool AA::getPotentiallyLoadedValues( const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, bool OnlyExact) { return getPotentialCopiesOfMemoryValue</* IsLoad */ true>( - A, LI, PotentialValues, PotentialValueOrigins, QueryingAA, + A, LI, PotentialValues, &PotentialValueOrigins, QueryingAA, UsedAssumedInformation, OnlyExact); } @@ -559,10 +601,9 @@ bool AA::getPotentialCopiesOfStoredValue( Attributor &A, StoreInst &SI, SmallSetVector<Value *, 4> &PotentialCopies, const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, bool OnlyExact) { - SmallSetVector<Instruction *, 4> PotentialValueOrigins; return getPotentialCopiesOfMemoryValue</* IsLoad */ false>( - A, SI, PotentialCopies, PotentialValueOrigins, QueryingAA, - UsedAssumedInformation, OnlyExact); + A, SI, PotentialCopies, nullptr, QueryingAA, UsedAssumedInformation, + OnlyExact); } static bool isAssumedReadOnlyOrReadNone(Attributor &A, const IRPosition &IRP, @@ -723,7 +764,7 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, // Check if we can reach returns. bool UsedAssumedInformation = false; - if (A.checkForAllInstructions(ReturnInstCB, FromFn, QueryingAA, + if (A.checkForAllInstructions(ReturnInstCB, FromFn, &QueryingAA, {Instruction::Ret}, UsedAssumedInformation)) { LLVM_DEBUG(dbgs() << "[AA] No return is reachable, done\n"); continue; @@ -1021,6 +1062,23 @@ ChangeStatus AbstractAttribute::update(Attributor &A) { return HasChanged; } +Attributor::Attributor(SetVector<Function *> &Functions, + InformationCache &InfoCache, + AttributorConfig Configuration) + : Allocator(InfoCache.Allocator), Functions(Functions), + InfoCache(InfoCache), Configuration(Configuration) { + if (!isClosedWorldModule()) + return; + for (Function *Fn : Functions) + if (Fn->hasAddressTaken(/*PutOffender=*/nullptr, + /*IgnoreCallbackUses=*/false, + /*IgnoreAssumeLikeCalls=*/true, + /*IgnoreLLVMUsed=*/true, + /*IgnoreARCAttachedCall=*/false, + /*IgnoreCastedDirectCall=*/true)) + InfoCache.IndirectlyCallableFunctions.push_back(Fn); +} + bool Attributor::getAttrsFromAssumes(const IRPosition &IRP, Attribute::AttrKind AK, SmallVectorImpl<Attribute> &Attrs) { @@ -1053,8 +1111,7 @@ bool Attributor::getAttrsFromAssumes(const IRPosition &IRP, template <typename DescTy> ChangeStatus -Attributor::updateAttrMap(const IRPosition &IRP, - const ArrayRef<DescTy> &AttrDescs, +Attributor::updateAttrMap(const IRPosition &IRP, ArrayRef<DescTy> AttrDescs, function_ref<bool(const DescTy &, AttributeSet, AttributeMask &, AttrBuilder &)> CB) { @@ -1161,9 +1218,8 @@ void Attributor::getAttrs(const IRPosition &IRP, getAttrsFromAssumes(IRP, AK, Attrs); } -ChangeStatus -Attributor::removeAttrs(const IRPosition &IRP, - const ArrayRef<Attribute::AttrKind> &AttrKinds) { +ChangeStatus Attributor::removeAttrs(const IRPosition &IRP, + ArrayRef<Attribute::AttrKind> AttrKinds) { auto RemoveAttrCB = [&](const Attribute::AttrKind &Kind, AttributeSet AttrSet, AttributeMask &AM, AttrBuilder &) { if (!AttrSet.hasAttribute(Kind)) @@ -1174,8 +1230,21 @@ Attributor::removeAttrs(const IRPosition &IRP, return updateAttrMap<Attribute::AttrKind>(IRP, AttrKinds, RemoveAttrCB); } +ChangeStatus Attributor::removeAttrs(const IRPosition &IRP, + ArrayRef<StringRef> Attrs) { + auto RemoveAttrCB = [&](StringRef Attr, AttributeSet AttrSet, + AttributeMask &AM, AttrBuilder &) -> bool { + if (!AttrSet.hasAttribute(Attr)) + return false; + AM.addAttribute(Attr); + return true; + }; + + return updateAttrMap<StringRef>(IRP, Attrs, RemoveAttrCB); +} + ChangeStatus Attributor::manifestAttrs(const IRPosition &IRP, - const ArrayRef<Attribute> &Attrs, + ArrayRef<Attribute> Attrs, bool ForceReplace) { LLVMContext &Ctx = IRP.getAnchorValue().getContext(); auto AddAttrCB = [&](const Attribute &Attr, AttributeSet AttrSet, @@ -1665,6 +1734,21 @@ bool Attributor::isAssumedDead(const BasicBlock &BB, return false; } +bool Attributor::checkForAllCallees( + function_ref<bool(ArrayRef<const Function *>)> Pred, + const AbstractAttribute &QueryingAA, const CallBase &CB) { + if (const Function *Callee = dyn_cast<Function>(CB.getCalledOperand())) + return Pred(Callee); + + const auto *CallEdgesAA = getAAFor<AACallEdges>( + QueryingAA, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); + if (!CallEdgesAA || CallEdgesAA->hasUnknownCallee()) + return false; + + const auto &Callees = CallEdgesAA->getOptimisticEdges(); + return Pred(Callees.getArrayRef()); +} + bool Attributor::checkForAllUses( function_ref<bool(const Use &, bool &)> Pred, const AbstractAttribute &QueryingAA, const Value &V, @@ -1938,7 +2022,7 @@ bool Attributor::checkForAllReturnedValues(function_ref<bool(Value &)> Pred, static bool checkForAllInstructionsImpl( Attributor *A, InformationCache::OpcodeInstMapTy &OpcodeInstMap, function_ref<bool(Instruction &)> Pred, const AbstractAttribute *QueryingAA, - const AAIsDead *LivenessAA, const ArrayRef<unsigned> &Opcodes, + const AAIsDead *LivenessAA, ArrayRef<unsigned> Opcodes, bool &UsedAssumedInformation, bool CheckBBLivenessOnly = false, bool CheckPotentiallyDead = false) { for (unsigned Opcode : Opcodes) { @@ -1967,8 +2051,8 @@ static bool checkForAllInstructionsImpl( bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, const Function *Fn, - const AbstractAttribute &QueryingAA, - const ArrayRef<unsigned> &Opcodes, + const AbstractAttribute *QueryingAA, + ArrayRef<unsigned> Opcodes, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, bool CheckPotentiallyDead) { @@ -1978,12 +2062,12 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, const IRPosition &QueryIRP = IRPosition::function(*Fn); const auto *LivenessAA = - CheckPotentiallyDead - ? nullptr - : (getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE)); + CheckPotentiallyDead && QueryingAA + ? (getAAFor<AAIsDead>(*QueryingAA, QueryIRP, DepClassTy::NONE)) + : nullptr; auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(*Fn); - if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, &QueryingAA, + if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, QueryingAA, LivenessAA, Opcodes, UsedAssumedInformation, CheckBBLivenessOnly, CheckPotentiallyDead)) return false; @@ -1993,13 +2077,13 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, const AbstractAttribute &QueryingAA, - const ArrayRef<unsigned> &Opcodes, + ArrayRef<unsigned> Opcodes, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, bool CheckPotentiallyDead) { const IRPosition &IRP = QueryingAA.getIRPosition(); const Function *AssociatedFunction = IRP.getAssociatedFunction(); - return checkForAllInstructions(Pred, AssociatedFunction, QueryingAA, Opcodes, + return checkForAllInstructions(Pred, AssociatedFunction, &QueryingAA, Opcodes, UsedAssumedInformation, CheckBBLivenessOnly, CheckPotentiallyDead); } @@ -2964,6 +3048,18 @@ ChangeStatus Attributor::rewriteFunctionSignatures( NewArgumentAttributes)); AttributeFuncs::updateMinLegalVectorWidthAttr(*NewFn, LargestVectorWidth); + // Remove argmem from the memory effects if we have no more pointer + // arguments, or they are readnone. + MemoryEffects ME = NewFn->getMemoryEffects(); + int ArgNo = -1; + if (ME.doesAccessArgPointees() && all_of(NewArgumentTypes, [&](Type *T) { + ++ArgNo; + return !T->isPtrOrPtrVectorTy() || + NewFn->hasParamAttribute(ArgNo, Attribute::ReadNone); + })) { + NewFn->setMemoryEffects(ME - MemoryEffects::argMemOnly()); + } + // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the // function empty. @@ -3203,6 +3299,12 @@ InformationCache::FunctionInfo::~FunctionInfo() { It.getSecond()->~InstructionVectorTy(); } +const ArrayRef<Function *> +InformationCache::getIndirectlyCallableFunctions(Attributor &A) const { + assert(A.isClosedWorldModule() && "Cannot see all indirect callees!"); + return IndirectlyCallableFunctions; +} + void Attributor::recordDependence(const AbstractAttribute &FromAA, const AbstractAttribute &ToAA, DepClassTy DepClass) { @@ -3236,9 +3338,10 @@ void Attributor::checkAndQueryIRAttr(const IRPosition &IRP, AttributeSet Attrs) { bool IsKnown; if (!Attrs.hasAttribute(AK)) - if (!AA::hasAssumedIRAttr<AK>(*this, nullptr, IRP, DepClassTy::NONE, - IsKnown)) - getOrCreateAAFor<AAType>(IRP); + if (!Configuration.Allowed || Configuration.Allowed->count(&AAType::ID)) + if (!AA::hasAssumedIRAttr<AK>(*this, nullptr, IRP, DepClassTy::NONE, + IsKnown)) + getOrCreateAAFor<AAType>(IRP); } void Attributor::identifyDefaultAbstractAttributes(Function &F) { @@ -3285,6 +3388,9 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every function might be "will-return". checkAndQueryIRAttr<Attribute::WillReturn, AAWillReturn>(FPos, FnAttrs); + // Every function might be marked "nosync" + checkAndQueryIRAttr<Attribute::NoSync, AANoSync>(FPos, FnAttrs); + // Everything that is visible from the outside (=function, argument, return // positions), cannot be changed if the function is not IPO amendable. We can // however analyse the code inside. @@ -3293,9 +3399,6 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every function can be nounwind. checkAndQueryIRAttr<Attribute::NoUnwind, AANoUnwind>(FPos, FnAttrs); - // Every function might be marked "nosync" - checkAndQueryIRAttr<Attribute::NoSync, AANoSync>(FPos, FnAttrs); - // Every function might be "no-return". checkAndQueryIRAttr<Attribute::NoReturn, AANoReturn>(FPos, FnAttrs); @@ -3315,6 +3418,14 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { // Every function can track active assumptions. getOrCreateAAFor<AAAssumptionInfo>(FPos); + // If we're not using a dynamic mode for float, there's nothing worthwhile + // to infer. This misses the edge case denormal-fp-math="dynamic" and + // denormal-fp-math-f32=something, but that likely has no real world use. + DenormalMode Mode = F.getDenormalMode(APFloat::IEEEsingle()); + if (Mode.Input == DenormalMode::Dynamic || + Mode.Output == DenormalMode::Dynamic) + getOrCreateAAFor<AADenormalFPMath>(FPos); + // Return attributes are only appropriate if the return type is non void. Type *ReturnType = F.getReturnType(); if (!ReturnType->isVoidTy()) { @@ -3420,8 +3531,10 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { Function *Callee = dyn_cast_if_present<Function>(CB.getCalledOperand()); // TODO: Even if the callee is not known now we might be able to simplify // the call/callee. - if (!Callee) + if (!Callee) { + getOrCreateAAFor<AAIndirectCallInfo>(CBFnPos); return true; + } // Every call site can track active assumptions. getOrCreateAAFor<AAAssumptionInfo>(CBFnPos); @@ -3498,14 +3611,13 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { }; auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); - bool Success; + [[maybe_unused]] bool Success; bool UsedAssumedInformation = false; Success = checkForAllInstructionsImpl( nullptr, OpcodeInstMap, CallSitePred, nullptr, nullptr, {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, (unsigned)Instruction::Call}, UsedAssumedInformation); - (void)Success; assert(Success && "Expected the check call to be successful!"); auto LoadStorePred = [&](Instruction &I) -> bool { @@ -3531,10 +3643,26 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { nullptr, OpcodeInstMap, LoadStorePred, nullptr, nullptr, {(unsigned)Instruction::Load, (unsigned)Instruction::Store}, UsedAssumedInformation); - (void)Success; + assert(Success && "Expected the check call to be successful!"); + + // AllocaInstPredicate + auto AAAllocationInfoPred = [&](Instruction &I) -> bool { + getOrCreateAAFor<AAAllocationInfo>(IRPosition::value(I)); + return true; + }; + + Success = checkForAllInstructionsImpl( + nullptr, OpcodeInstMap, AAAllocationInfoPred, nullptr, nullptr, + {(unsigned)Instruction::Alloca}, UsedAssumedInformation); assert(Success && "Expected the check call to be successful!"); } +bool Attributor::isClosedWorldModule() const { + if (CloseWorldAssumption.getNumOccurrences()) + return CloseWorldAssumption; + return isModulePass() && Configuration.IsClosedWorldModule; +} + /// Helpers to ease debugging through output streams and print calls. /// ///{ @@ -3696,6 +3824,26 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, AttributorConfig AC(CGUpdater); AC.IsModulePass = IsModulePass; AC.DeleteFns = DeleteFns; + + /// Tracking callback for specialization of indirect calls. + DenseMap<CallBase *, std::unique_ptr<SmallPtrSet<Function *, 8>>> + IndirectCalleeTrackingMap; + if (MaxSpecializationPerCB.getNumOccurrences()) { + AC.IndirectCalleeSpecializationCallback = + [&](Attributor &, const AbstractAttribute &AA, CallBase &CB, + Function &Callee) { + if (MaxSpecializationPerCB == 0) + return false; + auto &Set = IndirectCalleeTrackingMap[&CB]; + if (!Set) + Set = std::make_unique<SmallPtrSet<Function *, 8>>(); + if (Set->size() >= MaxSpecializationPerCB) + return Set->contains(&Callee); + Set->insert(&Callee); + return true; + }; + } + Attributor A(Functions, InfoCache, AC); // Create shallow wrappers for all functions that are not IPO amendable @@ -3759,6 +3907,88 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, return Changed == ChangeStatus::CHANGED; } +static bool runAttributorLightOnFunctions(InformationCache &InfoCache, + SetVector<Function *> &Functions, + AnalysisGetter &AG, + CallGraphUpdater &CGUpdater, + FunctionAnalysisManager &FAM, + bool IsModulePass) { + if (Functions.empty()) + return false; + + LLVM_DEBUG({ + dbgs() << "[AttributorLight] Run on module with " << Functions.size() + << " functions:\n"; + for (Function *Fn : Functions) + dbgs() << " - " << Fn->getName() << "\n"; + }); + + // Create an Attributor and initially empty information cache that is filled + // while we identify default attribute opportunities. + AttributorConfig AC(CGUpdater); + AC.IsModulePass = IsModulePass; + AC.DeleteFns = false; + DenseSet<const char *> Allowed( + {&AAWillReturn::ID, &AANoUnwind::ID, &AANoRecurse::ID, &AANoSync::ID, + &AANoFree::ID, &AANoReturn::ID, &AAMemoryLocation::ID, + &AAMemoryBehavior::ID, &AAUnderlyingObjects::ID, &AANoCapture::ID, + &AAInterFnReachability::ID, &AAIntraFnReachability::ID, &AACallEdges::ID, + &AANoFPClass::ID, &AAMustProgress::ID, &AANonNull::ID}); + AC.Allowed = &Allowed; + AC.UseLiveness = false; + + Attributor A(Functions, InfoCache, AC); + + for (Function *F : Functions) { + if (F->hasExactDefinition()) + NumFnWithExactDefinition++; + else + NumFnWithoutExactDefinition++; + + // We look at internal functions only on-demand but if any use is not a + // direct call or outside the current set of analyzed functions, we have + // to do it eagerly. + if (F->hasLocalLinkage()) { + if (llvm::all_of(F->uses(), [&Functions](const Use &U) { + const auto *CB = dyn_cast<CallBase>(U.getUser()); + return CB && CB->isCallee(&U) && + Functions.count(const_cast<Function *>(CB->getCaller())); + })) + continue; + } + + // Populate the Attributor with abstract attribute opportunities in the + // function and the information cache with IR information. + A.identifyDefaultAbstractAttributes(*F); + } + + ChangeStatus Changed = A.run(); + + if (Changed == ChangeStatus::CHANGED) { + // Invalidate analyses for modified functions so that we don't have to + // invalidate all analyses for all functions in this SCC. + PreservedAnalyses FuncPA; + // We haven't changed the CFG for modified functions. + FuncPA.preserveSet<CFGAnalyses>(); + for (Function *Changed : A.getModifiedFunctions()) { + FAM.invalidate(*Changed, FuncPA); + // Also invalidate any direct callers of changed functions since analyses + // may care about attributes of direct callees. For example, MemorySSA + // cares about whether or not a call's callee modifies memory and queries + // that through function attributes. + for (auto *U : Changed->users()) { + if (auto *Call = dyn_cast<CallBase>(U)) { + if (Call->getCalledFunction() == Changed) + FAM.invalidate(*Call->getFunction(), FuncPA); + } + } + } + } + LLVM_DEBUG(dbgs() << "[Attributor] Done with " << Functions.size() + << " functions, result: " << Changed << ".\n"); + return Changed == ChangeStatus::CHANGED; +} + void AADepGraph::viewGraph() { llvm::ViewGraph(this, "Dependency Graph"); } void AADepGraph::dumpGraph() { @@ -3839,6 +4069,62 @@ PreservedAnalyses AttributorCGSCCPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); } +PreservedAnalyses AttributorLightPass::run(Module &M, + ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + AnalysisGetter AG(FAM, /* CachedOnly */ true); + + SetVector<Function *> Functions; + for (Function &F : M) + Functions.insert(&F); + + CallGraphUpdater CGUpdater; + BumpPtrAllocator Allocator; + InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr); + if (runAttributorLightOnFunctions(InfoCache, Functions, AG, CGUpdater, FAM, + /* IsModulePass */ true)) { + PreservedAnalyses PA; + // We have not added or removed functions. + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + // We already invalidated all relevant function analyses above. + PA.preserveSet<AllAnalysesOn<Function>>(); + return PA; + } + return PreservedAnalyses::all(); +} + +PreservedAnalyses AttributorLightCGSCCPass::run(LazyCallGraph::SCC &C, + CGSCCAnalysisManager &AM, + LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); + AnalysisGetter AG(FAM); + + SetVector<Function *> Functions; + for (LazyCallGraph::Node &N : C) + Functions.insert(&N.getFunction()); + + if (Functions.empty()) + return PreservedAnalyses::all(); + + Module &M = *Functions.back()->getParent(); + CallGraphUpdater CGUpdater; + CGUpdater.initialize(CG, C, AM, UR); + BumpPtrAllocator Allocator; + InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ &Functions); + if (runAttributorLightOnFunctions(InfoCache, Functions, AG, CGUpdater, FAM, + /* IsModulePass */ false)) { + PreservedAnalyses PA; + // We have not added or removed functions. + PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); + // We already invalidated all relevant function analyses above. + PA.preserveSet<AllAnalysesOn<Function>>(); + return PA; + } + return PreservedAnalyses::all(); +} namespace llvm { template <> struct GraphTraits<AADepGraphNode *> { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 3a9a89d61355..8e1f782f7cd8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -55,6 +55,7 @@ #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/NoFolder.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" @@ -64,12 +65,16 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GraphWriter.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/TypeSize.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <cassert> #include <numeric> #include <optional> +#include <string> using namespace llvm; @@ -188,6 +193,10 @@ PIPE_OPERATOR(AAPointerInfo) PIPE_OPERATOR(AAAssumptionInfo) PIPE_OPERATOR(AAUnderlyingObjects) PIPE_OPERATOR(AAAddressSpace) +PIPE_OPERATOR(AAAllocationInfo) +PIPE_OPERATOR(AAIndirectCallInfo) +PIPE_OPERATOR(AAGlobalValueInfo) +PIPE_OPERATOR(AADenormalFPMath) #undef PIPE_OPERATOR @@ -281,20 +290,19 @@ static const Value *getPointerOperand(const Instruction *I, return nullptr; } -/// Helper function to create a pointer of type \p ResTy, based on \p Ptr, and -/// advanced by \p Offset bytes. To aid later analysis the method tries to build +/// Helper function to create a pointer based on \p Ptr, and advanced by \p +/// Offset bytes. To aid later analysis the method tries to build /// getelement pointer instructions that traverse the natural type of \p Ptr if /// possible. If that fails, the remaining offset is adjusted byte-wise, hence /// through a cast to i8*. /// /// TODO: This could probably live somewhere more prominantly if it doesn't /// already exist. -static Value *constructPointer(Type *ResTy, Type *PtrElemTy, Value *Ptr, - int64_t Offset, IRBuilder<NoFolder> &IRB, - const DataLayout &DL) { +static Value *constructPointer(Type *PtrElemTy, Value *Ptr, int64_t Offset, + IRBuilder<NoFolder> &IRB, const DataLayout &DL) { assert(Offset >= 0 && "Negative offset not supported yet!"); LLVM_DEBUG(dbgs() << "Construct pointer: " << *Ptr << " + " << Offset - << "-bytes as " << *ResTy << "\n"); + << "-bytes\n"); if (Offset) { Type *Ty = PtrElemTy; @@ -313,16 +321,11 @@ static Value *constructPointer(Type *ResTy, Type *PtrElemTy, Value *Ptr, // If an offset is left we use byte-wise adjustment. if (IntOffset != 0) { - Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy()); Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(IntOffset), GEPName + ".b" + Twine(IntOffset.getZExtValue())); } } - // Ensure the result has the requested type. - Ptr = IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, ResTy, - Ptr->getName() + ".cast"); - LLVM_DEBUG(dbgs() << "Constructed pointer: " << *Ptr << "\n"); return Ptr; } @@ -377,7 +380,7 @@ getMinimalBaseOfPointer(Attributor &A, const AbstractAttribute &QueryingAA, /// Clamp the information known for all returned values of a function /// (identified by \p QueryingAA) into \p S. template <typename AAType, typename StateType = typename AAType::StateType, - Attribute::AttrKind IRAttributeKind = Attribute::None, + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind, bool RecurseForSelectAndPHI = true> static void clampReturnedValueStates( Attributor &A, const AAType &QueryingAA, StateType &S, @@ -400,7 +403,7 @@ static void clampReturnedValueStates( auto CheckReturnValue = [&](Value &RV) -> bool { const IRPosition &RVPos = IRPosition::value(RV, CBContext); // If possible, use the hasAssumedIRAttr interface. - if (IRAttributeKind != Attribute::None) { + if (Attribute::isEnumAttrKind(IRAttributeKind)) { bool IsKnown; return AA::hasAssumedIRAttr<IRAttributeKind>( A, &QueryingAA, RVPos, DepClassTy::REQUIRED, IsKnown); @@ -434,7 +437,7 @@ namespace { template <typename AAType, typename BaseType, typename StateType = typename BaseType::StateType, bool PropagateCallBaseContext = false, - Attribute::AttrKind IRAttributeKind = Attribute::None, + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind, bool RecurseForSelectAndPHI = true> struct AAReturnedFromReturnedValues : public BaseType { AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A) @@ -455,7 +458,7 @@ struct AAReturnedFromReturnedValues : public BaseType { /// Clamp the information known at all call sites for a given argument /// (identified by \p QueryingAA) into \p S. template <typename AAType, typename StateType = typename AAType::StateType, - Attribute::AttrKind IRAttributeKind = Attribute::None> + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind> static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, StateType &S) { LLVM_DEBUG(dbgs() << "[Attributor] Clamp call site argument states for " @@ -480,7 +483,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, return false; // If possible, use the hasAssumedIRAttr interface. - if (IRAttributeKind != Attribute::None) { + if (Attribute::isEnumAttrKind(IRAttributeKind)) { bool IsKnown; return AA::hasAssumedIRAttr<IRAttributeKind>( A, &QueryingAA, ACSArgPos, DepClassTy::REQUIRED, IsKnown); @@ -514,7 +517,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, /// context. template <typename AAType, typename BaseType, typename StateType = typename AAType::StateType, - Attribute::AttrKind IRAttributeKind = Attribute::None> + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind> bool getArgumentStateFromCallBaseContext(Attributor &A, BaseType &QueryingAttribute, IRPosition &Pos, StateType &State) { @@ -529,7 +532,7 @@ bool getArgumentStateFromCallBaseContext(Attributor &A, const IRPosition CBArgPos = IRPosition::callsite_argument(*CBContext, ArgNo); // If possible, use the hasAssumedIRAttr interface. - if (IRAttributeKind != Attribute::None) { + if (Attribute::isEnumAttrKind(IRAttributeKind)) { bool IsKnown; return AA::hasAssumedIRAttr<IRAttributeKind>( A, &QueryingAttribute, CBArgPos, DepClassTy::REQUIRED, IsKnown); @@ -555,7 +558,7 @@ bool getArgumentStateFromCallBaseContext(Attributor &A, template <typename AAType, typename BaseType, typename StateType = typename AAType::StateType, bool BridgeCallBaseContext = false, - Attribute::AttrKind IRAttributeKind = Attribute::None> + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind> struct AAArgumentFromCallSiteArguments : public BaseType { AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -585,45 +588,55 @@ struct AAArgumentFromCallSiteArguments : public BaseType { template <typename AAType, typename BaseType, typename StateType = typename BaseType::StateType, bool IntroduceCallBaseContext = false, - Attribute::AttrKind IRAttributeKind = Attribute::None> -struct AACallSiteReturnedFromReturned : public BaseType { - AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A) - : BaseType(IRP, A) {} + Attribute::AttrKind IRAttributeKind = AAType::IRAttributeKind> +struct AACalleeToCallSite : public BaseType { + AACalleeToCallSite(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - assert(this->getIRPosition().getPositionKind() == - IRPosition::IRP_CALL_SITE_RETURNED && - "Can only wrap function returned positions for call site returned " - "positions!"); + auto IRPKind = this->getIRPosition().getPositionKind(); + assert((IRPKind == IRPosition::IRP_CALL_SITE_RETURNED || + IRPKind == IRPosition::IRP_CALL_SITE) && + "Can only wrap function returned positions for call site " + "returned positions!"); auto &S = this->getState(); - const Function *AssociatedFunction = - this->getIRPosition().getAssociatedFunction(); - if (!AssociatedFunction) - return S.indicatePessimisticFixpoint(); - - CallBase &CBContext = cast<CallBase>(this->getAnchorValue()); + CallBase &CB = cast<CallBase>(this->getAnchorValue()); if (IntroduceCallBaseContext) - LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" - << CBContext << "\n"); - - IRPosition FnPos = IRPosition::returned( - *AssociatedFunction, IntroduceCallBaseContext ? &CBContext : nullptr); + LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" << CB + << "\n"); - // If possible, use the hasAssumedIRAttr interface. - if (IRAttributeKind != Attribute::None) { - bool IsKnown; - if (!AA::hasAssumedIRAttr<IRAttributeKind>(A, this, FnPos, - DepClassTy::REQUIRED, IsKnown)) - return S.indicatePessimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + ChangeStatus Changed = ChangeStatus::UNCHANGED; + auto CalleePred = [&](ArrayRef<const Function *> Callees) { + for (const Function *Callee : Callees) { + IRPosition FnPos = + IRPKind == llvm::IRPosition::IRP_CALL_SITE_RETURNED + ? IRPosition::returned(*Callee, + IntroduceCallBaseContext ? &CB : nullptr) + : IRPosition::function( + *Callee, IntroduceCallBaseContext ? &CB : nullptr); + // If possible, use the hasAssumedIRAttr interface. + if (Attribute::isEnumAttrKind(IRAttributeKind)) { + bool IsKnown; + if (!AA::hasAssumedIRAttr<IRAttributeKind>( + A, this, FnPos, DepClassTy::REQUIRED, IsKnown)) + return false; + continue; + } - const AAType *AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED); - if (!AA) + const AAType *AA = + A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED); + if (!AA) + return false; + Changed |= clampStateAndIndicateChange(S, AA->getState()); + if (S.isAtFixpoint()) + return S.isValidState(); + } + return true; + }; + if (!A.checkForAllCallees(CalleePred, *this, CB)) return S.indicatePessimisticFixpoint(); - return clampStateAndIndicateChange(S, AA->getState()); + return Changed; } }; @@ -865,11 +878,9 @@ struct AA::PointerInfo::State : public AbstractState { AAPointerInfo::AccessKind Kind, Type *Ty, Instruction *RemoteI = nullptr); - using OffsetBinsTy = DenseMap<RangeTy, SmallSet<unsigned, 4>>; - - using const_bin_iterator = OffsetBinsTy::const_iterator; - const_bin_iterator begin() const { return OffsetBins.begin(); } - const_bin_iterator end() const { return OffsetBins.end(); } + AAPointerInfo::const_bin_iterator begin() const { return OffsetBins.begin(); } + AAPointerInfo::const_bin_iterator end() const { return OffsetBins.end(); } + int64_t numOffsetBins() const { return OffsetBins.size(); } const AAPointerInfo::Access &getAccess(unsigned Index) const { return AccessList[Index]; @@ -889,7 +900,7 @@ protected: // are all combined into a single Access object. This may result in loss of // information in RangeTy in the Access object. SmallVector<AAPointerInfo::Access> AccessList; - OffsetBinsTy OffsetBins; + AAPointerInfo::OffsetBinsTy OffsetBins; DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap; /// See AAPointerInfo::forallInterferingAccesses. @@ -1093,6 +1104,12 @@ struct AAPointerInfoImpl return AAPointerInfo::manifest(A); } + virtual const_bin_iterator begin() const override { return State::begin(); } + virtual const_bin_iterator end() const override { return State::end(); } + virtual int64_t numOffsetBins() const override { + return State::numOffsetBins(); + } + bool forallInterferingAccesses( AA::RangeTy Range, function_ref<bool(const AAPointerInfo::Access &, bool)> CB) @@ -1104,7 +1121,8 @@ struct AAPointerInfoImpl Attributor &A, const AbstractAttribute &QueryingAA, Instruction &I, bool FindInterferingWrites, bool FindInterferingReads, function_ref<bool(const Access &, bool)> UserCB, bool &HasBeenWrittenTo, - AA::RangeTy &Range) const override { + AA::RangeTy &Range, + function_ref<bool(const Access &)> SkipCB) const override { HasBeenWrittenTo = false; SmallPtrSet<const Access *, 8> DominatingWrites; @@ -1183,6 +1201,11 @@ struct AAPointerInfoImpl A, this, IRPosition::function(Scope), DepClassTy::OPTIONAL, IsKnownNoRecurse); + // TODO: Use reaching kernels from AAKernelInfo (or move it to + // AAExecutionDomain) such that we allow scopes other than kernels as long + // as the reaching kernels are disjoint. + bool InstInKernel = Scope.hasFnAttribute("kernel"); + bool ObjHasKernelLifetime = false; const bool UseDominanceReasoning = FindInterferingWrites && IsKnownNoRecurse; const DominatorTree *DT = @@ -1215,6 +1238,7 @@ struct AAPointerInfoImpl // If the alloca containing function is not recursive the alloca // must be dead in the callee. const Function *AIFn = AI->getFunction(); + ObjHasKernelLifetime = AIFn->hasFnAttribute("kernel"); bool IsKnownNoRecurse; if (AA::hasAssumedIRAttr<Attribute::NoRecurse>( A, this, IRPosition::function(*AIFn), DepClassTy::OPTIONAL, @@ -1224,7 +1248,8 @@ struct AAPointerInfoImpl } else if (auto *GV = dyn_cast<GlobalValue>(&getAssociatedValue())) { // If the global has kernel lifetime we can stop if we reach a kernel // as it is "dead" in the (unknown) callees. - if (HasKernelLifetime(GV, *GV->getParent())) + ObjHasKernelLifetime = HasKernelLifetime(GV, *GV->getParent()); + if (ObjHasKernelLifetime) IsLiveInCalleeCB = [](const Function &Fn) { return !Fn.hasFnAttribute("kernel"); }; @@ -1235,6 +1260,15 @@ struct AAPointerInfoImpl AA::InstExclusionSetTy ExclusionSet; auto AccessCB = [&](const Access &Acc, bool Exact) { + Function *AccScope = Acc.getRemoteInst()->getFunction(); + bool AccInSameScope = AccScope == &Scope; + + // If the object has kernel lifetime we can ignore accesses only reachable + // by other kernels. For now we only skip accesses *in* other kernels. + if (InstInKernel && ObjHasKernelLifetime && !AccInSameScope && + AccScope->hasFnAttribute("kernel")) + return true; + if (Exact && Acc.isMustAccess() && Acc.getRemoteInst() != &I) { if (Acc.isWrite() || (isa<LoadInst>(I) && Acc.isWriteOrAssumption())) ExclusionSet.insert(Acc.getRemoteInst()); @@ -1245,8 +1279,7 @@ struct AAPointerInfoImpl return true; bool Dominates = FindInterferingWrites && DT && Exact && - Acc.isMustAccess() && - (Acc.getRemoteInst()->getFunction() == &Scope) && + Acc.isMustAccess() && AccInSameScope && DT->dominates(Acc.getRemoteInst(), &I); if (Dominates) DominatingWrites.insert(&Acc); @@ -1276,6 +1309,8 @@ struct AAPointerInfoImpl // Helper to determine if we can skip a specific write access. auto CanSkipAccess = [&](const Access &Acc, bool Exact) { + if (SkipCB && SkipCB(Acc)) + return true; if (!CanIgnoreThreading(Acc)) return false; @@ -1817,9 +1852,14 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { LLVM_DEBUG(dbgs() << "[AAPointerInfo] Assumption found " << *Assumption.second << ": " << *LoadI << " == " << *Assumption.first << "\n"); - + bool UsedAssumedInformation = false; + std::optional<Value *> Content = nullptr; + if (Assumption.first) + Content = + A.getAssumedSimplified(*Assumption.first, *this, + UsedAssumedInformation, AA::Interprocedural); return handleAccess( - A, *Assumption.second, Assumption.first, AccessKind::AK_ASSUMPTION, + A, *Assumption.second, Content, AccessKind::AK_ASSUMPTION, OffsetInfoMap[CurPtr].Offsets, Changed, *LoadI->getType()); } @@ -2083,24 +2123,10 @@ struct AANoUnwindFunction final : public AANoUnwindImpl { }; /// NoUnwind attribute deduction for a call sites. -struct AANoUnwindCallSite final : AANoUnwindImpl { +struct AANoUnwindCallSite final + : AACalleeToCallSite<AANoUnwind, AANoUnwindImpl> { AANoUnwindCallSite(const IRPosition &IRP, Attributor &A) - : AANoUnwindImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnownNoUnwind; - if (AA::hasAssumedIRAttr<Attribute::NoUnwind>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoUnwind)) - return ChangeStatus::UNCHANGED; - return indicatePessimisticFixpoint(); - } + : AACalleeToCallSite<AANoUnwind, AANoUnwindImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); } @@ -2200,8 +2226,15 @@ ChangeStatus AANoSyncImpl::updateImpl(Attributor &A) { if (I.mayReadOrWriteMemory()) return true; + bool IsKnown; + CallBase &CB = cast<CallBase>(I); + if (AA::hasAssumedIRAttr<Attribute::NoSync>( + A, this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL, + IsKnown)) + return true; + // non-convergent and readnone imply nosync. - return !cast<CallBase>(I).isConvergent(); + return !CB.isConvergent(); }; bool UsedAssumedInformation = false; @@ -2223,24 +2256,9 @@ struct AANoSyncFunction final : public AANoSyncImpl { }; /// NoSync attribute deduction for a call sites. -struct AANoSyncCallSite final : AANoSyncImpl { +struct AANoSyncCallSite final : AACalleeToCallSite<AANoSync, AANoSyncImpl> { AANoSyncCallSite(const IRPosition &IRP, Attributor &A) - : AANoSyncImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnownNoSycn; - if (AA::hasAssumedIRAttr<Attribute::NoSync>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoSycn)) - return ChangeStatus::UNCHANGED; - return indicatePessimisticFixpoint(); - } + : AACalleeToCallSite<AANoSync, AANoSyncImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nosync); } @@ -2292,24 +2310,9 @@ struct AANoFreeFunction final : public AANoFreeImpl { }; /// NoFree attribute deduction for a call sites. -struct AANoFreeCallSite final : AANoFreeImpl { +struct AANoFreeCallSite final : AACalleeToCallSite<AANoFree, AANoFreeImpl> { AANoFreeCallSite(const IRPosition &IRP, Attributor &A) - : AANoFreeImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnown; - if (AA::hasAssumedIRAttr<Attribute::NoFree>(A, this, FnPos, - DepClassTy::REQUIRED, IsKnown)) - return ChangeStatus::UNCHANGED; - return indicatePessimisticFixpoint(); - } + : AACalleeToCallSite<AANoFree, AANoFreeImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nofree); } @@ -2450,9 +2453,6 @@ bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP, if (A.hasAttr(IRP, AttrKinds, IgnoreSubsumingPositions, Attribute::NonNull)) return true; - if (IRP.getPositionKind() == IRP_RETURNED) - return false; - DominatorTree *DT = nullptr; AssumptionCache *AC = nullptr; InformationCache &InfoCache = A.getInfoCache(); @@ -2463,9 +2463,27 @@ bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP, } } - if (!isKnownNonZero(&IRP.getAssociatedValue(), A.getDataLayout(), 0, AC, - IRP.getCtxI(), DT)) + SmallVector<AA::ValueAndContext> Worklist; + if (IRP.getPositionKind() != IRP_RETURNED) { + Worklist.push_back({IRP.getAssociatedValue(), IRP.getCtxI()}); + } else { + bool UsedAssumedInformation = false; + if (!A.checkForAllInstructions( + [&](Instruction &I) { + Worklist.push_back({*cast<ReturnInst>(I).getReturnValue(), &I}); + return true; + }, + IRP.getAssociatedFunction(), nullptr, {Instruction::Ret}, + UsedAssumedInformation)) + return false; + } + + if (llvm::any_of(Worklist, [&](AA::ValueAndContext VAC) { + return !isKnownNonZero(VAC.getValue(), A.getDataLayout(), 0, AC, + VAC.getCtxI(), DT); + })) return false; + A.manifestAttrs(IRP, {Attribute::get(IRP.getAnchorValue().getContext(), Attribute::NonNull)}); return true; @@ -2529,7 +2547,8 @@ static int64_t getKnownNonNullAndDerefBytesForUse( } std::optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I); - if (!Loc || Loc->Ptr != UseV || !Loc->Size.isPrecise() || I->isVolatile()) + if (!Loc || Loc->Ptr != UseV || !Loc->Size.isPrecise() || + Loc->Size.isScalable() || I->isVolatile()) return 0; int64_t Offset; @@ -2610,6 +2629,23 @@ struct AANonNullFloating : public AANonNullImpl { Values.size() != 1 || Values.front().getValue() != AssociatedValue; if (!Stripped) { + bool IsKnown; + if (auto *PHI = dyn_cast<PHINode>(AssociatedValue)) + if (llvm::all_of(PHI->incoming_values(), [&](Value *Op) { + return AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*Op), DepClassTy::OPTIONAL, + IsKnown); + })) + return ChangeStatus::UNCHANGED; + if (auto *Select = dyn_cast<SelectInst>(AssociatedValue)) + if (AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*Select->getFalseValue()), + DepClassTy::OPTIONAL, IsKnown) && + AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*Select->getTrueValue()), + DepClassTy::OPTIONAL, IsKnown)) + return ChangeStatus::UNCHANGED; + // If we haven't stripped anything we might still be able to use a // different AA, but only if the IRP changes. Effectively when we // interpret this not as a call site value but as a floating/argument @@ -2634,10 +2670,11 @@ struct AANonNullFloating : public AANonNullImpl { /// NonNull attribute for function return value. struct AANonNullReturned final : AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType, - false, AANonNull::IRAttributeKind> { + false, AANonNull::IRAttributeKind, false> { AANonNullReturned(const IRPosition &IRP, Attributor &A) : AAReturnedFromReturnedValues<AANonNull, AANonNull, AANonNull::StateType, - false, Attribute::NonNull>(IRP, A) {} + false, Attribute::NonNull, false>(IRP, A) { + } /// See AbstractAttribute::getAsStr(). const std::string getAsStr(Attributor *A) const override { @@ -2650,13 +2687,9 @@ struct AANonNullReturned final /// NonNull attribute for function argument. struct AANonNullArgument final - : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl, - AANonNull::StateType, false, - AANonNull::IRAttributeKind> { + : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl> { AANonNullArgument(const IRPosition &IRP, Attributor &A) - : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl, - AANonNull::StateType, false, - AANonNull::IRAttributeKind>(IRP, A) {} + : AAArgumentFromCallSiteArguments<AANonNull, AANonNullImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(nonnull) } @@ -2672,13 +2705,9 @@ struct AANonNullCallSiteArgument final : AANonNullFloating { /// NonNull attribute for a call site return position. struct AANonNullCallSiteReturned final - : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl, - AANonNull::StateType, false, - AANonNull::IRAttributeKind> { + : AACalleeToCallSite<AANonNull, AANonNullImpl> { AANonNullCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AANonNull, AANonNullImpl, - AANonNull::StateType, false, - AANonNull::IRAttributeKind>(IRP, A) {} + : AACalleeToCallSite<AANonNull, AANonNullImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nonnull) } @@ -2830,24 +2859,10 @@ struct AANoRecurseFunction final : AANoRecurseImpl { }; /// NoRecurse attribute deduction for a call sites. -struct AANoRecurseCallSite final : AANoRecurseImpl { +struct AANoRecurseCallSite final + : AACalleeToCallSite<AANoRecurse, AANoRecurseImpl> { AANoRecurseCallSite(const IRPosition &IRP, Attributor &A) - : AANoRecurseImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnownNoRecurse; - if (!AA::hasAssumedIRAttr<Attribute::NoRecurse>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoRecurse)) - return indicatePessimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + : AACalleeToCallSite<AANoRecurse, AANoRecurseImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(norecurse); } @@ -3355,26 +3370,17 @@ struct AAWillReturnFunction final : AAWillReturnImpl { }; /// WillReturn attribute deduction for a call sites. -struct AAWillReturnCallSite final : AAWillReturnImpl { +struct AAWillReturnCallSite final + : AACalleeToCallSite<AAWillReturn, AAWillReturnImpl> { AAWillReturnCallSite(const IRPosition &IRP, Attributor &A) - : AAWillReturnImpl(IRP, A) {} + : AACalleeToCallSite<AAWillReturn, AAWillReturnImpl>(IRP, A) {} /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { if (isImpliedByMustprogressAndReadonly(A, /* KnownOnly */ false)) return ChangeStatus::UNCHANGED; - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnown; - if (AA::hasAssumedIRAttr<Attribute::WillReturn>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnown)) - return ChangeStatus::UNCHANGED; - return indicatePessimisticFixpoint(); + return AACalleeToCallSite::updateImpl(A); } /// See AbstractAttribute::trackStatistics() @@ -3402,6 +3408,18 @@ template <typename ToTy> struct ReachabilityQueryInfo { /// and remember if it worked: Reachable Result = Reachable::No; + /// Precomputed hash for this RQI. + unsigned Hash = 0; + + unsigned computeHashValue() const { + assert(Hash == 0 && "Computed hash twice!"); + using InstSetDMI = DenseMapInfo<const AA::InstExclusionSetTy *>; + using PairDMI = DenseMapInfo<std::pair<const Instruction *, const ToTy *>>; + return const_cast<ReachabilityQueryInfo<ToTy> *>(this)->Hash = + detail::combineHashValue(PairDMI ::getHashValue({From, To}), + InstSetDMI::getHashValue(ExclusionSet)); + } + ReachabilityQueryInfo(const Instruction *From, const ToTy *To) : From(From), To(To) {} @@ -3435,9 +3453,7 @@ template <typename ToTy> struct DenseMapInfo<ReachabilityQueryInfo<ToTy> *> { return &TombstoneKey; } static unsigned getHashValue(const ReachabilityQueryInfo<ToTy> *RQI) { - unsigned H = PairDMI ::getHashValue({RQI->From, RQI->To}); - H += InstSetDMI::getHashValue(RQI->ExclusionSet); - return H; + return RQI->Hash ? RQI->Hash : RQI->computeHashValue(); } static bool isEqual(const ReachabilityQueryInfo<ToTy> *LHS, const ReachabilityQueryInfo<ToTy> *RHS) { @@ -3480,24 +3496,24 @@ struct CachedReachabilityAA : public BaseTy { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { ChangeStatus Changed = ChangeStatus::UNCHANGED; - InUpdate = true; for (unsigned u = 0, e = QueryVector.size(); u < e; ++u) { RQITy *RQI = QueryVector[u]; - if (RQI->Result == RQITy::Reachable::No && isReachableImpl(A, *RQI)) + if (RQI->Result == RQITy::Reachable::No && + isReachableImpl(A, *RQI, /*IsTemporaryRQI=*/false)) Changed = ChangeStatus::CHANGED; } - InUpdate = false; return Changed; } - virtual bool isReachableImpl(Attributor &A, RQITy &RQI) = 0; + virtual bool isReachableImpl(Attributor &A, RQITy &RQI, + bool IsTemporaryRQI) = 0; bool rememberResult(Attributor &A, typename RQITy::Reachable Result, - RQITy &RQI, bool UsedExclusionSet) { + RQITy &RQI, bool UsedExclusionSet, bool IsTemporaryRQI) { RQI.Result = Result; // Remove the temporary RQI from the cache. - if (!InUpdate) + if (IsTemporaryRQI) QueryCache.erase(&RQI); // Insert a plain RQI (w/o exclusion set) if that makes sense. Two options: @@ -3515,7 +3531,7 @@ struct CachedReachabilityAA : public BaseTy { } // Check if we need to insert a new permanent RQI with the exclusion set. - if (!InUpdate && Result != RQITy::Reachable::Yes && UsedExclusionSet) { + if (IsTemporaryRQI && Result != RQITy::Reachable::Yes && UsedExclusionSet) { assert((!RQI.ExclusionSet || !RQI.ExclusionSet->empty()) && "Did not expect empty set!"); RQITy *RQIPtr = new (A.Allocator) @@ -3527,7 +3543,7 @@ struct CachedReachabilityAA : public BaseTy { QueryCache.insert(RQIPtr); } - if (Result == RQITy::Reachable::No && !InUpdate) + if (Result == RQITy::Reachable::No && IsTemporaryRQI) A.registerForUpdate(*this); return Result == RQITy::Reachable::Yes; } @@ -3568,7 +3584,6 @@ struct CachedReachabilityAA : public BaseTy { } private: - bool InUpdate = false; SmallVector<RQITy *> QueryVector; DenseSet<RQITy *> QueryCache; }; @@ -3577,7 +3592,10 @@ struct AAIntraFnReachabilityFunction final : public CachedReachabilityAA<AAIntraFnReachability, Instruction> { using Base = CachedReachabilityAA<AAIntraFnReachability, Instruction>; AAIntraFnReachabilityFunction(const IRPosition &IRP, Attributor &A) - : Base(IRP, A) {} + : Base(IRP, A) { + DT = A.getInfoCache().getAnalysisResultForFunction<DominatorTreeAnalysis>( + *IRP.getAssociatedFunction()); + } bool isAssumedReachable( Attributor &A, const Instruction &From, const Instruction &To, @@ -3589,7 +3607,8 @@ struct AAIntraFnReachabilityFunction final RQITy StackRQI(A, From, To, ExclusionSet, false); typename RQITy::Reachable Result; if (!NonConstThis->checkQueryCache(A, StackRQI, Result)) - return NonConstThis->isReachableImpl(A, StackRQI); + return NonConstThis->isReachableImpl(A, StackRQI, + /*IsTemporaryRQI=*/true); return Result == RQITy::Reachable::Yes; } @@ -3598,16 +3617,24 @@ struct AAIntraFnReachabilityFunction final // of them changed. auto *LivenessAA = A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); - if (LivenessAA && llvm::all_of(DeadEdges, [&](const auto &DeadEdge) { - return LivenessAA->isEdgeDead(DeadEdge.first, DeadEdge.second); + if (LivenessAA && + llvm::all_of(DeadEdges, + [&](const auto &DeadEdge) { + return LivenessAA->isEdgeDead(DeadEdge.first, + DeadEdge.second); + }) && + llvm::all_of(DeadBlocks, [&](const BasicBlock *BB) { + return LivenessAA->isAssumedDead(BB); })) { return ChangeStatus::UNCHANGED; } DeadEdges.clear(); + DeadBlocks.clear(); return Base::updateImpl(A); } - bool isReachableImpl(Attributor &A, RQITy &RQI) override { + bool isReachableImpl(Attributor &A, RQITy &RQI, + bool IsTemporaryRQI) override { const Instruction *Origin = RQI.From; bool UsedExclusionSet = false; @@ -3633,31 +3660,41 @@ struct AAIntraFnReachabilityFunction final // possible. if (FromBB == ToBB && WillReachInBlock(*RQI.From, *RQI.To, RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet, + IsTemporaryRQI); // Check if reaching the ToBB block is sufficient or if even that would not // ensure reaching the target. In the latter case we are done. if (!WillReachInBlock(ToBB->front(), *RQI.To, RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet, + IsTemporaryRQI); + const Function *Fn = FromBB->getParent(); SmallPtrSet<const BasicBlock *, 16> ExclusionBlocks; if (RQI.ExclusionSet) for (auto *I : *RQI.ExclusionSet) - ExclusionBlocks.insert(I->getParent()); + if (I->getFunction() == Fn) + ExclusionBlocks.insert(I->getParent()); // Check if we make it out of the FromBB block at all. if (ExclusionBlocks.count(FromBB) && !WillReachInBlock(*RQI.From, *FromBB->getTerminator(), RQI.ExclusionSet)) - return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::No, RQI, true, IsTemporaryRQI); + + auto *LivenessAA = + A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (LivenessAA && LivenessAA->isAssumedDead(ToBB)) { + DeadBlocks.insert(ToBB); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet, + IsTemporaryRQI); + } SmallPtrSet<const BasicBlock *, 16> Visited; SmallVector<const BasicBlock *, 16> Worklist; Worklist.push_back(FromBB); DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> LocalDeadEdges; - auto *LivenessAA = - A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); while (!Worklist.empty()) { const BasicBlock *BB = Worklist.pop_back_val(); if (!Visited.insert(BB).second) @@ -3669,8 +3706,12 @@ struct AAIntraFnReachabilityFunction final } // We checked before if we just need to reach the ToBB block. if (SuccBB == ToBB) - return rememberResult(A, RQITy::Reachable::Yes, RQI, - UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet, + IsTemporaryRQI); + if (DT && ExclusionBlocks.empty() && DT->dominates(BB, ToBB)) + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet, + IsTemporaryRQI); + if (ExclusionBlocks.count(SuccBB)) { UsedExclusionSet = true; continue; @@ -3680,16 +3721,24 @@ struct AAIntraFnReachabilityFunction final } DeadEdges.insert(LocalDeadEdges.begin(), LocalDeadEdges.end()); - return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet, + IsTemporaryRQI); } /// See AbstractAttribute::trackStatistics() void trackStatistics() const override {} private: + // Set of assumed dead blocks we used in the last query. If any changes we + // update the state. + DenseSet<const BasicBlock *> DeadBlocks; + // Set of assumed dead edges we used in the last query. If any changes we // update the state. DenseSet<std::pair<const BasicBlock *, const BasicBlock *>> DeadEdges; + + /// The dominator tree of the function to short-circuit reasoning. + const DominatorTree *DT = nullptr; }; } // namespace @@ -3754,12 +3803,8 @@ struct AANoAliasFloating final : AANoAliasImpl { /// NoAlias attribute for an argument. struct AANoAliasArgument final - : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl, - AANoAlias::StateType, false, - Attribute::NoAlias> { - using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl, - AANoAlias::StateType, false, - Attribute::NoAlias>; + : AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl> { + using Base = AAArgumentFromCallSiteArguments<AANoAlias, AANoAliasImpl>; AANoAliasArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} /// See AbstractAttribute::update(...). @@ -4027,24 +4072,10 @@ struct AANoAliasReturned final : AANoAliasImpl { }; /// NoAlias attribute deduction for a call site return value. -struct AANoAliasCallSiteReturned final : AANoAliasImpl { +struct AANoAliasCallSiteReturned final + : AACalleeToCallSite<AANoAlias, AANoAliasImpl> { AANoAliasCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AANoAliasImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::returned(*F); - bool IsKnownNoAlias; - if (!AA::hasAssumedIRAttr<Attribute::NoAlias>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoAlias)) - return indicatePessimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + : AACalleeToCallSite<AANoAlias, AANoAliasImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noalias); } @@ -4696,23 +4727,53 @@ identifyAliveSuccessors(Attributor &A, const SwitchInst &SI, AbstractAttribute &AA, SmallVectorImpl<const Instruction *> &AliveSuccessors) { bool UsedAssumedInformation = false; - std::optional<Constant *> C = - A.getAssumedConstant(*SI.getCondition(), AA, UsedAssumedInformation); - if (!C || isa_and_nonnull<UndefValue>(*C)) { - // No value yet, assume all edges are dead. - } else if (isa_and_nonnull<ConstantInt>(*C)) { - for (const auto &CaseIt : SI.cases()) { - if (CaseIt.getCaseValue() == *C) { - AliveSuccessors.push_back(&CaseIt.getCaseSuccessor()->front()); - return UsedAssumedInformation; - } - } - AliveSuccessors.push_back(&SI.getDefaultDest()->front()); + SmallVector<AA::ValueAndContext> Values; + if (!A.getAssumedSimplifiedValues(IRPosition::value(*SI.getCondition()), &AA, + Values, AA::AnyScope, + UsedAssumedInformation)) { + // Something went wrong, assume all successors are live. + for (const BasicBlock *SuccBB : successors(SI.getParent())) + AliveSuccessors.push_back(&SuccBB->front()); + return false; + } + + if (Values.empty() || + (Values.size() == 1 && + isa_and_nonnull<UndefValue>(Values.front().getValue()))) { + // No valid value yet, assume all edges are dead. return UsedAssumedInformation; - } else { + } + + Type &Ty = *SI.getCondition()->getType(); + SmallPtrSet<ConstantInt *, 8> Constants; + auto CheckForConstantInt = [&](Value *V) { + if (auto *CI = dyn_cast_if_present<ConstantInt>(AA::getWithType(*V, Ty))) { + Constants.insert(CI); + return true; + } + return false; + }; + + if (!all_of(Values, [&](AA::ValueAndContext &VAC) { + return CheckForConstantInt(VAC.getValue()); + })) { for (const BasicBlock *SuccBB : successors(SI.getParent())) AliveSuccessors.push_back(&SuccBB->front()); + return UsedAssumedInformation; } + + unsigned MatchedCases = 0; + for (const auto &CaseIt : SI.cases()) { + if (Constants.count(CaseIt.getCaseValue())) { + ++MatchedCases; + AliveSuccessors.push_back(&CaseIt.getCaseSuccessor()->front()); + } + } + + // If all potential values have been matched, we will not visit the default + // case. + if (MatchedCases < Constants.size()) + AliveSuccessors.push_back(&SI.getDefaultDest()->front()); return UsedAssumedInformation; } @@ -5103,9 +5164,8 @@ struct AADereferenceableCallSiteArgument final : AADereferenceableFloating { /// Dereferenceable attribute deduction for a call site return value. struct AADereferenceableCallSiteReturned final - : AACallSiteReturnedFromReturned<AADereferenceable, AADereferenceableImpl> { - using Base = - AACallSiteReturnedFromReturned<AADereferenceable, AADereferenceableImpl>; + : AACalleeToCallSite<AADereferenceable, AADereferenceableImpl> { + using Base = AACalleeToCallSite<AADereferenceable, AADereferenceableImpl>; AADereferenceableCallSiteReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} @@ -5400,8 +5460,8 @@ struct AAAlignCallSiteArgument final : AAAlignFloating { /// Align attribute deduction for a call site return value. struct AAAlignCallSiteReturned final - : AACallSiteReturnedFromReturned<AAAlign, AAAlignImpl> { - using Base = AACallSiteReturnedFromReturned<AAAlign, AAAlignImpl>; + : AACalleeToCallSite<AAAlign, AAAlignImpl> { + using Base = AACalleeToCallSite<AAAlign, AAAlignImpl>; AAAlignCallSiteReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} @@ -5449,24 +5509,10 @@ struct AANoReturnFunction final : AANoReturnImpl { }; /// NoReturn attribute deduction for a call sites. -struct AANoReturnCallSite final : AANoReturnImpl { +struct AANoReturnCallSite final + : AACalleeToCallSite<AANoReturn, AANoReturnImpl> { AANoReturnCallSite(const IRPosition &IRP, Attributor &A) - : AANoReturnImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - bool IsKnownNoReturn; - if (!AA::hasAssumedIRAttr<Attribute::NoReturn>( - A, this, FnPos, DepClassTy::REQUIRED, IsKnownNoReturn)) - return indicatePessimisticFixpoint(); - return ChangeStatus::UNCHANGED; - } + : AACalleeToCallSite<AANoReturn, AANoReturnImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(noreturn); } @@ -5805,8 +5851,8 @@ struct AANoCaptureImpl : public AANoCapture { // For stores we already checked if we can follow them, if they make it // here we give up. if (isa<StoreInst>(UInst)) - return isCapturedIn(State, /* Memory */ true, /* Integer */ false, - /* Return */ false); + return isCapturedIn(State, /* Memory */ true, /* Integer */ true, + /* Return */ true); // Explicitly catch return instructions. if (isa<ReturnInst>(UInst)) { @@ -6476,7 +6522,7 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - return indicatePessimisticFixpoint(); + return indicatePessimisticFixpoint(); } void trackStatistics() const override { @@ -6937,13 +6983,17 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { << **DI->PotentialAllocationCalls.begin() << "\n"); return false; } - Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode(); - if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) { - LLVM_DEBUG( - dbgs() - << "[H2S] unique free call might not be executed with the allocation " - << *UniqueFree << "\n"); - return false; + + // __kmpc_alloc_shared and __kmpc_alloc_free are by construction matched. + if (AI.LibraryFunctionId != LibFunc___kmpc_alloc_shared) { + Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode(); + if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) { + LLVM_DEBUG( + dbgs() + << "[H2S] unique free call might not be executed with the allocation " + << *UniqueFree << "\n"); + return false; + } } return true; }; @@ -7437,19 +7487,16 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { if (auto *PrivStructType = dyn_cast<StructType>(PrivType)) { const StructLayout *PrivStructLayout = DL.getStructLayout(PrivStructType); for (unsigned u = 0, e = PrivStructType->getNumElements(); u < e; u++) { - Type *PointeeTy = PrivStructType->getElementType(u)->getPointerTo(); - Value *Ptr = - constructPointer(PointeeTy, PrivType, &Base, - PrivStructLayout->getElementOffset(u), IRB, DL); + Value *Ptr = constructPointer( + PrivType, &Base, PrivStructLayout->getElementOffset(u), IRB, DL); new StoreInst(F.getArg(ArgNo + u), Ptr, &IP); } } else if (auto *PrivArrayType = dyn_cast<ArrayType>(PrivType)) { Type *PointeeTy = PrivArrayType->getElementType(); - Type *PointeePtrTy = PointeeTy->getPointerTo(); uint64_t PointeeTySize = DL.getTypeStoreSize(PointeeTy); for (unsigned u = 0, e = PrivArrayType->getNumElements(); u < e; u++) { - Value *Ptr = constructPointer(PointeePtrTy, PrivType, &Base, - u * PointeeTySize, IRB, DL); + Value *Ptr = + constructPointer(PrivType, &Base, u * PointeeTySize, IRB, DL); new StoreInst(F.getArg(ArgNo + u), Ptr, &IP); } } else { @@ -7469,19 +7516,13 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { IRBuilder<NoFolder> IRB(IP); const DataLayout &DL = IP->getModule()->getDataLayout(); - Type *PrivPtrType = PrivType->getPointerTo(); - if (Base->getType() != PrivPtrType) - Base = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( - Base, PrivPtrType, "", ACS.getInstruction()); - // Traverse the type, build GEPs and loads. if (auto *PrivStructType = dyn_cast<StructType>(PrivType)) { const StructLayout *PrivStructLayout = DL.getStructLayout(PrivStructType); for (unsigned u = 0, e = PrivStructType->getNumElements(); u < e; u++) { Type *PointeeTy = PrivStructType->getElementType(u); - Value *Ptr = - constructPointer(PointeeTy->getPointerTo(), PrivType, Base, - PrivStructLayout->getElementOffset(u), IRB, DL); + Value *Ptr = constructPointer( + PrivType, Base, PrivStructLayout->getElementOffset(u), IRB, DL); LoadInst *L = new LoadInst(PointeeTy, Ptr, "", IP); L->setAlignment(Alignment); ReplacementValues.push_back(L); @@ -7489,10 +7530,9 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { } else if (auto *PrivArrayType = dyn_cast<ArrayType>(PrivType)) { Type *PointeeTy = PrivArrayType->getElementType(); uint64_t PointeeTySize = DL.getTypeStoreSize(PointeeTy); - Type *PointeePtrTy = PointeeTy->getPointerTo(); for (unsigned u = 0, e = PrivArrayType->getNumElements(); u < e; u++) { - Value *Ptr = constructPointer(PointeePtrTy, PrivType, Base, - u * PointeeTySize, IRB, DL); + Value *Ptr = + constructPointer(PrivType, Base, u * PointeeTySize, IRB, DL); LoadInst *L = new LoadInst(PointeeTy, Ptr, "", IP); L->setAlignment(Alignment); ReplacementValues.push_back(L); @@ -7796,6 +7836,9 @@ struct AAMemoryBehaviorImpl : public AAMemoryBehavior { // Clear existing attributes. A.removeAttrs(IRP, AttrKinds); + // Clear conflicting writable attribute. + if (isAssumedReadOnly()) + A.removeAttrs(IRP, Attribute::Writable); // Use the generic manifest method. return IRAttribute::manifest(A); @@ -7983,6 +8026,10 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl { ME = MemoryEffects::writeOnly(); A.removeAttrs(getIRPosition(), AttrKinds); + // Clear conflicting writable attribute. + if (ME.onlyReadsMemory()) + for (Argument &Arg : F.args()) + A.removeAttrs(IRPosition::argument(Arg), Attribute::Writable); return A.manifestAttrs(getIRPosition(), Attribute::getWithMemoryEffects(F.getContext(), ME)); } @@ -7999,24 +8046,10 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl { }; /// AAMemoryBehavior attribute for call sites. -struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { +struct AAMemoryBehaviorCallSite final + : AACalleeToCallSite<AAMemoryBehavior, AAMemoryBehaviorImpl> { AAMemoryBehaviorCallSite(const IRPosition &IRP, Attributor &A) - : AAMemoryBehaviorImpl(IRP, A) {} - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - // TODO: Once we have call site specific value information we can provide - // call site specific liveness liveness information and then it makes - // sense to specialize attributes for call sites arguments instead of - // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto *FnAA = - A.getAAFor<AAMemoryBehavior>(*this, FnPos, DepClassTy::REQUIRED); - if (!FnAA) - return indicatePessimisticFixpoint(); - return clampStateAndIndicateChange(getState(), FnAA->getState()); - } + : AACalleeToCallSite<AAMemoryBehavior, AAMemoryBehaviorImpl>(IRP, A) {} /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { @@ -8031,6 +8064,11 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { ME = MemoryEffects::writeOnly(); A.removeAttrs(getIRPosition(), AttrKinds); + // Clear conflicting writable attribute. + if (ME.onlyReadsMemory()) + for (Use &U : CB.args()) + A.removeAttrs(IRPosition::callsite_argument(CB, U.getOperandNo()), + Attribute::Writable); return A.manifestAttrs( getIRPosition(), Attribute::getWithMemoryEffects(CB.getContext(), ME)); } @@ -8821,6 +8859,108 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl { }; } // namespace +/// ------------------ denormal-fp-math Attribute ------------------------- + +namespace { +struct AADenormalFPMathImpl : public AADenormalFPMath { + AADenormalFPMathImpl(const IRPosition &IRP, Attributor &A) + : AADenormalFPMath(IRP, A) {} + + const std::string getAsStr(Attributor *A) const override { + std::string Str("AADenormalFPMath["); + raw_string_ostream OS(Str); + + DenormalState Known = getKnown(); + if (Known.Mode.isValid()) + OS << "denormal-fp-math=" << Known.Mode; + else + OS << "invalid"; + + if (Known.ModeF32.isValid()) + OS << " denormal-fp-math-f32=" << Known.ModeF32; + OS << ']'; + return OS.str(); + } +}; + +struct AADenormalFPMathFunction final : AADenormalFPMathImpl { + AADenormalFPMathFunction(const IRPosition &IRP, Attributor &A) + : AADenormalFPMathImpl(IRP, A) {} + + void initialize(Attributor &A) override { + const Function *F = getAnchorScope(); + DenormalMode Mode = F->getDenormalModeRaw(); + DenormalMode ModeF32 = F->getDenormalModeF32Raw(); + + // TODO: Handling this here prevents handling the case where a callee has a + // fixed denormal-fp-math with dynamic denormal-fp-math-f32, but called from + // a function with a fully fixed mode. + if (ModeF32 == DenormalMode::getInvalid()) + ModeF32 = Mode; + Known = DenormalState{Mode, ModeF32}; + if (isModeFixed()) + indicateFixpoint(); + } + + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + auto CheckCallSite = [=, &Change, &A](AbstractCallSite CS) { + Function *Caller = CS.getInstruction()->getFunction(); + LLVM_DEBUG(dbgs() << "[AADenormalFPMath] Call " << Caller->getName() + << "->" << getAssociatedFunction()->getName() << '\n'); + + const auto *CallerInfo = A.getAAFor<AADenormalFPMath>( + *this, IRPosition::function(*Caller), DepClassTy::REQUIRED); + if (!CallerInfo) + return false; + + Change = Change | clampStateAndIndicateChange(this->getState(), + CallerInfo->getState()); + return true; + }; + + bool AllCallSitesKnown = true; + if (!A.checkForAllCallSites(CheckCallSite, *this, true, AllCallSitesKnown)) + return indicatePessimisticFixpoint(); + + if (Change == ChangeStatus::CHANGED && isModeFixed()) + indicateFixpoint(); + return Change; + } + + ChangeStatus manifest(Attributor &A) override { + LLVMContext &Ctx = getAssociatedFunction()->getContext(); + + SmallVector<Attribute, 2> AttrToAdd; + SmallVector<StringRef, 2> AttrToRemove; + if (Known.Mode == DenormalMode::getDefault()) { + AttrToRemove.push_back("denormal-fp-math"); + } else { + AttrToAdd.push_back( + Attribute::get(Ctx, "denormal-fp-math", Known.Mode.str())); + } + + if (Known.ModeF32 != Known.Mode) { + AttrToAdd.push_back( + Attribute::get(Ctx, "denormal-fp-math-f32", Known.ModeF32.str())); + } else { + AttrToRemove.push_back("denormal-fp-math-f32"); + } + + auto &IRP = getIRPosition(); + + // TODO: There should be a combined add and remove API. + return A.removeAttrs(IRP, AttrToRemove) | + A.manifestAttrs(IRP, AttrToAdd, /*ForceReplace=*/true); + } + + void trackStatistics() const override { + STATS_DECLTRACK_FN_ATTR(denormal_fp_math) + } +}; +} // namespace + /// ------------------ Value Constant Range Attribute ------------------------- namespace { @@ -8911,7 +9051,8 @@ struct AAValueConstantRangeImpl : AAValueConstantRange { if (!LVI || !CtxI) return getWorstState(getBitWidth()); return LVI->getConstantRange(&getAssociatedValue(), - const_cast<Instruction *>(CtxI)); + const_cast<Instruction *>(CtxI), + /*UndefAllowed*/ false); } /// Return true if \p CtxI is valid for querying outside analyses. @@ -9427,17 +9568,13 @@ struct AAValueConstantRangeCallSite : AAValueConstantRangeFunction { }; struct AAValueConstantRangeCallSiteReturned - : AACallSiteReturnedFromReturned<AAValueConstantRange, - AAValueConstantRangeImpl, - AAValueConstantRangeImpl::StateType, - /* IntroduceCallBaseContext */ true> { + : AACalleeToCallSite<AAValueConstantRange, AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* IntroduceCallBaseContext */ true> { AAValueConstantRangeCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AAValueConstantRange, - AAValueConstantRangeImpl, - AAValueConstantRangeImpl::StateType, - /* IntroduceCallBaseContext */ true>(IRP, - A) { - } + : AACalleeToCallSite<AAValueConstantRange, AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* IntroduceCallBaseContext */ true>(IRP, A) {} /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { @@ -9956,12 +10093,12 @@ struct AAPotentialConstantValuesCallSite : AAPotentialConstantValuesFunction { }; struct AAPotentialConstantValuesCallSiteReturned - : AACallSiteReturnedFromReturned<AAPotentialConstantValues, - AAPotentialConstantValuesImpl> { + : AACalleeToCallSite<AAPotentialConstantValues, + AAPotentialConstantValuesImpl> { AAPotentialConstantValuesCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AAPotentialConstantValues, - AAPotentialConstantValuesImpl>(IRP, A) {} + : AACalleeToCallSite<AAPotentialConstantValues, + AAPotentialConstantValuesImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -10101,7 +10238,8 @@ struct AANoUndefFloating : public AANoUndefImpl { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { AANoUndefImpl::initialize(A); - if (!getState().isAtFixpoint()) + if (!getState().isAtFixpoint() && getAnchorScope() && + !getAnchorScope()->isDeclaration()) if (Instruction *CtxI = getCtxI()) followUsesInMBEC(*this, A, getState(), *CtxI); } @@ -10148,26 +10286,18 @@ struct AANoUndefFloating : public AANoUndefImpl { }; struct AANoUndefReturned final - : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef> { + : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl> { AANoUndefReturned(const IRPosition &IRP, Attributor &A) - : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef>(IRP, A) {} + : AAReturnedFromReturnedValues<AANoUndef, AANoUndefImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_FNRET_ATTR(noundef) } }; struct AANoUndefArgument final - : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef> { + : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl> { AANoUndefArgument(const IRPosition &IRP, Attributor &A) - : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef>(IRP, A) {} + : AAArgumentFromCallSiteArguments<AANoUndef, AANoUndefImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(noundef) } @@ -10182,13 +10312,9 @@ struct AANoUndefCallSiteArgument final : AANoUndefFloating { }; struct AANoUndefCallSiteReturned final - : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef> { + : AACalleeToCallSite<AANoUndef, AANoUndefImpl> { AANoUndefCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AANoUndef, AANoUndefImpl, - AANoUndef::StateType, false, - Attribute::NoUndef>(IRP, A) {} + : AACalleeToCallSite<AANoUndef, AANoUndefImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noundef) } @@ -10212,7 +10338,6 @@ struct AANoFPClassImpl : AANoFPClass { A.getAttrs(getIRPosition(), {Attribute::NoFPClass}, Attrs, false); for (const auto &Attr : Attrs) { addKnownBits(Attr.getNoFPClass()); - return; } const DataLayout &DL = A.getDataLayout(); @@ -10248,8 +10373,22 @@ struct AANoFPClassImpl : AANoFPClass { /*Depth=*/0, TLI, AC, I, DT); State.addKnownBits(~KnownFPClass.KnownFPClasses); - bool TrackUse = false; - return TrackUse; + if (auto *CI = dyn_cast<CallInst>(UseV)) { + // Special case FP intrinsic with struct return type. + switch (CI->getIntrinsicID()) { + case Intrinsic::frexp: + return true; + case Intrinsic::not_intrinsic: + // TODO: Could recognize math libcalls + return false; + default: + break; + } + } + + if (!UseV->getType()->isFPOrFPVectorTy()) + return false; + return !isa<LoadInst, AtomicRMWInst>(UseV); } const std::string getAsStr(Attributor *A) const override { @@ -10339,9 +10478,9 @@ struct AANoFPClassCallSiteArgument final : AANoFPClassFloating { }; struct AANoFPClassCallSiteReturned final - : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl> { + : AACalleeToCallSite<AANoFPClass, AANoFPClassImpl> { AANoFPClassCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AANoFPClass, AANoFPClassImpl>(IRP, A) {} + : AACalleeToCallSite<AANoFPClass, AANoFPClassImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -10446,15 +10585,12 @@ struct AACallEdgesCallSite : public AACallEdgesImpl { return Change; } - // Process callee metadata if available. - if (auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees)) { - for (const auto &Op : MD->operands()) { - Function *Callee = mdconst::dyn_extract_or_null<Function>(Op); - if (Callee) - addCalledFunction(Callee, Change); - } - return Change; - } + if (CB->isIndirectCall()) + if (auto *IndirectCallAA = A.getAAFor<AAIndirectCallInfo>( + *this, getIRPosition(), DepClassTy::OPTIONAL)) + if (IndirectCallAA->foreachCallee( + [&](Function *Fn) { return VisitValue(*Fn, CB); })) + return Change; // The most simple case. ProcessCalledOperand(CB->getCalledOperand(), CB); @@ -10519,28 +10655,26 @@ struct AAInterFnReachabilityFunction bool instructionCanReach( Attributor &A, const Instruction &From, const Function &To, - const AA::InstExclusionSetTy *ExclusionSet, - SmallPtrSet<const Function *, 16> *Visited) const override { + const AA::InstExclusionSetTy *ExclusionSet) const override { assert(From.getFunction() == getAnchorScope() && "Queried the wrong AA!"); auto *NonConstThis = const_cast<AAInterFnReachabilityFunction *>(this); RQITy StackRQI(A, From, To, ExclusionSet, false); typename RQITy::Reachable Result; if (!NonConstThis->checkQueryCache(A, StackRQI, Result)) - return NonConstThis->isReachableImpl(A, StackRQI); + return NonConstThis->isReachableImpl(A, StackRQI, + /*IsTemporaryRQI=*/true); return Result == RQITy::Reachable::Yes; } - bool isReachableImpl(Attributor &A, RQITy &RQI) override { - return isReachableImpl(A, RQI, nullptr); - } - bool isReachableImpl(Attributor &A, RQITy &RQI, - SmallPtrSet<const Function *, 16> *Visited) { - - SmallPtrSet<const Function *, 16> LocalVisited; - if (!Visited) - Visited = &LocalVisited; + bool IsTemporaryRQI) override { + const Instruction *EntryI = + &RQI.From->getFunction()->getEntryBlock().front(); + if (EntryI != RQI.From && + !instructionCanReach(A, *EntryI, *RQI.To, nullptr)) + return rememberResult(A, RQITy::Reachable::No, RQI, false, + IsTemporaryRQI); auto CheckReachableCallBase = [&](CallBase *CB) { auto *CBEdges = A.getAAFor<AACallEdges>( @@ -10554,8 +10688,7 @@ struct AAInterFnReachabilityFunction for (Function *Fn : CBEdges->getOptimisticEdges()) { if (Fn == RQI.To) return false; - if (!Visited->insert(Fn).second) - continue; + if (Fn->isDeclaration()) { if (Fn->hasFnAttribute(Attribute::NoCallback)) continue; @@ -10563,15 +10696,20 @@ struct AAInterFnReachabilityFunction return false; } - const AAInterFnReachability *InterFnReachability = this; - if (Fn != getAnchorScope()) - InterFnReachability = A.getAAFor<AAInterFnReachability>( - *this, IRPosition::function(*Fn), DepClassTy::OPTIONAL); + if (Fn == getAnchorScope()) { + if (EntryI == RQI.From) + continue; + return false; + } + + const AAInterFnReachability *InterFnReachability = + A.getAAFor<AAInterFnReachability>(*this, IRPosition::function(*Fn), + DepClassTy::OPTIONAL); const Instruction &FnFirstInst = Fn->getEntryBlock().front(); if (!InterFnReachability || InterFnReachability->instructionCanReach(A, FnFirstInst, *RQI.To, - RQI.ExclusionSet, Visited)) + RQI.ExclusionSet)) return false; } return true; @@ -10583,10 +10721,12 @@ struct AAInterFnReachabilityFunction // Determine call like instructions that we can reach from the inst. auto CheckCallBase = [&](Instruction &CBInst) { - if (!IntraFnReachability || !IntraFnReachability->isAssumedReachable( - A, *RQI.From, CBInst, RQI.ExclusionSet)) + // There are usually less nodes in the call graph, check inter function + // reachability first. + if (CheckReachableCallBase(cast<CallBase>(&CBInst))) return true; - return CheckReachableCallBase(cast<CallBase>(&CBInst)); + return IntraFnReachability && !IntraFnReachability->isAssumedReachable( + A, *RQI.From, CBInst, RQI.ExclusionSet); }; bool UsedExclusionSet = /* conservative */ true; @@ -10594,16 +10734,14 @@ struct AAInterFnReachabilityFunction if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this, UsedAssumedInformation, /* CheckBBLivenessOnly */ true)) - return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet, + IsTemporaryRQI); - return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet); + return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet, + IsTemporaryRQI); } void trackStatistics() const override {} - -private: - SmallVector<RQITy *> QueryVector; - DenseSet<RQITy *> QueryCache; }; } // namespace @@ -10880,64 +11018,104 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { // Simplify the operands first. bool UsedAssumedInformation = false; - const auto &SimplifiedLHS = A.getAssumedSimplified( - IRPosition::value(*LHS, getCallBaseContext()), *this, - UsedAssumedInformation, AA::Intraprocedural); - if (!SimplifiedLHS.has_value()) + SmallVector<AA::ValueAndContext> LHSValues, RHSValues; + auto GetSimplifiedValues = [&](Value &V, + SmallVector<AA::ValueAndContext> &Values) { + if (!A.getAssumedSimplifiedValues( + IRPosition::value(V, getCallBaseContext()), this, Values, + AA::Intraprocedural, UsedAssumedInformation)) { + Values.clear(); + Values.push_back(AA::ValueAndContext{V, II.I.getCtxI()}); + } + return Values.empty(); + }; + if (GetSimplifiedValues(*LHS, LHSValues)) return true; - if (!*SimplifiedLHS) - return false; - LHS = *SimplifiedLHS; - - const auto &SimplifiedRHS = A.getAssumedSimplified( - IRPosition::value(*RHS, getCallBaseContext()), *this, - UsedAssumedInformation, AA::Intraprocedural); - if (!SimplifiedRHS.has_value()) + if (GetSimplifiedValues(*RHS, RHSValues)) return true; - if (!*SimplifiedRHS) - return false; - RHS = *SimplifiedRHS; LLVMContext &Ctx = LHS->getContext(); - // Handle the trivial case first in which we don't even need to think about - // null or non-null. - if (LHS == RHS && - (CmpInst::isTrueWhenEqual(Pred) || CmpInst::isFalseWhenEqual(Pred))) { - Constant *NewV = ConstantInt::get(Type::getInt1Ty(Ctx), - CmpInst::isTrueWhenEqual(Pred)); - addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, - getAnchorScope()); - return true; - } - // From now on we only handle equalities (==, !=). - if (!CmpInst::isEquality(Pred)) - return false; + InformationCache &InfoCache = A.getInfoCache(); + Instruction *CmpI = dyn_cast<Instruction>(&Cmp); + Function *F = CmpI ? CmpI->getFunction() : nullptr; + const auto *DT = + F ? InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F) + : nullptr; + const auto *TLI = + F ? A.getInfoCache().getTargetLibraryInfoForFunction(*F) : nullptr; + auto *AC = + F ? InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F) + : nullptr; - bool LHSIsNull = isa<ConstantPointerNull>(LHS); - bool RHSIsNull = isa<ConstantPointerNull>(RHS); - if (!LHSIsNull && !RHSIsNull) - return false; + const DataLayout &DL = A.getDataLayout(); + SimplifyQuery Q(DL, TLI, DT, AC, CmpI); - // Left is the nullptr ==/!= non-nullptr case. We'll use AANonNull on the - // non-nullptr operand and if we assume it's non-null we can conclude the - // result of the comparison. - assert((LHSIsNull || RHSIsNull) && - "Expected nullptr versus non-nullptr comparison at this point"); + auto CheckPair = [&](Value &LHSV, Value &RHSV) { + if (isa<UndefValue>(LHSV) || isa<UndefValue>(RHSV)) { + addValue(A, getState(), *UndefValue::get(Cmp.getType()), + /* CtxI */ nullptr, II.S, getAnchorScope()); + return true; + } - // The index is the operand that we assume is not null. - unsigned PtrIdx = LHSIsNull; - bool IsKnownNonNull; - bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>( - A, this, IRPosition::value(*(PtrIdx ? RHS : LHS)), DepClassTy::REQUIRED, - IsKnownNonNull); - if (!IsAssumedNonNull) - return false; + // Handle the trivial case first in which we don't even need to think + // about null or non-null. + if (&LHSV == &RHSV && + (CmpInst::isTrueWhenEqual(Pred) || CmpInst::isFalseWhenEqual(Pred))) { + Constant *NewV = ConstantInt::get(Type::getInt1Ty(Ctx), + CmpInst::isTrueWhenEqual(Pred)); + addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, + getAnchorScope()); + return true; + } + + auto *TypedLHS = AA::getWithType(LHSV, *LHS->getType()); + auto *TypedRHS = AA::getWithType(RHSV, *RHS->getType()); + if (TypedLHS && TypedRHS) { + Value *NewV = simplifyCmpInst(Pred, TypedLHS, TypedRHS, Q); + if (NewV && NewV != &Cmp) { + addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, + getAnchorScope()); + return true; + } + } + + // From now on we only handle equalities (==, !=). + if (!CmpInst::isEquality(Pred)) + return false; + + bool LHSIsNull = isa<ConstantPointerNull>(LHSV); + bool RHSIsNull = isa<ConstantPointerNull>(RHSV); + if (!LHSIsNull && !RHSIsNull) + return false; + + // Left is the nullptr ==/!= non-nullptr case. We'll use AANonNull on the + // non-nullptr operand and if we assume it's non-null we can conclude the + // result of the comparison. + assert((LHSIsNull || RHSIsNull) && + "Expected nullptr versus non-nullptr comparison at this point"); - // The new value depends on the predicate, true for != and false for ==. - Constant *NewV = - ConstantInt::get(Type::getInt1Ty(Ctx), Pred == CmpInst::ICMP_NE); - addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, getAnchorScope()); + // The index is the operand that we assume is not null. + unsigned PtrIdx = LHSIsNull; + bool IsKnownNonNull; + bool IsAssumedNonNull = AA::hasAssumedIRAttr<Attribute::NonNull>( + A, this, IRPosition::value(*(PtrIdx ? &RHSV : &LHSV)), + DepClassTy::REQUIRED, IsKnownNonNull); + if (!IsAssumedNonNull) + return false; + + // The new value depends on the predicate, true for != and false for ==. + Constant *NewV = + ConstantInt::get(Type::getInt1Ty(Ctx), Pred == CmpInst::ICMP_NE); + addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, + getAnchorScope()); + return true; + }; + + for (auto &LHSValue : LHSValues) + for (auto &RHSValue : RHSValues) + if (!CheckPair(*LHSValue.getValue(), *RHSValue.getValue())) + return false; return true; } @@ -11152,9 +11330,8 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { SmallVectorImpl<ItemInfo> &Worklist, SmallMapVector<const Function *, LivenessInfo, 4> &LivenessAAs) { if (auto *CI = dyn_cast<CmpInst>(&I)) - if (handleCmp(A, *CI, CI->getOperand(0), CI->getOperand(1), - CI->getPredicate(), II, Worklist)) - return true; + return handleCmp(A, *CI, CI->getOperand(0), CI->getOperand(1), + CI->getPredicate(), II, Worklist); switch (I.getOpcode()) { case Instruction::Select: @@ -11272,12 +11449,12 @@ struct AAPotentialValuesArgument final : AAPotentialValuesImpl { ChangeStatus updateImpl(Attributor &A) override { auto AssumedBefore = getAssumed(); - unsigned CSArgNo = getCallSiteArgNo(); + unsigned ArgNo = getCalleeArgNo(); bool UsedAssumedInformation = false; SmallVector<AA::ValueAndContext> Values; auto CallSitePred = [&](AbstractCallSite ACS) { - const auto CSArgIRP = IRPosition::callsite_argument(ACS, CSArgNo); + const auto CSArgIRP = IRPosition::callsite_argument(ACS, ArgNo); if (CSArgIRP.getPositionKind() == IRP_INVALID) return false; @@ -11889,6 +12066,455 @@ struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl { }; } // namespace +/// ------------------------ Global Value Info ------------------------------- +namespace { +struct AAGlobalValueInfoFloating : public AAGlobalValueInfo { + AAGlobalValueInfoFloating(const IRPosition &IRP, Attributor &A) + : AAGlobalValueInfo(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override {} + + bool checkUse(Attributor &A, const Use &U, bool &Follow, + SmallVectorImpl<const Value *> &Worklist) { + Instruction *UInst = dyn_cast<Instruction>(U.getUser()); + if (!UInst) { + Follow = true; + return true; + } + + LLVM_DEBUG(dbgs() << "[AAGlobalValueInfo] Check use: " << *U.get() << " in " + << *UInst << "\n"); + + if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) { + int Idx = &Cmp->getOperandUse(0) == &U; + if (isa<Constant>(Cmp->getOperand(Idx))) + return true; + return U == &getAnchorValue(); + } + + // Explicitly catch return instructions. + if (isa<ReturnInst>(UInst)) { + auto CallSitePred = [&](AbstractCallSite ACS) { + Worklist.push_back(ACS.getInstruction()); + return true; + }; + bool UsedAssumedInformation = false; + // TODO: We should traverse the uses or add a "non-call-site" CB. + if (!A.checkForAllCallSites(CallSitePred, *UInst->getFunction(), + /*RequireAllCallSites=*/true, this, + UsedAssumedInformation)) + return false; + return true; + } + + // For now we only use special logic for call sites. However, the tracker + // itself knows about a lot of other non-capturing cases already. + auto *CB = dyn_cast<CallBase>(UInst); + if (!CB) + return false; + // Direct calls are OK uses. + if (CB->isCallee(&U)) + return true; + // Non-argument uses are scary. + if (!CB->isArgOperand(&U)) + return false; + // TODO: Iterate callees. + auto *Fn = dyn_cast<Function>(CB->getCalledOperand()); + if (!Fn || !A.isFunctionIPOAmendable(*Fn)) + return false; + + unsigned ArgNo = CB->getArgOperandNo(&U); + Worklist.push_back(Fn->getArg(ArgNo)); + return true; + } + + ChangeStatus updateImpl(Attributor &A) override { + unsigned NumUsesBefore = Uses.size(); + + SmallPtrSet<const Value *, 8> Visited; + SmallVector<const Value *> Worklist; + Worklist.push_back(&getAnchorValue()); + + auto UsePred = [&](const Use &U, bool &Follow) -> bool { + Uses.insert(&U); + switch (DetermineUseCaptureKind(U, nullptr)) { + case UseCaptureKind::NO_CAPTURE: + return checkUse(A, U, Follow, Worklist); + case UseCaptureKind::MAY_CAPTURE: + return checkUse(A, U, Follow, Worklist); + case UseCaptureKind::PASSTHROUGH: + Follow = true; + return true; + } + return true; + }; + auto EquivalentUseCB = [&](const Use &OldU, const Use &NewU) { + Uses.insert(&OldU); + return true; + }; + + while (!Worklist.empty()) { + const Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; + if (!A.checkForAllUses(UsePred, *this, *V, + /* CheckBBLivenessOnly */ true, + DepClassTy::OPTIONAL, + /* IgnoreDroppableUses */ true, EquivalentUseCB)) { + return indicatePessimisticFixpoint(); + } + } + + return Uses.size() == NumUsesBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + bool isPotentialUse(const Use &U) const override { + return !isValidState() || Uses.contains(&U); + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr(Attributor *A) const override { + return "[" + std::to_string(Uses.size()) + " uses]"; + } + + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(GlobalValuesTracked); + } + +private: + /// Set of (transitive) uses of this GlobalValue. + SmallPtrSet<const Use *, 8> Uses; +}; +} // namespace + +/// ------------------------ Indirect Call Info ------------------------------- +namespace { +struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo { + AAIndirectCallInfoCallSite(const IRPosition &IRP, Attributor &A) + : AAIndirectCallInfo(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees); + if (!MD && !A.isClosedWorldModule()) + return; + + if (MD) { + for (const auto &Op : MD->operands()) + if (Function *Callee = mdconst::dyn_extract_or_null<Function>(Op)) + PotentialCallees.insert(Callee); + } else if (A.isClosedWorldModule()) { + ArrayRef<Function *> IndirectlyCallableFunctions = + A.getInfoCache().getIndirectlyCallableFunctions(A); + PotentialCallees.insert(IndirectlyCallableFunctions.begin(), + IndirectlyCallableFunctions.end()); + } + + if (PotentialCallees.empty()) + indicateOptimisticFixpoint(); + } + + ChangeStatus updateImpl(Attributor &A) override { + CallBase *CB = cast<CallBase>(getCtxI()); + const Use &CalleeUse = CB->getCalledOperandUse(); + Value *FP = CB->getCalledOperand(); + + SmallSetVector<Function *, 4> AssumedCalleesNow; + bool AllCalleesKnownNow = AllCalleesKnown; + + auto CheckPotentialCalleeUse = [&](Function &PotentialCallee, + bool &UsedAssumedInformation) { + const auto *GIAA = A.getAAFor<AAGlobalValueInfo>( + *this, IRPosition::value(PotentialCallee), DepClassTy::OPTIONAL); + if (!GIAA || GIAA->isPotentialUse(CalleeUse)) + return true; + UsedAssumedInformation = !GIAA->isAtFixpoint(); + return false; + }; + + auto AddPotentialCallees = [&]() { + for (auto *PotentialCallee : PotentialCallees) { + bool UsedAssumedInformation = false; + if (CheckPotentialCalleeUse(*PotentialCallee, UsedAssumedInformation)) + AssumedCalleesNow.insert(PotentialCallee); + } + }; + + // Use simplification to find potential callees, if !callees was present, + // fallback to that set if necessary. + bool UsedAssumedInformation = false; + SmallVector<AA::ValueAndContext> Values; + if (!A.getAssumedSimplifiedValues(IRPosition::value(*FP), this, Values, + AA::ValueScope::AnyScope, + UsedAssumedInformation)) { + if (PotentialCallees.empty()) + return indicatePessimisticFixpoint(); + AddPotentialCallees(); + } + + // Try to find a reason for \p Fn not to be a potential callee. If none was + // found, add it to the assumed callees set. + auto CheckPotentialCallee = [&](Function &Fn) { + if (!PotentialCallees.empty() && !PotentialCallees.count(&Fn)) + return false; + + auto &CachedResult = FilterResults[&Fn]; + if (CachedResult.has_value()) + return CachedResult.value(); + + bool UsedAssumedInformation = false; + if (!CheckPotentialCalleeUse(Fn, UsedAssumedInformation)) { + if (!UsedAssumedInformation) + CachedResult = false; + return false; + } + + int NumFnArgs = Fn.arg_size(); + int NumCBArgs = CB->arg_size(); + + // Check if any excess argument (which we fill up with poison) is known to + // be UB on undef. + for (int I = NumCBArgs; I < NumFnArgs; ++I) { + bool IsKnown = false; + if (AA::hasAssumedIRAttr<Attribute::NoUndef>( + A, this, IRPosition::argument(*Fn.getArg(I)), + DepClassTy::OPTIONAL, IsKnown)) { + if (IsKnown) + CachedResult = false; + return false; + } + } + + CachedResult = true; + return true; + }; + + // Check simplification result, prune known UB callees, also restrict it to + // the !callees set, if present. + for (auto &VAC : Values) { + if (isa<UndefValue>(VAC.getValue())) + continue; + if (isa<ConstantPointerNull>(VAC.getValue()) && + VAC.getValue()->getType()->getPointerAddressSpace() == 0) + continue; + // TODO: Check for known UB, e.g., poison + noundef. + if (auto *VACFn = dyn_cast<Function>(VAC.getValue())) { + if (CheckPotentialCallee(*VACFn)) + AssumedCalleesNow.insert(VACFn); + continue; + } + if (!PotentialCallees.empty()) { + AddPotentialCallees(); + break; + } + AllCalleesKnownNow = false; + } + + if (AssumedCalleesNow == AssumedCallees && + AllCalleesKnown == AllCalleesKnownNow) + return ChangeStatus::UNCHANGED; + + std::swap(AssumedCallees, AssumedCalleesNow); + AllCalleesKnown = AllCalleesKnownNow; + return ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + // If we can't specialize at all, give up now. + if (!AllCalleesKnown && AssumedCallees.empty()) + return ChangeStatus::UNCHANGED; + + CallBase *CB = cast<CallBase>(getCtxI()); + bool UsedAssumedInformation = false; + if (A.isAssumedDead(*CB, this, /*LivenessAA=*/nullptr, + UsedAssumedInformation)) + return ChangeStatus::UNCHANGED; + + ChangeStatus Changed = ChangeStatus::UNCHANGED; + Value *FP = CB->getCalledOperand(); + if (FP->getType()->getPointerAddressSpace()) + FP = new AddrSpaceCastInst(FP, PointerType::get(FP->getType(), 0), + FP->getName() + ".as0", CB); + + bool CBIsVoid = CB->getType()->isVoidTy(); + Instruction *IP = CB; + FunctionType *CSFT = CB->getFunctionType(); + SmallVector<Value *> CSArgs(CB->arg_begin(), CB->arg_end()); + + // If we know all callees and there are none, the call site is (effectively) + // dead (or UB). + if (AssumedCallees.empty()) { + assert(AllCalleesKnown && + "Expected all callees to be known if there are none."); + A.changeToUnreachableAfterManifest(CB); + return ChangeStatus::CHANGED; + } + + // Special handling for the single callee case. + if (AllCalleesKnown && AssumedCallees.size() == 1) { + auto *NewCallee = AssumedCallees.front(); + if (isLegalToPromote(*CB, NewCallee)) { + promoteCall(*CB, NewCallee, nullptr); + return ChangeStatus::CHANGED; + } + Instruction *NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), + CSArgs, CB->getName(), CB); + if (!CBIsVoid) + A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewCall); + A.deleteAfterManifest(*CB); + return ChangeStatus::CHANGED; + } + + // For each potential value we create a conditional + // + // ``` + // if (ptr == value) value(args); + // else ... + // ``` + // + bool SpecializedForAnyCallees = false; + bool SpecializedForAllCallees = AllCalleesKnown; + ICmpInst *LastCmp = nullptr; + SmallVector<Function *, 8> SkippedAssumedCallees; + SmallVector<std::pair<CallInst *, Instruction *>> NewCalls; + for (Function *NewCallee : AssumedCallees) { + if (!A.shouldSpecializeCallSiteForCallee(*this, *CB, *NewCallee)) { + SkippedAssumedCallees.push_back(NewCallee); + SpecializedForAllCallees = false; + continue; + } + SpecializedForAnyCallees = true; + + LastCmp = new ICmpInst(IP, llvm::CmpInst::ICMP_EQ, FP, NewCallee); + Instruction *ThenTI = + SplitBlockAndInsertIfThen(LastCmp, IP, /* Unreachable */ false); + BasicBlock *CBBB = CB->getParent(); + A.registerManifestAddedBasicBlock(*ThenTI->getParent()); + A.registerManifestAddedBasicBlock(*CBBB); + auto *SplitTI = cast<BranchInst>(LastCmp->getNextNode()); + BasicBlock *ElseBB; + if (IP == CB) { + ElseBB = BasicBlock::Create(ThenTI->getContext(), "", + ThenTI->getFunction(), CBBB); + A.registerManifestAddedBasicBlock(*ElseBB); + IP = BranchInst::Create(CBBB, ElseBB); + SplitTI->replaceUsesOfWith(CBBB, ElseBB); + } else { + ElseBB = IP->getParent(); + ThenTI->replaceUsesOfWith(ElseBB, CBBB); + } + CastInst *RetBC = nullptr; + CallInst *NewCall = nullptr; + if (isLegalToPromote(*CB, NewCallee)) { + auto *CBClone = cast<CallBase>(CB->clone()); + CBClone->insertBefore(ThenTI); + NewCall = &cast<CallInst>(promoteCall(*CBClone, NewCallee, &RetBC)); + } else { + NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), CSArgs, + CB->getName(), ThenTI); + } + NewCalls.push_back({NewCall, RetBC}); + } + + auto AttachCalleeMetadata = [&](CallBase &IndirectCB) { + if (!AllCalleesKnown) + return ChangeStatus::UNCHANGED; + MDBuilder MDB(IndirectCB.getContext()); + MDNode *Callees = MDB.createCallees(SkippedAssumedCallees); + IndirectCB.setMetadata(LLVMContext::MD_callees, Callees); + return ChangeStatus::CHANGED; + }; + + if (!SpecializedForAnyCallees) + return AttachCalleeMetadata(*CB); + + // Check if we need the fallback indirect call still. + if (SpecializedForAllCallees) { + LastCmp->replaceAllUsesWith(ConstantInt::getTrue(LastCmp->getContext())); + LastCmp->eraseFromParent(); + new UnreachableInst(IP->getContext(), IP); + IP->eraseFromParent(); + } else { + auto *CBClone = cast<CallInst>(CB->clone()); + CBClone->setName(CB->getName()); + CBClone->insertBefore(IP); + NewCalls.push_back({CBClone, nullptr}); + AttachCalleeMetadata(*CBClone); + } + + // Check if we need a PHI to merge the results. + if (!CBIsVoid) { + auto *PHI = PHINode::Create(CB->getType(), NewCalls.size(), + CB->getName() + ".phi", + &*CB->getParent()->getFirstInsertionPt()); + for (auto &It : NewCalls) { + CallBase *NewCall = It.first; + Instruction *CallRet = It.second ? It.second : It.first; + if (CallRet->getType() == CB->getType()) + PHI->addIncoming(CallRet, CallRet->getParent()); + else if (NewCall->getType()->isVoidTy()) + PHI->addIncoming(PoisonValue::get(CB->getType()), + NewCall->getParent()); + else + llvm_unreachable("Call return should match or be void!"); + } + A.changeAfterManifest(IRPosition::callsite_returned(*CB), *PHI); + } + + A.deleteAfterManifest(*CB); + Changed = ChangeStatus::CHANGED; + + return Changed; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr(Attributor *A) const override { + return std::string(AllCalleesKnown ? "eliminate" : "specialize") + + " indirect call site with " + std::to_string(AssumedCallees.size()) + + " functions"; + } + + void trackStatistics() const override { + if (AllCalleesKnown) { + STATS_DECLTRACK( + Eliminated, CallSites, + "Number of indirect call sites eliminated via specialization") + } else { + STATS_DECLTRACK(Specialized, CallSites, + "Number of indirect call sites specialized") + } + } + + bool foreachCallee(function_ref<bool(Function *)> CB) const override { + return isValidState() && AllCalleesKnown && all_of(AssumedCallees, CB); + } + +private: + /// Map to remember filter results. + DenseMap<Function *, std::optional<bool>> FilterResults; + + /// If the !callee metadata was present, this set will contain all potential + /// callees (superset). + SmallSetVector<Function *, 4> PotentialCallees; + + /// This set contains all currently assumed calllees, which might grow over + /// time. + SmallSetVector<Function *, 4> AssumedCallees; + + /// Flag to indicate if all possible callees are in the AssumedCallees set or + /// if there could be others. + bool AllCalleesKnown = true; +}; +} // namespace + /// ------------------------ Address Space ------------------------------------ namespace { struct AAAddressSpaceImpl : public AAAddressSpace { @@ -11961,8 +12587,13 @@ struct AAAddressSpaceImpl : public AAAddressSpace { // CGSCC if the AA is run on CGSCC instead of the entire module. if (!A.isRunOn(Inst->getFunction())) return true; - if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) + if (isa<LoadInst>(Inst)) MakeChange(Inst, const_cast<Use &>(U)); + if (isa<StoreInst>(Inst)) { + // We only make changes if the use is the pointer operand. + if (U.getOperandNo() == 1) + MakeChange(Inst, const_cast<Use &>(U)); + } return true; }; @@ -12064,6 +12695,224 @@ struct AAAddressSpaceCallSiteArgument final : AAAddressSpaceImpl { }; } // namespace +/// ----------- Allocation Info ---------- +namespace { +struct AAAllocationInfoImpl : public AAAllocationInfo { + AAAllocationInfoImpl(const IRPosition &IRP, Attributor &A) + : AAAllocationInfo(IRP, A) {} + + std::optional<TypeSize> getAllocatedSize() const override { + assert(isValidState() && "the AA is invalid"); + return AssumedAllocatedSize; + } + + std::optional<TypeSize> findInitialAllocationSize(Instruction *I, + const DataLayout &DL) { + + // TODO: implement case for malloc like instructions + switch (I->getOpcode()) { + case Instruction::Alloca: { + AllocaInst *AI = cast<AllocaInst>(I); + return AI->getAllocationSize(DL); + } + default: + return std::nullopt; + } + } + + ChangeStatus updateImpl(Attributor &A) override { + + const IRPosition &IRP = getIRPosition(); + Instruction *I = IRP.getCtxI(); + + // TODO: update check for malloc like calls + if (!isa<AllocaInst>(I)) + return indicatePessimisticFixpoint(); + + bool IsKnownNoCapture; + if (!AA::hasAssumedIRAttr<Attribute::NoCapture>( + A, this, IRP, DepClassTy::OPTIONAL, IsKnownNoCapture)) + return indicatePessimisticFixpoint(); + + const AAPointerInfo *PI = + A.getOrCreateAAFor<AAPointerInfo>(IRP, *this, DepClassTy::REQUIRED); + + if (!PI) + return indicatePessimisticFixpoint(); + + if (!PI->getState().isValidState()) + return indicatePessimisticFixpoint(); + + const DataLayout &DL = A.getDataLayout(); + const auto AllocationSize = findInitialAllocationSize(I, DL); + + // If allocation size is nullopt, we give up. + if (!AllocationSize) + return indicatePessimisticFixpoint(); + + // For zero sized allocations, we give up. + // Since we can't reduce further + if (*AllocationSize == 0) + return indicatePessimisticFixpoint(); + + int64_t BinSize = PI->numOffsetBins(); + + // TODO: implement for multiple bins + if (BinSize > 1) + return indicatePessimisticFixpoint(); + + if (BinSize == 0) { + auto NewAllocationSize = std::optional<TypeSize>(TypeSize(0, false)); + if (!changeAllocationSize(NewAllocationSize)) + return ChangeStatus::UNCHANGED; + return ChangeStatus::CHANGED; + } + + // TODO: refactor this to be part of multiple bin case + const auto &It = PI->begin(); + + // TODO: handle if Offset is not zero + if (It->first.Offset != 0) + return indicatePessimisticFixpoint(); + + uint64_t SizeOfBin = It->first.Offset + It->first.Size; + + if (SizeOfBin >= *AllocationSize) + return indicatePessimisticFixpoint(); + + auto NewAllocationSize = + std::optional<TypeSize>(TypeSize(SizeOfBin * 8, false)); + + if (!changeAllocationSize(NewAllocationSize)) + return ChangeStatus::UNCHANGED; + + return ChangeStatus::CHANGED; + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + + assert(isValidState() && + "Manifest should only be called if the state is valid."); + + Instruction *I = getIRPosition().getCtxI(); + + auto FixedAllocatedSizeInBits = getAllocatedSize()->getFixedValue(); + + unsigned long NumBytesToAllocate = (FixedAllocatedSizeInBits + 7) / 8; + + switch (I->getOpcode()) { + // TODO: add case for malloc like calls + case Instruction::Alloca: { + + AllocaInst *AI = cast<AllocaInst>(I); + + Type *CharType = Type::getInt8Ty(I->getContext()); + + auto *NumBytesToValue = + ConstantInt::get(I->getContext(), APInt(32, NumBytesToAllocate)); + + AllocaInst *NewAllocaInst = + new AllocaInst(CharType, AI->getAddressSpace(), NumBytesToValue, + AI->getAlign(), AI->getName(), AI->getNextNode()); + + if (A.changeAfterManifest(IRPosition::inst(*AI), *NewAllocaInst)) + return ChangeStatus::CHANGED; + + break; + } + default: + break; + } + + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr(Attributor *A) const override { + if (!isValidState()) + return "allocationinfo(<invalid>)"; + return "allocationinfo(" + + (AssumedAllocatedSize == HasNoAllocationSize + ? "none" + : std::to_string(AssumedAllocatedSize->getFixedValue())) + + ")"; + } + +private: + std::optional<TypeSize> AssumedAllocatedSize = HasNoAllocationSize; + + // Maintain the computed allocation size of the object. + // Returns (bool) weather the size of the allocation was modified or not. + bool changeAllocationSize(std::optional<TypeSize> Size) { + if (AssumedAllocatedSize == HasNoAllocationSize || + AssumedAllocatedSize != Size) { + AssumedAllocatedSize = Size; + return true; + } + return false; + } +}; + +struct AAAllocationInfoFloating : AAAllocationInfoImpl { + AAAllocationInfoFloating(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(allocationinfo); + } +}; + +struct AAAllocationInfoReturned : AAAllocationInfoImpl { + AAAllocationInfoReturned(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + // TODO: we don't rewrite function argument for now because it will need to + // rewrite the function signature and all call sites + (void)indicatePessimisticFixpoint(); + } + + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(allocationinfo); + } +}; + +struct AAAllocationInfoCallSiteReturned : AAAllocationInfoImpl { + AAAllocationInfoCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + void trackStatistics() const override { + STATS_DECLTRACK_CSRET_ATTR(allocationinfo); + } +}; + +struct AAAllocationInfoArgument : AAAllocationInfoImpl { + AAAllocationInfoArgument(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + void trackStatistics() const override { + STATS_DECLTRACK_ARG_ATTR(allocationinfo); + } +}; + +struct AAAllocationInfoCallSiteArgument : AAAllocationInfoImpl { + AAAllocationInfoCallSiteArgument(const IRPosition &IRP, Attributor &A) + : AAAllocationInfoImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + + (void)indicatePessimisticFixpoint(); + } + + void trackStatistics() const override { + STATS_DECLTRACK_CSARG_ATTR(allocationinfo); + } +}; +} // namespace + const char AANoUnwind::ID = 0; const char AANoSync::ID = 0; const char AANoFree::ID = 0; @@ -12097,6 +12946,10 @@ const char AAPointerInfo::ID = 0; const char AAAssumptionInfo::ID = 0; const char AAUnderlyingObjects::ID = 0; const char AAAddressSpace::ID = 0; +const char AAAllocationInfo::ID = 0; +const char AAIndirectCallInfo::ID = 0; +const char AAGlobalValueInfo::ID = 0; +const char AADenormalFPMath::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -12143,6 +12996,18 @@ const char AAAddressSpace::ID = 0; return *AA; \ } +#define CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(POS, SUFFIX, CLASS) \ + CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ + CLASS *AA = nullptr; \ + switch (IRP.getPositionKind()) { \ + SWITCH_PK_CREATE(CLASS, IRP, POS, SUFFIX) \ + default: \ + llvm_unreachable("Cannot create " #CLASS " for position otherthan " #POS \ + " position!"); \ + } \ + return *AA; \ + } + #define CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(CLASS) \ CLASS &CLASS::createForPosition(const IRPosition &IRP, Attributor &A) { \ CLASS *AA = nullptr; \ @@ -12215,17 +13080,24 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUndef) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFPClass) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPointerInfo) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAddressSpace) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAllocationInfo) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects) +CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_CALL_SITE, CallSite, + AAIndirectCallInfo) +CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION(IRP_FLOAT, Floating, + AAGlobalValueInfo) + CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonConvergent) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIntraFnReachability) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAInterFnReachability) +CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AADenormalFPMath) CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) @@ -12234,5 +13106,6 @@ CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) #undef CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION #undef CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION #undef CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION +#undef CREATE_ABSTRACT_ATTRIBUTE_FOR_ONE_POSITION #undef SWITCH_PK_CREATE #undef SWITCH_PK_INV diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index 93d15f59a036..5cc8258a495a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -85,7 +85,7 @@ void CrossDSOCFI::buildCFICheck(Module &M) { LLVMContext &Ctx = M.getContext(); FunctionCallee C = M.getOrInsertFunction( "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), - Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx)); + PointerType::getUnqual(Ctx), PointerType::getUnqual(Ctx)); Function *F = cast<Function>(C.getCallee()); // Take over the existing function. The frontend emits a weak stub so that the // linker knows about the symbol; this pass replaces the function body. @@ -110,9 +110,9 @@ void CrossDSOCFI::buildCFICheck(Module &M) { BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F); IRBuilder<> IRBFail(TrapBB); - FunctionCallee CFICheckFailFn = - M.getOrInsertFunction("__cfi_check_fail", Type::getVoidTy(Ctx), - Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx)); + FunctionCallee CFICheckFailFn = M.getOrInsertFunction( + "__cfi_check_fail", Type::getVoidTy(Ctx), PointerType::getUnqual(Ctx), + PointerType::getUnqual(Ctx)); IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr}); IRBFail.CreateBr(ExitBB); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index 01834015f3fd..4f65748c19e6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -174,6 +174,7 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) { NF->setComdat(F.getComdat()); F.getParent()->getFunctionList().insert(F.getIterator(), NF); NF->takeName(&F); + NF->IsNewDbgInfoFormat = F.IsNewDbgInfoFormat; // Loop over all the callers of the function, transforming the call sites // to pass in a smaller number of arguments into the new function. @@ -248,7 +249,7 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) { NF->addMetadata(KindID, *Node); // Fix up any BlockAddresses that refer to the function. - F.replaceAllUsesWith(ConstantExpr::getBitCast(NF, F.getType())); + F.replaceAllUsesWith(NF); // Delete the bitcast that we just created, so that NF does not // appear to be address-taken. NF->removeDeadConstantUsers(); @@ -877,6 +878,7 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { // it again. F->getParent()->getFunctionList().insert(F->getIterator(), NF); NF->takeName(F); + NF->IsNewDbgInfoFormat = F->IsNewDbgInfoFormat; // Loop over all the callers of the function, transforming the call sites to // pass in a smaller number of arguments into the new function. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp index fa56a5b564ae..48ef0772e800 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/EmbedBitcodePass.cpp @@ -7,8 +7,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/EmbedBitcodePass.h" -#include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/Bitcode/BitcodeWriterPass.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" @@ -16,10 +14,8 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/IPO/ThinLTOBitcodeWriter.h" -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" -#include <memory> #include <string> using namespace llvm; @@ -34,19 +30,9 @@ PreservedAnalyses EmbedBitcodePass::run(Module &M, ModuleAnalysisManager &AM) { report_fatal_error( "EmbedBitcode pass currently only supports ELF object format", /*gen_crash_diag=*/false); - - std::unique_ptr<Module> NewModule = CloneModule(M); - MPM.run(*NewModule, AM); - std::string Data; raw_string_ostream OS(Data); - if (IsThinLTO) - ThinLTOBitcodeWriterPass(OS, /*ThinLinkOS=*/nullptr).run(*NewModule, AM); - else - BitcodeWriterPass(OS, /*ShouldPreserveUseListOrder=*/false, EmitLTOSummary) - .run(*NewModule, AM); - + ThinLTOBitcodeWriterPass(OS, /*ThinLinkOS=*/nullptr).run(M, AM); embedBufferInModule(M, MemoryBufferRef(Data, "ModuleData"), ".llvm.lto"); - return PreservedAnalyses::all(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 74931e1032d1..9cf4e448c9b6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -11,38 +11,57 @@ #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LineIterator.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; #define DEBUG_TYPE "forceattrs" -static cl::list<std::string> - ForceAttributes("force-attribute", cl::Hidden, - cl::desc("Add an attribute to a function. This should be a " - "pair of 'function-name:attribute-name', for " - "example -force-attribute=foo:noinline. This " - "option can be specified multiple times.")); +static cl::list<std::string> ForceAttributes( + "force-attribute", cl::Hidden, + cl::desc( + "Add an attribute to a function. This can be a " + "pair of 'function-name:attribute-name', to apply an attribute to a " + "specific function. For " + "example -force-attribute=foo:noinline. Specifying only an attribute " + "will apply the attribute to every function in the module. This " + "option can be specified multiple times.")); static cl::list<std::string> ForceRemoveAttributes( "force-remove-attribute", cl::Hidden, - cl::desc("Remove an attribute from a function. This should be a " - "pair of 'function-name:attribute-name', for " - "example -force-remove-attribute=foo:noinline. This " + cl::desc("Remove an attribute from a function. This can be a " + "pair of 'function-name:attribute-name' to remove an attribute " + "from a specific function. For " + "example -force-remove-attribute=foo:noinline. Specifying only an " + "attribute will remove the attribute from all functions in the " + "module. This " "option can be specified multiple times.")); +static cl::opt<std::string> CSVFilePath( + "forceattrs-csv-path", cl::Hidden, + cl::desc( + "Path to CSV file containing lines of function names and attributes to " + "add to them in the form of `f1,attr1` or `f2,attr2=str`.")); + /// If F has any forced attributes given on the command line, add them. /// If F has any forced remove attributes given on the command line, remove /// them. When both force and force-remove are given to a function, the latter /// takes precedence. static void forceAttributes(Function &F) { auto ParseFunctionAndAttr = [&](StringRef S) { - auto Kind = Attribute::None; - auto KV = StringRef(S).split(':'); - if (KV.first != F.getName()) - return Kind; - Kind = Attribute::getAttrKindFromName(KV.second); + StringRef AttributeText; + if (S.contains(':')) { + auto KV = StringRef(S).split(':'); + if (KV.first != F.getName()) + return Attribute::None; + AttributeText = KV.second; + } else { + AttributeText = S; + } + auto Kind = Attribute::getAttrKindFromName(AttributeText); if (Kind == Attribute::None || !Attribute::canUseAsFnAttr(Kind)) { - LLVM_DEBUG(dbgs() << "ForcedAttribute: " << KV.second + LLVM_DEBUG(dbgs() << "ForcedAttribute: " << AttributeText << " unknown or not a function attribute!\n"); } return Kind; @@ -69,12 +88,52 @@ static bool hasForceAttributes() { PreservedAnalyses ForceFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &) { - if (!hasForceAttributes()) - return PreservedAnalyses::all(); - - for (Function &F : M.functions()) - forceAttributes(F); - - // Just conservatively invalidate analyses, this isn't likely to be important. - return PreservedAnalyses::none(); + bool Changed = false; + if (!CSVFilePath.empty()) { + auto BufferOrError = MemoryBuffer::getFileOrSTDIN(CSVFilePath); + if (!BufferOrError) + report_fatal_error("Cannot open CSV file."); + StringRef Buffer = BufferOrError.get()->getBuffer(); + auto MemoryBuffer = MemoryBuffer::getMemBuffer(Buffer); + line_iterator It(*MemoryBuffer); + for (; !It.is_at_end(); ++It) { + auto SplitPair = It->split(','); + if (SplitPair.second.empty()) + continue; + Function *Func = M.getFunction(SplitPair.first); + if (Func) { + if (Func->isDeclaration()) + continue; + auto SecondSplitPair = SplitPair.second.split('='); + if (!SecondSplitPair.second.empty()) { + Func->addFnAttr(SecondSplitPair.first, SecondSplitPair.second); + Changed = true; + } else { + auto AttrKind = Attribute::getAttrKindFromName(SplitPair.second); + if (AttrKind != Attribute::None && + Attribute::canUseAsFnAttr(AttrKind)) { + // TODO: There could be string attributes without a value, we should + // support those, too. + Func->addFnAttr(AttrKind); + Changed = true; + } else + errs() << "Cannot add " << SplitPair.second + << " as an attribute name.\n"; + } + } else { + errs() << "Function in CSV file at line " << It.line_number() + << " does not exist.\n"; + // TODO: `report_fatal_error at end of pass for missing functions. + continue; + } + } + } + if (hasForceAttributes()) { + for (Function &F : M.functions()) + forceAttributes(F); + Changed = true; + } + // Just conservatively invalidate analyses if we've made any changes, this + // isn't likely to be important. + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 34299f9dbb23..7c277518b21d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -110,6 +110,39 @@ using SCCNodeSet = SmallSetVector<Function *, 8>; } // end anonymous namespace +static void addLocAccess(MemoryEffects &ME, const MemoryLocation &Loc, + ModRefInfo MR, AAResults &AAR) { + // Ignore accesses to known-invariant or local memory. + MR &= AAR.getModRefInfoMask(Loc, /*IgnoreLocal=*/true); + if (isNoModRef(MR)) + return; + + const Value *UO = getUnderlyingObject(Loc.Ptr); + assert(!isa<AllocaInst>(UO) && + "Should have been handled by getModRefInfoMask()"); + if (isa<Argument>(UO)) { + ME |= MemoryEffects::argMemOnly(MR); + return; + } + + // If it's not an identified object, it might be an argument. + if (!isIdentifiedObject(UO)) + ME |= MemoryEffects::argMemOnly(MR); + ME |= MemoryEffects(IRMemLocation::Other, MR); +} + +static void addArgLocs(MemoryEffects &ME, const CallBase *Call, + ModRefInfo ArgMR, AAResults &AAR) { + for (const Value *Arg : Call->args()) { + if (!Arg->getType()->isPtrOrPtrVectorTy()) + continue; + + addLocAccess(ME, + MemoryLocation::getBeforeOrAfter(Arg, Call->getAAMetadata()), + ArgMR, AAR); + } +} + /// Returns the memory access attribute for function F using AAR for AA results, /// where SCCNodes is the current SCC. /// @@ -118,54 +151,48 @@ using SCCNodeSet = SmallSetVector<Function *, 8>; /// result will be based only on AA results for the function declaration; it /// will be assumed that some other (perhaps less optimized) version of the /// function may be selected at link time. -static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, - AAResults &AAR, - const SCCNodeSet &SCCNodes) { +/// +/// The return value is split into two parts: Memory effects that always apply, +/// and additional memory effects that apply if any of the functions in the SCC +/// can access argmem. +static std::pair<MemoryEffects, MemoryEffects> +checkFunctionMemoryAccess(Function &F, bool ThisBody, AAResults &AAR, + const SCCNodeSet &SCCNodes) { MemoryEffects OrigME = AAR.getMemoryEffects(&F); if (OrigME.doesNotAccessMemory()) // Already perfect! - return OrigME; + return {OrigME, MemoryEffects::none()}; if (!ThisBody) - return OrigME; + return {OrigME, MemoryEffects::none()}; MemoryEffects ME = MemoryEffects::none(); + // Additional locations accessed if the SCC accesses argmem. + MemoryEffects RecursiveArgME = MemoryEffects::none(); + // Inalloca and preallocated arguments are always clobbered by the call. if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca) || F.getAttributes().hasAttrSomewhere(Attribute::Preallocated)) ME |= MemoryEffects::argMemOnly(ModRefInfo::ModRef); - auto AddLocAccess = [&](const MemoryLocation &Loc, ModRefInfo MR) { - // Ignore accesses to known-invariant or local memory. - MR &= AAR.getModRefInfoMask(Loc, /*IgnoreLocal=*/true); - if (isNoModRef(MR)) - return; - - const Value *UO = getUnderlyingObject(Loc.Ptr); - assert(!isa<AllocaInst>(UO) && - "Should have been handled by getModRefInfoMask()"); - if (isa<Argument>(UO)) { - ME |= MemoryEffects::argMemOnly(MR); - return; - } - - // If it's not an identified object, it might be an argument. - if (!isIdentifiedObject(UO)) - ME |= MemoryEffects::argMemOnly(MR); - ME |= MemoryEffects(IRMemLocation::Other, MR); - }; // Scan the function body for instructions that may read or write memory. for (Instruction &I : instructions(F)) { // Some instructions can be ignored even if they read or write memory. // Detect these now, skipping to the next instruction if one is found. if (auto *Call = dyn_cast<CallBase>(&I)) { - // Ignore calls to functions in the same SCC, as long as the call sites - // don't have operand bundles. Calls with operand bundles are allowed to - // have memory effects not described by the memory effects of the call - // target. + // We can optimistically ignore calls to functions in the same SCC, with + // two caveats: + // * Calls with operand bundles may have additional effects. + // * Argument memory accesses may imply additional effects depending on + // what the argument location is. if (!Call->hasOperandBundles() && Call->getCalledFunction() && - SCCNodes.count(Call->getCalledFunction())) + SCCNodes.count(Call->getCalledFunction())) { + // Keep track of which additional locations are accessed if the SCC + // turns out to access argmem. + addArgLocs(RecursiveArgME, Call, ModRefInfo::ModRef, AAR); continue; + } + MemoryEffects CallME = AAR.getMemoryEffects(Call); // If the call doesn't access memory, we're done. @@ -190,15 +217,8 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, // Check whether all pointer arguments point to local memory, and // ignore calls that only access local memory. ModRefInfo ArgMR = CallME.getModRef(IRMemLocation::ArgMem); - if (ArgMR != ModRefInfo::NoModRef) { - for (const Use &U : Call->args()) { - const Value *Arg = U; - if (!Arg->getType()->isPtrOrPtrVectorTy()) - continue; - - AddLocAccess(MemoryLocation::getBeforeOrAfter(Arg, I.getAAMetadata()), ArgMR); - } - } + if (ArgMR != ModRefInfo::NoModRef) + addArgLocs(ME, Call, ArgMR, AAR); continue; } @@ -222,15 +242,15 @@ static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, if (I.isVolatile()) ME |= MemoryEffects::inaccessibleMemOnly(MR); - AddLocAccess(*Loc, MR); + addLocAccess(ME, *Loc, MR, AAR); } - return OrigME & ME; + return {OrigME & ME, RecursiveArgME}; } MemoryEffects llvm::computeFunctionBodyMemoryAccess(Function &F, AAResults &AAR) { - return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}); + return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}).first; } /// Deduce readonly/readnone/writeonly attributes for the SCC. @@ -238,24 +258,37 @@ template <typename AARGetterT> static void addMemoryAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter, SmallSet<Function *, 8> &Changed) { MemoryEffects ME = MemoryEffects::none(); + MemoryEffects RecursiveArgME = MemoryEffects::none(); for (Function *F : SCCNodes) { // Call the callable parameter to look up AA results for this function. AAResults &AAR = AARGetter(*F); // Non-exact function definitions may not be selected at link time, and an // alternative version that writes to memory may be selected. See the // comment on GlobalValue::isDefinitionExact for more details. - ME |= checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes); + auto [FnME, FnRecursiveArgME] = + checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes); + ME |= FnME; + RecursiveArgME |= FnRecursiveArgME; // Reached bottom of the lattice, we will not be able to improve the result. if (ME == MemoryEffects::unknown()) return; } + // If the SCC accesses argmem, add recursive accesses resulting from that. + ModRefInfo ArgMR = ME.getModRef(IRMemLocation::ArgMem); + if (ArgMR != ModRefInfo::NoModRef) + ME |= RecursiveArgME & MemoryEffects(ArgMR); + for (Function *F : SCCNodes) { MemoryEffects OldME = F->getMemoryEffects(); MemoryEffects NewME = ME & OldME; if (NewME != OldME) { ++NumMemoryAttr; F->setMemoryEffects(NewME); + // Remove conflicting writable attributes. + if (!isModSet(NewME.getModRef(IRMemLocation::ArgMem))) + for (Argument &A : F->args()) + A.removeAttr(Attribute::Writable); Changed.insert(F); } } @@ -625,7 +658,15 @@ determinePointerAccessAttrs(Argument *A, // must be a data operand (e.g. argument or operand bundle) const unsigned UseIndex = CB.getDataOperandNo(U); - if (!CB.doesNotCapture(UseIndex)) { + // Some intrinsics (for instance ptrmask) do not capture their results, + // but return results thas alias their pointer argument, and thus should + // be handled like GEP or addrspacecast above. + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( + &CB, /*MustPreserveNullness=*/false)) { + for (Use &UU : CB.uses()) + if (Visited.insert(&UU).second) + Worklist.push_back(&UU); + } else if (!CB.doesNotCapture(UseIndex)) { if (!CB.onlyReadsMemory()) // If the callee can save a copy into other memory, then simply // scanning uses of the call is insufficient. We have no way @@ -639,7 +680,8 @@ determinePointerAccessAttrs(Argument *A, Worklist.push_back(&UU); } - if (CB.doesNotAccessMemory()) + ModRefInfo ArgMR = CB.getMemoryEffects().getModRef(IRMemLocation::ArgMem); + if (isNoModRef(ArgMR)) continue; if (Function *F = CB.getCalledFunction()) @@ -654,9 +696,9 @@ determinePointerAccessAttrs(Argument *A, // invokes with operand bundles. if (CB.doesNotAccessMemory(UseIndex)) { /* nop */ - } else if (CB.onlyReadsMemory() || CB.onlyReadsMemory(UseIndex)) { + } else if (!isModSet(ArgMR) || CB.onlyReadsMemory(UseIndex)) { IsRead = true; - } else if (CB.hasFnAttr(Attribute::WriteOnly) || + } else if (!isRefSet(ArgMR) || CB.dataOperandHasImpliedAttr(UseIndex, Attribute::WriteOnly)) { IsWrite = true; } else { @@ -810,6 +852,9 @@ static bool addAccessAttr(Argument *A, Attribute::AttrKind R) { A->removeAttr(Attribute::WriteOnly); A->removeAttr(Attribute::ReadOnly); A->removeAttr(Attribute::ReadNone); + // Remove conflicting writable attribute. + if (R == Attribute::ReadNone || R == Attribute::ReadOnly) + A->removeAttr(Attribute::Writable); A->addAttr(R); if (R == Attribute::ReadOnly) ++NumReadOnlyArg; @@ -1720,7 +1765,8 @@ static SCCNodesResult createSCCNodeSet(ArrayRef<Function *> Functions) { template <typename AARGetterT> static SmallSet<Function *, 8> -deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) { +deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter, + bool ArgAttrsOnly) { SCCNodesResult Nodes = createSCCNodeSet(Functions); // Bail if the SCC only contains optnone functions. @@ -1728,6 +1774,10 @@ deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) { return {}; SmallSet<Function *, 8> Changed; + if (ArgAttrsOnly) { + addArgumentAttrs(Nodes.SCCNodes, Changed); + return Changed; + } addArgumentReturnedAttrs(Nodes.SCCNodes, Changed); addMemoryAttrs(Nodes.SCCNodes, AARGetter, Changed); @@ -1762,10 +1812,13 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, LazyCallGraph &CG, CGSCCUpdateResult &) { // Skip non-recursive functions if requested. + // Only infer argument attributes for non-recursive functions, because + // it can affect optimization behavior in conjunction with noalias. + bool ArgAttrsOnly = false; if (C.size() == 1 && SkipNonRecursive) { LazyCallGraph::Node &N = *C.begin(); if (!N->lookup(N)) - return PreservedAnalyses::all(); + ArgAttrsOnly = true; } FunctionAnalysisManager &FAM = @@ -1782,7 +1835,8 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, Functions.push_back(&N.getFunction()); } - auto ChangedFunctions = deriveAttrsInPostOrder(Functions, AARGetter); + auto ChangedFunctions = + deriveAttrsInPostOrder(Functions, AARGetter, ArgAttrsOnly); if (ChangedFunctions.empty()) return PreservedAnalyses::all(); @@ -1818,7 +1872,7 @@ void PostOrderFunctionAttrsPass::printPipeline( static_cast<PassInfoMixin<PostOrderFunctionAttrsPass> *>(this)->printPipeline( OS, MapClassName2PassName); if (SkipNonRecursive) - OS << "<skip-non-recursive>"; + OS << "<skip-non-recursive-function-attrs>"; } template <typename AARGetterT> diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp index f635b14cd2a9..49b3f2b085e1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -16,7 +16,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/AutoUpgrade.h" @@ -38,6 +37,7 @@ #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO/Internalize.h" @@ -139,6 +139,29 @@ static cl::opt<bool> ImportAllIndex("import-all-index", cl::desc("Import all external functions in index.")); +/// Pass a workload description file - an example of workload would be the +/// functions executed to satisfy a RPC request. A workload is defined by a root +/// function and the list of functions that are (frequently) needed to satisfy +/// it. The module that defines the root will have all those functions imported. +/// The file contains a JSON dictionary. The keys are root functions, the values +/// are lists of functions to import in the module defining the root. It is +/// assumed -funique-internal-linkage-names was used, thus ensuring function +/// names are unique even for local linkage ones. +static cl::opt<std::string> WorkloadDefinitions( + "thinlto-workload-def", + cl::desc("Pass a workload definition. This is a file containing a JSON " + "dictionary. The keys are root functions, the values are lists of " + "functions to import in the module defining the root. It is " + "assumed -funique-internal-linkage-names was used, to ensure " + "local linkage functions have unique names. For example: \n" + "{\n" + " \"rootFunction_1\": [\"function_to_import_1\", " + "\"function_to_import_2\"], \n" + " \"rootFunction_2\": [\"function_to_import_3\", " + "\"function_to_import_4\"] \n" + "}"), + cl::Hidden); + // Load lazily a module from \p FileName in \p Context. static std::unique_ptr<Module> loadFile(const std::string &FileName, LLVMContext &Context) { @@ -272,7 +295,7 @@ class GlobalsImporter final { function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> IsPrevailing; FunctionImporter::ImportMapTy &ImportList; - StringMap<FunctionImporter::ExportSetTy> *const ExportLists; + DenseMap<StringRef, FunctionImporter::ExportSetTy> *const ExportLists; bool shouldImportGlobal(const ValueInfo &VI) { const auto &GVS = DefinedGVSummaries.find(VI.getGUID()); @@ -357,7 +380,7 @@ public: function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> IsPrevailing, FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists) + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists) : Index(Index), DefinedGVSummaries(DefinedGVSummaries), IsPrevailing(IsPrevailing), ImportList(ImportList), ExportLists(ExportLists) {} @@ -370,6 +393,264 @@ public: } }; +static const char *getFailureName(FunctionImporter::ImportFailureReason Reason); + +/// Determine the list of imports and exports for each module. +class ModuleImportsManager { +protected: + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing; + const ModuleSummaryIndex &Index; + DenseMap<StringRef, FunctionImporter::ExportSetTy> *const ExportLists; + + ModuleImportsManager( + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing, + const ModuleSummaryIndex &Index, + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists = nullptr) + : IsPrevailing(IsPrevailing), Index(Index), ExportLists(ExportLists) {} + +public: + virtual ~ModuleImportsManager() = default; + + /// Given the list of globals defined in a module, compute the list of imports + /// as well as the list of "exports", i.e. the list of symbols referenced from + /// another module (that may require promotion). + virtual void + computeImportForModule(const GVSummaryMapTy &DefinedGVSummaries, + StringRef ModName, + FunctionImporter::ImportMapTy &ImportList); + + static std::unique_ptr<ModuleImportsManager> + create(function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing, + const ModuleSummaryIndex &Index, + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists = + nullptr); +}; + +/// A ModuleImportsManager that operates based on a workload definition (see +/// -thinlto-workload-def). For modules that do not define workload roots, it +/// applies the base ModuleImportsManager import policy. +class WorkloadImportsManager : public ModuleImportsManager { + // Keep a module name -> value infos to import association. We use it to + // determine if a module's import list should be done by the base + // ModuleImportsManager or by us. + StringMap<DenseSet<ValueInfo>> Workloads; + + void + computeImportForModule(const GVSummaryMapTy &DefinedGVSummaries, + StringRef ModName, + FunctionImporter::ImportMapTy &ImportList) override { + auto SetIter = Workloads.find(ModName); + if (SetIter == Workloads.end()) { + LLVM_DEBUG(dbgs() << "[Workload] " << ModName + << " does not contain the root of any context.\n"); + return ModuleImportsManager::computeImportForModule(DefinedGVSummaries, + ModName, ImportList); + } + LLVM_DEBUG(dbgs() << "[Workload] " << ModName + << " contains the root(s) of context(s).\n"); + + GlobalsImporter GVI(Index, DefinedGVSummaries, IsPrevailing, ImportList, + ExportLists); + auto &ValueInfos = SetIter->second; + SmallVector<EdgeInfo, 128> GlobWorklist; + for (auto &VI : llvm::make_early_inc_range(ValueInfos)) { + auto It = DefinedGVSummaries.find(VI.getGUID()); + if (It != DefinedGVSummaries.end() && + IsPrevailing(VI.getGUID(), It->second)) { + LLVM_DEBUG( + dbgs() << "[Workload] " << VI.name() + << " has the prevailing variant already in the module " + << ModName << ". No need to import\n"); + continue; + } + auto Candidates = + qualifyCalleeCandidates(Index, VI.getSummaryList(), ModName); + + const GlobalValueSummary *GVS = nullptr; + auto PotentialCandidates = llvm::map_range( + llvm::make_filter_range( + Candidates, + [&](const auto &Candidate) { + LLVM_DEBUG(dbgs() << "[Workflow] Candidate for " << VI.name() + << " from " << Candidate.second->modulePath() + << " ImportFailureReason: " + << getFailureName(Candidate.first) << "\n"); + return Candidate.first == + FunctionImporter::ImportFailureReason::None; + }), + [](const auto &Candidate) { return Candidate.second; }); + if (PotentialCandidates.empty()) { + LLVM_DEBUG(dbgs() << "[Workload] Not importing " << VI.name() + << " because can't find eligible Callee. Guid is: " + << Function::getGUID(VI.name()) << "\n"); + continue; + } + /// We will prefer importing the prevailing candidate, if not, we'll + /// still pick the first available candidate. The reason we want to make + /// sure we do import the prevailing candidate is because the goal of + /// workload-awareness is to enable optimizations specializing the call + /// graph of that workload. Suppose a function is already defined in the + /// module, but it's not the prevailing variant. Suppose also we do not + /// inline it (in fact, if it were interposable, we can't inline it), + /// but we could specialize it to the workload in other ways. However, + /// the linker would drop it in the favor of the prevailing copy. + /// Instead, by importing the prevailing variant (assuming also the use + /// of `-avail-extern-to-local`), we keep the specialization. We could + /// alteranatively make the non-prevailing variant local, but the + /// prevailing one is also the one for which we would have previously + /// collected profiles, making it preferrable. + auto PrevailingCandidates = llvm::make_filter_range( + PotentialCandidates, [&](const auto *Candidate) { + return IsPrevailing(VI.getGUID(), Candidate); + }); + if (PrevailingCandidates.empty()) { + GVS = *PotentialCandidates.begin(); + if (!llvm::hasSingleElement(PotentialCandidates) && + GlobalValue::isLocalLinkage(GVS->linkage())) + LLVM_DEBUG( + dbgs() + << "[Workload] Found multiple non-prevailing candidates for " + << VI.name() + << ". This is unexpected. Are module paths passed to the " + "compiler unique for the modules passed to the linker?"); + // We could in theory have multiple (interposable) copies of a symbol + // when there is no prevailing candidate, if say the prevailing copy was + // in a native object being linked in. However, we should in theory be + // marking all of these non-prevailing IR copies dead in that case, in + // which case they won't be candidates. + assert(GVS->isLive()); + } else { + assert(llvm::hasSingleElement(PrevailingCandidates)); + GVS = *PrevailingCandidates.begin(); + } + + auto ExportingModule = GVS->modulePath(); + // We checked that for the prevailing case, but if we happen to have for + // example an internal that's defined in this module, it'd have no + // PrevailingCandidates. + if (ExportingModule == ModName) { + LLVM_DEBUG(dbgs() << "[Workload] Not importing " << VI.name() + << " because its defining module is the same as the " + "current module\n"); + continue; + } + LLVM_DEBUG(dbgs() << "[Workload][Including]" << VI.name() << " from " + << ExportingModule << " : " + << Function::getGUID(VI.name()) << "\n"); + ImportList[ExportingModule].insert(VI.getGUID()); + GVI.onImportingSummary(*GVS); + if (ExportLists) + (*ExportLists)[ExportingModule].insert(VI); + } + LLVM_DEBUG(dbgs() << "[Workload] Done\n"); + } + +public: + WorkloadImportsManager( + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing, + const ModuleSummaryIndex &Index, + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists) + : ModuleImportsManager(IsPrevailing, Index, ExportLists) { + // Since the workload def uses names, we need a quick lookup + // name->ValueInfo. + StringMap<ValueInfo> NameToValueInfo; + StringSet<> AmbiguousNames; + for (auto &I : Index) { + ValueInfo VI = Index.getValueInfo(I); + if (!NameToValueInfo.insert(std::make_pair(VI.name(), VI)).second) + LLVM_DEBUG(AmbiguousNames.insert(VI.name())); + } + auto DbgReportIfAmbiguous = [&](StringRef Name) { + LLVM_DEBUG(if (AmbiguousNames.count(Name) > 0) { + dbgs() << "[Workload] Function name " << Name + << " present in the workload definition is ambiguous. Consider " + "compiling with -funique-internal-linkage-names."; + }); + }; + std::error_code EC; + auto BufferOrErr = MemoryBuffer::getFileOrSTDIN(WorkloadDefinitions); + if (std::error_code EC = BufferOrErr.getError()) { + report_fatal_error("Failed to open context file"); + return; + } + auto Buffer = std::move(BufferOrErr.get()); + std::map<std::string, std::vector<std::string>> WorkloadDefs; + json::Path::Root NullRoot; + // The JSON is supposed to contain a dictionary matching the type of + // WorkloadDefs. For example: + // { + // "rootFunction_1": ["function_to_import_1", "function_to_import_2"], + // "rootFunction_2": ["function_to_import_3", "function_to_import_4"] + // } + auto Parsed = json::parse(Buffer->getBuffer()); + if (!Parsed) + report_fatal_error(Parsed.takeError()); + if (!json::fromJSON(*Parsed, WorkloadDefs, NullRoot)) + report_fatal_error("Invalid thinlto contextual profile format."); + for (const auto &Workload : WorkloadDefs) { + const auto &Root = Workload.first; + DbgReportIfAmbiguous(Root); + LLVM_DEBUG(dbgs() << "[Workload] Root: " << Root << "\n"); + const auto &AllCallees = Workload.second; + auto RootIt = NameToValueInfo.find(Root); + if (RootIt == NameToValueInfo.end()) { + LLVM_DEBUG(dbgs() << "[Workload] Root " << Root + << " not found in this linkage unit.\n"); + continue; + } + auto RootVI = RootIt->second; + if (RootVI.getSummaryList().size() != 1) { + LLVM_DEBUG(dbgs() << "[Workload] Root " << Root + << " should have exactly one summary, but has " + << RootVI.getSummaryList().size() << ". Skipping.\n"); + continue; + } + StringRef RootDefiningModule = + RootVI.getSummaryList().front()->modulePath(); + LLVM_DEBUG(dbgs() << "[Workload] Root defining module for " << Root + << " is : " << RootDefiningModule << "\n"); + auto &Set = Workloads[RootDefiningModule]; + for (const auto &Callee : AllCallees) { + LLVM_DEBUG(dbgs() << "[Workload] " << Callee << "\n"); + DbgReportIfAmbiguous(Callee); + auto ElemIt = NameToValueInfo.find(Callee); + if (ElemIt == NameToValueInfo.end()) { + LLVM_DEBUG(dbgs() << "[Workload] " << Callee << " not found\n"); + continue; + } + Set.insert(ElemIt->second); + } + LLVM_DEBUG({ + dbgs() << "[Workload] Root: " << Root << " we have " << Set.size() + << " distinct callees.\n"; + for (const auto &VI : Set) { + dbgs() << "[Workload] Root: " << Root + << " Would include: " << VI.getGUID() << "\n"; + } + }); + } + } +}; + +std::unique_ptr<ModuleImportsManager> ModuleImportsManager::create( + function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> + IsPrevailing, + const ModuleSummaryIndex &Index, + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists) { + if (WorkloadDefinitions.empty()) { + LLVM_DEBUG(dbgs() << "[Workload] Using the regular imports manager.\n"); + return std::unique_ptr<ModuleImportsManager>( + new ModuleImportsManager(IsPrevailing, Index, ExportLists)); + } + LLVM_DEBUG(dbgs() << "[Workload] Using the contextual imports manager.\n"); + return std::make_unique<WorkloadImportsManager>(IsPrevailing, Index, + ExportLists); +} + static const char * getFailureName(FunctionImporter::ImportFailureReason Reason) { switch (Reason) { @@ -403,7 +684,7 @@ static void computeImportForFunction( isPrevailing, SmallVectorImpl<EdgeInfo> &Worklist, GlobalsImporter &GVImporter, FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists, + DenseMap<StringRef, FunctionImporter::ExportSetTy> *ExportLists, FunctionImporter::ImportThresholdsTy &ImportThresholds) { GVImporter.onImportingSummary(Summary); static int ImportCount = 0; @@ -482,7 +763,7 @@ static void computeImportForFunction( continue; } - FunctionImporter::ImportFailureReason Reason; + FunctionImporter::ImportFailureReason Reason{}; CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold, Summary.modulePath(), Reason); if (!CalleeSummary) { @@ -567,20 +848,13 @@ static void computeImportForFunction( } } -/// Given the list of globals defined in a module, compute the list of imports -/// as well as the list of "exports", i.e. the list of symbols referenced from -/// another module (that may require promotion). -static void ComputeImportForModule( - const GVSummaryMapTy &DefinedGVSummaries, - function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> - isPrevailing, - const ModuleSummaryIndex &Index, StringRef ModName, - FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { +void ModuleImportsManager::computeImportForModule( + const GVSummaryMapTy &DefinedGVSummaries, StringRef ModName, + FunctionImporter::ImportMapTy &ImportList) { // Worklist contains the list of function imported in this module, for which // we will analyse the callees and may import further down the callgraph. SmallVector<EdgeInfo, 128> Worklist; - GlobalsImporter GVI(Index, DefinedGVSummaries, isPrevailing, ImportList, + GlobalsImporter GVI(Index, DefinedGVSummaries, IsPrevailing, ImportList, ExportLists); FunctionImporter::ImportThresholdsTy ImportThresholds; @@ -603,7 +877,7 @@ static void ComputeImportForModule( continue; LLVM_DEBUG(dbgs() << "Initialize import for " << VI << "\n"); computeImportForFunction(*FuncSummary, Index, ImportInstrLimit, - DefinedGVSummaries, isPrevailing, Worklist, GVI, + DefinedGVSummaries, IsPrevailing, Worklist, GVI, ImportList, ExportLists, ImportThresholds); } @@ -615,7 +889,7 @@ static void ComputeImportForModule( if (auto *FS = dyn_cast<FunctionSummary>(Summary)) computeImportForFunction(*FS, Index, Threshold, DefinedGVSummaries, - isPrevailing, Worklist, GVI, ImportList, + IsPrevailing, Worklist, GVI, ImportList, ExportLists, ImportThresholds); } @@ -671,10 +945,10 @@ static unsigned numGlobalVarSummaries(const ModuleSummaryIndex &Index, #endif #ifndef NDEBUG -static bool -checkVariableImport(const ModuleSummaryIndex &Index, - StringMap<FunctionImporter::ImportMapTy> &ImportLists, - StringMap<FunctionImporter::ExportSetTy> &ExportLists) { +static bool checkVariableImport( + const ModuleSummaryIndex &Index, + DenseMap<StringRef, FunctionImporter::ImportMapTy> &ImportLists, + DenseMap<StringRef, FunctionImporter::ExportSetTy> &ExportLists) { DenseSet<GlobalValue::GUID> FlattenedImports; @@ -702,7 +976,7 @@ checkVariableImport(const ModuleSummaryIndex &Index, for (auto &ExportPerModule : ExportLists) for (auto &VI : ExportPerModule.second) if (!FlattenedImports.count(VI.getGUID()) && - IsReadOrWriteOnlyVarNeedingImporting(ExportPerModule.first(), VI)) + IsReadOrWriteOnlyVarNeedingImporting(ExportPerModule.first, VI)) return false; return true; @@ -712,19 +986,19 @@ checkVariableImport(const ModuleSummaryIndex &Index, /// Compute all the import and export for every module using the Index. void llvm::ComputeCrossModuleImport( const ModuleSummaryIndex &Index, - const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, + const DenseMap<StringRef, GVSummaryMapTy> &ModuleToDefinedGVSummaries, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> isPrevailing, - StringMap<FunctionImporter::ImportMapTy> &ImportLists, - StringMap<FunctionImporter::ExportSetTy> &ExportLists) { + DenseMap<StringRef, FunctionImporter::ImportMapTy> &ImportLists, + DenseMap<StringRef, FunctionImporter::ExportSetTy> &ExportLists) { + auto MIS = ModuleImportsManager::create(isPrevailing, Index, &ExportLists); // For each module that has function defined, compute the import/export lists. for (const auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) { - auto &ImportList = ImportLists[DefinedGVSummaries.first()]; + auto &ImportList = ImportLists[DefinedGVSummaries.first]; LLVM_DEBUG(dbgs() << "Computing import for Module '" - << DefinedGVSummaries.first() << "'\n"); - ComputeImportForModule(DefinedGVSummaries.second, isPrevailing, Index, - DefinedGVSummaries.first(), ImportList, - &ExportLists); + << DefinedGVSummaries.first << "'\n"); + MIS->computeImportForModule(DefinedGVSummaries.second, + DefinedGVSummaries.first, ImportList); } // When computing imports we only added the variables and functions being @@ -735,7 +1009,7 @@ void llvm::ComputeCrossModuleImport( for (auto &ELI : ExportLists) { FunctionImporter::ExportSetTy NewExports; const auto &DefinedGVSummaries = - ModuleToDefinedGVSummaries.lookup(ELI.first()); + ModuleToDefinedGVSummaries.lookup(ELI.first); for (auto &EI : ELI.second) { // Find the copy defined in the exporting module so that we can mark the // values it references in that specific definition as exported. @@ -783,7 +1057,7 @@ void llvm::ComputeCrossModuleImport( LLVM_DEBUG(dbgs() << "Import/Export lists for " << ImportLists.size() << " modules:\n"); for (auto &ModuleImports : ImportLists) { - auto ModName = ModuleImports.first(); + auto ModName = ModuleImports.first; auto &Exports = ExportLists[ModName]; unsigned NumGVS = numGlobalVarSummaries(Index, Exports); LLVM_DEBUG(dbgs() << "* Module " << ModName << " exports " @@ -791,7 +1065,7 @@ void llvm::ComputeCrossModuleImport( << " vars. Imports from " << ModuleImports.second.size() << " modules.\n"); for (auto &Src : ModuleImports.second) { - auto SrcModName = Src.first(); + auto SrcModName = Src.first; unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second); LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod << " functions imported from " << SrcModName << "\n"); @@ -809,7 +1083,7 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index, LLVM_DEBUG(dbgs() << "* Module " << ModulePath << " imports from " << ImportList.size() << " modules.\n"); for (auto &Src : ImportList) { - auto SrcModName = Src.first(); + auto SrcModName = Src.first; unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second); LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod << " functions imported from " << SrcModName << "\n"); @@ -819,8 +1093,15 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index, } #endif -/// Compute all the imports for the given module in the Index. -void llvm::ComputeCrossModuleImportForModule( +/// Compute all the imports for the given module using the Index. +/// +/// \p isPrevailing is a callback that will be called with a global value's GUID +/// and summary and should return whether the module corresponding to the +/// summary contains the linker-prevailing copy of that value. +/// +/// \p ImportList will be populated with a map that can be passed to +/// FunctionImporter::importFunctions() above (see description there). +static void ComputeCrossModuleImportForModuleForTest( StringRef ModulePath, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> isPrevailing, @@ -833,17 +1114,20 @@ void llvm::ComputeCrossModuleImportForModule( // Compute the import list for this module. LLVM_DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n"); - ComputeImportForModule(FunctionSummaryMap, isPrevailing, Index, ModulePath, - ImportList); + auto MIS = ModuleImportsManager::create(isPrevailing, Index); + MIS->computeImportForModule(FunctionSummaryMap, ModulePath, ImportList); #ifndef NDEBUG dumpImportListForModule(Index, ModulePath, ImportList); #endif } -// Mark all external summaries in Index for import into the given module. -// Used for distributed builds using a distributed index. -void llvm::ComputeCrossModuleImportForModuleFromIndex( +/// Mark all external summaries in \p Index for import into the given module. +/// Used for testing the case of distributed builds using a distributed index. +/// +/// \p ImportList will be populated with a map that can be passed to +/// FunctionImporter::importFunctions() above (see description there). +static void ComputeCrossModuleImportForModuleFromIndexForTest( StringRef ModulePath, const ModuleSummaryIndex &Index, FunctionImporter::ImportMapTy &ImportList) { for (const auto &GlobalList : Index) { @@ -1041,7 +1325,7 @@ void llvm::computeDeadSymbolsWithConstProp( /// \p ModulePath. void llvm::gatherImportedSummariesForModule( StringRef ModulePath, - const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, + const DenseMap<StringRef, GVSummaryMapTy> &ModuleToDefinedGVSummaries, const FunctionImporter::ImportMapTy &ImportList, std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) { // Include all summaries from the importing module. @@ -1049,10 +1333,9 @@ void llvm::gatherImportedSummariesForModule( ModuleToDefinedGVSummaries.lookup(ModulePath); // Include summaries for imports. for (const auto &ILI : ImportList) { - auto &SummariesForIndex = - ModuleToSummariesForIndex[std::string(ILI.first())]; + auto &SummariesForIndex = ModuleToSummariesForIndex[std::string(ILI.first)]; const auto &DefinedGVSummaries = - ModuleToDefinedGVSummaries.lookup(ILI.first()); + ModuleToDefinedGVSummaries.lookup(ILI.first); for (const auto &GI : ILI.second) { const auto &DS = DefinedGVSummaries.find(GI); assert(DS != DefinedGVSummaries.end() && @@ -1298,7 +1581,7 @@ static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) { // ensure all uses of alias instead use the new clone (casted if necessary). NewFn->setLinkage(GA->getLinkage()); NewFn->setVisibility(GA->getVisibility()); - GA->replaceAllUsesWith(ConstantExpr::getBitCast(NewFn, GA->getType())); + GA->replaceAllUsesWith(NewFn); NewFn->takeName(GA); return NewFn; } @@ -1327,7 +1610,7 @@ Expected<bool> FunctionImporter::importFunctions( // Do the actual import of functions now, one Module at a time std::set<StringRef> ModuleNameOrderedList; for (const auto &FunctionsToImportPerModule : ImportList) { - ModuleNameOrderedList.insert(FunctionsToImportPerModule.first()); + ModuleNameOrderedList.insert(FunctionsToImportPerModule.first); } for (const auto &Name : ModuleNameOrderedList) { // Get the module for the import @@ -1461,7 +1744,7 @@ Expected<bool> FunctionImporter::importFunctions( return ImportedCount; } -static bool doImportingForModule( +static bool doImportingForModuleForTest( Module &M, function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)> isPrevailing) { if (SummaryFile.empty()) @@ -1481,11 +1764,11 @@ static bool doImportingForModule( // when testing distributed backend handling via the opt tool, when // we have distributed indexes containing exactly the summaries to import. if (ImportAllIndex) - ComputeCrossModuleImportForModuleFromIndex(M.getModuleIdentifier(), *Index, - ImportList); + ComputeCrossModuleImportForModuleFromIndexForTest(M.getModuleIdentifier(), + *Index, ImportList); else - ComputeCrossModuleImportForModule(M.getModuleIdentifier(), isPrevailing, - *Index, ImportList); + ComputeCrossModuleImportForModuleForTest(M.getModuleIdentifier(), + isPrevailing, *Index, ImportList); // Conservatively mark all internal values as promoted. This interface is // only used when doing importing via the function importing pass. The pass @@ -1533,7 +1816,7 @@ PreservedAnalyses FunctionImportPass::run(Module &M, auto isPrevailing = [](GlobalValue::GUID, const GlobalValueSummary *) { return true; }; - if (!doImportingForModule(M, isPrevailing)) + if (!doImportingForModuleForTest(M, isPrevailing)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index ac5dbc7cfb2a..a4c12006ee24 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -5,45 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This specialises functions with constant parameters. Constant parameters -// like function pointers and constant globals are propagated to the callee by -// specializing the function. The main benefit of this pass at the moment is -// that indirect calls are transformed into direct calls, which provides inline -// opportunities that the inliner would not have been able to achieve. That's -// why function specialisation is run before the inliner in the optimisation -// pipeline; that is by design. Otherwise, we would only benefit from constant -// passing, which is a valid use-case too, but hasn't been explored much in -// terms of performance uplifts, cost-model and compile-time impact. -// -// Current limitations: -// - It does not yet handle integer ranges. We do support "literal constants", -// but that's off by default under an option. -// - The cost-model could be further looked into (it mainly focuses on inlining -// benefits), -// -// Ideas: -// - With a function specialization attribute for arguments, we could have -// a direct way to steer function specialization, avoiding the cost-model, -// and thus control compile-times / code-size. -// -// Todos: -// - Specializing recursive functions relies on running the transformation a -// number of times, which is controlled by option -// `func-specialization-max-iters`. Thus, increasing this value and the -// number of iterations, will linearly increase the number of times recursive -// functions get specialized, see also the discussion in -// https://reviews.llvm.org/D106426 for details. Perhaps there is a -// compile-time friendlier way to control/limit the number of specialisations -// for recursive functions. -// - Don't transform the function if function specialization does not trigger; -// the SCCPSolver may make IR changes. -// -// References: -// - 2021 LLVM Dev Mtg “Introducing function specialisation, and can we enable -// it by default?”, https://www.youtube.com/watch?v=zJiCjeXgV5Q -// -//===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/FunctionSpecialization.h" #include "llvm/ADT/Statistic.h" @@ -78,11 +39,47 @@ static cl::opt<unsigned> MaxClones( "The maximum number of clones allowed for a single function " "specialization")); +static cl::opt<unsigned> + MaxDiscoveryIterations("funcspec-max-discovery-iterations", cl::init(100), + cl::Hidden, + cl::desc("The maximum number of iterations allowed " + "when searching for transitive " + "phis")); + +static cl::opt<unsigned> MaxIncomingPhiValues( + "funcspec-max-incoming-phi-values", cl::init(8), cl::Hidden, + cl::desc("The maximum number of incoming values a PHI node can have to be " + "considered during the specialization bonus estimation")); + +static cl::opt<unsigned> MaxBlockPredecessors( + "funcspec-max-block-predecessors", cl::init(2), cl::Hidden, cl::desc( + "The maximum number of predecessors a basic block can have to be " + "considered during the estimation of dead code")); + static cl::opt<unsigned> MinFunctionSize( - "funcspec-min-function-size", cl::init(100), cl::Hidden, cl::desc( + "funcspec-min-function-size", cl::init(300), cl::Hidden, cl::desc( "Don't specialize functions that have less than this number of " "instructions")); +static cl::opt<unsigned> MaxCodeSizeGrowth( + "funcspec-max-codesize-growth", cl::init(3), cl::Hidden, cl::desc( + "Maximum codesize growth allowed per function")); + +static cl::opt<unsigned> MinCodeSizeSavings( + "funcspec-min-codesize-savings", cl::init(20), cl::Hidden, cl::desc( + "Reject specializations whose codesize savings are less than this" + "much percent of the original function size")); + +static cl::opt<unsigned> MinLatencySavings( + "funcspec-min-latency-savings", cl::init(40), cl::Hidden, + cl::desc("Reject specializations whose latency savings are less than this" + "much percent of the original function size")); + +static cl::opt<unsigned> MinInliningBonus( + "funcspec-min-inlining-bonus", cl::init(300), cl::Hidden, cl::desc( + "Reject specializations whose inlining bonus is less than this" + "much percent of the original function size")); + static cl::opt<bool> SpecializeOnAddress( "funcspec-on-address", cl::init(false), cl::Hidden, cl::desc( "Enable function specialization on the address of global values")); @@ -96,26 +93,33 @@ static cl::opt<bool> SpecializeLiteralConstant( "Enable specialization of functions that take a literal constant as an " "argument")); -// Estimates the instruction cost of all the basic blocks in \p WorkList. -// The successors of such blocks are added to the list as long as they are -// executable and they have a unique predecessor. \p WorkList represents -// the basic blocks of a specialization which become dead once we replace -// instructions that are known to be constants. The aim here is to estimate -// the combination of size and latency savings in comparison to the non -// specialized version of the function. -static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList, - ConstMap &KnownConstants, SCCPSolver &Solver, - BlockFrequencyInfo &BFI, - TargetTransformInfo &TTI) { - Cost Bonus = 0; +bool InstCostVisitor::canEliminateSuccessor(BasicBlock *BB, BasicBlock *Succ, + DenseSet<BasicBlock *> &DeadBlocks) { + unsigned I = 0; + return all_of(predecessors(Succ), + [&I, BB, Succ, &DeadBlocks] (BasicBlock *Pred) { + return I++ < MaxBlockPredecessors && + (Pred == BB || Pred == Succ || DeadBlocks.contains(Pred)); + }); +} +// Estimates the codesize savings due to dead code after constant propagation. +// \p WorkList represents the basic blocks of a specialization which will +// eventually become dead once we replace instructions that are known to be +// constants. The successors of such blocks are added to the list as long as +// the \p Solver found they were executable prior to specialization, and only +// if all their predecessors are dead. +Cost InstCostVisitor::estimateBasicBlocks( + SmallVectorImpl<BasicBlock *> &WorkList) { + Cost CodeSize = 0; // Accumulate the instruction cost of each basic block weighted by frequency. while (!WorkList.empty()) { BasicBlock *BB = WorkList.pop_back_val(); - uint64_t Weight = BFI.getBlockFreq(BB).getFrequency() / - BFI.getEntryFreq(); - if (!Weight) + // These blocks are considered dead as far as the InstCostVisitor + // is concerned. They haven't been proven dead yet by the Solver, + // but may become if we propagate the specialization arguments. + if (!DeadBlocks.insert(BB).second) continue; for (Instruction &I : *BB) { @@ -127,67 +131,105 @@ static Cost estimateBasicBlocks(SmallVectorImpl<BasicBlock *> &WorkList, if (KnownConstants.contains(&I)) continue; - Bonus += Weight * - TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + Cost C = TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); - LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus - << " after user " << I << "\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: CodeSize " << C + << " for user " << I << "\n"); + CodeSize += C; } // Keep adding dead successors to the list as long as they are - // executable and they have a unique predecessor. + // executable and only reachable from dead blocks. for (BasicBlock *SuccBB : successors(BB)) - if (Solver.isBlockExecutable(SuccBB) && - SuccBB->getUniquePredecessor() == BB) + if (isBlockExecutable(SuccBB) && + canEliminateSuccessor(BB, SuccBB, DeadBlocks)) WorkList.push_back(SuccBB); } - return Bonus; + return CodeSize; } static Constant *findConstantFor(Value *V, ConstMap &KnownConstants) { if (auto *C = dyn_cast<Constant>(V)) return C; - if (auto It = KnownConstants.find(V); It != KnownConstants.end()) - return It->second; - return nullptr; + return KnownConstants.lookup(V); } -Cost InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) { - // Cache the iterator before visiting. - LastVisited = KnownConstants.insert({Use, C}).first; +Bonus InstCostVisitor::getBonusFromPendingPHIs() { + Bonus B; + while (!PendingPHIs.empty()) { + Instruction *Phi = PendingPHIs.pop_back_val(); + // The pending PHIs could have been proven dead by now. + if (isBlockExecutable(Phi->getParent())) + B += getUserBonus(Phi); + } + return B; +} + +/// Compute a bonus for replacing argument \p A with constant \p C. +Bonus InstCostVisitor::getSpecializationBonus(Argument *A, Constant *C) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " + << C->getNameOrAsOperand() << "\n"); + Bonus B; + for (auto *U : A->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + if (isBlockExecutable(UI->getParent())) + B += getUserBonus(UI, A, C); - if (auto *I = dyn_cast<SwitchInst>(User)) - return estimateSwitchInst(*I); + LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated bonus {CodeSize = " + << B.CodeSize << ", Latency = " << B.Latency + << "} for argument " << *A << "\n"); + return B; +} - if (auto *I = dyn_cast<BranchInst>(User)) - return estimateBranchInst(*I); +Bonus InstCostVisitor::getUserBonus(Instruction *User, Value *Use, Constant *C) { + // We have already propagated a constant for this user. + if (KnownConstants.contains(User)) + return {0, 0}; - C = visit(*User); - if (!C) - return 0; + // Cache the iterator before visiting. + LastVisited = Use ? KnownConstants.insert({Use, C}).first + : KnownConstants.end(); + + Cost CodeSize = 0; + if (auto *I = dyn_cast<SwitchInst>(User)) { + CodeSize = estimateSwitchInst(*I); + } else if (auto *I = dyn_cast<BranchInst>(User)) { + CodeSize = estimateBranchInst(*I); + } else { + C = visit(*User); + if (!C) + return {0, 0}; + } + // Even though it doesn't make sense to bind switch and branch instructions + // with a constant, unlike any other instruction type, it prevents estimating + // their bonus multiple times. KnownConstants.insert({User, C}); + CodeSize += TTI.getInstructionCost(User, TargetTransformInfo::TCK_CodeSize); + uint64_t Weight = BFI.getBlockFreq(User->getParent()).getFrequency() / - BFI.getEntryFreq(); - if (!Weight) - return 0; + BFI.getEntryFreq().getFrequency(); - Cost Bonus = Weight * - TTI.getInstructionCost(User, TargetTransformInfo::TCK_SizeAndLatency); + Cost Latency = Weight * + TTI.getInstructionCost(User, TargetTransformInfo::TCK_Latency); - LLVM_DEBUG(dbgs() << "FnSpecialization: Bonus " << Bonus - << " for user " << *User << "\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: {CodeSize = " << CodeSize + << ", Latency = " << Latency << "} for user " + << *User << "\n"); + Bonus B(CodeSize, Latency); for (auto *U : User->users()) if (auto *UI = dyn_cast<Instruction>(U)) - if (Solver.isBlockExecutable(UI->getParent())) - Bonus += getUserBonus(UI, User, C); + if (UI != User && isBlockExecutable(UI->getParent())) + B += getUserBonus(UI, User, C); - return Bonus; + return B; } Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (I.getCondition() != LastVisited->first) return 0; @@ -202,16 +244,17 @@ Cost InstCostVisitor::estimateSwitchInst(SwitchInst &I) { SmallVector<BasicBlock *> WorkList; for (const auto &Case : I.cases()) { BasicBlock *BB = Case.getCaseSuccessor(); - if (BB == Succ || !Solver.isBlockExecutable(BB) || - BB->getUniquePredecessor() != I.getParent()) - continue; - WorkList.push_back(BB); + if (BB != Succ && isBlockExecutable(BB) && + canEliminateSuccessor(I.getParent(), BB, DeadBlocks)) + WorkList.push_back(BB); } - return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI); + return estimateBasicBlocks(WorkList); } Cost InstCostVisitor::estimateBranchInst(BranchInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (I.getCondition() != LastVisited->first) return 0; @@ -219,14 +262,115 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) { // Initialize the worklist with the dead successor as long as // it is executable and has a unique predecessor. SmallVector<BasicBlock *> WorkList; - if (Solver.isBlockExecutable(Succ) && - Succ->getUniquePredecessor() == I.getParent()) + if (isBlockExecutable(Succ) && + canEliminateSuccessor(I.getParent(), Succ, DeadBlocks)) WorkList.push_back(Succ); - return estimateBasicBlocks(WorkList, KnownConstants, Solver, BFI, TTI); + return estimateBasicBlocks(WorkList); +} + +bool InstCostVisitor::discoverTransitivelyIncomingValues( + Constant *Const, PHINode *Root, DenseSet<PHINode *> &TransitivePHIs) { + + SmallVector<PHINode *, 64> WorkList; + WorkList.push_back(Root); + unsigned Iter = 0; + + while (!WorkList.empty()) { + PHINode *PN = WorkList.pop_back_val(); + + if (++Iter > MaxDiscoveryIterations || + PN->getNumIncomingValues() > MaxIncomingPhiValues) + return false; + + if (!TransitivePHIs.insert(PN).second) + continue; + + for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) { + Value *V = PN->getIncomingValue(I); + + // Disregard self-references and dead incoming values. + if (auto *Inst = dyn_cast<Instruction>(V)) + if (Inst == PN || DeadBlocks.contains(PN->getIncomingBlock(I))) + continue; + + if (Constant *C = findConstantFor(V, KnownConstants)) { + // Not all incoming values are the same constant. Bail immediately. + if (C != Const) + return false; + continue; + } + + if (auto *Phi = dyn_cast<PHINode>(V)) { + WorkList.push_back(Phi); + continue; + } + + // We can't reason about anything else. + return false; + } + } + return true; +} + +Constant *InstCostVisitor::visitPHINode(PHINode &I) { + if (I.getNumIncomingValues() > MaxIncomingPhiValues) + return nullptr; + + bool Inserted = VisitedPHIs.insert(&I).second; + Constant *Const = nullptr; + bool HaveSeenIncomingPHI = false; + + for (unsigned Idx = 0, E = I.getNumIncomingValues(); Idx != E; ++Idx) { + Value *V = I.getIncomingValue(Idx); + + // Disregard self-references and dead incoming values. + if (auto *Inst = dyn_cast<Instruction>(V)) + if (Inst == &I || DeadBlocks.contains(I.getIncomingBlock(Idx))) + continue; + + if (Constant *C = findConstantFor(V, KnownConstants)) { + if (!Const) + Const = C; + // Not all incoming values are the same constant. Bail immediately. + if (C != Const) + return nullptr; + continue; + } + + if (Inserted) { + // First time we are seeing this phi. We will retry later, after + // all the constant arguments have been propagated. Bail for now. + PendingPHIs.push_back(&I); + return nullptr; + } + + if (isa<PHINode>(V)) { + // Perhaps it is a Transitive Phi. We will confirm later. + HaveSeenIncomingPHI = true; + continue; + } + + // We can't reason about anything else. + return nullptr; + } + + if (!Const) + return nullptr; + + if (!HaveSeenIncomingPHI) + return Const; + + DenseSet<PHINode *> TransitivePHIs; + if (!discoverTransitivelyIncomingValues(Const, &I, TransitivePHIs)) + return nullptr; + + return Const; } Constant *InstCostVisitor::visitFreezeInst(FreezeInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (isGuaranteedNotToBeUndefOrPoison(LastVisited->second)) return LastVisited->second; return nullptr; @@ -253,6 +397,8 @@ Constant *InstCostVisitor::visitCallBase(CallBase &I) { } Constant *InstCostVisitor::visitLoadInst(LoadInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (isa<ConstantPointerNull>(LastVisited->second)) return nullptr; return ConstantFoldLoadFromConstPtr(LastVisited->second, I.getType(), DL); @@ -275,6 +421,8 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { } Constant *InstCostVisitor::visitSelectInst(SelectInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + if (I.getCondition() != LastVisited->first) return nullptr; @@ -290,6 +438,8 @@ Constant *InstCostVisitor::visitCastInst(CastInst &I) { } Constant *InstCostVisitor::visitCmpInst(CmpInst &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + bool Swap = I.getOperand(1) == LastVisited->first; Value *V = Swap ? I.getOperand(0) : I.getOperand(1); Constant *Other = findConstantFor(V, KnownConstants); @@ -303,10 +453,14 @@ Constant *InstCostVisitor::visitCmpInst(CmpInst &I) { } Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + return ConstantFoldUnaryOpOperand(I.getOpcode(), LastVisited->second, DL); } Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) { + assert(LastVisited != KnownConstants.end() && "Invalid iterator!"); + bool Swap = I.getOperand(1) == LastVisited->first; Value *V = Swap ? I.getOperand(0) : I.getOperand(1); Constant *Other = findConstantFor(V, KnownConstants); @@ -413,10 +567,7 @@ void FunctionSpecializer::promoteConstantStackValues(Function *F) { Value *GV = new GlobalVariable(M, ConstVal->getType(), true, GlobalValue::InternalLinkage, ConstVal, - "funcspec.arg"); - if (ArgOpType != ConstVal->getType()) - GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType); - + "specialized.arg." + Twine(++NGlobals)); Call->setArgOperand(Idx, GV); } } @@ -506,13 +657,18 @@ bool FunctionSpecializer::run() { if (!Inserted && !Metrics.isRecursive && !SpecializeLiteralConstant) continue; + int64_t Sz = *Metrics.NumInsts.getValue(); + assert(Sz > 0 && "CodeSize should be positive"); + // It is safe to down cast from int64_t, NumInsts is always positive. + unsigned FuncSize = static_cast<unsigned>(Sz); + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " - << F.getName() << " is " << Metrics.NumInsts << "\n"); + << F.getName() << " is " << FuncSize << "\n"); if (Inserted && Metrics.isRecursive) promoteConstantStackValues(&F); - if (!findSpecializations(&F, Metrics.NumInsts, AllSpecs, SM)) { + if (!findSpecializations(&F, FuncSize, AllSpecs, SM)) { LLVM_DEBUG( dbgs() << "FnSpecialization: No possible specializations found for " << F.getName() << "\n"); @@ -640,14 +796,15 @@ void FunctionSpecializer::removeDeadFunctions() { /// Clone the function \p F and remove the ssa_copy intrinsics added by /// the SCCPSolver in the cloned version. -static Function *cloneCandidateFunction(Function *F) { +static Function *cloneCandidateFunction(Function *F, unsigned NSpecs) { ValueToValueMapTy Mappings; Function *Clone = CloneFunction(F, Mappings); + Clone->setName(F->getName() + ".specialized." + Twine(NSpecs)); removeSSACopy(*Clone); return Clone; } -bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost, +bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize, SmallVectorImpl<Spec> &AllSpecs, SpecMap &SM) { // A mapping from a specialisation signature to the index of the respective @@ -713,16 +870,48 @@ bool FunctionSpecializer::findSpecializations(Function *F, Cost SpecCost, AllSpecs[Index].CallSites.push_back(&CS); } else { // Calculate the specialisation gain. - Cost Score = 0 - SpecCost; + Bonus B; + unsigned Score = 0; InstCostVisitor Visitor = getInstCostVisitorFor(F); - for (ArgInfo &A : S.Args) - Score += getSpecializationBonus(A.Formal, A.Actual, Visitor); + for (ArgInfo &A : S.Args) { + B += Visitor.getSpecializationBonus(A.Formal, A.Actual); + Score += getInliningBonus(A.Formal, A.Actual); + } + B += Visitor.getBonusFromPendingPHIs(); + + + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization bonus {CodeSize = " + << B.CodeSize << ", Latency = " << B.Latency + << ", Inlining = " << Score << "}\n"); + + FunctionGrowth[F] += FuncSize - B.CodeSize; + + auto IsProfitable = [](Bonus &B, unsigned Score, unsigned FuncSize, + unsigned FuncGrowth) -> bool { + // No check required. + if (ForceSpecialization) + return true; + // Minimum inlining bonus. + if (Score > MinInliningBonus * FuncSize / 100) + return true; + // Minimum codesize savings. + if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100) + return false; + // Minimum latency savings. + if (B.Latency < MinLatencySavings * FuncSize / 100) + return false; + // Maximum codesize growth. + if (FuncGrowth / FuncSize > MaxCodeSizeGrowth) + return false; + return true; + }; // Discard unprofitable specialisations. - if (!ForceSpecialization && Score <= 0) + if (!IsProfitable(B, Score, FuncSize, FunctionGrowth[F])) continue; // Create a new specialisation entry. + Score += std::max(B.CodeSize, B.Latency); auto &Spec = AllSpecs.emplace_back(F, S, Score); if (CS.getFunction() != F) Spec.CallSites.push_back(&CS); @@ -768,7 +957,7 @@ bool FunctionSpecializer::isCandidateFunction(Function *F) { Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &S) { - Function *Clone = cloneCandidateFunction(F); + Function *Clone = cloneCandidateFunction(F, Specializations.size() + 1); // The original function does not neccessarily have internal linkage, but the // clone must. @@ -789,30 +978,14 @@ Function *FunctionSpecializer::createSpecialization(Function *F, return Clone; } -/// Compute a bonus for replacing argument \p A with constant \p C. -Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, - InstCostVisitor &Visitor) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " - << C->getNameOrAsOperand() << "\n"); - - Cost TotalCost = 0; - for (auto *U : A->users()) - if (auto *UI = dyn_cast<Instruction>(U)) - if (Solver.isBlockExecutable(UI->getParent())) - TotalCost += Visitor.getUserBonus(UI, A, C); - - LLVM_DEBUG(dbgs() << "FnSpecialization: Accumulated user bonus " - << TotalCost << " for argument " << *A << "\n"); - - // The below heuristic is only concerned with exposing inlining - // opportunities via indirect call promotion. If the argument is not a - // (potentially casted) function pointer, give up. - // - // TODO: Perhaps we should consider checking such inlining opportunities - // while traversing the users of the specialization arguments ? +/// Compute the inlining bonus for replacing argument \p A with constant \p C. +/// The below heuristic is only concerned with exposing inlining +/// opportunities via indirect call promotion. If the argument is not a +/// (potentially casted) function pointer, give up. +unsigned FunctionSpecializer::getInliningBonus(Argument *A, Constant *C) { Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts()); if (!CalledFunction) - return TotalCost; + return 0; // Get TTI for the called function (used for the inline cost). auto &CalleeTTI = (GetTTI)(*CalledFunction); @@ -822,7 +995,7 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, // calls to be promoted to direct calls. If the indirect call promotion // would likely enable the called function to be inlined, specializing is a // good idea. - int Bonus = 0; + int InliningBonus = 0; for (User *U : A->users()) { if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) continue; @@ -849,15 +1022,15 @@ Cost FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, // We clamp the bonus for this call to be between zero and the default // threshold. if (IC.isAlways()) - Bonus += Params.DefaultThreshold; + InliningBonus += Params.DefaultThreshold; else if (IC.isVariable() && IC.getCostDelta() > 0) - Bonus += IC.getCostDelta(); + InliningBonus += IC.getCostDelta(); - LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << Bonus + LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << InliningBonus << " for user " << *U << "\n"); } - return TotalCost + Bonus; + return InliningBonus > 0 ? static_cast<unsigned>(InliningBonus) : 0; } /// Determine if it is possible to specialise the function for constant values diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 8012e1e650a0..951372adcfa9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -17,7 +17,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" @@ -390,7 +389,7 @@ static bool collectSRATypes(DenseMap<uint64_t, GlobalPart> &Parts, } // Scalable types not currently supported. - if (isa<ScalableVectorType>(Ty)) + if (Ty->isScalableTy()) return false; auto IsStored = [](Value *V, Constant *Initializer) { @@ -930,25 +929,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, } // Update users of the allocation to use the new global instead. - BitCastInst *TheBC = nullptr; - while (!CI->use_empty()) { - Instruction *User = cast<Instruction>(CI->user_back()); - if (BitCastInst *BCI = dyn_cast<BitCastInst>(User)) { - if (BCI->getType() == NewGV->getType()) { - BCI->replaceAllUsesWith(NewGV); - BCI->eraseFromParent(); - } else { - BCI->setOperand(0, NewGV); - } - } else { - if (!TheBC) - TheBC = new BitCastInst(NewGV, CI->getType(), "newgv", CI); - User->replaceUsesOfWith(CI, TheBC); - } - } - - SmallSetVector<Constant *, 1> RepValues; - RepValues.insert(NewGV); + CI->replaceAllUsesWith(NewGV); // If there is a comparison against null, we will insert a global bool to // keep track of whether the global was initialized yet or not. @@ -980,9 +961,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, Use &LoadUse = *LI->use_begin(); ICmpInst *ICI = dyn_cast<ICmpInst>(LoadUse.getUser()); if (!ICI) { - auto *CE = ConstantExpr::getBitCast(NewGV, LI->getType()); - RepValues.insert(CE); - LoadUse.set(CE); + LoadUse.set(NewGV); continue; } @@ -1028,8 +1007,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, // To further other optimizations, loop over all users of NewGV and try to // constant prop them. This will promote GEP instructions with constant // indices into GEP constant-exprs, which will allow global-opt to hack on it. - for (auto *CE : RepValues) - ConstantPropUsersOf(CE, DL, TLI); + ConstantPropUsersOf(NewGV, DL, TLI); return NewGV; } @@ -1474,7 +1452,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, if (!GS.HasMultipleAccessingFunctions && GS.AccessingFunction && GV->getValueType()->isSingleValueType() && - GV->getType()->getAddressSpace() == 0 && + GV->getType()->getAddressSpace() == DL.getAllocaAddrSpace() && !GV->isExternallyInitialized() && GS.AccessingFunction->doesNotRecurse() && isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV, @@ -1584,7 +1562,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, GV->getAddressSpace()); NGV->takeName(GV); NGV->copyAttributesFrom(GV); - GV->replaceAllUsesWith(ConstantExpr::getBitCast(NGV, GV->getType())); + GV->replaceAllUsesWith(NGV); GV->eraseFromParent(); GV = NGV; } @@ -1635,7 +1613,7 @@ processGlobal(GlobalValue &GV, function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<DominatorTree &(Function &)> LookupDomTree) { - if (GV.getName().startswith("llvm.")) + if (GV.getName().starts_with("llvm.")) return false; GlobalStatus GS; @@ -1885,12 +1863,9 @@ static void RemovePreallocated(Function *F) { CB->eraseFromParent(); Builder.SetInsertPoint(PreallocatedSetup); - auto *StackSave = - Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::stacksave)); - + auto *StackSave = Builder.CreateStackSave(); Builder.SetInsertPoint(NewCB->getNextNonDebugInstruction()); - Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::stackrestore), - StackSave); + Builder.CreateStackRestore(StackSave); // Replace @llvm.call.preallocated.arg() with alloca. // Cannot modify users() while iterating over it, so make a copy. @@ -1917,10 +1892,8 @@ static void RemovePreallocated(Function *F) { Builder.SetInsertPoint(InsertBefore); auto *Alloca = Builder.CreateAlloca(ArgType, AddressSpace, nullptr, "paarg"); - auto *BitCast = Builder.CreateBitCast( - Alloca, Type::getInt8PtrTy(M->getContext()), UseCall->getName()); - ArgAllocas[AllocArgIndex] = BitCast; - AllocaReplacement = BitCast; + ArgAllocas[AllocArgIndex] = Alloca; + AllocaReplacement = Alloca; } UseCall->replaceAllUsesWith(AllocaReplacement); @@ -2131,19 +2104,18 @@ static void setUsedInitializer(GlobalVariable &V, const auto *VEPT = cast<PointerType>(VAT->getArrayElementType()); // Type of pointer to the array of pointers. - PointerType *Int8PtrTy = - Type::getInt8PtrTy(V.getContext(), VEPT->getAddressSpace()); + PointerType *PtrTy = + PointerType::get(V.getContext(), VEPT->getAddressSpace()); SmallVector<Constant *, 8> UsedArray; for (GlobalValue *GV : Init) { - Constant *Cast = - ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy); + Constant *Cast = ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, PtrTy); UsedArray.push_back(Cast); } // Sort to get deterministic order. array_pod_sort(UsedArray.begin(), UsedArray.end(), compareNames); - ArrayType *ATy = ArrayType::get(Int8PtrTy, UsedArray.size()); + ArrayType *ATy = ArrayType::get(PtrTy, UsedArray.size()); Module *M = V.getParent(); V.removeFromParent(); @@ -2313,7 +2285,7 @@ OptimizeGlobalAliases(Module &M, if (!hasUsesToReplace(J, Used, RenameTarget)) continue; - J.replaceAllUsesWith(ConstantExpr::getBitCast(Aliasee, J.getType())); + J.replaceAllUsesWith(Aliasee); ++NumAliasesResolved; Changed = true; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp index 599ace9ca79f..fabb3c5fb921 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -44,6 +44,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/Support/CommandLine.h" @@ -86,6 +87,11 @@ static cl::opt<int> MaxParametersForSplit( "hotcoldsplit-max-params", cl::init(4), cl::Hidden, cl::desc("Maximum number of parameters for a split function")); +static cl::opt<int> ColdBranchProbDenom( + "hotcoldsplit-cold-probability-denom", cl::init(100), cl::Hidden, + cl::desc("Divisor of cold branch probability." + "BranchProbability = 1/ColdBranchProbDenom")); + namespace { // Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify // this function unless you modify the MBB version as well. @@ -102,6 +108,32 @@ bool blockEndsInUnreachable(const BasicBlock &BB) { return !(isa<ReturnInst>(I) || isa<IndirectBrInst>(I)); } +void analyzeProfMetadata(BasicBlock *BB, + BranchProbability ColdProbThresh, + SmallPtrSetImpl<BasicBlock *> &AnnotatedColdBlocks) { + // TODO: Handle branches with > 2 successors. + BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); + if (!CondBr) + return; + + uint64_t TrueWt, FalseWt; + if (!extractBranchWeights(*CondBr, TrueWt, FalseWt)) + return; + + auto SumWt = TrueWt + FalseWt; + if (SumWt == 0) + return; + + auto TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt); + auto FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt); + + if (TrueProb <= ColdProbThresh) + AnnotatedColdBlocks.insert(CondBr->getSuccessor(0)); + + if (FalseProb <= ColdProbThresh) + AnnotatedColdBlocks.insert(CondBr->getSuccessor(1)); +} + bool unlikelyExecuted(BasicBlock &BB) { // Exception handling blocks are unlikely executed. if (BB.isEHPad() || isa<ResumeInst>(BB.getTerminator())) @@ -183,6 +215,34 @@ bool HotColdSplitting::isFunctionCold(const Function &F) const { return false; } +bool HotColdSplitting::isBasicBlockCold(BasicBlock *BB, + BranchProbability ColdProbThresh, + SmallPtrSetImpl<BasicBlock *> &ColdBlocks, + SmallPtrSetImpl<BasicBlock *> &AnnotatedColdBlocks, + BlockFrequencyInfo *BFI) const { + // This block is already part of some outlining region. + if (ColdBlocks.count(BB)) + return true; + + if (BFI) { + if (PSI->isColdBlock(BB, BFI)) + return true; + } else { + // Find cold blocks of successors of BB during a reverse postorder traversal. + analyzeProfMetadata(BB, ColdProbThresh, AnnotatedColdBlocks); + + // A statically cold BB would be known before it is visited + // because the prof-data of incoming edges are 'analyzed' as part of RPOT. + if (AnnotatedColdBlocks.count(BB)) + return true; + } + + if (EnableStaticAnalysis && unlikelyExecuted(*BB)) + return true; + + return false; +} + // Returns false if the function should not be considered for hot-cold split // optimization. bool HotColdSplitting::shouldOutlineFrom(const Function &F) const { @@ -565,6 +625,9 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { // The set of cold blocks. SmallPtrSet<BasicBlock *, 4> ColdBlocks; + // Set of cold blocks obtained with RPOT. + SmallPtrSet<BasicBlock *, 4> AnnotatedColdBlocks; + // The worklist of non-intersecting regions left to outline. SmallVector<OutliningRegion, 2> OutliningWorklist; @@ -587,16 +650,15 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { TargetTransformInfo &TTI = GetTTI(F); OptimizationRemarkEmitter &ORE = (*GetORE)(F); AssumptionCache *AC = LookupAC(F); + auto ColdProbThresh = TTI.getPredictableBranchThreshold().getCompl(); + + if (ColdBranchProbDenom.getNumOccurrences()) + ColdProbThresh = BranchProbability(1, ColdBranchProbDenom.getValue()); // Find all cold regions. for (BasicBlock *BB : RPOT) { - // This block is already part of some outlining region. - if (ColdBlocks.count(BB)) - continue; - - bool Cold = (BFI && PSI->isColdBlock(BB, BFI)) || - (EnableStaticAnalysis && unlikelyExecuted(*BB)); - if (!Cold) + if (!isBasicBlockCold(BB, ColdProbThresh, ColdBlocks, AnnotatedColdBlocks, + BFI)) continue; LLVM_DEBUG({ diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp index e258299c6a4c..a6e19df7c5f1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -155,7 +155,7 @@ struct OutlinableGroup { /// \param TargetBB - the BasicBlock to put Instruction into. static void moveBBContents(BasicBlock &SourceBB, BasicBlock &TargetBB) { for (Instruction &I : llvm::make_early_inc_range(SourceBB)) - I.moveBefore(TargetBB, TargetBB.end()); + I.moveBeforePreserving(TargetBB, TargetBB.end()); } /// A function to sort the keys of \p Map, which must be a mapping of constant @@ -198,7 +198,7 @@ Value *OutlinableRegion::findCorrespondingValueIn(const OutlinableRegion &Other, BasicBlock * OutlinableRegion::findCorrespondingBlockIn(const OutlinableRegion &Other, BasicBlock *BB) { - Instruction *FirstNonPHI = BB->getFirstNonPHI(); + Instruction *FirstNonPHI = BB->getFirstNonPHIOrDbg(); assert(FirstNonPHI && "block is empty?"); Value *CorrespondingVal = findCorrespondingValueIn(Other, FirstNonPHI); if (!CorrespondingVal) @@ -557,7 +557,7 @@ collectRegionsConstants(OutlinableRegion &Region, // Iterate over the operands in an instruction. If the global value number, // assigned by the IRSimilarityCandidate, has been seen before, we check if - // the the number has been found to be not the same value in each instance. + // the number has been found to be not the same value in each instance. for (Value *V : ID.OperVals) { std::optional<unsigned> GVNOpt = C.getGVN(V); assert(GVNOpt && "Expected a GVN for operand?"); @@ -766,7 +766,7 @@ static void moveFunctionData(Function &Old, Function &New, } } -/// Find the the constants that will need to be lifted into arguments +/// Find the constants that will need to be lifted into arguments /// as they are not the same in each instance of the region. /// /// \param [in] C - The IRSimilarityCandidate containing the region we are @@ -1346,7 +1346,7 @@ findExtractedOutputToOverallOutputMapping(Module &M, OutlinableRegion &Region, // the output, so we add a pointer type to the argument types of the overall // function to handle this output and create a mapping to it. if (!TypeFound) { - Group.ArgumentTypes.push_back(Output->getType()->getPointerTo( + Group.ArgumentTypes.push_back(PointerType::get(Output->getContext(), M.getDataLayout().getAllocaAddrSpace())); // Mark the new pointer type as the last value in the aggregate argument // list. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp index 3e00aebce372..a9747aebf67b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp @@ -13,7 +13,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/Inliner.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -63,7 +62,6 @@ #include <cassert> #include <functional> #include <utility> -#include <vector> using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index 9b4b3efd7283..733f290b1bc9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -381,8 +381,7 @@ struct ScopedSaveAliaseesAndUsed { appendToCompilerUsed(M, CompilerUsed); for (auto P : FunctionAliases) - P.first->setAliasee( - ConstantExpr::getBitCast(P.second, P.first->getType())); + P.first->setAliasee(P.second); for (auto P : ResolverIFuncs) { // This does not preserve pointer casts that may have been stripped by the @@ -411,16 +410,19 @@ class LowerTypeTestsModule { // selectJumpTableArmEncoding may decide to use Thumb in either case. bool CanUseArmJumpTable = false, CanUseThumbBWJumpTable = false; + // Cache variable used by hasBranchTargetEnforcement(). + int HasBranchTargetEnforcement = -1; + // The jump table type we ended up deciding on. (Usually the same as // Arch, except that 'arm' and 'thumb' are often interchangeable.) Triple::ArchType JumpTableArch = Triple::UnknownArch; IntegerType *Int1Ty = Type::getInt1Ty(M.getContext()); IntegerType *Int8Ty = Type::getInt8Ty(M.getContext()); - PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext()); + PointerType *Int8PtrTy = PointerType::getUnqual(M.getContext()); ArrayType *Int8Arr0Ty = ArrayType::get(Type::getInt8Ty(M.getContext()), 0); IntegerType *Int32Ty = Type::getInt32Ty(M.getContext()); - PointerType *Int32PtrTy = PointerType::getUnqual(Int32Ty); + PointerType *Int32PtrTy = PointerType::getUnqual(M.getContext()); IntegerType *Int64Ty = Type::getInt64Ty(M.getContext()); IntegerType *IntPtrTy = M.getDataLayout().getIntPtrType(M.getContext(), 0); @@ -492,6 +494,7 @@ class LowerTypeTestsModule { ArrayRef<GlobalTypeMember *> Globals); Triple::ArchType selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions); + bool hasBranchTargetEnforcement(); unsigned getJumpTableEntrySize(); Type *getJumpTableEntryType(); void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS, @@ -755,9 +758,9 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI, // also conveniently gives us a bit offset to use during the load from // the bitset. Value *OffsetSHR = - B.CreateLShr(PtrOffset, ConstantExpr::getZExt(TIL.AlignLog2, IntPtrTy)); + B.CreateLShr(PtrOffset, B.CreateZExt(TIL.AlignLog2, IntPtrTy)); Value *OffsetSHL = B.CreateShl( - PtrOffset, ConstantExpr::getZExt( + PtrOffset, B.CreateZExt( ConstantExpr::getSub( ConstantInt::get(Int8Ty, DL.getPointerSizeInBits(0)), TIL.AlignLog2), @@ -962,7 +965,6 @@ LowerTypeTestsModule::importTypeId(StringRef TypeId) { Int8Arr0Ty); if (auto *GV = dyn_cast<GlobalVariable>(C)) GV->setVisibility(GlobalValue::HiddenVisibility); - C = ConstantExpr::getBitCast(C, Int8PtrTy); return C; }; @@ -1100,15 +1102,13 @@ void LowerTypeTestsModule::importFunction( replaceCfiUses(F, FDecl, isJumpTableCanonical); // Set visibility late because it's used in replaceCfiUses() to determine - // whether uses need to to be replaced. + // whether uses need to be replaced. F->setVisibility(Visibility); } void LowerTypeTestsModule::lowerTypeTestCalls( ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) { - CombinedGlobalAddr = ConstantExpr::getBitCast(CombinedGlobalAddr, Int8PtrTy); - // For each type identifier in this disjoint set... for (Metadata *TypeId : TypeIds) { // Build the bitset. @@ -1196,6 +1196,20 @@ static const unsigned kARMJumpTableEntrySize = 4; static const unsigned kARMBTIJumpTableEntrySize = 8; static const unsigned kARMv6MJumpTableEntrySize = 16; static const unsigned kRISCVJumpTableEntrySize = 8; +static const unsigned kLOONGARCH64JumpTableEntrySize = 8; + +bool LowerTypeTestsModule::hasBranchTargetEnforcement() { + if (HasBranchTargetEnforcement == -1) { + // First time this query has been called. Find out the answer by checking + // the module flags. + if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("branch-target-enforcement"))) + HasBranchTargetEnforcement = (BTE->getZExtValue() != 0); + else + HasBranchTargetEnforcement = 0; + } + return HasBranchTargetEnforcement; +} unsigned LowerTypeTestsModule::getJumpTableEntrySize() { switch (JumpTableArch) { @@ -1209,19 +1223,22 @@ unsigned LowerTypeTestsModule::getJumpTableEntrySize() { case Triple::arm: return kARMJumpTableEntrySize; case Triple::thumb: - if (CanUseThumbBWJumpTable) + if (CanUseThumbBWJumpTable) { + if (hasBranchTargetEnforcement()) + return kARMBTIJumpTableEntrySize; return kARMJumpTableEntrySize; - else + } else { return kARMv6MJumpTableEntrySize; + } case Triple::aarch64: - if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( - M.getModuleFlag("branch-target-enforcement"))) - if (BTE->getZExtValue()) - return kARMBTIJumpTableEntrySize; + if (hasBranchTargetEnforcement()) + return kARMBTIJumpTableEntrySize; return kARMJumpTableEntrySize; case Triple::riscv32: case Triple::riscv64: return kRISCVJumpTableEntrySize; + case Triple::loongarch64: + return kLOONGARCH64JumpTableEntrySize; default: report_fatal_error("Unsupported architecture for jump tables"); } @@ -1251,10 +1268,8 @@ void LowerTypeTestsModule::createJumpTableEntry( } else if (JumpTableArch == Triple::arm) { AsmOS << "b $" << ArgIndex << "\n"; } else if (JumpTableArch == Triple::aarch64) { - if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( - Dest->getParent()->getModuleFlag("branch-target-enforcement"))) - if (BTE->getZExtValue()) - AsmOS << "bti c\n"; + if (hasBranchTargetEnforcement()) + AsmOS << "bti c\n"; AsmOS << "b $" << ArgIndex << "\n"; } else if (JumpTableArch == Triple::thumb) { if (!CanUseThumbBWJumpTable) { @@ -1281,11 +1296,16 @@ void LowerTypeTestsModule::createJumpTableEntry( << ".balign 4\n" << "1: .word $" << ArgIndex << " - (0b + 4)\n"; } else { + if (hasBranchTargetEnforcement()) + AsmOS << "bti\n"; AsmOS << "b.w $" << ArgIndex << "\n"; } } else if (JumpTableArch == Triple::riscv32 || JumpTableArch == Triple::riscv64) { AsmOS << "tail $" << ArgIndex << "@plt\n"; + } else if (JumpTableArch == Triple::loongarch64) { + AsmOS << "pcalau12i $$t0, %pc_hi20($" << ArgIndex << ")\n" + << "jirl $$r0, $$t0, %pc_lo12($" << ArgIndex << ")\n"; } else { report_fatal_error("Unsupported architecture for jump tables"); } @@ -1304,7 +1324,8 @@ void LowerTypeTestsModule::buildBitSetsFromFunctions( ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) { if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm || Arch == Triple::thumb || Arch == Triple::aarch64 || - Arch == Triple::riscv32 || Arch == Triple::riscv64) + Arch == Triple::riscv32 || Arch == Triple::riscv64 || + Arch == Triple::loongarch64) buildBitSetsFromFunctionsNative(TypeIds, Functions); else if (Arch == Triple::wasm32 || Arch == Triple::wasm64) buildBitSetsFromFunctionsWASM(TypeIds, Functions); @@ -1446,9 +1467,19 @@ void LowerTypeTestsModule::createJumpTable( SmallVector<Value *, 16> AsmArgs; AsmArgs.reserve(Functions.size() * 2); - for (GlobalTypeMember *GTM : Functions) + // Check if all entries have the NoUnwind attribute. + // If all entries have it, we can safely mark the + // cfi.jumptable as NoUnwind, otherwise, direct calls + // to the jump table will not handle exceptions properly + bool areAllEntriesNounwind = true; + for (GlobalTypeMember *GTM : Functions) { + if (!llvm::cast<llvm::Function>(GTM->getGlobal()) + ->hasFnAttribute(llvm::Attribute::NoUnwind)) { + areAllEntriesNounwind = false; + } createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs, cast<Function>(GTM->getGlobal())); + } // Align the whole table by entry size. F->setAlignment(Align(getJumpTableEntrySize())); @@ -1461,17 +1492,23 @@ void LowerTypeTestsModule::createJumpTable( if (JumpTableArch == Triple::arm) F->addFnAttr("target-features", "-thumb-mode"); if (JumpTableArch == Triple::thumb) { - F->addFnAttr("target-features", "+thumb-mode"); - if (CanUseThumbBWJumpTable) { - // Thumb jump table assembly needs Thumb2. The following attribute is - // added by Clang for -march=armv7. - F->addFnAttr("target-cpu", "cortex-a8"); + if (hasBranchTargetEnforcement()) { + // If we're generating a Thumb jump table with BTI, add a target-features + // setting to ensure BTI can be assembled. + F->addFnAttr("target-features", "+thumb-mode,+pacbti"); + } else { + F->addFnAttr("target-features", "+thumb-mode"); + if (CanUseThumbBWJumpTable) { + // Thumb jump table assembly needs Thumb2. The following attribute is + // added by Clang for -march=armv7. + F->addFnAttr("target-cpu", "cortex-a8"); + } } } // When -mbranch-protection= is used, the inline asm adds a BTI. Suppress BTI // for the function to avoid double BTI. This is a no-op without // -mbranch-protection=. - if (JumpTableArch == Triple::aarch64) { + if (JumpTableArch == Triple::aarch64 || JumpTableArch == Triple::thumb) { F->addFnAttr("branch-target-enforcement", "false"); F->addFnAttr("sign-return-address", "none"); } @@ -1485,8 +1522,13 @@ void LowerTypeTestsModule::createJumpTable( // -fcf-protection=. if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) F->addFnAttr(Attribute::NoCfCheck); - // Make sure we don't emit .eh_frame for this function. - F->addFnAttr(Attribute::NoUnwind); + + // Make sure we don't emit .eh_frame for this function if it isn't needed. + if (areAllEntriesNounwind) + F->addFnAttr(Attribute::NoUnwind); + + // Make sure we do not inline any calls to the cfi.jumptable. + F->addFnAttr(Attribute::NoInline); BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F); IRBuilder<> IRB(BB); @@ -1618,12 +1660,10 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( Function *F = cast<Function>(Functions[I]->getGlobal()); bool IsJumpTableCanonical = Functions[I]->isJumpTableCanonical(); - Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast( - ConstantExpr::getInBoundsGetElementPtr( - JumpTableType, JumpTable, - ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), - ConstantInt::get(IntPtrTy, I)}), - F->getType()); + Constant *CombinedGlobalElemPtr = ConstantExpr::getInBoundsGetElementPtr( + JumpTableType, JumpTable, + ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), + ConstantInt::get(IntPtrTy, I)}); const bool IsExported = Functions[I]->isExported(); if (!IsJumpTableCanonical) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp index f835fb26fcb8..70a3f3067d9d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -104,11 +104,13 @@ static cl::opt<std::string> MemProfImportSummary( cl::desc("Import summary to use for testing the ThinLTO backend via opt"), cl::Hidden); +namespace llvm { // Indicate we are linking with an allocator that supports hot/cold operator // new interfaces. cl::opt<bool> SupportsHotColdNew( "supports-hot-cold-new", cl::init(false), cl::Hidden, cl::desc("Linking with hot/cold operator new interfaces")); +} // namespace llvm namespace { /// CRTP base for graphs built from either IR or ThinLTO summary index. @@ -791,11 +793,10 @@ CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: eraseCalleeEdge(const ContextEdge *Edge) { - auto EI = - std::find_if(CalleeEdges.begin(), CalleeEdges.end(), - [Edge](const std::shared_ptr<ContextEdge> &CalleeEdge) { - return CalleeEdge.get() == Edge; - }); + auto EI = llvm::find_if( + CalleeEdges, [Edge](const std::shared_ptr<ContextEdge> &CalleeEdge) { + return CalleeEdge.get() == Edge; + }); assert(EI != CalleeEdges.end()); CalleeEdges.erase(EI); } @@ -803,11 +804,10 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: eraseCallerEdge(const ContextEdge *Edge) { - auto EI = - std::find_if(CallerEdges.begin(), CallerEdges.end(), - [Edge](const std::shared_ptr<ContextEdge> &CallerEdge) { - return CallerEdge.get() == Edge; - }); + auto EI = llvm::find_if( + CallerEdges, [Edge](const std::shared_ptr<ContextEdge> &CallerEdge) { + return CallerEdge.get() == Edge; + }); assert(EI != CallerEdges.end()); CallerEdges.erase(EI); } @@ -2093,8 +2093,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( for (auto &Edge : CallerEdges) { // Skip any that have been removed by an earlier recursive call. if (Edge->Callee == nullptr && Edge->Caller == nullptr) { - assert(!std::count(Node->CallerEdges.begin(), Node->CallerEdges.end(), - Edge)); + assert(!llvm::count(Node->CallerEdges, Edge)); continue; } // Ignore any caller we previously visited via another edge. @@ -2985,6 +2984,21 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { if (!mayHaveMemprofSummary(CB)) continue; + auto *CalledValue = CB->getCalledOperand(); + auto *CalledFunction = CB->getCalledFunction(); + if (CalledValue && !CalledFunction) { + CalledValue = CalledValue->stripPointerCasts(); + // Stripping pointer casts can reveal a called function. + CalledFunction = dyn_cast<Function>(CalledValue); + } + // Check if this is an alias to a function. If so, get the + // called aliasee for the checks below. + if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) { + assert(!CalledFunction && + "Expected null called function in callsite for alias"); + CalledFunction = dyn_cast<Function>(GA->getAliaseeObject()); + } + CallStack<MDNode, MDNode::op_iterator> CallsiteContext( I.getMetadata(LLVMContext::MD_callsite)); auto *MemProfMD = I.getMetadata(LLVMContext::MD_memprof); @@ -3116,13 +3130,13 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { CloneFuncIfNeeded(/*NumClones=*/StackNode.Clones.size()); // Should have skipped indirect calls via mayHaveMemprofSummary. - assert(CB->getCalledFunction()); - assert(!IsMemProfClone(*CB->getCalledFunction())); + assert(CalledFunction); + assert(!IsMemProfClone(*CalledFunction)); // Update the calls per the summary info. // Save orig name since it gets updated in the first iteration // below. - auto CalleeOrigName = CB->getCalledFunction()->getName(); + auto CalleeOrigName = CalledFunction->getName(); for (unsigned J = 0; J < StackNode.Clones.size(); J++) { // Do nothing if this version calls the original version of its // callee. @@ -3130,7 +3144,7 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { continue; auto NewF = M.getOrInsertFunction( getMemProfFuncName(CalleeOrigName, StackNode.Clones[J]), - CB->getCalledFunction()->getFunctionType()); + CalledFunction->getFunctionType()); CallBase *CBClone; // Copy 0 is the original function. if (!J) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp index feda5d6459cb..c8c011d94e4a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -107,6 +107,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/StructuralHash.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -171,15 +172,14 @@ namespace { class FunctionNode { mutable AssertingVH<Function> F; - FunctionComparator::FunctionHash Hash; + IRHash Hash; public: // Note the hash is recalculated potentially multiple times, but it is cheap. - FunctionNode(Function *F) - : F(F), Hash(FunctionComparator::functionHash(*F)) {} + FunctionNode(Function *F) : F(F), Hash(StructuralHash(*F)) {} Function *getFunc() const { return F; } - FunctionComparator::FunctionHash getHash() const { return Hash; } + IRHash getHash() const { return Hash; } /// Replace the reference to the function F by the function G, assuming their /// implementations are equal. @@ -375,9 +375,32 @@ bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) { } #endif +/// Check whether \p F has an intrinsic which references +/// distinct metadata as an operand. The most common +/// instance of this would be CFI checks for function-local types. +static bool hasDistinctMetadataIntrinsic(const Function &F) { + for (const BasicBlock &BB : F) { + for (const Instruction &I : BB.instructionsWithoutDebug()) { + if (!isa<IntrinsicInst>(&I)) + continue; + + for (Value *Op : I.operands()) { + auto *MDL = dyn_cast<MetadataAsValue>(Op); + if (!MDL) + continue; + if (MDNode *N = dyn_cast<MDNode>(MDL->getMetadata())) + if (N->isDistinct()) + return true; + } + } + } + return false; +} + /// Check whether \p F is eligible for function merging. static bool isEligibleForMerging(Function &F) { - return !F.isDeclaration() && !F.hasAvailableExternallyLinkage(); + return !F.isDeclaration() && !F.hasAvailableExternallyLinkage() && + !hasDistinctMetadataIntrinsic(F); } bool MergeFunctions::runOnModule(Module &M) { @@ -390,11 +413,10 @@ bool MergeFunctions::runOnModule(Module &M) { // All functions in the module, ordered by hash. Functions with a unique // hash value are easily eliminated. - std::vector<std::pair<FunctionComparator::FunctionHash, Function *>> - HashedFuncs; + std::vector<std::pair<IRHash, Function *>> HashedFuncs; for (Function &Func : M) { if (isEligibleForMerging(Func)) { - HashedFuncs.push_back({FunctionComparator::functionHash(Func), &Func}); + HashedFuncs.push_back({StructuralHash(Func), &Func}); } } @@ -441,7 +463,6 @@ bool MergeFunctions::runOnModule(Module &M) { // Replace direct callers of Old with New. void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { - Constant *BitcastNew = ConstantExpr::getBitCast(New, Old->getType()); for (Use &U : llvm::make_early_inc_range(Old->uses())) { CallBase *CB = dyn_cast<CallBase>(U.getUser()); if (CB && CB->isCallee(&U)) { @@ -450,7 +471,7 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { // type congruences in byval(), in which case we need to keep the byval // type of the call-site, not the callee function. remove(CB->getFunction()); - U.set(BitcastNew); + U.set(New); } } } @@ -632,7 +653,7 @@ static bool canCreateThunkFor(Function *F) { // Don't merge tiny functions using a thunk, since it can just end up // making the function larger. if (F->size() == 1) { - if (F->front().size() <= 2) { + if (F->front().sizeWithoutDebug() < 2) { LLVM_DEBUG(dbgs() << "canCreateThunkFor: " << F->getName() << " is too small to bother creating a thunk for\n"); return false; @@ -641,6 +662,13 @@ static bool canCreateThunkFor(Function *F) { return true; } +/// Copy metadata from one function to another. +static void copyMetadataIfPresent(Function *From, Function *To, StringRef Key) { + if (MDNode *MD = From->getMetadata(Key)) { + To->setMetadata(Key, MD); + } +} + // Replace G with a simple tail call to bitcast(F). Also (unless // MergeFunctionsPDI holds) replace direct uses of G with bitcast(F), // delete G. Under MergeFunctionsPDI, we use G itself for creating @@ -719,6 +747,9 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { } else { NewG->copyAttributesFrom(G); NewG->takeName(G); + // Ensure CFI type metadata is propagated to the new function. + copyMetadataIfPresent(G, NewG, "type"); + copyMetadataIfPresent(G, NewG, "kcfi_type"); removeUsers(G); G->replaceAllUsesWith(NewG); G->eraseFromParent(); @@ -741,10 +772,9 @@ static bool canCreateAliasFor(Function *F) { // Replace G with an alias to F (deleting function G) void MergeFunctions::writeAlias(Function *F, Function *G) { - Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); PointerType *PtrType = G->getType(); auto *GA = GlobalAlias::create(G->getValueType(), PtrType->getAddressSpace(), - G->getLinkage(), "", BitcastF, G->getParent()); + G->getLinkage(), "", F, G->getParent()); const MaybeAlign FAlign = F->getAlign(); const MaybeAlign GAlign = G->getAlign(); @@ -795,6 +825,9 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { F->getAddressSpace(), "", F->getParent()); NewF->copyAttributesFrom(F); NewF->takeName(F); + // Ensure CFI type metadata is propagated to the new function. + copyMetadataIfPresent(F, NewF, "type"); + copyMetadataIfPresent(F, NewF, "kcfi_type"); removeUsers(F); F->replaceAllUsesWith(NewF); @@ -825,9 +858,8 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { // to replace a key in ValueMap<GlobalValue *> with a non-global. GlobalNumbers.erase(G); // If G's address is not significant, replace it entirely. - Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); removeUsers(G); - G->replaceAllUsesWith(BitcastF); + G->replaceAllUsesWith(F); } else { // Redirect direct callers of G to F. (See note on MergeFunctionsPDI // above). diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 588f3901e3cb..b2665161c090 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/Assumptions.h" #include "llvm/IR/BasicBlock.h" @@ -42,6 +43,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -156,6 +158,8 @@ STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified, "Number of OpenMP runtime function uses identified"); STATISTIC(NumOpenMPTargetRegionKernels, "Number of OpenMP target region entry points (=kernels) identified"); +STATISTIC(NumNonOpenMPTargetRegionKernels, + "Number of non-OpenMP target region kernels identified"); STATISTIC(NumOpenMPTargetRegionKernelsSPMD, "Number of OpenMP target region entry points (=kernels) executed in " "SPMD-mode instead of generic-mode"); @@ -181,6 +185,92 @@ STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated"); static constexpr auto TAG = "[" DEBUG_TYPE "]"; #endif +namespace KernelInfo { + +// struct ConfigurationEnvironmentTy { +// uint8_t UseGenericStateMachine; +// uint8_t MayUseNestedParallelism; +// llvm::omp::OMPTgtExecModeFlags ExecMode; +// int32_t MinThreads; +// int32_t MaxThreads; +// int32_t MinTeams; +// int32_t MaxTeams; +// }; + +// struct DynamicEnvironmentTy { +// uint16_t DebugIndentionLevel; +// }; + +// struct KernelEnvironmentTy { +// ConfigurationEnvironmentTy Configuration; +// IdentTy *Ident; +// DynamicEnvironmentTy *DynamicEnv; +// }; + +#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \ + constexpr const unsigned MEMBER##Idx = IDX; + +KERNEL_ENVIRONMENT_IDX(Configuration, 0) +KERNEL_ENVIRONMENT_IDX(Ident, 1) + +#undef KERNEL_ENVIRONMENT_IDX + +#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \ + constexpr const unsigned MEMBER##Idx = IDX; + +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5) +KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX + +#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \ + RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \ + return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \ + } + +KERNEL_ENVIRONMENT_GETTER(Ident, Constant) +KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct) + +#undef KERNEL_ENVIRONMENT_GETTER + +#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \ + ConstantInt *get##MEMBER##FromKernelEnvironment( \ + ConstantStruct *KernelEnvC) { \ + ConstantStruct *ConfigC = \ + getConfigurationFromKernelEnvironment(KernelEnvC); \ + return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \ + } + +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams) +KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER + +GlobalVariable * +getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) { + constexpr const int InitKernelEnvironmentArgNo = 0; + return cast<GlobalVariable>( + KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo) + ->stripPointerCasts()); +} + +ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) { + GlobalVariable *KernelEnvGV = + getKernelEnvironementGVFromKernelInitCB(KernelInitCB); + return cast<ConstantStruct>(KernelEnvGV->getInitializer()); +} +} // namespace KernelInfo + namespace { struct AAHeapToShared; @@ -196,6 +286,7 @@ struct OMPInformationCache : public InformationCache { : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M), OpenMPPostLink(OpenMPPostLink) { + OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M); OMPBuilder.initialize(); initializeRuntimeFunctions(M); initializeInternalControlVars(); @@ -531,7 +622,7 @@ struct OMPInformationCache : public InformationCache { for (Function &F : M) { for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"}) if (F.hasFnAttribute(Attribute::NoInline) && - F.getName().startswith(Prefix) && + F.getName().starts_with(Prefix) && !F.hasFnAttribute(Attribute::OptimizeNone)) F.removeFnAttr(Attribute::NoInline); } @@ -595,7 +686,7 @@ struct KernelInfoState : AbstractState { /// The parallel regions (identified by the outlined parallel functions) that /// can be reached from the associated function. - BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false> + BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false> ReachedKnownParallelRegions; /// State to track what parallel region we might reach. @@ -610,6 +701,10 @@ struct KernelInfoState : AbstractState { /// one we abort as the kernel is malformed. CallBase *KernelInitCB = nullptr; + /// The constant kernel environement as taken from and passed to + /// __kmpc_target_init. + ConstantStruct *KernelEnvC = nullptr; + /// The __kmpc_target_deinit call in this kernel, if any. If we find more than /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; @@ -651,6 +746,7 @@ struct KernelInfoState : AbstractState { SPMDCompatibilityTracker.indicatePessimisticFixpoint(); ReachedKnownParallelRegions.indicatePessimisticFixpoint(); ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); + NestedParallelism = true; return ChangeStatus::CHANGED; } @@ -680,6 +776,8 @@ struct KernelInfoState : AbstractState { return false; if (ParallelLevels != RHS.ParallelLevels) return false; + if (NestedParallelism != RHS.NestedParallelism) + return false; return true; } @@ -714,6 +812,12 @@ struct KernelInfoState : AbstractState { "assumptions."); KernelDeinitCB = KIS.KernelDeinitCB; } + if (KIS.KernelEnvC) { + if (KernelEnvC && KernelEnvC != KIS.KernelEnvC) + llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt " + "assumptions."); + KernelEnvC = KIS.KernelEnvC; + } SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; @@ -875,6 +979,9 @@ struct OpenMPOpt { } } + if (OMPInfoCache.OpenMPPostLink) + Changed |= removeRuntimeSymbols(); + return Changed; } @@ -903,7 +1010,7 @@ struct OpenMPOpt { /// Print OpenMP GPU kernels for testing. void printKernels() const { for (Function *F : SCC) { - if (!omp::isKernel(*F)) + if (!omp::isOpenMPKernel(*F)) continue; auto Remark = [&](OptimizationRemarkAnalysis ORA) { @@ -1404,6 +1511,37 @@ private: return Changed; } + /// Tries to remove known runtime symbols that are optional from the module. + bool removeRuntimeSymbols() { + // The RPC client symbol is defined in `libc` and indicates that something + // required an RPC server. If its users were all optimized out then we can + // safely remove it. + // TODO: This should be somewhere more common in the future. + if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) { + if (!GV->getType()->isPointerTy()) + return false; + + Constant *C = GV->getInitializer(); + if (!C) + return false; + + // Check to see if the only user of the RPC client is the external handle. + GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts()); + if (!Client || Client->getNumUses() > 1 || + Client->user_back() != GV->getInitializer()) + return false; + + Client->replaceAllUsesWith(PoisonValue::get(Client->getType())); + Client->eraseFromParent(); + + GV->replaceAllUsesWith(PoisonValue::get(GV->getType())); + GV->eraseFromParent(); + + return true; + } + return false; + } + /// Tries to hide the latency of runtime calls that involve host to /// device memory transfers by splitting them into their "issue" and "wait" /// versions. The "issue" is moved upwards as much as possible. The "wait" is @@ -1858,7 +1996,7 @@ private: Function *F = I->getParent()->getParent(); auto &ORE = OREGetter(F); - if (RemarkName.startswith("OMP")) + if (RemarkName.starts_with("OMP")) ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)) << " [" << RemarkName << "]"; @@ -1874,7 +2012,7 @@ private: RemarkCallBack &&RemarkCB) const { auto &ORE = OREGetter(F); - if (RemarkName.startswith("OMP")) + if (RemarkName.starts_with("OMP")) ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)) << " [" << RemarkName << "]"; @@ -1944,7 +2082,7 @@ Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { // TODO: We should use an AA to create an (optimistic and callback // call-aware) call graph. For now we stick to simple patterns that // are less powerful, basically the worst fixpoint. - if (isKernel(F)) { + if (isOpenMPKernel(F)) { CachedKernel = Kernel(&F); return *CachedKernel; } @@ -2535,6 +2673,17 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker { } }; +/// Determines if \p BB exits the function unconditionally itself or reaches a +/// block that does through only unique successors. +static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) { + if (succ_empty(BB)) + return true; + const BasicBlock *const Successor = BB->getUniqueSuccessor(); + if (!Successor) + return false; + return hasFunctionEndAsUniqueSuccessor(Successor); +} + struct AAExecutionDomainFunction : public AAExecutionDomain { AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) : AAExecutionDomain(IRP, A) {} @@ -2587,18 +2736,22 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { if (!ED.IsReachedFromAlignedBarrierOnly || ED.EncounteredNonLocalSideEffect) return; + if (!ED.EncounteredAssumes.empty() && !A.isModulePass()) + return; - // We can remove this barrier, if it is one, or all aligned barriers - // reaching the kernel end. In the latter case we can transitively work - // our way back until we find a barrier that guards a side-effect if we - // are dealing with the kernel end here. + // We can remove this barrier, if it is one, or aligned barriers reaching + // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel + // end should only be removed if the kernel end is their unique successor; + // otherwise, they may have side-effects that aren't accounted for in the + // kernel end in their other successors. If those barriers have other + // barriers reaching them, those can be transitively removed as well as + // long as the kernel end is also their unique successor. if (CB) { DeletedBarriers.insert(CB); A.deleteAfterManifest(*CB); ++NumBarriersEliminated; Changed = ChangeStatus::CHANGED; } else if (!ED.AlignedBarriers.empty()) { - NumBarriersEliminated += ED.AlignedBarriers.size(); Changed = ChangeStatus::CHANGED; SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(), ED.AlignedBarriers.end()); @@ -2609,7 +2762,10 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { continue; if (LastCB->getFunction() != getAnchorScope()) continue; + if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent())) + continue; if (!DeletedBarriers.count(LastCB)) { + ++NumBarriersEliminated; A.deleteAfterManifest(*LastCB); continue; } @@ -2633,7 +2789,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { HandleAlignedBarrier(CB); // Handle the "kernel end barrier" for kernels too. - if (omp::isKernel(*getAnchorScope())) + if (omp::isOpenMPKernel(*getAnchorScope())) HandleAlignedBarrier(nullptr); return Changed; @@ -2779,9 +2935,11 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr; if (!CB) return false; - const int InitModeArgNo = 1; - auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo)); - return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC); + ConstantStruct *KernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(CB); + ConstantInt *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC); + return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC; } if (C->isZero()) { @@ -2884,11 +3042,11 @@ bool AAExecutionDomainFunction::handleCallees(Attributor &A, } else { // We could not find all predecessors, so this is either a kernel or a // function with external linkage (or with some other weird uses). - if (omp::isKernel(*getAnchorScope())) { + if (omp::isOpenMPKernel(*getAnchorScope())) { EntryBBED.IsExecutedByInitialThreadOnly = false; EntryBBED.IsReachedFromAlignedBarrierOnly = true; EntryBBED.EncounteredNonLocalSideEffect = false; - ExitED.IsReachingAlignedBarrierOnly = true; + ExitED.IsReachingAlignedBarrierOnly = false; } else { EntryBBED.IsExecutedByInitialThreadOnly = false; EntryBBED.IsReachedFromAlignedBarrierOnly = false; @@ -2938,7 +3096,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { Function *F = getAnchorScope(); BasicBlock &EntryBB = F->getEntryBlock(); - bool IsKernel = omp::isKernel(*F); + bool IsKernel = omp::isOpenMPKernel(*F); SmallVector<Instruction *> SyncInstWorklist; for (auto &RIt : *RPOT) { @@ -3063,7 +3221,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { if (EDAA && EDAA->getState().isValidState()) { const auto &CalleeED = EDAA->getFunctionExecutionDomain(); ED.IsReachedFromAlignedBarrierOnly = - CalleeED.IsReachedFromAlignedBarrierOnly; + CalleeED.IsReachedFromAlignedBarrierOnly; AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly; if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly) ED.EncounteredNonLocalSideEffect |= @@ -3442,6 +3600,10 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { using Base = StateWrapper<KernelInfoState, AbstractAttribute>; AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + /// The callee value is tracked beyond a simple stripPointerCasts, so we allow + /// unknown callees. + static bool requiresCalleeForCallBase() { return false; } + /// Statistics are tracked as part of manifest for now. void trackStatistics() const override {} @@ -3468,7 +3630,8 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { ", #ParLevels: " + (ParallelLevels.isValidState() ? std::to_string(ParallelLevels.size()) - : "<invalid>"); + : "<invalid>") + + ", NestedPar: " + (NestedParallelism ? "yes" : "no"); } /// Create an abstract attribute biew for the position \p IRP. @@ -3500,6 +3663,33 @@ struct AAKernelInfoFunction : AAKernelInfo { return GuardedInstructions; } + void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) { + Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction( + KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx}); + assert(NewKernelEnvC && "Failed to create new kernel environment"); + KernelEnvC = cast<ConstantStruct>(NewKernelEnvC); + } + +#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \ + void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \ + ConstantStruct *ConfigC = \ + KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \ + Constant *NewConfigC = ConstantFoldInsertValueInstruction( \ + ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \ + assert(NewConfigC && "Failed to create new configuration environment"); \ + setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \ + } + + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams) + KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams) + +#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { // This is a high-level transform that might change the constant arguments @@ -3548,61 +3738,73 @@ struct AAKernelInfoFunction : AAKernelInfo { ReachingKernelEntries.insert(Fn); IsKernelEntry = true; - // For kernels we might need to initialize/finalize the IsSPMD state and - // we need to register a simplification callback so that the Attributor - // knows the constant arguments to __kmpc_target_init and - // __kmpc_target_deinit might actually change. - - Attributor::SimplifictionCallbackTy StateMachineSimplifyCB = - [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> std::optional<Value *> { - return nullptr; - }; - - Attributor::SimplifictionCallbackTy ModeSimplifyCB = - [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> std::optional<Value *> { - // IRP represents the "SPMDCompatibilityTracker" argument of an - // __kmpc_target_init or - // __kmpc_target_deinit call. We will answer this one with the internal - // state. - if (!SPMDCompatibilityTracker.isValidState()) - return nullptr; - if (!SPMDCompatibilityTracker.isAtFixpoint()) { - if (AA) - A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); + KernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + GlobalVariable *KernelEnvGV = + KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB); + + Attributor::GlobalVariableSimplifictionCallbackTy + KernelConfigurationSimplifyCB = + [&](const GlobalVariable &GV, const AbstractAttribute *AA, + bool &UsedAssumedInformation) -> std::optional<Constant *> { + if (!isAtFixpoint()) { + if (!AA) + return nullptr; UsedAssumedInformation = true; - } else { - UsedAssumedInformation = false; + A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); } - auto *Val = ConstantInt::getSigned( - IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()), - SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD - : OMP_TGT_EXEC_MODE_GENERIC); - return Val; + return KernelEnvC; }; - constexpr const int InitModeArgNo = 1; - constexpr const int DeinitModeArgNo = 1; - constexpr const int InitUseStateMachineArgNo = 2; - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), - StateMachineSimplifyCB); - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo), - ModeSimplifyCB); - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo), - ModeSimplifyCB); + A.registerGlobalVariableSimplificationCallback( + *KernelEnvGV, KernelConfigurationSimplifyCB); // Check if we know we are in SPMD-mode already. - ConstantInt *ModeArg = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); - if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) + ConstantInt *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC); + ConstantInt *AssumedExecModeC = ConstantInt::get( + ExecModeC->getType(), + ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD); + if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD) SPMDCompatibilityTracker.indicateOptimisticFixpoint(); - // This is a generic region but SPMDization is disabled so stop tracking. else if (DisableOpenMPOptSPMDization) + // This is a generic region but SPMDization is disabled so stop + // tracking. SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + else + setExecModeOfKernelEnvironment(AssumedExecModeC); + + const Triple T(Fn->getParent()->getTargetTriple()); + auto *Int32Ty = Type::getInt32Ty(Fn->getContext()); + auto [MinThreads, MaxThreads] = + OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn); + if (MinThreads) + setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads)); + if (MaxThreads) + setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads)); + auto [MinTeams, MaxTeams] = + OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn); + if (MinTeams) + setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams)); + if (MaxTeams) + setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams)); + + ConstantInt *MayUseNestedParallelismC = + KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC); + ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get( + MayUseNestedParallelismC->getType(), NestedParallelism); + setMayUseNestedParallelismOfKernelEnvironment( + AssumedMayUseNestedParallelismC); + + if (!DisableOpenMPOptStateMachineRewrite) { + ConstantInt *UseGenericStateMachineC = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + KernelEnvC); + ConstantInt *AssumedUseGenericStateMachineC = + ConstantInt::get(UseGenericStateMachineC->getType(), false); + setUseGenericStateMachineOfKernelEnvironment( + AssumedUseGenericStateMachineC); + } // Register virtual uses of functions we might need to preserve. auto RegisterVirtualUse = [&](RuntimeFunction RFKind, @@ -3703,22 +3905,32 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!KernelInitCB || !KernelDeinitCB) return ChangeStatus::UNCHANGED; - /// Insert nested Parallelism global variable - Function *Kernel = getAnchorScope(); - Module &M = *Kernel->getParent(); - Type *Int8Ty = Type::getInt8Ty(M.getContext()); - auto *GV = new GlobalVariable( - M, Int8Ty, /* isConstant */ true, GlobalValue::WeakAnyLinkage, - ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0), - Kernel->getName() + "_nested_parallelism"); - GV->setVisibility(GlobalValue::HiddenVisibility); - - // If we can we change the execution mode to SPMD-mode otherwise we build a - // custom state machine. ChangeStatus Changed = ChangeStatus::UNCHANGED; + + bool HasBuiltStateMachine = true; if (!changeToSPMDMode(A, Changed)) { if (!KernelInitCB->getCalledFunction()->isDeclaration()) - return buildCustomStateMachine(A); + HasBuiltStateMachine = buildCustomStateMachine(A, Changed); + else + HasBuiltStateMachine = false; + } + + // We need to reset KernelEnvC if specific rewriting is not done. + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + ConstantInt *OldUseGenericStateMachineVal = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC); + if (!HasBuiltStateMachine) + setUseGenericStateMachineOfKernelEnvironment( + OldUseGenericStateMachineVal); + + // At last, update the KernelEnvc + GlobalVariable *KernelEnvGV = + KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB); + if (KernelEnvGV->getInitializer() != KernelEnvC) { + KernelEnvGV->setInitializer(KernelEnvC); + Changed = ChangeStatus::CHANGED; } return Changed; @@ -3788,14 +4000,14 @@ struct AAKernelInfoFunction : AAKernelInfo { // Find escaping outputs from the guarded region to outside users and // broadcast their values to them. for (Instruction &I : *RegionStartBB) { - SmallPtrSet<Instruction *, 4> OutsideUsers; - for (User *Usr : I.users()) { - Instruction &UsrI = *cast<Instruction>(Usr); + SmallVector<Use *, 4> OutsideUses; + for (Use &U : I.uses()) { + Instruction &UsrI = *cast<Instruction>(U.getUser()); if (UsrI.getParent() != RegionStartBB) - OutsideUsers.insert(&UsrI); + OutsideUses.push_back(&U); } - if (OutsideUsers.empty()) + if (OutsideUses.empty()) continue; HasBroadcastValues = true; @@ -3818,8 +4030,8 @@ struct AAKernelInfoFunction : AAKernelInfo { RegionBarrierBB->getTerminator()); // Emit a load instruction and replace uses of the output value. - for (Instruction *UsrI : OutsideUsers) - UsrI->replaceUsesOfWith(&I, LoadI); + for (Use *U : OutsideUses) + A.changeUseAfterManifest(*U, *LoadI); } auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); @@ -4043,19 +4255,14 @@ struct AAKernelInfoFunction : AAKernelInfo { auto *CB = cast<CallBase>(Kernel->user_back()); Kernel = CB->getCaller(); } - assert(omp::isKernel(*Kernel) && "Expected kernel function!"); + assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!"); // Check if the kernel is already in SPMD mode, if so, return success. - GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( - (Kernel->getName() + "_exec_mode").str()); - assert(ExecMode && "Kernel without exec mode?"); - assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!"); - - // Set the global exec mode flag to indicate SPMD-Generic mode. - assert(isa<ConstantInt>(ExecMode->getInitializer()) && - "ExecMode is not an integer!"); - const int8_t ExecModeVal = - cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue(); + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); + auto *ExecModeC = + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC); + const int8_t ExecModeVal = ExecModeC->getSExtValue(); if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC) return true; @@ -4073,27 +4280,8 @@ struct AAKernelInfoFunction : AAKernelInfo { // kernel is executed in. assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC && "Initially non-SPMD kernel has SPMD exec mode!"); - ExecMode->setInitializer( - ConstantInt::get(ExecMode->getInitializer()->getType(), - ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD)); - - // Next rewrite the init and deinit calls to indicate we use SPMD-mode now. - const int InitModeArgNo = 1; - const int DeinitModeArgNo = 1; - const int InitUseStateMachineArgNo = 2; - - auto &Ctx = getAnchorValue().getContext(); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitModeArgNo), - *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), - OMP_TGT_EXEC_MODE_SPMD)); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), - *ConstantInt::getBool(Ctx, false)); - A.changeUseAfterManifest( - KernelDeinitCB->getArgOperandUse(DeinitModeArgNo), - *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), - OMP_TGT_EXEC_MODE_SPMD)); + setExecModeOfKernelEnvironment(ConstantInt::get( + ExecModeC->getType(), ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD)); ++NumOpenMPTargetRegionKernelsSPMD; @@ -4104,46 +4292,47 @@ struct AAKernelInfoFunction : AAKernelInfo { return true; }; - ChangeStatus buildCustomStateMachine(Attributor &A) { + bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) { // If we have disabled state machine rewrites, don't make a custom one if (DisableOpenMPOptStateMachineRewrite) - return ChangeStatus::UNCHANGED; + return false; // Don't rewrite the state machine if we are not in a valid state. if (!ReachedKnownParallelRegions.isValidState()) - return ChangeStatus::UNCHANGED; + return false; auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); if (!OMPInfoCache.runtimeFnsAvailable( {OMPRTL___kmpc_get_hardware_num_threads_in_block, OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic, OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel})) - return ChangeStatus::UNCHANGED; + return false; - const int InitModeArgNo = 1; - const int InitUseStateMachineArgNo = 2; + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB); // Check if the current configuration is non-SPMD and generic state machine. // If we already have SPMD mode or a custom state machine we do not need to // go any further. If it is anything but a constant something is weird and // we give up. - ConstantInt *UseStateMachine = dyn_cast<ConstantInt>( - KernelInitCB->getArgOperand(InitUseStateMachineArgNo)); - ConstantInt *Mode = - dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo)); + ConstantInt *UseStateMachineC = + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC); + ConstantInt *ModeC = + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC); // If we are stuck with generic mode, try to create a custom device (=GPU) // state machine which is specialized for the parallel regions that are // reachable by the kernel. - if (!UseStateMachine || UseStateMachine->isZero() || !Mode || - (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) - return ChangeStatus::UNCHANGED; + if (UseStateMachineC->isZero() || + (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)) + return false; + + Changed = ChangeStatus::CHANGED; // If not SPMD mode, indicate we use a custom state machine now. - auto &Ctx = getAnchorValue().getContext(); - auto *FalseVal = ConstantInt::getBool(Ctx, false); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal); + setUseGenericStateMachineOfKernelEnvironment( + ConstantInt::get(UseStateMachineC->getType(), false)); // If we don't actually need a state machine we are done here. This can // happen if there simply are no parallel regions. In the resulting kernel @@ -4157,7 +4346,7 @@ struct AAKernelInfoFunction : AAKernelInfo { }; A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark); - return ChangeStatus::CHANGED; + return true; } // Keep track in the statistics of our new shiny custom state machine. @@ -4222,6 +4411,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // UserCodeEntryBB: // user code // __kmpc_target_deinit(...) // + auto &Ctx = getAnchorValue().getContext(); Function *Kernel = getAssociatedFunction(); assert(Kernel && "Expected an associated function!"); @@ -4292,7 +4482,7 @@ struct AAKernelInfoFunction : AAKernelInfo { // Create local storage for the work function pointer. const DataLayout &DL = M.getDataLayout(); - Type *VoidPtrTy = Type::getInt8PtrTy(Ctx); + Type *VoidPtrTy = PointerType::getUnqual(Ctx); Instruction *WorkFnAI = new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr, "worker.work_fn.addr", &Kernel->getEntryBlock().front()); @@ -4304,7 +4494,7 @@ struct AAKernelInfoFunction : AAKernelInfo { StateMachineBeginBB->end()), DLoc)); - Value *Ident = KernelInitCB->getArgOperand(0); + Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC); Value *GTid = KernelInitCB; FunctionCallee BarrierFn = @@ -4337,9 +4527,6 @@ struct AAKernelInfoFunction : AAKernelInfo { FunctionType *ParallelRegionFnTy = FunctionType::get( Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)}, false); - Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( - WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast", - StateMachineBeginBB); Instruction *IsDone = ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, @@ -4358,11 +4545,15 @@ struct AAKernelInfoFunction : AAKernelInfo { Value *ZeroArg = Constant::getNullValue(ParallelRegionFnTy->getParamType(0)); + const unsigned int WrapperFunctionArgNo = 6; + // Now that we have most of the CFG skeleton it is time for the if-cascade // that checks the function pointer we got from the runtime against the // parallel regions we expect, if there are any. for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) { - auto *ParallelRegion = ReachedKnownParallelRegions[I]; + auto *CB = ReachedKnownParallelRegions[I]; + auto *ParallelRegion = dyn_cast<Function>( + CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts()); BasicBlock *PRExecuteBB = BasicBlock::Create( Ctx, "worker_state_machine.parallel_region.execute", Kernel, StateMachineEndParallelBB); @@ -4374,13 +4565,15 @@ struct AAKernelInfoFunction : AAKernelInfo { BasicBlock *PRNextBB = BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check", Kernel, StateMachineEndParallelBB); + A.registerManifestAddedBasicBlock(*PRExecuteBB); + A.registerManifestAddedBasicBlock(*PRNextBB); // Check if we need to compare the pointer at all or if we can just // call the parallel region function. Value *IsPR; if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) { Instruction *CmpI = ICmpInst::Create( - ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion, + ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion, "worker.check_parallel_region", StateMachineIfCascadeCurrentBB); CmpI->setDebugLoc(DLoc); IsPR = CmpI; @@ -4400,7 +4593,7 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!ReachedUnknownParallelRegions.empty()) { StateMachineIfCascadeCurrentBB->setName( "worker_state_machine.parallel_region.fallback.execute"); - CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "", + CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "", StateMachineIfCascadeCurrentBB) ->setDebugLoc(DLoc); } @@ -4423,7 +4616,7 @@ struct AAKernelInfoFunction : AAKernelInfo { BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB) ->setDebugLoc(DLoc); - return ChangeStatus::CHANGED; + return true; } /// Fixpoint iteration update function. Will be called every time a dependence @@ -4431,6 +4624,46 @@ struct AAKernelInfoFunction : AAKernelInfo { ChangeStatus updateImpl(Attributor &A) override { KernelInfoState StateBefore = getState(); + // When we leave this function this RAII will make sure the member + // KernelEnvC is updated properly depending on the state. That member is + // used for simplification of values and needs to be up to date at all + // times. + struct UpdateKernelEnvCRAII { + AAKernelInfoFunction &AA; + + UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {} + + ~UpdateKernelEnvCRAII() { + if (!AA.KernelEnvC) + return; + + ConstantStruct *ExistingKernelEnvC = + KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB); + + if (!AA.isValidState()) { + AA.KernelEnvC = ExistingKernelEnvC; + return; + } + + if (!AA.ReachedKnownParallelRegions.isValidState()) + AA.setUseGenericStateMachineOfKernelEnvironment( + KernelInfo::getUseGenericStateMachineFromKernelEnvironment( + ExistingKernelEnvC)); + + if (!AA.SPMDCompatibilityTracker.isValidState()) + AA.setExecModeOfKernelEnvironment( + KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC)); + + ConstantInt *MayUseNestedParallelismC = + KernelInfo::getMayUseNestedParallelismFromKernelEnvironment( + AA.KernelEnvC); + ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get( + MayUseNestedParallelismC->getType(), AA.NestedParallelism); + AA.setMayUseNestedParallelismOfKernelEnvironment( + NewMayUseNestedParallelismC); + } + } RAII(*this); + // Callback to check a read/write instruction. auto CheckRWInst = [&](Instruction &I) { // We handle calls later. @@ -4634,15 +4867,13 @@ struct AAKernelInfoCallSite : AAKernelInfo { AAKernelInfo::initialize(A); CallBase &CB = cast<CallBase>(getAssociatedValue()); - Function *Callee = getAssociatedFunction(); - auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>( *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); // Check for SPMD-mode assumptions. if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) { - SPMDCompatibilityTracker.indicateOptimisticFixpoint(); indicateOptimisticFixpoint(); + return; } // First weed out calls we do not care about, that is readonly/readnone @@ -4657,124 +4888,156 @@ struct AAKernelInfoCallSite : AAKernelInfo { // we will handle them explicitly in the switch below. If it is not, we // will use an AAKernelInfo object on the callee to gather information and // merge that into the current state. The latter happens in the updateImpl. - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); - if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { - // Unknown caller or declarations are not analyzable, we give up. - if (!Callee || !A.isFunctionIPOAmendable(*Callee)) { - - // Unknown callees might contain parallel regions, except if they have - // an appropriate assumption attached. - if (!AssumptionAA || - !(AssumptionAA->hasAssumption("omp_no_openmp") || - AssumptionAA->hasAssumption("omp_no_parallelism"))) - ReachedUnknownParallelRegions.insert(&CB); - - // If SPMDCompatibilityTracker is not fixed, we need to give up on the - // idea we can run something unknown in SPMD-mode. - if (!SPMDCompatibilityTracker.isAtFixpoint()) { - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - } + auto CheckCallee = [&](Function *Callee, unsigned NumCallees) { + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + // Unknown caller or declarations are not analyzable, we give up. + if (!Callee || !A.isFunctionIPOAmendable(*Callee)) { + + // Unknown callees might contain parallel regions, except if they have + // an appropriate assumption attached. + if (!AssumptionAA || + !(AssumptionAA->hasAssumption("omp_no_openmp") || + AssumptionAA->hasAssumption("omp_no_parallelism"))) + ReachedUnknownParallelRegions.insert(&CB); + + // If SPMDCompatibilityTracker is not fixed, we need to give up on the + // idea we can run something unknown in SPMD-mode. + if (!SPMDCompatibilityTracker.isAtFixpoint()) { + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); + } - // We have updated the state for this unknown call properly, there won't - // be any change so we indicate a fixpoint. - indicateOptimisticFixpoint(); + // We have updated the state for this unknown call properly, there + // won't be any change so we indicate a fixpoint. + indicateOptimisticFixpoint(); + } + // If the callee is known and can be used in IPO, we will update the + // state based on the callee state in updateImpl. + return; + } + if (NumCallees > 1) { + indicatePessimisticFixpoint(); + return; } - // If the callee is known and can be used in IPO, we will update the state - // based on the callee state in updateImpl. - return; - } - const unsigned int WrapperFunctionArgNo = 6; - RuntimeFunction RF = It->getSecond(); - switch (RF) { - // All the functions we know are compatible with SPMD mode. - case OMPRTL___kmpc_is_spmd_exec_mode: - case OMPRTL___kmpc_distribute_static_fini: - case OMPRTL___kmpc_for_static_fini: - case OMPRTL___kmpc_global_thread_num: - case OMPRTL___kmpc_get_hardware_num_threads_in_block: - case OMPRTL___kmpc_get_hardware_num_blocks: - case OMPRTL___kmpc_single: - case OMPRTL___kmpc_end_single: - case OMPRTL___kmpc_master: - case OMPRTL___kmpc_end_master: - case OMPRTL___kmpc_barrier: - case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2: - case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2: - case OMPRTL___kmpc_nvptx_end_reduce_nowait: - break; - case OMPRTL___kmpc_distribute_static_init_4: - case OMPRTL___kmpc_distribute_static_init_4u: - case OMPRTL___kmpc_distribute_static_init_8: - case OMPRTL___kmpc_distribute_static_init_8u: - case OMPRTL___kmpc_for_static_init_4: - case OMPRTL___kmpc_for_static_init_4u: - case OMPRTL___kmpc_for_static_init_8: - case OMPRTL___kmpc_for_static_init_8u: { - // Check the schedule and allow static schedule in SPMD mode. - unsigned ScheduleArgOpNo = 2; - auto *ScheduleTypeCI = - dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo)); - unsigned ScheduleTypeVal = - ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; - switch (OMPScheduleType(ScheduleTypeVal)) { - case OMPScheduleType::UnorderedStatic: - case OMPScheduleType::UnorderedStaticChunked: - case OMPScheduleType::OrderedDistribute: - case OMPScheduleType::OrderedDistributeChunked: + RuntimeFunction RF = It->getSecond(); + switch (RF) { + // All the functions we know are compatible with SPMD mode. + case OMPRTL___kmpc_is_spmd_exec_mode: + case OMPRTL___kmpc_distribute_static_fini: + case OMPRTL___kmpc_for_static_fini: + case OMPRTL___kmpc_global_thread_num: + case OMPRTL___kmpc_get_hardware_num_threads_in_block: + case OMPRTL___kmpc_get_hardware_num_blocks: + case OMPRTL___kmpc_single: + case OMPRTL___kmpc_end_single: + case OMPRTL___kmpc_master: + case OMPRTL___kmpc_end_master: + case OMPRTL___kmpc_barrier: + case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2: + case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2: + case OMPRTL___kmpc_error: + case OMPRTL___kmpc_flush: + case OMPRTL___kmpc_get_hardware_thread_id_in_block: + case OMPRTL___kmpc_get_warp_size: + case OMPRTL_omp_get_thread_num: + case OMPRTL_omp_get_num_threads: + case OMPRTL_omp_get_max_threads: + case OMPRTL_omp_in_parallel: + case OMPRTL_omp_get_dynamic: + case OMPRTL_omp_get_cancellation: + case OMPRTL_omp_get_nested: + case OMPRTL_omp_get_schedule: + case OMPRTL_omp_get_thread_limit: + case OMPRTL_omp_get_supported_active_levels: + case OMPRTL_omp_get_max_active_levels: + case OMPRTL_omp_get_level: + case OMPRTL_omp_get_ancestor_thread_num: + case OMPRTL_omp_get_team_size: + case OMPRTL_omp_get_active_level: + case OMPRTL_omp_in_final: + case OMPRTL_omp_get_proc_bind: + case OMPRTL_omp_get_num_places: + case OMPRTL_omp_get_num_procs: + case OMPRTL_omp_get_place_proc_ids: + case OMPRTL_omp_get_place_num: + case OMPRTL_omp_get_partition_num_places: + case OMPRTL_omp_get_partition_place_nums: + case OMPRTL_omp_get_wtime: break; - default: + case OMPRTL___kmpc_distribute_static_init_4: + case OMPRTL___kmpc_distribute_static_init_4u: + case OMPRTL___kmpc_distribute_static_init_8: + case OMPRTL___kmpc_distribute_static_init_8u: + case OMPRTL___kmpc_for_static_init_4: + case OMPRTL___kmpc_for_static_init_4u: + case OMPRTL___kmpc_for_static_init_8: + case OMPRTL___kmpc_for_static_init_8u: { + // Check the schedule and allow static schedule in SPMD mode. + unsigned ScheduleArgOpNo = 2; + auto *ScheduleTypeCI = + dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo)); + unsigned ScheduleTypeVal = + ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; + switch (OMPScheduleType(ScheduleTypeVal)) { + case OMPScheduleType::UnorderedStatic: + case OMPScheduleType::UnorderedStaticChunked: + case OMPScheduleType::OrderedDistribute: + case OMPScheduleType::OrderedDistributeChunked: + break; + default: + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); + break; + }; + } break; + case OMPRTL___kmpc_target_init: + KernelInitCB = &CB; + break; + case OMPRTL___kmpc_target_deinit: + KernelDeinitCB = &CB; + break; + case OMPRTL___kmpc_parallel_51: + if (!handleParallel51(A, CB)) + indicatePessimisticFixpoint(); + return; + case OMPRTL___kmpc_omp_task: + // We do not look into tasks right now, just give up. SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); + ReachedUnknownParallelRegions.insert(&CB); break; - }; - } break; - case OMPRTL___kmpc_target_init: - KernelInitCB = &CB; - break; - case OMPRTL___kmpc_target_deinit: - KernelDeinitCB = &CB; - break; - case OMPRTL___kmpc_parallel_51: - if (auto *ParallelRegion = dyn_cast<Function>( - CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) { - ReachedKnownParallelRegions.insert(ParallelRegion); - /// Check nested parallelism - auto *FnAA = A.getAAFor<AAKernelInfo>( - *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL); - NestedParallelism |= !FnAA || !FnAA->getState().isValidState() || - !FnAA->ReachedKnownParallelRegions.empty() || - !FnAA->ReachedUnknownParallelRegions.empty(); + case OMPRTL___kmpc_alloc_shared: + case OMPRTL___kmpc_free_shared: + // Return without setting a fixpoint, to be resolved in updateImpl. + return; + default: + // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, + // generally. However, they do not hide parallel regions. + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + SPMDCompatibilityTracker.insert(&CB); break; } - // The condition above should usually get the parallel region function - // pointer and record it. In the off chance it doesn't we assume the - // worst. - ReachedUnknownParallelRegions.insert(&CB); - break; - case OMPRTL___kmpc_omp_task: - // We do not look into tasks right now, just give up. - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - ReachedUnknownParallelRegions.insert(&CB); - break; - case OMPRTL___kmpc_alloc_shared: - case OMPRTL___kmpc_free_shared: - // Return without setting a fixpoint, to be resolved in updateImpl. + // All other OpenMP runtime calls will not reach parallel regions so they + // can be safely ignored for now. Since it is a known OpenMP runtime call + // we have now modeled all effects and there is no need for any update. + indicateOptimisticFixpoint(); + }; + + const auto *AACE = + A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) { + CheckCallee(getAssociatedFunction(), 1); return; - default: - // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, - // generally. However, they do not hide parallel regions. - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); - break; } - // All other OpenMP runtime calls will not reach parallel regions so they - // can be safely ignored for now. Since it is a known OpenMP runtime call we - // have now modeled all effects and there is no need for any update. - indicateOptimisticFixpoint(); + const auto &OptimisticEdges = AACE->getOptimisticEdges(); + for (auto *Callee : OptimisticEdges) { + CheckCallee(Callee, OptimisticEdges.size()); + if (isAtFixpoint()) + break; + } } ChangeStatus updateImpl(Attributor &A) override { @@ -4782,62 +5045,115 @@ struct AAKernelInfoCallSite : AAKernelInfo { // call site specific liveness information and then it makes // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. - Function *F = getAssociatedFunction(); - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F); - - // If F is not a runtime function, propagate the AAKernelInfo of the callee. - if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { - const IRPosition &FnPos = IRPosition::function(*F); - auto *FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); - if (!FnAA) - return indicatePessimisticFixpoint(); - if (getState() == FnAA->getState()) - return ChangeStatus::UNCHANGED; - getState() = FnAA->getState(); - return ChangeStatus::CHANGED; - } - - // F is a runtime function that allocates or frees memory, check - // AAHeapToStack and AAHeapToShared. KernelInfoState StateBefore = getState(); - assert((It->getSecond() == OMPRTL___kmpc_alloc_shared || - It->getSecond() == OMPRTL___kmpc_free_shared) && - "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"); - - CallBase &CB = cast<CallBase>(getAssociatedValue()); - auto *HeapToStackAA = A.getAAFor<AAHeapToStack>( - *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); - auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>( - *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); + auto CheckCallee = [&](Function *F, int NumCallees) { + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F); + + // If F is not a runtime function, propagate the AAKernelInfo of the + // callee. + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + const IRPosition &FnPos = IRPosition::function(*F); + auto *FnAA = + A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED); + if (!FnAA) + return indicatePessimisticFixpoint(); + if (getState() == FnAA->getState()) + return ChangeStatus::UNCHANGED; + getState() = FnAA->getState(); + return ChangeStatus::CHANGED; + } + if (NumCallees > 1) + return indicatePessimisticFixpoint(); - RuntimeFunction RF = It->getSecond(); + CallBase &CB = cast<CallBase>(getAssociatedValue()); + if (It->getSecond() == OMPRTL___kmpc_parallel_51) { + if (!handleParallel51(A, CB)) + return indicatePessimisticFixpoint(); + return StateBefore == getState() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } - switch (RF) { - // If neither HeapToStack nor HeapToShared assume the call is removed, - // assume SPMD incompatibility. - case OMPRTL___kmpc_alloc_shared: - if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) && - (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB))) - SPMDCompatibilityTracker.insert(&CB); - break; - case OMPRTL___kmpc_free_shared: - if ((!HeapToStackAA || - !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) && - (!HeapToSharedAA || - !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB))) + // F is a runtime function that allocates or frees memory, check + // AAHeapToStack and AAHeapToShared. + assert( + (It->getSecond() == OMPRTL___kmpc_alloc_shared || + It->getSecond() == OMPRTL___kmpc_free_shared) && + "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call"); + + auto *HeapToStackAA = A.getAAFor<AAHeapToStack>( + *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); + auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>( + *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL); + + RuntimeFunction RF = It->getSecond(); + + switch (RF) { + // If neither HeapToStack nor HeapToShared assume the call is removed, + // assume SPMD incompatibility. + case OMPRTL___kmpc_alloc_shared: + if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) && + (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB))) + SPMDCompatibilityTracker.insert(&CB); + break; + case OMPRTL___kmpc_free_shared: + if ((!HeapToStackAA || + !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) && + (!HeapToSharedAA || + !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB))) + SPMDCompatibilityTracker.insert(&CB); + break; + default: + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.insert(&CB); - break; - default: - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - SPMDCompatibilityTracker.insert(&CB); + } + return ChangeStatus::CHANGED; + }; + + const auto *AACE = + A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL); + if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) { + if (Function *F = getAssociatedFunction()) + CheckCallee(F, /*NumCallees=*/1); + } else { + const auto &OptimisticEdges = AACE->getOptimisticEdges(); + for (auto *Callee : OptimisticEdges) { + CheckCallee(Callee, OptimisticEdges.size()); + if (isAtFixpoint()) + break; + } } return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + + /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was + /// handled, if a problem occurred, false is returned. + bool handleParallel51(Attributor &A, CallBase &CB) { + const unsigned int NonWrapperFunctionArgNo = 5; + const unsigned int WrapperFunctionArgNo = 6; + auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed() + ? NonWrapperFunctionArgNo + : WrapperFunctionArgNo; + + auto *ParallelRegion = dyn_cast<Function>( + CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts()); + if (!ParallelRegion) + return false; + + ReachedKnownParallelRegions.insert(&CB); + /// Check nested parallelism + auto *FnAA = A.getAAFor<AAKernelInfo>( + *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL); + NestedParallelism |= !FnAA || !FnAA->getState().isValidState() || + !FnAA->ReachedKnownParallelRegions.empty() || + !FnAA->ReachedKnownParallelRegions.isValidState() || + !FnAA->ReachedUnknownParallelRegions.isValidState() || + !FnAA->ReachedUnknownParallelRegions.empty(); + return true; + } }; struct AAFoldRuntimeCall @@ -5251,6 +5567,11 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) { UsedAssumedInformation, AA::Interprocedural); continue; } + if (auto *CI = dyn_cast<CallBase>(&I)) { + if (CI->isIndirectCall()) + A.getOrCreateAAFor<AAIndirectCallInfo>( + IRPosition::callsite_function(*CI)); + } if (auto *SI = dyn_cast<StoreInst>(&I)) { A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI)); continue; @@ -5569,7 +5890,9 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); } -bool llvm::omp::isKernel(Function &Fn) { return Fn.hasFnAttribute("kernel"); } +bool llvm::omp::isOpenMPKernel(Function &Fn) { + return Fn.hasFnAttribute("kernel"); +} KernelSet llvm::omp::getDeviceKernels(Module &M) { // TODO: Create a more cross-platform way of determining device kernels. @@ -5591,10 +5914,13 @@ KernelSet llvm::omp::getDeviceKernels(Module &M) { if (!KernelFn) continue; - assert(isKernel(*KernelFn) && "Inconsistent kernel function annotation"); - ++NumOpenMPTargetRegionKernels; - - Kernels.insert(KernelFn); + // We are only interested in OpenMP target regions. Others, such as kernels + // generated by CUDA but linked together, are not interesting to this pass. + if (isOpenMPKernel(*KernelFn)) { + ++NumOpenMPTargetRegionKernels; + Kernels.insert(KernelFn); + } else + ++NumNonOpenMPTargetRegionKernels; } return Kernels; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp index b88ba2dec24b..aa4f205ec5bd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -161,7 +161,7 @@ struct FunctionOutliningInfo { // The dominating block of the region to be outlined. BasicBlock *NonReturnBlock = nullptr; - // The set of blocks in Entries that that are predecessors to ReturnBlock + // The set of blocks in Entries that are predecessors to ReturnBlock SmallVector<BasicBlock *, 4> ReturnBlockPreds; }; @@ -767,7 +767,7 @@ bool PartialInlinerImpl::shouldPartialInline( const DataLayout &DL = Caller->getParent()->getDataLayout(); // The savings of eliminating the call: - int NonWeightedSavings = getCallsiteCost(CB, DL); + int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL); BlockFrequency NormWeightedSavings(NonWeightedSavings); // Weighted saving is smaller than weighted cost, return false @@ -842,12 +842,12 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB, } if (CallInst *CI = dyn_cast<CallInst>(&I)) { - InlineCost += getCallsiteCost(*CI, DL); + InlineCost += getCallsiteCost(*TTI, *CI, DL); continue; } if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) { - InlineCost += getCallsiteCost(*II, DL); + InlineCost += getCallsiteCost(*TTI, *II, DL); continue; } @@ -1042,7 +1042,7 @@ void PartialInlinerImpl::FunctionCloner::normalizeReturnBlock() const { ClonedOI->ReturnBlock = ClonedOI->ReturnBlock->splitBasicBlock( ClonedOI->ReturnBlock->getFirstNonPHI()->getIterator()); BasicBlock::iterator I = PreReturn->begin(); - Instruction *Ins = &ClonedOI->ReturnBlock->front(); + BasicBlock::iterator Ins = ClonedOI->ReturnBlock->begin(); SmallVector<Instruction *, 4> DeadPhis; while (I != PreReturn->end()) { PHINode *OldPhi = dyn_cast<PHINode>(I); @@ -1050,9 +1050,10 @@ void PartialInlinerImpl::FunctionCloner::normalizeReturnBlock() const { break; PHINode *RetPhi = - PHINode::Create(OldPhi->getType(), NumPredsFromEntries + 1, "", Ins); + PHINode::Create(OldPhi->getType(), NumPredsFromEntries + 1, ""); + RetPhi->insertBefore(Ins); OldPhi->replaceAllUsesWith(RetPhi); - Ins = ClonedOI->ReturnBlock->getFirstNonPHI(); + Ins = ClonedOI->ReturnBlock->getFirstNonPHIIt(); RetPhi->addIncoming(&*I, PreReturn); for (BasicBlock *E : ClonedOI->ReturnBlockPreds) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp index e2e6364df906..b1f9b827dcba 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/AttributeMask.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ModRef.h" @@ -43,7 +44,7 @@ STATISTIC(NumInstReplaced, "Number of instructions replaced with (simpler) instruction"); static cl::opt<unsigned> FuncSpecMaxIters( - "funcspec-max-iters", cl::init(1), cl::Hidden, cl::desc( + "funcspec-max-iters", cl::init(10), cl::Hidden, cl::desc( "The maximum number of iterations function specialization is run")); static void findReturnsToZap(Function &F, @@ -235,11 +236,11 @@ static bool runIPSCCP( // nodes in executable blocks we found values for. The function's entry // block is not part of BlocksToErase, so we have to handle it separately. for (BasicBlock *BB : BlocksToErase) { - NumInstRemoved += changeToUnreachable(BB->getFirstNonPHI(), + NumInstRemoved += changeToUnreachable(BB->getFirstNonPHIOrDbg(), /*PreserveLCSSA=*/false, &DTU); } if (!Solver.isBlockExecutable(&F.front())) - NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHI(), + NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHIOrDbg(), /*PreserveLCSSA=*/false, &DTU); BasicBlock *NewUnreachableBB = nullptr; @@ -371,6 +372,18 @@ static bool runIPSCCP( StoreInst *SI = cast<StoreInst>(GV->user_back()); SI->eraseFromParent(); } + + // Try to create a debug constant expression for the global variable + // initializer value. + SmallVector<DIGlobalVariableExpression *, 1> GVEs; + GV->getDebugInfo(GVEs); + if (GVEs.size() == 1) { + DIBuilder DIB(M); + if (DIExpression *InitExpr = getExpressionForConstant( + DIB, *GV->getInitializer(), *GV->getValueType())) + GVEs[0]->replaceOperandWith(1, InitExpr); + } + MadeChanges = true; M.eraseGlobalVariable(GV); ++NumGlobalConst; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp index 3ddf5fe20edb..f7a54d428f20 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/SampleContextTracker.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/InstrTypes.h" @@ -29,7 +28,7 @@ using namespace sampleprof; namespace llvm { ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite, - StringRef CalleeName) { + FunctionId CalleeName) { if (CalleeName.empty()) return getHottestChildContext(CallSite); @@ -104,7 +103,7 @@ SampleContextTracker::moveContextSamples(ContextTrieNode &ToNodeParent, } void ContextTrieNode::removeChildContext(const LineLocation &CallSite, - StringRef CalleeName) { + FunctionId CalleeName) { uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite); // Note this essentially calls dtor and destroys that child context AllChildContext.erase(Hash); @@ -114,7 +113,7 @@ std::map<uint64_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() { return AllChildContext; } -StringRef ContextTrieNode::getFuncName() const { return FuncName; } +FunctionId ContextTrieNode::getFuncName() const { return FuncName; } FunctionSamples *ContextTrieNode::getFunctionSamples() const { return FuncSamples; @@ -178,7 +177,7 @@ void ContextTrieNode::dumpTree() { } ContextTrieNode *ContextTrieNode::getOrCreateChildContext( - const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) { + const LineLocation &CallSite, FunctionId CalleeName, bool AllowCreate) { uint64_t Hash = FunctionSamples::getCallSiteHash(CalleeName, CallSite); auto It = AllChildContext.find(Hash); if (It != AllChildContext.end()) { @@ -201,7 +200,7 @@ SampleContextTracker::SampleContextTracker( : GUIDToFuncNameMap(GUIDToFuncNameMap) { for (auto &FuncSample : Profiles) { FunctionSamples *FSamples = &FuncSample.second; - SampleContext Context = FuncSample.first; + SampleContext Context = FuncSample.second.getContext(); LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context.toString() << "\n"); ContextTrieNode *NewNode = getOrCreateContextPath(Context, true); @@ -232,14 +231,12 @@ SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, return nullptr; CalleeName = FunctionSamples::getCanonicalFnName(CalleeName); - // Convert real function names to MD5 names, if the input profile is - // MD5-based. - std::string FGUID; - CalleeName = getRepInFormat(CalleeName, FunctionSamples::UseMD5, FGUID); + + FunctionId FName = getRepInFormat(CalleeName); // For indirect call, CalleeName will be empty, in which case the context // profile for callee with largest total samples will be returned. - ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName); + ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, FName); if (CalleeContext) { FunctionSamples *FSamples = CalleeContext->getFunctionSamples(); LLVM_DEBUG(if (FSamples) { @@ -305,27 +302,23 @@ SampleContextTracker::getContextSamplesFor(const SampleContext &Context) { SampleContextTracker::ContextSamplesTy & SampleContextTracker::getAllContextSamplesFor(const Function &Func) { StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); - return FuncToCtxtProfiles[CanonName]; + return FuncToCtxtProfiles[getRepInFormat(CanonName)]; } SampleContextTracker::ContextSamplesTy & SampleContextTracker::getAllContextSamplesFor(StringRef Name) { - return FuncToCtxtProfiles[Name]; + return FuncToCtxtProfiles[getRepInFormat(Name)]; } FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func, bool MergeContext) { StringRef CanonName = FunctionSamples::getCanonicalFnName(Func); - return getBaseSamplesFor(CanonName, MergeContext); + return getBaseSamplesFor(getRepInFormat(CanonName), MergeContext); } -FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, +FunctionSamples *SampleContextTracker::getBaseSamplesFor(FunctionId Name, bool MergeContext) { LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n"); - // Convert real function names to MD5 names, if the input profile is - // MD5-based. - std::string FGUID; - Name = getRepInFormat(Name, FunctionSamples::UseMD5, FGUID); // Base profile is top-level node (child of root node), so try to retrieve // existing top-level node for given function first. If it exists, it could be @@ -373,7 +366,7 @@ void SampleContextTracker::markContextSamplesInlined( ContextTrieNode &SampleContextTracker::getRootContext() { return RootContext; } void SampleContextTracker::promoteMergeContextSamplesTree( - const Instruction &Inst, StringRef CalleeName) { + const Instruction &Inst, FunctionId CalleeName) { LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n" << Inst << "\n"); // Get the caller context for the call instruction, we don't use callee @@ -458,9 +451,9 @@ void SampleContextTracker::dump() { RootContext.dumpTree(); } StringRef SampleContextTracker::getFuncNameFor(ContextTrieNode *Node) const { if (!FunctionSamples::UseMD5) - return Node->getFuncName(); + return Node->getFuncName().stringRef(); assert(GUIDToFuncNameMap && "GUIDToFuncNameMap needs to be populated first"); - return GUIDToFuncNameMap->lookup(std::stoull(Node->getFuncName().data())); + return GUIDToFuncNameMap->lookup(Node->getFuncName().getHashCode()); } ContextTrieNode * @@ -470,7 +463,7 @@ SampleContextTracker::getContextFor(const SampleContext &Context) { ContextTrieNode * SampleContextTracker::getCalleeContextFor(const DILocation *DIL, - StringRef CalleeName) { + FunctionId CalleeName) { assert(DIL && "Expect non-null location"); ContextTrieNode *CallContext = getContextFor(DIL); @@ -485,7 +478,7 @@ SampleContextTracker::getCalleeContextFor(const DILocation *DIL, ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { assert(DIL && "Expect non-null location"); - SmallVector<std::pair<LineLocation, StringRef>, 10> S; + SmallVector<std::pair<LineLocation, FunctionId>, 10> S; // Use C++ linkage name if possible. const DILocation *PrevDIL = DIL; @@ -494,7 +487,8 @@ ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { if (Name.empty()) Name = PrevDIL->getScope()->getSubprogram()->getName(); S.push_back( - std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), Name)); + std::make_pair(FunctionSamples::getCallSiteIdentifier(DIL), + getRepInFormat(Name))); PrevDIL = DIL; } @@ -503,24 +497,14 @@ ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) { StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName(); if (RootName.empty()) RootName = PrevDIL->getScope()->getSubprogram()->getName(); - S.push_back(std::make_pair(LineLocation(0, 0), RootName)); - - // Convert real function names to MD5 names, if the input profile is - // MD5-based. - std::list<std::string> MD5Names; - if (FunctionSamples::UseMD5) { - for (auto &Location : S) { - MD5Names.emplace_back(); - getRepInFormat(Location.second, FunctionSamples::UseMD5, MD5Names.back()); - Location.second = MD5Names.back(); - } - } + S.push_back(std::make_pair(LineLocation(0, 0), + getRepInFormat(RootName))); ContextTrieNode *ContextNode = &RootContext; int I = S.size(); while (--I >= 0 && ContextNode) { LineLocation &CallSite = S[I].first; - StringRef CalleeName = S[I].second; + FunctionId CalleeName = S[I].second; ContextNode = ContextNode->getChildContext(CallSite, CalleeName); } @@ -540,10 +524,10 @@ SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, // Create child node at parent line/disc location if (AllowCreate) { ContextNode = - ContextNode->getOrCreateChildContext(CallSiteLoc, Callsite.FuncName); + ContextNode->getOrCreateChildContext(CallSiteLoc, Callsite.Func); } else { ContextNode = - ContextNode->getChildContext(CallSiteLoc, Callsite.FuncName); + ContextNode->getChildContext(CallSiteLoc, Callsite.Func); } CallSiteLoc = Callsite.Location; } @@ -553,12 +537,14 @@ SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, return ContextNode; } -ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) { +ContextTrieNode * +SampleContextTracker::getTopLevelContextNode(FunctionId FName) { assert(!FName.empty() && "Top level node query must provide valid name"); return RootContext.getChildContext(LineLocation(0, 0), FName); } -ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) { +ContextTrieNode & +SampleContextTracker::addTopLevelContextNode(FunctionId FName) { assert(!getTopLevelContextNode(FName) && "Node to add must not exist"); return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName); } @@ -638,7 +624,7 @@ void SampleContextTracker::createContextLessProfileMap( FunctionSamples *FProfile = Node->getFunctionSamples(); // Profile's context can be empty, use ContextNode's func name. if (FProfile) - ContextLessProfiles[Node->getFuncName()].merge(*FProfile); + ContextLessProfiles.Create(Node->getFuncName()).merge(*FProfile); } } } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp index a53baecd4776..6c6f0a0eca72 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -56,6 +56,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/PseudoProbe.h" #include "llvm/IR/ValueSymbolTable.h" #include "llvm/ProfileData/InstrProf.h" @@ -142,11 +143,6 @@ static cl::opt<bool> PersistProfileStaleness( cl::desc("Compute stale profile statistical metrics and write it into the " "native object file(.llvm_stats section).")); -static cl::opt<bool> FlattenProfileForMatching( - "flatten-profile-for-matching", cl::Hidden, cl::init(true), - cl::desc( - "Use flattened profile for stale profile detection and matching.")); - static cl::opt<bool> ProfileSampleAccurate( "profile-sample-accurate", cl::Hidden, cl::init(false), cl::desc("If the sample profile is accurate, we will mark all un-sampled " @@ -429,7 +425,7 @@ struct CandidateComparer { return LCS->getBodySamples().size() > RCS->getBodySamples().size(); // Tie breaker using GUID so we have stable/deterministic inlining order - return LCS->getGUID(LCS->getName()) < RCS->getGUID(RCS->getName()); + return LCS->getGUID() < RCS->getGUID(); } }; @@ -458,32 +454,44 @@ class SampleProfileMatcher { uint64_t MismatchedFuncHashSamples = 0; uint64_t TotalFuncHashSamples = 0; + // A dummy name for unknown indirect callee, used to differentiate from a + // non-call instruction that also has an empty callee name. + static constexpr const char *UnknownIndirectCallee = + "unknown.indirect.callee"; + public: SampleProfileMatcher(Module &M, SampleProfileReader &Reader, const PseudoProbeManager *ProbeManager) - : M(M), Reader(Reader), ProbeManager(ProbeManager) { - if (FlattenProfileForMatching) { - ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles, - FunctionSamples::ProfileIsCS); - } - } + : M(M), Reader(Reader), ProbeManager(ProbeManager){}; void runOnModule(); private: FunctionSamples *getFlattenedSamplesFor(const Function &F) { StringRef CanonFName = FunctionSamples::getCanonicalFnName(F); - auto It = FlattenedProfiles.find(CanonFName); + auto It = FlattenedProfiles.find(FunctionId(CanonFName)); if (It != FlattenedProfiles.end()) return &It->second; return nullptr; } - void runOnFunction(const Function &F, const FunctionSamples &FS); + void runOnFunction(const Function &F); + void findIRAnchors(const Function &F, + std::map<LineLocation, StringRef> &IRAnchors); + void findProfileAnchors( + const FunctionSamples &FS, + std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors); + void countMismatchedSamples(const FunctionSamples &FS); void countProfileMismatches( + const Function &F, const FunctionSamples &FS, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors); + void countProfileCallsiteMismatches( const FunctionSamples &FS, - const std::unordered_set<LineLocation, LineLocationHash> - &MatchedCallsiteLocs, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors, uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites); - LocToLocMap &getIRToProfileLocationMap(const Function &F) { auto Ret = FuncMappings.try_emplace( FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap()); @@ -491,12 +499,10 @@ private: } void distributeIRToProfileLocationMap(); void distributeIRToProfileLocationMap(FunctionSamples &FS); - void populateProfileCallsites( - const FunctionSamples &FS, - StringMap<std::set<LineLocation>> &CalleeToCallsitesMap); void runStaleProfileMatching( - const std::map<LineLocation, StringRef> &IRLocations, - StringMap<std::set<LineLocation>> &CalleeToCallsitesMap, + const Function &F, const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors, LocToLocMap &IRToProfileLocationMap); }; @@ -538,7 +544,6 @@ protected: findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const; void findExternalInlineCandidate(CallBase *CB, const FunctionSamples *Samples, DenseSet<GlobalValue::GUID> &InlinedGUIDs, - const StringMap<Function *> &SymbolMap, uint64_t Threshold); // Attempt to promote indirect call and also inline the promoted call bool tryPromoteAndInlineCandidate( @@ -573,7 +578,7 @@ protected: /// the function name. If the function name contains suffix, additional /// entry is added to map from the stripped name to the function if there /// is one-to-one mapping. - StringMap<Function *> SymbolMap; + HashKeyMap<std::unordered_map, FunctionId, Function *> SymbolMap; std::function<AssumptionCache &(Function &)> GetAC; std::function<TargetTransformInfo &(Function &)> GetTTI; @@ -615,6 +620,11 @@ protected: // All the Names used in FunctionSamples including outline function // names, inline instance names and call target names. StringSet<> NamesInProfile; + // MD5 version of NamesInProfile. Either NamesInProfile or GUIDsInProfile is + // populated, depends on whether the profile uses MD5. Because the name table + // generally contains several magnitude more entries than the number of + // functions, we do not want to convert all names from one form to another. + llvm::DenseSet<uint64_t> GUIDsInProfile; // For symbol in profile symbol list, whether to regard their profiles // to be accurate. It is mainly decided by existance of profile symbol @@ -759,8 +769,7 @@ SampleProfileLoader::findIndirectCallFunctionSamples( assert(L && R && "Expect non-null FunctionSamples"); if (L->getHeadSamplesEstimate() != R->getHeadSamplesEstimate()) return L->getHeadSamplesEstimate() > R->getHeadSamplesEstimate(); - return FunctionSamples::getGUID(L->getName()) < - FunctionSamples::getGUID(R->getName()); + return L->getGUID() < R->getGUID(); }; if (FunctionSamples::ProfileIsCS) { @@ -970,13 +979,13 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate( // This prevents allocating an array of zero length in callees below. if (MaxNumPromotions == 0) return false; - auto CalleeFunctionName = Candidate.CalleeSamples->getFuncName(); + auto CalleeFunctionName = Candidate.CalleeSamples->getFunction(); auto R = SymbolMap.find(CalleeFunctionName); - if (R == SymbolMap.end() || !R->getValue()) + if (R == SymbolMap.end() || !R->second) return false; auto &CI = *Candidate.CallInstr; - if (!doesHistoryAllowICP(CI, R->getValue()->getName())) + if (!doesHistoryAllowICP(CI, R->second->getName())) return false; const char *Reason = "Callee function not available"; @@ -986,17 +995,17 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate( // clone the caller first, and inline the cloned caller if it is // recursive. As llvm does not inline recursive calls, we will // simply ignore it instead of handling it explicitly. - if (!R->getValue()->isDeclaration() && R->getValue()->getSubprogram() && - R->getValue()->hasFnAttribute("use-sample-profile") && - R->getValue() != &F && isLegalToPromote(CI, R->getValue(), &Reason)) { + if (!R->second->isDeclaration() && R->second->getSubprogram() && + R->second->hasFnAttribute("use-sample-profile") && + R->second != &F && isLegalToPromote(CI, R->second, &Reason)) { // For promoted target, set its value with NOMORE_ICP_MAGICNUM count // in the value profile metadata so the target won't be promoted again. SmallVector<InstrProfValueData, 1> SortedCallTargets = {InstrProfValueData{ - Function::getGUID(R->getValue()->getName()), NOMORE_ICP_MAGICNUM}}; + Function::getGUID(R->second->getName()), NOMORE_ICP_MAGICNUM}}; updateIDTMetaData(CI, SortedCallTargets, 0); auto *DI = &pgo::promoteIndirectCall( - CI, R->getValue(), Candidate.CallsiteCount, Sum, false, ORE); + CI, R->second, Candidate.CallsiteCount, Sum, false, ORE); if (DI) { Sum -= Candidate.CallsiteCount; // Do not prorate the indirect callsite distribution since the original @@ -1025,7 +1034,8 @@ bool SampleProfileLoader::tryPromoteAndInlineCandidate( } } else { LLVM_DEBUG(dbgs() << "\nFailed to promote indirect call to " - << Candidate.CalleeSamples->getFuncName() << " because " + << FunctionSamples::getCanonicalFnName( + Candidate.CallInstr->getName())<< " because " << Reason << "\n"); } return false; @@ -1070,8 +1080,7 @@ void SampleProfileLoader::emitOptimizationRemarksForInlineCandidates( void SampleProfileLoader::findExternalInlineCandidate( CallBase *CB, const FunctionSamples *Samples, - DenseSet<GlobalValue::GUID> &InlinedGUIDs, - const StringMap<Function *> &SymbolMap, uint64_t Threshold) { + DenseSet<GlobalValue::GUID> &InlinedGUIDs, uint64_t Threshold) { // If ExternalInlineAdvisor(ReplayInlineAdvisor) wants to inline an external // function make sure it's imported @@ -1080,7 +1089,7 @@ void SampleProfileLoader::findExternalInlineCandidate( // just add the direct GUID and move on if (!Samples) { InlinedGUIDs.insert( - FunctionSamples::getGUID(CB->getCalledFunction()->getName())); + Function::getGUID(CB->getCalledFunction()->getName())); return; } // Otherwise, drop the threshold to import everything that we can @@ -1121,22 +1130,20 @@ void SampleProfileLoader::findExternalInlineCandidate( CalleeSample->getContext().hasAttribute(ContextShouldBeInlined); if (!PreInline && CalleeSample->getHeadSamplesEstimate() < Threshold) continue; - - StringRef Name = CalleeSample->getFuncName(); - Function *Func = SymbolMap.lookup(Name); + + Function *Func = SymbolMap.lookup(CalleeSample->getFunction()); // Add to the import list only when it's defined out of module. if (!Func || Func->isDeclaration()) - InlinedGUIDs.insert(FunctionSamples::getGUID(CalleeSample->getName())); + InlinedGUIDs.insert(CalleeSample->getGUID()); // Import hot CallTargets, which may not be available in IR because full // profile annotation cannot be done until backend compilation in ThinLTO. for (const auto &BS : CalleeSample->getBodySamples()) for (const auto &TS : BS.second.getCallTargets()) - if (TS.getValue() > Threshold) { - StringRef CalleeName = CalleeSample->getFuncName(TS.getKey()); - const Function *Callee = SymbolMap.lookup(CalleeName); + if (TS.second > Threshold) { + const Function *Callee = SymbolMap.lookup(TS.first); if (!Callee || Callee->isDeclaration()) - InlinedGUIDs.insert(FunctionSamples::getGUID(TS.getKey())); + InlinedGUIDs.insert(TS.first.getHashCode()); } // Import hot child context profile associted with callees. Note that this @@ -1234,7 +1241,7 @@ bool SampleProfileLoader::inlineHotFunctions( for (const auto *FS : findIndirectCallFunctionSamples(*I, Sum)) { uint64_t SumOrigin = Sum; if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { - findExternalInlineCandidate(I, FS, InlinedGUIDs, SymbolMap, + findExternalInlineCandidate(I, FS, InlinedGUIDs, PSI->getOrCompHotCountThreshold()); continue; } @@ -1255,7 +1262,7 @@ bool SampleProfileLoader::inlineHotFunctions( } } else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { findExternalInlineCandidate(I, findCalleeFunctionSamples(*I), - InlinedGUIDs, SymbolMap, + InlinedGUIDs, PSI->getOrCompHotCountThreshold()); } } @@ -1504,7 +1511,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( for (const auto *FS : CalleeSamples) { // TODO: Consider disable pre-lTO ICP for MonoLTO as well if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { - findExternalInlineCandidate(I, FS, InlinedGUIDs, SymbolMap, + findExternalInlineCandidate(I, FS, InlinedGUIDs, PSI->getOrCompHotCountThreshold()); continue; } @@ -1557,7 +1564,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( } } else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { findExternalInlineCandidate(I, findCalleeFunctionSamples(*I), - InlinedGUIDs, SymbolMap, + InlinedGUIDs, PSI->getOrCompHotCountThreshold()); } } @@ -1619,7 +1626,12 @@ void SampleProfileLoader::promoteMergeNotInlinedContextSamples( // Note that we have to do the merge right after processing function. // This allows OutlineFS's profile to be used for annotation during // top-down processing of functions' annotation. - FunctionSamples *OutlineFS = Reader->getOrCreateSamplesFor(*Callee); + FunctionSamples *OutlineFS = Reader->getSamplesFor(*Callee); + // If outlined function does not exist in the profile, add it to a + // separate map so that it does not rehash the original profile. + if (!OutlineFS) + OutlineFS = &OutlineFunctionSamples[ + FunctionId(FunctionSamples::getCanonicalFnName(Callee->getName()))]; OutlineFS->merge(*FS, 1); // Set outlined profile to be synthetic to not bias the inliner. OutlineFS->SetContextSynthetic(); @@ -1638,7 +1650,7 @@ GetSortedValueDataFromCallTargets(const SampleRecord::CallTargetMap &M) { SmallVector<InstrProfValueData, 2> R; for (const auto &I : SampleRecord::SortCallTargets(M)) { R.emplace_back( - InstrProfValueData{FunctionSamples::getGUID(I.first), I.second}); + InstrProfValueData{I.first.getHashCode(), I.second}); } return R; } @@ -1699,9 +1711,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { else if (OverwriteExistingWeights) I.setMetadata(LLVMContext::MD_prof, nullptr); } else if (!isa<IntrinsicInst>(&I)) { - I.setMetadata(LLVMContext::MD_prof, - MDB.createBranchWeights( - {static_cast<uint32_t>(BlockWeights[BB])})); + setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])}); } } } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) { @@ -1709,10 +1719,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { // clear it for cold code. for (auto &I : *BB) { if (isa<CallInst>(I) || isa<InvokeInst>(I)) { - if (cast<CallBase>(I).isIndirectCall()) + if (cast<CallBase>(I).isIndirectCall()) { I.setMetadata(LLVMContext::MD_prof, nullptr); - else - I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0)); + } else { + setBranchWeights(I, {uint32_t(0)}); + } } } } @@ -1792,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { if (MaxWeight > 0 && (!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) { LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n"); - TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + setBranchWeights(*TI, Weights); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst) << "most popular destination for conditional branches at " @@ -1865,7 +1876,8 @@ SampleProfileLoader::buildProfiledCallGraph(Module &M) { for (Function &F : M) { if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) continue; - ProfiledCG->addProfiledFunction(FunctionSamples::getCanonicalFnName(F)); + ProfiledCG->addProfiledFunction( + getRepInFormat(FunctionSamples::getCanonicalFnName(F))); } return ProfiledCG; @@ -1913,7 +1925,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) { // on the profile to favor more inlining. This is only a problem with CS // profile. // 3. Transitive indirect call edges due to inlining. When a callee function - // (say B) is inlined into into a caller function (say A) in LTO prelink, + // (say B) is inlined into a caller function (say A) in LTO prelink, // every call edge originated from the callee B will be transferred to // the caller A. If any transferred edge (say A->C) is indirect, the // original profiled indirect edge B->C, even if considered, would not @@ -2016,8 +2028,16 @@ bool SampleProfileLoader::doInitialization(Module &M, ProfileAccurateForSymsInList && PSL && !ProfileSampleAccurate; if (ProfAccForSymsInList) { NamesInProfile.clear(); - if (auto NameTable = Reader->getNameTable()) - NamesInProfile.insert(NameTable->begin(), NameTable->end()); + GUIDsInProfile.clear(); + if (auto NameTable = Reader->getNameTable()) { + if (FunctionSamples::UseMD5) { + for (auto Name : *NameTable) + GUIDsInProfile.insert(Name.getHashCode()); + } else { + for (auto Name : *NameTable) + NamesInProfile.insert(Name.stringRef()); + } + } CoverageTracker.setProfAccForSymsInList(true); } @@ -2103,77 +2123,200 @@ bool SampleProfileLoader::doInitialization(Module &M, return true; } -void SampleProfileMatcher::countProfileMismatches( - const FunctionSamples &FS, - const std::unordered_set<LineLocation, LineLocationHash> - &MatchedCallsiteLocs, - uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) { +void SampleProfileMatcher::findIRAnchors( + const Function &F, std::map<LineLocation, StringRef> &IRAnchors) { + // For inlined code, recover the original callsite and callee by finding the + // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the + // top-level frame is "main:1", the callsite is "1" and the callee is "foo". + auto FindTopLevelInlinedCallsite = [](const DILocation *DIL) { + assert((DIL && DIL->getInlinedAt()) && "No inlined callsite"); + const DILocation *PrevDIL = nullptr; + do { + PrevDIL = DIL; + DIL = DIL->getInlinedAt(); + } while (DIL->getInlinedAt()); + + LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL); + StringRef CalleeName = PrevDIL->getSubprogramLinkageName(); + return std::make_pair(Callsite, CalleeName); + }; - auto isInvalidLineOffset = [](uint32_t LineOffset) { - return LineOffset & 0x8000; + auto GetCanonicalCalleeName = [](const CallBase *CB) { + StringRef CalleeName = UnknownIndirectCallee; + if (Function *Callee = CB->getCalledFunction()) + CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName()); + return CalleeName; }; - // Check if there are any callsites in the profile that does not match to any - // IR callsites, those callsite samples will be discarded. - for (auto &I : FS.getBodySamples()) { - const LineLocation &Loc = I.first; - if (isInvalidLineOffset(Loc.LineOffset)) - continue; + // Extract profile matching anchors in the IR. + for (auto &BB : F) { + for (auto &I : BB) { + DILocation *DIL = I.getDebugLoc(); + if (!DIL) + continue; + + if (FunctionSamples::ProfileIsProbeBased) { + if (auto Probe = extractProbe(I)) { + // Flatten inlined IR for the matching. + if (DIL->getInlinedAt()) { + IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL)); + } else { + // Use empty StringRef for basic block probe. + StringRef CalleeName; + if (const auto *CB = dyn_cast<CallBase>(&I)) { + // Skip the probe inst whose callee name is "llvm.pseudoprobe". + if (!isa<IntrinsicInst>(&I)) + CalleeName = GetCanonicalCalleeName(CB); + } + IRAnchors.emplace(LineLocation(Probe->Id, 0), CalleeName); + } + } + } else { + // TODO: For line-number based profile(AutoFDO), currently only support + // find callsite anchors. In future, we need to parse all the non-call + // instructions to extract the line locations for profile matching. + if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I)) + continue; - uint64_t Count = I.second.getSamples(); - if (!I.second.getCallTargets().empty()) { - TotalCallsiteSamples += Count; - FuncProfiledCallsites++; - if (!MatchedCallsiteLocs.count(Loc)) { - MismatchedCallsiteSamples += Count; - FuncMismatchedCallsites++; + if (DIL->getInlinedAt()) { + IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL)); + } else { + LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL); + StringRef CalleeName = GetCanonicalCalleeName(dyn_cast<CallBase>(&I)); + IRAnchors.emplace(Callsite, CalleeName); + } } } } +} - for (auto &I : FS.getCallsiteSamples()) { - const LineLocation &Loc = I.first; - if (isInvalidLineOffset(Loc.LineOffset)) - continue; +void SampleProfileMatcher::countMismatchedSamples(const FunctionSamples &FS) { + const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID()); + // Skip the function that is external or renamed. + if (!FuncDesc) + return; - uint64_t Count = 0; - for (auto &FM : I.second) { - Count += FM.second.getHeadSamplesEstimate(); + if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) { + MismatchedFuncHashSamples += FS.getTotalSamples(); + return; + } + for (const auto &I : FS.getCallsiteSamples()) + for (const auto &CS : I.second) + countMismatchedSamples(CS.second); +} + +void SampleProfileMatcher::countProfileMismatches( + const Function &F, const FunctionSamples &FS, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors) { + [[maybe_unused]] bool IsFuncHashMismatch = false; + if (FunctionSamples::ProfileIsProbeBased) { + TotalFuncHashSamples += FS.getTotalSamples(); + TotalProfiledFunc++; + const auto *FuncDesc = ProbeManager->getDesc(F); + if (FuncDesc) { + if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) { + NumMismatchedFuncHash++; + IsFuncHashMismatch = true; + } + countMismatchedSamples(FS); } - TotalCallsiteSamples += Count; + } + + uint64_t FuncMismatchedCallsites = 0; + uint64_t FuncProfiledCallsites = 0; + countProfileCallsiteMismatches(FS, IRAnchors, ProfileAnchors, + FuncMismatchedCallsites, + FuncProfiledCallsites); + TotalProfiledCallsites += FuncProfiledCallsites; + NumMismatchedCallsites += FuncMismatchedCallsites; + LLVM_DEBUG({ + if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch && + FuncMismatchedCallsites) + dbgs() << "Function checksum is matched but there are " + << FuncMismatchedCallsites << "/" << FuncProfiledCallsites + << " mismatched callsites.\n"; + }); +} + +void SampleProfileMatcher::countProfileCallsiteMismatches( + const FunctionSamples &FS, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors, + uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) { + + // Check if there are any callsites in the profile that does not match to any + // IR callsites, those callsite samples will be discarded. + for (const auto &I : ProfileAnchors) { + const auto &Loc = I.first; + const auto &Callees = I.second; + assert(!Callees.empty() && "Callees should not be empty"); + + StringRef IRCalleeName; + const auto &IR = IRAnchors.find(Loc); + if (IR != IRAnchors.end()) + IRCalleeName = IR->second; + + // Compute number of samples in the original profile. + uint64_t CallsiteSamples = 0; + auto CTM = FS.findCallTargetMapAt(Loc); + if (CTM) { + for (const auto &I : CTM.get()) + CallsiteSamples += I.second; + } + const auto *FSMap = FS.findFunctionSamplesMapAt(Loc); + if (FSMap) { + for (const auto &I : *FSMap) + CallsiteSamples += I.second.getTotalSamples(); + } + + bool CallsiteIsMatched = false; + // Since indirect call does not have CalleeName, check conservatively if + // callsite in the profile is a callsite location. This is to reduce num of + // false positive since otherwise all the indirect call samples will be + // reported as mismatching. + if (IRCalleeName == UnknownIndirectCallee) + CallsiteIsMatched = true; + else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName))) + CallsiteIsMatched = true; + FuncProfiledCallsites++; - if (!MatchedCallsiteLocs.count(Loc)) { - MismatchedCallsiteSamples += Count; + TotalCallsiteSamples += CallsiteSamples; + if (!CallsiteIsMatched) { FuncMismatchedCallsites++; + MismatchedCallsiteSamples += CallsiteSamples; } } } -// Populate the anchors(direct callee name) from profile. -void SampleProfileMatcher::populateProfileCallsites( - const FunctionSamples &FS, - StringMap<std::set<LineLocation>> &CalleeToCallsitesMap) { +void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS, + std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) { + auto isInvalidLineOffset = [](uint32_t LineOffset) { + return LineOffset & 0x8000; + }; + for (const auto &I : FS.getBodySamples()) { - const auto &Loc = I.first; - const auto &CTM = I.second.getCallTargets(); - // Filter out possible indirect calls, use direct callee name as anchor. - if (CTM.size() == 1) { - StringRef CalleeName = CTM.begin()->first(); - const auto &Candidates = CalleeToCallsitesMap.try_emplace( - CalleeName, std::set<LineLocation>()); - Candidates.first->second.insert(Loc); + const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; + for (const auto &I : I.second.getCallTargets()) { + auto Ret = ProfileAnchors.try_emplace(Loc, + std::unordered_set<FunctionId>()); + Ret.first->second.insert(I.first); } } for (const auto &I : FS.getCallsiteSamples()) { const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; const auto &CalleeMap = I.second; - // Filter out possible indirect calls, use direct callee name as anchor. - if (CalleeMap.size() == 1) { - StringRef CalleeName = CalleeMap.begin()->first; - const auto &Candidates = CalleeToCallsitesMap.try_emplace( - CalleeName, std::set<LineLocation>()); - Candidates.first->second.insert(Loc); + for (const auto &I : CalleeMap) { + auto Ret = ProfileAnchors.try_emplace(Loc, + std::unordered_set<FunctionId>()); + Ret.first->second.insert(I.first); } } } @@ -2196,12 +2339,30 @@ void SampleProfileMatcher::populateProfileCallsites( // [1, 2, 3(foo), 4, 7, 8(bar), 9] // The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9]. void SampleProfileMatcher::runStaleProfileMatching( - const std::map<LineLocation, StringRef> &IRLocations, - StringMap<std::set<LineLocation>> &CalleeToCallsitesMap, + const Function &F, + const std::map<LineLocation, StringRef> &IRAnchors, + const std::map<LineLocation, std::unordered_set<FunctionId>> + &ProfileAnchors, LocToLocMap &IRToProfileLocationMap) { + LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName() + << "\n"); assert(IRToProfileLocationMap.empty() && "Run stale profile matching only once per function"); + std::unordered_map<FunctionId, std::set<LineLocation>> + CalleeToCallsitesMap; + for (const auto &I : ProfileAnchors) { + const auto &Loc = I.first; + const auto &Callees = I.second; + // Filter out possible indirect calls, use direct callee name as anchor. + if (Callees.size() == 1) { + FunctionId CalleeName = *Callees.begin(); + const auto &Candidates = CalleeToCallsitesMap.try_emplace( + CalleeName, std::set<LineLocation>()); + Candidates.first->second.insert(Loc); + } + } + auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) { // Skip the unchanged location mapping to save memory. if (From != To) @@ -2212,18 +2373,19 @@ void SampleProfileMatcher::runStaleProfileMatching( int32_t LocationDelta = 0; SmallVector<LineLocation> LastMatchedNonAnchors; - for (const auto &IR : IRLocations) { + for (const auto &IR : IRAnchors) { const auto &Loc = IR.first; - StringRef CalleeName = IR.second; + auto CalleeName = IR.second; bool IsMatchedAnchor = false; // Match the anchor location in lexical order. if (!CalleeName.empty()) { - auto ProfileAnchors = CalleeToCallsitesMap.find(CalleeName); - if (ProfileAnchors != CalleeToCallsitesMap.end() && - !ProfileAnchors->second.empty()) { - auto CI = ProfileAnchors->second.begin(); + auto CandidateAnchors = CalleeToCallsitesMap.find( + getRepInFormat(CalleeName)); + if (CandidateAnchors != CalleeToCallsitesMap.end() && + !CandidateAnchors->second.empty()) { + auto CI = CandidateAnchors->second.begin(); const auto Candidate = *CI; - ProfileAnchors->second.erase(CI); + CandidateAnchors->second.erase(CI); InsertMatching(Loc, Candidate); LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName << " is matched from " << Loc << " to " << Candidate @@ -2261,122 +2423,56 @@ void SampleProfileMatcher::runStaleProfileMatching( } } -void SampleProfileMatcher::runOnFunction(const Function &F, - const FunctionSamples &FS) { - bool IsFuncHashMismatch = false; - if (FunctionSamples::ProfileIsProbeBased) { - uint64_t Count = FS.getTotalSamples(); - TotalFuncHashSamples += Count; - TotalProfiledFunc++; - if (!ProbeManager->profileIsValid(F, FS)) { - MismatchedFuncHashSamples += Count; - NumMismatchedFuncHash++; - IsFuncHashMismatch = true; - } - } - - std::unordered_set<LineLocation, LineLocationHash> MatchedCallsiteLocs; - // The value of the map is the name of direct callsite and use empty StringRef - // for non-direct-call site. - std::map<LineLocation, StringRef> IRLocations; - - // Extract profile matching anchors and profile mismatch metrics in the IR. - for (auto &BB : F) { - for (auto &I : BB) { - // TODO: Support line-number based location(AutoFDO). - if (FunctionSamples::ProfileIsProbeBased && isa<PseudoProbeInst>(&I)) { - if (std::optional<PseudoProbe> Probe = extractProbe(I)) - IRLocations.emplace(LineLocation(Probe->Id, 0), StringRef()); - } +void SampleProfileMatcher::runOnFunction(const Function &F) { + // We need to use flattened function samples for matching. + // Unlike IR, which includes all callsites from the source code, the callsites + // in profile only show up when they are hit by samples, i,e. the profile + // callsites in one context may differ from those in another context. To get + // the maximum number of callsites, we merge the function profiles from all + // contexts, aka, the flattened profile to find profile anchors. + const auto *FSFlattened = getFlattenedSamplesFor(F); + if (!FSFlattened) + return; - if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I)) - continue; - - const auto *CB = dyn_cast<CallBase>(&I); - if (auto &DLoc = I.getDebugLoc()) { - LineLocation IRCallsite = FunctionSamples::getCallSiteIdentifier(DLoc); - - StringRef CalleeName; - if (Function *Callee = CB->getCalledFunction()) - CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName()); - - // Force to overwrite the callee name in case any non-call location was - // written before. - auto R = IRLocations.emplace(IRCallsite, CalleeName); - R.first->second = CalleeName; - assert((!FunctionSamples::ProfileIsProbeBased || R.second || - R.first->second == CalleeName) && - "Overwrite non-call or different callee name location for " - "pseudo probe callsite"); - - // Go through all the callsites on the IR and flag the callsite if the - // target name is the same as the one in the profile. - const auto CTM = FS.findCallTargetMapAt(IRCallsite); - const auto CallsiteFS = FS.findFunctionSamplesMapAt(IRCallsite); - - // Indirect call case. - if (CalleeName.empty()) { - // Since indirect call does not have the CalleeName, check - // conservatively if callsite in the profile is a callsite location. - // This is to avoid nums of false positive since otherwise all the - // indirect call samples will be reported as mismatching. - if ((CTM && !CTM->empty()) || (CallsiteFS && !CallsiteFS->empty())) - MatchedCallsiteLocs.insert(IRCallsite); - } else { - // Check if the call target name is matched for direct call case. - if ((CTM && CTM->count(CalleeName)) || - (CallsiteFS && CallsiteFS->count(CalleeName))) - MatchedCallsiteLocs.insert(IRCallsite); - } - } - } - } + // Anchors for IR. It's a map from IR location to callee name, callee name is + // empty for non-call instruction and use a dummy name(UnknownIndirectCallee) + // for unknown indrect callee name. + std::map<LineLocation, StringRef> IRAnchors; + findIRAnchors(F, IRAnchors); + // Anchors for profile. It's a map from callsite location to a set of callee + // name. + std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors; + findProfileAnchors(*FSFlattened, ProfileAnchors); // Detect profile mismatch for profile staleness metrics report. - if (ReportProfileStaleness || PersistProfileStaleness) { - uint64_t FuncMismatchedCallsites = 0; - uint64_t FuncProfiledCallsites = 0; - countProfileMismatches(FS, MatchedCallsiteLocs, FuncMismatchedCallsites, - FuncProfiledCallsites); - TotalProfiledCallsites += FuncProfiledCallsites; - NumMismatchedCallsites += FuncMismatchedCallsites; - LLVM_DEBUG({ - if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch && - FuncMismatchedCallsites) - dbgs() << "Function checksum is matched but there are " - << FuncMismatchedCallsites << "/" << FuncProfiledCallsites - << " mismatched callsites.\n"; - }); - } - - if (IsFuncHashMismatch && SalvageStaleProfile) { - LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName() - << "\n"); - - StringMap<std::set<LineLocation>> CalleeToCallsitesMap; - populateProfileCallsites(FS, CalleeToCallsitesMap); - + // Skip reporting the metrics for imported functions. + if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) && + (ReportProfileStaleness || PersistProfileStaleness)) { + // Use top-level nested FS for counting profile mismatch metrics since + // currently once a callsite is mismatched, all its children profiles are + // dropped. + if (const auto *FS = Reader.getSamplesFor(F)) + countProfileMismatches(F, *FS, IRAnchors, ProfileAnchors); + } + + // Run profile matching for checksum mismatched profile, currently only + // support for pseudo-probe. + if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased && + !ProbeManager->profileIsValid(F, *FSFlattened)) { // The matching result will be saved to IRToProfileLocationMap, create a new // map for each function. - auto &IRToProfileLocationMap = getIRToProfileLocationMap(F); - - runStaleProfileMatching(IRLocations, CalleeToCallsitesMap, - IRToProfileLocationMap); + runStaleProfileMatching(F, IRAnchors, ProfileAnchors, + getIRToProfileLocationMap(F)); } } void SampleProfileMatcher::runOnModule() { + ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles, + FunctionSamples::ProfileIsCS); for (auto &F : M) { if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) continue; - FunctionSamples *FS = nullptr; - if (FlattenProfileForMatching) - FS = getFlattenedSamplesFor(F); - else - FS = Reader.getSamplesFor(F); - if (!FS) - continue; - runOnFunction(F, *FS); + runOnFunction(F); } if (SalvageStaleProfile) distributeIRToProfileLocationMap(); @@ -2424,7 +2520,7 @@ void SampleProfileMatcher::runOnModule() { void SampleProfileMatcher::distributeIRToProfileLocationMap( FunctionSamples &FS) { - const auto ProfileMappings = FuncMappings.find(FS.getName()); + const auto ProfileMappings = FuncMappings.find(FS.getFuncName()); if (ProfileMappings != FuncMappings.end()) { FS.setIRToProfileLocationMap(&(ProfileMappings->second)); } @@ -2466,10 +2562,10 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, Function *F = dyn_cast<Function>(N_F.getValue()); if (F == nullptr || OrigName.empty()) continue; - SymbolMap[OrigName] = F; + SymbolMap[FunctionId(OrigName)] = F; StringRef NewName = FunctionSamples::getCanonicalFnName(*F); if (OrigName != NewName && !NewName.empty()) { - auto r = SymbolMap.insert(std::make_pair(NewName, F)); + auto r = SymbolMap.emplace(FunctionId(NewName), F); // Failiing to insert means there is already an entry in SymbolMap, // thus there are multiple functions that are mapped to the same // stripped name. In this case of name conflicting, set the value @@ -2482,11 +2578,11 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, if (Remapper) { if (auto MapName = Remapper->lookUpNameInProfile(OrigName)) { if (*MapName != OrigName && !MapName->empty()) - SymbolMap.insert(std::make_pair(*MapName, F)); + SymbolMap.emplace(FunctionId(*MapName), F); } } } - assert(SymbolMap.count(StringRef()) == 0 && + assert(SymbolMap.count(FunctionId()) == 0 && "No empty StringRef should be added in SymbolMap"); if (ReportProfileStaleness || PersistProfileStaleness || @@ -2550,7 +2646,9 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) // but not cold accumulatively...), so the outline function showing up as // cold in sampled binary will actually not be cold after current build. StringRef CanonName = FunctionSamples::getCanonicalFnName(F); - if (NamesInProfile.count(CanonName)) + if ((FunctionSamples::UseMD5 && + GUIDsInProfile.count(Function::getGUID(CanonName))) || + (!FunctionSamples::UseMD5 && NamesInProfile.count(CanonName))) initialEntryCount = -1; } @@ -2571,8 +2669,24 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) if (FunctionSamples::ProfileIsCS) Samples = ContextTracker->getBaseSamplesFor(F); - else + else { Samples = Reader->getSamplesFor(F); + // Try search in previously inlined functions that were split or duplicated + // into base. + if (!Samples) { + StringRef CanonName = FunctionSamples::getCanonicalFnName(F); + auto It = OutlineFunctionSamples.find(FunctionId(CanonName)); + if (It != OutlineFunctionSamples.end()) { + Samples = &It->second; + } else if (auto Remapper = Reader->getRemapper()) { + if (auto RemppedName = Remapper->lookUpNameInProfile(CanonName)) { + It = OutlineFunctionSamples.find(FunctionId(*RemppedName)); + if (It != OutlineFunctionSamples.end()) + Samples = &It->second; + } + } + } + } if (Samples && !Samples->empty()) return emitAnnotations(F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp index 0a42de7224b4..8f0b12d0cfed 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" @@ -95,13 +96,13 @@ void PseudoProbeVerifier::runAfterPass(StringRef PassID, Any IR) { std::string Banner = "\n*** Pseudo Probe Verification After " + PassID.str() + " ***\n"; dbgs() << Banner; - if (const auto **M = any_cast<const Module *>(&IR)) + if (const auto **M = llvm::any_cast<const Module *>(&IR)) runAfterPass(*M); - else if (const auto **F = any_cast<const Function *>(&IR)) + else if (const auto **F = llvm::any_cast<const Function *>(&IR)) runAfterPass(*F); - else if (const auto **C = any_cast<const LazyCallGraph::SCC *>(&IR)) + else if (const auto **C = llvm::any_cast<const LazyCallGraph::SCC *>(&IR)) runAfterPass(*C); - else if (const auto **L = any_cast<const Loop *>(&IR)) + else if (const auto **L = llvm::any_cast<const Loop *>(&IR)) runAfterPass(*L); else llvm_unreachable("Unknown IR unit"); @@ -221,12 +222,26 @@ void SampleProfileProber::computeProbeIdForBlocks() { } void SampleProfileProber::computeProbeIdForCallsites() { + LLVMContext &Ctx = F->getContext(); + Module *M = F->getParent(); + for (auto &BB : *F) { for (auto &I : BB) { if (!isa<CallBase>(I)) continue; if (isa<IntrinsicInst>(&I)) continue; + + // The current implementation uses the lower 16 bits of the discriminator + // so anything larger than 0xFFFF will be ignored. + if (LastProbeId >= 0xFFFF) { + std::string Msg = "Pseudo instrumentation incomplete for " + + std::string(F->getName()) + " because it's too large"; + Ctx.diagnose( + DiagnosticInfoSampleProfile(M->getName().data(), Msg, DS_Warning)); + return; + } + CallProbeIds[&I] = ++LastProbeId; } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp index 147513452789..28d7d4ba6b01 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -30,12 +30,18 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/TypeFinder.h" #include "llvm/IR/ValueSymbolTable.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/StripSymbols.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +static cl::opt<bool> + StripGlobalConstants("strip-global-constants", cl::init(false), cl::Hidden, + cl::desc("Removes debug compile units which reference " + "to non-existing global constants")); + /// OnlyUsedBy - Return true if V is only used by Usr. static bool OnlyUsedBy(Value *V, Value *Usr) { for (User *U : V->users()) @@ -73,7 +79,7 @@ static void StripSymtab(ValueSymbolTable &ST, bool PreserveDbgInfo) { Value *V = VI->getValue(); ++VI; if (!isa<GlobalValue>(V) || cast<GlobalValue>(V)->hasLocalLinkage()) { - if (!PreserveDbgInfo || !V->getName().startswith("llvm.dbg")) + if (!PreserveDbgInfo || !V->getName().starts_with("llvm.dbg")) // Set name to "", removing from symbol table! V->setName(""); } @@ -88,7 +94,7 @@ static void StripTypeNames(Module &M, bool PreserveDbgInfo) { for (StructType *STy : StructTypes) { if (STy->isLiteral() || STy->getName().empty()) continue; - if (PreserveDbgInfo && STy->getName().startswith("llvm.dbg")) + if (PreserveDbgInfo && STy->getName().starts_with("llvm.dbg")) continue; STy->setName(""); @@ -118,13 +124,13 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) { for (GlobalVariable &GV : M.globals()) { if (GV.hasLocalLinkage() && !llvmUsedValues.contains(&GV)) - if (!PreserveDbgInfo || !GV.getName().startswith("llvm.dbg")) + if (!PreserveDbgInfo || !GV.getName().starts_with("llvm.dbg")) GV.setName(""); // Internal symbols can't participate in linkage } for (Function &I : M) { if (I.hasLocalLinkage() && !llvmUsedValues.contains(&I)) - if (!PreserveDbgInfo || !I.getName().startswith("llvm.dbg")) + if (!PreserveDbgInfo || !I.getName().starts_with("llvm.dbg")) I.setName(""); // Internal symbols can't participate in linkage if (auto *Symtab = I.getValueSymbolTable()) StripSymtab(*Symtab, PreserveDbgInfo); @@ -216,7 +222,8 @@ static bool stripDeadDebugInfoImpl(Module &M) { // Create our live global variable list. bool GlobalVariableChange = false; for (auto *DIG : DIC->getGlobalVariables()) { - if (DIG->getExpression() && DIG->getExpression()->isConstant()) + if (DIG->getExpression() && DIG->getExpression()->isConstant() && + !StripGlobalConstants) LiveGVs.insert(DIG); // Make sure we only visit each global variable only once. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index d46f9a6c6757..f6f895676084 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -111,7 +111,7 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, // Now compute the callsite count from relative frequency and // entry count: BasicBlock *CSBB = CB.getParent(); - Scaled64 EntryFreq(BFI.getEntryFreq(), 0); + Scaled64 EntryFreq(BFI.getEntryFreq().getFrequency(), 0); Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0); BBCount /= EntryFreq; BBCount *= Counts[Caller]; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index fc1e70b1b3d3..e5f9fa1dda88 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -186,7 +186,7 @@ void simplifyExternals(Module &M) { if (!F.isDeclaration() || F.getFunctionType() == EmptyFT || // Changing the type of an intrinsic may invalidate the IR. - F.getName().startswith("llvm.")) + F.getName().starts_with("llvm.")) continue; Function *NewF = @@ -198,7 +198,7 @@ void simplifyExternals(Module &M) { AttributeList::FunctionIndex, F.getAttributes().getFnAttrs())); NewF->takeName(&F); - F.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, F.getType())); + F.replaceAllUsesWith(NewF); F.eraseFromParent(); } @@ -329,7 +329,7 @@ void splitAndWriteThinLTOBitcode( // comdat in MergedM to keep the comdat together. DenseSet<const Comdat *> MergedMComdats; for (GlobalVariable &GV : M.globals()) - if (HasTypeMetadata(&GV)) { + if (!GV.isDeclaration() && HasTypeMetadata(&GV)) { if (const auto *C = GV.getComdat()) MergedMComdats.insert(C); forEachVirtualFunction(GV.getInitializer(), [&](Function *F) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index d33258642365..85afc020dbf8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -58,7 +58,6 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -369,8 +368,6 @@ template <> struct DenseMapInfo<VTableSlotSummary> { } // end namespace llvm -namespace { - // Returns true if the function must be unreachable based on ValueInfo. // // In particular, identifies a function as unreachable in the following @@ -378,7 +375,7 @@ namespace { // 1) All summaries are live. // 2) All function summaries indicate it's unreachable // 3) There is no non-function with the same GUID (which is rare) -bool mustBeUnreachableFunction(ValueInfo TheFnVI) { +static bool mustBeUnreachableFunction(ValueInfo TheFnVI) { if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { // Returns false if ValueInfo is absent, or the summary list is empty // (e.g., function declarations). @@ -403,6 +400,7 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) { return true; } +namespace { // A virtual call site. VTable is the loaded virtual table pointer, and CS is // the indirect virtual call. struct VirtualCallSite { @@ -590,7 +588,7 @@ struct DevirtModule { : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), ExportSummary(ExportSummary), ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), - Int8PtrTy(Type::getInt8PtrTy(M.getContext())), + Int8PtrTy(PointerType::getUnqual(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())), Int64Ty(Type::getInt64Ty(M.getContext())), IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), @@ -776,20 +774,59 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, return PreservedAnalyses::none(); } -namespace llvm { // Enable whole program visibility if enabled by client (e.g. linker) or // internal option, and not force disabled. -bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { +bool llvm::hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) && !DisableWholeProgramVisibility; } +static bool +typeIDVisibleToRegularObj(StringRef TypeID, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + // TypeID for member function pointer type is an internal construct + // and won't exist in IsVisibleToRegularObj. The full TypeID + // will be present and participate in invalidation. + if (TypeID.ends_with(".virtual")) + return false; + + // TypeID that doesn't start with Itanium mangling (_ZTS) will be + // non-externally visible types which cannot interact with + // external native files. See CodeGenModule::CreateMetadataIdentifierImpl. + if (!TypeID.consume_front("_ZTS")) + return false; + + // TypeID is keyed off the type name symbol (_ZTS). However, the native + // object may not contain this symbol if it does not contain a key + // function for the base type and thus only contains a reference to the + // type info (_ZTI). To catch this case we query using the type info + // symbol corresponding to the TypeID. + std::string typeInfo = ("_ZTI" + TypeID).str(); + return IsVisibleToRegularObj(typeInfo); +} + +static bool +skipUpdateDueToValidation(GlobalVariable &GV, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + SmallVector<MDNode *, 2> Types; + GV.getMetadata(LLVMContext::MD_type, Types); + + for (auto Type : Types) + if (auto *TypeID = dyn_cast<MDString>(Type->getOperand(1).get())) + return typeIDVisibleToRegularObj(TypeID->getString(), + IsVisibleToRegularObj); + + return false; +} + /// If whole program visibility asserted, then upgrade all public vcall /// visibility metadata on vtable definitions to linkage unit visibility in /// Module IR (for regular or hybrid LTO). -void updateVCallVisibilityInModule( +void llvm::updateVCallVisibilityInModule( Module &M, bool WholeProgramVisibilityEnabledInLTO, - const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { + const DenseSet<GlobalValue::GUID> &DynamicExportSymbols, + bool ValidateAllVtablesHaveTypeInfos, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) return; for (GlobalVariable &GV : M.globals()) { @@ -800,13 +837,19 @@ void updateVCallVisibilityInModule( GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic && // Don't upgrade the visibility for symbols exported to the dynamic // linker, as we have no information on their eventual use. - !DynamicExportSymbols.count(GV.getGUID())) + !DynamicExportSymbols.count(GV.getGUID()) && + // With validation enabled, we want to exclude symbols visible to + // regular objects. Local symbols will be in this group due to the + // current implementation but those with VCallVisibilityTranslationUnit + // will have already been marked in clang so are unaffected. + !(ValidateAllVtablesHaveTypeInfos && + skipUpdateDueToValidation(GV, IsVisibleToRegularObj))) GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit); } } -void updatePublicTypeTestCalls(Module &M, - bool WholeProgramVisibilityEnabledInLTO) { +void llvm::updatePublicTypeTestCalls(Module &M, + bool WholeProgramVisibilityEnabledInLTO) { Function *PublicTypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::public_type_test)); if (!PublicTypeTestFunc) @@ -832,12 +875,26 @@ void updatePublicTypeTestCalls(Module &M, } } +/// Based on typeID string, get all associated vtable GUIDS that are +/// visible to regular objects. +void llvm::getVisibleToRegularObjVtableGUIDs( + ModuleSummaryIndex &Index, + DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols, + function_ref<bool(StringRef)> IsVisibleToRegularObj) { + for (const auto &typeID : Index.typeIdCompatibleVtableMap()) { + if (typeIDVisibleToRegularObj(typeID.first, IsVisibleToRegularObj)) + for (const TypeIdOffsetVtableInfo &P : typeID.second) + VisibleToRegularObjSymbols.insert(P.VTableVI.getGUID()); + } +} + /// If whole program visibility asserted, then upgrade all public vcall /// visibility metadata on vtable definition summaries to linkage unit /// visibility in Module summary index (for ThinLTO). -void updateVCallVisibilityInIndex( +void llvm::updateVCallVisibilityInIndex( ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO, - const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { + const DenseSet<GlobalValue::GUID> &DynamicExportSymbols, + const DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols) { if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) return; for (auto &P : Index) { @@ -850,18 +907,24 @@ void updateVCallVisibilityInIndex( if (!GVar || GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) continue; + // With validation enabled, we want to exclude symbols visible to regular + // objects. Local symbols will be in this group due to the current + // implementation but those with VCallVisibilityTranslationUnit will have + // already been marked in clang so are unaffected. + if (VisibleToRegularObjSymbols.count(P.first)) + continue; GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); } } } -void runWholeProgramDevirtOnIndex( +void llvm::runWholeProgramDevirtOnIndex( ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run(); } -void updateIndexWPDForExports( +void llvm::updateIndexWPDForExports( ModuleSummaryIndex &Summary, function_ref<bool(StringRef, ValueInfo)> isExported, std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { @@ -887,8 +950,6 @@ void updateIndexWPDForExports( } } -} // end namespace llvm - static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { // Check that summary index contains regular LTO module when performing // export to prevent occasional use of index from pure ThinLTO compilation @@ -942,7 +1003,7 @@ bool DevirtModule::runForTesting( ExitOnError ExitOnErr( "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); std::error_code EC; - if (StringRef(ClWriteSummary).endswith(".bc")) { + if (StringRef(ClWriteSummary).ends_with(".bc")) { raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None); ExitOnErr(errorCodeToError(EC)); writeIndexToFile(*Summary, OS); @@ -1045,8 +1106,8 @@ bool DevirtModule::tryFindVirtualCallTargets( } bool DevirtIndex::tryFindVirtualCallTargets( - std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo, - uint64_t ByteOffset) { + std::vector<ValueInfo> &TargetsForSlot, + const TypeIdCompatibleVtableInfo TIdInfo, uint64_t ByteOffset) { for (const TypeIdOffsetVtableInfo &P : TIdInfo) { // Find a representative copy of the vtable initializer. // We can have multiple available_externally, linkonce_odr and weak_odr @@ -1203,7 +1264,8 @@ static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) { // to better ensure we have the opportunity to inline them. bool IsExported = false; auto &S = Callee.getSummaryList()[0]; - CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0); + CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* HasTailCall = */ false, + /* RelBF = */ 0); auto AddCalls = [&](CallSiteInfo &CSInfo) { for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) { FS->addCall({Callee, CI}); @@ -1437,7 +1499,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, IRBuilder<> IRB(&CB); std::vector<Value *> Args; - Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); + Args.push_back(VCallSite.VTable); llvm::append_range(Args, CB.args()); CallBase *NewCS = nullptr; @@ -1471,10 +1533,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, // llvm.type.test and therefore require an llvm.type.test resolution for the // type identifier. - std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) { - CBs.first->replaceAllUsesWith(CBs.second); - CBs.first->eraseFromParent(); - }); + for (auto &[Old, New] : CallBases) { + Old->replaceAllUsesWith(New); + Old->eraseFromParent(); + } }; Apply(SlotInfo.CSInfo); for (auto &P : SlotInfo.ConstCSInfo) @@ -1648,8 +1710,7 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, } Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { - Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); - return ConstantExpr::getGetElementPtr(Int8Ty, C, + return ConstantExpr::getGetElementPtr(Int8Ty, M->Bits->GV, ConstantInt::get(Int64Ty, M->Offset)); } @@ -1708,8 +1769,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, continue; auto *RetType = cast<IntegerType>(Call.CB.getType()); IRBuilder<> B(&Call.CB); - Value *Addr = - B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); + Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); if (RetType->getBitWidth() == 1) { Value *Bits = B.CreateLoad(Int8Ty, Addr); Value *BitsAndBit = B.CreateAnd(Bits, Bit); @@ -2007,17 +2067,14 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { if (TypeCheckedLoadFunc->getIntrinsicID() == Intrinsic::type_checked_load_relative) { Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); - Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty)); - LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr); + LoadedValue = LoadB.CreateLoad(Int32Ty, GEP); LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy); GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy); LoadedValue = LoadB.CreateAdd(GEP, LoadedValue); LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy); } else { Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); - Value *GEPPtr = - LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); - LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); + LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEP); } for (Instruction *LoadedPtr : LoadedPtrs) { |