diff options
Diffstat (limited to 'llvm/lib/Transforms')
239 files changed, 26610 insertions, 16233 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 35adaa3bde65..473b41241b8a 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -14,8 +14,6 @@ #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "AggressiveInstCombineInternal.h" -#include "llvm-c/Initialization.h" -#include "llvm-c/Transforms/AggressiveInstCombine.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" @@ -24,23 +22,17 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; -namespace llvm { -class DataLayout; -} - #define DEBUG_TYPE "aggressive-instcombine" STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded"); @@ -50,31 +42,9 @@ STATISTIC(NumGuardedFunnelShifts, "Number of guarded funnel shifts transformed into funnel shifts"); STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized"); -namespace { -/// Contains expression pattern combiner logic. -/// This class provides both the logic to combine expression patterns and -/// combine them. It differs from InstCombiner class in that each pattern -/// combiner runs only once as opposed to InstCombine's multi-iteration, -/// which allows pattern combiner to have higher complexity than the O(1) -/// required by the instruction combiner. -class AggressiveInstCombinerLegacyPass : public FunctionPass { -public: - static char ID; // Pass identification, replacement for typeid - - AggressiveInstCombinerLegacyPass() : FunctionPass(ID) { - initializeAggressiveInstCombinerLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override; - - /// Run all expression pattern optimizations on the given /p F function. - /// - /// \param F function to optimize. - /// \returns true if the IR is changed. - bool runOnFunction(Function &F) override; -}; -} // namespace +static cl::opt<unsigned> MaxInstrsToScan( + "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, + cl::desc("Max number of instructions to scan for aggressive instcombine.")); /// Match a pattern for a bitwise funnel/rotate operation that partially guards /// against undefined behavior by branching around the funnel-shift/rotation @@ -446,21 +416,22 @@ foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) { if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl) return false; - // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created, - // and (3) we would not end up lowering to a libcall anyway (which could - // change the value of errno), then: - // (1) the operand arg must not be less than -0.0. - // (2) errno won't be set. - // (3) it is safe to convert this to an intrinsic call. - // TODO: Check if the arg is known non-negative. + // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created + // (because NNAN or the operand arg must not be less than -0.0) and (2) we + // would not end up lowering to a libcall anyway (which could change the value + // of errno), then: + // (1) errno won't be set. + // (2) it is safe to convert this to an intrinsic call. Type *Ty = Call->getType(); - if (TTI.haveFastSqrt(Ty) && Call->hasNoNaNs()) { + Value *Arg = Call->getArgOperand(0); + if (TTI.haveFastSqrt(Ty) && + (Call->hasNoNaNs() || CannotBeOrderedLessThanZero(Arg, &TLI))) { IRBuilder<> Builder(&I); IRBuilderBase::FastMathFlagGuard Guard(Builder); Builder.setFastMathFlags(Call->getFastMathFlags()); Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); - Value *NewSqrt = Builder.CreateCall(Sqrt, Call->getArgOperand(0), "sqrt"); + Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); I.replaceAllUsesWith(NewSqrt); // Explicitly erase the old call because a call with side effects is not @@ -472,18 +443,401 @@ foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) { return false; } +// Check if this array of constants represents a cttz table. +// Iterate over the elements from \p Table by trying to find/match all +// the numbers from 0 to \p InputBits that should represent cttz results. +static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, + uint64_t Shift, uint64_t InputBits) { + unsigned Length = Table.getNumElements(); + if (Length < InputBits || Length > InputBits * 2) + return false; + + APInt Mask = APInt::getBitsSetFrom(InputBits, Shift); + unsigned Matched = 0; + + for (unsigned i = 0; i < Length; i++) { + uint64_t Element = Table.getElementAsInteger(i); + if (Element >= InputBits) + continue; + + // Check if \p Element matches a concrete answer. It could fail for some + // elements that are never accessed, so we keep iterating over each element + // from the table. The number of matched elements should be equal to the + // number of potential right answers which is \p InputBits actually. + if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i) + Matched++; + } + + return Matched == InputBits; +} + +// Try to recognize table-based ctz implementation. +// E.g., an example in C (for more cases please see the llvm/tests): +// int f(unsigned x) { +// static const char table[32] = +// {0, 1, 28, 2, 29, 14, 24, 3, 30, +// 22, 20, 15, 25, 17, 4, 8, 31, 27, +// 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9}; +// return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27]; +// } +// this can be lowered to `cttz` instruction. +// There is also a special case when the element is 0. +// +// Here are some examples or LLVM IR for a 64-bit target: +// +// CASE 1: +// %sub = sub i32 0, %x +// %and = and i32 %sub, %x +// %mul = mul i32 %and, 125613361 +// %shr = lshr i32 %mul, 27 +// %idxprom = zext i32 %shr to i64 +// %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0, +// i64 %idxprom %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 +// +// CASE 2: +// %sub = sub i32 0, %x +// %and = and i32 %sub, %x +// %mul = mul i32 %and, 72416175 +// %shr = lshr i32 %mul, 26 +// %idxprom = zext i32 %shr to i64 +// %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, i64 +// 0, i64 %idxprom %0 = load i16, i16* %arrayidx, align 2, !tbaa !8 +// +// CASE 3: +// %sub = sub i32 0, %x +// %and = and i32 %sub, %x +// %mul = mul i32 %and, 81224991 +// %shr = lshr i32 %mul, 27 +// %idxprom = zext i32 %shr to i64 +// %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, i64 +// 0, i64 %idxprom %0 = load i32, i32* %arrayidx, align 4, !tbaa !8 +// +// CASE 4: +// %sub = sub i64 0, %x +// %and = and i64 %sub, %x +// %mul = mul i64 %and, 283881067100198605 +// %shr = lshr i64 %mul, 58 +// %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, i64 +// %shr %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 +// +// All this can be lowered to @llvm.cttz.i32/64 intrinsic. +static bool tryToRecognizeTableBasedCttz(Instruction &I) { + LoadInst *LI = dyn_cast<LoadInst>(&I); + if (!LI) + return false; + + Type *AccessType = LI->getType(); + if (!AccessType->isIntegerTy()) + return false; + + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand()); + if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2) + return false; + + if (!GEP->getSourceElementType()->isArrayTy()) + return false; + + uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements(); + if (ArraySize != 32 && ArraySize != 64) + return false; + + GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand()); + if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant()) + return false; + + ConstantDataArray *ConstData = + dyn_cast<ConstantDataArray>(GVTable->getInitializer()); + if (!ConstData) + return false; + + if (!match(GEP->idx_begin()->get(), m_ZeroInt())) + return false; + + Value *Idx2 = std::next(GEP->idx_begin())->get(); + Value *X1; + uint64_t MulConst, ShiftConst; + // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will + // probably fail for other (e.g. 32-bit) targets. + if (!match(Idx2, m_ZExtOrSelf( + m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)), + m_ConstantInt(MulConst)), + m_ConstantInt(ShiftConst))))) + return false; + + unsigned InputBits = X1->getType()->getScalarSizeInBits(); + if (InputBits != 32 && InputBits != 64) + return false; + + // Shift should extract top 5..7 bits. + if (InputBits - Log2_32(InputBits) != ShiftConst && + InputBits - Log2_32(InputBits) - 1 != ShiftConst) + return false; + + if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits)) + return false; + + auto ZeroTableElem = ConstData->getElementAsInteger(0); + bool DefinedForZero = ZeroTableElem == InputBits; + + IRBuilder<> B(LI); + ConstantInt *BoolConst = B.getInt1(!DefinedForZero); + Type *XType = X1->getType(); + auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst}); + Value *ZExtOrTrunc = nullptr; + + if (DefinedForZero) { + ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType); + } else { + // If the value in elem 0 isn't the same as InputBits, we still want to + // produce the value from the table. + auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0)); + auto Select = + B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz); + + // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target + // it should be handled as: `cttz(x) & (typeSize - 1)`. + + ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType); + } + + LI->replaceAllUsesWith(ZExtOrTrunc); + + return true; +} + +/// This is used by foldLoadsRecursive() to capture a Root Load node which is +/// of type or(load, load) and recursively build the wide load. Also capture the +/// shift amount, zero extend type and loadSize. +struct LoadOps { + LoadInst *Root = nullptr; + LoadInst *RootInsert = nullptr; + bool FoundRoot = false; + uint64_t LoadSize = 0; + Value *Shift = nullptr; + Type *ZextType; + AAMDNodes AATags; +}; + +// Identify and Merge consecutive loads recursively which is of the form +// (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1 +// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3) +static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, + AliasAnalysis &AA) { + Value *ShAmt2 = nullptr; + Value *X; + Instruction *L1, *L2; + + // Go to the last node with loads. + if (match(V, m_OneUse(m_c_Or( + m_Value(X), + m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))), + m_Value(ShAmt2)))))) || + match(V, m_OneUse(m_Or(m_Value(X), + m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) { + if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot) + // Avoid Partial chain merge. + return false; + } else + return false; + + // Check if the pattern has loads + LoadInst *LI1 = LOps.Root; + Value *ShAmt1 = LOps.Shift; + if (LOps.FoundRoot == false && + (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) || + match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), + m_Value(ShAmt1)))))) { + LI1 = dyn_cast<LoadInst>(L1); + } + LoadInst *LI2 = dyn_cast<LoadInst>(L2); + + // Check if loads are same, atomic, volatile and having same address space. + if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() || + LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace()) + return false; + + // Check if Loads come from same BB. + if (LI1->getParent() != LI2->getParent()) + return false; + + // Find the data layout + bool IsBigEndian = DL.isBigEndian(); + + // Check if loads are consecutive and same size. + Value *Load1Ptr = LI1->getPointerOperand(); + APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); + Load1Ptr = + Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1, + /* AllowNonInbounds */ true); + + Value *Load2Ptr = LI2->getPointerOperand(); + APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0); + Load2Ptr = + Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2, + /* AllowNonInbounds */ true); + + // Verify if both loads have same base pointers and load sizes are same. + uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits(); + uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits(); + if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2) + return false; + + // Support Loadsizes greater or equal to 8bits and only power of 2. + if (LoadSize1 < 8 || !isPowerOf2_64(LoadSize1)) + return false; + + // Alias Analysis to check for stores b/w the loads. + LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2; + MemoryLocation Loc; + if (!Start->comesBefore(End)) { + std::swap(Start, End); + Loc = MemoryLocation::get(End); + if (LOps.FoundRoot) + Loc = Loc.getWithNewSize(LOps.LoadSize); + } else + Loc = MemoryLocation::get(End); + unsigned NumScanned = 0; + for (Instruction &Inst : + make_range(Start->getIterator(), End->getIterator())) { + if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc))) + return false; + if (++NumScanned > MaxInstrsToScan) + return false; + } + + // Make sure Load with lower Offset is at LI1 + bool Reverse = false; + if (Offset2.slt(Offset1)) { + std::swap(LI1, LI2); + std::swap(ShAmt1, ShAmt2); + std::swap(Offset1, Offset2); + std::swap(Load1Ptr, Load2Ptr); + std::swap(LoadSize1, LoadSize2); + Reverse = true; + } + + // Big endian swap the shifts + if (IsBigEndian) + std::swap(ShAmt1, ShAmt2); + + // Find Shifts values. + const APInt *Temp; + uint64_t Shift1 = 0, Shift2 = 0; + if (ShAmt1 && match(ShAmt1, m_APInt(Temp))) + Shift1 = Temp->getZExtValue(); + if (ShAmt2 && match(ShAmt2, m_APInt(Temp))) + Shift2 = Temp->getZExtValue(); + + // First load is always LI1. This is where we put the new load. + // Use the merged load size available from LI1 for forward loads. + if (LOps.FoundRoot) { + if (!Reverse) + LoadSize1 = LOps.LoadSize; + else + LoadSize2 = LOps.LoadSize; + } + + // Verify if shift amount and load index aligns and verifies that loads + // are consecutive. + uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1; + uint64_t PrevSize = + DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1)); + if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize) + return false; + + // Update LOps + AAMDNodes AATags1 = LOps.AATags; + AAMDNodes AATags2 = LI2->getAAMetadata(); + if (LOps.FoundRoot == false) { + LOps.FoundRoot = true; + AATags1 = LI1->getAAMetadata(); + } + LOps.LoadSize = LoadSize1 + LoadSize2; + LOps.RootInsert = Start; + + // Concatenate the AATags of the Merged Loads. + LOps.AATags = AATags1.concat(AATags2); + + LOps.Root = LI1; + LOps.Shift = ShAmt1; + LOps.ZextType = X->getType(); + return true; +} + +// For a given BB instruction, evaluate all loads in the chain that form a +// pattern which suggests that the loads can be combined. The one and only use +// of the loads is to form a wider load. +static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, + TargetTransformInfo &TTI, AliasAnalysis &AA) { + // Only consider load chains of scalar values. + if (isa<VectorType>(I.getType())) + return false; + + LoadOps LOps; + if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot) + return false; + + IRBuilder<> Builder(&I); + LoadInst *NewLoad = nullptr, *LI1 = LOps.Root; + + IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize); + // TTI based checks if we want to proceed with wider load + bool Allowed = TTI.isTypeLegal(WiderType); + if (!Allowed) + return false; + + unsigned AS = LI1->getPointerAddressSpace(); + unsigned Fast = 0; + Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize, + AS, LI1->getAlign(), &Fast); + if (!Allowed || !Fast) + return false; + + // Make sure the Load pointer of type GEP/non-GEP is above insert point + Instruction *Inst = dyn_cast<Instruction>(LI1->getPointerOperand()); + if (Inst && Inst->getParent() == LI1->getParent() && + !Inst->comesBefore(LOps.RootInsert)) + Inst->moveBefore(LOps.RootInsert); + + // New load can be generated + Value *Load1Ptr = LI1->getPointerOperand(); + Builder.SetInsertPoint(LOps.RootInsert); + Value *NewPtr = Builder.CreateBitCast(Load1Ptr, WiderType->getPointerTo(AS)); + NewLoad = Builder.CreateAlignedLoad(WiderType, NewPtr, LI1->getAlign(), + LI1->isVolatile(), ""); + NewLoad->takeName(LI1); + // Set the New Load AATags Metadata. + if (LOps.AATags) + NewLoad->setAAMetadata(LOps.AATags); + + Value *NewOp = NewLoad; + // Check if zero extend needed. + if (LOps.ZextType) + NewOp = Builder.CreateZExt(NewOp, LOps.ZextType); + + // Check if shift needed. We need to shift with the amount of load1 + // shift if not zero. + if (LOps.Shift) + NewOp = Builder.CreateShl(NewOp, LOps.Shift); + I.replaceAllUsesWith(NewOp); + + return true; +} + /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. static bool foldUnusualPatterns(Function &F, DominatorTree &DT, TargetTransformInfo &TTI, - TargetLibraryInfo &TLI) { + TargetLibraryInfo &TLI, AliasAnalysis &AA) { bool MadeChange = false; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(&BB)) continue; + const DataLayout &DL = F.getParent()->getDataLayout(); + // Walk the block backwards for efficiency. We're matching a chain of // use->defs, so we're more likely to succeed by starting from the bottom. // Also, we want to avoid matching partial patterns. @@ -494,6 +848,11 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, MadeChange |= foldGuardedFunnelShift(I, DT); MadeChange |= tryToRecognizePopCount(I); MadeChange |= tryToFPToSat(I, TTI); + MadeChange |= tryToRecognizeTableBasedCttz(I); + MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA); + // NOTE: This function introduces erasing of the instruction `I`, so it + // needs to be called at the end of this sequence, otherwise we may make + // bugs. MadeChange |= foldSqrt(I, TTI, TLI); } } @@ -509,43 +868,24 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, - TargetLibraryInfo &TLI, DominatorTree &DT) { + TargetLibraryInfo &TLI, DominatorTree &DT, + AliasAnalysis &AA) { bool MadeChange = false; const DataLayout &DL = F.getParent()->getDataLayout(); TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); - MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI); + MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA); return MadeChange; } -void AggressiveInstCombinerLegacyPass::getAnalysisUsage( - AnalysisUsage &AU) const { - AU.setPreservesCFG(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<BasicAAWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); -} - -bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { - auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - return runImpl(F, AC, TTI, TLI, DT); -} - PreservedAnalyses AggressiveInstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); - if (!runImpl(F, AC, TTI, TLI, DT)) { + auto &AA = AM.getResult<AAManager>(F); + if (!runImpl(F, AC, TTI, TLI, DT, AA)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } @@ -554,31 +894,3 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F, PA.preserveSet<CFGAnalyses>(); return PA; } - -char AggressiveInstCombinerLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass, - "aggressive-instcombine", - "Combine pattern based expressions", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", - "Combine pattern based expressions", false, false) - -// Initialization Routines -void llvm::initializeAggressiveInstCombine(PassRegistry &Registry) { - initializeAggressiveInstCombinerLegacyPassPass(Registry); -} - -void LLVMInitializeAggressiveInstCombiner(LLVMPassRegistryRef R) { - initializeAggressiveInstCombinerLegacyPassPass(*unwrap(R)); -} - -FunctionPass *llvm::createAggressiveInstCombinerPass() { - return new AggressiveInstCombinerLegacyPass(); -} - -void LLVMAddAggressiveInstCombinerPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createAggressiveInstCombinerPass()); -} diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp index 70ea68587b8e..6c62e84077ac 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -157,7 +157,7 @@ bool TruncInstCombine::buildTruncExpressionGraph() { getRelevantOperands(I, Operands); // Add only operands not in Stack to prevent cycle for (auto *Op : Operands) - if (all_of(Stack, [Op](Value *V) { return Op != V; })) + if (!llvm::is_contained(Stack, Op)) Worklist.push_back(Op); break; } diff --git a/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/llvm/lib/Transforms/CFGuard/CFGuard.cpp index 5fc5295969d0..bebaa6cb5969 100644 --- a/llvm/lib/Transforms/CFGuard/CFGuard.cpp +++ b/llvm/lib/Transforms/CFGuard/CFGuard.cpp @@ -241,14 +241,21 @@ bool CFGuard::doInitialization(Module &M) { GuardFnPtrType = PointerType::get(GuardFnType, 0); // Get or insert the guard check or dispatch global symbols. + llvm::StringRef GuardFnName; if (GuardMechanism == CF_Check) { - GuardFnGlobal = - M.getOrInsertGlobal("__guard_check_icall_fptr", GuardFnPtrType); + GuardFnName = "__guard_check_icall_fptr"; + } else if (GuardMechanism == CF_Dispatch) { + GuardFnName = "__guard_dispatch_icall_fptr"; } else { - assert(GuardMechanism == CF_Dispatch && "Invalid CFGuard mechanism"); - GuardFnGlobal = - M.getOrInsertGlobal("__guard_dispatch_icall_fptr", GuardFnPtrType); + assert(false && "Invalid CFGuard mechanism"); } + GuardFnGlobal = M.getOrInsertGlobal(GuardFnName, GuardFnPtrType, [&] { + auto *Var = new GlobalVariable(M, GuardFnPtrType, false, + GlobalVariable::ExternalLinkage, nullptr, + GuardFnName); + Var->setDSOLocal(true); + return Var; + }); return true; } @@ -265,8 +272,8 @@ bool CFGuard::runOnFunction(Function &F) { // instructions. Make a separate list of pointers to indirect // call/invoke/callbr instructions because the original instructions will be // deleted as the checks are added. - for (BasicBlock &BB : F.getBasicBlockList()) { - for (Instruction &I : BB.getInstList()) { + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { auto *CB = dyn_cast<CallBase>(&I); if (CB && CB->isIndirectCall() && !CB->hasFnAttr("guard_nocf")) { IndirectCalls.push_back(CB); diff --git a/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp b/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp index 3d26a43ceba7..974123fe36a1 100644 --- a/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp @@ -22,3 +22,11 @@ PreservedAnalyses CoroConditionalWrapper::run(Module &M, return PM.run(M, AM); } + +void CoroConditionalWrapper::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + OS << "coro-cond"; + OS << "("; + PM.printPipeline(OS, MapClassName2PassName); + OS << ")"; +} diff --git a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp index dd7cb23f3f3d..d510b90d9dec 100644 --- a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -8,6 +8,7 @@ #include "llvm/Transforms/Coroutines/CoroEarly.h" #include "CoroInternal.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" @@ -100,6 +101,25 @@ void Lowerer::lowerCoroDone(IntrinsicInst *II) { II->eraseFromParent(); } +static void buildDebugInfoForNoopResumeDestroyFunc(Function *NoopFn) { + Module &M = *NoopFn->getParent(); + if (M.debug_compile_units().empty()) + return; + + DICompileUnit *CU = *M.debug_compile_units_begin(); + DIBuilder DB(M, /*AllowUnresolved*/ false, CU); + std::array<Metadata *, 2> Params{nullptr, nullptr}; + auto *SubroutineType = + DB.createSubroutineType(DB.getOrCreateTypeArray(Params)); + StringRef Name = NoopFn->getName(); + auto *SP = DB.createFunction( + CU, /*Name=*/Name, /*LinkageName=*/Name, /*File=*/ CU->getFile(), + /*LineNo=*/0, SubroutineType, /*ScopeLine=*/0, DINode::FlagArtificial, + DISubprogram::SPFlagDefinition); + NoopFn->setSubprogram(SP); + DB.finalize(); +} + void Lowerer::lowerCoroNoop(IntrinsicInst *II) { if (!NoopCoro) { LLVMContext &C = Builder.getContext(); @@ -116,8 +136,9 @@ void Lowerer::lowerCoroNoop(IntrinsicInst *II) { // Create a Noop function that does nothing. Function *NoopFn = Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, - "NoopCoro.ResumeDestroy", &M); + "__NoopCoro_ResumeDestroy", &M); NoopFn->setCallingConv(CallingConv::Fast); + buildDebugInfoForNoopResumeDestroyFunc(NoopFn); auto *Entry = BasicBlock::Create(C, "entry", NoopFn); ReturnInst::Create(C, Entry); diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp index 6f78fc8db311..f032c568449b 100644 --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -16,6 +16,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" +#include <optional> using namespace llvm; @@ -101,11 +102,12 @@ static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) { // Given a resume function @f.resume(%f.frame* %frame), returns the size // and expected alignment of %f.frame type. -static Optional<std::pair<uint64_t, Align>> getFrameLayout(Function *Resume) { +static std::optional<std::pair<uint64_t, Align>> +getFrameLayout(Function *Resume) { // Pull information from the function attributes. auto Size = Resume->getParamDereferenceableBytes(0); if (!Size) - return None; + return std::nullopt; return std::make_pair(Size, Resume->getParamAlign(0).valueOrOne()); } @@ -244,7 +246,7 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { // Filter out the coro.destroy that lie along exceptional paths. SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; - for (auto &It : DestroyAddr) { + for (const auto &It : DestroyAddr) { // If there is any coro.destroy dominates all of the terminators for the // coro.begin, we could know the corresponding coro.begin wouldn't escape. for (Instruction *DA : It.second) { diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp index 51eb8ebf0369..e98c601648e0 100644 --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -37,6 +37,7 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> +#include <optional> using namespace llvm; @@ -76,11 +77,14 @@ public: // // For every basic block 'i' it maintains a BlockData that consists of: // Consumes: a bit vector which contains a set of indices of blocks that can -// reach block 'i' +// reach block 'i'. A block can trivially reach itself. // Kills: a bit vector which contains a set of indices of blocks that can -// reach block 'i', but one of the path will cross a suspend point +// reach block 'i' but there is a path crossing a suspend point +// not repeating 'i' (path to 'i' without cycles containing 'i'). // Suspend: a boolean indicating whether block 'i' contains a suspend point. // End: a boolean indicating whether block 'i' contains a coro.end intrinsic. +// KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that +// crosses a suspend point. // namespace { struct SuspendCrossingInfo { @@ -91,6 +95,7 @@ struct SuspendCrossingInfo { BitVector Kills; bool Suspend = false; bool End = false; + bool KillLoop = false; }; SmallVector<BlockData, SmallVectorThreshold> Block; @@ -108,16 +113,31 @@ struct SuspendCrossingInfo { SuspendCrossingInfo(Function &F, coro::Shape &Shape); - bool hasPathCrossingSuspendPoint(BasicBlock *DefBB, BasicBlock *UseBB) const { - size_t const DefIndex = Mapping.blockToIndex(DefBB); - size_t const UseIndex = Mapping.blockToIndex(UseBB); - - bool const Result = Block[UseIndex].Kills[DefIndex]; - LLVM_DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName() + /// Returns true if there is a path from \p From to \p To crossing a suspend + /// point without crossing \p From a 2nd time. + bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const { + size_t const FromIndex = Mapping.blockToIndex(From); + size_t const ToIndex = Mapping.blockToIndex(To); + bool const Result = Block[ToIndex].Kills[FromIndex]; + LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName() << " answer is " << Result << "\n"); return Result; } + /// Returns true if there is a path from \p From to \p To crossing a suspend + /// point without crossing \p From a 2nd time. If \p From is the same as \p To + /// this will also check if there is a looping path crossing a suspend point. + bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From, + BasicBlock *To) const { + size_t const FromIndex = Mapping.blockToIndex(From); + size_t const ToIndex = Mapping.blockToIndex(To); + bool Result = Block[ToIndex].Kills[FromIndex] || + (From == To && Block[ToIndex].KillLoop); + LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName() + << " answer is " << Result << " (path or loop)\n"); + return Result; + } + bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const { auto *I = cast<Instruction>(U); @@ -270,6 +290,7 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) } else { // This is reached when S block it not Suspend nor coro.end and it // need to make sure that it is not in the kill set. + S.KillLoop |= S.Kills[SuccNo]; S.Kills.reset(SuccNo); } @@ -302,10 +323,10 @@ class FrameTypeBuilder; using SpillInfo = SmallMapVector<Value *, SmallVector<Instruction *, 2>, 8>; struct AllocaInfo { AllocaInst *Alloca; - DenseMap<Instruction *, llvm::Optional<APInt>> Aliases; + DenseMap<Instruction *, std::optional<APInt>> Aliases; bool MayWriteBeforeCoroBegin; AllocaInfo(AllocaInst *Alloca, - DenseMap<Instruction *, llvm::Optional<APInt>> Aliases, + DenseMap<Instruction *, std::optional<APInt>> Aliases, bool MayWriteBeforeCoroBegin) : Alloca(Alloca), Aliases(std::move(Aliases)), MayWriteBeforeCoroBegin(MayWriteBeforeCoroBegin) {} @@ -437,20 +458,20 @@ private: Align StructAlign; bool IsFinished = false; - Optional<Align> MaxFrameAlignment; + std::optional<Align> MaxFrameAlignment; SmallVector<Field, 8> Fields; DenseMap<Value*, unsigned> FieldIndexByKey; public: FrameTypeBuilder(LLVMContext &Context, const DataLayout &DL, - Optional<Align> MaxFrameAlignment) + std::optional<Align> MaxFrameAlignment) : DL(DL), Context(Context), MaxFrameAlignment(MaxFrameAlignment) {} /// Add a field to this structure for the storage of an `alloca` /// instruction. - LLVM_NODISCARD FieldIDType addFieldForAlloca(AllocaInst *AI, - bool IsHeader = false) { + [[nodiscard]] FieldIDType addFieldForAlloca(AllocaInst *AI, + bool IsHeader = false) { Type *Ty = AI->getAllocatedType(); // Make an array type if this is a static array allocation. @@ -495,9 +516,9 @@ public: coro::Shape &Shape); /// Add a field to this structure. - LLVM_NODISCARD FieldIDType addField(Type *Ty, MaybeAlign MaybeFieldAlignment, - bool IsHeader = false, - bool IsSpillOfValue = false) { + [[nodiscard]] FieldIDType addField(Type *Ty, MaybeAlign MaybeFieldAlignment, + bool IsHeader = false, + bool IsSpillOfValue = false) { assert(!IsFinished && "adding fields to a finished builder"); assert(Ty && "must provide a type for a field"); @@ -629,8 +650,8 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F, // patterns since it just prevend putting the allocas to live in the same // slot. DenseMap<SwitchInst *, BasicBlock *> DefaultSuspendDest; - for (auto CoroSuspendInst : Shape.CoroSuspends) { - for (auto U : CoroSuspendInst->users()) { + for (auto *CoroSuspendInst : Shape.CoroSuspends) { + for (auto *U : CoroSuspendInst->users()) { if (auto *ConstSWI = dyn_cast<SwitchInst>(U)) { auto *SWI = const_cast<SwitchInst *>(ConstSWI); DefaultSuspendDest[SWI] = SWI->getDefaultDest(); @@ -654,10 +675,10 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F, StackLifetimeAnalyzer.getLiveRange(AI2)); }; auto GetAllocaSize = [&](const AllocaInfo &A) { - Optional<TypeSize> RetSize = A.Alloca->getAllocationSizeInBits(DL); + std::optional<TypeSize> RetSize = A.Alloca->getAllocationSize(DL); assert(RetSize && "Variable Length Arrays (VLA) are not supported.\n"); assert(!RetSize->isScalable() && "Scalable vectors are not yet supported"); - return RetSize->getFixedSize(); + return RetSize->getFixedValue(); }; // Put larger allocas in the front. So the larger allocas have higher // priority to merge, which can save more space potentially. Also each @@ -888,14 +909,15 @@ static DIType *solveDIType(DIBuilder &Builder, Type *Ty, // struct Node { // Node* ptr; // }; - RetType = Builder.createPointerType(nullptr, Layout.getTypeSizeInBits(Ty), - Layout.getABITypeAlignment(Ty), - /*DWARFAddressSpace=*/None, Name); + RetType = + Builder.createPointerType(nullptr, Layout.getTypeSizeInBits(Ty), + Layout.getABITypeAlign(Ty).value() * CHAR_BIT, + /*DWARFAddressSpace=*/std::nullopt, Name); } else if (Ty->isStructTy()) { auto *DIStruct = Builder.createStructType( Scope, Name, Scope->getFile(), LineNum, Layout.getTypeSizeInBits(Ty), - Layout.getPrefTypeAlignment(Ty), llvm::DINode::FlagArtificial, nullptr, - llvm::DINodeArray()); + Layout.getPrefTypeAlign(Ty).value() * CHAR_BIT, + llvm::DINode::FlagArtificial, nullptr, llvm::DINodeArray()); auto *StructTy = cast<StructType>(Ty); SmallVector<Metadata *, 16> Elements; @@ -1064,7 +1086,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, Type *Ty = FrameTy->getElementType(Index); assert(Ty->isSized() && "We can't handle type which is not sized.\n"); - SizeInBits = Layout.getTypeSizeInBits(Ty).getFixedSize(); + SizeInBits = Layout.getTypeSizeInBits(Ty).getFixedValue(); AlignInBits = OffsetCache[Index].first * 8; OffsetInBits = OffsetCache[Index].second * 8; @@ -1131,13 +1153,13 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, }(); // We will use this value to cap the alignment of spilled values. - Optional<Align> MaxFrameAlignment; + std::optional<Align> MaxFrameAlignment; if (Shape.ABI == coro::ABI::Async) MaxFrameAlignment = Shape.AsyncLowering.getContextAlignment(); FrameTypeBuilder B(C, DL, MaxFrameAlignment); AllocaInst *PromiseAlloca = Shape.getPromiseAlloca(); - Optional<FieldIDType> SwitchIndexFieldId; + std::optional<FieldIDType> SwitchIndexFieldId; if (Shape.ABI == coro::ABI::Switch) { auto *FramePtrTy = FrameTy->getPointerTo(); @@ -1147,8 +1169,8 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, // Add header fields for the resume and destroy functions. // We can rely on these being perfectly packed. - (void)B.addField(FnPtrTy, None, /*header*/ true); - (void)B.addField(FnPtrTy, None, /*header*/ true); + (void)B.addField(FnPtrTy, std::nullopt, /*header*/ true); + (void)B.addField(FnPtrTy, std::nullopt, /*header*/ true); // PromiseAlloca field needs to be explicitly added here because it's // a header field with a fixed offset based on its alignment. Hence it @@ -1162,7 +1184,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size())); Type *IndexType = Type::getIntNTy(C, IndexBits); - SwitchIndexFieldId = B.addField(IndexType, None); + SwitchIndexFieldId = B.addField(IndexType, std::nullopt); } else { assert(PromiseAlloca == nullptr && "lowering doesn't support promises"); } @@ -1178,7 +1200,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, // We assume that the promise alloca won't be modified before // CoroBegin and no alias will be create before CoroBegin. FrameData.Allocas.emplace_back( - PromiseAlloca, DenseMap<Instruction *, llvm::Optional<APInt>>{}, false); + PromiseAlloca, DenseMap<Instruction *, std::optional<APInt>>{}, false); // Create an entry for every spilled value. for (auto &S : FrameData.Spills) { Type *FieldType = S.first->getType(); @@ -1187,8 +1209,8 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, if (const Argument *A = dyn_cast<Argument>(S.first)) if (A->hasByValAttr()) FieldType = A->getParamByValType(); - FieldIDType Id = - B.addField(FieldType, None, false /*header*/, true /*IsSpillOfValue*/); + FieldIDType Id = B.addField(FieldType, std::nullopt, false /*header*/, + true /*IsSpillOfValue*/); FrameData.setFieldIndex(S.first, Id); } @@ -1403,7 +1425,7 @@ struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> { bool getMayWriteBeforeCoroBegin() const { return MayWriteBeforeCoroBegin; } - DenseMap<Instruction *, llvm::Optional<APInt>> getAliasesCopy() const { + DenseMap<Instruction *, std::optional<APInt>> getAliasesCopy() const { assert(getShouldLiveOnFrame() && "This method should only be called if the " "alloca needs to live on the frame."); for (const auto &P : AliasOffetMap) @@ -1420,13 +1442,13 @@ private: // All alias to the original AllocaInst, created before CoroBegin and used // after CoroBegin. Each entry contains the instruction and the offset in the // original Alloca. They need to be recreated after CoroBegin off the frame. - DenseMap<Instruction *, llvm::Optional<APInt>> AliasOffetMap{}; + DenseMap<Instruction *, std::optional<APInt>> AliasOffetMap{}; SmallPtrSet<Instruction *, 4> Users{}; SmallPtrSet<IntrinsicInst *, 2> LifetimeStarts{}; bool MayWriteBeforeCoroBegin{false}; bool ShouldUseLifetimeStartInfo{true}; - mutable llvm::Optional<bool> ShouldLiveOnFrame{}; + mutable std::optional<bool> ShouldLiveOnFrame{}; bool computeShouldLiveOnFrame() const { // If lifetime information is available, we check it first since it's @@ -1438,6 +1460,19 @@ private: for (auto *S : LifetimeStarts) if (Checker.isDefinitionAcrossSuspend(*S, I)) return true; + // Addresses are guaranteed to be identical after every lifetime.start so + // we cannot use the local stack if the address escaped and there is a + // suspend point between lifetime markers. This should also cover the + // case of a single lifetime.start intrinsic in a loop with suspend point. + if (PI.isEscaped()) { + for (auto *A : LifetimeStarts) { + for (auto *B : LifetimeStarts) { + if (Checker.hasPathOrLoopCrossingSuspendPoint(A->getParent(), + B->getParent())) + return true; + } + } + } return false; } // FIXME: Ideally the isEscaped check should come at the beginning. @@ -1599,7 +1634,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { // // Note: If we change the strategy dealing with alignment, we need to refine // this casting. - if (GEP->getResultElementType() != Orig->getType()) + if (GEP->getType() != Orig->getType()) return Builder.CreateBitCast(GEP, Orig->getType(), Orig->getName() + Twine(".cast")); } @@ -1777,8 +1812,15 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { for (auto *DVI : DIs) DVI->replaceUsesOfWith(Alloca, G); - for (Instruction *I : UsersToUpdate) + for (Instruction *I : UsersToUpdate) { + // It is meaningless to remain the lifetime intrinsics refer for the + // member of coroutine frames and the meaningless lifetime intrinsics + // are possible to block further optimizations. + if (I->isLifetimeStartOrEnd()) + continue; + I->replaceUsesOfWith(Alloca, G); + } } Builder.SetInsertPoint(Shape.getInsertPtAfterFramePtr()); for (const auto &A : FrameData.Allocas) { @@ -1810,6 +1852,47 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { AliasPtrTyped, [&](Use &U) { return DT.dominates(CB, U); }); } } + + // PromiseAlloca is not collected in FrameData.Allocas. So we don't handle + // the case that the PromiseAlloca may have writes before CoroBegin in the + // above codes. And it may be problematic in edge cases. See + // https://github.com/llvm/llvm-project/issues/57861 for an example. + if (Shape.ABI == coro::ABI::Switch && Shape.SwitchLowering.PromiseAlloca) { + AllocaInst *PA = Shape.SwitchLowering.PromiseAlloca; + // If there is memory accessing to promise alloca before CoroBegin; + bool HasAccessingPromiseBeforeCB = llvm::any_of(PA->uses(), [&](Use &U) { + auto *Inst = dyn_cast<Instruction>(U.getUser()); + if (!Inst || DT.dominates(CB, Inst)) + return false; + + if (auto *CI = dyn_cast<CallInst>(Inst)) { + // It is fine if the call wouldn't write to the Promise. + // This is possible for @llvm.coro.id intrinsics, which + // would take the promise as the second argument as a + // marker. + if (CI->onlyReadsMemory() || + CI->onlyReadsMemory(CI->getArgOperandNo(&U))) + return false; + return true; + } + + return isa<StoreInst>(Inst) || + // It may take too much time to track the uses. + // Be conservative about the case the use may escape. + isa<GetElementPtrInst>(Inst) || + // There would always be a bitcast for the promise alloca + // before we enabled Opaque pointers. And now given + // opaque pointers are enabled by default. This should be + // fine. + isa<BitCastInst>(Inst); + }); + if (HasAccessingPromiseBeforeCB) { + Builder.SetInsertPoint(Shape.getInsertPtAfterFramePtr()); + auto *G = GetFramePointer(PA); + auto *Value = Builder.CreateLoad(PA->getAllocatedType(), PA); + Builder.CreateStore(Value, G); + } + } } // Moves the values in the PHIs in SuccBB that correspong to PredBB into a new @@ -2099,7 +2182,7 @@ static bool isSuspendReachableFrom(BasicBlock *From, return true; // Recurse on the successors. - for (auto Succ : successors(From)) { + for (auto *Succ : successors(From)) { if (isSuspendReachableFrom(Succ, VisitedOrFreeBBs)) return true; } @@ -2113,7 +2196,7 @@ static bool isLocalAlloca(CoroAllocaAllocInst *AI) { // Seed the visited set with all the basic blocks containing a free // so that we won't pass them up. VisitedBlocksSet VisitedOrFreeBBs; - for (auto User : AI->users()) { + for (auto *User : AI->users()) { if (auto FI = dyn_cast<CoroAllocaFreeInst>(User)) VisitedOrFreeBBs.insert(FI->getParent()); } @@ -2133,7 +2216,7 @@ static bool willLeaveFunctionImmediatelyAfter(BasicBlock *BB, if (isSuspendBlock(BB)) return true; // Recurse into the successors. - for (auto Succ : successors(BB)) { + for (auto *Succ : successors(BB)) { if (!willLeaveFunctionImmediatelyAfter(Succ, depth - 1)) return false; } @@ -2146,7 +2229,7 @@ static bool localAllocaNeedsStackSave(CoroAllocaAllocInst *AI) { // Look for a free that isn't sufficiently obviously followed by // either a suspend or a termination, i.e. something that will leave // the coro resumption frame. - for (auto U : AI->users()) { + for (auto *U : AI->users()) { auto FI = dyn_cast<CoroAllocaFreeInst>(U); if (!FI) continue; @@ -2162,7 +2245,7 @@ static bool localAllocaNeedsStackSave(CoroAllocaAllocInst *AI) { /// instruction. static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas, SmallVectorImpl<Instruction*> &DeadInsts) { - for (auto AI : LocalAllocas) { + for (auto *AI : LocalAllocas) { auto M = AI->getModule(); IRBuilder<> Builder(AI); @@ -2177,7 +2260,7 @@ static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas, auto Alloca = Builder.CreateAlloca(Builder.getInt8Ty(), AI->getSize()); Alloca->setAlignment(AI->getAlignment()); - for (auto U : AI->users()) { + for (auto *U : AI->users()) { // Replace gets with the allocation. if (isa<CoroAllocaGetInst>(U)) { U->replaceAllUsesWith(Alloca); @@ -2340,12 +2423,12 @@ static void eliminateSwiftErrorArgument(Function &F, Argument &Arg, Builder.CreateStore(InitialValue, Alloca); // Find all the suspends in the function and save and restore around them. - for (auto Suspend : Shape.CoroSuspends) { + for (auto *Suspend : Shape.CoroSuspends) { (void) emitSetAndGetSwiftErrorValueAround(Suspend, Alloca, Shape); } // Find all the coro.ends in the function and restore the error value. - for (auto End : Shape.CoroEnds) { + for (auto *End : Shape.CoroEnds) { Builder.SetInsertPoint(End); auto FinalValue = Builder.CreateLoad(ValueTy, Alloca); (void) emitSetSwiftErrorValue(Builder, FinalValue, Shape); @@ -2523,34 +2606,32 @@ static void sinkLifetimeStartMarkers(Function &F, coro::Shape &Shape, } } -static void collectFrameAllocas(Function &F, coro::Shape &Shape, - const SuspendCrossingInfo &Checker, - SmallVectorImpl<AllocaInfo> &Allocas) { - for (Instruction &I : instructions(F)) { - auto *AI = dyn_cast<AllocaInst>(&I); - if (!AI) - continue; - // The PromiseAlloca will be specially handled since it needs to be in a - // fixed position in the frame. - if (AI == Shape.SwitchLowering.PromiseAlloca) { - continue; - } - DominatorTree DT(F); - // The code that uses lifetime.start intrinsic does not work for functions - // with loops without exit. Disable it on ABIs we know to generate such - // code. - bool ShouldUseLifetimeStartInfo = - (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon && - Shape.ABI != coro::ABI::RetconOnce); - AllocaUseVisitor Visitor{F.getParent()->getDataLayout(), DT, - *Shape.CoroBegin, Checker, - ShouldUseLifetimeStartInfo}; - Visitor.visitPtr(*AI); - if (!Visitor.getShouldLiveOnFrame()) - continue; - Allocas.emplace_back(AI, Visitor.getAliasesCopy(), - Visitor.getMayWriteBeforeCoroBegin()); - } +static void collectFrameAlloca(AllocaInst *AI, coro::Shape &Shape, + const SuspendCrossingInfo &Checker, + SmallVectorImpl<AllocaInfo> &Allocas, + const DominatorTree &DT) { + if (Shape.CoroSuspends.empty()) + return; + + // The PromiseAlloca will be specially handled since it needs to be in a + // fixed position in the frame. + if (AI == Shape.SwitchLowering.PromiseAlloca) + return; + + // The code that uses lifetime.start intrinsic does not work for functions + // with loops without exit. Disable it on ABIs we know to generate such + // code. + bool ShouldUseLifetimeStartInfo = + (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon && + Shape.ABI != coro::ABI::RetconOnce); + AllocaUseVisitor Visitor{AI->getModule()->getDataLayout(), DT, + *Shape.CoroBegin, Checker, + ShouldUseLifetimeStartInfo}; + Visitor.visitPtr(*AI); + if (!Visitor.getShouldLiveOnFrame()) + return; + Allocas.emplace_back(AI, Visitor.getAliasesCopy(), + Visitor.getMayWriteBeforeCoroBegin()); } void coro::salvageDebugInfo( @@ -2633,16 +2714,13 @@ void coro::salvageDebugInfo( // dbg.value or dbg.addr since they do not have the same function wide // guarantees that dbg.declare does. if (!isa<DbgValueInst>(DVI) && !isa<DbgAddrIntrinsic>(DVI)) { - if (auto *II = dyn_cast<InvokeInst>(Storage)) - DVI->moveBefore(II->getNormalDest()->getFirstNonPHI()); - else if (auto *CBI = dyn_cast<CallBrInst>(Storage)) - DVI->moveBefore(CBI->getDefaultDest()->getFirstNonPHI()); - else if (auto *InsertPt = dyn_cast<Instruction>(Storage)) { - assert(!InsertPt->isTerminator() && - "Unimaged terminator that could return a storage."); - DVI->moveAfter(InsertPt); - } else if (isa<Argument>(Storage)) - DVI->moveAfter(F->getEntryBlock().getFirstNonPHI()); + Instruction *InsertPt = nullptr; + if (auto *I = dyn_cast<Instruction>(Storage)) + InsertPt = I->getInsertionPointAfterDef(); + else if (isa<Argument>(Storage)) + InsertPt = &*F->getEntryBlock().begin(); + if (InsertPt) + DVI->moveBefore(InsertPt); } } @@ -2687,7 +2765,7 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { } // Later code makes structural assumptions about single predecessors phis e.g - // that they are not live accross a suspend point. + // that they are not live across a suspend point. cleanupSinglePredPHIs(F); // Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will @@ -2706,6 +2784,8 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { SpillInfo Spills; for (int Repeat = 0; Repeat < 4; ++Repeat) { // See if there are materializable instructions across suspend points. + // FIXME: We can use a worklist to track the possible materialize + // instructions instead of iterating the whole function again and again. for (Instruction &I : instructions(F)) if (materializable(I)) { for (User *U : I.users()) @@ -2728,28 +2808,19 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { Shape.ABI != coro::ABI::RetconOnce) sinkLifetimeStartMarkers(F, Shape, Checker); - if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty()) - collectFrameAllocas(F, Shape, Checker, FrameData.Allocas); - LLVM_DEBUG(dumpAllocas(FrameData.Allocas)); - // Collect the spills for arguments and other not-materializable values. for (Argument &A : F.args()) for (User *U : A.users()) if (Checker.isDefinitionAcrossSuspend(A, U)) FrameData.Spills[&A].push_back(cast<Instruction>(U)); + const DominatorTree DT(F); for (Instruction &I : instructions(F)) { // Values returned from coroutine structure intrinsics should not be part // of the Coroutine Frame. if (isCoroutineStructureIntrinsic(I) || &I == Shape.CoroBegin) continue; - // The Coroutine Promise always included into coroutine frame, no need to - // check for suspend crossing. - if (Shape.ABI == coro::ABI::Switch && - Shape.SwitchLowering.PromiseAlloca == &I) - continue; - // Handle alloca.alloc specially here. if (auto AI = dyn_cast<CoroAllocaAllocInst>(&I)) { // Check whether the alloca's lifetime is bounded by suspend points. @@ -2776,8 +2847,10 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { if (isa<CoroAllocaGetInst>(I)) continue; - if (isa<AllocaInst>(I)) + if (auto *AI = dyn_cast<AllocaInst>(&I)) { + collectFrameAlloca(AI, Shape, Checker, FrameData.Allocas, DT); continue; + } for (User *U : I.users()) if (Checker.isDefinitionAcrossSuspend(I, U)) { @@ -2789,6 +2862,8 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { } } + LLVM_DEBUG(dumpAllocas(FrameData.Allocas)); + // We don't want the layout of coroutine frame to be affected // by debug information. So we only choose to salvage DbgValueInst for // whose value is already in the frame. @@ -2813,6 +2888,6 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { insertSpills(FrameData, Shape); lowerLocalAllocas(LocalAllocas, DeadInstructions); - for (auto I : DeadInstructions) + for (auto *I : DeadInstructions) I->eraseFromParent(); } diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h index af35b45c2eaf..032361c22045 100644 --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -112,6 +112,7 @@ struct LLVM_LIBRARY_VISIBILITY Shape { unsigned IndexAlign; unsigned IndexOffset; bool HasFinalSuspend; + bool HasUnwindCoroEnd; }; struct RetconLoweringStorage { diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 722a1c6ec0ce..1171878f749a 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -201,8 +201,8 @@ static bool replaceCoroEndAsync(AnyCoroEndInst *End) { assert(MustTailCallFuncBlock && "Must have a single predecessor block"); auto It = MustTailCallFuncBlock->getTerminator()->getIterator(); auto *MustTailCall = cast<CallInst>(&*std::prev(It)); - CoroEndBlock->getInstList().splice( - End->getIterator(), MustTailCallFuncBlock->getInstList(), MustTailCall); + CoroEndBlock->splice(End->getIterator(), MustTailCallFuncBlock, + MustTailCall->getIterator()); // Insert the return instruction. Builder.SetInsertPoint(End); @@ -396,11 +396,22 @@ static void createResumeEntryBlock(Function &F, coro::Shape &Shape) { // The coroutine should be marked done if it reaches the final suspend // point. markCoroutineAsDone(Builder, Shape, FramePtr); - } else { + } + + // If the coroutine don't have unwind coro end, we could omit the store to + // the final suspend point since we could infer the coroutine is suspended + // at the final suspend point by the nullness of ResumeFnAddr. + // However, we can't skip it if the coroutine have unwind coro end. Since + // the coroutine reaches unwind coro end is considered suspended at the + // final suspend point (the ResumeFnAddr is null) but in fact the coroutine + // didn't complete yet. We need the IndexVal for the final suspend point + // to make the states clear. + if (!S->isFinal() || Shape.SwitchLowering.HasUnwindCoroEnd) { auto *GepIndex = Builder.CreateStructGEP( FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr"); Builder.CreateStore(IndexVal, GepIndex); } + Save->replaceAllUsesWith(ConstantTokenNone::get(C)); Save->eraseFromParent(); @@ -449,19 +460,22 @@ static void createResumeEntryBlock(Function &F, coro::Shape &Shape) { Shape.SwitchLowering.ResumeEntryBlock = NewEntry; } - -// Rewrite final suspend point handling. We do not use suspend index to -// represent the final suspend point. Instead we zero-out ResumeFnAddr in the -// coroutine frame, since it is undefined behavior to resume a coroutine -// suspended at the final suspend point. Thus, in the resume function, we can -// simply remove the last case (when coro::Shape is built, the final suspend -// point (if present) is always the last element of CoroSuspends array). -// In the destroy function, we add a code sequence to check if ResumeFnAddress -// is Null, and if so, jump to the appropriate label to handle cleanup from the -// final suspend point. +// In the resume function, we remove the last case (when coro::Shape is built, +// the final suspend point (if present) is always the last element of +// CoroSuspends array) since it is an undefined behavior to resume a coroutine +// suspended at the final suspend point. +// In the destroy function, if it isn't possible that the ResumeFnAddr is NULL +// and the coroutine doesn't suspend at the final suspend point actually (this +// is possible since the coroutine is considered suspended at the final suspend +// point if promise.unhandled_exception() exits via an exception), we can +// remove the last case. void CoroCloner::handleFinalSuspend() { assert(Shape.ABI == coro::ABI::Switch && Shape.SwitchLowering.HasFinalSuspend); + + if (isSwitchDestroyFunction() && Shape.SwitchLowering.HasUnwindCoroEnd) + return; + auto *Switch = cast<SwitchInst>(VMap[Shape.SwitchLowering.ResumeSwitch]); auto FinalCaseIt = std::prev(Switch->case_end()); BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor(); @@ -502,13 +516,6 @@ static Function *createCloneDeclaration(Function &OrigF, coro::Shape &Shape, Function *NewF = Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, OrigF.getName() + Suffix); - if (Shape.ABI != coro::ABI::Async) - NewF->addParamAttr(0, Attribute::NonNull); - - // For the async lowering ABI we can't guarantee that the context argument is - // not access via a different pointer not based on the argument. - if (Shape.ABI != coro::ABI::Async) - NewF->addParamAttr(0, Attribute::NoAlias); M->getFunctionList().insert(InsertBefore, NewF); @@ -835,11 +842,15 @@ Value *CoroCloner::deriveNewFramePointer() { } static void addFramePointerAttrs(AttributeList &Attrs, LLVMContext &Context, - unsigned ParamIndex, - uint64_t Size, Align Alignment) { + unsigned ParamIndex, uint64_t Size, + Align Alignment, bool NoAlias) { AttrBuilder ParamAttrs(Context); ParamAttrs.addAttribute(Attribute::NonNull); - ParamAttrs.addAttribute(Attribute::NoAlias); + ParamAttrs.addAttribute(Attribute::NoUndef); + + if (NoAlias) + ParamAttrs.addAttribute(Attribute::NoAlias); + ParamAttrs.addAlignmentAttr(Alignment); ParamAttrs.addDereferenceableAttr(Size); Attrs = Attrs.addParamAttributes(Context, ParamIndex, ParamAttrs); @@ -945,8 +956,8 @@ void CoroCloner::create() { NewAttrs = NewAttrs.addFnAttributes( Context, AttrBuilder(Context, OrigAttrs.getFnAttrs())); - addFramePointerAttrs(NewAttrs, Context, 0, - Shape.FrameSize, Shape.FrameAlign); + addFramePointerAttrs(NewAttrs, Context, 0, Shape.FrameSize, + Shape.FrameAlign, /*NoAlias=*/false); break; case coro::ABI::Async: { auto *ActiveAsyncSuspend = cast<CoroSuspendAsyncInst>(ActiveSuspend); @@ -975,9 +986,12 @@ void CoroCloner::create() { // full-stop. NewAttrs = Shape.RetconLowering.ResumePrototype->getAttributes(); + /// FIXME: Is it really good to add the NoAlias attribute? addFramePointerAttrs(NewAttrs, Context, 0, Shape.getRetconCoroId()->getStorageSize(), - Shape.getRetconCoroId()->getStorageAlignment()); + Shape.getRetconCoroId()->getStorageAlignment(), + /*NoAlias=*/true); + break; } @@ -1362,7 +1376,7 @@ static bool shouldBeMustTail(const CallInst &CI, const Function &F) { // for symmetrical coroutine control transfer (C++ Coroutines TS extension). // This transformation is done only in the resume part of the coroutine that has // identical signature and calling convention as the coro.resume call. -static void addMustTailToCoroResumes(Function &F) { +static void addMustTailToCoroResumes(Function &F, TargetTransformInfo &TTI) { bool changed = false; // Collect potential resume instructions. @@ -1374,7 +1388,9 @@ static void addMustTailToCoroResumes(Function &F) { // Set musttail on those that are followed by a ret instruction. for (CallInst *Call : Resumes) - if (simplifyTerminatorLeadingToRet(Call->getNextNode())) { + // Skip targets which don't support tail call on the specific case. + if (TTI.supportsTailCallFor(Call) && + simplifyTerminatorLeadingToRet(Call->getNextNode())) { Call->setTailCallKind(CallInst::TCK_MustTail); changed = true; } @@ -1555,6 +1571,8 @@ static void simplifySuspendPoints(coro::Shape &Shape) { size_t I = 0, N = S.size(); if (N == 0) return; + + size_t ChangedFinalIndex = std::numeric_limits<size_t>::max(); while (true) { auto SI = cast<CoroSuspendInst>(S[I]); // Leave final.suspend to handleFinalSuspend since it is undefined behavior @@ -1562,13 +1580,27 @@ static void simplifySuspendPoints(coro::Shape &Shape) { if (!SI->isFinal() && simplifySuspendPoint(SI, Shape.CoroBegin)) { if (--N == I) break; + std::swap(S[I], S[N]); + + if (cast<CoroSuspendInst>(S[I])->isFinal()) { + assert(Shape.SwitchLowering.HasFinalSuspend); + ChangedFinalIndex = I; + } + continue; } if (++I == N) break; } S.resize(N); + + // Maintain final.suspend in case final suspend was swapped. + // Due to we requrie the final suspend to be the last element of CoroSuspends. + if (ChangedFinalIndex < N) { + assert(cast<CoroSuspendInst>(S[ChangedFinalIndex])->isFinal()); + std::swap(S[ChangedFinalIndex], S.back()); + } } static void splitSwitchCoroutine(Function &F, coro::Shape &Shape, @@ -1594,7 +1626,7 @@ static void splitSwitchCoroutine(Function &F, coro::Shape &Shape, // FIXME: Could we support symmetric transfer effectively without musttail // call? if (TTI.supportsTailCalls()) - addMustTailToCoroResumes(*ResumeClone); + addMustTailToCoroResumes(*ResumeClone, TTI); // Store addresses resume/destroy/cleanup functions in the coroutine frame. updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); @@ -1629,7 +1661,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy, ArrayRef<Value *> FnArgs, SmallVectorImpl<Value *> &CallArgs) { size_t ArgIdx = 0; - for (auto paramTy : FnTy->params()) { + for (auto *paramTy : FnTy->params()) { assert(ArgIdx < FnArgs.size()); if (paramTy != FnArgs[ArgIdx]->getType()) CallArgs.push_back( @@ -1838,7 +1870,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, Shape.CoroSuspends.size())); // Next, all the directly-yielded values. - for (auto ResultTy : Shape.getRetconResultTypes()) + for (auto *ResultTy : Shape.getRetconResultTypes()) ReturnPHIs.push_back(Builder.CreatePHI(ResultTy, Shape.CoroSuspends.size())); @@ -1963,7 +1995,7 @@ static coro::Shape splitCoroutine(Function &F, /// Remove calls to llvm.coro.end in the original function. static void removeCoroEnds(const coro::Shape &Shape) { - for (auto End : Shape.CoroEnds) { + for (auto *End : Shape.CoroEnds) { replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr); } } diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 1742e9319c3b..ce4262e593b6 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -171,6 +171,7 @@ static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin, // Collect "interesting" coroutine intrinsics. void coro::Shape::buildFrom(Function &F) { bool HasFinalSuspend = false; + bool HasUnwindCoroEnd = false; size_t FinalSuspendIndex = 0; clear(*this); SmallVector<CoroFrameInst *, 8> CoroFrames; @@ -242,6 +243,10 @@ void coro::Shape::buildFrom(Function &F) { if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) { AsyncEnd->checkWellFormed(); } + + if (CoroEnds.back()->isUnwind()) + HasUnwindCoroEnd = true; + if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) { // Make sure that the fallthrough coro.end is the first element in the // CoroEnds vector. @@ -290,11 +295,12 @@ void coro::Shape::buildFrom(Function &F) { auto SwitchId = cast<CoroIdInst>(Id); this->ABI = coro::ABI::Switch; this->SwitchLowering.HasFinalSuspend = HasFinalSuspend; + this->SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd; this->SwitchLowering.ResumeSwitch = nullptr; this->SwitchLowering.PromiseAlloca = SwitchId->getPromise(); this->SwitchLowering.ResumeEntryBlock = nullptr; - for (auto AnySuspend : CoroSuspends) { + for (auto *AnySuspend : CoroSuspends) { auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend); if (!Suspend) { #ifndef NDEBUG @@ -340,7 +346,7 @@ void coro::Shape::buildFrom(Function &F) { auto ResultTys = getRetconResultTypes(); auto ResumeTys = getRetconResumeTypes(); - for (auto AnySuspend : CoroSuspends) { + for (auto *AnySuspend : CoroSuspends) { auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend); if (!Suspend) { #ifndef NDEBUG diff --git a/llvm/lib/Transforms/IPO/AlwaysInliner.cpp b/llvm/lib/Transforms/IPO/AlwaysInliner.cpp index 58cea7ebb749..09286482edff 100644 --- a/llvm/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/llvm/lib/Transforms/IPO/AlwaysInliner.cpp @@ -70,8 +70,9 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, &FAM.getResult<BlockFrequencyAnalysis>(*Caller), &FAM.getResult<BlockFrequencyAnalysis>(F)); - InlineResult Res = InlineFunction( - *CB, IFI, &FAM.getResult<AAManager>(F), InsertLifetime); + InlineResult Res = + InlineFunction(*CB, IFI, /*MergeAttributes=*/true, + &FAM.getResult<AAManager>(F), InsertLifetime); if (!Res.isSuccess()) { ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, @@ -88,9 +89,6 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, InlineCost::getAlways("always inline attribute"), /*ForProfileContext=*/false, DEBUG_TYPE); - // Merge the attributes based on the inlining. - AttributeFuncs::mergeAttributesForInlining(*Caller, F); - Changed = true; } diff --git a/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp b/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp index 5ca4e24df8fc..6cc04544cabc 100644 --- a/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp +++ b/llvm/lib/Transforms/IPO/Annotation2Metadata.cpp @@ -47,20 +47,13 @@ static bool convertAnnotation2Metadata(Module &M) { auto *OpC = dyn_cast<ConstantStruct>(&Op); if (!OpC || OpC->getNumOperands() != 4) continue; - auto *StrGEP = dyn_cast<ConstantExpr>(OpC->getOperand(1)); - if (!StrGEP || StrGEP->getNumOperands() < 2) - continue; - auto *StrC = dyn_cast<GlobalValue>(StrGEP->getOperand(0)); + auto *StrC = dyn_cast<GlobalValue>(OpC->getOperand(1)->stripPointerCasts()); if (!StrC) continue; auto *StrData = dyn_cast<ConstantDataSequential>(StrC->getOperand(0)); if (!StrData) continue; - // Look through bitcast. - auto *Bitcast = dyn_cast<ConstantExpr>(OpC->getOperand(0)); - if (!Bitcast || Bitcast->getOpcode() != Instruction::BitCast) - continue; - auto *Fn = dyn_cast<Function>(Bitcast->getOperand(0)); + auto *Fn = dyn_cast<Function>(OpC->getOperand(0)->stripPointerCasts()); if (!Fn) continue; diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index 8c77b6937737..dd1a3b78a378 100644 --- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -204,7 +204,7 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, for (auto *I : Params) if (auto *VT = dyn_cast<llvm::VectorType>(I)) LargestVectorWidth = std::max( - LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinSize()); + LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinValue()); // Recompute the parameter attributes list based on the new arguments for // the function. @@ -300,7 +300,7 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the // function empty. - NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); + NF->splice(NF->begin(), F); // We will collect all the new created allocas to promote them into registers // after the following loop @@ -476,10 +476,10 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR, bool AreStoresAllowed = Arg->getParamByValType() && Arg->getParamAlign(); // An end user of a pointer argument is a load or store instruction. - // Returns None if this load or store is not based on the argument. Return - // true if we can promote the instruction, false otherwise. + // Returns std::nullopt if this load or store is not based on the argument. + // Return true if we can promote the instruction, false otherwise. auto HandleEndUser = [&](auto *I, Type *Ty, - bool GuaranteedToExecute) -> Optional<bool> { + bool GuaranteedToExecute) -> std::optional<bool> { // Don't promote volatile or atomic instructions. if (!I->isSimple()) return false; @@ -489,7 +489,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR, Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset, /* AllowNonInbounds */ true); if (Ptr != Arg) - return None; + return std::nullopt; if (Offset.getSignificantBits() >= 64) return false; @@ -553,7 +553,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR, // Look for loads and stores that are guaranteed to execute on entry. for (Instruction &I : Arg->getParent()->getEntryBlock()) { - Optional<bool> Res{}; + std::optional<bool> Res{}; if (LoadInst *LI = dyn_cast<LoadInst>(&I)) Res = HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ true); else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp index 37c773bd47d6..b9134ce26e80 100644 --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -27,7 +27,9 @@ #include "llvm/Analysis/MustExecute.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantFold.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instruction.h" @@ -45,17 +47,20 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" +#include <cstdint> #ifdef EXPENSIVE_CHECKS #include "llvm/IR/Verifier.h" #endif #include <cassert> +#include <optional> #include <string> using namespace llvm; #define DEBUG_TYPE "attributor" +#define VERBOSE_DEBUG_TYPE DEBUG_TYPE "-verbose" DEBUG_COUNTER(ManifestDBGCounter, "attributor-manifest", "Determine what attributes are manifested in the IR"); @@ -219,7 +224,9 @@ bool AA::isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA, } Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + const DataLayout &DL, + AA::RangeTy *RangePtr) { if (isa<AllocaInst>(Obj)) return UndefValue::get(&Ty); if (Constant *Init = getInitialValueOfAllocation(&Obj, TLI, &Ty)) @@ -231,7 +238,13 @@ Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty, return nullptr; if (!GV->hasInitializer()) return UndefValue::get(&Ty); - return dyn_cast_or_null<Constant>(getWithType(*GV->getInitializer(), Ty)); + + if (RangePtr && !RangePtr->offsetOrSizeAreUnknown()) { + APInt Offset = APInt(64, RangePtr->Offset); + return ConstantFoldLoadFromConst(GV->getInitializer(), &Ty, Offset, DL); + } + + return ConstantFoldLoadFromUniformValue(GV->getInitializer(), &Ty); } bool AA::isValidInScope(const Value &V, const Function *Scope) { @@ -292,9 +305,10 @@ Value *AA::getWithType(Value &V, Type &Ty) { return nullptr; } -Optional<Value *> -AA::combineOptionalValuesInAAValueLatice(const Optional<Value *> &A, - const Optional<Value *> &B, Type *Ty) { +std::optional<Value *> +AA::combineOptionalValuesInAAValueLatice(const std::optional<Value *> &A, + const std::optional<Value *> &B, + Type *Ty) { if (A == B) return A; if (!B) @@ -326,14 +340,6 @@ static bool getPotentialCopiesOfMemoryValue( << " (only exact: " << OnlyExact << ")\n";); Value &Ptr = *I.getPointerOperand(); - SmallSetVector<Value *, 8> Objects; - if (!AA::getAssumedUnderlyingObjects(A, Ptr, Objects, QueryingAA, &I, - UsedAssumedInformation)) { - LLVM_DEBUG( - dbgs() << "Underlying objects stored into could not be determined\n";); - return false; - } - // Containers to remember the pointer infos and new copies while we are not // sure that we can find all of them. If we abort we want to avoid spurious // dependences and potential copies in the provided container. @@ -343,42 +349,43 @@ static bool getPotentialCopiesOfMemoryValue( const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*I.getFunction()); - LLVM_DEBUG(dbgs() << "Visit " << Objects.size() << " objects:\n"); - for (Value *Obj : Objects) { - LLVM_DEBUG(dbgs() << "Visit underlying object " << *Obj << "\n"); - if (isa<UndefValue>(Obj)) - continue; - if (isa<ConstantPointerNull>(Obj)) { + + auto Pred = [&](Value &Obj) { + LLVM_DEBUG(dbgs() << "Visit underlying object " << Obj << "\n"); + if (isa<UndefValue>(&Obj)) + return true; + if (isa<ConstantPointerNull>(&Obj)) { // A null pointer access can be undefined but any offset from null may // be OK. We do not try to optimize the latter. if (!NullPointerIsDefined(I.getFunction(), Ptr.getType()->getPointerAddressSpace()) && A.getAssumedSimplified(Ptr, QueryingAA, UsedAssumedInformation, - AA::Interprocedural) == Obj) - continue; + AA::Interprocedural) == &Obj) + return true; LLVM_DEBUG( dbgs() << "Underlying object is a valid nullptr, giving up.\n";); return false; } // TODO: Use assumed noalias return. - if (!isa<AllocaInst>(Obj) && !isa<GlobalVariable>(Obj) && - !(IsLoad ? isAllocationFn(Obj, TLI) : isNoAliasCall(Obj))) { - LLVM_DEBUG(dbgs() << "Underlying object is not supported yet: " << *Obj + if (!isa<AllocaInst>(&Obj) && !isa<GlobalVariable>(&Obj) && + !(IsLoad ? isAllocationFn(&Obj, TLI) : isNoAliasCall(&Obj))) { + LLVM_DEBUG(dbgs() << "Underlying object is not supported yet: " << Obj << "\n";); return false; } - if (auto *GV = dyn_cast<GlobalVariable>(Obj)) + if (auto *GV = dyn_cast<GlobalVariable>(&Obj)) if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer())) { LLVM_DEBUG(dbgs() << "Underlying object is global with external " "linkage, not supported yet: " - << *Obj << "\n";); + << Obj << "\n";); return false; } bool NullOnly = true; bool NullRequired = false; - auto CheckForNullOnlyAndUndef = [&](Optional<Value *> V, bool IsExact) { + auto CheckForNullOnlyAndUndef = [&](std::optional<Value *> V, + bool IsExact) { if (!V || *V == nullptr) NullOnly = false; else if (isa<UndefValue>(*V)) @@ -390,7 +397,7 @@ static bool getPotentialCopiesOfMemoryValue( }; auto CheckAccess = [&](const AAPointerInfo::Access &Acc, bool IsExact) { - if ((IsLoad && !Acc.isWrite()) || (!IsLoad && !Acc.isRead())) + if ((IsLoad && !Acc.isWriteOrAssumption()) || (!IsLoad && !Acc.isRead())) return true; if (IsLoad && Acc.isWrittenValueYetUndetermined()) return true; @@ -441,21 +448,27 @@ static bool getPotentialCopiesOfMemoryValue( // object. bool HasBeenWrittenTo = false; - auto &PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(*Obj), + AA::RangeTy Range; + auto &PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(Obj), DepClassTy::NONE); if (!PI.forallInterferingAccesses(A, QueryingAA, I, CheckAccess, - HasBeenWrittenTo)) { + HasBeenWrittenTo, Range)) { LLVM_DEBUG( dbgs() << "Failed to verify all interfering accesses for underlying object: " - << *Obj << "\n"); + << Obj << "\n"); return false; } - if (IsLoad && !HasBeenWrittenTo) { - Value *InitialValue = AA::getInitialValueForObj(*Obj, *I.getType(), TLI); - if (!InitialValue) + if (IsLoad && !HasBeenWrittenTo && !Range.isUnassigned()) { + const DataLayout &DL = A.getDataLayout(); + Value *InitialValue = + AA::getInitialValueForObj(Obj, *I.getType(), TLI, DL, &Range); + if (!InitialValue) { + LLVM_DEBUG(dbgs() << "Could not determine required initial value of " + "underlying object, abort!\n"); return false; + } CheckForNullOnlyAndUndef(InitialValue, /* IsExact */ true); if (NullRequired && !NullOnly) { LLVM_DEBUG(dbgs() << "Non exact access but initial value that is not " @@ -468,12 +481,22 @@ static bool getPotentialCopiesOfMemoryValue( } PIs.push_back(&PI); + + return true; + }; + + const auto &AAUO = A.getAAFor<AAUnderlyingObjects>( + QueryingAA, IRPosition::value(Ptr), DepClassTy::OPTIONAL); + if (!AAUO.forallUnderlyingObjects(Pred)) { + LLVM_DEBUG( + dbgs() << "Underlying objects stored into could not be determined\n";); + return false; } // Only if we were successful collection all potential copies we record // dependences (on non-fix AAPointerInfo AAs). We also only then modify the // given PotentialCopies container. - for (auto *PI : PIs) { + for (const auto *PI : PIs) { if (!PI->getState().isAtFixpoint()) UsedAssumedInformation = true; A.recordDependence(*PI, QueryingAA, DepClassTy::OPTIONAL); @@ -549,19 +572,27 @@ static bool isPotentiallyReachable(Attributor &A, const Instruction &FromI, const Instruction *ToI, const Function &ToFn, const AbstractAttribute &QueryingAA, + const AA::InstExclusionSetTy *ExclusionSet, std::function<bool(const Function &F)> GoBackwardsCB) { - LLVM_DEBUG(dbgs() << "[AA] isPotentiallyReachable @" << ToFn.getName() - << " from " << FromI << " [GBCB: " << bool(GoBackwardsCB) - << "]\n"); - - // TODO: If we can go arbitrarily backwards we will eventually reach an - // entry point that can reach ToI. Only once this takes a set of blocks - // through which we cannot go, or once we track internal functions not - // accessible from the outside, it makes sense to perform backwards analysis - // in the absence of a GoBackwardsCB. - if (!GoBackwardsCB) { + LLVM_DEBUG({ + dbgs() << "[AA] isPotentiallyReachable @" << ToFn.getName() << " from " + << FromI << " [GBCB: " << bool(GoBackwardsCB) << "][#ExS: " + << (ExclusionSet ? std::to_string(ExclusionSet->size()) : "none") + << "]\n"; + if (ExclusionSet) + for (auto *ES : *ExclusionSet) + dbgs() << *ES << "\n"; + }); + + // If we can go arbitrarily backwards we will eventually reach an entry point + // that can reach ToI. Only if a set of blocks through which we cannot go is + // provided, or once we track internal functions not accessible from the + // outside, it makes sense to perform backwards analysis in the absence of a + // GoBackwardsCB. + if (!GoBackwardsCB && !ExclusionSet) { LLVM_DEBUG(dbgs() << "[AA] check @" << ToFn.getName() << " from " << FromI - << " is not checked backwards, abort\n"); + << " is not checked backwards and does not have an " + "exclusion set, abort\n"); return true; } @@ -580,9 +611,10 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, return true; LLVM_DEBUG(dbgs() << "[AA] check " << *ToI << " from " << *CurFromI << " intraprocedurally\n"); - const auto &ReachabilityAA = A.getAAFor<AAReachability>( + const auto &ReachabilityAA = A.getAAFor<AAIntraFnReachability>( QueryingAA, IRPosition::function(ToFn), DepClassTy::OPTIONAL); - bool Result = ReachabilityAA.isAssumedReachable(A, *CurFromI, *ToI); + bool Result = + ReachabilityAA.isAssumedReachable(A, *CurFromI, *ToI, ExclusionSet); LLVM_DEBUG(dbgs() << "[AA] " << *CurFromI << " " << (Result ? "can potentially " : "cannot ") << "reach " << *ToI << " [Intra]\n"); @@ -590,16 +622,57 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, return true; } - // Check if the current instruction is already known to reach the ToFn. - const auto &FnReachabilityAA = A.getAAFor<AAFunctionReachability>( + bool Result = true; + if (!ToFn.isDeclaration() && ToI) { + const auto &ToReachabilityAA = A.getAAFor<AAIntraFnReachability>( + QueryingAA, IRPosition::function(ToFn), DepClassTy::OPTIONAL); + const Instruction &EntryI = ToFn.getEntryBlock().front(); + Result = + ToReachabilityAA.isAssumedReachable(A, EntryI, *ToI, ExclusionSet); + LLVM_DEBUG(dbgs() << "[AA] Entry " << EntryI << " of @" << ToFn.getName() + << " " << (Result ? "can potentially " : "cannot ") + << "reach @" << *ToI << " [ToFn]\n"); + } + + if (Result) { + // The entry of the ToFn can reach the instruction ToI. If the current + // instruction is already known to reach the ToFn. + const auto &FnReachabilityAA = A.getAAFor<AAInterFnReachability>( + QueryingAA, IRPosition::function(*FromFn), DepClassTy::OPTIONAL); + Result = FnReachabilityAA.instructionCanReach(A, *CurFromI, ToFn, + ExclusionSet); + LLVM_DEBUG(dbgs() << "[AA] " << *CurFromI << " in @" << FromFn->getName() + << " " << (Result ? "can potentially " : "cannot ") + << "reach @" << ToFn.getName() << " [FromFn]\n"); + if (Result) + return true; + } + + // TODO: Check assumed nounwind. + const auto &ReachabilityAA = A.getAAFor<AAIntraFnReachability>( QueryingAA, IRPosition::function(*FromFn), DepClassTy::OPTIONAL); - bool Result = FnReachabilityAA.instructionCanReach( - A, *CurFromI, ToFn); - LLVM_DEBUG(dbgs() << "[AA] " << *CurFromI << " in @" << FromFn->getName() - << " " << (Result ? "can potentially " : "cannot ") - << "reach @" << ToFn.getName() << " [FromFn]\n"); - if (Result) + auto ReturnInstCB = [&](Instruction &Ret) { + bool Result = + ReachabilityAA.isAssumedReachable(A, *CurFromI, Ret, ExclusionSet); + LLVM_DEBUG(dbgs() << "[AA][Ret] " << *CurFromI << " " + << (Result ? "can potentially " : "cannot ") << "reach " + << Ret << " [Intra]\n"); + return !Result; + }; + + // Check if we can reach returns. + bool UsedAssumedInformation = false; + if (A.checkForAllInstructions(ReturnInstCB, FromFn, QueryingAA, + {Instruction::Ret}, UsedAssumedInformation)) { + LLVM_DEBUG(dbgs() << "[AA] No return is reachable, done\n"); + continue; + } + + if (!GoBackwardsCB) { + LLVM_DEBUG(dbgs() << "[AA] check @" << ToFn.getName() << " from " << FromI + << " is not checked backwards, abort\n"); return true; + } // If we do not go backwards from the FromFn we are done here and so far we // could not find a way to reach ToFn/ToI. @@ -622,7 +695,6 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, return true; }; - bool UsedAssumedInformation = false; Result = !A.checkForAllCallSites(CheckCallSite, *FromFn, /* RequireAllCallSites */ true, &QueryingAA, UsedAssumedInformation); @@ -643,20 +715,128 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, bool AA::isPotentiallyReachable( Attributor &A, const Instruction &FromI, const Instruction &ToI, const AbstractAttribute &QueryingAA, + const AA::InstExclusionSetTy *ExclusionSet, std::function<bool(const Function &F)> GoBackwardsCB) { - LLVM_DEBUG(dbgs() << "[AA] isPotentiallyReachable " << ToI << " from " - << FromI << " [GBCB: " << bool(GoBackwardsCB) << "]\n"); const Function *ToFn = ToI.getFunction(); return ::isPotentiallyReachable(A, FromI, &ToI, *ToFn, QueryingAA, - GoBackwardsCB); + ExclusionSet, GoBackwardsCB); } bool AA::isPotentiallyReachable( Attributor &A, const Instruction &FromI, const Function &ToFn, const AbstractAttribute &QueryingAA, + const AA::InstExclusionSetTy *ExclusionSet, std::function<bool(const Function &F)> GoBackwardsCB) { return ::isPotentiallyReachable(A, FromI, /* ToI */ nullptr, ToFn, QueryingAA, - GoBackwardsCB); + ExclusionSet, GoBackwardsCB); +} + +bool AA::isAssumedThreadLocalObject(Attributor &A, Value &Obj, + const AbstractAttribute &QueryingAA) { + if (isa<UndefValue>(Obj)) + return true; + if (isa<AllocaInst>(Obj)) { + InformationCache &InfoCache = A.getInfoCache(); + if (!InfoCache.stackIsAccessibleByOtherThreads()) { + LLVM_DEBUG( + dbgs() << "[AA] Object '" << Obj + << "' is thread local; stack objects are thread local.\n"); + return true; + } + const auto &NoCaptureAA = A.getAAFor<AANoCapture>( + QueryingAA, IRPosition::value(Obj), DepClassTy::OPTIONAL); + LLVM_DEBUG(dbgs() << "[AA] Object '" << Obj << "' is " + << (NoCaptureAA.isAssumedNoCapture() ? "" : "not") + << " thread local; " + << (NoCaptureAA.isAssumedNoCapture() ? "non-" : "") + << "captured stack object.\n"); + return NoCaptureAA.isAssumedNoCapture(); + } + if (auto *GV = dyn_cast<GlobalVariable>(&Obj)) { + if (GV->isConstant()) { + LLVM_DEBUG(dbgs() << "[AA] Object '" << Obj + << "' is thread local; constant global\n"); + return true; + } + if (GV->isThreadLocal()) { + LLVM_DEBUG(dbgs() << "[AA] Object '" << Obj + << "' is thread local; thread local global\n"); + return true; + } + } + + if (A.getInfoCache().targetIsGPU()) { + if (Obj.getType()->getPointerAddressSpace() == + (int)AA::GPUAddressSpace::Local) { + LLVM_DEBUG(dbgs() << "[AA] Object '" << Obj + << "' is thread local; GPU local memory\n"); + return true; + } + if (Obj.getType()->getPointerAddressSpace() == + (int)AA::GPUAddressSpace::Constant) { + LLVM_DEBUG(dbgs() << "[AA] Object '" << Obj + << "' is thread local; GPU constant memory\n"); + return true; + } + } + + LLVM_DEBUG(dbgs() << "[AA] Object '" << Obj << "' is not thread local\n"); + return false; +} + +bool AA::isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, + const AbstractAttribute &QueryingAA) { + if (!I.mayHaveSideEffects() && !I.mayReadFromMemory()) + return false; + + SmallSetVector<const Value *, 8> Ptrs; + + auto AddLocationPtr = [&](std::optional<MemoryLocation> Loc) { + if (!Loc || !Loc->Ptr) { + LLVM_DEBUG( + dbgs() << "[AA] Access to unknown location; -> requires barriers\n"); + return false; + } + Ptrs.insert(Loc->Ptr); + return true; + }; + + if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(&I)) { + if (!AddLocationPtr(MemoryLocation::getForDest(MI))) + return true; + if (const MemTransferInst *MTI = dyn_cast<MemTransferInst>(&I)) + if (!AddLocationPtr(MemoryLocation::getForSource(MTI))) + return true; + } else if (!AddLocationPtr(MemoryLocation::getOrNone(&I))) + return true; + + return isPotentiallyAffectedByBarrier(A, Ptrs.getArrayRef(), QueryingAA, &I); +} + +bool AA::isPotentiallyAffectedByBarrier(Attributor &A, + ArrayRef<const Value *> Ptrs, + const AbstractAttribute &QueryingAA, + const Instruction *CtxI) { + for (const Value *Ptr : Ptrs) { + if (!Ptr) { + LLVM_DEBUG(dbgs() << "[AA] nullptr; -> requires barriers\n"); + return true; + } + + auto Pred = [&](Value &Obj) { + if (AA::isAssumedThreadLocalObject(A, Obj, QueryingAA)) + return true; + LLVM_DEBUG(dbgs() << "[AA] Access to '" << Obj << "' via '" << *Ptr + << "'; -> requires barrier\n"); + return false; + }; + + const auto &UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>( + QueryingAA, IRPosition::value(*Ptr), DepClassTy::OPTIONAL); + if (!UnderlyingObjsAA.forallUnderlyingObjects(Pred)) + return true; + } + return false; } /// Return true if \p New is equal or worse than \p Old. @@ -720,7 +900,7 @@ Argument *IRPosition::getAssociatedArgument() const { // values and the ones in callbacks. If a callback was found that makes use // of the underlying call site operand, we want the corresponding callback // callee argument and not the direct callee argument. - Optional<Argument *> CBCandidateArg; + std::optional<Argument *> CBCandidateArg; SmallVector<const Use *, 4> CallbackUses; const auto &CB = cast<CallBase>(getAnchorValue()); AbstractCallSite::getCallbackUses(CB, CallbackUses); @@ -748,8 +928,8 @@ Argument *IRPosition::getAssociatedArgument() const { } // If we found a unique callback candidate argument, return it. - if (CBCandidateArg && CBCandidateArg.value()) - return CBCandidateArg.value(); + if (CBCandidateArg && *CBCandidateArg) + return *CBCandidateArg; // If no callbacks were found, or none used the underlying call site operand // exclusively, use the direct callee argument if available. @@ -977,7 +1157,7 @@ bool IRPosition::getAttrsFromAssumes(Attribute::AttrKind AK, MustBeExecutedContextExplorer &Explorer = A.getInfoCache().getMustBeExecutedContextExplorer(); auto EIt = Explorer.begin(getCtxI()), EEnd = Explorer.end(getCtxI()); - for (auto &It : A2K) + for (const auto &It : A2K) if (Explorer.findInContextOf(It.first, EIt, EEnd)) Attrs.push_back(Attribute::get(Ctx, AK, It.second.Max)); return AttrsSize != Attrs.size(); @@ -1051,17 +1231,17 @@ void IRPosition::verify() { #endif } -Optional<Constant *> +std::optional<Constant *> Attributor::getAssumedConstant(const IRPosition &IRP, const AbstractAttribute &AA, bool &UsedAssumedInformation) { // First check all callbacks provided by outside AAs. If any of them returns - // a non-null value that is different from the associated value, or None, we - // assume it's simplified. + // a non-null value that is different from the associated value, or + // std::nullopt, we assume it's simplified. for (auto &CB : SimplificationCallbacks.lookup(IRP)) { - Optional<Value *> SimplifiedV = CB(IRP, &AA, UsedAssumedInformation); + std::optional<Value *> SimplifiedV = CB(IRP, &AA, UsedAssumedInformation); if (!SimplifiedV) - return llvm::None; + return std::nullopt; if (isa_and_nonnull<Constant>(*SimplifiedV)) return cast<Constant>(*SimplifiedV); return nullptr; @@ -1073,7 +1253,7 @@ Attributor::getAssumedConstant(const IRPosition &IRP, AA::ValueScope::Interprocedural, UsedAssumedInformation)) { if (Values.empty()) - return llvm::None; + return std::nullopt; if (auto *C = dyn_cast_or_null<Constant>( AAPotentialValues::getSingleValue(*this, AA, IRP, Values))) return C; @@ -1081,13 +1261,12 @@ Attributor::getAssumedConstant(const IRPosition &IRP, return nullptr; } -Optional<Value *> Attributor::getAssumedSimplified(const IRPosition &IRP, - const AbstractAttribute *AA, - bool &UsedAssumedInformation, - AA::ValueScope S) { +std::optional<Value *> Attributor::getAssumedSimplified( + const IRPosition &IRP, const AbstractAttribute *AA, + bool &UsedAssumedInformation, AA::ValueScope S) { // First check all callbacks provided by outside AAs. If any of them returns - // a non-null value that is different from the associated value, or None, we - // assume it's simplified. + // a non-null value that is different from the associated value, or + // std::nullopt, we assume it's simplified. for (auto &CB : SimplificationCallbacks.lookup(IRP)) return CB(IRP, AA, UsedAssumedInformation); @@ -1095,7 +1274,7 @@ Optional<Value *> Attributor::getAssumedSimplified(const IRPosition &IRP, if (!getAssumedSimplifiedValues(IRP, AA, Values, S, UsedAssumedInformation)) return &IRP.getAssociatedValue(); if (Values.empty()) - return llvm::None; + return std::nullopt; if (AA) if (Value *V = AAPotentialValues::getSingleValue(*this, *AA, IRP, Values)) return V; @@ -1110,14 +1289,14 @@ bool Attributor::getAssumedSimplifiedValues( SmallVectorImpl<AA::ValueAndContext> &Values, AA::ValueScope S, bool &UsedAssumedInformation) { // First check all callbacks provided by outside AAs. If any of them returns - // a non-null value that is different from the associated value, or None, we - // assume it's simplified. + // a non-null value that is different from the associated value, or + // std::nullopt, we assume it's simplified. const auto &SimplificationCBs = SimplificationCallbacks.lookup(IRP); - for (auto &CB : SimplificationCBs) { - Optional<Value *> CBResult = CB(IRP, AA, UsedAssumedInformation); + for (const auto &CB : SimplificationCBs) { + std::optional<Value *> CBResult = CB(IRP, AA, UsedAssumedInformation); if (!CBResult.has_value()) continue; - Value *V = CBResult.value(); + Value *V = *CBResult; if (!V) return false; if ((S & AA::ValueScope::Interprocedural) || @@ -1138,8 +1317,8 @@ bool Attributor::getAssumedSimplifiedValues( return true; } -Optional<Value *> Attributor::translateArgumentToCallSiteContent( - Optional<Value *> V, CallBase &CB, const AbstractAttribute &AA, +std::optional<Value *> Attributor::translateArgumentToCallSiteContent( + std::optional<Value *> V, CallBase &CB, const AbstractAttribute &AA, bool &UsedAssumedInformation) { if (!V) return V; @@ -1157,8 +1336,8 @@ Optional<Value *> Attributor::translateArgumentToCallSiteContent( Attributor::~Attributor() { // The abstract attributes are allocated via the BumpPtrAllocator Allocator, // thus we cannot delete them. We can, and want to, destruct them though. - for (auto &DepAA : DG.SyntheticRoot.Deps) { - AbstractAttribute *AA = cast<AbstractAttribute>(DepAA.getPointer()); + for (auto &It : AAMap) { + AbstractAttribute *AA = It.getSecond(); AA->~AbstractAttribute(); } } @@ -1225,23 +1404,26 @@ bool Attributor::isAssumedDead(const Instruction &I, const AbstractAttribute *QueryingAA, const AAIsDead *FnLivenessAA, bool &UsedAssumedInformation, - bool CheckBBLivenessOnly, DepClassTy DepClass) { + bool CheckBBLivenessOnly, DepClassTy DepClass, + bool CheckForDeadStore) { const IRPosition::CallBaseContext *CBCtx = QueryingAA ? QueryingAA->getCallBaseContext() : nullptr; if (ManifestAddedBlocks.contains(I.getParent())) return false; - if (!FnLivenessAA) - FnLivenessAA = - lookupAAFor<AAIsDead>(IRPosition::function(*I.getFunction(), CBCtx), - QueryingAA, DepClassTy::NONE); + const Function &F = *I.getFunction(); + if (!FnLivenessAA || FnLivenessAA->getAnchorScope() != &F) + FnLivenessAA = &getOrCreateAAFor<AAIsDead>(IRPosition::function(F, CBCtx), + QueryingAA, DepClassTy::NONE); + + // Don't use recursive reasoning. + if (QueryingAA == FnLivenessAA) + return false; // If we have a context instruction and a liveness AA we use it. - if (FnLivenessAA && - FnLivenessAA->getIRPosition().getAnchorScope() == I.getFunction() && - (CheckBBLivenessOnly ? FnLivenessAA->isAssumedDead(I.getParent()) - : FnLivenessAA->isAssumedDead(&I))) { + if (CheckBBLivenessOnly ? FnLivenessAA->isAssumedDead(I.getParent()) + : FnLivenessAA->isAssumedDead(&I)) { if (QueryingAA) recordDependence(*FnLivenessAA, *QueryingAA, DepClass); if (!FnLivenessAA->isKnownDead(&I)) @@ -1255,7 +1437,8 @@ bool Attributor::isAssumedDead(const Instruction &I, const IRPosition IRP = IRPosition::inst(I, CBCtx); const AAIsDead &IsDeadAA = getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE); - // Don't check liveness for AAIsDead. + + // Don't use recursive reasoning. if (QueryingAA == &IsDeadAA) return false; @@ -1267,6 +1450,14 @@ bool Attributor::isAssumedDead(const Instruction &I, return true; } + if (CheckForDeadStore && isa<StoreInst>(I) && IsDeadAA.isRemovableStore()) { + if (QueryingAA) + recordDependence(IsDeadAA, *QueryingAA, DepClass); + if (!IsDeadAA.isKnownDead()) + UsedAssumedInformation = true; + return true; + } + return false; } @@ -1275,6 +1466,13 @@ bool Attributor::isAssumedDead(const IRPosition &IRP, const AAIsDead *FnLivenessAA, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, DepClassTy DepClass) { + // Don't check liveness for constants, e.g. functions, used as (floating) + // values since the context instruction and such is here meaningless. + if (IRP.getPositionKind() == IRPosition::IRP_FLOAT && + isa<Constant>(IRP.getAssociatedValue())) { + return false; + } + Instruction *CtxI = IRP.getCtxI(); if (CtxI && isAssumedDead(*CtxI, QueryingAA, FnLivenessAA, UsedAssumedInformation, @@ -1293,7 +1491,8 @@ bool Attributor::isAssumedDead(const IRPosition &IRP, QueryingAA, DepClassTy::NONE); else IsDeadAA = &getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE); - // Don't check liveness for AAIsDead. + + // Don't use recursive reasoning. if (QueryingAA == IsDeadAA) return false; @@ -1312,9 +1511,15 @@ bool Attributor::isAssumedDead(const BasicBlock &BB, const AbstractAttribute *QueryingAA, const AAIsDead *FnLivenessAA, DepClassTy DepClass) { - if (!FnLivenessAA) - FnLivenessAA = lookupAAFor<AAIsDead>(IRPosition::function(*BB.getParent()), - QueryingAA, DepClassTy::NONE); + const Function &F = *BB.getParent(); + if (!FnLivenessAA || FnLivenessAA->getAnchorScope() != &F) + FnLivenessAA = &getOrCreateAAFor<AAIsDead>(IRPosition::function(F), + QueryingAA, DepClassTy::NONE); + + // Don't use recursive reasoning. + if (QueryingAA == FnLivenessAA) + return false; + if (FnLivenessAA->isAssumedDead(&BB)) { if (QueryingAA) recordDependence(*FnLivenessAA, *QueryingAA, DepClass); @@ -1331,6 +1536,11 @@ bool Attributor::checkForAllUses( bool IgnoreDroppableUses, function_ref<bool(const Use &OldU, const Use &NewU)> EquivalentUseCB) { + // Check virtual uses first. + for (VirtualUseCallbackTy &CB : VirtualUseCallbacks.lookup(&V)) + if (!CB(*this, &QueryingAA)) + return false; + // Check the trivial case first as it catches void values. if (V.use_empty()) return true; @@ -1368,7 +1578,7 @@ bool Attributor::checkForAllUses( const Use *U = Worklist.pop_back_val(); if (isa<PHINode>(U->getUser()) && !Visited.insert(U).second) continue; - LLVM_DEBUG({ + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, { if (auto *Fn = dyn_cast<Function>(U->getUser())) dbgs() << "[Attributor] Check use: " << **U << " in " << Fn->getName() << "\n"; @@ -1379,11 +1589,13 @@ bool Attributor::checkForAllUses( bool UsedAssumedInformation = false; if (isAssumedDead(*U, &QueryingAA, LivenessAA, UsedAssumedInformation, CheckBBLivenessOnly, LivenessDepClass)) { - LLVM_DEBUG(dbgs() << "[Attributor] Dead use, skip!\n"); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, + dbgs() << "[Attributor] Dead use, skip!\n"); continue; } if (IgnoreDroppableUses && U->getUser()->isDroppable()) { - LLVM_DEBUG(dbgs() << "[Attributor] Droppable user, skip!\n"); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, + dbgs() << "[Attributor] Droppable user, skip!\n"); continue; } @@ -1395,9 +1607,11 @@ bool Attributor::checkForAllUses( if (AA::getPotentialCopiesOfStoredValue( *this, *SI, PotentialCopies, QueryingAA, UsedAssumedInformation, /* OnlyExact */ true)) { - LLVM_DEBUG(dbgs() << "[Attributor] Value is stored, continue with " - << PotentialCopies.size() - << " potential copies instead!\n"); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, + dbgs() + << "[Attributor] Value is stored, continue with " + << PotentialCopies.size() + << " potential copies instead!\n"); for (Value *PotentialCopy : PotentialCopies) if (!AddUsers(*PotentialCopy, U)) return false; @@ -1458,7 +1672,8 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred, const Function &Fn, bool RequireAllCallSites, const AbstractAttribute *QueryingAA, - bool &UsedAssumedInformation) { + bool &UsedAssumedInformation, + bool CheckPotentiallyDead) { if (RequireAllCallSites && !Fn.hasLocalLinkage()) { LLVM_DEBUG( dbgs() @@ -1466,11 +1681,15 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred, << " has no internal linkage, hence not all call sites are known\n"); return false; } + // Check virtual uses first. + for (VirtualUseCallbackTy &CB : VirtualUseCallbacks.lookup(&Fn)) + if (!CB(*this, QueryingAA)) + return false; SmallVector<const Use *, 8> Uses(make_pointer_range(Fn.uses())); for (unsigned u = 0; u < Uses.size(); ++u) { const Use &U = *Uses[u]; - LLVM_DEBUG({ + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, { if (auto *Fn = dyn_cast<Function>(U)) dbgs() << "[Attributor] Check use: " << Fn->getName() << " in " << *U.getUser() << "\n"; @@ -1478,17 +1697,19 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred, dbgs() << "[Attributor] Check use: " << *U << " in " << *U.getUser() << "\n"; }); - if (isAssumedDead(U, QueryingAA, nullptr, UsedAssumedInformation, + if (!CheckPotentiallyDead && + isAssumedDead(U, QueryingAA, nullptr, UsedAssumedInformation, /* CheckBBLivenessOnly */ true)) { - LLVM_DEBUG(dbgs() << "[Attributor] Dead use, skip!\n"); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, + dbgs() << "[Attributor] Dead use, skip!\n"); continue; } if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U.getUser())) { if (CE->isCast() && CE->getType()->isPointerTy()) { - LLVM_DEBUG( - dbgs() << "[Attributor] Use, is constant cast expression, add " - << CE->getNumUses() - << " uses of that expression instead!\n"); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, { + dbgs() << "[Attributor] Use, is constant cast expression, add " + << CE->getNumUses() << " uses of that expression instead!\n"; + }); for (const Use &CEU : CE->uses()) Uses.push_back(&CEU); continue; @@ -1618,8 +1839,9 @@ static bool checkForAllInstructionsImpl( if (A && !CheckPotentiallyDead && A->isAssumedDead(IRPosition::inst(*I), QueryingAA, LivenessAA, UsedAssumedInformation, CheckBBLivenessOnly)) { - LLVM_DEBUG(dbgs() << "[Attributor] Instruction " << *I - << " is potentially dead, skip!\n";); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, + dbgs() << "[Attributor] Instruction " << *I + << " is potentially dead, skip!\n";); continue; } @@ -1728,19 +1950,22 @@ void Attributor::runTillFixpoint() { AbstractAttribute *InvalidAA = InvalidAAs[u]; // Check the dependences to fast track invalidation. - LLVM_DEBUG(dbgs() << "[Attributor] InvalidAA: " << *InvalidAA << " has " - << InvalidAA->Deps.size() - << " required & optional dependences\n"); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, + dbgs() << "[Attributor] InvalidAA: " << *InvalidAA + << " has " << InvalidAA->Deps.size() + << " required & optional dependences\n"); while (!InvalidAA->Deps.empty()) { const auto &Dep = InvalidAA->Deps.back(); InvalidAA->Deps.pop_back(); AbstractAttribute *DepAA = cast<AbstractAttribute>(Dep.getPointer()); if (Dep.getInt() == unsigned(DepClassTy::OPTIONAL)) { - LLVM_DEBUG(dbgs() << " - recompute: " << *DepAA); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, + dbgs() << " - recompute: " << *DepAA); Worklist.insert(DepAA); continue; } - LLVM_DEBUG(dbgs() << " - invalidate: " << *DepAA); + DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, dbgs() + << " - invalidate: " << *DepAA); DepAA->getState().indicatePessimisticFixpoint(); assert(DepAA->getState().isAtFixpoint() && "Expected fixpoint state!"); if (!DepAA->getState().isValidState()) @@ -1935,13 +2160,23 @@ void Attributor::identifyDeadInternalFunctions() { if (!Configuration.DeleteFns) return; + // To avoid triggering an assertion in the lazy call graph we will not delete + // any internal library functions. We should modify the assertion though and + // allow internals to be deleted. + const auto *TLI = + isModulePass() + ? nullptr + : getInfoCache().getTargetLibraryInfoForFunction(*Functions.back()); + LibFunc LF; + // Identify dead internal functions and delete them. This happens outside // the other fixpoint analysis as we might treat potentially dead functions // as live to lower the number of iterations. If they happen to be dead, the // below fixpoint loop will identify and eliminate them. + SmallVector<Function *, 8> InternalFns; for (Function *F : Functions) - if (F->hasLocalLinkage()) + if (F->hasLocalLinkage() && (isModulePass() || !TLI->getLibFunc(*F, LF))) InternalFns.push_back(F); SmallPtrSet<Function *, 8> LiveInternalFns; @@ -1999,9 +2234,9 @@ ChangeStatus Attributor::cleanupIR() { // If we plan to replace NewV we need to update it at this point. do { const auto &Entry = ToBeChangedValues.lookup(NewV); - if (!Entry.first) + if (!get<0>(Entry)) break; - NewV = Entry.first; + NewV = get<0>(Entry); } while (true); Instruction *I = dyn_cast<Instruction>(U->getUser()); @@ -2021,11 +2256,6 @@ ChangeStatus Attributor::cleanupIR() { Arg.removeAttr(Attribute::Returned); } - // Do not perform call graph altering changes outside the SCC. - if (auto *CB = dyn_cast_or_null<CallBase>(I)) - if (CB->isCallee(U)) - return; - LLVM_DEBUG(dbgs() << "Use " << *NewV << " in " << *U->getUser() << " instead of " << *OldV << "\n"); U->set(NewV); @@ -2065,11 +2295,10 @@ ChangeStatus Attributor::cleanupIR() { SmallVector<Use *, 4> Uses; for (auto &It : ToBeChangedValues) { Value *OldV = It.first; - auto &Entry = It.second; - Value *NewV = Entry.first; + auto [NewV, Done] = It.second; Uses.clear(); for (auto &U : OldV->uses()) - if (Entry.second || !U.getUser()->isDroppable()) + if (Done || !U.getUser()->isDroppable()) Uses.push_back(&U); for (Use *U : Uses) { if (auto *I = dyn_cast<Instruction>(U->getUser())) @@ -2079,7 +2308,7 @@ ChangeStatus Attributor::cleanupIR() { } } - for (auto &V : InvokeWithDeadSuccessor) + for (const auto &V : InvokeWithDeadSuccessor) if (InvokeInst *II = dyn_cast_or_null<InvokeInst>(V)) { assert(isRunOn(*II->getFunction()) && "Cannot replace an invoke outside the current SCC!"); @@ -2112,7 +2341,7 @@ ChangeStatus Attributor::cleanupIR() { CGModifiedFunctions.insert(I->getFunction()); ConstantFoldTerminator(I->getParent()); } - for (auto &V : ToBeChangedToUnreachableInsts) + for (const auto &V : ToBeChangedToUnreachableInsts) if (Instruction *I = dyn_cast_or_null<Instruction>(V)) { LLVM_DEBUG(dbgs() << "[Attributor] Change to unreachable: " << *I << "\n"); @@ -2122,10 +2351,10 @@ ChangeStatus Attributor::cleanupIR() { changeToUnreachable(I); } - for (auto &V : ToBeDeletedInsts) { + for (const auto &V : ToBeDeletedInsts) { if (Instruction *I = dyn_cast_or_null<Instruction>(V)) { if (auto *CB = dyn_cast<CallBase>(I)) { - assert(isRunOn(*I->getFunction()) && + assert((isa<IntrinsicInst>(CB) || isRunOn(*I->getFunction())) && "Cannot delete an instruction outside the current SCC!"); if (!isa<IntrinsicInst>(CB)) Configuration.CGUpdater.removeCallSite(*CB); @@ -2272,10 +2501,20 @@ ChangeStatus Attributor::updateAA(AbstractAttribute &AA) { /* CheckBBLivenessOnly */ true)) CS = AA.update(*this); - if (!AA.isQueryAA() && DV.empty()) { - // If the attribute did not query any non-fix information, the state - // will not change and we can indicate that right away. - AAState.indicateOptimisticFixpoint(); + if (!AA.isQueryAA() && DV.empty() && !AA.getState().isAtFixpoint()) { + // If the AA did not rely on outside information but changed, we run it + // again to see if it found a fixpoint. Most AAs do but we don't require + // them to. Hence, it might take the AA multiple iterations to get to a + // fixpoint even if it does not rely on outside information, which is fine. + ChangeStatus RerunCS = ChangeStatus::UNCHANGED; + if (CS == ChangeStatus::CHANGED) + RerunCS = AA.update(*this); + + // If the attribute did not change during the run or rerun, and it still did + // not query any non-fix information, the state will not change and we can + // indicate that right at this point. + if (RerunCS == ChangeStatus::UNCHANGED && !AA.isQueryAA() && DV.empty()) + AAState.indicateOptimisticFixpoint(); } if (!AAState.isAtFixpoint()) @@ -2572,8 +2811,9 @@ ChangeStatus Attributor::rewriteFunctionSignatures( uint64_t LargestVectorWidth = 0; for (auto *I : NewArgumentTypes) if (auto *VT = dyn_cast<llvm::VectorType>(I)) - LargestVectorWidth = std::max( - LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinSize()); + LargestVectorWidth = + std::max(LargestVectorWidth, + VT->getPrimitiveSizeInBits().getKnownMinValue()); FunctionType *OldFnTy = OldFn->getFunctionType(); Type *RetTy = OldFnTy->getReturnType(); @@ -2609,8 +2849,7 @@ ChangeStatus Attributor::rewriteFunctionSignatures( // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the // function empty. - NewFn->getBasicBlockList().splice(NewFn->begin(), - OldFn->getBasicBlockList()); + NewFn->splice(NewFn->begin(), OldFn); // Fixup block addresses to reference new function. SmallVector<BlockAddress *, 8u> BlockAddresses; @@ -2692,7 +2931,8 @@ ChangeStatus Attributor::rewriteFunctionSignatures( // Use the CallSiteReplacementCreator to create replacement call sites. bool UsedAssumedInformation = false; bool Success = checkForAllCallSites(CallSiteReplacementCreator, *OldFn, - true, nullptr, UsedAssumedInformation); + true, nullptr, UsedAssumedInformation, + /* CheckPotentiallyDead */ true); (void)Success; assert(Success && "Assumed call site replacement to succeed!"); @@ -2753,7 +2993,7 @@ void InformationCache::initializeInformationCache(const Function &CF, // queried by abstract attributes during their initialization or update. // This has to happen before we create attributes. - DenseMap<const Value *, Optional<short>> AssumeUsesMap; + DenseMap<const Value *, std::optional<short>> AssumeUsesMap; // Add \p V to the assume uses map which track the number of uses outside of // "visited" assumes. If no outside uses are left the value is added to the @@ -2764,11 +3004,11 @@ void InformationCache::initializeInformationCache(const Function &CF, Worklist.push_back(I); while (!Worklist.empty()) { const Instruction *I = Worklist.pop_back_val(); - Optional<short> &NumUses = AssumeUsesMap[I]; + std::optional<short> &NumUses = AssumeUsesMap[I]; if (!NumUses) NumUses = I->getNumUses(); - NumUses = NumUses.value() - /* this assume */ 1; - if (NumUses.value() != 0) + NumUses = *NumUses - /* this assume */ 1; + if (*NumUses != 0) continue; AssumeOnlyValues.insert(I); for (const Value *Op : I->operands()) @@ -2796,6 +3036,7 @@ void InformationCache::initializeInformationCache(const Function &CF, // For `llvm.assume` calls we also fill the KnowledgeMap as we find them. // For `must-tail` calls we remember the caller and callee. if (auto *Assume = dyn_cast<AssumeInst>(&I)) { + AssumeOnlyValues.insert(Assume); fillMapFromAssume(*Assume, KnowledgeMap); AddToAssumeUsesMap(*Assume->getArgOperand(0)); } else if (cast<CallInst>(I).isMustTailCall()) { @@ -2803,7 +3044,7 @@ void InformationCache::initializeInformationCache(const Function &CF, if (const Function *Callee = cast<CallInst>(I).getCalledFunction()) getFunctionInfo(*Callee).CalledViaMustTail = true; } - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::CallBr: case Instruction::Invoke: case Instruction::CleanupRet: @@ -3190,7 +3431,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, if (!S.isValidState()) OS << "full-set"; else { - for (auto &It : S.getAssumedSet()) + for (const auto &It : S.getAssumedSet()) OS << It << ", "; if (S.undefIsContained()) OS << "undef "; @@ -3206,7 +3447,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, if (!S.isValidState()) OS << "full-set"; else { - for (auto &It : S.getAssumedSet()) { + for (const auto &It : S.getAssumedSet()) { if (auto *F = dyn_cast<Function>(It.first.getValue())) OS << "@" << F->getName() << "[" << int(It.second) << "], "; else @@ -3298,7 +3539,7 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, // Internalize non-exact functions // TODO: for now we eagerly internalize functions without calculating the // cost, we need a cost interface to determine whether internalizing - // a function is "benefitial" + // a function is "beneficial" if (AllowDeepWrapper) { unsigned FunSize = Functions.size(); for (unsigned u = 0; u < FunSize; u++) { diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 83252fec3ea8..001ef55ba472 100644 --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SCCIterator.h" @@ -27,6 +28,7 @@ #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/CycleAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -43,10 +45,13 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/NoFolder.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" @@ -60,6 +65,8 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <cassert> +#include <numeric> +#include <optional> using namespace llvm; @@ -91,12 +98,6 @@ static cl::opt<int> MaxPotentialValuesIterations( "Maximum number of iterations we keep dismantling potential values."), cl::init(64)); -static cl::opt<unsigned> MaxInterferingAccesses( - "attributor-max-interfering-accesses", cl::Hidden, - cl::desc("Maximum number of interfering accesses to " - "check before assuming all might interfere."), - cl::init(6)); - STATISTIC(NumAAs, "Number of abstract attributes created"); // Some helper macros to deal with statistics tracking. @@ -167,7 +168,7 @@ PIPE_OPERATOR(AANoCapture) PIPE_OPERATOR(AAValueSimplify) PIPE_OPERATOR(AANoFree) PIPE_OPERATOR(AAHeapToStack) -PIPE_OPERATOR(AAReachability) +PIPE_OPERATOR(AAIntraFnReachability) PIPE_OPERATOR(AAMemoryBehavior) PIPE_OPERATOR(AAMemoryLocation) PIPE_OPERATOR(AAValueConstantRange) @@ -177,9 +178,10 @@ PIPE_OPERATOR(AAPotentialConstantValues) PIPE_OPERATOR(AAPotentialValues) PIPE_OPERATOR(AANoUndef) PIPE_OPERATOR(AACallEdges) -PIPE_OPERATOR(AAFunctionReachability) +PIPE_OPERATOR(AAInterFnReachability) PIPE_OPERATOR(AAPointerInfo) PIPE_OPERATOR(AAAssumptionInfo) +PIPE_OPERATOR(AAUnderlyingObjects) #undef PIPE_OPERATOR @@ -306,38 +308,6 @@ static Value *constructPointer(Type *ResTy, Type *PtrElemTy, Value *Ptr, return Ptr; } -bool AA::getAssumedUnderlyingObjects(Attributor &A, const Value &Ptr, - SmallSetVector<Value *, 8> &Objects, - const AbstractAttribute &QueryingAA, - const Instruction *CtxI, - bool &UsedAssumedInformation, - AA::ValueScope S, - SmallPtrSetImpl<Value *> *SeenObjects) { - SmallPtrSet<Value *, 8> LocalSeenObjects; - if (!SeenObjects) - SeenObjects = &LocalSeenObjects; - - SmallVector<AA::ValueAndContext> Values; - if (!A.getAssumedSimplifiedValues(IRPosition::value(Ptr), &QueryingAA, Values, - S, UsedAssumedInformation)) { - Objects.insert(const_cast<Value *>(&Ptr)); - return true; - } - - for (auto &VAC : Values) { - Value *UO = getUnderlyingObject(VAC.getValue()); - if (UO && UO != VAC.getValue() && SeenObjects->insert(UO).second) { - if (!getAssumedUnderlyingObjects(A, *UO, Objects, QueryingAA, - VAC.getCtxI(), UsedAssumedInformation, S, - SeenObjects)) - return false; - continue; - } - Objects.insert(VAC.getValue()); - } - return true; -} - static const Value * stripAndAccumulateOffsets(Attributor &A, const AbstractAttribute &QueryingAA, const Value *Val, const DataLayout &DL, APInt &Offset, @@ -401,7 +371,7 @@ static void clampReturnedValueStates( // Use an optional state as there might not be any return values and we want // to join (IntegerState::operator&) the state of all there are. - Optional<StateType> T; + std::optional<StateType> T; // Callback for each possibly returned value. auto CheckReturnValue = [&](Value &RV) -> bool { @@ -460,7 +430,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, // Use an optional state as there might not be any return values and we want // to join (IntegerState::operator&) the state of all there are. - Optional<StateType> T; + std::optional<StateType> T; // The argument number which is also the call site argument number. unsigned ArgNo = QueryingAA.getIRPosition().getCallSiteArgNo(); @@ -707,7 +677,7 @@ struct State; } // namespace PointerInfo } // namespace AA -/// Helper for AA::PointerInfo::Acccess DenseMap/Set usage. +/// Helper for AA::PointerInfo::Access DenseMap/Set usage. template <> struct DenseMapInfo<AAPointerInfo::Access> : DenseMapInfo<Instruction *> { using Access = AAPointerInfo::Access; @@ -717,12 +687,30 @@ struct DenseMapInfo<AAPointerInfo::Access> : DenseMapInfo<Instruction *> { static bool isEqual(const Access &LHS, const Access &RHS); }; -/// Helper that allows OffsetAndSize as a key in a DenseMap. -template <> -struct DenseMapInfo<AAPointerInfo ::OffsetAndSize> - : DenseMapInfo<std::pair<int64_t, int64_t>> {}; +/// Helper that allows RangeTy as a key in a DenseMap. +template <> struct DenseMapInfo<AA::RangeTy> { + static inline AA::RangeTy getEmptyKey() { + auto EmptyKey = DenseMapInfo<int64_t>::getEmptyKey(); + return AA::RangeTy{EmptyKey, EmptyKey}; + } + + static inline AA::RangeTy getTombstoneKey() { + auto TombstoneKey = DenseMapInfo<int64_t>::getTombstoneKey(); + return AA::RangeTy{TombstoneKey, TombstoneKey}; + } + + static unsigned getHashValue(const AA::RangeTy &Range) { + return detail::combineHashValue( + DenseMapInfo<int64_t>::getHashValue(Range.Offset), + DenseMapInfo<int64_t>::getHashValue(Range.Size)); + } + + static bool isEqual(const AA::RangeTy &A, const AA::RangeTy B) { + return A == B; + } +}; -/// Helper for AA::PointerInfo::Acccess DenseMap/Set usage ignoring everythign +/// Helper for AA::PointerInfo::Access DenseMap/Set usage ignoring everythign /// but the instruction struct AccessAsInstructionInfo : DenseMapInfo<Instruction *> { using Base = DenseMapInfo<Instruction *>; @@ -737,13 +725,6 @@ struct AccessAsInstructionInfo : DenseMapInfo<Instruction *> { /// A type to track pointer/struct usage and accesses for AAPointerInfo. struct AA::PointerInfo::State : public AbstractState { - - ~State() { - // We do not delete the Accesses objects but need to destroy them still. - for (auto &It : AccessBins) - It.second->~Accesses(); - } - /// Return the best possible representable state. static State getBestState(const State &SIS) { return State(); } @@ -755,9 +736,7 @@ struct AA::PointerInfo::State : public AbstractState { } State() = default; - State(State &&SIS) : AccessBins(std::move(SIS.AccessBins)) { - SIS.AccessBins.clear(); - } + State(State &&SIS) = default; const State &getAssumed() const { return *this; } @@ -783,7 +762,9 @@ struct AA::PointerInfo::State : public AbstractState { if (this == &R) return *this; BS = R.BS; - AccessBins = R.AccessBins; + AccessList = R.AccessList; + OffsetBins = R.OffsetBins; + RemoteIMap = R.RemoteIMap; return *this; } @@ -791,114 +772,69 @@ struct AA::PointerInfo::State : public AbstractState { if (this == &R) return *this; std::swap(BS, R.BS); - std::swap(AccessBins, R.AccessBins); + std::swap(AccessList, R.AccessList); + std::swap(OffsetBins, R.OffsetBins); + std::swap(RemoteIMap, R.RemoteIMap); return *this; } - bool operator==(const State &R) const { - if (BS != R.BS) - return false; - if (AccessBins.size() != R.AccessBins.size()) - return false; - auto It = begin(), RIt = R.begin(), E = end(); - while (It != E) { - if (It->getFirst() != RIt->getFirst()) - return false; - auto &Accs = It->getSecond(); - auto &RAccs = RIt->getSecond(); - if (Accs->size() != RAccs->size()) - return false; - for (const auto &ZipIt : llvm::zip(*Accs, *RAccs)) - if (std::get<0>(ZipIt) != std::get<1>(ZipIt)) - return false; - ++It; - ++RIt; - } - return true; - } - bool operator!=(const State &R) const { return !(*this == R); } - - /// We store accesses in a set with the instruction as key. - struct Accesses { - SmallVector<AAPointerInfo::Access, 4> Accesses; - DenseMap<const Instruction *, unsigned> Map; - - unsigned size() const { return Accesses.size(); } - - using vec_iterator = decltype(Accesses)::iterator; - vec_iterator begin() { return Accesses.begin(); } - vec_iterator end() { return Accesses.end(); } - - using iterator = decltype(Map)::const_iterator; - iterator find(AAPointerInfo::Access &Acc) { - return Map.find(Acc.getRemoteInst()); - } - iterator find_end() { return Map.end(); } - - AAPointerInfo::Access &get(iterator &It) { - return Accesses[It->getSecond()]; - } + /// Add a new Access to the state at offset \p Offset and with size \p Size. + /// The access is associated with \p I, writes \p Content (if anything), and + /// is of kind \p Kind. If an Access already exists for the same \p I and same + /// \p RemoteI, the two are combined, potentially losing information about + /// offset and size. The resulting access must now be moved from its original + /// OffsetBin to the bin for its new offset. + /// + /// \Returns CHANGED, if the state changed, UNCHANGED otherwise. + ChangeStatus addAccess(Attributor &A, const AAPointerInfo::RangeList &Ranges, + Instruction &I, std::optional<Value *> Content, + AAPointerInfo::AccessKind Kind, Type *Ty, + Instruction *RemoteI = nullptr); - void insert(AAPointerInfo::Access &Acc) { - Map[Acc.getRemoteInst()] = Accesses.size(); - Accesses.push_back(Acc); - } - }; + using OffsetBinsTy = DenseMap<RangeTy, SmallSet<unsigned, 4>>; - /// We store all accesses in bins denoted by their offset and size. - using AccessBinsTy = DenseMap<AAPointerInfo::OffsetAndSize, Accesses *>; + using const_bin_iterator = OffsetBinsTy::const_iterator; + const_bin_iterator begin() const { return OffsetBins.begin(); } + const_bin_iterator end() const { return OffsetBins.end(); } - AccessBinsTy::const_iterator begin() const { return AccessBins.begin(); } - AccessBinsTy::const_iterator end() const { return AccessBins.end(); } + const AAPointerInfo::Access &getAccess(unsigned Index) const { + return AccessList[Index]; + } protected: - /// The bins with all the accesses for the associated pointer. - AccessBinsTy AccessBins; - - /// Add a new access to the state at offset \p Offset and with size \p Size. - /// The access is associated with \p I, writes \p Content (if anything), and - /// is of kind \p Kind. - /// \Returns CHANGED, if the state changed, UNCHANGED otherwise. - ChangeStatus addAccess(Attributor &A, int64_t Offset, int64_t Size, - Instruction &I, Optional<Value *> Content, - AAPointerInfo::AccessKind Kind, Type *Ty, - Instruction *RemoteI = nullptr, - Accesses *BinPtr = nullptr) { - AAPointerInfo::OffsetAndSize Key{Offset, Size}; - Accesses *&Bin = BinPtr ? BinPtr : AccessBins[Key]; - if (!Bin) - Bin = new (A.Allocator) Accesses; - AAPointerInfo::Access Acc(&I, RemoteI ? RemoteI : &I, Content, Kind, Ty); - // Check if we have an access for this instruction in this bin, if not, - // simply add it. - auto It = Bin->find(Acc); - if (It == Bin->find_end()) { - Bin->insert(Acc); - return ChangeStatus::CHANGED; - } - // If the existing access is the same as then new one, nothing changed. - AAPointerInfo::Access &Current = Bin->get(It); - AAPointerInfo::Access Before = Current; - // The new one will be combined with the existing one. - Current &= Acc; - return Current == Before ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; - } + // Every memory instruction results in an Access object. We maintain a list of + // all Access objects that we own, along with the following maps: + // + // - OffsetBins: RangeTy -> { Access } + // - RemoteIMap: RemoteI x LocalI -> Access + // + // A RemoteI is any instruction that accesses memory. RemoteI is different + // from LocalI if and only if LocalI is a call; then RemoteI is some + // instruction in the callgraph starting from LocalI. Multiple paths in the + // callgraph from LocalI to RemoteI may produce multiple accesses, but these + // are all combined into a single Access object. This may result in loss of + // information in RangeTy in the Access object. + SmallVector<AAPointerInfo::Access> AccessList; + OffsetBinsTy OffsetBins; + DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap; /// See AAPointerInfo::forallInterferingAccesses. bool forallInterferingAccesses( - AAPointerInfo::OffsetAndSize OAS, + AA::RangeTy Range, function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const { if (!isValidState()) return false; - for (auto &It : AccessBins) { - AAPointerInfo::OffsetAndSize ItOAS = It.getFirst(); - if (!OAS.mayOverlap(ItOAS)) + for (const auto &It : OffsetBins) { + AA::RangeTy ItRange = It.getFirst(); + if (!Range.mayOverlap(ItRange)) continue; - bool IsExact = OAS == ItOAS && !OAS.offsetOrSizeAreUnknown(); - for (auto &Access : *It.getSecond()) + bool IsExact = Range == ItRange && !Range.offsetOrSizeAreUnknown(); + for (auto Index : It.getSecond()) { + auto &Access = AccessList[Index]; if (!CB(Access, IsExact)) return false; + } } return true; } @@ -906,29 +842,24 @@ protected: /// See AAPointerInfo::forallInterferingAccesses. bool forallInterferingAccesses( Instruction &I, - function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const { + function_ref<bool(const AAPointerInfo::Access &, bool)> CB, + AA::RangeTy &Range) const { if (!isValidState()) return false; - // First find the offset and size of I. - AAPointerInfo::OffsetAndSize OAS(-1, -1); - for (auto &It : AccessBins) { - for (auto &Access : *It.getSecond()) { - if (Access.getRemoteInst() == &I) { - OAS = It.getFirst(); + auto LocalList = RemoteIMap.find(&I); + if (LocalList == RemoteIMap.end()) { + return true; + } + + for (unsigned Index : LocalList->getSecond()) { + for (auto &R : AccessList[Index]) { + Range &= R; + if (Range.offsetOrSizeAreUnknown()) break; - } } - if (OAS.getSize() != -1) - break; } - // No access for I was found, we are done. - if (OAS.getSize() == -1) - return true; - - // Now that we have an offset and size, find all overlapping ones and use - // the callback on the accesses. - return forallInterferingAccesses(OAS, CB); + return forallInterferingAccesses(Range, CB); } private: @@ -936,7 +867,144 @@ private: BooleanState BS; }; +ChangeStatus AA::PointerInfo::State::addAccess( + Attributor &A, const AAPointerInfo::RangeList &Ranges, Instruction &I, + std::optional<Value *> Content, AAPointerInfo::AccessKind Kind, Type *Ty, + Instruction *RemoteI) { + RemoteI = RemoteI ? RemoteI : &I; + + // Check if we have an access for this instruction, if not, simply add it. + auto &LocalList = RemoteIMap[RemoteI]; + bool AccExists = false; + unsigned AccIndex = AccessList.size(); + for (auto Index : LocalList) { + auto &A = AccessList[Index]; + if (A.getLocalInst() == &I) { + AccExists = true; + AccIndex = Index; + break; + } + } + + auto AddToBins = [&](const AAPointerInfo::RangeList &ToAdd) { + LLVM_DEBUG( + if (ToAdd.size()) + dbgs() << "[AAPointerInfo] Inserting access in new offset bins\n"; + ); + + for (auto Key : ToAdd) { + LLVM_DEBUG(dbgs() << " key " << Key << "\n"); + OffsetBins[Key].insert(AccIndex); + } + }; + + if (!AccExists) { + AccessList.emplace_back(&I, RemoteI, Ranges, Content, Kind, Ty); + assert((AccessList.size() == AccIndex + 1) && + "New Access should have been at AccIndex"); + LocalList.push_back(AccIndex); + AddToBins(AccessList[AccIndex].getRanges()); + return ChangeStatus::CHANGED; + } + + // Combine the new Access with the existing Access, and then update the + // mapping in the offset bins. + AAPointerInfo::Access Acc(&I, RemoteI, Ranges, Content, Kind, Ty); + auto &Current = AccessList[AccIndex]; + auto Before = Current; + Current &= Acc; + if (Current == Before) + return ChangeStatus::UNCHANGED; + + auto &ExistingRanges = Before.getRanges(); + auto &NewRanges = Current.getRanges(); + + // Ranges that are in the old access but not the new access need to be removed + // from the offset bins. + AAPointerInfo::RangeList ToRemove; + AAPointerInfo::RangeList::set_difference(ExistingRanges, NewRanges, ToRemove); + LLVM_DEBUG( + if (ToRemove.size()) + dbgs() << "[AAPointerInfo] Removing access from old offset bins\n"; + ); + + for (auto Key : ToRemove) { + LLVM_DEBUG(dbgs() << " key " << Key << "\n"); + assert(OffsetBins.count(Key) && "Existing Access must be in some bin."); + auto &Bin = OffsetBins[Key]; + assert(Bin.count(AccIndex) && + "Expected bin to actually contain the Access."); + Bin.erase(AccIndex); + } + + // Ranges that are in the new access but not the old access need to be added + // to the offset bins. + AAPointerInfo::RangeList ToAdd; + AAPointerInfo::RangeList::set_difference(NewRanges, ExistingRanges, ToAdd); + AddToBins(ToAdd); + return ChangeStatus::CHANGED; +} + namespace { + +/// A helper containing a list of offsets computed for a Use. Ideally this +/// list should be strictly ascending, but we ensure that only when we +/// actually translate the list of offsets to a RangeList. +struct OffsetInfo { + using VecTy = SmallVector<int64_t>; + using const_iterator = VecTy::const_iterator; + VecTy Offsets; + + const_iterator begin() const { return Offsets.begin(); } + const_iterator end() const { return Offsets.end(); } + + bool operator==(const OffsetInfo &RHS) const { + return Offsets == RHS.Offsets; + } + + bool operator!=(const OffsetInfo &RHS) const { return !(*this == RHS); } + + void insert(int64_t Offset) { Offsets.push_back(Offset); } + bool isUnassigned() const { return Offsets.size() == 0; } + + bool isUnknown() const { + if (isUnassigned()) + return false; + if (Offsets.size() == 1) + return Offsets.front() == AA::RangeTy::Unknown; + return false; + } + + void setUnknown() { + Offsets.clear(); + Offsets.push_back(AA::RangeTy::Unknown); + } + + void addToAll(int64_t Inc) { + for (auto &Offset : Offsets) { + Offset += Inc; + } + } + + /// Copy offsets from \p R into the current list. + /// + /// Ideally all lists should be strictly ascending, but we defer that to the + /// actual use of the list. So we just blindly append here. + void merge(const OffsetInfo &R) { Offsets.append(R.Offsets); } +}; + +#ifndef NDEBUG +static raw_ostream &operator<<(raw_ostream &OS, const OffsetInfo &OI) { + ListSeparator LS; + OS << "["; + for (auto Offset : OI) { + OS << LS << Offset; + } + OS << "]"; + return OS; +} +#endif // NDEBUG + struct AAPointerInfoImpl : public StateWrapper<AA::PointerInfo::State, AAPointerInfo> { using BaseTy = StateWrapper<AA::PointerInfo::State, AAPointerInfo>; @@ -946,7 +1014,7 @@ struct AAPointerInfoImpl const std::string getAsStr() const override { return std::string("PointerInfo ") + (isValidState() ? (std::string("#") + - std::to_string(AccessBins.size()) + " bins") + std::to_string(OffsetBins.size()) + " bins") : "<invalid>"); } @@ -956,17 +1024,16 @@ struct AAPointerInfoImpl } bool forallInterferingAccesses( - OffsetAndSize OAS, + AA::RangeTy Range, function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const override { - return State::forallInterferingAccesses(OAS, CB); + return State::forallInterferingAccesses(Range, CB); } - bool - forallInterferingAccesses(Attributor &A, const AbstractAttribute &QueryingAA, - Instruction &I, - function_ref<bool(const Access &, bool)> UserCB, - bool &HasBeenWrittenTo) const override { + bool forallInterferingAccesses( + Attributor &A, const AbstractAttribute &QueryingAA, Instruction &I, + function_ref<bool(const Access &, bool)> UserCB, bool &HasBeenWrittenTo, + AA::RangeTy &Range) const override { HasBeenWrittenTo = false; SmallPtrSet<const Access *, 8> DominatingWrites; @@ -977,25 +1044,43 @@ struct AAPointerInfoImpl QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL); const auto *ExecDomainAA = A.lookupAAFor<AAExecutionDomain>( IRPosition::function(Scope), &QueryingAA, DepClassTy::OPTIONAL); - const bool NoSync = NoSyncAA.isAssumedNoSync(); + bool AllInSameNoSyncFn = NoSyncAA.isAssumedNoSync(); + bool InstIsExecutedByInitialThreadOnly = + ExecDomainAA && ExecDomainAA->isExecutedByInitialThreadOnly(I); + bool InstIsExecutedInAlignedRegion = + ExecDomainAA && ExecDomainAA->isExecutedInAlignedRegion(A, I); + + InformationCache &InfoCache = A.getInfoCache(); + bool IsThreadLocalObj = + AA::isAssumedThreadLocalObject(A, getAssociatedValue(), *this); // Helper to determine if we need to consider threading, which we cannot // right now. However, if the function is (assumed) nosync or the thread // executing all instructions is the main thread only we can ignore - // threading. - auto CanIgnoreThreading = [&](const Instruction &I) -> bool { - if (NoSync) + // threading. Also, thread-local objects do not require threading reasoning. + // Finally, we can ignore threading if either access is executed in an + // aligned region. + auto CanIgnoreThreadingForInst = [&](const Instruction &I) -> bool { + if (IsThreadLocalObj || AllInSameNoSyncFn) + return true; + if (!ExecDomainAA) + return false; + if (InstIsExecutedInAlignedRegion || + ExecDomainAA->isExecutedInAlignedRegion(A, I)) return true; - if (ExecDomainAA && ExecDomainAA->isExecutedByInitialThreadOnly(I)) + if (InstIsExecutedByInitialThreadOnly && + ExecDomainAA->isExecutedByInitialThreadOnly(I)) return true; return false; }; // Helper to determine if the access is executed by the same thread as the - // load, for now it is sufficient to avoid any potential threading effects - // as we cannot deal with them anyway. - auto IsSameThreadAsLoad = [&](const Access &Acc) -> bool { - return CanIgnoreThreading(*Acc.getLocalInst()); + // given instruction, for now it is sufficient to avoid any potential + // threading effects as we cannot deal with them anyway. + auto CanIgnoreThreading = [&](const Access &Acc) -> bool { + return CanIgnoreThreadingForInst(*Acc.getRemoteInst()) || + (Acc.getRemoteInst() != Acc.getLocalInst() && + CanIgnoreThreadingForInst(*Acc.getLocalInst())); }; // TODO: Use inter-procedural reachability and dominance. @@ -1006,19 +1091,9 @@ struct AAPointerInfoImpl const bool FindInterferingReads = I.mayWriteToMemory(); const bool UseDominanceReasoning = FindInterferingWrites && NoRecurseAA.isKnownNoRecurse(); - const bool CanUseCFGResoning = CanIgnoreThreading(I); - InformationCache &InfoCache = A.getInfoCache(); const DominatorTree *DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(Scope); - enum GPUAddressSpace : unsigned { - Generic = 0, - Global = 1, - Shared = 3, - Constant = 4, - Local = 5, - }; - // Helper to check if a value has "kernel lifetime", that is it will not // outlive a GPU kernel. This is true for shared, constant, and local // globals on AMD and NVIDIA GPUs. @@ -1026,10 +1101,10 @@ struct AAPointerInfoImpl Triple T(M.getTargetTriple()); if (!(T.isAMDGPU() || T.isNVPTX())) return false; - switch (V->getType()->getPointerAddressSpace()) { - case GPUAddressSpace::Shared: - case GPUAddressSpace::Constant: - case GPUAddressSpace::Local: + switch (AA::GPUAddressSpace(V->getType()->getPointerAddressSpace())) { + case AA::GPUAddressSpace::Shared: + case AA::GPUAddressSpace::Constant: + case AA::GPUAddressSpace::Local: return true; default: return false; @@ -1061,72 +1136,121 @@ struct AAPointerInfoImpl }; } + // Set of accesses/instructions that will overwrite the result and are + // therefore blockers in the reachability traversal. + AA::InstExclusionSetTy ExclusionSet; + auto AccessCB = [&](const Access &Acc, bool Exact) { - if ((!FindInterferingWrites || !Acc.isWrite()) && + if (Exact && Acc.isMustAccess() && Acc.getRemoteInst() != &I) { + if (Acc.isWrite() || (isa<LoadInst>(I) && Acc.isWriteOrAssumption())) + ExclusionSet.insert(Acc.getRemoteInst()); + } + + if ((!FindInterferingWrites || !Acc.isWriteOrAssumption()) && (!FindInterferingReads || !Acc.isRead())) return true; - bool Dominates = DT && Exact && Acc.isMustAccess() && - (Acc.getLocalInst()->getFunction() == &Scope) && + bool Dominates = FindInterferingWrites && DT && Exact && + Acc.isMustAccess() && + (Acc.getRemoteInst()->getFunction() == &Scope) && DT->dominates(Acc.getRemoteInst(), &I); - if (FindInterferingWrites && Dominates) - HasBeenWrittenTo = true; - - // For now we only filter accesses based on CFG reasoning which does not - // work yet if we have threading effects, or the access is complicated. - if (CanUseCFGResoning && Dominates && UseDominanceReasoning && - IsSameThreadAsLoad(Acc)) + if (Dominates) DominatingWrites.insert(&Acc); + // Track if all interesting accesses are in the same `nosync` function as + // the given instruction. + AllInSameNoSyncFn &= Acc.getRemoteInst()->getFunction() == &Scope; + InterferingAccesses.push_back({&Acc, Exact}); return true; }; - if (!State::forallInterferingAccesses(I, AccessCB)) + if (!State::forallInterferingAccesses(I, AccessCB, Range)) return false; - if (HasBeenWrittenTo) { - const Function *ScopePtr = &Scope; - IsLiveInCalleeCB = [ScopePtr](const Function &Fn) { - return ScopePtr != &Fn; - }; + HasBeenWrittenTo = !DominatingWrites.empty(); + + // Dominating writes form a chain, find the least/lowest member. + Instruction *LeastDominatingWriteInst = nullptr; + for (const Access *Acc : DominatingWrites) { + if (!LeastDominatingWriteInst) { + LeastDominatingWriteInst = Acc->getRemoteInst(); + } else if (DT->dominates(LeastDominatingWriteInst, + Acc->getRemoteInst())) { + LeastDominatingWriteInst = Acc->getRemoteInst(); + } } - // Helper to determine if we can skip a specific write access. This is in - // the worst case quadratic as we are looking for another write that will - // hide the effect of this one. + // Helper to determine if we can skip a specific write access. auto CanSkipAccess = [&](const Access &Acc, bool Exact) { - if ((!Acc.isWrite() || - !AA::isPotentiallyReachable(A, *Acc.getLocalInst(), I, QueryingAA, - IsLiveInCalleeCB)) && - (!Acc.isRead() || - !AA::isPotentiallyReachable(A, I, *Acc.getLocalInst(), QueryingAA, - IsLiveInCalleeCB))) + if (!CanIgnoreThreading(Acc)) + return false; + + // Check read (RAW) dependences and write (WAR) dependences as necessary. + // If we successfully excluded all effects we are interested in, the + // access can be skipped. + bool ReadChecked = !FindInterferingReads; + bool WriteChecked = !FindInterferingWrites; + + // If the instruction cannot reach the access, the former does not + // interfere with what the access reads. + if (!ReadChecked) { + if (!AA::isPotentiallyReachable(A, I, *Acc.getRemoteInst(), QueryingAA, + &ExclusionSet, IsLiveInCalleeCB)) + ReadChecked = true; + } + // If the instruction cannot be reach from the access, the latter does not + // interfere with what the instruction reads. + if (!WriteChecked) { + if (!AA::isPotentiallyReachable(A, *Acc.getRemoteInst(), I, QueryingAA, + &ExclusionSet, IsLiveInCalleeCB)) + WriteChecked = true; + } + + // If we still might be affected by the write of the access but there are + // dominating writes in the function of the instruction + // (HasBeenWrittenTo), we can try to reason that the access is overwritten + // by them. This would have happend above if they are all in the same + // function, so we only check the inter-procedural case. Effectively, we + // want to show that there is no call after the dominting write that might + // reach the access, and when it returns reach the instruction with the + // updated value. To this end, we iterate all call sites, check if they + // might reach the instruction without going through another access + // (ExclusionSet) and at the same time might reach the access. However, + // that is all part of AAInterFnReachability. + if (!WriteChecked && HasBeenWrittenTo && + Acc.getRemoteInst()->getFunction() != &Scope) { + + const auto &FnReachabilityAA = A.getAAFor<AAInterFnReachability>( + QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL); + + // Without going backwards in the call tree, can we reach the access + // from the least dominating write. Do not allow to pass the instruction + // itself either. + bool Inserted = ExclusionSet.insert(&I).second; + + if (!FnReachabilityAA.instructionCanReach( + A, *LeastDominatingWriteInst, + *Acc.getRemoteInst()->getFunction(), &ExclusionSet)) + WriteChecked = true; + + if (Inserted) + ExclusionSet.erase(&I); + } + + if (ReadChecked && WriteChecked) return true; if (!DT || !UseDominanceReasoning) return false; - if (!IsSameThreadAsLoad(Acc)) - return false; if (!DominatingWrites.count(&Acc)) return false; - for (const Access *DomAcc : DominatingWrites) { - assert(Acc.getLocalInst()->getFunction() == - DomAcc->getLocalInst()->getFunction() && - "Expected dominating writes to be in the same function!"); - - if (DomAcc != &Acc && - DT->dominates(Acc.getLocalInst(), DomAcc->getLocalInst())) { - return true; - } - } - return false; + return LeastDominatingWriteInst != Acc.getRemoteInst(); }; - // Run the user callback on all accesses we cannot skip and return if that - // succeeded for all or not. - unsigned NumInterferingAccesses = InterferingAccesses.size(); + // Run the user callback on all accesses we cannot skip and return if + // that succeeded for all or not. for (auto &It : InterferingAccesses) { - if (NumInterferingAccesses > MaxInterferingAccesses || + if ((!AllInSameNoSyncFn && !IsThreadLocalObj && !ExecDomainAA) || !CanSkipAccess(*It.first, It.second)) { if (!UserCB(*It.first, It.second)) return false; @@ -1135,40 +1259,63 @@ struct AAPointerInfoImpl return true; } - ChangeStatus translateAndAddState(Attributor &A, const AAPointerInfo &OtherAA, - int64_t Offset, CallBase &CB, - bool FromCallee = false) { + ChangeStatus translateAndAddStateFromCallee(Attributor &A, + const AAPointerInfo &OtherAA, + CallBase &CB) { using namespace AA::PointerInfo; if (!OtherAA.getState().isValidState() || !isValidState()) return indicatePessimisticFixpoint(); const auto &OtherAAImpl = static_cast<const AAPointerInfoImpl &>(OtherAA); - bool IsByval = - FromCallee && OtherAAImpl.getAssociatedArgument()->hasByValAttr(); + bool IsByval = OtherAAImpl.getAssociatedArgument()->hasByValAttr(); // Combine the accesses bin by bin. ChangeStatus Changed = ChangeStatus::UNCHANGED; - for (auto &It : OtherAAImpl.getState()) { - OffsetAndSize OAS = OffsetAndSize::getUnknown(); - if (Offset != OffsetAndSize::Unknown) - OAS = OffsetAndSize(It.first.getOffset() + Offset, It.first.getSize()); - Accesses *Bin = AccessBins.lookup(OAS); - for (const AAPointerInfo::Access &RAcc : *It.second) { + const auto &State = OtherAAImpl.getState(); + for (const auto &It : State) { + for (auto Index : It.getSecond()) { + const auto &RAcc = State.getAccess(Index); if (IsByval && !RAcc.isRead()) continue; bool UsedAssumedInformation = false; AccessKind AK = RAcc.getKind(); - Optional<Value *> Content = RAcc.getContent(); - if (FromCallee) { - Content = A.translateArgumentToCallSiteContent( - RAcc.getContent(), CB, *this, UsedAssumedInformation); - AK = - AccessKind(AK & (IsByval ? AccessKind::AK_R : AccessKind::AK_RW)); - AK = AccessKind(AK | (RAcc.isMayAccess() ? AK_MAY : AK_MUST)); + auto Content = A.translateArgumentToCallSiteContent( + RAcc.getContent(), CB, *this, UsedAssumedInformation); + AK = AccessKind(AK & (IsByval ? AccessKind::AK_R : AccessKind::AK_RW)); + AK = AccessKind(AK | (RAcc.isMayAccess() ? AK_MAY : AK_MUST)); + + Changed |= addAccess(A, RAcc.getRanges(), CB, Content, AK, + RAcc.getType(), RAcc.getRemoteInst()); + } + } + return Changed; + } + + ChangeStatus translateAndAddState(Attributor &A, const AAPointerInfo &OtherAA, + const OffsetInfo &Offsets, CallBase &CB) { + using namespace AA::PointerInfo; + if (!OtherAA.getState().isValidState() || !isValidState()) + return indicatePessimisticFixpoint(); + + const auto &OtherAAImpl = static_cast<const AAPointerInfoImpl &>(OtherAA); + + // Combine the accesses bin by bin. + ChangeStatus Changed = ChangeStatus::UNCHANGED; + const auto &State = OtherAAImpl.getState(); + for (const auto &It : State) { + for (auto Index : It.getSecond()) { + const auto &RAcc = State.getAccess(Index); + for (auto Offset : Offsets) { + auto NewRanges = Offset == AA::RangeTy::Unknown + ? AA::RangeTy::getUnknown() + : RAcc.getRanges(); + if (!NewRanges.isUnknown()) { + NewRanges.addToAllOffsets(Offset); + } + Changed |= + addAccess(A, NewRanges, CB, RAcc.getContent(), RAcc.getKind(), + RAcc.getType(), RAcc.getRemoteInst()); } - Changed = - Changed | addAccess(A, OAS.getOffset(), OAS.getSize(), CB, Content, - AK, RAcc.getType(), RAcc.getRemoteInst(), Bin); } } return Changed; @@ -1180,11 +1327,11 @@ struct AAPointerInfoImpl /// Dump the state into \p O. void dumpState(raw_ostream &O) { - for (auto &It : AccessBins) { - O << "[" << It.first.getOffset() << "-" - << It.first.getOffset() + It.first.getSize() - << "] : " << It.getSecond()->size() << "\n"; - for (auto &Acc : *It.getSecond()) { + for (auto &It : OffsetBins) { + O << "[" << It.first.Offset << "-" << It.first.Offset + It.first.Size + << "] : " << It.getSecond().size() << "\n"; + for (auto AccIndex : It.getSecond()) { + auto &Acc = AccessList[AccIndex]; O << " - " << Acc.getKind() << " - " << *Acc.getLocalInst() << "\n"; if (Acc.getLocalInst() != Acc.getRemoteInst()) O << " --> " << *Acc.getRemoteInst() @@ -1206,245 +1353,478 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { : AAPointerInfoImpl(IRP, A) {} /// Deal with an access and signal if it was handled successfully. - bool handleAccess(Attributor &A, Instruction &I, Value &Ptr, - Optional<Value *> Content, AccessKind Kind, int64_t Offset, - ChangeStatus &Changed, Type *Ty, - int64_t Size = OffsetAndSize::Unknown) { + bool handleAccess(Attributor &A, Instruction &I, + std::optional<Value *> Content, AccessKind Kind, + SmallVectorImpl<int64_t> &Offsets, ChangeStatus &Changed, + Type &Ty) { using namespace AA::PointerInfo; - // No need to find a size if one is given or the offset is unknown. - if (Offset != OffsetAndSize::Unknown && Size == OffsetAndSize::Unknown && - Ty) { - const DataLayout &DL = A.getDataLayout(); - TypeSize AccessSize = DL.getTypeStoreSize(Ty); - if (!AccessSize.isScalable()) - Size = AccessSize.getFixedSize(); - } - Changed = Changed | addAccess(A, Offset, Size, I, Content, Kind, Ty); + auto Size = AA::RangeTy::Unknown; + const DataLayout &DL = A.getDataLayout(); + TypeSize AccessSize = DL.getTypeStoreSize(&Ty); + if (!AccessSize.isScalable()) + Size = AccessSize.getFixedValue(); + + // Make a strictly ascending list of offsets as required by addAccess() + llvm::sort(Offsets); + auto *Last = std::unique(Offsets.begin(), Offsets.end()); + Offsets.erase(Last, Offsets.end()); + + VectorType *VT = dyn_cast<VectorType>(&Ty); + if (!VT || VT->getElementCount().isScalable() || + !Content.value_or(nullptr) || !isa<Constant>(*Content) || + (*Content)->getType() != VT || + DL.getTypeStoreSize(VT->getElementType()).isScalable()) { + Changed = Changed | addAccess(A, {Offsets, Size}, I, Content, Kind, &Ty); + } else { + // Handle vector stores with constant content element-wise. + // TODO: We could look for the elements or create instructions + // representing them. + // TODO: We need to push the Content into the range abstraction + // (AA::RangeTy) to allow different content values for different + // ranges. ranges. Hence, support vectors storing different values. + Type *ElementType = VT->getElementType(); + int64_t ElementSize = DL.getTypeStoreSize(ElementType).getFixedValue(); + auto *ConstContent = cast<Constant>(*Content); + Type *Int32Ty = Type::getInt32Ty(ElementType->getContext()); + SmallVector<int64_t> ElementOffsets(Offsets.begin(), Offsets.end()); + + for (int i = 0, e = VT->getElementCount().getFixedValue(); i != e; ++i) { + Value *ElementContent = ConstantExpr::getExtractElement( + ConstContent, ConstantInt::get(Int32Ty, i)); + + // Add the element access. + Changed = Changed | addAccess(A, {ElementOffsets, ElementSize}, I, + ElementContent, Kind, ElementType); + + // Advance the offsets for the next element. + for (auto &ElementOffset : ElementOffsets) + ElementOffset += ElementSize; + } + } return true; }; - /// Helper struct, will support ranges eventually. - struct OffsetInfo { - int64_t Offset = OffsetAndSize::Unknown; + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; - bool operator==(const OffsetInfo &OI) const { return Offset == OI.Offset; } - }; + /// If the indices to \p GEP can be traced to constants, incorporate all + /// of these into \p UsrOI. + /// + /// \return true iff \p UsrOI is updated. + bool collectConstantsForGEP(Attributor &A, const DataLayout &DL, + OffsetInfo &UsrOI, const OffsetInfo &PtrOI, + const GEPOperator *GEP); - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - using namespace AA::PointerInfo; - ChangeStatus Changed = ChangeStatus::UNCHANGED; - Value &AssociatedValue = getAssociatedValue(); + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + AAPointerInfoImpl::trackPointerInfoStatistics(getIRPosition()); + } +}; - const DataLayout &DL = A.getDataLayout(); - DenseMap<Value *, OffsetInfo> OffsetInfoMap; - OffsetInfoMap[&AssociatedValue] = OffsetInfo{0}; +bool AAPointerInfoFloating::collectConstantsForGEP(Attributor &A, + const DataLayout &DL, + OffsetInfo &UsrOI, + const OffsetInfo &PtrOI, + const GEPOperator *GEP) { + unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP->getType()); + MapVector<Value *, APInt> VariableOffsets; + APInt ConstantOffset(BitWidth, 0); - auto HandlePassthroughUser = [&](Value *Usr, OffsetInfo PtrOI, - bool &Follow) { - OffsetInfo &UsrOI = OffsetInfoMap[Usr]; - UsrOI = PtrOI; - Follow = true; + assert(!UsrOI.isUnknown() && !PtrOI.isUnknown() && + "Don't look for constant values if the offset has already been " + "determined to be unknown."); + + if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) { + UsrOI.setUnknown(); + return true; + } + + LLVM_DEBUG(dbgs() << "[AAPointerInfo] GEP offset is " + << (VariableOffsets.empty() ? "" : "not") << " constant " + << *GEP << "\n"); + + auto Union = PtrOI; + Union.addToAll(ConstantOffset.getSExtValue()); + + // Each VI in VariableOffsets has a set of potential constant values. Every + // combination of elements, picked one each from these sets, is separately + // added to the original set of offsets, thus resulting in more offsets. + for (const auto &VI : VariableOffsets) { + auto &PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(*VI.first), DepClassTy::OPTIONAL); + if (!PotentialConstantsAA.isValidState()) { + UsrOI.setUnknown(); return true; - }; + } - const auto *TLI = getAnchorScope() - ? A.getInfoCache().getTargetLibraryInfoForFunction( - *getAnchorScope()) - : nullptr; - auto UsePred = [&](const Use &U, bool &Follow) -> bool { - Value *CurPtr = U.get(); - User *Usr = U.getUser(); - LLVM_DEBUG(dbgs() << "[AAPointerInfo] Analyze " << *CurPtr << " in " - << *Usr << "\n"); - assert(OffsetInfoMap.count(CurPtr) && - "The current pointer offset should have been seeded!"); - - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Usr)) { - if (CE->isCast()) - return HandlePassthroughUser(Usr, OffsetInfoMap[CurPtr], Follow); - if (CE->isCompare()) - return true; - if (!isa<GEPOperator>(CE)) { - LLVM_DEBUG(dbgs() << "[AAPointerInfo] Unhandled constant user " << *CE - << "\n"); - return false; - } + // UndefValue is treated as a zero, which leaves Union as is. + if (PotentialConstantsAA.undefIsContained()) + continue; + + // We need at least one constant in every set to compute an actual offset. + // Otherwise, we end up pessimizing AAPointerInfo by respecting offsets that + // don't actually exist. In other words, the absence of constant values + // implies that the operation can be assumed dead for now. + auto &AssumedSet = PotentialConstantsAA.getAssumedSet(); + if (AssumedSet.empty()) + return false; + + OffsetInfo Product; + for (const auto &ConstOffset : AssumedSet) { + auto CopyPerOffset = Union; + CopyPerOffset.addToAll(ConstOffset.getSExtValue() * + VI.second.getZExtValue()); + Product.merge(CopyPerOffset); + } + Union = Product; + } + + UsrOI = std::move(Union); + return true; +} + +ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { + using namespace AA::PointerInfo; + ChangeStatus Changed = ChangeStatus::UNCHANGED; + const DataLayout &DL = A.getDataLayout(); + Value &AssociatedValue = getAssociatedValue(); + + DenseMap<Value *, OffsetInfo> OffsetInfoMap; + OffsetInfoMap[&AssociatedValue].insert(0); + + auto HandlePassthroughUser = [&](Value *Usr, Value *CurPtr, bool &Follow) { + // One does not simply walk into a map and assign a reference to a possibly + // new location. That can cause an invalidation before the assignment + // happens, like so: + // + // OffsetInfoMap[Usr] = OffsetInfoMap[CurPtr]; /* bad idea! */ + // + // The RHS is a reference that may be invalidated by an insertion caused by + // the LHS. So we ensure that the side-effect of the LHS happens first. + auto &UsrOI = OffsetInfoMap[Usr]; + auto &PtrOI = OffsetInfoMap[CurPtr]; + assert(!PtrOI.isUnassigned() && + "Cannot pass through if the input Ptr was not visited!"); + UsrOI = PtrOI; + Follow = true; + return true; + }; + + const auto *F = getAnchorScope(); + const auto *CI = + F ? A.getInfoCache().getAnalysisResultForFunction<CycleAnalysis>(*F) + : nullptr; + const auto *TLI = + F ? A.getInfoCache().getTargetLibraryInfoForFunction(*F) : nullptr; + + auto UsePred = [&](const Use &U, bool &Follow) -> bool { + Value *CurPtr = U.get(); + User *Usr = U.getUser(); + LLVM_DEBUG(dbgs() << "[AAPointerInfo] Analyze " << *CurPtr << " in " << *Usr + << "\n"); + assert(OffsetInfoMap.count(CurPtr) && + "The current pointer offset should have been seeded!"); + + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Usr)) { + if (CE->isCast()) + return HandlePassthroughUser(Usr, CurPtr, Follow); + if (CE->isCompare()) + return true; + if (!isa<GEPOperator>(CE)) { + LLVM_DEBUG(dbgs() << "[AAPointerInfo] Unhandled constant user " << *CE + << "\n"); + return false; } - if (auto *GEP = dyn_cast<GEPOperator>(Usr)) { - // Note the order here, the Usr access might change the map, CurPtr is - // already in it though. - OffsetInfo &UsrOI = OffsetInfoMap[Usr]; - OffsetInfo &PtrOI = OffsetInfoMap[CurPtr]; - UsrOI = PtrOI; - - // TODO: Use range information. - if (PtrOI.Offset == OffsetAndSize::Unknown || - !GEP->hasAllConstantIndices()) { - UsrOI.Offset = OffsetAndSize::Unknown; - Follow = true; - return true; - } + } + if (auto *GEP = dyn_cast<GEPOperator>(Usr)) { + // Note the order here, the Usr access might change the map, CurPtr is + // already in it though. + auto &UsrOI = OffsetInfoMap[Usr]; + auto &PtrOI = OffsetInfoMap[CurPtr]; - SmallVector<Value *, 8> Indices; - for (Use &Idx : GEP->indices()) { - if (auto *CIdx = dyn_cast<ConstantInt>(Idx)) { - Indices.push_back(CIdx); - continue; - } + if (UsrOI.isUnknown()) + return true; - LLVM_DEBUG(dbgs() << "[AAPointerInfo] Non constant GEP index " << *GEP - << " : " << *Idx << "\n"); - return false; - } - UsrOI.Offset = PtrOI.Offset + DL.getIndexedOffsetInType( - GEP->getSourceElementType(), Indices); + if (PtrOI.isUnknown()) { Follow = true; + UsrOI.setUnknown(); + return true; + } + + Follow = collectConstantsForGEP(A, DL, UsrOI, PtrOI, GEP); + return true; + } + if (isa<PtrToIntInst>(Usr)) + return false; + if (isa<CastInst>(Usr) || isa<SelectInst>(Usr) || isa<ReturnInst>(Usr)) + return HandlePassthroughUser(Usr, CurPtr, Follow); + + // For PHIs we need to take care of the recurrence explicitly as the value + // might change while we iterate through a loop. For now, we give up if + // the PHI is not invariant. + if (isa<PHINode>(Usr)) { + // Note the order here, the Usr access might change the map, CurPtr is + // already in it though. + bool IsFirstPHIUser = !OffsetInfoMap.count(Usr); + auto &UsrOI = OffsetInfoMap[Usr]; + auto &PtrOI = OffsetInfoMap[CurPtr]; + + // Check if the PHI operand has already an unknown offset as we can't + // improve on that anymore. + if (PtrOI.isUnknown()) { + LLVM_DEBUG(dbgs() << "[AAPointerInfo] PHI operand offset unknown " + << *CurPtr << " in " << *Usr << "\n"); + Follow = !UsrOI.isUnknown(); + UsrOI.setUnknown(); return true; } - if (isa<CastInst>(Usr) || isa<SelectInst>(Usr) || isa<ReturnInst>(Usr)) - return HandlePassthroughUser(Usr, OffsetInfoMap[CurPtr], Follow); - - // For PHIs we need to take care of the recurrence explicitly as the value - // might change while we iterate through a loop. For now, we give up if - // the PHI is not invariant. - if (isa<PHINode>(Usr)) { - // Note the order here, the Usr access might change the map, CurPtr is - // already in it though. - bool IsFirstPHIUser = !OffsetInfoMap.count(Usr); - OffsetInfo &UsrOI = OffsetInfoMap[Usr]; - OffsetInfo &PtrOI = OffsetInfoMap[CurPtr]; - // Check if the PHI is invariant (so far). - if (UsrOI == PtrOI) - return true; - // Check if the PHI operand has already an unknown offset as we can't - // improve on that anymore. - if (PtrOI.Offset == OffsetAndSize::Unknown) { - UsrOI = PtrOI; - Follow = true; + // Check if the PHI is invariant (so far). + if (UsrOI == PtrOI) { + assert(!PtrOI.isUnassigned() && + "Cannot assign if the current Ptr was not visited!"); + LLVM_DEBUG(dbgs() << "[AAPointerInfo] PHI is invariant (so far)"); + return true; + } + + // Check if the PHI operand can be traced back to AssociatedValue. + APInt Offset( + DL.getIndexSizeInBits(CurPtr->getType()->getPointerAddressSpace()), + 0); + Value *CurPtrBase = CurPtr->stripAndAccumulateConstantOffsets( + DL, Offset, /* AllowNonInbounds */ true); + auto It = OffsetInfoMap.find(CurPtrBase); + if (It == OffsetInfoMap.end()) { + LLVM_DEBUG(dbgs() << "[AAPointerInfo] PHI operand is too complex " + << *CurPtr << " in " << *Usr << "\n"); + UsrOI.setUnknown(); + Follow = true; + return true; + } + + auto mayBeInCycleHeader = [](const CycleInfo *CI, const Instruction *I) { + if (!CI) return true; - } + auto *BB = I->getParent(); + auto *C = CI->getCycle(BB); + if (!C) + return false; + return BB == C->getHeader(); + }; - // Check if the PHI operand is not dependent on the PHI itself. - APInt Offset( - DL.getIndexSizeInBits(CurPtr->getType()->getPointerAddressSpace()), - 0); - Value *CurPtrBase = CurPtr->stripAndAccumulateConstantOffsets( - DL, Offset, /* AllowNonInbounds */ true); - auto It = OffsetInfoMap.find(CurPtrBase); - if (It != OffsetInfoMap.end()) { - Offset += It->getSecond().Offset; - if (IsFirstPHIUser || Offset == UsrOI.Offset) - return HandlePassthroughUser(Usr, PtrOI, Follow); - LLVM_DEBUG(dbgs() - << "[AAPointerInfo] PHI operand pointer offset mismatch " - << *CurPtr << " in " << *Usr << "\n"); - } else { - LLVM_DEBUG(dbgs() << "[AAPointerInfo] PHI operand is too complex " - << *CurPtr << " in " << *Usr << "\n"); + // Check if the PHI operand is not dependent on the PHI itself. Every + // recurrence is a cyclic net of PHIs in the data flow, and has an + // equivalent Cycle in the control flow. One of those PHIs must be in the + // header of that control flow Cycle. This is independent of the choice of + // Cycles reported by CycleInfo. It is sufficient to check the PHIs in + // every Cycle header; if such a node is marked unknown, this will + // eventually propagate through the whole net of PHIs in the recurrence. + if (mayBeInCycleHeader(CI, cast<Instruction>(Usr))) { + auto BaseOI = It->getSecond(); + BaseOI.addToAll(Offset.getZExtValue()); + if (IsFirstPHIUser || BaseOI == UsrOI) { + LLVM_DEBUG(dbgs() << "[AAPointerInfo] PHI is invariant " << *CurPtr + << " in " << *Usr << "\n"); + return HandlePassthroughUser(Usr, CurPtr, Follow); } - // TODO: Approximate in case we know the direction of the recurrence. - UsrOI = PtrOI; - UsrOI.Offset = OffsetAndSize::Unknown; + LLVM_DEBUG( + dbgs() << "[AAPointerInfo] PHI operand pointer offset mismatch " + << *CurPtr << " in " << *Usr << "\n"); + UsrOI.setUnknown(); Follow = true; return true; } - if (auto *LoadI = dyn_cast<LoadInst>(Usr)) { - // If the access is to a pointer that may or may not be the associated - // value, e.g. due to a PHI, we cannot assume it will be read. - AccessKind AK = AccessKind::AK_R; - if (getUnderlyingObject(CurPtr) == &AssociatedValue) - AK = AccessKind(AK | AccessKind::AK_MUST); - else - AK = AccessKind(AK | AccessKind::AK_MAY); - return handleAccess(A, *LoadI, *CurPtr, /* Content */ nullptr, AK, - OffsetInfoMap[CurPtr].Offset, Changed, - LoadI->getType()); - } + UsrOI.merge(PtrOI); + Follow = true; + return true; + } + + if (auto *LoadI = dyn_cast<LoadInst>(Usr)) { + // If the access is to a pointer that may or may not be the associated + // value, e.g. due to a PHI, we cannot assume it will be read. + AccessKind AK = AccessKind::AK_R; + if (getUnderlyingObject(CurPtr) == &AssociatedValue) + AK = AccessKind(AK | AccessKind::AK_MUST); + else + AK = AccessKind(AK | AccessKind::AK_MAY); + if (!handleAccess(A, *LoadI, /* Content */ nullptr, AK, + OffsetInfoMap[CurPtr].Offsets, Changed, + *LoadI->getType())) + return false; + + auto IsAssumption = [](Instruction &I) { + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + return II->isAssumeLikeIntrinsic(); + return false; + }; + + auto IsImpactedInRange = [&](Instruction *FromI, Instruction *ToI) { + // Check if the assumption and the load are executed together without + // memory modification. + do { + if (FromI->mayWriteToMemory() && !IsAssumption(*FromI)) + return true; + FromI = FromI->getNextNonDebugInstruction(); + } while (FromI && FromI != ToI); + return false; + }; - if (auto *StoreI = dyn_cast<StoreInst>(Usr)) { - if (StoreI->getValueOperand() == CurPtr) { - LLVM_DEBUG(dbgs() << "[AAPointerInfo] Escaping use in store " - << *StoreI << "\n"); + BasicBlock *BB = LoadI->getParent(); + auto IsValidAssume = [&](IntrinsicInst &IntrI) { + if (IntrI.getIntrinsicID() != Intrinsic::assume) return false; + BasicBlock *IntrBB = IntrI.getParent(); + if (IntrI.getParent() == BB) { + if (IsImpactedInRange(LoadI->getNextNonDebugInstruction(), &IntrI)) + return false; + } else { + auto PredIt = pred_begin(IntrBB); + if ((*PredIt) != BB) + return false; + if (++PredIt != pred_end(IntrBB)) + return false; + for (auto *SuccBB : successors(BB)) { + if (SuccBB == IntrBB) + continue; + if (isa<UnreachableInst>(SuccBB->getTerminator())) + continue; + return false; + } + if (IsImpactedInRange(LoadI->getNextNonDebugInstruction(), + BB->getTerminator())) + return false; + if (IsImpactedInRange(&IntrBB->front(), &IntrI)) + return false; } - // If the access is to a pointer that may or may not be the associated - // value, e.g. due to a PHI, we cannot assume it will be written. - AccessKind AK = AccessKind::AK_W; - if (getUnderlyingObject(CurPtr) == &AssociatedValue) - AK = AccessKind(AK | AccessKind::AK_MUST); - else - AK = AccessKind(AK | AccessKind::AK_MAY); - bool UsedAssumedInformation = false; - Optional<Value *> Content = - A.getAssumedSimplified(*StoreI->getValueOperand(), *this, - UsedAssumedInformation, AA::Interprocedural); - return handleAccess(A, *StoreI, *CurPtr, Content, AK, - OffsetInfoMap[CurPtr].Offset, Changed, - StoreI->getValueOperand()->getType()); + return true; + }; + + std::pair<Value *, IntrinsicInst *> Assumption; + for (const Use &LoadU : LoadI->uses()) { + if (auto *CmpI = dyn_cast<CmpInst>(LoadU.getUser())) { + if (!CmpI->isEquality() || !CmpI->isTrueWhenEqual()) + continue; + for (const Use &CmpU : CmpI->uses()) { + if (auto *IntrI = dyn_cast<IntrinsicInst>(CmpU.getUser())) { + if (!IsValidAssume(*IntrI)) + continue; + int Idx = CmpI->getOperandUse(0) == LoadU; + Assumption = {CmpI->getOperand(Idx), IntrI}; + break; + } + } + } + if (Assumption.first) + break; } - if (auto *CB = dyn_cast<CallBase>(Usr)) { - if (CB->isLifetimeStartOrEnd()) - return true; - if (getFreedOperand(CB, TLI) == U) - return true; - if (CB->isArgOperand(&U)) { - unsigned ArgNo = CB->getArgOperandNo(&U); - const auto &CSArgPI = A.getAAFor<AAPointerInfo>( - *this, IRPosition::callsite_argument(*CB, ArgNo), - DepClassTy::REQUIRED); - Changed = translateAndAddState(A, CSArgPI, - OffsetInfoMap[CurPtr].Offset, *CB) | - Changed; - return isValidState(); + + // Check if we found an assumption associated with this load. + if (!Assumption.first || !Assumption.second) + return true; + + LLVM_DEBUG(dbgs() << "[AAPointerInfo] Assumption found " + << *Assumption.second << ": " << *LoadI + << " == " << *Assumption.first << "\n"); + + return handleAccess( + A, *Assumption.second, Assumption.first, AccessKind::AK_ASSUMPTION, + OffsetInfoMap[CurPtr].Offsets, Changed, *LoadI->getType()); + } + + auto HandleStoreLike = [&](Instruction &I, Value *ValueOp, Type &ValueTy, + ArrayRef<Value *> OtherOps, AccessKind AK) { + for (auto *OtherOp : OtherOps) { + if (OtherOp == CurPtr) { + LLVM_DEBUG( + dbgs() + << "[AAPointerInfo] Escaping use in store like instruction " << I + << "\n"); + return false; } - LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB - << "\n"); - // TODO: Allow some call uses - return false; } - LLVM_DEBUG(dbgs() << "[AAPointerInfo] User not handled " << *Usr << "\n"); - return false; + // If the access is to a pointer that may or may not be the associated + // value, e.g. due to a PHI, we cannot assume it will be written. + if (getUnderlyingObject(CurPtr) == &AssociatedValue) + AK = AccessKind(AK | AccessKind::AK_MUST); + else + AK = AccessKind(AK | AccessKind::AK_MAY); + bool UsedAssumedInformation = false; + std::optional<Value *> Content = nullptr; + if (ValueOp) + Content = A.getAssumedSimplified( + *ValueOp, *this, UsedAssumedInformation, AA::Interprocedural); + return handleAccess(A, I, Content, AK, OffsetInfoMap[CurPtr].Offsets, + Changed, ValueTy); }; - auto EquivalentUseCB = [&](const Use &OldU, const Use &NewU) { - if (OffsetInfoMap.count(NewU)) { - LLVM_DEBUG({ - if (!(OffsetInfoMap[NewU] == OffsetInfoMap[OldU])) { - dbgs() << "[AAPointerInfo] Equivalent use callback failed: " - << OffsetInfoMap[NewU].Offset << " vs " - << OffsetInfoMap[OldU].Offset << "\n"; - } - }); - return OffsetInfoMap[NewU] == OffsetInfoMap[OldU]; + + if (auto *StoreI = dyn_cast<StoreInst>(Usr)) + return HandleStoreLike(*StoreI, StoreI->getValueOperand(), + *StoreI->getValueOperand()->getType(), + {StoreI->getValueOperand()}, AccessKind::AK_W); + if (auto *RMWI = dyn_cast<AtomicRMWInst>(Usr)) + return HandleStoreLike(*RMWI, nullptr, *RMWI->getValOperand()->getType(), + {RMWI->getValOperand()}, AccessKind::AK_RW); + if (auto *CXI = dyn_cast<AtomicCmpXchgInst>(Usr)) + return HandleStoreLike( + *CXI, nullptr, *CXI->getNewValOperand()->getType(), + {CXI->getCompareOperand(), CXI->getNewValOperand()}, + AccessKind::AK_RW); + + if (auto *CB = dyn_cast<CallBase>(Usr)) { + if (CB->isLifetimeStartOrEnd()) + return true; + if (getFreedOperand(CB, TLI) == U) + return true; + if (CB->isArgOperand(&U)) { + unsigned ArgNo = CB->getArgOperandNo(&U); + const auto &CSArgPI = A.getAAFor<AAPointerInfo>( + *this, IRPosition::callsite_argument(*CB, ArgNo), + DepClassTy::REQUIRED); + Changed = translateAndAddState(A, CSArgPI, OffsetInfoMap[CurPtr], *CB) | + Changed; + return isValidState(); } - OffsetInfoMap[NewU] = OffsetInfoMap[OldU]; - return true; - }; - if (!A.checkForAllUses(UsePred, *this, AssociatedValue, - /* CheckBBLivenessOnly */ true, DepClassTy::OPTIONAL, - /* IgnoreDroppableUses */ true, EquivalentUseCB)) { - LLVM_DEBUG( - dbgs() << "[AAPointerInfo] Check for all uses failed, abort!\n"); - return indicatePessimisticFixpoint(); + LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB + << "\n"); + // TODO: Allow some call uses + return false; } - LLVM_DEBUG({ - dbgs() << "Accesses by bin after update:\n"; - dumpState(dbgs()); - }); - - return Changed; + LLVM_DEBUG(dbgs() << "[AAPointerInfo] User not handled " << *Usr << "\n"); + return false; + }; + auto EquivalentUseCB = [&](const Use &OldU, const Use &NewU) { + assert(OffsetInfoMap.count(OldU) && "Old use should be known already!"); + if (OffsetInfoMap.count(NewU)) { + LLVM_DEBUG({ + if (!(OffsetInfoMap[NewU] == OffsetInfoMap[OldU])) { + dbgs() << "[AAPointerInfo] Equivalent use callback failed: " + << OffsetInfoMap[NewU] << " vs " << OffsetInfoMap[OldU] + << "\n"; + } + }); + return OffsetInfoMap[NewU] == OffsetInfoMap[OldU]; + } + OffsetInfoMap[NewU] = OffsetInfoMap[OldU]; + return true; + }; + if (!A.checkForAllUses(UsePred, *this, AssociatedValue, + /* CheckBBLivenessOnly */ true, DepClassTy::OPTIONAL, + /* IgnoreDroppableUses */ true, EquivalentUseCB)) { + LLVM_DEBUG(dbgs() << "[AAPointerInfo] Check for all uses failed, abort!\n"); + return indicatePessimisticFixpoint(); } - /// See AbstractAttribute::trackStatistics() - void trackStatistics() const override { - AAPointerInfoImpl::trackPointerInfoStatistics(getIRPosition()); - } -}; + LLVM_DEBUG({ + dbgs() << "Accesses by bin after update:\n"; + dumpState(dbgs()); + }); + + return Changed; +} struct AAPointerInfoReturned final : AAPointerInfoImpl { AAPointerInfoReturned(const IRPosition &IRP, Attributor &A) @@ -1490,24 +1870,21 @@ struct AAPointerInfoCallSiteArgument final : AAPointerInfoFloating { // accessed. if (auto *MI = dyn_cast_or_null<MemIntrinsic>(getCtxI())) { ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength()); - int64_t LengthVal = OffsetAndSize::Unknown; + int64_t LengthVal = AA::RangeTy::Unknown; if (Length) LengthVal = Length->getSExtValue(); - Value &Ptr = getAssociatedValue(); unsigned ArgNo = getIRPosition().getCallSiteArgNo(); ChangeStatus Changed = ChangeStatus::UNCHANGED; - if (ArgNo == 0) { - handleAccess(A, *MI, Ptr, nullptr, AccessKind::AK_MUST_WRITE, 0, - Changed, nullptr, LengthVal); - } else if (ArgNo == 1) { - handleAccess(A, *MI, Ptr, nullptr, AccessKind::AK_MUST_READ, 0, Changed, - nullptr, LengthVal); - } else { + if (ArgNo > 1) { LLVM_DEBUG(dbgs() << "[AAPointerInfo] Unhandled memory intrinsic " << *MI << "\n"); return indicatePessimisticFixpoint(); + } else { + auto Kind = + ArgNo == 0 ? AccessKind::AK_MUST_WRITE : AccessKind::AK_MUST_READ; + Changed = + Changed | addAccess(A, {0, LengthVal}, *MI, nullptr, Kind, nullptr); } - LLVM_DEBUG({ dbgs() << "Accesses by bin after update:\n"; dumpState(dbgs()); @@ -1521,13 +1898,31 @@ struct AAPointerInfoCallSiteArgument final : AAPointerInfoFloating { // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. Argument *Arg = getAssociatedArgument(); - if (!Arg) + if (Arg) { + const IRPosition &ArgPos = IRPosition::argument(*Arg); + auto &ArgAA = + A.getAAFor<AAPointerInfo>(*this, ArgPos, DepClassTy::REQUIRED); + if (ArgAA.getState().isValidState()) + return translateAndAddStateFromCallee(A, ArgAA, + *cast<CallBase>(getCtxI())); + if (!Arg->getParent()->isDeclaration()) + return indicatePessimisticFixpoint(); + } + + const auto &NoCaptureAA = + A.getAAFor<AANoCapture>(*this, getIRPosition(), DepClassTy::OPTIONAL); + + if (!NoCaptureAA.isAssumedNoCapture()) return indicatePessimisticFixpoint(); - const IRPosition &ArgPos = IRPosition::argument(*Arg); - auto &ArgAA = - A.getAAFor<AAPointerInfo>(*this, ArgPos, DepClassTy::REQUIRED); - return translateAndAddState(A, ArgAA, 0, *cast<CallBase>(getCtxI()), - /* FromCallee */ true); + + bool IsKnown = false; + if (AA::isAssumedReadNone(A, getIRPosition(), *this, IsKnown)) + return ChangeStatus::UNCHANGED; + bool ReadOnly = AA::isAssumedReadOnly(A, getIRPosition(), *this, IsKnown); + auto Kind = + ReadOnly ? AccessKind::AK_MAY_READ : AccessKind::AK_MAY_READ_WRITE; + return addAccess(A, AA::RangeTy::getUnknown(), *getCtxI(), nullptr, Kind, + nullptr); } /// See AbstractAttribute::trackStatistics() @@ -1709,9 +2104,9 @@ public: } /// Return an assumed unique return value if a single candidate is found. If - /// there cannot be one, return a nullptr. If it is not clear yet, return the - /// Optional::NoneType. - Optional<Value *> getAssumedUniqueReturnValue(Attributor &A) const; + /// there cannot be one, return a nullptr. If it is not clear yet, return + /// std::nullopt. + std::optional<Value *> getAssumedUniqueReturnValue(Attributor &A) const; /// See AbstractState::checkForAllReturnedValues(...). bool checkForAllReturnedValuesAndReturnInsts( @@ -1749,16 +2144,16 @@ ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) { "Number of function with known return values"); // Check if we have an assumed unique return value that we could manifest. - Optional<Value *> UniqueRV = getAssumedUniqueReturnValue(A); + std::optional<Value *> UniqueRV = getAssumedUniqueReturnValue(A); - if (!UniqueRV || !UniqueRV.value()) + if (!UniqueRV || !*UniqueRV) return Changed; // Bookkeeping. STATS_DECLTRACK(UniqueReturnValue, FunctionReturn, "Number of function with unique return"); // If the assumed unique return value is an argument, annotate it. - if (auto *UniqueRVArg = dyn_cast<Argument>(UniqueRV.value())) { + if (auto *UniqueRVArg = dyn_cast<Argument>(*UniqueRV)) { if (UniqueRVArg->getType()->canLosslesslyBitCastTo( getAssociatedFunction()->getReturnType())) { getIRPosition() = IRPosition::argument(*UniqueRVArg); @@ -1773,19 +2168,19 @@ const std::string AAReturnedValuesImpl::getAsStr() const { (isValidState() ? std::to_string(getNumReturnValues()) : "?") + ")"; } -Optional<Value *> +std::optional<Value *> AAReturnedValuesImpl::getAssumedUniqueReturnValue(Attributor &A) const { // If checkForAllReturnedValues provides a unique value, ignoring potential // undef values that can also be present, it is assumed to be the actual // return value and forwarded to the caller of this method. If there are // multiple, a nullptr is returned indicating there cannot be a unique // returned value. - Optional<Value *> UniqueRV; + std::optional<Value *> UniqueRV; Type *Ty = getAssociatedFunction()->getReturnType(); auto Pred = [&](Value &RV) -> bool { UniqueRV = AA::combineOptionalValuesInAAValueLatice(UniqueRV, &RV, Ty); - return UniqueRV != Optional<Value *>(nullptr); + return UniqueRV != std::optional<Value *>(nullptr); }; if (!A.checkForAllReturnedValues(Pred, *this)) @@ -1802,7 +2197,7 @@ bool AAReturnedValuesImpl::checkForAllReturnedValuesAndReturnInsts( // Check all returned values but ignore call sites as long as we have not // encountered an overdefined one during an update. - for (auto &It : ReturnedValues) { + for (const auto &It : ReturnedValues) { Value *RV = It.first; if (!Pred(*RV, It.second)) return false; @@ -1876,6 +2271,23 @@ struct AAReturnedValuesCallSite final : AAReturnedValuesImpl { /// ------------------------ NoSync Function Attribute ------------------------- +bool AANoSync::isAlignedBarrier(const CallBase &CB, bool ExecutedAligned) { + switch (CB.getIntrinsicID()) { + case Intrinsic::nvvm_barrier0: + case Intrinsic::nvvm_barrier0_and: + case Intrinsic::nvvm_barrier0_or: + case Intrinsic::nvvm_barrier0_popc: + return true; + case Intrinsic::amdgcn_s_barrier: + if (ExecutedAligned) + return true; + break; + default: + break; + } + return hasAssumption(CB, KnownAssumptionString("ompx_aligned_barrier")); +} + bool AANoSync::isNonRelaxedAtomic(const Instruction *I) { if (!I->isAtomic()) return false; @@ -2235,7 +2647,7 @@ static int64_t getKnownNonNullAndDerefBytesForUse( return DerefAA.getKnownDereferenceableBytes(); } - Optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I); + std::optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I); if (!Loc || Loc->Ptr != UseV || !Loc->Size.isPrecise() || I->isVolatile()) return 0; @@ -2461,9 +2873,9 @@ struct AANoRecurseFunction final : AANoRecurseImpl { return ChangeStatus::UNCHANGED; } - const AAFunctionReachability &EdgeReachability = - A.getAAFor<AAFunctionReachability>(*this, getIRPosition(), - DepClassTy::REQUIRED); + const AAInterFnReachability &EdgeReachability = + A.getAAFor<AAInterFnReachability>(*this, getIRPosition(), + DepClassTy::REQUIRED); if (EdgeReachability.canReach(A, *getAnchorScope())) return indicatePessimisticFixpoint(); return ChangeStatus::UNCHANGED; @@ -2534,10 +2946,11 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { // Either we stopped and the appropriate action was taken, // or we got back a simplified value to continue. - Optional<Value *> SimplifiedPtrOp = stopOnUndefOrAssumed(A, PtrOp, &I); - if (!SimplifiedPtrOp || !SimplifiedPtrOp.value()) + std::optional<Value *> SimplifiedPtrOp = + stopOnUndefOrAssumed(A, PtrOp, &I); + if (!SimplifiedPtrOp || !*SimplifiedPtrOp) return true; - const Value *PtrOpVal = SimplifiedPtrOp.value(); + const Value *PtrOpVal = *SimplifiedPtrOp; // A memory access through a pointer is considered UB // only if the pointer has constant null value. @@ -2578,7 +2991,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { // Either we stopped and the appropriate action was taken, // or we got back a simplified value to continue. - Optional<Value *> SimplifiedCond = + std::optional<Value *> SimplifiedCond = stopOnUndefOrAssumed(A, BrInst->getCondition(), BrInst); if (!SimplifiedCond || !*SimplifiedCond) return true; @@ -2622,19 +3035,19 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { if (!NoUndefAA.isKnownNoUndef()) continue; bool UsedAssumedInformation = false; - Optional<Value *> SimplifiedVal = + std::optional<Value *> SimplifiedVal = A.getAssumedSimplified(IRPosition::value(*ArgVal), *this, UsedAssumedInformation, AA::Interprocedural); if (UsedAssumedInformation) continue; - if (SimplifiedVal && !SimplifiedVal.value()) + if (SimplifiedVal && !*SimplifiedVal) return true; - if (!SimplifiedVal || isa<UndefValue>(*SimplifiedVal.value())) { + if (!SimplifiedVal || isa<UndefValue>(**SimplifiedVal)) { KnownUBInsts.insert(&I); continue; } if (!ArgVal->getType()->isPointerTy() || - !isa<ConstantPointerNull>(*SimplifiedVal.value())) + !isa<ConstantPointerNull>(**SimplifiedVal)) continue; auto &NonNullAA = A.getAAFor<AANonNull>(*this, CalleeArgumentIRP, DepClassTy::NONE); @@ -2648,7 +3061,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { auto &RI = cast<ReturnInst>(I); // Either we stopped and the appropriate action was taken, // or we got back a simplified return value to continue. - Optional<Value *> SimplifiedRetValue = + std::optional<Value *> SimplifiedRetValue = stopOnUndefOrAssumed(A, RI.getReturnValue(), &I); if (!SimplifiedRetValue || !*SimplifiedRetValue) return true; @@ -2788,14 +3201,14 @@ private: // - If the value is assumed, then stop. // - If the value is known but undef, then consider it UB. // - Otherwise, do specific processing with the simplified value. - // We return None in the first 2 cases to signify that an appropriate + // We return std::nullopt in the first 2 cases to signify that an appropriate // action was taken and the caller should stop. // Otherwise, we return the simplified value that the caller should // use for specific processing. - Optional<Value *> stopOnUndefOrAssumed(Attributor &A, Value *V, - Instruction *I) { + std::optional<Value *> stopOnUndefOrAssumed(Attributor &A, Value *V, + Instruction *I) { bool UsedAssumedInformation = false; - Optional<Value *> SimplifiedV = + std::optional<Value *> SimplifiedV = A.getAssumedSimplified(IRPosition::value(*V), *this, UsedAssumedInformation, AA::Interprocedural); if (!UsedAssumedInformation) { @@ -2804,7 +3217,7 @@ private: // If it is known (which we tested above) but it doesn't have a value, // then we can assume `undef` and hence the instruction is UB. KnownUBInsts.insert(I); - return llvm::None; + return std::nullopt; } if (!*SimplifiedV) return nullptr; @@ -2812,7 +3225,7 @@ private: } if (isa<UndefValue>(V)) { KnownUBInsts.insert(I); - return llvm::None; + return std::nullopt; } return V; } @@ -2975,30 +3388,245 @@ struct AAWillReturnCallSite final : AAWillReturnImpl { }; } // namespace -/// -------------------AAReachability Attribute-------------------------- +/// -------------------AAIntraFnReachability Attribute-------------------------- + +/// All information associated with a reachability query. This boilerplate code +/// is used by both AAIntraFnReachability and AAInterFnReachability, with +/// different \p ToTy values. +template <typename ToTy> struct ReachabilityQueryInfo { + enum class Reachable { + No, + Yes, + }; + + /// Start here, + const Instruction *From = nullptr; + /// reach this place, + const ToTy *To = nullptr; + /// without going through any of these instructions, + const AA::InstExclusionSetTy *ExclusionSet = nullptr; + /// and remember if it worked: + Reachable Result = Reachable::No; + + ReachabilityQueryInfo(const Instruction *From, const ToTy *To) + : From(From), To(To) {} + + /// Constructor replacement to ensure unique and stable sets are used for the + /// cache. + ReachabilityQueryInfo(Attributor &A, const Instruction &From, const ToTy &To, + const AA::InstExclusionSetTy *ES) + : From(&From), To(&To), ExclusionSet(ES) { + + if (ExclusionSet && !ExclusionSet->empty()) { + ExclusionSet = + A.getInfoCache().getOrCreateUniqueBlockExecutionSet(ExclusionSet); + } else { + ExclusionSet = nullptr; + } + } + + ReachabilityQueryInfo(const ReachabilityQueryInfo &RQI) + : From(RQI.From), To(RQI.To), ExclusionSet(RQI.ExclusionSet) { + assert(RQI.Result == Reachable::No && + "Didn't expect to copy an explored RQI!"); + } +}; + +namespace llvm { +template <typename ToTy> struct DenseMapInfo<ReachabilityQueryInfo<ToTy> *> { + using InstSetDMI = DenseMapInfo<const AA::InstExclusionSetTy *>; + using PairDMI = DenseMapInfo<std::pair<const Instruction *, const ToTy *>>; + + static ReachabilityQueryInfo<ToTy> EmptyKey; + static ReachabilityQueryInfo<ToTy> TombstoneKey; + + static inline ReachabilityQueryInfo<ToTy> *getEmptyKey() { return &EmptyKey; } + static inline ReachabilityQueryInfo<ToTy> *getTombstoneKey() { + return &TombstoneKey; + } + static unsigned getHashValue(const ReachabilityQueryInfo<ToTy> *RQI) { + unsigned H = PairDMI ::getHashValue({RQI->From, RQI->To}); + H += InstSetDMI::getHashValue(RQI->ExclusionSet); + return H; + } + static bool isEqual(const ReachabilityQueryInfo<ToTy> *LHS, + const ReachabilityQueryInfo<ToTy> *RHS) { + if (!PairDMI::isEqual({LHS->From, LHS->To}, {RHS->From, RHS->To})) + return false; + return InstSetDMI::isEqual(LHS->ExclusionSet, RHS->ExclusionSet); + } +}; + +#define DefineKeys(ToTy) \ + template <> \ + ReachabilityQueryInfo<ToTy> \ + DenseMapInfo<ReachabilityQueryInfo<ToTy> *>::EmptyKey = \ + ReachabilityQueryInfo<ToTy>( \ + DenseMapInfo<const Instruction *>::getEmptyKey(), \ + DenseMapInfo<const ToTy *>::getEmptyKey()); \ + template <> \ + ReachabilityQueryInfo<ToTy> \ + DenseMapInfo<ReachabilityQueryInfo<ToTy> *>::TombstoneKey = \ + ReachabilityQueryInfo<ToTy>( \ + DenseMapInfo<const Instruction *>::getTombstoneKey(), \ + DenseMapInfo<const ToTy *>::getTombstoneKey()); + +DefineKeys(Instruction) DefineKeys(Function) +#undef DefineKeys + +} // namespace llvm namespace { -struct AAReachabilityImpl : AAReachability { - AAReachabilityImpl(const IRPosition &IRP, Attributor &A) - : AAReachability(IRP, A) {} + +template <typename BaseTy, typename ToTy> +struct CachedReachabilityAA : public BaseTy { + using RQITy = ReachabilityQueryInfo<ToTy>; + + CachedReachabilityAA<BaseTy, ToTy>(const IRPosition &IRP, Attributor &A) + : BaseTy(IRP, A) {} + + /// See AbstractAttribute::isQueryAA. + bool isQueryAA() const override { return true; } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + InUpdate = true; + for (RQITy *RQI : QueryVector) { + if (RQI->Result == RQITy::Reachable::No && isReachableImpl(A, *RQI)) + Changed = ChangeStatus::CHANGED; + } + InUpdate = false; + return Changed; + } + + virtual bool isReachableImpl(Attributor &A, RQITy &RQI) = 0; + + bool rememberResult(Attributor &A, typename RQITy::Reachable Result, + RQITy &RQI) { + if (Result == RQITy::Reachable::No) { + if (!InUpdate) + A.registerForUpdate(*this); + return false; + } + assert(RQI.Result == RQITy::Reachable::No && "Already reachable?"); + RQI.Result = Result; + return true; + } const std::string getAsStr() const override { // TODO: Return the number of reachable queries. - return "reachable"; + return "#queries(" + std::to_string(QueryVector.size()) + ")"; } - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - return ChangeStatus::UNCHANGED; + RQITy *checkQueryCache(Attributor &A, RQITy &StackRQI, + typename RQITy::Reachable &Result) { + if (!this->getState().isValidState()) { + Result = RQITy::Reachable::Yes; + return nullptr; + } + + auto It = QueryCache.find(&StackRQI); + if (It != QueryCache.end()) { + Result = (*It)->Result; + return nullptr; + } + + RQITy *RQIPtr = new (A.Allocator) RQITy(StackRQI); + QueryVector.push_back(RQIPtr); + QueryCache.insert(RQIPtr); + return RQIPtr; } + +private: + bool InUpdate = false; + SmallVector<RQITy *> QueryVector; + DenseSet<RQITy *> QueryCache; }; -struct AAReachabilityFunction final : public AAReachabilityImpl { - AAReachabilityFunction(const IRPosition &IRP, Attributor &A) - : AAReachabilityImpl(IRP, A) {} +struct AAIntraFnReachabilityFunction final + : public CachedReachabilityAA<AAIntraFnReachability, Instruction> { + AAIntraFnReachabilityFunction(const IRPosition &IRP, Attributor &A) + : CachedReachabilityAA<AAIntraFnReachability, Instruction>(IRP, A) {} + + bool isAssumedReachable( + Attributor &A, const Instruction &From, const Instruction &To, + const AA::InstExclusionSetTy *ExclusionSet) const override { + auto *NonConstThis = const_cast<AAIntraFnReachabilityFunction *>(this); + if (&From == &To) + return true; + + RQITy StackRQI(A, From, To, ExclusionSet); + typename RQITy::Reachable Result; + if (RQITy *RQIPtr = NonConstThis->checkQueryCache(A, StackRQI, Result)) { + return NonConstThis->isReachableImpl(A, *RQIPtr); + } + return Result == RQITy::Reachable::Yes; + } + + bool isReachableImpl(Attributor &A, RQITy &RQI) override { + const Instruction *Origin = RQI.From; + + auto WillReachInBlock = [=](const Instruction &From, const Instruction &To, + const AA::InstExclusionSetTy *ExclusionSet) { + const Instruction *IP = &From; + while (IP && IP != &To) { + if (ExclusionSet && IP != Origin && ExclusionSet->count(IP)) + break; + IP = IP->getNextNode(); + } + return IP == &To; + }; + + const BasicBlock *FromBB = RQI.From->getParent(); + const BasicBlock *ToBB = RQI.To->getParent(); + assert(FromBB->getParent() == ToBB->getParent() && + "Not an intra-procedural query!"); + + // Check intra-block reachability, however, other reaching paths are still + // possible. + if (FromBB == ToBB && + WillReachInBlock(*RQI.From, *RQI.To, RQI.ExclusionSet)) + return rememberResult(A, RQITy::Reachable::Yes, RQI); + + SmallPtrSet<const BasicBlock *, 16> ExclusionBlocks; + if (RQI.ExclusionSet) + for (auto *I : *RQI.ExclusionSet) + ExclusionBlocks.insert(I->getParent()); + + // Check if we make it out of the FromBB block at all. + if (ExclusionBlocks.count(FromBB) && + !WillReachInBlock(*RQI.From, *FromBB->getTerminator(), + RQI.ExclusionSet)) + return rememberResult(A, RQITy::Reachable::No, RQI); + + SmallPtrSet<const BasicBlock *, 16> Visited; + SmallVector<const BasicBlock *, 16> Worklist; + Worklist.push_back(FromBB); + + auto &LivenessAA = + A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); + while (!Worklist.empty()) { + const BasicBlock *BB = Worklist.pop_back_val(); + if (!Visited.insert(BB).second) + continue; + for (const BasicBlock *SuccBB : successors(BB)) { + if (LivenessAA.isEdgeDead(BB, SuccBB)) + continue; + if (SuccBB == ToBB && + WillReachInBlock(SuccBB->front(), *RQI.To, RQI.ExclusionSet)) + return rememberResult(A, RQITy::Reachable::Yes, RQI); + if (ExclusionBlocks.count(SuccBB)) + continue; + Worklist.push_back(SuccBB); + } + } + + return rememberResult(A, RQITy::Reachable::No, RQI); + } /// See AbstractAttribute::trackStatistics() - void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(reachable); } + void trackStatistics() const override {} }; } // namespace @@ -3241,7 +3869,7 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { } if (!AA::isPotentiallyReachable( - A, *UserI, *getCtxI(), *this, + A, *UserI, *getCtxI(), *this, /* ExclusionSet */ nullptr, [ScopeFn](const Function &Fn) { return &Fn != ScopeFn; })) return true; } @@ -3443,7 +4071,7 @@ struct AAIsDeadValueImpl : public AAIsDead { if (!A.isRunOn(*I->getFunction())) return false; bool UsedAssumedInformation = false; - Optional<Constant *> C = + std::optional<Constant *> C = A.getAssumedConstant(V, *this, UsedAssumedInformation); if (!C || *C) return true; @@ -3503,19 +4131,48 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { } } - bool isDeadStore(Attributor &A, StoreInst &SI) { + bool isDeadStore(Attributor &A, StoreInst &SI, + SmallSetVector<Instruction *, 8> *AssumeOnlyInst = nullptr) { // Lang ref now states volatile store is not UB/dead, let's skip them. if (SI.isVolatile()) return false; + // If we are collecting assumes to be deleted we are in the manifest stage. + // It's problematic to collect the potential copies again now so we use the + // cached ones. bool UsedAssumedInformation = false; - SmallSetVector<Value *, 4> PotentialCopies; - if (!AA::getPotentialCopiesOfStoredValue(A, SI, PotentialCopies, *this, - UsedAssumedInformation)) - return false; + if (!AssumeOnlyInst) { + PotentialCopies.clear(); + if (!AA::getPotentialCopiesOfStoredValue(A, SI, PotentialCopies, *this, + UsedAssumedInformation)) { + LLVM_DEBUG( + dbgs() + << "[AAIsDead] Could not determine potential copies of store!\n"); + return false; + } + } + LLVM_DEBUG(dbgs() << "[AAIsDead] Store has " << PotentialCopies.size() + << " potential copies.\n"); + + InformationCache &InfoCache = A.getInfoCache(); return llvm::all_of(PotentialCopies, [&](Value *V) { - return A.isAssumedDead(IRPosition::value(*V), this, nullptr, - UsedAssumedInformation); + if (A.isAssumedDead(IRPosition::value(*V), this, nullptr, + UsedAssumedInformation)) + return true; + if (auto *LI = dyn_cast<LoadInst>(V)) { + if (llvm::all_of(LI->uses(), [&](const Use &U) { + return InfoCache.isOnlyUsedByAssume( + cast<Instruction>(*U.getUser())) || + A.isAssumedDead(U, this, nullptr, UsedAssumedInformation); + })) { + if (AssumeOnlyInst) + AssumeOnlyInst->insert(LI); + return true; + } + } + LLVM_DEBUG(dbgs() << "[AAIsDead] Potential copy " << *V + << " is assumed live!\n"); + return false; }); } @@ -3555,8 +4212,21 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { // isAssumedSideEffectFree returns true here again because it might not be // the case and only the users are dead but the instruction (=call) is // still needed. - if (isa<StoreInst>(I) || - (isAssumedSideEffectFree(A, I) && !isa<InvokeInst>(I))) { + if (auto *SI = dyn_cast<StoreInst>(I)) { + SmallSetVector<Instruction *, 8> AssumeOnlyInst; + bool IsDead = isDeadStore(A, *SI, &AssumeOnlyInst); + (void)IsDead; + assert(IsDead && "Store was assumed to be dead!"); + A.deleteAfterManifest(*I); + for (size_t i = 0; i < AssumeOnlyInst.size(); ++i) { + Instruction *AOI = AssumeOnlyInst[i]; + for (auto *Usr : AOI->users()) + AssumeOnlyInst.insert(cast<Instruction>(Usr)); + A.deleteAfterManifest(*AOI); + } + return ChangeStatus::CHANGED; + } + if (isAssumedSideEffectFree(A, I) && !isa<InvokeInst>(I)) { A.deleteAfterManifest(*I); return ChangeStatus::CHANGED; } @@ -3568,6 +4238,10 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { void trackStatistics() const override { STATS_DECLTRACK_FLOATING_ATTR(IsDead) } + +private: + // The potential copies of a dead store, used for deletion during manifest. + SmallSetVector<Value *, 4> PotentialCopies; }; struct AAIsDeadArgument : public AAIsDeadFloating { @@ -3746,8 +4420,18 @@ struct AAIsDeadFunction : public AAIsDead { indicatePessimisticFixpoint(); return; } - ToBeExploredFrom.insert(&F->getEntryBlock().front()); - assumeLive(A, F->getEntryBlock()); + if (!isAssumedDeadInternalFunction(A)) { + ToBeExploredFrom.insert(&F->getEntryBlock().front()); + assumeLive(A, F->getEntryBlock()); + } + } + + bool isAssumedDeadInternalFunction(Attributor &A) { + if (!getAnchorScope()->hasLocalLinkage()) + return false; + bool UsedAssumedInformation = false; + return A.checkForAllCallSites([](AbstractCallSite) { return false; }, *this, + true, UsedAssumedInformation); } /// See AbstractAttribute::getAsStr(). @@ -3950,7 +4634,7 @@ identifyAliveSuccessors(Attributor &A, const BranchInst &BI, if (BI.getNumSuccessors() == 1) { AliveSuccessors.push_back(&BI.getSuccessor(0)->front()); } else { - Optional<Constant *> C = + std::optional<Constant *> C = A.getAssumedConstant(*BI.getCondition(), AA, UsedAssumedInformation); if (!C || isa_and_nonnull<UndefValue>(*C)) { // No value yet, assume both edges are dead. @@ -3972,13 +4656,13 @@ identifyAliveSuccessors(Attributor &A, const SwitchInst &SI, AbstractAttribute &AA, SmallVectorImpl<const Instruction *> &AliveSuccessors) { bool UsedAssumedInformation = false; - Optional<Constant *> C = + std::optional<Constant *> C = A.getAssumedConstant(*SI.getCondition(), AA, UsedAssumedInformation); - if (!C || isa_and_nonnull<UndefValue>(C.value())) { + if (!C || isa_and_nonnull<UndefValue>(*C)) { // No value yet, assume all edges are dead. - } else if (isa_and_nonnull<ConstantInt>(C.value())) { - for (auto &CaseIt : SI.cases()) { - if (CaseIt.getCaseValue() == C.value()) { + } else if (isa_and_nonnull<ConstantInt>(*C)) { + for (const auto &CaseIt : SI.cases()) { + if (CaseIt.getCaseValue() == *C) { AliveSuccessors.push_back(&CaseIt.getCaseSuccessor()->front()); return UsedAssumedInformation; } @@ -3995,6 +4679,16 @@ identifyAliveSuccessors(Attributor &A, const SwitchInst &SI, ChangeStatus AAIsDeadFunction::updateImpl(Attributor &A) { ChangeStatus Change = ChangeStatus::UNCHANGED; + if (AssumedLiveBlocks.empty()) { + if (isAssumedDeadInternalFunction(A)) + return ChangeStatus::UNCHANGED; + + Function *F = getAnchorScope(); + ToBeExploredFrom.insert(&F->getEntryBlock().front()); + assumeLive(A, F->getEntryBlock()); + Change = ChangeStatus::CHANGED; + } + LLVM_DEBUG(dbgs() << "[AAIsDead] Live [" << AssumedLiveBlocks.size() << "/" << getAnchorScope()->size() << "] BBs and " << ToBeExploredFrom.size() << " exploration points and " @@ -4171,7 +4865,7 @@ struct AADereferenceableImpl : AADereferenceable { if (!UseV->getType()->isPointerTy()) return; - Optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I); + std::optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I); if (!Loc || Loc->Ptr != UseV || !Loc->Size.isPrecise() || I->isVolatile()) return; @@ -4296,7 +4990,7 @@ struct AADereferenceableFloating : AADereferenceableImpl { } else if (OffsetSExt > 0) { // If something was stripped but there is circular reasoning we look // for the offset. If it is positive we basically decrease the - // dereferenceable bytes in a circluar loop now, which will simply + // dereferenceable bytes in a circular loop now, which will simply // drive them down to the known value in a very slow way which we // can accelerate. T.indicatePessimisticFixpoint(); @@ -4427,8 +5121,7 @@ static unsigned getKnownAlignForUse(Attributor &A, AAAlign &QueryingAA, // So we can say that the maximum power of two which is a divisor of // gcd(Offset, Alignment) is an alignment. - uint32_t gcd = - greatestCommonDivisor(uint32_t(abs((int32_t)Offset)), Alignment); + uint32_t gcd = std::gcd(uint32_t(abs((int32_t)Offset)), Alignment); Alignment = llvm::PowerOf2Floor(gcd); } } @@ -4563,8 +5256,8 @@ struct AAAlignFloating : AAAlignImpl { // So we can say that the maximum power of two which is a divisor of // gcd(Offset, Alignment) is an alignment. - uint32_t gcd = greatestCommonDivisor(uint32_t(abs((int32_t)Offset)), - uint32_t(PA.value())); + uint32_t gcd = + std::gcd(uint32_t(abs((int32_t)Offset)), uint32_t(PA.value())); Alignment = llvm::PowerOf2Floor(gcd); } else { Alignment = V.getPointerAlignment(DL).value(); @@ -4834,7 +5527,7 @@ struct AAInstanceInfoImpl : public AAInstanceInfo { // If this call base might reach the scope again we might forward the // argument back here. This is very conservative. if (AA::isPotentiallyReachable( - A, *CB, *Scope, *this, + A, *CB, *Scope, *this, /* ExclusionSet */ nullptr, [Scope](const Function &Fn) { return &Fn != Scope; })) return false; return true; @@ -4845,11 +5538,8 @@ struct AAInstanceInfoImpl : public AAInstanceInfo { auto EquivalentUseCB = [&](const Use &OldU, const Use &NewU) { if (auto *SI = dyn_cast<StoreInst>(OldU.getUser())) { auto *Ptr = SI->getPointerOperand()->stripPointerCasts(); - if (isa<AllocaInst>(Ptr) && AA::isDynamicallyUnique(A, *this, *Ptr)) - return true; - auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction( - *SI->getFunction()); - if (isAllocationFn(Ptr, TLI) && AA::isDynamicallyUnique(A, *this, *Ptr)) + if ((isa<AllocaInst>(Ptr) || isNoAliasCall(Ptr)) && + AA::isDynamicallyUnique(A, *this, *Ptr)) return true; } return false; @@ -5144,7 +5834,7 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { if (!RVAA.getState().isValidState()) return false; bool SeenConstant = false; - for (auto &It : RVAA.returned_values()) { + for (const auto &It : RVAA.returned_values()) { if (isa<Constant>(It.first)) { if (SeenConstant) return false; @@ -5302,11 +5992,11 @@ struct AANoCaptureCallSiteReturned final : AANoCaptureImpl { /// ------------------ Value Simplify Attribute ---------------------------- -bool ValueSimplifyStateType::unionAssumed(Optional<Value *> Other) { +bool ValueSimplifyStateType::unionAssumed(std::optional<Value *> Other) { // FIXME: Add a typecast support. SimplifiedAssociatedValue = AA::combineOptionalValuesInAAValueLatice( SimplifiedAssociatedValue, Other, Ty); - if (SimplifiedAssociatedValue == Optional<Value *>(nullptr)) + if (SimplifiedAssociatedValue == std::optional<Value *>(nullptr)) return false; LLVM_DEBUG({ @@ -5347,7 +6037,8 @@ struct AAValueSimplifyImpl : AAValueSimplify { void trackStatistics() const override {} /// See AAValueSimplify::getAssumedSimplifiedValue() - Optional<Value *> getAssumedSimplifiedValue(Attributor &A) const override { + std::optional<Value *> + getAssumedSimplifiedValue(Attributor &A) const override { return SimplifiedAssociatedValue; } @@ -5411,13 +6102,13 @@ struct AAValueSimplifyImpl : AAValueSimplify { if (const auto &NewV = VMap.lookup(&V)) return NewV; bool UsedAssumedInformation = false; - Optional<Value *> SimpleV = A.getAssumedSimplified( + std::optional<Value *> SimpleV = A.getAssumedSimplified( V, QueryingAA, UsedAssumedInformation, AA::Interprocedural); if (!SimpleV.has_value()) return PoisonValue::get(&Ty); Value *EffectiveV = &V; - if (SimpleV.value()) - EffectiveV = SimpleV.value(); + if (*SimpleV) + EffectiveV = *SimpleV; if (auto *C = dyn_cast<Constant>(EffectiveV)) return C; if (CtxI && AA::isValidAtPosition(AA::ValueAndContext(*EffectiveV, *CtxI), @@ -5433,7 +6124,7 @@ struct AAValueSimplifyImpl : AAValueSimplify { /// nullptr if we don't have one that makes sense. Value *manifestReplacementValue(Attributor &A, Instruction *CtxI) const { Value *NewV = SimplifiedAssociatedValue - ? SimplifiedAssociatedValue.value() + ? *SimplifiedAssociatedValue : UndefValue::get(getAssociatedType()); if (NewV && NewV != &getAssociatedValue()) { ValueToValueMapTy VMap; @@ -5447,12 +6138,12 @@ struct AAValueSimplifyImpl : AAValueSimplify { return nullptr; } - /// Helper function for querying AAValueSimplify and updating candicate. + /// Helper function for querying AAValueSimplify and updating candidate. /// \param IRP The value position we are trying to unify with SimplifiedValue bool checkAndUpdate(Attributor &A, const AbstractAttribute &QueryingAA, const IRPosition &IRP, bool Simplify = true) { bool UsedAssumedInformation = false; - Optional<Value *> QueryingValueSimplified = &IRP.getAssociatedValue(); + std::optional<Value *> QueryingValueSimplified = &IRP.getAssociatedValue(); if (Simplify) QueryingValueSimplified = A.getAssumedSimplified( IRP, QueryingAA, UsedAssumedInformation, AA::Interprocedural); @@ -5468,10 +6159,10 @@ struct AAValueSimplifyImpl : AAValueSimplify { const auto &AA = A.getAAFor<AAType>(*this, getIRPosition(), DepClassTy::NONE); - Optional<Constant *> COpt = AA.getAssumedConstant(A); + std::optional<Constant *> COpt = AA.getAssumedConstant(A); if (!COpt) { - SimplifiedAssociatedValue = llvm::None; + SimplifiedAssociatedValue = std::nullopt; A.recordDependence(AA, *this, DepClassTy::OPTIONAL); return true; } @@ -5560,11 +6251,11 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl { // in other functions, e.g., we don't want to say a an argument in a // static function is actually an argument in a different function. bool UsedAssumedInformation = false; - Optional<Constant *> SimpleArgOp = + std::optional<Constant *> SimpleArgOp = A.getAssumedConstant(ACSArgPos, *this, UsedAssumedInformation); if (!SimpleArgOp) return true; - if (!SimpleArgOp.value()) + if (!*SimpleArgOp) return false; if (!AA::isDynamicallyUnique(A, *this, **SimpleArgOp)) return false; @@ -5586,7 +6277,7 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl { if (!askSimplifiedValueForOtherAAs(A)) return indicatePessimisticFixpoint(); - // If a candicate was found in this update, return CHANGED. + // If a candidate was found in this update, return CHANGED. return Before == SimplifiedAssociatedValue ? ChangeStatus::UNCHANGED : ChangeStatus ::CHANGED; } @@ -5602,7 +6293,8 @@ struct AAValueSimplifyReturned : AAValueSimplifyImpl { : AAValueSimplifyImpl(IRP, A) {} /// See AAValueSimplify::getAssumedSimplifiedValue() - Optional<Value *> getAssumedSimplifiedValue(Attributor &A) const override { + std::optional<Value *> + getAssumedSimplifiedValue(Attributor &A) const override { if (!isValidState()) return nullptr; return SimplifiedAssociatedValue; @@ -5625,7 +6317,7 @@ struct AAValueSimplifyReturned : AAValueSimplifyImpl { if (!askSimplifiedValueForOtherAAs(A)) return indicatePessimisticFixpoint(); - // If a candicate was found in this update, return CHANGED. + // If a candidate was found in this update, return CHANGED. return Before == SimplifiedAssociatedValue ? ChangeStatus::UNCHANGED : ChangeStatus ::CHANGED; } @@ -5662,7 +6354,7 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { if (!askSimplifiedValueForOtherAAs(A)) return indicatePessimisticFixpoint(); - // If a candicate was found in this update, return CHANGED. + // If a candidate was found in this update, return CHANGED. return Before == SimplifiedAssociatedValue ? ChangeStatus::UNCHANGED : ChangeStatus ::CHANGED; } @@ -5736,12 +6428,13 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl { auto PredForReturned = [&](Value &RetVal, const SmallSetVector<ReturnInst *, 4> &RetInsts) { bool UsedAssumedInformation = false; - Optional<Value *> CSRetVal = A.translateArgumentToCallSiteContent( - &RetVal, *cast<CallBase>(getCtxI()), *this, - UsedAssumedInformation); + std::optional<Value *> CSRetVal = + A.translateArgumentToCallSiteContent( + &RetVal, *cast<CallBase>(getCtxI()), *this, + UsedAssumedInformation); SimplifiedAssociatedValue = AA::combineOptionalValuesInAAValueLatice( SimplifiedAssociatedValue, CSRetVal, getAssociatedType()); - return SimplifiedAssociatedValue != Optional<Value *>(nullptr); + return SimplifiedAssociatedValue != std::optional<Value *>(nullptr); }; if (!RetAA.checkForAllReturnedValuesAndReturnInsts(PredForReturned)) if (!askSimplifiedValueForOtherAAs(A)) @@ -5879,7 +6572,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { Attributor::SimplifictionCallbackTy SCB = [](const IRPosition &, const AbstractAttribute *, - bool &) -> Optional<Value *> { return nullptr; }; + bool &) -> std::optional<Value *> { return nullptr; }; for (const auto &It : AllocationInfos) A.registerSimplificationCallback(IRPosition::callsite_returned(*It.first), SCB); @@ -5905,7 +6598,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { STATS_DECL( MallocCalls, Function, "Number of malloc/calloc/aligned_alloc calls converted to allocas"); - for (auto &It : AllocationInfos) + for (const auto &It : AllocationInfos) if (It.second->Status != AllocationInfo::INVALID) ++BUILD_STAT_NAME(MallocCalls, Function); } @@ -5922,7 +6615,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { if (!isValidState()) return false; - for (auto &It : AllocationInfos) { + for (const auto &It : AllocationInfos) { AllocationInfo &AI = *It.second; if (AI.Status == AllocationInfo::INVALID) continue; @@ -5970,7 +6663,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { const DataLayout &DL = A.getInfoCache().getDL(); Value *Size; - Optional<APInt> SizeAPI = getSize(A, *this, AI); + std::optional<APInt> SizeAPI = getSize(A, *this, AI); if (SizeAPI) { Size = ConstantInt::get(AI.CB->getContext(), *SizeAPI); } else { @@ -5990,11 +6683,11 @@ struct AAHeapToStackFunction final : public AAHeapToStack { if (MaybeAlign RetAlign = AI.CB->getRetAlign()) Alignment = std::max(Alignment, *RetAlign); if (Value *Align = getAllocAlignment(AI.CB, TLI)) { - Optional<APInt> AlignmentAPI = getAPInt(A, *this, *Align); - assert(AlignmentAPI && AlignmentAPI.value().getZExtValue() > 0 && + std::optional<APInt> AlignmentAPI = getAPInt(A, *this, *Align); + assert(AlignmentAPI && AlignmentAPI->getZExtValue() > 0 && "Expected an alignment during manifest!"); - Alignment = std::max( - Alignment, assumeAligned(AlignmentAPI.value().getZExtValue())); + Alignment = + std::max(Alignment, assumeAligned(AlignmentAPI->getZExtValue())); } // TODO: Hoist the alloca towards the function entry. @@ -6028,7 +6721,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { if (!isa<UndefValue>(InitVal)) { IRBuilder<> Builder(Alloca->getNextNode()); // TODO: Use alignment above if align!=1 - Builder.CreateMemSet(Alloca, InitVal, Size, None); + Builder.CreateMemSet(Alloca, InitVal, Size, std::nullopt); } HasChanged = ChangeStatus::CHANGED; } @@ -6036,23 +6729,23 @@ struct AAHeapToStackFunction final : public AAHeapToStack { return HasChanged; } - Optional<APInt> getAPInt(Attributor &A, const AbstractAttribute &AA, - Value &V) { + std::optional<APInt> getAPInt(Attributor &A, const AbstractAttribute &AA, + Value &V) { bool UsedAssumedInformation = false; - Optional<Constant *> SimpleV = + std::optional<Constant *> SimpleV = A.getAssumedConstant(V, AA, UsedAssumedInformation); if (!SimpleV) return APInt(64, 0); - if (auto *CI = dyn_cast_or_null<ConstantInt>(SimpleV.value())) + if (auto *CI = dyn_cast_or_null<ConstantInt>(*SimpleV)) return CI->getValue(); - return llvm::None; + return std::nullopt; } - Optional<APInt> getSize(Attributor &A, const AbstractAttribute &AA, - AllocationInfo &AI) { + std::optional<APInt> getSize(Attributor &A, const AbstractAttribute &AA, + AllocationInfo &AI) { auto Mapper = [&](const Value *V) -> const Value * { bool UsedAssumedInformation = false; - if (Optional<Constant *> SimpleV = + if (std::optional<Constant *> SimpleV = A.getAssumedConstant(*V, AA, UsedAssumedInformation)) if (*SimpleV) return *SimpleV; @@ -6091,13 +6784,13 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { LoopInfo *LI = A.getInfoCache().getAnalysisResultForFunction<LoopAnalysis>(*F); - Optional<bool> MayContainIrreducibleControl; + std::optional<bool> MayContainIrreducibleControl; auto IsInLoop = [&](BasicBlock &BB) { if (&F->getEntryBlock() == &BB) return false; if (!MayContainIrreducibleControl.has_value()) MayContainIrreducibleControl = mayContainIrreducibleControl(*F, LI); - if (MayContainIrreducibleControl.value()) + if (*MayContainIrreducibleControl) return true; if (!LI) return true; @@ -6304,7 +6997,7 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { continue; if (Value *Align = getAllocAlignment(AI.CB, TLI)) { - Optional<APInt> APAlign = getAPInt(A, *this, *Align); + std::optional<APInt> APAlign = getAPInt(A, *this, *Align); if (!APAlign) { // Can't generate an alloca which respects the required alignment // on the allocation. @@ -6324,9 +7017,9 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { } } - Optional<APInt> Size = getSize(A, *this, AI); + std::optional<APInt> Size = getSize(A, *this, AI); if (MaxHeapToStackSize != -1) { - if (!Size || Size.value().ugt(MaxHeapToStackSize)) { + if (!Size || Size->ugt(MaxHeapToStackSize)) { LLVM_DEBUG({ if (!Size) dbgs() << "[H2S] Unknown allocation size: " << *AI.CB << "\n"; @@ -6346,7 +7039,7 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { if (UsesCheck(AI)) break; AI.Status = AllocationInfo::STACK_DUE_TO_FREE; - LLVM_FALLTHROUGH; + [[fallthrough]]; case AllocationInfo::STACK_DUE_TO_FREE: if (FreeCheck(AI)) break; @@ -6357,9 +7050,14 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { llvm_unreachable("Invalid allocations should never reach this point!"); }; - // Check if we still think we can move it into the entry block. + // Check if we still think we can move it into the entry block. If the + // alloca comes from a converted __kmpc_alloc_shared then we can usually + // ignore the potential compilations associated with loops. + bool IsGlobalizedLocal = + AI.LibraryFunctionId == LibFunc___kmpc_alloc_shared; if (AI.MoveAllocaIntoEntry && - (!Size.has_value() || IsInLoop(*AI.CB->getParent()))) + (!Size.has_value() || + (!IsGlobalizedLocal && IsInLoop(*AI.CB->getParent())))) AI.MoveAllocaIntoEntry = false; } @@ -6371,7 +7069,7 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { namespace { struct AAPrivatizablePtrImpl : public AAPrivatizablePtr { AAPrivatizablePtrImpl(const IRPosition &IRP, Attributor &A) - : AAPrivatizablePtr(IRP, A), PrivatizableType(llvm::None) {} + : AAPrivatizablePtr(IRP, A), PrivatizableType(std::nullopt) {} ChangeStatus indicatePessimisticFixpoint() override { AAPrivatizablePtr::indicatePessimisticFixpoint(); @@ -6381,11 +7079,12 @@ struct AAPrivatizablePtrImpl : public AAPrivatizablePtr { /// Identify the type we can chose for a private copy of the underlying /// argument. None means it is not clear yet, nullptr means there is none. - virtual Optional<Type *> identifyPrivatizableType(Attributor &A) = 0; + virtual std::optional<Type *> identifyPrivatizableType(Attributor &A) = 0; /// Return a privatizable type that encloses both T0 and T1. /// TODO: This is merely a stub for now as we should manage a mapping as well. - Optional<Type *> combineTypes(Optional<Type *> T0, Optional<Type *> T1) { + std::optional<Type *> combineTypes(std::optional<Type *> T0, + std::optional<Type *> T1) { if (!T0) return T1; if (!T1) @@ -6395,7 +7094,7 @@ struct AAPrivatizablePtrImpl : public AAPrivatizablePtr { return nullptr; } - Optional<Type *> getPrivatizableType() const override { + std::optional<Type *> getPrivatizableType() const override { return PrivatizableType; } @@ -6404,7 +7103,7 @@ struct AAPrivatizablePtrImpl : public AAPrivatizablePtr { } protected: - Optional<Type *> PrivatizableType; + std::optional<Type *> PrivatizableType; }; // TODO: Do this for call site arguments (probably also other values) as well. @@ -6414,7 +7113,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { : AAPrivatizablePtrImpl(IRP, A) {} /// See AAPrivatizablePtrImpl::identifyPrivatizableType(...) - Optional<Type *> identifyPrivatizableType(Attributor &A) override { + std::optional<Type *> identifyPrivatizableType(Attributor &A) override { // If this is a byval argument and we know all the call sites (so we can // rewrite them), there is no need to check them explicitly. bool UsedAssumedInformation = false; @@ -6425,7 +7124,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { true, UsedAssumedInformation)) return Attrs[0].getValueAsType(); - Optional<Type *> Ty; + std::optional<Type *> Ty; unsigned ArgNo = getIRPosition().getCallSiteArgNo(); // Make sure the associated call site argument has the same type at all call @@ -6444,12 +7143,12 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { // Check that all call sites agree on a type. auto &PrivCSArgAA = A.getAAFor<AAPrivatizablePtr>(*this, ACSArgPos, DepClassTy::REQUIRED); - Optional<Type *> CSTy = PrivCSArgAA.getPrivatizableType(); + std::optional<Type *> CSTy = PrivCSArgAA.getPrivatizableType(); LLVM_DEBUG({ dbgs() << "[AAPrivatizablePtr] ACSPos: " << ACSArgPos << ", CSTy: "; - if (CSTy && CSTy.value()) - CSTy.value()->print(dbgs()); + if (CSTy && *CSTy) + (*CSTy)->print(dbgs()); else if (CSTy) dbgs() << "<nullptr>"; else @@ -6460,8 +7159,8 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { LLVM_DEBUG({ dbgs() << " : New Type: "; - if (Ty && Ty.value()) - Ty.value()->print(dbgs()); + if (Ty && *Ty) + (*Ty)->print(dbgs()); else if (Ty) dbgs() << "<nullptr>"; else @@ -6469,7 +7168,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { dbgs() << "\n"; }); - return !Ty || Ty.value(); + return !Ty || *Ty; }; if (!A.checkForAllCallSites(CallSiteCheck, *this, true, @@ -6483,7 +7182,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { PrivatizableType = identifyPrivatizableType(A); if (!PrivatizableType) return ChangeStatus::UNCHANGED; - if (!PrivatizableType.value()) + if (!*PrivatizableType) return indicatePessimisticFixpoint(); // The dependence is optional so we don't give up once we give up on the @@ -6571,7 +7270,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { auto CBArgPrivTy = CBArgPrivAA.getPrivatizableType(); if (!CBArgPrivTy) continue; - if (CBArgPrivTy.value() == PrivatizableType) + if (*CBArgPrivTy == PrivatizableType) continue; } @@ -6618,7 +7317,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { auto DCArgPrivTy = DCArgPrivAA.getPrivatizableType(); if (!DCArgPrivTy) return true; - if (DCArgPrivTy.value() == PrivatizableType) + if (*DCArgPrivTy == PrivatizableType) return true; } } @@ -6760,7 +7459,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { ChangeStatus manifest(Attributor &A) override { if (!PrivatizableType) return ChangeStatus::UNCHANGED; - assert(PrivatizableType.value() && "Expected privatizable type!"); + assert(*PrivatizableType && "Expected privatizable type!"); // Collect all tail calls in the function as we cannot allow new allocas to // escape into tail recursion. @@ -6793,9 +7492,9 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { Instruction *IP = &*EntryBB.getFirstInsertionPt(); const DataLayout &DL = IP->getModule()->getDataLayout(); unsigned AS = DL.getAllocaAddrSpace(); - Instruction *AI = new AllocaInst(PrivatizableType.value(), AS, + Instruction *AI = new AllocaInst(*PrivatizableType, AS, Arg->getName() + ".priv", IP); - createInitialization(PrivatizableType.value(), *AI, ReplacementFn, + createInitialization(*PrivatizableType, *AI, ReplacementFn, ArgIt->getArgNo(), *IP); if (AI->getType() != Arg->getType()) @@ -6857,7 +7556,7 @@ struct AAPrivatizablePtrFloating : public AAPrivatizablePtrImpl { } /// See AAPrivatizablePtrImpl::identifyPrivatizableType(...) - Optional<Type *> identifyPrivatizableType(Attributor &A) override { + std::optional<Type *> identifyPrivatizableType(Attributor &A) override { Value *Obj = getUnderlyingObject(&getAssociatedValue()); if (!Obj) { LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] No underlying object found!\n"); @@ -6903,7 +7602,7 @@ struct AAPrivatizablePtrCallSiteArgument final PrivatizableType = identifyPrivatizableType(A); if (!PrivatizableType) return ChangeStatus::UNCHANGED; - if (!PrivatizableType.value()) + if (!*PrivatizableType) return indicatePessimisticFixpoint(); const IRPosition &IRP = getIRPosition(); @@ -7224,13 +7923,28 @@ struct AAMemoryBehaviorFunction final : public AAMemoryBehaviorImpl { /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { + // TODO: It would be better to merge this with AAMemoryLocation, so that + // we could determine read/write per location. This would also have the + // benefit of only one place trying to manifest the memory attribute. Function &F = cast<Function>(getAnchorValue()); - if (isAssumedReadNone()) { - F.removeFnAttr(Attribute::ArgMemOnly); - F.removeFnAttr(Attribute::InaccessibleMemOnly); - F.removeFnAttr(Attribute::InaccessibleMemOrArgMemOnly); - } - return AAMemoryBehaviorImpl::manifest(A); + MemoryEffects ME = MemoryEffects::unknown(); + if (isAssumedReadNone()) + ME = MemoryEffects::none(); + else if (isAssumedReadOnly()) + ME = MemoryEffects::readOnly(); + else if (isAssumedWriteOnly()) + ME = MemoryEffects::writeOnly(); + + // Intersect with existing memory attribute, as we currently deduce the + // location and modref portion separately. + MemoryEffects ExistingME = F.getMemoryEffects(); + ME &= ExistingME; + if (ME == ExistingME) + return ChangeStatus::UNCHANGED; + + return IRAttributeManifest::manifestAttrs( + A, getIRPosition(), Attribute::getWithMemoryEffects(F.getContext(), ME), + /*ForceReplace*/ true); } /// See AbstractAttribute::trackStatistics() @@ -7270,6 +7984,31 @@ struct AAMemoryBehaviorCallSite final : AAMemoryBehaviorImpl { return clampStateAndIndicateChange(getState(), FnAA.getState()); } + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + // TODO: Deduplicate this with AAMemoryBehaviorFunction. + CallBase &CB = cast<CallBase>(getAnchorValue()); + MemoryEffects ME = MemoryEffects::unknown(); + if (isAssumedReadNone()) + ME = MemoryEffects::none(); + else if (isAssumedReadOnly()) + ME = MemoryEffects::readOnly(); + else if (isAssumedWriteOnly()) + ME = MemoryEffects::writeOnly(); + + // Intersect with existing memory attribute, as we currently deduce the + // location and modref portion separately. + MemoryEffects ExistingME = CB.getMemoryEffects(); + ME &= ExistingME; + if (ME == ExistingME) + return ChangeStatus::UNCHANGED; + + return IRAttributeManifest::manifestAttrs( + A, getIRPosition(), + Attribute::getWithMemoryEffects(CB.getContext(), ME), + /*ForceReplace*/ true); + } + /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { if (isAssumedReadNone()) @@ -7539,36 +8278,54 @@ struct AAMemoryLocationImpl : public AAMemoryLocation { // unlikely this will cause real performance problems. If we are deriving // attributes for the anchor function we even remove the attribute in // addition to ignoring it. + // TODO: A better way to handle this would be to add ~NO_GLOBAL_MEM / + // MemoryEffects::Other as a possible location. bool UseArgMemOnly = true; Function *AnchorFn = IRP.getAnchorScope(); if (AnchorFn && A.isRunOn(*AnchorFn)) UseArgMemOnly = !AnchorFn->hasLocalLinkage(); SmallVector<Attribute, 2> Attrs; - IRP.getAttrs(AttrKinds, Attrs, IgnoreSubsumingPositions); + IRP.getAttrs({Attribute::Memory}, Attrs, IgnoreSubsumingPositions); for (const Attribute &Attr : Attrs) { - switch (Attr.getKindAsEnum()) { - case Attribute::ReadNone: + // TODO: We can map MemoryEffects to Attributor locations more precisely. + MemoryEffects ME = Attr.getMemoryEffects(); + if (ME.doesNotAccessMemory()) { State.addKnownBits(NO_LOCAL_MEM | NO_CONST_MEM); - break; - case Attribute::InaccessibleMemOnly: + continue; + } + if (ME.onlyAccessesInaccessibleMem()) { State.addKnownBits(inverseLocation(NO_INACCESSIBLE_MEM, true, true)); - break; - case Attribute::ArgMemOnly: + continue; + } + if (ME.onlyAccessesArgPointees()) { if (UseArgMemOnly) State.addKnownBits(inverseLocation(NO_ARGUMENT_MEM, true, true)); - else - IRP.removeAttrs({Attribute::ArgMemOnly}); - break; - case Attribute::InaccessibleMemOrArgMemOnly: + else { + // Remove location information, only keep read/write info. + ME = MemoryEffects(ME.getModRef()); + IRAttributeManifest::manifestAttrs( + A, IRP, + Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(), + ME), + /*ForceReplace*/ true); + } + continue; + } + if (ME.onlyAccessesInaccessibleOrArgMem()) { if (UseArgMemOnly) State.addKnownBits(inverseLocation( NO_INACCESSIBLE_MEM | NO_ARGUMENT_MEM, true, true)); - else - IRP.removeAttrs({Attribute::InaccessibleMemOrArgMemOnly}); - break; - default: - llvm_unreachable("Unexpected attribute!"); + else { + // Remove location information, only keep read/write info. + ME = MemoryEffects(ME.getModRef()); + IRAttributeManifest::manifestAttrs( + A, IRP, + Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(), + ME), + /*ForceReplace*/ true); + } + continue; } } } @@ -7576,41 +8333,53 @@ struct AAMemoryLocationImpl : public AAMemoryLocation { /// See AbstractAttribute::getDeducedAttributes(...). void getDeducedAttributes(LLVMContext &Ctx, SmallVectorImpl<Attribute> &Attrs) const override { + // TODO: We can map Attributor locations to MemoryEffects more precisely. assert(Attrs.size() == 0); - if (isAssumedReadNone()) { - Attrs.push_back(Attribute::get(Ctx, Attribute::ReadNone)); - } else if (getIRPosition().getPositionKind() == IRPosition::IRP_FUNCTION) { - if (isAssumedInaccessibleMemOnly()) - Attrs.push_back(Attribute::get(Ctx, Attribute::InaccessibleMemOnly)); + if (getIRPosition().getPositionKind() == IRPosition::IRP_FUNCTION) { + if (isAssumedReadNone()) + Attrs.push_back( + Attribute::getWithMemoryEffects(Ctx, MemoryEffects::none())); + else if (isAssumedInaccessibleMemOnly()) + Attrs.push_back(Attribute::getWithMemoryEffects( + Ctx, MemoryEffects::inaccessibleMemOnly())); else if (isAssumedArgMemOnly()) - Attrs.push_back(Attribute::get(Ctx, Attribute::ArgMemOnly)); - else if (isAssumedInaccessibleOrArgMemOnly()) Attrs.push_back( - Attribute::get(Ctx, Attribute::InaccessibleMemOrArgMemOnly)); + Attribute::getWithMemoryEffects(Ctx, MemoryEffects::argMemOnly())); + else if (isAssumedInaccessibleOrArgMemOnly()) + Attrs.push_back(Attribute::getWithMemoryEffects( + Ctx, MemoryEffects::inaccessibleOrArgMemOnly())); } assert(Attrs.size() <= 1); } /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { + // TODO: If AAMemoryLocation and AAMemoryBehavior are merged, we could + // provide per-location modref information here. const IRPosition &IRP = getIRPosition(); - // Check if we would improve the existing attributes first. - SmallVector<Attribute, 4> DeducedAttrs; + SmallVector<Attribute, 1> DeducedAttrs; getDeducedAttributes(IRP.getAnchorValue().getContext(), DeducedAttrs); - if (llvm::all_of(DeducedAttrs, [&](const Attribute &Attr) { - return IRP.hasAttr(Attr.getKindAsEnum(), - /* IgnoreSubsumingPositions */ true); - })) + if (DeducedAttrs.size() != 1) return ChangeStatus::UNCHANGED; + MemoryEffects ME = DeducedAttrs[0].getMemoryEffects(); + + // Intersect with existing memory attribute, as we currently deduce the + // location and modref portion separately. + SmallVector<Attribute, 1> ExistingAttrs; + IRP.getAttrs({Attribute::Memory}, ExistingAttrs, + /* IgnoreSubsumingPositions */ true); + if (ExistingAttrs.size() == 1) { + MemoryEffects ExistingME = ExistingAttrs[0].getMemoryEffects(); + ME &= ExistingME; + if (ME == ExistingME) + return ChangeStatus::UNCHANGED; + } - // Clear existing attributes. - IRP.removeAttrs(AttrKinds); - if (isAssumedReadNone()) - IRP.removeAttrs(AAMemoryBehaviorImpl::AttrKinds); - - // Use the generic manifest method. - return IRAttribute::manifest(A); + return IRAttributeManifest::manifestAttrs( + A, IRP, + Attribute::getWithMemoryEffects(IRP.getAnchorValue().getContext(), ME), + /*ForceReplace*/ true); } /// See AAMemoryLocation::checkForAllAccessesToMemoryKind(...). @@ -7733,15 +8502,8 @@ protected: /// Used to allocate access sets. BumpPtrAllocator &Allocator; - - /// The set of IR attributes AAMemoryLocation deals with. - static const Attribute::AttrKind AttrKinds[4]; }; -const Attribute::AttrKind AAMemoryLocationImpl::AttrKinds[] = { - Attribute::ReadNone, Attribute::InaccessibleMemOnly, Attribute::ArgMemOnly, - Attribute::InaccessibleMemOrArgMemOnly}; - void AAMemoryLocationImpl::categorizePtrValue( Attributor &A, const Instruction &I, const Value &Ptr, AAMemoryLocation::StateType &State, bool &Changed) { @@ -7749,50 +8511,38 @@ void AAMemoryLocationImpl::categorizePtrValue( << Ptr << " [" << getMemoryLocationsAsStr(State.getAssumed()) << "]\n"); - SmallSetVector<Value *, 8> Objects; - bool UsedAssumedInformation = false; - if (!AA::getAssumedUnderlyingObjects(A, Ptr, Objects, *this, &I, - UsedAssumedInformation, - AA::Intraprocedural)) { - LLVM_DEBUG( - dbgs() << "[AAMemoryLocation] Pointer locations not categorized\n"); - updateStateAndAccessesMap(State, NO_UNKOWN_MEM, &I, nullptr, Changed, - getAccessKindFromInst(&I)); - return; - } - - for (Value *Obj : Objects) { + auto Pred = [&](Value &Obj) { // TODO: recognize the TBAA used for constant accesses. MemoryLocationsKind MLK = NO_LOCATIONS; - if (isa<UndefValue>(Obj)) - continue; - if (isa<Argument>(Obj)) { + if (isa<UndefValue>(&Obj)) + return true; + if (isa<Argument>(&Obj)) { // TODO: For now we do not treat byval arguments as local copies performed // on the call edge, though, we should. To make that happen we need to // teach various passes, e.g., DSE, about the copy effect of a byval. That // would also allow us to mark functions only accessing byval arguments as - // readnone again, atguably their acceses have no effect outside of the + // readnone again, arguably their accesses have no effect outside of the // function, like accesses to allocas. MLK = NO_ARGUMENT_MEM; - } else if (auto *GV = dyn_cast<GlobalValue>(Obj)) { + } else if (auto *GV = dyn_cast<GlobalValue>(&Obj)) { // Reading constant memory is not treated as a read "effect" by the // function attr pass so we won't neither. Constants defined by TBAA are // similar. (We know we do not write it because it is constant.) if (auto *GVar = dyn_cast<GlobalVariable>(GV)) if (GVar->isConstant()) - continue; + return true; if (GV->hasLocalLinkage()) MLK = NO_GLOBAL_INTERNAL_MEM; else MLK = NO_GLOBAL_EXTERNAL_MEM; - } else if (isa<ConstantPointerNull>(Obj) && + } else if (isa<ConstantPointerNull>(&Obj) && !NullPointerIsDefined(getAssociatedFunction(), Ptr.getType()->getPointerAddressSpace())) { - continue; - } else if (isa<AllocaInst>(Obj)) { + return true; + } else if (isa<AllocaInst>(&Obj)) { MLK = NO_LOCAL_MEM; - } else if (const auto *CB = dyn_cast<CallBase>(Obj)) { + } else if (const auto *CB = dyn_cast<CallBase>(&Obj)) { const auto &NoAliasAA = A.getAAFor<AANoAlias>( *this, IRPosition::callsite_returned(*CB), DepClassTy::OPTIONAL); if (NoAliasAA.isAssumedNoAlias()) @@ -7805,10 +8555,21 @@ void AAMemoryLocationImpl::categorizePtrValue( assert(MLK != NO_LOCATIONS && "No location specified!"); LLVM_DEBUG(dbgs() << "[AAMemoryLocation] Ptr value can be categorized: " - << *Obj << " -> " << getMemoryLocationsAsStr(MLK) - << "\n"); - updateStateAndAccessesMap(getState(), MLK, &I, Obj, Changed, + << Obj << " -> " << getMemoryLocationsAsStr(MLK) << "\n"); + updateStateAndAccessesMap(getState(), MLK, &I, &Obj, Changed, + getAccessKindFromInst(&I)); + + return true; + }; + + const auto &AA = A.getAAFor<AAUnderlyingObjects>( + *this, IRPosition::value(Ptr), DepClassTy::OPTIONAL); + if (!AA.forallUnderlyingObjects(Pred, AA::Intraprocedural)) { + LLVM_DEBUG( + dbgs() << "[AAMemoryLocation] Pointer locations not categorized\n"); + updateStateAndAccessesMap(State, NO_UNKOWN_MEM, &I, nullptr, Changed, getAccessKindFromInst(&I)); + return; } LLVM_DEBUG( @@ -8363,7 +9124,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { UsedAssumedInformation, AA::Interprocedural); if (!SimplifiedLHS.has_value()) return true; - if (!SimplifiedLHS.value()) + if (!*SimplifiedLHS) return false; LHS = *SimplifiedLHS; @@ -8372,7 +9133,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { UsedAssumedInformation, AA::Interprocedural); if (!SimplifiedRHS.has_value()) return true; - if (!SimplifiedRHS.value()) + if (!*SimplifiedRHS) return false; RHS = *SimplifiedRHS; @@ -8416,7 +9177,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { UsedAssumedInformation, AA::Interprocedural); if (!SimplifiedOpV.has_value()) return true; - if (!SimplifiedOpV.value()) + if (!*SimplifiedOpV) return false; OpV = *SimplifiedOpV; @@ -8446,7 +9207,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { UsedAssumedInformation, AA::Interprocedural); if (!SimplifiedLHS.has_value()) return true; - if (!SimplifiedLHS.value()) + if (!*SimplifiedLHS) return false; LHS = *SimplifiedLHS; @@ -8455,7 +9216,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { UsedAssumedInformation, AA::Interprocedural); if (!SimplifiedRHS.has_value()) return true; - if (!SimplifiedRHS.value()) + if (!*SimplifiedRHS) return false; RHS = *SimplifiedRHS; @@ -8521,7 +9282,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { UsedAssumedInformation, AA::Interprocedural); if (!SimplifiedOpV.has_value()) return true; - if (!SimplifiedOpV.value()) + if (!*SimplifiedOpV) return false; Value *VPtr = *SimplifiedOpV; @@ -8682,11 +9443,15 @@ struct AAPotentialConstantValuesImpl : AAPotentialConstantValues { } bool fillSetWithConstantValues(Attributor &A, const IRPosition &IRP, SetTy &S, - bool &ContainsUndef) { + bool &ContainsUndef, bool ForSelf) { SmallVector<AA::ValueAndContext> Values; bool UsedAssumedInformation = false; if (!A.getAssumedSimplifiedValues(IRP, *this, Values, AA::Interprocedural, UsedAssumedInformation)) { + // Avoid recursion when the caller is computing constant values for this + // IRP itself. + if (ForSelf) + return false; if (!IRP.getAssociatedType()->isIntegerTy()) return false; auto &PotentialValuesAA = A.getAAFor<AAPotentialConstantValues>( @@ -8698,15 +9463,21 @@ struct AAPotentialConstantValuesImpl : AAPotentialConstantValues { return true; } + // Copy all the constant values, except UndefValue. ContainsUndef is true + // iff Values contains only UndefValue instances. If there are other known + // constants, then UndefValue is dropped. + ContainsUndef = false; for (auto &It : Values) { - if (isa<UndefValue>(It.getValue())) + if (isa<UndefValue>(It.getValue())) { + ContainsUndef = true; continue; + } auto *CI = dyn_cast<ConstantInt>(It.getValue()); if (!CI) return false; S.insert(CI->getValue()); } - ContainsUndef = S.empty(); + ContainsUndef &= S.empty(); return true; } @@ -8902,9 +9673,9 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { bool LHSContainsUndef = false, RHSContainsUndef = false; SetTy LHSAAPVS, RHSAAPVS; if (!fillSetWithConstantValues(A, IRPosition::value(*LHS), LHSAAPVS, - LHSContainsUndef) || + LHSContainsUndef, /* ForSelf */ false) || !fillSetWithConstantValues(A, IRPosition::value(*RHS), RHSAAPVS, - RHSContainsUndef)) + RHSContainsUndef, /* ForSelf */ false)) return indicatePessimisticFixpoint(); // TODO: make use of undef flag to limit potential values aggressively. @@ -8955,8 +9726,8 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { Value *RHS = SI->getFalseValue(); bool UsedAssumedInformation = false; - Optional<Constant *> C = A.getAssumedConstant(*SI->getCondition(), *this, - UsedAssumedInformation); + std::optional<Constant *> C = A.getAssumedConstant( + *SI->getCondition(), *this, UsedAssumedInformation); // Check if we only need one operand. bool OnlyLeft = false, OnlyRight = false; @@ -8967,12 +9738,14 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { bool LHSContainsUndef = false, RHSContainsUndef = false; SetTy LHSAAPVS, RHSAAPVS; - if (!OnlyRight && !fillSetWithConstantValues(A, IRPosition::value(*LHS), - LHSAAPVS, LHSContainsUndef)) + if (!OnlyRight && + !fillSetWithConstantValues(A, IRPosition::value(*LHS), LHSAAPVS, + LHSContainsUndef, /* ForSelf */ false)) return indicatePessimisticFixpoint(); - if (!OnlyLeft && !fillSetWithConstantValues(A, IRPosition::value(*RHS), - RHSAAPVS, RHSContainsUndef)) + if (!OnlyLeft && + !fillSetWithConstantValues(A, IRPosition::value(*RHS), RHSAAPVS, + RHSContainsUndef, /* ForSelf */ false)) return indicatePessimisticFixpoint(); if (OnlyLeft || OnlyRight) { @@ -8983,7 +9756,7 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { if (Undef) unionAssumedWithUndef(); else { - for (auto &It : *OpAA) + for (const auto &It : *OpAA) unionAssumed(It); } @@ -8991,9 +9764,9 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { // select i1 *, undef , undef => undef unionAssumedWithUndef(); } else { - for (auto &It : LHSAAPVS) + for (const auto &It : LHSAAPVS) unionAssumed(It); - for (auto &It : RHSAAPVS) + for (const auto &It : RHSAAPVS) unionAssumed(It); } return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED @@ -9011,7 +9784,7 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { bool SrcContainsUndef = false; SetTy SrcPVS; if (!fillSetWithConstantValues(A, IRPosition::value(*Src), SrcPVS, - SrcContainsUndef)) + SrcContainsUndef, /* ForSelf */ false)) return indicatePessimisticFixpoint(); if (SrcContainsUndef) @@ -9034,9 +9807,9 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { bool LHSContainsUndef = false, RHSContainsUndef = false; SetTy LHSAAPVS, RHSAAPVS; if (!fillSetWithConstantValues(A, IRPosition::value(*LHS), LHSAAPVS, - LHSContainsUndef) || + LHSContainsUndef, /* ForSelf */ false) || !fillSetWithConstantValues(A, IRPosition::value(*RHS), RHSAAPVS, - RHSContainsUndef)) + RHSContainsUndef, /* ForSelf */ false)) return indicatePessimisticFixpoint(); const APInt Zero = APInt(LHS->getType()->getIntegerBitWidth(), 0); @@ -9067,6 +9840,23 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { : ChangeStatus::CHANGED; } + ChangeStatus updateWithInstruction(Attributor &A, Instruction *Inst) { + auto AssumedBefore = getAssumed(); + SetTy Incoming; + bool ContainsUndef; + if (!fillSetWithConstantValues(A, IRPosition::value(*Inst), Incoming, + ContainsUndef, /* ForSelf */ true)) + return indicatePessimisticFixpoint(); + if (ContainsUndef) { + unionAssumedWithUndef(); + } else { + for (const auto &It : Incoming) + unionAssumed(It); + } + return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { Value &V = getAssociatedValue(); @@ -9084,6 +9874,9 @@ struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { if (auto *BinOp = dyn_cast<BinaryOperator>(I)) return updateWithBinaryOperator(A, BinOp); + if (isa<PHINode>(I) || isa<LoadInst>(I)) + return updateWithInstruction(A, I); + return indicatePessimisticFixpoint(); } @@ -9410,16 +10203,18 @@ struct AACallEdgesCallSite : public AACallEdgesImpl { CallBase *CB = cast<CallBase>(getCtxI()); - if (CB->isInlineAsm()) { - if (!hasAssumption(*CB->getCaller(), "ompx_no_call_asm") && - !hasAssumption(*CB, "ompx_no_call_asm")) + if (auto *IA = dyn_cast<InlineAsm>(CB->getCalledOperand())) { + if (IA->hasSideEffects() && + !hasAssumption(*CB->getCaller(), "ompx_no_call_asm") && + !hasAssumption(*CB, "ompx_no_call_asm")) { setHasUnknownCallee(false, Change); + } return Change; } // Process callee metadata if available. if (auto *MD = getCtxI()->getMetadata(LLVMContext::MD_callees)) { - for (auto &Op : MD->operands()) { + for (const auto &Op : MD->operands()) { Function *Callee = mdconst::dyn_extract_or_null<Function>(Op); if (Callee) addCalledFunction(Callee, Change); @@ -9478,294 +10273,103 @@ struct AACallEdgesFunction : public AACallEdgesImpl { } }; -struct AAFunctionReachabilityFunction : public AAFunctionReachability { -private: - struct QuerySet { - void markReachable(const Function &Fn) { - Reachable.insert(&Fn); - Unreachable.erase(&Fn); - } +/// -------------------AAInterFnReachability Attribute-------------------------- - /// If there is no information about the function None is returned. - Optional<bool> isCachedReachable(const Function &Fn) { - // Assume that we can reach the function. - // TODO: Be more specific with the unknown callee. - if (CanReachUnknownCallee) - return true; +struct AAInterFnReachabilityFunction + : public CachedReachabilityAA<AAInterFnReachability, Function> { + AAInterFnReachabilityFunction(const IRPosition &IRP, Attributor &A) + : CachedReachabilityAA<AAInterFnReachability, Function>(IRP, A) {} - if (Reachable.count(&Fn)) - return true; + bool instructionCanReach( + Attributor &A, const Instruction &From, const Function &To, + const AA::InstExclusionSetTy *ExclusionSet, + SmallPtrSet<const Function *, 16> *Visited) const override { + assert(From.getFunction() == getAnchorScope() && "Queried the wrong AA!"); + auto *NonConstThis = const_cast<AAInterFnReachabilityFunction *>(this); - if (Unreachable.count(&Fn)) - return false; - - return llvm::None; - } - - /// Set of functions that we know for sure is reachable. - DenseSet<const Function *> Reachable; - - /// Set of functions that are unreachable, but might become reachable. - DenseSet<const Function *> Unreachable; - - /// If we can reach a function with a call to a unknown function we assume - /// that we can reach any function. - bool CanReachUnknownCallee = false; - }; - - struct QueryResolver : public QuerySet { - ChangeStatus update(Attributor &A, const AAFunctionReachability &AA, - ArrayRef<const AACallEdges *> AAEdgesList) { - ChangeStatus Change = ChangeStatus::UNCHANGED; - - for (auto *AAEdges : AAEdgesList) { - if (AAEdges->hasUnknownCallee()) { - if (!CanReachUnknownCallee) { - LLVM_DEBUG(dbgs() - << "[QueryResolver] Edges include unknown callee!\n"); - Change = ChangeStatus::CHANGED; - } - CanReachUnknownCallee = true; - return Change; - } - } - - for (const Function *Fn : make_early_inc_range(Unreachable)) { - if (checkIfReachable(A, AA, AAEdgesList, *Fn)) { - Change = ChangeStatus::CHANGED; - markReachable(*Fn); - } - } - return Change; - } - - bool isReachable(Attributor &A, AAFunctionReachability &AA, - ArrayRef<const AACallEdges *> AAEdgesList, - const Function &Fn) { - Optional<bool> Cached = isCachedReachable(Fn); - if (Cached) - return Cached.value(); - - // The query was not cached, thus it is new. We need to request an update - // explicitly to make sure this the information is properly run to a - // fixpoint. - A.registerForUpdate(AA); - - // We need to assume that this function can't reach Fn to prevent - // an infinite loop if this function is recursive. - Unreachable.insert(&Fn); - - bool Result = checkIfReachable(A, AA, AAEdgesList, Fn); - if (Result) - markReachable(Fn); - return Result; - } - - bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA, - ArrayRef<const AACallEdges *> AAEdgesList, - const Function &Fn) const { - - // Handle the most trivial case first. - for (auto *AAEdges : AAEdgesList) { - const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges(); - - if (Edges.count(const_cast<Function *>(&Fn))) - return true; - } - - SmallVector<const AAFunctionReachability *, 8> Deps; - for (auto &AAEdges : AAEdgesList) { - const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges(); - - for (Function *Edge : Edges) { - // Functions that do not call back into the module can be ignored. - if (Edge->hasFnAttribute(Attribute::NoCallback)) - continue; + RQITy StackRQI(A, From, To, ExclusionSet); + typename RQITy::Reachable Result; + if (RQITy *RQIPtr = NonConstThis->checkQueryCache(A, StackRQI, Result)) + return NonConstThis->isReachableImpl(A, *RQIPtr); + return Result == RQITy::Reachable::Yes; + } - // We don't need a dependency if the result is reachable. - const AAFunctionReachability &EdgeReachability = - A.getAAFor<AAFunctionReachability>( - AA, IRPosition::function(*Edge), DepClassTy::NONE); - Deps.push_back(&EdgeReachability); + bool isReachableImpl(Attributor &A, RQITy &RQI) override { + return isReachableImpl(A, RQI, nullptr); + } - if (EdgeReachability.canReach(A, Fn)) - return true; - } - } + bool isReachableImpl(Attributor &A, RQITy &RQI, + SmallPtrSet<const Function *, 16> *Visited) { - // The result is false for now, set dependencies and leave. - for (auto *Dep : Deps) - A.recordDependence(*Dep, AA, DepClassTy::REQUIRED); + SmallPtrSet<const Function *, 16> LocalVisited; + if (!Visited) + Visited = &LocalVisited; - return false; - } - }; + const auto &IntraFnReachability = A.getAAFor<AAIntraFnReachability>( + *this, IRPosition::function(*RQI.From->getFunction()), + DepClassTy::OPTIONAL); - /// Get call edges that can be reached by this instruction. - bool getReachableCallEdges(Attributor &A, const AAReachability &Reachability, - const Instruction &Inst, - SmallVector<const AACallEdges *> &Result) const { // Determine call like instructions that we can reach from the inst. + SmallVector<CallBase *> ReachableCallBases; auto CheckCallBase = [&](Instruction &CBInst) { - if (!Reachability.isAssumedReachable(A, Inst, CBInst)) - return true; - - auto &CB = cast<CallBase>(CBInst); - const AACallEdges &AAEdges = A.getAAFor<AACallEdges>( - *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); - - Result.push_back(&AAEdges); + if (IntraFnReachability.isAssumedReachable(A, *RQI.From, CBInst, + RQI.ExclusionSet)) + ReachableCallBases.push_back(cast<CallBase>(&CBInst)); return true; }; bool UsedAssumedInformation = false; - return A.checkForAllCallLikeInstructions(CheckCallBase, *this, - UsedAssumedInformation, - /* CheckBBLivenessOnly */ true); - } - -public: - AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A) - : AAFunctionReachability(IRP, A) {} - - bool canReach(Attributor &A, const Function &Fn) const override { - if (!isValidState()) - return true; - - const AACallEdges &AAEdges = - A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED); - - // Attributor returns attributes as const, so this function has to be - // const for users of this attribute to use it without having to do - // a const_cast. - // This is a hack for us to be able to cache queries. - auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this); - bool Result = NonConstThis->WholeFunction.isReachable(A, *NonConstThis, - {&AAEdges}, Fn); - - return Result; - } - - /// Can \p CB reach \p Fn - bool canReach(Attributor &A, CallBase &CB, - const Function &Fn) const override { - if (!isValidState()) - return true; - - const AACallEdges &AAEdges = A.getAAFor<AACallEdges>( - *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); - - // Attributor returns attributes as const, so this function has to be - // const for users of this attribute to use it without having to do - // a const_cast. - // This is a hack for us to be able to cache queries. - auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this); - QueryResolver &CBQuery = NonConstThis->CBQueries[&CB]; - - bool Result = CBQuery.isReachable(A, *NonConstThis, {&AAEdges}, Fn); - - return Result; - } - - bool instructionCanReach(Attributor &A, const Instruction &Inst, - const Function &Fn) const override { - if (!isValidState()) - return true; - - const auto &Reachability = A.getAAFor<AAReachability>( - *this, IRPosition::function(*getAssociatedFunction()), - DepClassTy::REQUIRED); - - SmallVector<const AACallEdges *> CallEdges; - bool AllKnown = getReachableCallEdges(A, Reachability, Inst, CallEdges); - // Attributor returns attributes as const, so this function has to be - // const for users of this attribute to use it without having to do - // a const_cast. - // This is a hack for us to be able to cache queries. - auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this); - QueryResolver &InstQSet = NonConstThis->InstQueries[&Inst]; - if (!AllKnown) { - LLVM_DEBUG(dbgs() << "[AAReachability] Not all reachable edges known, " - "may reach unknown callee!\n"); - InstQSet.CanReachUnknownCallee = true; - } - - return InstQSet.isReachable(A, *NonConstThis, CallEdges, Fn); - } - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override { - const AACallEdges &AAEdges = - A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED); - ChangeStatus Change = ChangeStatus::UNCHANGED; - - Change |= WholeFunction.update(A, *this, {&AAEdges}); + if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this, + UsedAssumedInformation, + /* CheckBBLivenessOnly */ true)) + return rememberResult(A, RQITy::Reachable::Yes, RQI); - for (auto &CBPair : CBQueries) { - const AACallEdges &AAEdges = A.getAAFor<AACallEdges>( - *this, IRPosition::callsite_function(*CBPair.first), - DepClassTy::REQUIRED); + for (CallBase *CB : ReachableCallBases) { + auto &CBEdges = A.getAAFor<AACallEdges>( + *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL); + if (!CBEdges.getState().isValidState()) + return rememberResult(A, RQITy::Reachable::Yes, RQI); + // TODO Check To backwards in this case. + if (CBEdges.hasUnknownCallee()) + return rememberResult(A, RQITy::Reachable::Yes, RQI); - Change |= CBPair.second.update(A, *this, {&AAEdges}); - } + for (Function *Fn : CBEdges.getOptimisticEdges()) { + if (Fn == RQI.To) + return rememberResult(A, RQITy::Reachable::Yes, RQI); + if (!Visited->insert(Fn).second) + continue; + if (Fn->isDeclaration()) { + if (Fn->hasFnAttribute(Attribute::NoCallback)) + continue; + // TODO Check To backwards in this case. + return rememberResult(A, RQITy::Reachable::Yes, RQI); + } - // Update the Instruction queries. - if (!InstQueries.empty()) { - const AAReachability *Reachability = &A.getAAFor<AAReachability>( - *this, IRPosition::function(*getAssociatedFunction()), - DepClassTy::REQUIRED); + const AAInterFnReachability *InterFnReachability = this; + if (Fn != getAnchorScope()) + InterFnReachability = &A.getAAFor<AAInterFnReachability>( + *this, IRPosition::function(*Fn), DepClassTy::OPTIONAL); - // Check for local callbases first. - for (auto &InstPair : InstQueries) { - SmallVector<const AACallEdges *> CallEdges; - bool AllKnown = - getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges); - // Update will return change if we this effects any queries. - if (!AllKnown) { - LLVM_DEBUG(dbgs() << "[AAReachability] Not all reachable edges " - "known, may reach unknown callee!\n"); - InstPair.second.CanReachUnknownCallee = true; - } - Change |= InstPair.second.update(A, *this, CallEdges); + const Instruction &FnFirstInst = Fn->getEntryBlock().front(); + if (InterFnReachability->instructionCanReach(A, FnFirstInst, *RQI.To, + RQI.ExclusionSet, Visited)) + return rememberResult(A, RQITy::Reachable::Yes, RQI); } } - return Change; - } - - const std::string getAsStr() const override { - size_t QueryCount = - WholeFunction.Reachable.size() + WholeFunction.Unreachable.size(); - - return "FunctionReachability [" + - (canReachUnknownCallee() - ? "unknown" - : (std::to_string(WholeFunction.Reachable.size()) + "," + - std::to_string(QueryCount))) + - "]"; + return rememberResult(A, RQITy::Reachable::No, RQI); } void trackStatistics() const override {} private: - bool canReachUnknownCallee() const override { - return WholeFunction.CanReachUnknownCallee; - } - - /// Used to answer if a the whole function can reacha a specific function. - QueryResolver WholeFunction; - - /// Used to answer if a call base inside this function can reach a specific - /// function. - MapVector<const CallBase *, QueryResolver> CBQueries; - - /// This is for instruction queries than scan "forward". - MapVector<const Instruction *, QueryResolver> InstQueries; + SmallVector<RQITy *> QueryVector; + DenseSet<RQITy *> QueryCache; }; } // namespace template <typename AAType> -static Optional<Constant *> +static std::optional<Constant *> askForAssumedConstant(Attributor &A, const AbstractAttribute &QueryingAA, const IRPosition &IRP, Type &Ty) { if (!Ty.isIntegerTy()) @@ -9774,13 +10378,13 @@ askForAssumedConstant(Attributor &A, const AbstractAttribute &QueryingAA, // This will also pass the call base context. const auto &AA = A.getAAFor<AAType>(QueryingAA, IRP, DepClassTy::NONE); - Optional<Constant *> COpt = AA.getAssumedConstant(A); + std::optional<Constant *> COpt = AA.getAssumedConstant(A); if (!COpt.has_value()) { A.recordDependence(AA, QueryingAA, DepClassTy::OPTIONAL); - return llvm::None; + return std::nullopt; } - if (auto *C = COpt.value()) { + if (auto *C = *COpt) { A.recordDependence(AA, QueryingAA, DepClassTy::OPTIONAL); return C; } @@ -9791,15 +10395,15 @@ Value *AAPotentialValues::getSingleValue( Attributor &A, const AbstractAttribute &AA, const IRPosition &IRP, SmallVectorImpl<AA::ValueAndContext> &Values) { Type &Ty = *IRP.getAssociatedType(); - Optional<Value *> V; + std::optional<Value *> V; for (auto &It : Values) { V = AA::combineOptionalValuesInAAValueLatice(V, It.getValue(), &Ty); - if (V.has_value() && !V.value()) + if (V.has_value() && !*V) break; } if (!V.has_value()) return UndefValue::get(&Ty); - return V.value(); + return *V; } namespace { @@ -9816,7 +10420,9 @@ struct AAPotentialValuesImpl : AAPotentialValues { return; } Value *Stripped = getAssociatedValue().stripPointerCasts(); - if (isa<Constant>(Stripped)) { + auto *CE = dyn_cast<ConstantExpr>(Stripped); + if (isa<Constant>(Stripped) && + (!CE || CE->getOpcode() != Instruction::ICmp)) { addValue(A, getState(), *Stripped, getCtxI(), AA::AnyScope, getAnchorScope()); indicateOptimisticFixpoint(); @@ -9834,15 +10440,15 @@ struct AAPotentialValuesImpl : AAPotentialValues { } template <typename AAType> - static Optional<Value *> askOtherAA(Attributor &A, - const AbstractAttribute &AA, - const IRPosition &IRP, Type &Ty) { + static std::optional<Value *> askOtherAA(Attributor &A, + const AbstractAttribute &AA, + const IRPosition &IRP, Type &Ty) { if (isa<Constant>(IRP.getAssociatedValue())) return &IRP.getAssociatedValue(); - Optional<Constant *> C = askForAssumedConstant<AAType>(A, AA, IRP, Ty); + std::optional<Constant *> C = askForAssumedConstant<AAType>(A, AA, IRP, Ty); if (!C) - return llvm::None; - if (C.value()) + return std::nullopt; + if (*C) if (auto *CC = AA::getWithType(**C, Ty)) return CC; return nullptr; @@ -9854,7 +10460,7 @@ struct AAPotentialValuesImpl : AAPotentialValues { IRPosition ValIRP = IRPosition::value(V); if (auto *CB = dyn_cast_or_null<CallBase>(CtxI)) { - for (auto &U : CB->args()) { + for (const auto &U : CB->args()) { if (U.get() != &V) continue; ValIRP = IRPosition::callsite_argument(*CB, CB->getArgOperandNo(&U)); @@ -9865,25 +10471,24 @@ struct AAPotentialValuesImpl : AAPotentialValues { Value *VPtr = &V; if (ValIRP.getAssociatedType()->isIntegerTy()) { Type &Ty = *getAssociatedType(); - Optional<Value *> SimpleV = + std::optional<Value *> SimpleV = askOtherAA<AAValueConstantRange>(A, *this, ValIRP, Ty); - if (SimpleV.has_value() && !SimpleV.value()) { + if (SimpleV.has_value() && !*SimpleV) { auto &PotentialConstantsAA = A.getAAFor<AAPotentialConstantValues>( *this, ValIRP, DepClassTy::OPTIONAL); if (PotentialConstantsAA.isValidState()) { - for (auto &It : PotentialConstantsAA.getAssumedSet()) { + for (const auto &It : PotentialConstantsAA.getAssumedSet()) State.unionAssumed({{*ConstantInt::get(&Ty, It), nullptr}, S}); - } - assert(!PotentialConstantsAA.undefIsContained() && - "Undef should be an explicit value!"); + if (PotentialConstantsAA.undefIsContained()) + State.unionAssumed({{*UndefValue::get(&Ty), nullptr}, S}); return; } } if (!SimpleV.has_value()) return; - if (SimpleV.value()) - VPtr = SimpleV.value(); + if (*SimpleV) + VPtr = *SimpleV; } if (isa<ConstantInt>(VPtr)) @@ -9899,6 +10504,15 @@ struct AAPotentialValuesImpl : AAPotentialValues { struct ItemInfo { AA::ValueAndContext I; AA::ValueScope S; + + bool operator==(const ItemInfo &II) const { + return II.I == I && II.S == S; + }; + bool operator<(const ItemInfo &II) const { + if (I == II.I) + return S < II.S; + return I < II.I; + }; }; bool recurseForValue(Attributor &A, const IRPosition &IRP, AA::ValueScope S) { @@ -9925,7 +10539,7 @@ struct AAPotentialValuesImpl : AAPotentialValues { void giveUpOnIntraprocedural(Attributor &A) { auto NewS = StateType::getBestState(getState()); - for (auto &It : getAssumedSet()) { + for (const auto &It : getAssumedSet()) { if (It.second == AA::Intraprocedural) continue; addValue(A, NewS, *It.first.getValue(), It.first.getCtxI(), @@ -9977,7 +10591,7 @@ struct AAPotentialValuesImpl : AAPotentialValues { AA::ValueScope S) const override { if (!isValidState()) return false; - for (auto &It : getAssumedSet()) + for (const auto &It : getAssumedSet()) if (It.second & S) Values.push_back(It.first); assert(!undefIsContained() && "Undef should be an explicit value!"); @@ -10010,10 +10624,9 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { /// We handle multiple cases, one in which at least one operand is an /// (assumed) nullptr. If so, try to simplify it using AANonNull on the other /// operand. Return true if successful, in that case Worklist will be updated. - bool handleCmp(Attributor &A, CmpInst &Cmp, ItemInfo II, + bool handleCmp(Attributor &A, Value &Cmp, Value *LHS, Value *RHS, + CmpInst::Predicate Pred, ItemInfo II, SmallVectorImpl<ItemInfo> &Worklist) { - Value *LHS = Cmp.getOperand(0); - Value *RHS = Cmp.getOperand(1); // Simplify the operands first. bool UsedAssumedInformation = false; @@ -10022,7 +10635,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { UsedAssumedInformation, AA::Intraprocedural); if (!SimplifiedLHS.has_value()) return true; - if (!SimplifiedLHS.value()) + if (!*SimplifiedLHS) return false; LHS = *SimplifiedLHS; @@ -10031,24 +10644,24 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { UsedAssumedInformation, AA::Intraprocedural); if (!SimplifiedRHS.has_value()) return true; - if (!SimplifiedRHS.value()) + if (!*SimplifiedRHS) return false; RHS = *SimplifiedRHS; - LLVMContext &Ctx = Cmp.getContext(); + LLVMContext &Ctx = LHS->getContext(); // Handle the trivial case first in which we don't even need to think about // null or non-null. - if (LHS == RHS && (Cmp.isTrueWhenEqual() || Cmp.isFalseWhenEqual())) { - Constant *NewV = - ConstantInt::get(Type::getInt1Ty(Ctx), Cmp.isTrueWhenEqual()); + if (LHS == RHS && + (CmpInst::isTrueWhenEqual(Pred) || CmpInst::isFalseWhenEqual(Pred))) { + Constant *NewV = ConstantInt::get(Type::getInt1Ty(Ctx), + CmpInst::isTrueWhenEqual(Pred)); addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, getAnchorScope()); return true; } // From now on we only handle equalities (==, !=). - ICmpInst *ICmp = dyn_cast<ICmpInst>(&Cmp); - if (!ICmp || !ICmp->isEquality()) + if (!CmpInst::isEquality(Pred)) return false; bool LHSIsNull = isa<ConstantPointerNull>(LHS); @@ -10065,14 +10678,13 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { // The index is the operand that we assume is not null. unsigned PtrIdx = LHSIsNull; auto &PtrNonNullAA = A.getAAFor<AANonNull>( - *this, IRPosition::value(*ICmp->getOperand(PtrIdx)), - DepClassTy::REQUIRED); + *this, IRPosition::value(*(PtrIdx ? RHS : LHS)), DepClassTy::REQUIRED); if (!PtrNonNullAA.isAssumedNonNull()) return false; // The new value depends on the predicate, true for != and false for ==. - Constant *NewV = ConstantInt::get(Type::getInt1Ty(Ctx), - ICmp->getPredicate() == CmpInst::ICMP_NE); + Constant *NewV = + ConstantInt::get(Type::getInt1Ty(Ctx), Pred == CmpInst::ICMP_NE); addValue(A, getState(), *NewV, /* CtxI */ nullptr, II.S, getAnchorScope()); return true; } @@ -10082,7 +10694,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const Instruction *CtxI = II.I.getCtxI(); bool UsedAssumedInformation = false; - Optional<Constant *> C = + std::optional<Constant *> C = A.getAssumedConstant(*SI.getCondition(), *this, UsedAssumedInformation); bool NoValueYet = !C.has_value(); if (NoValueYet || isa_and_nonnull<UndefValue>(*C)) @@ -10092,10 +10704,20 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { Worklist.push_back({{*SI.getFalseValue(), CtxI}, II.S}); else Worklist.push_back({{*SI.getTrueValue(), CtxI}, II.S}); - } else { + } else if (&SI == &getAssociatedValue()) { // We could not simplify the condition, assume both values. Worklist.push_back({{*SI.getTrueValue(), CtxI}, II.S}); Worklist.push_back({{*SI.getFalseValue(), CtxI}, II.S}); + } else { + std::optional<Value *> SimpleV = A.getAssumedSimplified( + IRPosition::inst(SI), *this, UsedAssumedInformation, II.S); + if (!SimpleV.has_value()) + return true; + if (*SimpleV) { + addValue(A, getState(), **SimpleV, CtxI, II.S, getAnchorScope()); + return true; + } + return false; } return true; } @@ -10180,16 +10802,28 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { return LI; }; - LivenessInfo &LI = GetLivenessInfo(*PHI.getFunction()); - for (unsigned u = 0, e = PHI.getNumIncomingValues(); u < e; u++) { - BasicBlock *IncomingBB = PHI.getIncomingBlock(u); - if (LI.LivenessAA->isEdgeDead(IncomingBB, PHI.getParent())) { - LI.AnyDead = true; - continue; + if (&PHI == &getAssociatedValue()) { + LivenessInfo &LI = GetLivenessInfo(*PHI.getFunction()); + for (unsigned u = 0, e = PHI.getNumIncomingValues(); u < e; u++) { + BasicBlock *IncomingBB = PHI.getIncomingBlock(u); + if (LI.LivenessAA->isEdgeDead(IncomingBB, PHI.getParent())) { + LI.AnyDead = true; + continue; + } + Worklist.push_back( + {{*PHI.getIncomingValue(u), IncomingBB->getTerminator()}, II.S}); } - Worklist.push_back( - {{*PHI.getIncomingValue(u), IncomingBB->getTerminator()}, II.S}); + return true; } + + bool UsedAssumedInformation = false; + std::optional<Value *> SimpleV = A.getAssumedSimplified( + IRPosition::inst(PHI), *this, UsedAssumedInformation, II.S); + if (!SimpleV.has_value()) + return true; + if (!(*SimpleV)) + return false; + addValue(A, getState(), **SimpleV, &PHI, II.S, getAnchorScope()); return true; } @@ -10212,8 +10846,8 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { if (!SimplifiedOp.has_value()) return true; - if (SimplifiedOp.value()) - NewOps[Idx] = SimplifiedOp.value(); + if (*SimplifiedOp) + NewOps[Idx] = *SimplifiedOp; else NewOps[Idx] = Op; @@ -10251,7 +10885,8 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { SmallVectorImpl<ItemInfo> &Worklist, SmallMapVector<const Function *, LivenessInfo, 4> &LivenessAAs) { if (auto *CI = dyn_cast<CmpInst>(&I)) - if (handleCmp(A, *CI, II, Worklist)) + if (handleCmp(A, *CI, CI->getOperand(0), CI->getOperand(1), + CI->getPredicate(), II, Worklist)) return true; switch (I.getOpcode()) { @@ -10271,7 +10906,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { SmallMapVector<const Function *, LivenessInfo, 4> LivenessAAs; Value *InitialV = &getAssociatedValue(); - SmallSet<AA::ValueAndContext, 16> Visited; + SmallSet<ItemInfo, 16> Visited; SmallVector<ItemInfo, 16> Worklist; Worklist.push_back({{*InitialV, getCtxI()}, AA::AnyScope}); @@ -10285,7 +10920,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { // Check if we should process the current value. To prevent endless // recursion keep a record of the values we followed! - if (!Visited.insert(II.I).second) + if (!Visited.insert(II).second) continue; // Make sure we limit the compile time for complex expressions. @@ -10316,6 +10951,13 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { continue; } + if (auto *CE = dyn_cast<ConstantExpr>(V)) { + if (CE->getOpcode() == Instruction::ICmp) + if (handleCmp(A, *CE, CE->getOperand(0), CE->getOperand(1), + CmpInst::Predicate(CE->getPredicate()), II, Worklist)) + continue; + } + if (auto *I = dyn_cast<Instruction>(V)) { if (simplifyInstruction(A, *I, II, Worklist, LivenessAAs)) continue; @@ -10406,8 +11048,7 @@ struct AAPotentialValuesArgument final : AAPotentialValuesImpl { getAnchorScope()); AnyNonLocal = true; } - if (undefIsContained()) - unionAssumedWithUndef(); + assert(!undefIsContained() && "Undef should be an explicit value!"); if (AnyNonLocal) giveUpOnIntraprocedural(A); @@ -10508,16 +11149,16 @@ struct AAPotentialValuesCallSiteReturned : AAPotentialValuesImpl { bool AnyNonLocal = false; for (auto &It : Values) { Value *V = It.getValue(); - Optional<Value *> CallerV = A.translateArgumentToCallSiteContent( + std::optional<Value *> CallerV = A.translateArgumentToCallSiteContent( V, *CB, *this, UsedAssumedInformation); if (!CallerV.has_value()) { // Nothing to do as long as no value was determined. continue; } - V = CallerV.value() ? CallerV.value() : V; + V = *CallerV ? *CallerV : V; if (AA::isDynamicallyUnique(A, *this, *V) && AA::isValidInScope(*V, Caller)) { - if (CallerV.value()) { + if (*CallerV) { SmallVector<AA::ValueAndContext> ArgValues; IRPosition IRP = IRPosition::value(*V); if (auto *Arg = dyn_cast<Argument>(V)) @@ -10708,7 +11349,7 @@ private: DenseSet<StringRef> getInitialAssumptions(const IRPosition &IRP) { const CallBase &CB = cast<CallBase>(IRP.getAssociatedValue()); auto Assumptions = getAssumptions(CB); - if (Function *F = IRP.getAssociatedFunction()) + if (const Function *F = CB.getCaller()) set_union(Assumptions, getAssumptions(*F)); if (Function *F = IRP.getAssociatedFunction()) set_union(Assumptions, getAssumptions(*F)); @@ -10724,6 +11365,159 @@ AACallGraphNode *AACallEdgeIterator::operator*() const { void AttributorCallGraph::print() { llvm::WriteGraph(outs(), this); } +/// ------------------------ UnderlyingObjects --------------------------------- + +namespace { +struct AAUnderlyingObjectsImpl + : StateWrapper<BooleanState, AAUnderlyingObjects> { + using BaseTy = StateWrapper<BooleanState, AAUnderlyingObjects>; + AAUnderlyingObjectsImpl(const IRPosition &IRP, Attributor &A) : BaseTy(IRP) {} + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return std::string("UnderlyingObjects ") + + (isValidState() + ? (std::string("inter #") + + std::to_string(InterAssumedUnderlyingObjects.size()) + + " objs" + std::string(", intra #") + + std::to_string(IntraAssumedUnderlyingObjects.size()) + + " objs") + : "<invalid>"); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + auto &Ptr = getAssociatedValue(); + + auto DoUpdate = [&](SmallSetVector<Value *, 8> &UnderlyingObjects, + AA::ValueScope Scope) { + bool UsedAssumedInformation = false; + SmallPtrSet<Value *, 8> SeenObjects; + SmallVector<AA::ValueAndContext> Values; + + if (!A.getAssumedSimplifiedValues(IRPosition::value(Ptr), *this, Values, + Scope, UsedAssumedInformation)) + return UnderlyingObjects.insert(&Ptr); + + bool Changed = false; + + for (unsigned I = 0; I < Values.size(); ++I) { + auto &VAC = Values[I]; + auto *Obj = VAC.getValue(); + Value *UO = getUnderlyingObject(Obj); + if (UO && UO != VAC.getValue() && SeenObjects.insert(UO).second) { + const auto &OtherAA = A.getAAFor<AAUnderlyingObjects>( + *this, IRPosition::value(*UO), DepClassTy::OPTIONAL); + auto Pred = [&Values](Value &V) { + Values.emplace_back(V, nullptr); + return true; + }; + + if (!OtherAA.forallUnderlyingObjects(Pred, Scope)) + llvm_unreachable( + "The forall call should not return false at this position"); + + continue; + } + + if (isa<SelectInst>(Obj) || isa<PHINode>(Obj)) { + Changed |= handleIndirect(A, *Obj, UnderlyingObjects, Scope); + continue; + } + + Changed |= UnderlyingObjects.insert(Obj); + } + + return Changed; + }; + + bool Changed = false; + Changed |= DoUpdate(IntraAssumedUnderlyingObjects, AA::Intraprocedural); + Changed |= DoUpdate(InterAssumedUnderlyingObjects, AA::Interprocedural); + + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + } + + bool forallUnderlyingObjects( + function_ref<bool(Value &)> Pred, + AA::ValueScope Scope = AA::Interprocedural) const override { + if (!isValidState()) + return Pred(getAssociatedValue()); + + auto &AssumedUnderlyingObjects = Scope == AA::Intraprocedural + ? IntraAssumedUnderlyingObjects + : InterAssumedUnderlyingObjects; + for (Value *Obj : AssumedUnderlyingObjects) + if (!Pred(*Obj)) + return false; + + return true; + } + +private: + /// Handle the case where the value is not the actual underlying value, such + /// as a phi node or a select instruction. + bool handleIndirect(Attributor &A, Value &V, + SmallSetVector<Value *, 8> &UnderlyingObjects, + AA::ValueScope Scope) { + bool Changed = false; + const auto &AA = A.getAAFor<AAUnderlyingObjects>( + *this, IRPosition::value(V), DepClassTy::OPTIONAL); + auto Pred = [&](Value &V) { + Changed |= UnderlyingObjects.insert(&V); + return true; + }; + if (!AA.forallUnderlyingObjects(Pred, Scope)) + llvm_unreachable( + "The forall call should not return false at this position"); + return Changed; + } + + /// All the underlying objects collected so far via intra procedural scope. + SmallSetVector<Value *, 8> IntraAssumedUnderlyingObjects; + /// All the underlying objects collected so far via inter procedural scope. + SmallSetVector<Value *, 8> InterAssumedUnderlyingObjects; +}; + +struct AAUnderlyingObjectsFloating final : AAUnderlyingObjectsImpl { + AAUnderlyingObjectsFloating(const IRPosition &IRP, Attributor &A) + : AAUnderlyingObjectsImpl(IRP, A) {} +}; + +struct AAUnderlyingObjectsArgument final : AAUnderlyingObjectsImpl { + AAUnderlyingObjectsArgument(const IRPosition &IRP, Attributor &A) + : AAUnderlyingObjectsImpl(IRP, A) {} +}; + +struct AAUnderlyingObjectsCallSite final : AAUnderlyingObjectsImpl { + AAUnderlyingObjectsCallSite(const IRPosition &IRP, Attributor &A) + : AAUnderlyingObjectsImpl(IRP, A) {} +}; + +struct AAUnderlyingObjectsCallSiteArgument final : AAUnderlyingObjectsImpl { + AAUnderlyingObjectsCallSiteArgument(const IRPosition &IRP, Attributor &A) + : AAUnderlyingObjectsImpl(IRP, A) {} +}; + +struct AAUnderlyingObjectsReturned final : AAUnderlyingObjectsImpl { + AAUnderlyingObjectsReturned(const IRPosition &IRP, Attributor &A) + : AAUnderlyingObjectsImpl(IRP, A) {} +}; + +struct AAUnderlyingObjectsCallSiteReturned final : AAUnderlyingObjectsImpl { + AAUnderlyingObjectsCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AAUnderlyingObjectsImpl(IRP, A) {} +}; + +struct AAUnderlyingObjectsFunction final : AAUnderlyingObjectsImpl { + AAUnderlyingObjectsFunction(const IRPosition &IRP, Attributor &A) + : AAUnderlyingObjectsImpl(IRP, A) {} +}; +} + const char AAReturnedValues::ID = 0; const char AANoUnwind::ID = 0; const char AANoSync::ID = 0; @@ -10733,7 +11527,7 @@ const char AANoRecurse::ID = 0; const char AAWillReturn::ID = 0; const char AAUndefinedBehavior::ID = 0; const char AANoAlias::ID = 0; -const char AAReachability::ID = 0; +const char AAIntraFnReachability::ID = 0; const char AANoReturn::ID = 0; const char AAIsDead::ID = 0; const char AADereferenceable::ID = 0; @@ -10750,9 +11544,10 @@ const char AAPotentialConstantValues::ID = 0; const char AAPotentialValues::ID = 0; const char AANoUndef::ID = 0; const char AACallEdges::ID = 0; -const char AAFunctionReachability::ID = 0; +const char AAInterFnReachability::ID = 0; const char AAPointerInfo::ID = 0; const char AAAssumptionInfo::ID = 0; +const char AAUnderlyingObjects::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -10873,11 +11668,12 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPointerInfo) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoFree) +CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUnderlyingObjects) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAHeapToStack) -CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReachability) CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAUndefinedBehavior) -CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAFunctionReachability) +CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIntraFnReachability) +CREATE_FUNCTION_ONLY_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAInterFnReachability) CREATE_NON_RET_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryBehavior) diff --git a/llvm/lib/Transforms/IPO/BlockExtractor.cpp b/llvm/lib/Transforms/IPO/BlockExtractor.cpp index 9e27ae49a901..a68cf7db7c85 100644 --- a/llvm/lib/Transforms/IPO/BlockExtractor.cpp +++ b/llvm/lib/Transforms/IPO/BlockExtractor.cpp @@ -45,20 +45,15 @@ class BlockExtractor { public: BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {} bool runOnModule(Module &M); - void init(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> - &GroupsOfBlocksToExtract) { - for (const SmallVectorImpl<BasicBlock *> &GroupOfBlocks : - GroupsOfBlocksToExtract) { - SmallVector<BasicBlock *, 16> NewGroup; - NewGroup.append(GroupOfBlocks.begin(), GroupOfBlocks.end()); - GroupsOfBlocks.emplace_back(NewGroup); - } + void + init(const std::vector<std::vector<BasicBlock *>> &GroupsOfBlocksToExtract) { + GroupsOfBlocks = GroupsOfBlocksToExtract; if (!BlockExtractorFile.empty()) loadFile(); } private: - SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupsOfBlocks; + std::vector<std::vector<BasicBlock *>> GroupsOfBlocks; bool EraseFunctions; /// Map a function name to groups of blocks. SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> @@ -68,56 +63,8 @@ private: void splitLandingPadPreds(Function &F); }; -class BlockExtractorLegacyPass : public ModulePass { - BlockExtractor BE; - bool runOnModule(Module &M) override; - -public: - static char ID; - BlockExtractorLegacyPass(const SmallVectorImpl<BasicBlock *> &BlocksToExtract, - bool EraseFunctions) - : ModulePass(ID), BE(EraseFunctions) { - // We want one group per element of the input list. - SmallVector<SmallVector<BasicBlock *, 16>, 4> MassagedGroupsOfBlocks; - for (BasicBlock *BB : BlocksToExtract) { - SmallVector<BasicBlock *, 16> NewGroup; - NewGroup.push_back(BB); - MassagedGroupsOfBlocks.push_back(NewGroup); - } - BE.init(MassagedGroupsOfBlocks); - } - - BlockExtractorLegacyPass(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> - &GroupsOfBlocksToExtract, - bool EraseFunctions) - : ModulePass(ID), BE(EraseFunctions) { - BE.init(GroupsOfBlocksToExtract); - } - - BlockExtractorLegacyPass() - : BlockExtractorLegacyPass(SmallVector<BasicBlock *, 0>(), false) {} -}; - } // end anonymous namespace -char BlockExtractorLegacyPass::ID = 0; -INITIALIZE_PASS(BlockExtractorLegacyPass, "extract-blocks", - "Extract basic blocks from module", false, false) - -ModulePass *llvm::createBlockExtractorPass() { - return new BlockExtractorLegacyPass(); -} -ModulePass *llvm::createBlockExtractorPass( - const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) { - return new BlockExtractorLegacyPass(BlocksToExtract, EraseFunctions); -} -ModulePass *llvm::createBlockExtractorPass( - const SmallVectorImpl<SmallVector<BasicBlock *, 16>> - &GroupsOfBlocksToExtract, - bool EraseFunctions) { - return new BlockExtractorLegacyPass(GroupsOfBlocksToExtract, EraseFunctions); -} - /// Gets all of the blocks specified in the input file. void BlockExtractor::loadFile() { auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile); @@ -161,7 +108,7 @@ void BlockExtractor::splitLandingPadPreds(Function &F) { // Look through the landing pad's predecessors. If one of them ends in an // 'invoke', then we want to split the landing pad. bool Split = false; - for (auto PredBB : predecessors(LPad)) { + for (auto *PredBB : predecessors(LPad)) { if (PredBB->isLandingPad() && PredBB != Parent && isa<InvokeInst>(Parent->getTerminator())) { Split = true; @@ -179,7 +126,6 @@ void BlockExtractor::splitLandingPadPreds(Function &F) { } bool BlockExtractor::runOnModule(Module &M) { - bool Changed = false; // Get all the functions. @@ -251,14 +197,15 @@ bool BlockExtractor::runOnModule(Module &M) { return Changed; } -bool BlockExtractorLegacyPass::runOnModule(Module &M) { - return BE.runOnModule(M); -} +BlockExtractorPass::BlockExtractorPass( + std::vector<std::vector<BasicBlock *>> &&GroupsOfBlocks, + bool EraseFunctions) + : GroupsOfBlocks(GroupsOfBlocks), EraseFunctions(EraseFunctions) {} PreservedAnalyses BlockExtractorPass::run(Module &M, ModuleAnalysisManager &AM) { - BlockExtractor BE(false); - BE.init(SmallVector<SmallVector<BasicBlock *, 16>, 0>()); + BlockExtractor BE(EraseFunctions); + BE.init(GroupsOfBlocks); return BE.runOnModule(M) ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/IPO/ConstantMerge.cpp b/llvm/lib/Transforms/IPO/ConstantMerge.cpp index 73af30ece47c..77bc377f4514 100644 --- a/llvm/lib/Transforms/IPO/ConstantMerge.cpp +++ b/llvm/lib/Transforms/IPO/ConstantMerge.cpp @@ -80,7 +80,7 @@ static void copyDebugLocMetadata(const GlobalVariable *From, GlobalVariable *To) { SmallVector<DIGlobalVariableExpression *, 1> MDs; From->getDebugInfo(MDs); - for (auto MD : MDs) + for (auto *MD : MDs) To->addDebugInfo(MD); } diff --git a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index dfe33ac9da0d..4fe7bb6c757c 100644 --- a/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -87,7 +87,7 @@ void CrossDSOCFI::buildCFICheck(Module &M) { NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions"); if (CfiFunctionsMD) { - for (auto Func : CfiFunctionsMD->operands()) { + for (auto *Func : CfiFunctionsMD->operands()) { assert(Func->getNumOperands() >= 2); for (unsigned I = 2; I < Func->getNumOperands(); ++I) if (ConstantInt *TypeId = diff --git a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index 99fa4baf355d..bf2c65a2402c 100644 --- a/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -222,7 +222,7 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) { // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the // function empty. - NF->getBasicBlockList().splice(NF->begin(), F.getBasicBlockList()); + NF->splice(NF->begin(), &F); // Loop over the argument list, transferring uses of the old arguments over to // the new arguments, also transferring over the names as well. While we're @@ -238,8 +238,8 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) { // Clone metadata from the old function, including debug info descriptor. SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; F.getAllMetadata(MDs); - for (auto MD : MDs) - NF->addMetadata(MD.first, *MD.second); + for (auto [KindID, Node] : MDs) + NF->addMetadata(KindID, *Node); // Fix up any BlockAddresses that refer to the function. F.replaceAllUsesWith(ConstantExpr::getBitCast(NF, F.getType())); @@ -996,7 +996,7 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the // function empty. - NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); + NF->splice(NF->begin(), F); // Loop over the argument list, transferring uses of the old arguments over to // the new arguments, also transferring over the names as well. @@ -1056,14 +1056,14 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { // value (possibly 0 if we became void). auto *NewRet = ReturnInst::Create(F->getContext(), RetVal, RI); NewRet->setDebugLoc(RI->getDebugLoc()); - BB.getInstList().erase(RI); + RI->eraseFromParent(); } // Clone metadata from the old function, including debug info descriptor. SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; F->getAllMetadata(MDs); - for (auto MD : MDs) - NF->addMetadata(MD.first, *MD.second); + for (auto [KindID, Node] : MDs) + NF->addMetadata(KindID, *Node); // If either the return value(s) or argument(s) are removed, then probably the // function does not follow standard calling conventions anymore. Hence, add diff --git a/llvm/lib/Transforms/IPO/ExtractGV.cpp b/llvm/lib/Transforms/IPO/ExtractGV.cpp index 84280781ee70..d5073eed2fef 100644 --- a/llvm/lib/Transforms/IPO/ExtractGV.cpp +++ b/llvm/lib/Transforms/IPO/ExtractGV.cpp @@ -10,11 +10,11 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/SetVector.h" +#include "llvm/Transforms/IPO/ExtractGV.h" #include "llvm/IR/Module.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/IPO.h" +#include "llvm/IR/PassManager.h" #include <algorithm> + using namespace llvm; /// Make sure GV is visible from both modules. Delete is true if it is @@ -48,110 +48,86 @@ static void makeVisible(GlobalValue &GV, bool Delete) { } } -namespace { - /// A pass to extract specific global values and their dependencies. - class GVExtractorPass : public ModulePass { - SetVector<GlobalValue *> Named; - bool deleteStuff; - bool keepConstInit; - public: - static char ID; // Pass identification, replacement for typeid /// If deleteS is true, this pass deletes the specified global values. /// Otherwise, it deletes as much of the module as possible, except for the /// global values specified. - explicit GVExtractorPass(std::vector<GlobalValue*> &GVs, - bool deleteS = true, bool keepConstInit = false) - : ModulePass(ID), Named(GVs.begin(), GVs.end()), deleteStuff(deleteS), - keepConstInit(keepConstInit) {} - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - // Visit the global inline asm. - if (!deleteStuff) - M.setModuleInlineAsm(""); - - // For simplicity, just give all GlobalValues ExternalLinkage. A trickier - // implementation could figure out which GlobalValues are actually - // referenced by the Named set, and which GlobalValues in the rest of - // the module are referenced by the NamedSet, and get away with leaving - // more internal and private things internal and private. But for now, - // be conservative and simple. - - // Visit the GlobalVariables. - for (GlobalVariable &GV : M.globals()) { - bool Delete = deleteStuff == (bool)Named.count(&GV) && - !GV.isDeclaration() && - (!GV.isConstant() || !keepConstInit); - if (!Delete) { - if (GV.hasAvailableExternallyLinkage()) - continue; - if (GV.getName() == "llvm.global_ctors") - continue; - } - - makeVisible(GV, Delete); - - if (Delete) { - // Make this a declaration and drop it's comdat. - GV.setInitializer(nullptr); - GV.setComdat(nullptr); - } - } +ExtractGVPass::ExtractGVPass(std::vector<GlobalValue *> &GVs, bool deleteS, + bool keepConstInit) + : Named(GVs.begin(), GVs.end()), deleteStuff(deleteS), + keepConstInit(keepConstInit) {} + +PreservedAnalyses ExtractGVPass::run(Module &M, ModuleAnalysisManager &) { + // Visit the global inline asm. + if (!deleteStuff) + M.setModuleInlineAsm(""); + + // For simplicity, just give all GlobalValues ExternalLinkage. A trickier + // implementation could figure out which GlobalValues are actually + // referenced by the Named set, and which GlobalValues in the rest of + // the module are referenced by the NamedSet, and get away with leaving + // more internal and private things internal and private. But for now, + // be conservative and simple. + + // Visit the GlobalVariables. + for (GlobalVariable &GV : M.globals()) { + bool Delete = deleteStuff == (bool)Named.count(&GV) && + !GV.isDeclaration() && (!GV.isConstant() || !keepConstInit); + if (!Delete) { + if (GV.hasAvailableExternallyLinkage()) + continue; + if (GV.getName() == "llvm.global_ctors") + continue; + } - // Visit the Functions. - for (Function &F : M) { - bool Delete = - deleteStuff == (bool)Named.count(&F) && !F.isDeclaration(); - if (!Delete) { - if (F.hasAvailableExternallyLinkage()) - continue; - } - - makeVisible(F, Delete); - - if (Delete) { - // Make this a declaration and drop it's comdat. - F.deleteBody(); - F.setComdat(nullptr); - } - } + makeVisible(GV, Delete); - // Visit the Aliases. - for (GlobalAlias &GA : llvm::make_early_inc_range(M.aliases())) { - bool Delete = deleteStuff == (bool)Named.count(&GA); - makeVisible(GA, Delete); - - if (Delete) { - Type *Ty = GA.getValueType(); - - GA.removeFromParent(); - llvm::Value *Declaration; - if (FunctionType *FTy = dyn_cast<FunctionType>(Ty)) { - Declaration = - Function::Create(FTy, GlobalValue::ExternalLinkage, - GA.getAddressSpace(), GA.getName(), &M); - - } else { - Declaration = - new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, - nullptr, GA.getName()); - } - GA.replaceAllUsesWith(Declaration); - delete &GA; - } - } + if (Delete) { + // Make this a declaration and drop it's comdat. + GV.setInitializer(nullptr); + GV.setComdat(nullptr); + } + } - return true; + // Visit the Functions. + for (Function &F : M) { + bool Delete = deleteStuff == (bool)Named.count(&F) && !F.isDeclaration(); + if (!Delete) { + if (F.hasAvailableExternallyLinkage()) + continue; } - }; - char GVExtractorPass::ID = 0; -} + makeVisible(F, Delete); + + if (Delete) { + // Make this a declaration and drop it's comdat. + F.deleteBody(); + F.setComdat(nullptr); + } + } + + // Visit the Aliases. + for (GlobalAlias &GA : llvm::make_early_inc_range(M.aliases())) { + bool Delete = deleteStuff == (bool)Named.count(&GA); + makeVisible(GA, Delete); + + if (Delete) { + Type *Ty = GA.getValueType(); + + GA.removeFromParent(); + llvm::Value *Declaration; + if (FunctionType *FTy = dyn_cast<FunctionType>(Ty)) { + Declaration = Function::Create(FTy, GlobalValue::ExternalLinkage, + GA.getAddressSpace(), GA.getName(), &M); + + } else { + Declaration = new GlobalVariable( + M, Ty, false, GlobalValue::ExternalLinkage, nullptr, GA.getName()); + } + GA.replaceAllUsesWith(Declaration); + delete &GA; + } + } -ModulePass *llvm::createGVExtractionPass(std::vector<GlobalValue *> &GVs, - bool deleteFn, bool keepConstInit) { - return new GVExtractorPass(GVs, deleteFn, keepConstInit); + return PreservedAnalyses::none(); } diff --git a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 50710eaa1b57..3f61dbe3354e 100644 --- a/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -63,16 +63,14 @@ #include <cassert> #include <iterator> #include <map> +#include <optional> #include <vector> using namespace llvm; #define DEBUG_TYPE "function-attrs" -STATISTIC(NumArgMemOnly, "Number of functions marked argmemonly"); -STATISTIC(NumReadNone, "Number of functions marked readnone"); -STATISTIC(NumReadOnly, "Number of functions marked readonly"); -STATISTIC(NumWriteOnly, "Number of functions marked writeonly"); +STATISTIC(NumMemoryAttr, "Number of functions with improved memory attribute"); STATISTIC(NumNoCapture, "Number of arguments marked nocapture"); STATISTIC(NumReturned, "Number of arguments marked returned"); STATISTIC(NumReadNoneArg, "Number of arguments marked readnone"); @@ -122,28 +120,43 @@ using SCCNodeSet = SmallSetVector<Function *, 8>; /// result will be based only on AA results for the function declaration; it /// will be assumed that some other (perhaps less optimized) version of the /// function may be selected at link time. -static FunctionModRefBehavior -checkFunctionMemoryAccess(Function &F, bool ThisBody, AAResults &AAR, - const SCCNodeSet &SCCNodes) { - FunctionModRefBehavior MRB = AAR.getModRefBehavior(&F); - if (MRB == FMRB_DoesNotAccessMemory) +static MemoryEffects checkFunctionMemoryAccess(Function &F, bool ThisBody, + AAResults &AAR, + const SCCNodeSet &SCCNodes) { + MemoryEffects OrigME = AAR.getMemoryEffects(&F); + if (OrigME.doesNotAccessMemory()) // Already perfect! - return MRB; + return OrigME; if (!ThisBody) - return MRB; + return OrigME; + + MemoryEffects ME = MemoryEffects::none(); + // Inalloca and preallocated arguments are always clobbered by the call. + if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca) || + F.getAttributes().hasAttrSomewhere(Attribute::Preallocated)) + ME |= MemoryEffects::argMemOnly(ModRefInfo::ModRef); + + auto AddLocAccess = [&](const MemoryLocation &Loc, ModRefInfo MR) { + // Ignore accesses to known-invariant or local memory. + MR &= AAR.getModRefInfoMask(Loc, /*IgnoreLocal=*/true); + if (isNoModRef(MR)) + return; - // Scan the function body for instructions that may read or write memory. - bool ReadsMemory = false; - bool WritesMemory = false; - // Track if the function accesses memory not based on pointer arguments or - // allocas. - bool AccessesNonArgsOrAlloca = false; - // Returns true if Ptr is not based on a function argument. - auto IsArgumentOrAlloca = [](const Value *Ptr) { - const Value *UO = getUnderlyingObject(Ptr); - return isa<Argument>(UO) || isa<AllocaInst>(UO); + const Value *UO = getUnderlyingObject(Loc.Ptr); + assert(!isa<AllocaInst>(UO) && + "Should have been handled by getModRefInfoMask()"); + if (isa<Argument>(UO)) { + ME |= MemoryEffects::argMemOnly(MR); + return; + } + + // If it's not an identified object, it might be an argument. + if (!isIdentifiedObject(UO)) + ME |= MemoryEffects::argMemOnly(MR); + ME |= MemoryEffects(MemoryEffects::Other, MR); }; + // Scan the function body for instructions that may read or write memory. for (Instruction &I : instructions(F)) { // Some instructions can be ignored even if they read or write memory. // Detect these now, skipping to the next instruction if one is found. @@ -155,11 +168,10 @@ checkFunctionMemoryAccess(Function &F, bool ThisBody, AAResults &AAR, if (!Call->hasOperandBundles() && Call->getCalledFunction() && SCCNodes.count(Call->getCalledFunction())) continue; - FunctionModRefBehavior MRB = AAR.getModRefBehavior(Call); - ModRefInfo MRI = createModRefInfo(MRB); + MemoryEffects CallME = AAR.getMemoryEffects(Call); // If the call doesn't access memory, we're done. - if (isNoModRef(MRI)) + if (CallME.doesNotAccessMemory()) continue; // A pseudo probe call shouldn't change any function attribute since it @@ -169,92 +181,57 @@ checkFunctionMemoryAccess(Function &F, bool ThisBody, AAResults &AAR, if (isa<PseudoProbeInst>(I)) continue; - if (!AliasAnalysis::onlyAccessesArgPointees(MRB)) { - // The call could access any memory. If that includes writes, note it. - if (isModSet(MRI)) - WritesMemory = true; - // If it reads, note it. - if (isRefSet(MRI)) - ReadsMemory = true; - AccessesNonArgsOrAlloca = true; - continue; - } + ME |= CallME.getWithoutLoc(MemoryEffects::ArgMem); + + // If the call accesses captured memory (currently part of "other") and + // an argument is captured (currently not tracked), then it may also + // access argument memory. + ModRefInfo OtherMR = CallME.getModRef(MemoryEffects::Other); + ME |= MemoryEffects::argMemOnly(OtherMR); // Check whether all pointer arguments point to local memory, and // ignore calls that only access local memory. - for (const Use &U : Call->args()) { - const Value *Arg = U; - if (!Arg->getType()->isPtrOrPtrVectorTy()) - continue; + ModRefInfo ArgMR = CallME.getModRef(MemoryEffects::ArgMem); + if (ArgMR != ModRefInfo::NoModRef) { + for (const Use &U : Call->args()) { + const Value *Arg = U; + if (!Arg->getType()->isPtrOrPtrVectorTy()) + continue; - MemoryLocation Loc = - MemoryLocation::getBeforeOrAfter(Arg, I.getAAMetadata()); - // Skip accesses to local or constant memory as they don't impact the - // externally visible mod/ref behavior. - if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) - continue; + AddLocAccess(MemoryLocation::getBeforeOrAfter(Arg, I.getAAMetadata()), ArgMR); + } + } + continue; + } - AccessesNonArgsOrAlloca |= !IsArgumentOrAlloca(Loc.Ptr); + ModRefInfo MR = ModRefInfo::NoModRef; + if (I.mayWriteToMemory()) + MR |= ModRefInfo::Mod; + if (I.mayReadFromMemory()) + MR |= ModRefInfo::Ref; + if (MR == ModRefInfo::NoModRef) + continue; - if (isModSet(MRI)) - // Writes non-local memory. - WritesMemory = true; - if (isRefSet(MRI)) - // Ok, it reads non-local memory. - ReadsMemory = true; - } + std::optional<MemoryLocation> Loc = MemoryLocation::getOrNone(&I); + if (!Loc) { + // If no location is known, conservatively assume anything can be + // accessed. + ME |= MemoryEffects(MR); continue; - } else if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { - MemoryLocation Loc = MemoryLocation::get(LI); - // Ignore non-volatile loads from local memory. (Atomic is okay here.) - if (!LI->isVolatile() && - AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) - continue; - AccessesNonArgsOrAlloca |= !IsArgumentOrAlloca(Loc.Ptr); - } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { - MemoryLocation Loc = MemoryLocation::get(SI); - // Ignore non-volatile stores to local memory. (Atomic is okay here.) - if (!SI->isVolatile() && - AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) - continue; - AccessesNonArgsOrAlloca |= !IsArgumentOrAlloca(Loc.Ptr); - } else if (VAArgInst *VI = dyn_cast<VAArgInst>(&I)) { - // Ignore vaargs on local memory. - MemoryLocation Loc = MemoryLocation::get(VI); - if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) - continue; - AccessesNonArgsOrAlloca |= !IsArgumentOrAlloca(Loc.Ptr); - } else { - // If AccessesNonArgsOrAlloca has not been updated above, set it - // conservatively. - AccessesNonArgsOrAlloca |= I.mayReadOrWriteMemory(); } - // Any remaining instructions need to be taken seriously! Check if they - // read or write memory. - // - // Writes memory, remember that. - WritesMemory |= I.mayWriteToMemory(); + // Volatile operations may access inaccessible memory. + if (I.isVolatile()) + ME |= MemoryEffects::inaccessibleMemOnly(MR); - // If this instruction may read memory, remember that. - ReadsMemory |= I.mayReadFromMemory(); + AddLocAccess(*Loc, MR); } - if (!WritesMemory && !ReadsMemory) - return FMRB_DoesNotAccessMemory; - - FunctionModRefBehavior Result = FunctionModRefBehavior(FMRL_Anywhere); - if (!AccessesNonArgsOrAlloca) - Result = FunctionModRefBehavior(FMRL_ArgumentPointees); - if (WritesMemory) - Result = FunctionModRefBehavior(Result | static_cast<int>(ModRefInfo::Mod)); - if (ReadsMemory) - Result = FunctionModRefBehavior(Result | static_cast<int>(ModRefInfo::Ref)); - return Result; + return OrigME & ME; } -FunctionModRefBehavior llvm::computeFunctionBodyMemoryAccess(Function &F, - AAResults &AAR) { +MemoryEffects llvm::computeFunctionBodyMemoryAccess(Function &F, + AAResults &AAR) { return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}); } @@ -262,91 +239,27 @@ FunctionModRefBehavior llvm::computeFunctionBodyMemoryAccess(Function &F, template <typename AARGetterT> static void addMemoryAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter, SmallSet<Function *, 8> &Changed) { - // Check if any of the functions in the SCC read or write memory. If they - // write memory then they can't be marked readnone or readonly. - bool ReadsMemory = false; - bool WritesMemory = false; - // Check if all functions only access memory through their arguments. - bool ArgMemOnly = true; + MemoryEffects ME = MemoryEffects::none(); for (Function *F : SCCNodes) { // Call the callable parameter to look up AA results for this function. AAResults &AAR = AARGetter(*F); // Non-exact function definitions may not be selected at link time, and an // alternative version that writes to memory may be selected. See the // comment on GlobalValue::isDefinitionExact for more details. - FunctionModRefBehavior FMRB = - checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes); - if (FMRB == FMRB_DoesNotAccessMemory) - continue; - ModRefInfo MR = createModRefInfo(FMRB); - ReadsMemory |= isRefSet(MR); - WritesMemory |= isModSet(MR); - ArgMemOnly &= AliasAnalysis::onlyAccessesArgPointees(FMRB); - // Reached neither readnone, readonly, writeonly nor argmemonly can be - // inferred. Exit. - if (ReadsMemory && WritesMemory && !ArgMemOnly) + ME |= checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes); + // Reached bottom of the lattice, we will not be able to improve the result. + if (ME == MemoryEffects::unknown()) return; } - assert((!ReadsMemory || !WritesMemory || ArgMemOnly) && - "no memory attributes can be added for this SCC, should have exited " - "earlier"); - // Success! Functions in this SCC do not access memory, only read memory, - // only write memory, or only access memory through its arguments. Give them - // the appropriate attribute. - for (Function *F : SCCNodes) { - // If possible add argmemonly attribute to F, if it accesses memory. - if (ArgMemOnly && !F->onlyAccessesArgMemory() && - (ReadsMemory || WritesMemory)) { - NumArgMemOnly++; - F->addFnAttr(Attribute::ArgMemOnly); + MemoryEffects OldME = F->getMemoryEffects(); + MemoryEffects NewME = ME & OldME; + if (NewME != OldME) { + ++NumMemoryAttr; + F->setMemoryEffects(NewME); Changed.insert(F); } - - // The SCC contains functions both writing and reading from memory. We - // cannot add readonly or writeonline attributes. - if (ReadsMemory && WritesMemory) - continue; - if (F->doesNotAccessMemory()) - // Already perfect! - continue; - - if (F->onlyReadsMemory() && ReadsMemory) - // No change. - continue; - - if (F->onlyWritesMemory() && WritesMemory) - continue; - - Changed.insert(F); - - // Clear out any existing attributes. - AttributeMask AttrsToRemove; - AttrsToRemove.addAttribute(Attribute::ReadOnly); - AttrsToRemove.addAttribute(Attribute::ReadNone); - AttrsToRemove.addAttribute(Attribute::WriteOnly); - - if (!WritesMemory && !ReadsMemory) { - // Clear out any "access range attributes" if readnone was deduced. - AttrsToRemove.addAttribute(Attribute::ArgMemOnly); - AttrsToRemove.addAttribute(Attribute::InaccessibleMemOnly); - AttrsToRemove.addAttribute(Attribute::InaccessibleMemOrArgMemOnly); - } - F->removeFnAttrs(AttrsToRemove); - - // Add in the new attribute. - if (WritesMemory && !ReadsMemory) - F->addFnAttr(Attribute::WriteOnly); - else - F->addFnAttr(ReadsMemory ? Attribute::ReadOnly : Attribute::ReadNone); - - if (WritesMemory && !ReadsMemory) - ++NumWriteOnly; - else if (ReadsMemory) - ++NumReadOnly; - else - ++NumReadNone; } } @@ -517,7 +430,7 @@ bool llvm::thinLTOPropagateFunctionAttrs( ++NumThinLinkNoUnwind; } - for (auto &S : V.getSummaryList()) { + for (const auto &S : V.getSummaryList()) { if (auto *FS = dyn_cast<FunctionSummary>(S.get())) { if (InferredFlags.NoRecurse) FS->setNoRecurse(); @@ -1146,7 +1059,7 @@ static bool isFunctionMallocLike(Function *F, const SCCNodeSet &SCCNodes) { break; if (CB.getCalledFunction() && SCCNodes.count(CB.getCalledFunction())) break; - LLVM_FALLTHROUGH; + [[fallthrough]]; } default: return false; // Did not come from an allocation. diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp index 360ec24a0509..7c994657e5c8 100644 --- a/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -274,7 +274,7 @@ static void computeImportForReferencedGlobals( SmallVectorImpl<EdgeInfo> &Worklist, FunctionImporter::ImportMapTy &ImportList, StringMap<FunctionImporter::ExportSetTy> *ExportLists) { - for (auto &VI : Summary.refs()) { + for (const auto &VI : Summary.refs()) { if (!shouldImportGlobal(VI, DefinedGVSummaries)) { LLVM_DEBUG( dbgs() << "Ref ignored! Target already in destination module.\n"); @@ -294,7 +294,7 @@ static void computeImportForReferencedGlobals( RefSummary->modulePath() != Summary.modulePath(); }; - for (auto &RefSummary : VI.getSummaryList()) + for (const auto &RefSummary : VI.getSummaryList()) if (isa<GlobalVarSummary>(RefSummary.get()) && Index.canImportGlobalVar(RefSummary.get(), /* AnalyzeRefs */ true) && !LocalNotInModule(RefSummary.get())) { @@ -355,7 +355,7 @@ static void computeImportForFunction( computeImportForReferencedGlobals(Summary, Index, DefinedGVSummaries, Worklist, ImportList, ExportLists); static int ImportCount = 0; - for (auto &Edge : Summary.calls()) { + for (const auto &Edge : Summary.calls()) { ValueInfo VI = Edge.first; LLVM_DEBUG(dbgs() << " edge -> " << VI << " Threshold:" << Threshold << "\n"); @@ -529,7 +529,7 @@ static void ComputeImportForModule( // Populate the worklist with the import for the functions in the current // module - for (auto &GVSummary : DefinedGVSummaries) { + for (const auto &GVSummary : DefinedGVSummaries) { #ifndef NDEBUG // FIXME: Change the GVSummaryMapTy to hold ValueInfo instead of GUID // so this map look up (and possibly others) can be avoided. @@ -656,7 +656,7 @@ void llvm::ComputeCrossModuleImport( StringMap<FunctionImporter::ImportMapTy> &ImportLists, StringMap<FunctionImporter::ExportSetTy> &ExportLists) { // For each module that has function defined, compute the import/export lists. - for (auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) { + for (const auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) { auto &ImportList = ImportLists[DefinedGVSummaries.first()]; LLVM_DEBUG(dbgs() << "Computing import for Module '" << DefinedGVSummaries.first() << "'\n"); @@ -697,9 +697,9 @@ void llvm::ComputeCrossModuleImport( NewExports.insert(VI); } else { auto *FS = cast<FunctionSummary>(S); - for (auto &Edge : FS->calls()) + for (const auto &Edge : FS->calls()) NewExports.insert(Edge.first); - for (auto &Ref : FS->refs()) + for (const auto &Ref : FS->refs()) NewExports.insert(Ref); } } @@ -780,7 +780,7 @@ void llvm::ComputeCrossModuleImportForModule( void llvm::ComputeCrossModuleImportForModuleFromIndex( StringRef ModulePath, const ModuleSummaryIndex &Index, FunctionImporter::ImportMapTy &ImportList) { - for (auto &GlobalList : Index) { + for (const auto &GlobalList : Index) { // Ignore entries for undefined references. if (GlobalList.second.SummaryList.empty()) continue; @@ -837,7 +837,7 @@ void updateValueInfoForIndirectCalls(ModuleSummaryIndex &Index, void llvm::updateIndirectCalls(ModuleSummaryIndex &Index) { for (const auto &Entry : Index) { - for (auto &S : Entry.second.SummaryList) { + for (const auto &S : Entry.second.SummaryList) { if (auto *FS = dyn_cast<FunctionSummary>(S.get())) updateValueInfoForIndirectCalls(Index, FS); } @@ -863,14 +863,14 @@ void llvm::computeDeadSymbolsAndUpdateIndirectCalls( ValueInfo VI = Index.getValueInfo(GUID); if (!VI) continue; - for (auto &S : VI.getSummaryList()) + for (const auto &S : VI.getSummaryList()) S->setLive(true); } // Add values flagged in the index as live roots to the worklist. for (const auto &Entry : Index) { auto VI = Index.getValueInfo(Entry); - for (auto &S : Entry.second.SummaryList) { + for (const auto &S : Entry.second.SummaryList) { if (auto *FS = dyn_cast<FunctionSummary>(S.get())) updateValueInfoForIndirectCalls(Index, FS); if (S->isLive()) { @@ -907,7 +907,7 @@ void llvm::computeDeadSymbolsAndUpdateIndirectCalls( if (isPrevailing(VI.getGUID()) == PrevailingType::No) { bool KeepAliveLinkage = false; bool Interposable = false; - for (auto &S : VI.getSummaryList()) { + for (const auto &S : VI.getSummaryList()) { if (S->linkage() == GlobalValue::AvailableExternallyLinkage || S->linkage() == GlobalValue::WeakODRLinkage || S->linkage() == GlobalValue::LinkOnceODRLinkage) @@ -927,7 +927,7 @@ void llvm::computeDeadSymbolsAndUpdateIndirectCalls( } } - for (auto &S : VI.getSummaryList()) + for (const auto &S : VI.getSummaryList()) S->setLive(true); ++LiveSymbols; Worklist.push_back(VI); @@ -935,7 +935,7 @@ void llvm::computeDeadSymbolsAndUpdateIndirectCalls( while (!Worklist.empty()) { auto VI = Worklist.pop_back_val(); - for (auto &Summary : VI.getSummaryList()) { + for (const auto &Summary : VI.getSummaryList()) { if (auto *AS = dyn_cast<AliasSummary>(Summary.get())) { // If this is an alias, visit the aliasee VI to ensure that all copies // are marked live and it is added to the worklist for further @@ -982,12 +982,12 @@ void llvm::gatherImportedSummariesForModule( ModuleToSummariesForIndex[std::string(ModulePath)] = ModuleToDefinedGVSummaries.lookup(ModulePath); // Include summaries for imports. - for (auto &ILI : ImportList) { + for (const auto &ILI : ImportList) { auto &SummariesForIndex = ModuleToSummariesForIndex[std::string(ILI.first())]; const auto &DefinedGVSummaries = ModuleToDefinedGVSummaries.lookup(ILI.first()); - for (auto &GI : ILI.second) { + for (const auto &GI : ILI.second) { const auto &DS = DefinedGVSummaries.find(GI); assert(DS != DefinedGVSummaries.end() && "Expected a defined summary for imported global value"); @@ -1004,7 +1004,7 @@ std::error_code llvm::EmitImportsFiles( raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::OF_None); if (EC) return EC; - for (auto &ILI : ModuleToSummariesForIndex) + for (const auto &ILI : ModuleToSummariesForIndex) // The ModuleToSummariesForIndex map includes an entry for the current // Module (needed for writing out the index files). We don't want to // include it in the imports file, however, so filter it out. @@ -1051,6 +1051,7 @@ bool llvm::convertToDeclaration(GlobalValue &GV) { void llvm::thinLTOFinalizeInModule(Module &TheModule, const GVSummaryMapTy &DefinedGlobals, bool PropagateAttrs) { + DenseSet<Comdat *> NonPrevailingComdats; auto FinalizeInModule = [&](GlobalValue &GV, bool Propagate = false) { // See if the global summary analysis computed a new resolved linkage. const auto &GS = DefinedGlobals.find(GV.getGUID()); @@ -1128,8 +1129,11 @@ void llvm::thinLTOFinalizeInModule(Module &TheModule, // as this is a declaration for the linker, and will be dropped eventually. // It is illegal for comdats to contain declarations. auto *GO = dyn_cast_or_null<GlobalObject>(&GV); - if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) + if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) { + if (GO->getComdat()->getName() == GO->getName()) + NonPrevailingComdats.insert(GO->getComdat()); GO->setComdat(nullptr); + } }; // Process functions and global now @@ -1139,6 +1143,36 @@ void llvm::thinLTOFinalizeInModule(Module &TheModule, FinalizeInModule(GV); for (auto &GV : TheModule.aliases()) FinalizeInModule(GV); + + // For a non-prevailing comdat, all its members must be available_externally. + // FinalizeInModule has handled non-local-linkage GlobalValues. Here we handle + // local linkage GlobalValues. + if (NonPrevailingComdats.empty()) + return; + for (auto &GO : TheModule.global_objects()) { + if (auto *C = GO.getComdat(); C && NonPrevailingComdats.count(C)) { + GO.setComdat(nullptr); + GO.setLinkage(GlobalValue::AvailableExternallyLinkage); + } + } + bool Changed; + do { + Changed = false; + // If an alias references a GlobalValue in a non-prevailing comdat, change + // it to available_externally. For simplicity we only handle GlobalValue and + // ConstantExpr with a base object. ConstantExpr without a base object is + // unlikely used in a COMDAT. + for (auto &GA : TheModule.aliases()) { + if (GA.hasAvailableExternallyLinkage()) + continue; + GlobalObject *Obj = GA.getAliaseeObject(); + assert(Obj && "aliasee without an base object is unimplemented"); + if (Obj->hasAvailableExternallyLinkage()) { + GA.setLinkage(GlobalValue::AvailableExternallyLinkage); + Changed = true; + } + } + } while (Changed); } /// Run internalization on \p TheModule based on symmary analysis. @@ -1226,10 +1260,10 @@ Expected<bool> FunctionImporter::importFunctions( IRMover Mover(DestModule); // Do the actual import of functions now, one Module at a time std::set<StringRef> ModuleNameOrderedList; - for (auto &FunctionsToImportPerModule : ImportList) { + for (const auto &FunctionsToImportPerModule : ImportList) { ModuleNameOrderedList.insert(FunctionsToImportPerModule.first()); } - for (auto &Name : ModuleNameOrderedList) { + for (const auto &Name : ModuleNameOrderedList) { // Get the module for the import const auto &FunctionsToImportPerModule = ImportList.find(Name); assert(FunctionsToImportPerModule != ImportList.end()); diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index dafd0dc865a2..4a7efb28e853 100644 --- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -45,6 +45,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/IPO/FunctionSpecialization.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/InlineCost.h" @@ -70,11 +71,6 @@ static cl::opt<bool> ForceFunctionSpecialization( cl::desc("Force function specialization for every call site with a " "constant argument")); -static cl::opt<unsigned> FuncSpecializationMaxIters( - "func-specialization-max-iters", cl::Hidden, - cl::desc("The maximum number of iterations function specialization is run"), - cl::init(1)); - static cl::opt<unsigned> MaxClonesThreshold( "func-specialization-max-clones", cl::Hidden, cl::desc("The maximum number of clones allowed for a single function " @@ -97,9 +93,6 @@ static cl::opt<bool> SpecializeOnAddresses( cl::desc("Enable function specialization on the address of global values")); // Disabled by default as it can significantly increase compilation times. -// Running nikic's compile time tracker on x86 with instruction count as the -// metric shows 3-4% regression for SPASS while being neutral for all other -// benchmarks of the llvm test suite. // // https://llvm-compile-time-tracker.com // https://github.com/nikic/llvm-compile-time-tracker @@ -108,37 +101,8 @@ static cl::opt<bool> EnableSpecializationForLiteralConstant( cl::desc("Enable specialization of functions that take a literal constant " "as an argument.")); -namespace { -// Bookkeeping struct to pass data from the analysis and profitability phase -// to the actual transform helper functions. -struct SpecializationInfo { - SmallVector<ArgInfo, 8> Args; // Stores the {formal,actual} argument pairs. - InstructionCost Gain; // Profitability: Gain = Bonus - Cost. -}; -} // Anonymous namespace - -using FuncList = SmallVectorImpl<Function *>; -using CallArgBinding = std::pair<CallBase *, Constant *>; -using CallSpecBinding = std::pair<CallBase *, SpecializationInfo>; -// We are using MapVector because it guarantees deterministic iteration -// order across executions. -using SpecializationMap = SmallMapVector<CallBase *, SpecializationInfo, 8>; - -// Helper to check if \p LV is either a constant or a constant -// range with a single element. This should cover exactly the same cases as the -// old ValueLatticeElement::isConstant() and is intended to be used in the -// transition to ValueLatticeElement. -static bool isConstant(const ValueLatticeElement &LV) { - return LV.isConstant() || - (LV.isConstantRange() && LV.getConstantRange().isSingleElement()); -} - -// Helper to check if \p LV is either overdefined or a constant int. -static bool isOverdefined(const ValueLatticeElement &LV) { - return !LV.isUnknownOrUndef() && !isConstant(LV); -} - -static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) { +Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca, + CallInst *Call) { Value *StoreValue = nullptr; for (auto *User : Alloca->users()) { // We can't use llvm::isAllocaPromotable() as that would fail because of @@ -161,14 +125,14 @@ static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) { // Bail if there is any other unknown usage. return nullptr; } - return dyn_cast_or_null<Constant>(StoreValue); + return getCandidateConstant(StoreValue); } // A constant stack value is an AllocaInst that has a single constant // value stored to it. Return this constant if such an alloca stack value // is a function argument. -static Constant *getConstantStackValue(CallInst *Call, Value *Val, - SCCPSolver &Solver) { +Constant *FunctionSpecializer::getConstantStackValue(CallInst *Call, + Value *Val) { if (!Val) return nullptr; Val = Val->stripPointerCasts(); @@ -201,19 +165,23 @@ static Constant *getConstantStackValue(CallInst *Call, Value *Val, // ret void // } // -static void constantArgPropagation(FuncList &WorkList, Module &M, - SCCPSolver &Solver) { +void FunctionSpecializer::promoteConstantStackValues() { // Iterate over the argument tracked functions see if there // are any new constant values for the call instruction via // stack variables. - for (auto *F : WorkList) { + for (Function &F : M) { + if (!Solver.isArgumentTrackedFunction(&F)) + continue; - for (auto *User : F->users()) { + for (auto *User : F.users()) { auto *Call = dyn_cast<CallInst>(User); if (!Call) continue; + if (!Solver.isBlockExecutable(Call->getParent())) + continue; + bool Changed = false; for (const Use &U : Call->args()) { unsigned Idx = Call->getArgOperandNo(&U); @@ -223,7 +191,7 @@ static void constantArgPropagation(FuncList &WorkList, Module &M, if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy()) continue; - auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver); + auto *ConstVal = getConstantStackValue(Call, ArgOp); if (!ConstVal) continue; @@ -245,7 +213,7 @@ static void constantArgPropagation(FuncList &WorkList, Module &M, } // ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics -// interfere with the constantArgPropagation optimization. +// interfere with the promoteConstantStackValues() optimization. static void removeSSACopy(Function &F) { for (BasicBlock &BB : F) { for (Instruction &Inst : llvm::make_early_inc_range(BB)) { @@ -260,690 +228,552 @@ static void removeSSACopy(Function &F) { } } -static void removeSSACopy(Module &M) { - for (Function &F : M) - removeSSACopy(F); +/// Remove any ssa_copy intrinsics that may have been introduced. +void FunctionSpecializer::cleanUpSSA() { + for (Function *F : SpecializedFuncs) + removeSSACopy(*F); } -namespace { -class FunctionSpecializer { - - /// The IPSCCP Solver. - SCCPSolver &Solver; - - /// Analyses used to help determine if a function should be specialized. - std::function<AssumptionCache &(Function &)> GetAC; - std::function<TargetTransformInfo &(Function &)> GetTTI; - std::function<TargetLibraryInfo &(Function &)> GetTLI; - - SmallPtrSet<Function *, 4> SpecializedFuncs; - SmallPtrSet<Function *, 4> FullySpecialized; - SmallVector<Instruction *> ReplacedWithConstant; - DenseMap<Function *, CodeMetrics> FunctionMetrics; - -public: - FunctionSpecializer(SCCPSolver &Solver, - std::function<AssumptionCache &(Function &)> GetAC, - std::function<TargetTransformInfo &(Function &)> GetTTI, - std::function<TargetLibraryInfo &(Function &)> GetTLI) - : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {} - - ~FunctionSpecializer() { - // Eliminate dead code. - removeDeadInstructions(); - removeDeadFunctions(); - } - /// Attempt to specialize functions in the module to enable constant - /// propagation across function boundaries. - /// - /// \returns true if at least one function is specialized. - bool specializeFunctions(FuncList &Candidates, FuncList &WorkList) { - bool Changed = false; - for (auto *F : Candidates) { - if (!isCandidateFunction(F)) - continue; +template <> struct llvm::DenseMapInfo<SpecSig> { + static inline SpecSig getEmptyKey() { return {~0U, {}}; } - auto Cost = getSpecializationCost(F); - if (!Cost.isValid()) { - LLVM_DEBUG( - dbgs() << "FnSpecialization: Invalid specialization cost.\n"); - continue; - } + static inline SpecSig getTombstoneKey() { return {~1U, {}}; } - LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " - << F->getName() << " is " << Cost << "\n"); + static unsigned getHashValue(const SpecSig &S) { + return static_cast<unsigned>(hash_value(S)); + } - SmallVector<CallSpecBinding, 8> Specializations; - if (!calculateGains(F, Cost, Specializations)) { - LLVM_DEBUG(dbgs() << "FnSpecialization: No possible constants found\n"); - continue; - } + static bool isEqual(const SpecSig &LHS, const SpecSig &RHS) { + return LHS == RHS; + } +}; + +/// Attempt to specialize functions in the module to enable constant +/// propagation across function boundaries. +/// +/// \returns true if at least one function is specialized. +bool FunctionSpecializer::run() { + // Find possible specializations for each function. + SpecMap SM; + SmallVector<Spec, 32> AllSpecs; + unsigned NumCandidates = 0; + for (Function &F : M) { + if (!isCandidateFunction(&F)) + continue; - Changed = true; - for (auto &Entry : Specializations) - specializeFunction(F, Entry.second, WorkList); + auto Cost = getSpecializationCost(&F); + if (!Cost.isValid()) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialization cost for " + << F.getName() << "\n"); + continue; } - updateSpecializedFuncs(Candidates, WorkList); - NumFuncSpecialized += NbFunctionsSpecialized; - return Changed; - } + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " + << F.getName() << " is " << Cost << "\n"); - void removeDeadInstructions() { - for (auto *I : ReplacedWithConstant) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead instruction " << *I - << "\n"); - I->eraseFromParent(); + if (!findSpecializations(&F, Cost, AllSpecs, SM)) { + LLVM_DEBUG( + dbgs() << "FnSpecialization: No possible specializations found for " + << F.getName() << "\n"); + continue; } - ReplacedWithConstant.clear(); + + ++NumCandidates; } - void removeDeadFunctions() { - for (auto *F : FullySpecialized) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function " - << F->getName() << "\n"); - F->eraseFromParent(); - } - FullySpecialized.clear(); + if (!NumCandidates) { + LLVM_DEBUG( + dbgs() + << "FnSpecialization: No possible specializations found in module\n"); + return false; } - bool tryToReplaceWithConstant(Value *V) { - if (!V->getType()->isSingleValueType() || isa<CallBase>(V) || - V->user_empty()) - return false; - - const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); - if (isOverdefined(IV)) - return false; - auto *Const = - isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType()); - - LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing " << *V - << "\nFnSpecialization: with " << *Const << "\n"); - - // Record uses of V to avoid visiting irrelevant uses of const later. - SmallVector<Instruction *> UseInsts; - for (auto *U : V->users()) - if (auto *I = dyn_cast<Instruction>(U)) - if (Solver.isBlockExecutable(I->getParent())) - UseInsts.push_back(I); - - V->replaceAllUsesWith(Const); - - for (auto *I : UseInsts) - Solver.visit(I); - - // Remove the instruction from Block and Solver. - if (auto *I = dyn_cast<Instruction>(V)) { - if (I->isSafeToRemove()) { - ReplacedWithConstant.push_back(I); - Solver.removeLatticeValueFor(I); - } + // Choose the most profitable specialisations, which fit in the module + // specialization budget, which is derived from maximum number of + // specializations per specialization candidate function. + auto CompareGain = [&AllSpecs](unsigned I, unsigned J) { + return AllSpecs[I].Gain > AllSpecs[J].Gain; + }; + const unsigned NSpecs = + std::min(NumCandidates * MaxClonesThreshold, unsigned(AllSpecs.size())); + SmallVector<unsigned> BestSpecs(NSpecs + 1); + std::iota(BestSpecs.begin(), BestSpecs.begin() + NSpecs, 0); + if (AllSpecs.size() > NSpecs) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed " + << "the maximum number of clones threshold.\n" + << "FnSpecialization: Specializing the " + << NSpecs + << " most profitable candidates.\n"); + std::make_heap(BestSpecs.begin(), BestSpecs.begin() + NSpecs, CompareGain); + for (unsigned I = NSpecs, N = AllSpecs.size(); I < N; ++I) { + BestSpecs[NSpecs] = I; + std::push_heap(BestSpecs.begin(), BestSpecs.end(), CompareGain); + std::pop_heap(BestSpecs.begin(), BestSpecs.end(), CompareGain); } - return true; } -private: - // The number of functions specialised, used for collecting statistics and - // also in the cost model. - unsigned NbFunctionsSpecialized = 0; - - // Compute the code metrics for function \p F. - CodeMetrics &analyzeFunction(Function *F) { - auto I = FunctionMetrics.insert({F, CodeMetrics()}); - CodeMetrics &Metrics = I.first->second; - if (I.second) { - // The code metrics were not cached. - SmallPtrSet<const Value *, 32> EphValues; - CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); - for (BasicBlock &BB : *F) - Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); - - LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function " - << F->getName() << " is " << Metrics.NumInsts - << " instructions\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: List of specializations \n"; + for (unsigned I = 0; I < NSpecs; ++I) { + const Spec &S = AllSpecs[BestSpecs[I]]; + dbgs() << "FnSpecialization: Function " << S.F->getName() + << " , gain " << S.Gain << "\n"; + for (const ArgInfo &Arg : S.Sig.Args) + dbgs() << "FnSpecialization: FormalArg = " + << Arg.Formal->getNameOrAsOperand() + << ", ActualArg = " << Arg.Actual->getNameOrAsOperand() + << "\n"; + }); + + // Create the chosen specializations. + SmallPtrSet<Function *, 8> OriginalFuncs; + SmallVector<Function *> Clones; + for (unsigned I = 0; I < NSpecs; ++I) { + Spec &S = AllSpecs[BestSpecs[I]]; + S.Clone = createSpecialization(S.F, S.Sig); + + // Update the known call sites to call the clone. + for (CallBase *Call : S.CallSites) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call + << " to call " << S.Clone->getName() << "\n"); + Call->setCalledFunction(S.Clone); } - return Metrics; - } - /// Clone the function \p F and remove the ssa_copy intrinsics added by - /// the SCCPSolver in the cloned version. - Function *cloneCandidateFunction(Function *F, ValueToValueMapTy &Mappings) { - Function *Clone = CloneFunction(F, Mappings); - removeSSACopy(*Clone); - return Clone; + Clones.push_back(S.Clone); + OriginalFuncs.insert(S.F); } - /// This function decides whether it's worthwhile to specialize function - /// \p F based on the known constant values its arguments can take on. It - /// only discovers potential specialization opportunities without actually - /// applying them. - /// - /// \returns true if any specializations have been found. - bool calculateGains(Function *F, InstructionCost Cost, - SmallVectorImpl<CallSpecBinding> &WorkList) { - SpecializationMap Specializations; - // Determine if we should specialize the function based on the values the - // argument can take on. If specialization is not profitable, we continue - // on to the next argument. - for (Argument &FormalArg : F->args()) { - // Determine if this argument is interesting. If we know the argument can - // take on any constant values, they are collected in Constants. - SmallVector<CallArgBinding, 8> ActualArgs; - if (!isArgumentInteresting(&FormalArg, ActualArgs)) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Argument " - << FormalArg.getNameOrAsOperand() - << " is not interesting\n"); - continue; - } + Solver.solveWhileResolvedUndefsIn(Clones); - for (const auto &Entry : ActualArgs) { - CallBase *Call = Entry.first; - Constant *ActualArg = Entry.second; + // Update the rest of the call sites - these are the recursive calls, calls + // to discarded specialisations and calls that may match a specialisation + // after the solver runs. + for (Function *F : OriginalFuncs) { + auto [Begin, End] = SM[F]; + updateCallSites(F, AllSpecs.begin() + Begin, AllSpecs.begin() + End); + } - auto I = Specializations.insert({Call, SpecializationInfo()}); - SpecializationInfo &S = I.first->second; + promoteConstantStackValues(); + LLVM_DEBUG(if (NbFunctionsSpecialized) dbgs() + << "FnSpecialization: Specialized " << NbFunctionsSpecialized + << " functions in module " << M.getName() << "\n"); - if (I.second) - S.Gain = ForceFunctionSpecialization ? 1 : 0 - Cost; - if (!ForceFunctionSpecialization) - S.Gain += getSpecializationBonus(&FormalArg, ActualArg); - S.Args.push_back({&FormalArg, ActualArg}); - } - } + NumFuncSpecialized += NbFunctionsSpecialized; + return true; +} - // Remove unprofitable specializations. - Specializations.remove_if( - [](const auto &Entry) { return Entry.second.Gain <= 0; }); - - // Clear the MapVector and return the underlying vector. - WorkList = Specializations.takeVector(); - - // Sort the candidates in descending order. - llvm::stable_sort(WorkList, [](const auto &L, const auto &R) { - return L.second.Gain > R.second.Gain; - }); - - // Truncate the worklist to 'MaxClonesThreshold' candidates if necessary. - if (WorkList.size() > MaxClonesThreshold) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed " - << "the maximum number of clones threshold.\n" - << "FnSpecialization: Truncating worklist to " - << MaxClonesThreshold << " candidates.\n"); - WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end()); - } +void FunctionSpecializer::removeDeadFunctions() { + for (Function *F : FullySpecialized) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function " + << F->getName() << "\n"); + if (FAM) + FAM->clear(*F, F->getName()); + F->eraseFromParent(); + } + FullySpecialized.clear(); +} - LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function " - << F->getName() << "\n"; - for (const auto &Entry - : WorkList) { - dbgs() << "FnSpecialization: Gain = " << Entry.second.Gain - << "\n"; - for (const ArgInfo &Arg : Entry.second.Args) - dbgs() << "FnSpecialization: FormalArg = " - << Arg.Formal->getNameOrAsOperand() - << ", ActualArg = " - << Arg.Actual->getNameOrAsOperand() << "\n"; - }); - - return !WorkList.empty(); +// Compute the code metrics for function \p F. +CodeMetrics &FunctionSpecializer::analyzeFunction(Function *F) { + auto I = FunctionMetrics.insert({F, CodeMetrics()}); + CodeMetrics &Metrics = I.first->second; + if (I.second) { + // The code metrics were not cached. + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); + for (BasicBlock &BB : *F) + Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function " + << F->getName() << " is " << Metrics.NumInsts + << " instructions\n"); } + return Metrics; +} - bool isCandidateFunction(Function *F) { - // Do not specialize the cloned function again. - if (SpecializedFuncs.contains(F)) - return false; +/// Clone the function \p F and remove the ssa_copy intrinsics added by +/// the SCCPSolver in the cloned version. +static Function *cloneCandidateFunction(Function *F) { + ValueToValueMapTy Mappings; + Function *Clone = CloneFunction(F, Mappings); + removeSSACopy(*Clone); + return Clone; +} - // If we're optimizing the function for size, we shouldn't specialize it. - if (F->hasOptSize() || - shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass)) - return false; +bool FunctionSpecializer::findSpecializations(Function *F, InstructionCost Cost, + SmallVectorImpl<Spec> &AllSpecs, + SpecMap &SM) { + // A mapping from a specialisation signature to the index of the respective + // entry in the all specialisation array. Used to ensure uniqueness of + // specialisations. + DenseMap<SpecSig, unsigned> UM; + + // Get a list of interesting arguments. + SmallVector<Argument *> Args; + for (Argument &Arg : F->args()) + if (isArgumentInteresting(&Arg)) + Args.push_back(&Arg); + + if (Args.empty()) + return false; - // Exit if the function is not executable. There's no point in specializing - // a dead function. - if (!Solver.isBlockExecutable(&F->getEntryBlock())) - return false; + bool Found = false; + for (User *U : F->users()) { + if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) + continue; + auto &CS = *cast<CallBase>(U); - // It wastes time to specialize a function which would get inlined finally. - if (F->hasFnAttribute(Attribute::AlwaysInline)) - return false; + // The user instruction does not call our function. + if (CS.getCalledFunction() != F) + continue; - LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName() - << "\n"); - return true; - } + // If the call site has attribute minsize set, that callsite won't be + // specialized. + if (CS.hasFnAttr(Attribute::MinSize)) + continue; - void specializeFunction(Function *F, SpecializationInfo &S, - FuncList &WorkList) { - ValueToValueMapTy Mappings; - Function *Clone = cloneCandidateFunction(F, Mappings); - - // Rewrite calls to the function so that they call the clone instead. - rewriteCallSites(Clone, S.Args, Mappings); - - // Initialize the lattice state of the arguments of the function clone, - // marking the argument on which we specialized the function constant - // with the given value. - Solver.markArgInFuncSpecialization(Clone, S.Args); - - // Mark all the specialized functions - WorkList.push_back(Clone); - NbFunctionsSpecialized++; - - // If the function has been completely specialized, the original function - // is no longer needed. Mark it unreachable. - if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) { - if (auto *CS = dyn_cast<CallBase>(U)) - return CS->getFunction() == F; - return false; - })) { - Solver.markFunctionUnreachable(F); - FullySpecialized.insert(F); - } - } + // If the parent of the call site will never be executed, we don't need + // to worry about the passed value. + if (!Solver.isBlockExecutable(CS.getParent())) + continue; - /// Compute and return the cost of specializing function \p F. - InstructionCost getSpecializationCost(Function *F) { - CodeMetrics &Metrics = analyzeFunction(F); - // If the code metrics reveal that we shouldn't duplicate the function, we - // shouldn't specialize it. Set the specialization cost to Invalid. - // Or if the lines of codes implies that this function is easy to get - // inlined so that we shouldn't specialize it. - if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() || - (!ForceFunctionSpecialization && - *Metrics.NumInsts.getValue() < SmallFunctionThreshold)) { - InstructionCost C{}; - C.setInvalid(); - return C; + // Examine arguments and create a specialisation candidate from the + // constant operands of this call site. + SpecSig S; + for (Argument *A : Args) { + Constant *C = getCandidateConstant(CS.getArgOperand(A->getArgNo())); + if (!C) + continue; + LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument " + << A->getName() << " : " << C->getNameOrAsOperand() + << "\n"); + S.Args.push_back({A, C}); } - // Otherwise, set the specialization cost to be the cost of all the - // instructions in the function and penalty for specializing more functions. - unsigned Penalty = NbFunctionsSpecialized + 1; - return Metrics.NumInsts * InlineConstants::InstrCost * Penalty; - } - - InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI, - LoopInfo &LI) { - auto *I = dyn_cast_or_null<Instruction>(U); - // If not an instruction we do not know how to evaluate. - // Keep minimum possible cost for now so that it doesnt affect - // specialization. - if (!I) - return std::numeric_limits<unsigned>::min(); - - auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency); - - // Traverse recursively if there are more uses. - // TODO: Any other instructions to be added here? - if (I->mayReadFromMemory() || I->isCast()) - for (auto *User : I->users()) - Cost += getUserBonus(User, TTI, LI); - - // Increase the cost if it is inside the loop. - auto LoopDepth = LI.getLoopDepth(I->getParent()); - Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth); - return Cost; - } - - /// Compute a bonus for replacing argument \p A with constant \p C. - InstructionCost getSpecializationBonus(Argument *A, Constant *C) { - Function *F = A->getParent(); - DominatorTree DT(*F); - LoopInfo LI(DT); - auto &TTI = (GetTTI)(*F); - LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " - << C->getNameOrAsOperand() << "\n"); - - InstructionCost TotalCost = 0; - for (auto *U : A->users()) { - TotalCost += getUserBonus(U, TTI, LI); - LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; - TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); - } + if (S.Args.empty()) + continue; - // The below heuristic is only concerned with exposing inlining - // opportunities via indirect call promotion. If the argument is not a - // (potentially casted) function pointer, give up. - Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts()); - if (!CalledFunction) - return TotalCost; - - // Get TTI for the called function (used for the inline cost). - auto &CalleeTTI = (GetTTI)(*CalledFunction); - - // Look at all the call sites whose called value is the argument. - // Specializing the function on the argument would allow these indirect - // calls to be promoted to direct calls. If the indirect call promotion - // would likely enable the called function to be inlined, specializing is a - // good idea. - int Bonus = 0; - for (User *U : A->users()) { - if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) + // Check if we have encountered the same specialisation already. + if (auto It = UM.find(S); It != UM.end()) { + // Existing specialisation. Add the call to the list to rewrite, unless + // it's a recursive call. A specialisation, generated because of a + // recursive call may end up as not the best specialisation for all + // the cloned instances of this call, which result from specialising + // functions. Hence we don't rewrite the call directly, but match it with + // the best specialisation once all specialisations are known. + if (CS.getFunction() == F) continue; - auto *CS = cast<CallBase>(U); - if (CS->getCalledOperand() != A) + const unsigned Index = It->second; + AllSpecs[Index].CallSites.push_back(&CS); + } else { + // Calculate the specialisation gain. + InstructionCost Gain = 0 - Cost; + for (ArgInfo &A : S.Args) + Gain += + getSpecializationBonus(A.Formal, A.Actual, Solver.getLoopInfo(*F)); + + // Discard unprofitable specialisations. + if (!ForceFunctionSpecialization && Gain <= 0) continue; - // Get the cost of inlining the called function at this call site. Note - // that this is only an estimate. The called function may eventually - // change in a way that leads to it not being inlined here, even though - // inlining looks profitable now. For example, one of its called - // functions may be inlined into it, making the called function too large - // to be inlined into this call site. - // - // We apply a boost for performing indirect call promotion by increasing - // the default threshold by the threshold for indirect calls. - auto Params = getInlineParams(); - Params.DefaultThreshold += InlineConstants::IndirectCallThreshold; - InlineCost IC = - getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI); - - // We clamp the bonus for this call to be between zero and the default - // threshold. - if (IC.isAlways()) - Bonus += Params.DefaultThreshold; - else if (IC.isVariable() && IC.getCostDelta() > 0) - Bonus += IC.getCostDelta(); - - LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << Bonus - << " for user " << *U << "\n"); + // Create a new specialisation entry. + auto &Spec = AllSpecs.emplace_back(F, S, Gain); + if (CS.getFunction() != F) + Spec.CallSites.push_back(&CS); + const unsigned Index = AllSpecs.size() - 1; + UM[S] = Index; + if (auto [It, Inserted] = SM.try_emplace(F, Index, Index + 1); !Inserted) + It->second.second = Index + 1; + Found = true; } - - return TotalCost + Bonus; } - /// Determine if we should specialize a function based on the incoming values - /// of the given argument. - /// - /// This function implements the goal-directed heuristic. It determines if - /// specializing the function based on the incoming values of argument \p A - /// would result in any significant optimization opportunities. If - /// optimization opportunities exist, the constant values of \p A on which to - /// specialize the function are collected in \p Constants. - /// - /// \returns true if the function should be specialized on the given - /// argument. - bool isArgumentInteresting(Argument *A, - SmallVectorImpl<CallArgBinding> &Constants) { - // For now, don't attempt to specialize functions based on the values of - // composite types. - if (!A->getType()->isSingleValueType() || A->user_empty()) - return false; - - // If the argument isn't overdefined, there's nothing to do. It should - // already be constant. - if (!Solver.getLatticeValueFor(A).isOverdefined()) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, argument " - << A->getNameOrAsOperand() - << " is already constant?\n"); - return false; - } - - // Collect the constant values that the argument can take on. If the - // argument can't take on any constant values, we aren't going to - // specialize the function. While it's possible to specialize the function - // based on non-constant arguments, there's likely not much benefit to - // constant propagation in doing so. - // - // TODO 1: currently it won't specialize if there are over the threshold of - // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it - // might be beneficial to take the occurrences into account in the cost - // model, so we would need to find the unique constants. - // - // TODO 2: this currently does not support constants, i.e. integer ranges. - // - getPossibleConstants(A, Constants); - - if (Constants.empty()) - return false; + return Found; +} - LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument " - << A->getNameOrAsOperand() << "\n"); - return true; - } +bool FunctionSpecializer::isCandidateFunction(Function *F) { + if (F->isDeclaration()) + return false; - /// Collect in \p Constants all the constant values that argument \p A can - /// take on. - void getPossibleConstants(Argument *A, - SmallVectorImpl<CallArgBinding> &Constants) { - Function *F = A->getParent(); + if (F->hasFnAttribute(Attribute::NoDuplicate)) + return false; - // SCCP solver does not record an argument that will be constructed on - // stack. - if (A->hasByValAttr() && !F->onlyReadsMemory()) - return; + if (!Solver.isArgumentTrackedFunction(F)) + return false; - // Iterate over all the call sites of the argument's parent function. - for (User *U : F->users()) { - if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) - continue; - auto &CS = *cast<CallBase>(U); - // If the call site has attribute minsize set, that callsite won't be - // specialized. - if (CS.hasFnAttr(Attribute::MinSize)) - continue; + // Do not specialize the cloned function again. + if (SpecializedFuncs.contains(F)) + return false; - // If the parent of the call site will never be executed, we don't need - // to worry about the passed value. - if (!Solver.isBlockExecutable(CS.getParent())) - continue; + // If we're optimizing the function for size, we shouldn't specialize it. + if (F->hasOptSize() || + shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass)) + return false; - auto *V = CS.getArgOperand(A->getArgNo()); - if (isa<PoisonValue>(V)) - return; + // Exit if the function is not executable. There's no point in specializing + // a dead function. + if (!Solver.isBlockExecutable(&F->getEntryBlock())) + return false; - // TrackValueOfGlobalVariable only tracks scalar global variables. - if (auto *GV = dyn_cast<GlobalVariable>(V)) { - // Check if we want to specialize on the address of non-constant - // global values. - if (!GV->isConstant()) - if (!SpecializeOnAddresses) - return; + // It wastes time to specialize a function which would get inlined finally. + if (F->hasFnAttribute(Attribute::AlwaysInline)) + return false; - if (!GV->getValueType()->isSingleValueType()) - return; - } + LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName() + << "\n"); + return true; +} - if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() || - EnableSpecializationForLiteralConstant)) - Constants.push_back({&CS, cast<Constant>(V)}); - } - } +Function *FunctionSpecializer::createSpecialization(Function *F, const SpecSig &S) { + Function *Clone = cloneCandidateFunction(F); - /// Rewrite calls to function \p F to call function \p Clone instead. - /// - /// This function modifies calls to function \p F as long as the actual - /// arguments match those in \p Args. Note that for recursive calls we - /// need to compare against the cloned formal arguments. - /// - /// Callsites that have been marked with the MinSize function attribute won't - /// be specialized and rewritten. - void rewriteCallSites(Function *Clone, const SmallVectorImpl<ArgInfo> &Args, - ValueToValueMapTy &Mappings) { - assert(!Args.empty() && "Specialization without arguments"); - Function *F = Args[0].Formal->getParent(); - - SmallVector<CallBase *, 8> CallSitesToRewrite; - for (auto *U : F->users()) { - if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) - continue; - auto &CS = *cast<CallBase>(U); - if (!CS.getCalledFunction() || CS.getCalledFunction() != F) - continue; - CallSitesToRewrite.push_back(&CS); - } + // Initialize the lattice state of the arguments of the function clone, + // marking the argument on which we specialized the function constant + // with the given value. + Solver.markArgInFuncSpecialization(Clone, S.Args); - LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing call sites of " - << F->getName() << " with " << Clone->getName() << "\n"); + Solver.addArgumentTrackedFunction(Clone); + Solver.markBlockExecutable(&Clone->front()); - for (auto *CS : CallSitesToRewrite) { - LLVM_DEBUG(dbgs() << "FnSpecialization: " - << CS->getFunction()->getName() << " ->" << *CS - << "\n"); - if (/* recursive call */ - (CS->getFunction() == Clone && - all_of(Args, - [CS, &Mappings](const ArgInfo &Arg) { - unsigned ArgNo = Arg.Formal->getArgNo(); - return CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]; - })) || - /* normal call */ - all_of(Args, [CS](const ArgInfo &Arg) { - unsigned ArgNo = Arg.Formal->getArgNo(); - return CS->getArgOperand(ArgNo) == Arg.Actual; - })) { - CS->setCalledFunction(Clone); - Solver.markOverdefined(CS); - } - } - } + // Mark all the specialized functions + SpecializedFuncs.insert(Clone); + NbFunctionsSpecialized++; - void updateSpecializedFuncs(FuncList &Candidates, FuncList &WorkList) { - for (auto *F : WorkList) { - SpecializedFuncs.insert(F); + return Clone; +} - // Initialize the state of the newly created functions, marking them - // argument-tracked and executable. - if (F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked)) - Solver.addTrackedFunction(F); +/// Compute and return the cost of specializing function \p F. +InstructionCost FunctionSpecializer::getSpecializationCost(Function *F) { + CodeMetrics &Metrics = analyzeFunction(F); + // If the code metrics reveal that we shouldn't duplicate the function, we + // shouldn't specialize it. Set the specialization cost to Invalid. + // Or if the lines of codes implies that this function is easy to get + // inlined so that we shouldn't specialize it. + if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() || + (!ForceFunctionSpecialization && + !F->hasFnAttribute(Attribute::NoInline) && + Metrics.NumInsts < SmallFunctionThreshold)) + return InstructionCost::getInvalid(); + + // Otherwise, set the specialization cost to be the cost of all the + // instructions in the function. + return Metrics.NumInsts * InlineConstants::getInstrCost(); +} - Solver.addArgumentTrackedFunction(F); - Candidates.push_back(F); - Solver.markBlockExecutable(&F->front()); +static InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI, + const LoopInfo &LI) { + auto *I = dyn_cast_or_null<Instruction>(U); + // If not an instruction we do not know how to evaluate. + // Keep minimum possible cost for now so that it doesnt affect + // specialization. + if (!I) + return std::numeric_limits<unsigned>::min(); + + InstructionCost Cost = + TTI.getInstructionCost(U, TargetTransformInfo::TCK_SizeAndLatency); + + // Increase the cost if it is inside the loop. + unsigned LoopDepth = LI.getLoopDepth(I->getParent()); + Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth); + + // Traverse recursively if there are more uses. + // TODO: Any other instructions to be added here? + if (I->mayReadFromMemory() || I->isCast()) + for (auto *User : I->users()) + Cost += getUserBonus(User, TTI, LI); + + return Cost; +} - // Replace the function arguments for the specialized functions. - for (Argument &Arg : F->args()) - if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg)) - LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: " - << Arg.getNameOrAsOperand() << "\n"); - } +/// Compute a bonus for replacing argument \p A with constant \p C. +InstructionCost +FunctionSpecializer::getSpecializationBonus(Argument *A, Constant *C, + const LoopInfo &LI) { + Function *F = A->getParent(); + auto &TTI = (GetTTI)(*F); + LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " + << C->getNameOrAsOperand() << "\n"); + + InstructionCost TotalCost = 0; + for (auto *U : A->users()) { + TotalCost += getUserBonus(U, TTI, LI); + LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; + TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); } -}; -} // namespace - -bool llvm::runFunctionSpecialization( - Module &M, const DataLayout &DL, - std::function<TargetLibraryInfo &(Function &)> GetTLI, - std::function<TargetTransformInfo &(Function &)> GetTTI, - std::function<AssumptionCache &(Function &)> GetAC, - function_ref<AnalysisResultsForFn(Function &)> GetAnalysis) { - SCCPSolver Solver(DL, GetTLI, M.getContext()); - FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI); - bool Changed = false; - - // Loop over all functions, marking arguments to those with their addresses - // taken or that are external as overdefined. - for (Function &F : M) { - if (F.isDeclaration()) + + // The below heuristic is only concerned with exposing inlining + // opportunities via indirect call promotion. If the argument is not a + // (potentially casted) function pointer, give up. + Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts()); + if (!CalledFunction) + return TotalCost; + + // Get TTI for the called function (used for the inline cost). + auto &CalleeTTI = (GetTTI)(*CalledFunction); + + // Look at all the call sites whose called value is the argument. + // Specializing the function on the argument would allow these indirect + // calls to be promoted to direct calls. If the indirect call promotion + // would likely enable the called function to be inlined, specializing is a + // good idea. + int Bonus = 0; + for (User *U : A->users()) { + if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) + continue; + auto *CS = cast<CallBase>(U); + if (CS->getCalledOperand() != A) continue; - if (F.hasFnAttribute(Attribute::NoDuplicate)) + if (CS->getFunctionType() != CalledFunction->getFunctionType()) continue; - LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName() - << "\n"); - Solver.addAnalysis(F, GetAnalysis(F)); + // Get the cost of inlining the called function at this call site. Note + // that this is only an estimate. The called function may eventually + // change in a way that leads to it not being inlined here, even though + // inlining looks profitable now. For example, one of its called + // functions may be inlined into it, making the called function too large + // to be inlined into this call site. + // + // We apply a boost for performing indirect call promotion by increasing + // the default threshold by the threshold for indirect calls. + auto Params = getInlineParams(); + Params.DefaultThreshold += InlineConstants::IndirectCallThreshold; + InlineCost IC = + getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI); + + // We clamp the bonus for this call to be between zero and the default + // threshold. + if (IC.isAlways()) + Bonus += Params.DefaultThreshold; + else if (IC.isVariable() && IC.getCostDelta() > 0) + Bonus += IC.getCostDelta(); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << Bonus + << " for user " << *U << "\n"); + } - // Determine if we can track the function's arguments. If so, add the - // function to the solver's set of argument-tracked functions. - if (canTrackArgumentsInterprocedurally(&F)) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n"); - Solver.addArgumentTrackedFunction(&F); - continue; - } else { - LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n" - << "FnSpecialization: Doesn't have local linkage, or " - << "has its address taken\n"); - } + return TotalCost + Bonus; +} - // Assume the function is called. - Solver.markBlockExecutable(&F.front()); +/// Determine if it is possible to specialise the function for constant values +/// of the formal parameter \p A. +bool FunctionSpecializer::isArgumentInteresting(Argument *A) { + // No point in specialization if the argument is unused. + if (A->user_empty()) + return false; - // Assume nothing about the incoming arguments. - for (Argument &AI : F.args()) - Solver.markOverdefined(&AI); - } + // For now, don't attempt to specialize functions based on the values of + // composite types. + Type *ArgTy = A->getType(); + if (!ArgTy->isSingleValueType()) + return false; - // Determine if we can track any of the module's global variables. If so, add - // the global variables we can track to the solver's set of tracked global - // variables. - for (GlobalVariable &G : M.globals()) { - G.removeDeadConstantUsers(); - if (canTrackGlobalVariableInterprocedurally(&G)) - Solver.trackValueOfGlobalVariable(&G); - } + // Specialization of integer and floating point types needs to be explicitly + // enabled. + if (!EnableSpecializationForLiteralConstant && + (ArgTy->isIntegerTy() || ArgTy->isFloatingPointTy())) + return false; - auto &TrackedFuncs = Solver.getArgumentTrackedFunctions(); - SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(), - TrackedFuncs.end()); + // SCCP solver does not record an argument that will be constructed on + // stack. + if (A->hasByValAttr() && !A->getParent()->onlyReadsMemory()) + return false; - // No tracked functions, so nothing to do: don't run the solver and remove - // the ssa_copy intrinsics that may have been introduced. - if (TrackedFuncs.empty()) { - removeSSACopy(M); + // Check the lattice value and decide if we should attemt to specialize, + // based on this argument. No point in specialization, if the lattice value + // is already a constant. + const ValueLatticeElement &LV = Solver.getLatticeValueFor(A); + if (LV.isUnknownOrUndef() || LV.isConstant() || + (LV.isConstantRange() && LV.getConstantRange().isSingleElement())) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, parameter " + << A->getNameOrAsOperand() << " is already constant\n"); return false; } - // Solve for constants. - auto RunSCCPSolver = [&](auto &WorkList) { - bool ResolvedUndefs = true; - - while (ResolvedUndefs) { - // Not running the solver unnecessary is checked in regression test - // nothing-to-do.ll, so if this debug message is changed, this regression - // test needs updating too. - LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n"); - - Solver.solve(); - LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n"); - ResolvedUndefs = false; - for (Function *F : WorkList) - if (Solver.resolvedUndefsIn(*F)) - ResolvedUndefs = true; - } + LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting parameter " + << A->getNameOrAsOperand() << "\n"); - for (auto *F : WorkList) { - for (BasicBlock &BB : *F) { - if (!Solver.isBlockExecutable(&BB)) - continue; - // FIXME: The solver may make changes to the function here, so set - // Changed, even if later function specialization does not trigger. - for (auto &I : make_early_inc_range(BB)) - Changed |= FS.tryToReplaceWithConstant(&I); - } - } - }; + return true; +} + +/// Check if the valuy \p V (an actual argument) is a constant or can only +/// have a constant value. Return that constant. +Constant *FunctionSpecializer::getCandidateConstant(Value *V) { + if (isa<PoisonValue>(V)) + return nullptr; -#ifndef NDEBUG - LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n"); - for (auto *F : FuncDecls) - LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n"); -#endif + // TrackValueOfGlobalVariable only tracks scalar global variables. + if (auto *GV = dyn_cast<GlobalVariable>(V)) { + // Check if we want to specialize on the address of non-constant + // global values. + if (!GV->isConstant() && !SpecializeOnAddresses) + return nullptr; - // Initially resolve the constants in all the argument tracked functions. - RunSCCPSolver(FuncDecls); + if (!GV->getValueType()->isSingleValueType()) + return nullptr; + } - SmallVector<Function *, 8> WorkList; - unsigned I = 0; - while (FuncSpecializationMaxIters != I++ && - FS.specializeFunctions(FuncDecls, WorkList)) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Finished iteration " << I << "\n"); + // Select for possible specialisation values that are constants or + // are deduced to be constants or constant ranges with a single element. + Constant *C = dyn_cast<Constant>(V); + if (!C) { + const ValueLatticeElement &LV = Solver.getLatticeValueFor(V); + if (LV.isConstant()) + C = LV.getConstant(); + else if (LV.isConstantRange() && LV.getConstantRange().isSingleElement()) { + assert(V->getType()->isIntegerTy() && "Non-integral constant range"); + C = Constant::getIntegerValue(V->getType(), + *LV.getConstantRange().getSingleElement()); + } else + return nullptr; + } - // Run the solver for the specialized functions. - RunSCCPSolver(WorkList); + return C; +} - // Replace some unresolved constant arguments. - constantArgPropagation(FuncDecls, M, Solver); +void FunctionSpecializer::updateCallSites(Function *F, const Spec *Begin, + const Spec *End) { + // Collect the call sites that need updating. + SmallVector<CallBase *> ToUpdate; + for (User *U : F->users()) + if (auto *CS = dyn_cast<CallBase>(U); + CS && CS->getCalledFunction() == F && + Solver.isBlockExecutable(CS->getParent())) + ToUpdate.push_back(CS); + + unsigned NCallsLeft = ToUpdate.size(); + for (CallBase *CS : ToUpdate) { + bool ShouldDecrementCount = CS->getFunction() == F; + + // Find the best matching specialisation. + const Spec *BestSpec = nullptr; + for (const Spec &S : make_range(Begin, End)) { + if (!S.Clone || (BestSpec && S.Gain <= BestSpec->Gain)) + continue; - WorkList.clear(); - Changed = true; - } + if (any_of(S.Sig.Args, [CS, this](const ArgInfo &Arg) { + unsigned ArgNo = Arg.Formal->getArgNo(); + return getCandidateConstant(CS->getArgOperand(ArgNo)) != Arg.Actual; + })) + continue; + + BestSpec = &S; + } - LLVM_DEBUG(dbgs() << "FnSpecialization: Number of specializations = " - << NumFuncSpecialized << "\n"); + if (BestSpec) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *CS + << " to call " << BestSpec->Clone->getName() << "\n"); + CS->setCalledFunction(BestSpec->Clone); + ShouldDecrementCount = true; + } + + if (ShouldDecrementCount) + --NCallsLeft; + } - // Remove any ssa_copy intrinsics that may have been introduced. - removeSSACopy(M); - return Changed; + // If the function has been completely specialized, the original function + // is no longer needed. Mark it unreachable. + if (NCallsLeft == 0) { + Solver.markFunctionUnreachable(F); + FullySpecialized.insert(F); + } } diff --git a/llvm/lib/Transforms/IPO/GlobalDCE.cpp b/llvm/lib/Transforms/IPO/GlobalDCE.cpp index f35827220bb6..2f2bb174a8c8 100644 --- a/llvm/lib/Transforms/IPO/GlobalDCE.cpp +++ b/llvm/lib/Transforms/IPO/GlobalDCE.cpp @@ -206,7 +206,7 @@ void GlobalDCEPass::ScanVTables(Module &M) { void GlobalDCEPass::ScanVTableLoad(Function *Caller, Metadata *TypeId, uint64_t CallOffset) { - for (auto &VTableInfo : TypeIdMap[TypeId]) { + for (const auto &VTableInfo : TypeIdMap[TypeId]) { GlobalVariable *VTable = VTableInfo.first; uint64_t VTableOffset = VTableInfo.second; @@ -240,7 +240,7 @@ void GlobalDCEPass::ScanTypeCheckedLoadIntrinsics(Module &M) { if (!TypeCheckedLoadFunc) return; - for (auto U : TypeCheckedLoadFunc->users()) { + for (auto *U : TypeCheckedLoadFunc->users()) { auto CI = dyn_cast<CallInst>(U); if (!CI) continue; @@ -254,7 +254,7 @@ void GlobalDCEPass::ScanTypeCheckedLoadIntrinsics(Module &M) { } else { // type.checked.load with a non-constant offset, so assume every entry in // every matching vtable is used. - for (auto &VTableInfo : TypeIdMap[TypeId]) { + for (const auto &VTableInfo : TypeIdMap[TypeId]) { VFESafeVTables.erase(VTableInfo.first); } } diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 6df0409256bb..0317a8bcb6bc 100644 --- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -68,6 +68,7 @@ #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> +#include <optional> #include <utility> #include <vector> @@ -140,9 +141,7 @@ static bool isLeakCheckerRoot(GlobalVariable *GV) { case Type::StructTyID: { StructType *STy = cast<StructType>(Ty); if (STy->isOpaque()) return true; - for (StructType::element_iterator I = STy->element_begin(), - E = STy->element_end(); I != E; ++I) { - Type *InnerTy = *I; + for (Type *InnerTy : STy->elements()) { if (isa<PointerType>(InnerTy)) return true; if (isa<StructType>(InnerTy) || isa<ArrayType>(InnerTy) || isa<VectorType>(InnerTy)) @@ -377,6 +376,11 @@ static bool collectSRATypes(DenseMap<uint64_t, Type *> &Types, GlobalValue *GV, auto It = Types.try_emplace(Offset.getZExtValue(), Ty).first; if (Ty != It->second) return false; + + // Scalable types not currently supported. + if (isa<ScalableVectorType>(Ty)) + return false; + continue; } @@ -652,7 +656,7 @@ static bool allUsesOfLoadedValueWillTrapIfNull(const GlobalVariable *GV) { Worklist.push_back(GV); while (!Worklist.empty()) { const Value *P = Worklist.pop_back_val(); - for (auto *U : P->users()) { + for (const auto *U : P->users()) { if (auto *LI = dyn_cast<LoadInst>(U)) { SmallPtrSet<const PHINode *, 8> PHIs; if (!AllUsesOfValueWillTrapIfNull(LI, PHIs)) @@ -879,7 +883,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, if (!isa<UndefValue>(InitVal)) { IRBuilder<> Builder(CI->getNextNode()); // TODO: Use alignment above if align!=1 - Builder.CreateMemSet(NewGV, InitVal, AllocSize, None); + Builder.CreateMemSet(NewGV, InitVal, AllocSize, std::nullopt); } // Update users of the allocation to use the new global instead. @@ -1378,8 +1382,8 @@ static bool isPointerValueDeadOnEntryToFunction( // and the number of bits loaded in L is less than or equal to // the number of bits stored in S. return DT.dominates(S, L) && - DL.getTypeStoreSize(LTy).getFixedSize() <= - DL.getTypeStoreSize(STy).getFixedSize(); + DL.getTypeStoreSize(LTy).getFixedValue() <= + DL.getTypeStoreSize(STy).getFixedValue(); })) return false; } @@ -1818,11 +1822,14 @@ hasOnlyColdCalls(Function &F, Function *CalledFn = CI->getCalledFunction(); if (!CalledFn) return false; - if (!CalledFn->hasLocalLinkage()) - return false; // Skip over intrinsics since they won't remain as function calls. + // Important to do this check before the linkage check below so we + // won't bail out on debug intrinsics, possibly making the generated + // code dependent on the presence of debug info. if (CalledFn->getIntrinsicID() != Intrinsic::not_intrinsic) continue; + if (!CalledFn->hasLocalLinkage()) + return false; // Check if it's valid to use coldcc calling convention. if (!hasChangeableCC(CalledFn) || CalledFn->isVarArg() || CalledFn->hasAddressTaken()) @@ -2003,7 +2010,7 @@ OptimizeFunctions(Module &M, // FIXME: We should also hoist alloca affected by this to the entry // block if possible. if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca) && - !F.hasAddressTaken() && !hasMustTailCallers(&F)) { + !F.hasAddressTaken() && !hasMustTailCallers(&F) && !F.isVarArg()) { RemoveAttribute(&F, Attribute::InAlloca); Changed = true; } @@ -2399,7 +2406,7 @@ static bool cxxDtorIsEmpty(const Function &Fn) { if (Fn.isDeclaration()) return false; - for (auto &I : Fn.getEntryBlock()) { + for (const auto &I : Fn.getEntryBlock()) { if (I.isDebugOrPseudoInst()) continue; if (isa<ReturnInst>(I)) @@ -2462,7 +2469,7 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL, SmallPtrSet<const Comdat *, 8> NotDiscardableComdats; bool Changed = false; bool LocalChange = true; - Optional<uint32_t> FirstNotFullyEvaluatedPriority; + std::optional<uint32_t> FirstNotFullyEvaluatedPriority; while (LocalChange) { LocalChange = false; diff --git a/llvm/lib/Transforms/IPO/IPO.cpp b/llvm/lib/Transforms/IPO/IPO.cpp index dfd434e61d5b..4163c448dc8f 100644 --- a/llvm/lib/Transforms/IPO/IPO.cpp +++ b/llvm/lib/Transforms/IPO/IPO.cpp @@ -23,7 +23,6 @@ using namespace llvm; void llvm::initializeIPO(PassRegistry &Registry) { - initializeOpenMPOptCGSCCLegacyPassPass(Registry); initializeAnnotation2MetadataLegacyPass(Registry); initializeCalledValuePropagationLegacyPassPass(Registry); initializeConstantMergeLegacyPassPass(Registry); @@ -31,7 +30,6 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeDAEPass(Registry); initializeDAHPass(Registry); initializeForceFunctionAttrsLegacyPassPass(Registry); - initializeFunctionSpecializationLegacyPassPass(Registry); initializeGlobalDCELegacyPassPass(Registry); initializeGlobalOptLegacyPassPass(Registry); initializeGlobalSplitPass(Registry); @@ -42,7 +40,6 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeInferFunctionAttrsLegacyPassPass(Registry); initializeInternalizeLegacyPassPass(Registry); initializeLoopExtractorLegacyPassPass(Registry); - initializeBlockExtractorLegacyPassPass(Registry); initializeSingleLoopExtractorPass(Registry); initializeMergeFunctionsLegacyPassPass(Registry); initializePartialInlinerLegacyPassPass(Registry); @@ -50,7 +47,6 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeAttributorCGSCCLegacyPassPass(Registry); initializePostOrderFunctionAttrsLegacyPassPass(Registry); initializeReversePostOrderFunctionAttrsLegacyPassPass(Registry); - initializePruneEHPass(Registry); initializeIPSCCPLegacyPassPass(Registry); initializeStripDeadPrototypesLegacyPassPass(Registry); initializeStripSymbolsPass(Registry); @@ -97,10 +93,6 @@ void LLVMAddGlobalOptimizerPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createGlobalOptimizerPass()); } -void LLVMAddPruneEHPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createPruneEHPass()); -} - void LLVMAddIPSCCPPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createIPSCCPPass()); } diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp index 28bc43aa1633..f5c52e5c7f5d 100644 --- a/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -26,6 +26,7 @@ #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" +#include <optional> #include <vector> #define DEBUG_TYPE "iroutliner" @@ -133,7 +134,7 @@ struct OutlinableGroup { /// The argument that needs to be marked with the swifterr attribute. If not /// needed, there is no value. - Optional<unsigned> SwiftErrorArgument; + std::optional<unsigned> SwiftErrorArgument; /// For the \ref Regions, we look at every Value. If it is a constant, /// we check whether it is the same in Region. @@ -169,7 +170,15 @@ static void getSortedConstantKeys(std::vector<Value *> &SortedKeys, for (auto &VtoBB : Map) SortedKeys.push_back(VtoBB.first); + // Here we expect to have either 1 value that is void (nullptr) or multiple + // values that are all constant integers. + if (SortedKeys.size() == 1) { + assert(!SortedKeys[0] && "Expected a single void value."); + return; + } + stable_sort(SortedKeys, [](const Value *LHS, const Value *RHS) { + assert(LHS && RHS && "Expected non void values."); const ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS); const ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS); assert(RHSC && "Not a constant integer in return value?"); @@ -181,11 +190,12 @@ static void getSortedConstantKeys(std::vector<Value *> &SortedKeys, Value *OutlinableRegion::findCorrespondingValueIn(const OutlinableRegion &Other, Value *V) { - Optional<unsigned> GVN = Candidate->getGVN(V); + std::optional<unsigned> GVN = Candidate->getGVN(V); assert(GVN && "No GVN for incoming value"); - Optional<unsigned> CanonNum = Candidate->getCanonicalNum(*GVN); - Optional<unsigned> FirstGVN = Other.Candidate->fromCanonicalNum(*CanonNum); - Optional<Value *> FoundValueOpt = Other.Candidate->fromGVN(*FirstGVN); + std::optional<unsigned> CanonNum = Candidate->getCanonicalNum(*GVN); + std::optional<unsigned> FirstGVN = + Other.Candidate->fromCanonicalNum(*CanonNum); + std::optional<Value *> FoundValueOpt = Other.Candidate->fromGVN(*FirstGVN); return FoundValueOpt.value_or(nullptr); } @@ -453,14 +463,14 @@ void OutlinableRegion::reattachCandidate() { /// \param GVNToConstant - The mapping of global value number to Constants. /// \returns true if the Value matches the Constant mapped to by V and false if /// it \p V is a Constant but does not match. -/// \returns None if \p V is not a Constant. -static Optional<bool> +/// \returns std::nullopt if \p V is not a Constant. +static std::optional<bool> constantMatches(Value *V, unsigned GVN, DenseMap<unsigned, Constant *> &GVNToConstant) { // See if we have a constants Constant *CST = dyn_cast<Constant>(V); if (!CST) - return None; + return std::nullopt; // Holds a mapping from a global value number to a Constant. DenseMap<unsigned, Constant *>::iterator GVNToConstantIt; @@ -553,9 +563,9 @@ collectRegionsConstants(OutlinableRegion &Region, // assigned by the IRSimilarityCandidate, has been seen before, we check if // the the number has been found to be not the same value in each instance. for (Value *V : ID.OperVals) { - Optional<unsigned> GVNOpt = C.getGVN(V); + std::optional<unsigned> GVNOpt = C.getGVN(V); assert(GVNOpt && "Expected a GVN for operand?"); - unsigned GVN = GVNOpt.value(); + unsigned GVN = *GVNOpt; // Check if this global value has been found to not be the same already. if (NotSame.contains(GVN)) { @@ -568,9 +578,10 @@ collectRegionsConstants(OutlinableRegion &Region, // associated Constant value match the previous instances of the same // global value number. If the global value does not map to a Constant, // it is considered to not be the same value. - Optional<bool> ConstantMatches = constantMatches(V, GVN, GVNToConstant); + std::optional<bool> ConstantMatches = + constantMatches(V, GVN, GVNToConstant); if (ConstantMatches) { - if (ConstantMatches.value()) + if (*ConstantMatches) continue; else ConstantsTheSame = false; @@ -651,7 +662,7 @@ Function *IROutliner::createFunction(Module &M, OutlinableGroup &Group, // Transfer the swifterr attribute to the correct function parameter. if (Group.SwiftErrorArgument) - Group.OutlinedFunction->addParamAttr(Group.SwiftErrorArgument.value(), + Group.OutlinedFunction->addParamAttr(*Group.SwiftErrorArgument, Attribute::SwiftError); Group.OutlinedFunction->addFnAttr(Attribute::OptimizeForSize); @@ -675,7 +686,8 @@ Function *IROutliner::createFunction(Module &M, OutlinableGroup &Group, Unit /* Context */, F->getName(), MangledNameStream.str(), Unit /* File */, 0 /* Line 0 is reserved for compiler-generated code. */, - DB.createSubroutineType(DB.getOrCreateTypeArray(None)), /* void type */ + DB.createSubroutineType( + DB.getOrCreateTypeArray(std::nullopt)), /* void type */ 0, /* Line 0 is reserved for compiler-generated code. */ DINode::DIFlags::FlagArtificial /* Compiler-generated code. */, /* Outlined code is optimized code by definition. */ @@ -809,7 +821,7 @@ static void mapInputsToGVNs(IRSimilarityCandidate &C, if (OutputMappings.find(Input) != OutputMappings.end()) Input = OutputMappings.find(Input)->second; assert(C.getGVN(Input) && "Could not find a numbering for the given input"); - EndInputNumbers.push_back(C.getGVN(Input).value()); + EndInputNumbers.push_back(*C.getGVN(Input)); } } @@ -946,13 +958,13 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region, // we find argument locations for the canonical value numbering. This // numbering overrides any discovered location for the extracted code. for (unsigned InputVal : InputGVNs) { - Optional<unsigned> CanonicalNumberOpt = C.getCanonicalNum(InputVal); + std::optional<unsigned> CanonicalNumberOpt = C.getCanonicalNum(InputVal); assert(CanonicalNumberOpt && "Canonical number not found?"); - unsigned CanonicalNumber = CanonicalNumberOpt.value(); + unsigned CanonicalNumber = *CanonicalNumberOpt; - Optional<Value *> InputOpt = C.fromGVN(InputVal); + std::optional<Value *> InputOpt = C.fromGVN(InputVal); assert(InputOpt && "Global value number not found?"); - Value *Input = InputOpt.value(); + Value *Input = *InputOpt; DenseMap<unsigned, unsigned>::iterator AggArgIt = Group.CanonicalNumberToAggArg.find(CanonicalNumber); @@ -1161,12 +1173,12 @@ static hash_code encodePHINodeData(PHINodeData &PND) { /// \param PN - The PHINode we are analyzing. /// \param Blocks - The blocks for the region we are analyzing. /// \param AggArgIdx - The argument \p PN will be stored into. -/// \returns An optional holding the assigned canonical number, or None if -/// there is some attribute of the PHINode blocking it from being used. -static Optional<unsigned> getGVNForPHINode(OutlinableRegion &Region, - PHINode *PN, - DenseSet<BasicBlock *> &Blocks, - unsigned AggArgIdx) { +/// \returns An optional holding the assigned canonical number, or std::nullopt +/// if there is some attribute of the PHINode blocking it from being used. +static std::optional<unsigned> getGVNForPHINode(OutlinableRegion &Region, + PHINode *PN, + DenseSet<BasicBlock *> &Blocks, + unsigned AggArgIdx) { OutlinableGroup &Group = *Region.Parent; IRSimilarityCandidate &Cand = *Region.Candidate; BasicBlock *PHIBB = PN->getParent(); @@ -1181,10 +1193,10 @@ static Optional<unsigned> getGVNForPHINode(OutlinableRegion &Region, // are trying to analyze, meaning, that if it was outlined, we would be // adding an extra input. We ignore this case for now, and so ignore the // region. - Optional<unsigned> OGVN = Cand.getGVN(Incoming); + std::optional<unsigned> OGVN = Cand.getGVN(Incoming); if (!OGVN && Blocks.contains(IncomingBlock)) { Region.IgnoreRegion = true; - return None; + return std::nullopt; } // If the incoming block isn't in the region, we don't have to worry about @@ -1202,7 +1214,7 @@ static Optional<unsigned> getGVNForPHINode(OutlinableRegion &Region, // the hash for the PHINode. OGVN = Cand.getGVN(IncomingBlock); - // If there is no number for the incoming block, it is becaause we have + // If there is no number for the incoming block, it is because we have // split the candidate basic blocks. So we use the previous block that it // was split from to find the valid global value numbering for the PHINode. if (!OGVN) { @@ -1233,16 +1245,16 @@ static Optional<unsigned> getGVNForPHINode(OutlinableRegion &Region, // PHINode to generate a hash value representing this instance of the PHINode. DenseMap<hash_code, unsigned>::iterator GVNToPHIIt; DenseMap<unsigned, PHINodeData>::iterator PHIToGVNIt; - Optional<unsigned> BBGVN = Cand.getGVN(PHIBB); + std::optional<unsigned> BBGVN = Cand.getGVN(PHIBB); assert(BBGVN && "Could not find GVN for the incoming block!"); - BBGVN = Cand.getCanonicalNum(BBGVN.value()); + BBGVN = Cand.getCanonicalNum(*BBGVN); assert(BBGVN && "Could not find canonical number for the incoming block!"); // Create a pair of the exit block canonical value, and the aggregate // argument location, connected to the canonical numbers stored in the // PHINode. PHINodeData TemporaryPair = - std::make_pair(std::make_pair(BBGVN.value(), AggArgIdx), PHIGVNs); + std::make_pair(std::make_pair(*BBGVN, AggArgIdx), PHIGVNs); hash_code PHINodeDataHash = encodePHINodeData(TemporaryPair); // Look for and create a new entry in our connection between canonical @@ -1265,7 +1277,7 @@ static Optional<unsigned> getGVNForPHINode(OutlinableRegion &Region, /// \param [in,out] Region - The region of code to be analyzed. /// \param [in] Outputs - The values found by the code extractor. static void -findExtractedOutputToOverallOutputMapping(OutlinableRegion &Region, +findExtractedOutputToOverallOutputMapping(Module &M, OutlinableRegion &Region, SetVector<Value *> &Outputs) { OutlinableGroup &Group = *Region.Parent; IRSimilarityCandidate &C = *Region.Candidate; @@ -1338,7 +1350,8 @@ findExtractedOutputToOverallOutputMapping(OutlinableRegion &Region, // the output, so we add a pointer type to the argument types of the overall // function to handle this output and create a mapping to it. if (!TypeFound) { - Group.ArgumentTypes.push_back(PointerType::getUnqual(Output->getType())); + Group.ArgumentTypes.push_back(Output->getType()->getPointerTo( + M.getDataLayout().getAllocaAddrSpace())); // Mark the new pointer type as the last value in the aggregate argument // list. unsigned ArgTypeIdx = Group.ArgumentTypes.size() - 1; @@ -1353,7 +1366,7 @@ findExtractedOutputToOverallOutputMapping(OutlinableRegion &Region, // TODO: Adapt to the extra input from the PHINode. PHINode *PN = dyn_cast<PHINode>(Output); - Optional<unsigned> GVN; + std::optional<unsigned> GVN; if (PN && !BlocksInRegion.contains(PN->getParent())) { // Values outside the region can be combined into PHINode when we // have multiple exits. We collect both of these into a list to identify @@ -1406,7 +1419,7 @@ void IROutliner::findAddInputsOutputs(Module &M, OutlinableRegion &Region, // Map the outputs found by the CodeExtractor to the arguments found for // the overall function. - findExtractedOutputToOverallOutputMapping(Region, Outputs); + findExtractedOutputToOverallOutputMapping(M, Region, Outputs); } /// Replace the extracted function in the Region with a call to the overall @@ -1516,7 +1529,7 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) { // Make sure that the argument in the new function has the SwiftError // argument. if (Group.SwiftErrorArgument) - Call->addParamAttr(Group.SwiftErrorArgument.value(), Attribute::SwiftError); + Call->addParamAttr(*Group.SwiftErrorArgument, Attribute::SwiftError); return Call; } @@ -1646,9 +1659,9 @@ static void findCanonNumsForPHI( IVal = findOutputMapping(OutputMappings, IVal); // Find and add the canonical number for the incoming value. - Optional<unsigned> GVN = Region.Candidate->getGVN(IVal); + std::optional<unsigned> GVN = Region.Candidate->getGVN(IVal); assert(GVN && "No GVN for incoming value"); - Optional<unsigned> CanonNum = Region.Candidate->getCanonicalNum(*GVN); + std::optional<unsigned> CanonNum = Region.Candidate->getCanonicalNum(*GVN); assert(CanonNum && "No Canonical Number for GVN"); CanonNums.push_back(std::make_pair(*CanonNum, IBlock)); } @@ -1861,7 +1874,7 @@ replaceArgumentUses(OutlinableRegion &Region, StoreInst *NewI = cast<StoreInst>(I->clone()); NewI->setDebugLoc(DebugLoc()); BasicBlock *OutputBB = VBBIt->second; - OutputBB->getInstList().push_back(NewI); + NewI->insertInto(OutputBB, OutputBB->end()); LLVM_DEBUG(dbgs() << "Move store for instruction " << *I << " to " << *OutputBB << "\n"); @@ -1958,7 +1971,7 @@ void replaceConstants(OutlinableRegion &Region) { /// \param OutputBBs [in] the blocks we are looking for a duplicate of. /// \param OutputStoreBBs [in] The existing output blocks. /// \returns an optional value with the number output block if there is a match. -Optional<unsigned> findDuplicateOutputBlock( +std::optional<unsigned> findDuplicateOutputBlock( DenseMap<Value *, BasicBlock *> &OutputBBs, std::vector<DenseMap<Value *, BasicBlock *>> &OutputStoreBBs) { @@ -2004,7 +2017,7 @@ Optional<unsigned> findDuplicateOutputBlock( MatchingNum++; } - return None; + return std::nullopt; } /// Remove empty output blocks from the outlined region. @@ -2073,17 +2086,16 @@ static void alignOutputBlockWithAggFunc( return; // Determine is there is a duplicate set of blocks. - Optional<unsigned> MatchingBB = + std::optional<unsigned> MatchingBB = findDuplicateOutputBlock(OutputBBs, OutputStoreBBs); // If there is, we remove the new output blocks. If it does not, // we add it to our list of sets of output blocks. if (MatchingBB) { LLVM_DEBUG(dbgs() << "Set output block for region in function" - << Region.ExtractedFunction << " to " - << MatchingBB.value()); + << Region.ExtractedFunction << " to " << *MatchingBB); - Region.OutputBlockNum = MatchingBB.value(); + Region.OutputBlockNum = *MatchingBB; for (std::pair<Value *, BasicBlock *> &VtoBB : OutputBBs) VtoBB.second->eraseFromParent(); return; @@ -2415,6 +2427,7 @@ void IROutliner::pruneIncompatibleRegions( PreviouslyOutlined = false; unsigned StartIdx = IRSC.getStartIdx(); unsigned EndIdx = IRSC.getEndIdx(); + const Function &FnForCurrCand = *IRSC.getFunction(); for (unsigned Idx = StartIdx; Idx <= EndIdx; Idx++) if (Outlined.contains(Idx)) { @@ -2434,9 +2447,17 @@ void IROutliner::pruneIncompatibleRegions( if (BBHasAddressTaken) continue; - if (IRSC.getFunction()->hasOptNone()) + if (FnForCurrCand.hasOptNone()) continue; + if (FnForCurrCand.hasFnAttribute("nooutline")) { + LLVM_DEBUG({ + dbgs() << "... Skipping function with nooutline attribute: " + << FnForCurrCand.getName() << "\n"; + }); + continue; + } + if (IRSC.front()->Inst->getFunction()->hasLinkOnceODRLinkage() && !OutlineFromLinkODRs) continue; @@ -2500,9 +2521,10 @@ static Value *findOutputValueInRegion(OutlinableRegion &Region, assert(It->second.second.size() > 0 && "PHINode does not have any values!"); OutputCanon = *It->second.second.begin(); } - Optional<unsigned> OGVN = Region.Candidate->fromCanonicalNum(OutputCanon); + std::optional<unsigned> OGVN = + Region.Candidate->fromCanonicalNum(OutputCanon); assert(OGVN && "Could not find GVN for Canonical Number?"); - Optional<Value *> OV = Region.Candidate->fromGVN(*OGVN); + std::optional<Value *> OV = Region.Candidate->fromGVN(*OGVN); assert(OV && "Could not find value for GVN?"); return *OV; } @@ -2663,7 +2685,7 @@ void IROutliner::updateOutputMapping(OutlinableRegion &Region, LoadInst *LI) { // For and load instructions following the call Value *Operand = LI->getPointerOperand(); - Optional<unsigned> OutputIdx = None; + std::optional<unsigned> OutputIdx; // Find if the operand it is an output register. for (unsigned ArgIdx = Region.NumExtractedInputs; ArgIdx < Region.Call->arg_size(); ArgIdx++) { @@ -2678,14 +2700,14 @@ void IROutliner::updateOutputMapping(OutlinableRegion &Region, if (!OutputIdx) return; - if (OutputMappings.find(Outputs[OutputIdx.value()]) == OutputMappings.end()) { + if (OutputMappings.find(Outputs[*OutputIdx]) == OutputMappings.end()) { LLVM_DEBUG(dbgs() << "Mapping extracted output " << *LI << " to " - << *Outputs[OutputIdx.value()] << "\n"); - OutputMappings.insert(std::make_pair(LI, Outputs[OutputIdx.value()])); + << *Outputs[*OutputIdx] << "\n"); + OutputMappings.insert(std::make_pair(LI, Outputs[*OutputIdx])); } else { - Value *Orig = OutputMappings.find(Outputs[OutputIdx.value()])->second; + Value *Orig = OutputMappings.find(Outputs[*OutputIdx])->second; LLVM_DEBUG(dbgs() << "Mapping extracted output " << *Orig << " to " - << *Outputs[OutputIdx.value()] << "\n"); + << *Outputs[*OutputIdx] << "\n"); OutputMappings.insert(std::make_pair(LI, Orig)); } } diff --git a/llvm/lib/Transforms/IPO/InlineSimple.cpp b/llvm/lib/Transforms/IPO/InlineSimple.cpp index 2143e39d488d..eba0d6636d6c 100644 --- a/llvm/lib/Transforms/IPO/InlineSimple.cpp +++ b/llvm/lib/Transforms/IPO/InlineSimple.cpp @@ -50,7 +50,7 @@ public: TargetTransformInfo &TTI = TTIWP->getTTI(*Callee); bool RemarksEnabled = false; - const auto &BBs = CB.getCaller()->getBasicBlockList(); + const auto &BBs = *CB.getCaller(); if (!BBs.empty()) { auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBs.front()); if (DI.isEnabled()) diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp index 4d32266eb9ea..5bcfc38c585b 100644 --- a/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/llvm/lib/Transforms/IPO/Inliner.cpp @@ -14,7 +14,6 @@ #include "llvm/Transforms/IPO/Inliner.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -31,7 +30,6 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/InlineOrder.h" #include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" @@ -110,7 +108,9 @@ static cl::opt<bool> EnablePostSCCAdvisorPrinting("enable-scc-inline-advisor-printing", cl::init(false), cl::Hidden); +namespace llvm { extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats; +} static cl::opt<std::string> CGSCCInlineReplayFile( "cgscc-inline-replay", cl::init(""), cl::value_desc("filename"), @@ -316,15 +316,15 @@ static InlineResult inlineCallIfPossible( // Try to inline the function. Get the list of static allocas that were // inlined. - InlineResult IR = InlineFunction(CB, IFI, &AAR, InsertLifetime); + InlineResult IR = + InlineFunction(CB, IFI, + /*MergeAttributes=*/true, &AAR, InsertLifetime); if (!IR.isSuccess()) return IR; if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) ImportedFunctionsStats.recordInline(*Caller, *Callee); - AttributeFuncs::mergeAttributesForInlining(*Caller, *Callee); - if (!DisableInlinedAllocaMerging) mergeInlinedArrayAllocas(Caller, IFI, InlinedArrayAllocas, InlineHistory); @@ -785,7 +785,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // this model, but it is uniformly spread across all the functions in the SCC // and eventually they all become too large to inline, rather than // incrementally maknig a single function grow in a super linear fashion. - DefaultInlineOrder<std::pair<CallBase *, int>> Calls; + SmallVector<std::pair<CallBase *, int>, 16> Calls; // Populate the initial list of calls in this SCC. for (auto &N : InitialC) { @@ -800,7 +800,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (auto *CB = dyn_cast<CallBase>(&I)) if (Function *Callee = CB->getCalledFunction()) { if (!Callee->isDeclaration()) - Calls.push({CB, -1}); + Calls.push_back({CB, -1}); else if (!isa<IntrinsicInst>(I)) { using namespace ore; setInlineRemark(*CB, "unavailable definition"); @@ -839,18 +839,17 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // be deleted as a batch after inlining. SmallVector<Function *, 4> DeadFunctionsInComdats; - // Loop forward over all of the calls. - while (!Calls.empty()) { + // Loop forward over all of the calls. Note that we cannot cache the size as + // inlining can introduce new calls that need to be processed. + for (int I = 0; I < (int)Calls.size(); ++I) { // We expect the calls to typically be batched with sequences of calls that // have the same caller, so we first set up some shared infrastructure for // this caller. We also do any pruning we can at this layer on the caller // alone. - Function &F = *Calls.front().first->getCaller(); + Function &F = *Calls[I].first->getCaller(); LazyCallGraph::Node &N = *CG.lookup(F); - if (CG.lookupSCC(N) != C) { - Calls.pop(); + if (CG.lookupSCC(N) != C) continue; - } LLVM_DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n" << " Function size: " << F.getInstructionCount() @@ -864,8 +863,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // We bail out as soon as the caller has to change so we can update the // call graph and prepare the context of that new caller. bool DidInline = false; - while (!Calls.empty() && Calls.front().first->getCaller() == &F) { - auto P = Calls.pop(); + for (; I < (int)Calls.size() && Calls[I].first->getCaller() == &F; ++I) { + auto &P = Calls[I]; CallBase *CB = P.first; const int InlineHistoryID = P.second; Function &Callee = *CB->getCalledFunction(); @@ -917,7 +916,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, &FAM.getResult<BlockFrequencyAnalysis>(Callee)); InlineResult IR = - InlineFunction(*CB, IFI, &FAM.getResult<AAManager>(*CB->getCaller())); + InlineFunction(*CB, IFI, /*MergeAttributes=*/true, + &FAM.getResult<AAManager>(*CB->getCaller())); if (!IR.isSuccess()) { Advice->recordUnsuccessfulInlining(IR); continue; @@ -949,7 +949,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, } if (NewCallee) { if (!NewCallee->isDeclaration()) { - Calls.push({ICB, NewHistoryID}); + Calls.push_back({ICB, NewHistoryID}); // Continually inlining through an SCC can result in huge compile // times and bloated code since we arbitrarily stop at some point // when the inliner decides it's not profitable to inline anymore. @@ -972,9 +972,6 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, } } - // Merge the attributes based on the inlining. - AttributeFuncs::mergeAttributesForInlining(F, Callee); - // For local functions or discardable functions without comdats, check // whether this makes the callee trivially dead. In that case, we can drop // the body of the function eagerly which may reduce the number of callers @@ -984,9 +981,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (Callee.isDiscardableIfUnused() && Callee.hasZeroLiveUses() && !CG.isLibFunction(Callee)) { if (Callee.hasLocalLinkage() || !Callee.hasComdat()) { - Calls.erase_if([&](const std::pair<CallBase *, int> &Call) { - return Call.first->getCaller() == &Callee; - }); + Calls.erase( + std::remove_if(Calls.begin() + I + 1, Calls.end(), + [&](const std::pair<CallBase *, int> &Call) { + return Call.first->getCaller() == &Callee; + }), + Calls.end()); + // Clear the body and queue the function itself for deletion when we // finish inlining and call graph updates. // Note that after this point, it is an error to do anything other @@ -1006,6 +1007,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Advice->recordInlining(); } + // Back the call index up by one to put us in a good position to go around + // the outer loop. + --I; + if (!DidInline) continue; Changed = true; diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index e3e4908f085b..ddfcace6acf8 100644 --- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -172,7 +172,7 @@ BitSetInfo BitSetBuilder::build() { BSI.AlignLog2 = 0; if (Mask != 0) - BSI.AlignLog2 = countTrailingZeros(Mask, ZB_Undefined); + BSI.AlignLog2 = countTrailingZeros(Mask); // Build the compressed bitset while normalizing the offsets against the // computed alignment. @@ -309,7 +309,7 @@ public: } ArrayRef<MDNode *> types() const { - return makeArrayRef(getTrailingObjects<MDNode *>(), NTypes); + return ArrayRef(getTrailingObjects<MDNode *>(), NTypes); } }; @@ -331,7 +331,7 @@ struct ICallBranchFunnel final CallInst *CI; ArrayRef<GlobalTypeMember *> targets() const { - return makeArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets); + return ArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets); } unsigned UniqueId; @@ -539,7 +539,7 @@ BitSetInfo LowerTypeTestsModule::buildBitSet( // Compute the byte offset of each address associated with this type // identifier. - for (auto &GlobalAndOffset : GlobalLayout) { + for (const auto &GlobalAndOffset : GlobalLayout) { for (MDNode *Type : GlobalAndOffset.first->types()) { if (Type->getOperand(1) != TypeId) continue; @@ -1179,6 +1179,7 @@ void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { } static const unsigned kX86JumpTableEntrySize = 8; +static const unsigned kX86IBTJumpTableEntrySize = 16; static const unsigned kARMJumpTableEntrySize = 4; static const unsigned kARMBTIJumpTableEntrySize = 8; static const unsigned kRISCVJumpTableEntrySize = 8; @@ -1187,6 +1188,10 @@ unsigned LowerTypeTestsModule::getJumpTableEntrySize() { switch (Arch) { case Triple::x86: case Triple::x86_64: + if (const auto *MD = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("cf-protection-branch"))) + if (MD->getZExtValue()) + return kX86IBTJumpTableEntrySize; return kX86JumpTableEntrySize; case Triple::arm: case Triple::thumb: @@ -1215,8 +1220,17 @@ void LowerTypeTestsModule::createJumpTableEntry( unsigned ArgIndex = AsmArgs.size(); if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) { + bool Endbr = false; + if (const auto *MD = mdconst::extract_or_null<ConstantInt>( + Dest->getParent()->getModuleFlag("cf-protection-branch"))) + Endbr = MD->getZExtValue() != 0; + if (Endbr) + AsmOS << (JumpTableArch == Triple::x86 ? "endbr32\n" : "endbr64\n"); AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n"; - AsmOS << "int3\nint3\nint3\n"; + if (Endbr) + AsmOS << ".balign 16, 0xcc\n"; + else + AsmOS << "int3\nint3\nint3\n"; } else if (JumpTableArch == Triple::arm) { AsmOS << "b $" << ArgIndex << "\n"; } else if (JumpTableArch == Triple::aarch64) { @@ -1300,7 +1314,7 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( // (all?) targets. Switch to a runtime initializer. SmallSetVector<GlobalVariable *, 8> GlobalVarUsers; findGlobalVariableUsersOf(F, GlobalVarUsers); - for (auto GV : GlobalVarUsers) + for (auto *GV : GlobalVarUsers) moveInitializerToModuleConstructor(GV); // Can not RAUW F with an expression that uses F. Replace with a temporary @@ -1369,9 +1383,9 @@ void LowerTypeTestsModule::createJumpTable( Triple::ArchType JumpTableArch = selectJumpTableArmEncoding(Functions, Arch); - for (unsigned I = 0; I != Functions.size(); ++I) + for (GlobalTypeMember *GTM : Functions) createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs, - cast<Function>(Functions[I]->getGlobal())); + cast<Function>(GTM->getGlobal())); // Align the whole table by entry size. F->setAlignment(Align(getJumpTableEntrySize())); @@ -1389,6 +1403,9 @@ void LowerTypeTestsModule::createJumpTable( // by Clang for -march=armv7. F->addFnAttr("target-cpu", "cortex-a8"); } + // When -mbranch-protection= is used, the inline asm adds a BTI. Suppress BTI + // for the function to avoid double BTI. This is a no-op without + // -mbranch-protection=. if (JumpTableArch == Triple::aarch64) { F->addFnAttr("branch-target-enforcement", "false"); F->addFnAttr("sign-return-address", "none"); @@ -1398,6 +1415,11 @@ void LowerTypeTestsModule::createJumpTable( // the linker. F->addFnAttr("target-features", "-c,-relax"); } + // When -fcf-protection= is used, the inline asm adds an ENDBR. Suppress ENDBR + // for the function to avoid double ENDBR. This is a no-op without + // -fcf-protection=. + if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) + F->addFnAttr(Attribute::NoCfCheck); // Make sure we don't emit .eh_frame for this function. F->addFnAttr(Attribute::NoUnwind); @@ -1863,9 +1885,9 @@ bool LowerTypeTestsModule::lower() { std::vector<GlobalAlias *> AliasesToErase; { ScopedSaveAliaseesAndUsed S(M); - for (auto F : Defs) + for (auto *F : Defs) importFunction(F, /*isJumpTableCanonical*/ true, AliasesToErase); - for (auto F : Decls) + for (auto *F : Decls) importFunction(F, /*isJumpTableCanonical*/ false, AliasesToErase); } for (GlobalAlias *GA : AliasesToErase) @@ -1912,12 +1934,12 @@ bool LowerTypeTestsModule::lower() { for (auto &I : *ExportSummary) for (auto &GVS : I.second.SummaryList) if (GVS->isLive()) - for (auto &Ref : GVS->refs()) + for (const auto &Ref : GVS->refs()) AddressTaken.insert(Ref.getGUID()); NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions"); if (CfiFunctionsMD) { - for (auto FuncMD : CfiFunctionsMD->operands()) { + for (auto *FuncMD : CfiFunctionsMD->operands()) { assert(FuncMD->getNumOperands() >= 2); StringRef FunctionName = cast<MDString>(FuncMD->getOperand(0))->getString(); @@ -1938,7 +1960,7 @@ bool LowerTypeTestsModule::lower() { bool Exported = false; if (auto VI = ExportSummary->getValueInfo(GUID)) - for (auto &GVS : VI.getSummaryList()) + for (const auto &GVS : VI.getSummaryList()) if (GVS->isLive() && !GlobalValue::isLocalLinkage(GVS->linkage())) Exported = true; @@ -2212,7 +2234,7 @@ bool LowerTypeTestsModule::lower() { // with an alias to the intended target. if (ExportSummary) { if (NamedMDNode *AliasesMD = M.getNamedMetadata("aliases")) { - for (auto AliasMD : AliasesMD->operands()) { + for (auto *AliasMD : AliasesMD->operands()) { assert(AliasMD->getNumOperands() >= 4); StringRef AliasName = cast<MDString>(AliasMD->getOperand(0))->getString(); @@ -2254,7 +2276,7 @@ bool LowerTypeTestsModule::lower() { // Emit .symver directives for exported functions, if they exist. if (ExportSummary) { if (NamedMDNode *SymversMD = M.getNamedMetadata("symvers")) { - for (auto Symver : SymversMD->operands()) { + for (auto *Symver : SymversMD->operands()) { assert(Symver->getNumOperands() >= 2); StringRef SymbolName = cast<MDString>(Symver->getOperand(0))->getString(); diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp index b850591b4aa6..590f62ca58dd 100644 --- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -215,7 +215,7 @@ private: if (LHS.getHash() != RHS.getHash()) return LHS.getHash() < RHS.getHash(); FunctionComparator FCmp(LHS.getFunc(), RHS.getFunc(), GlobalNumbers); - return FCmp.compare() == -1; + return FCmp.compare() < 0; } }; using FnTreeType = std::set<FunctionNode, FunctionNodeCmp>; @@ -493,12 +493,11 @@ static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements()); Value *Result = PoisonValue::get(DestTy); for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) { - Value *Element = createCast( - Builder, Builder.CreateExtractValue(V, makeArrayRef(I)), - DestTy->getStructElementType(I)); + Value *Element = + createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)), + DestTy->getStructElementType(I)); - Result = - Builder.CreateInsertValue(Result, Element, makeArrayRef(I)); + Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I)); } return Result; } @@ -775,7 +774,12 @@ void MergeFunctions::writeAlias(Function *F, Function *G) { auto *GA = GlobalAlias::create(G->getValueType(), PtrType->getAddressSpace(), G->getLinkage(), "", BitcastF, G->getParent()); - F->setAlignment(MaybeAlign(std::max(F->getAlignment(), G->getAlignment()))); + const MaybeAlign FAlign = F->getAlign(); + const MaybeAlign GAlign = G->getAlign(); + if (FAlign || GAlign) + F->setAlignment(std::max(FAlign.valueOrOne(), GAlign.valueOrOne())); + else + F->setAlignment(std::nullopt); GA->takeName(G); GA->setVisibility(G->getVisibility()); GA->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); @@ -822,12 +826,18 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { removeUsers(F); F->replaceAllUsesWith(NewF); - MaybeAlign MaxAlignment(std::max(G->getAlignment(), NewF->getAlignment())); + // We collect alignment before writeThunkOrAlias that overwrites NewF and + // G's content. + const MaybeAlign NewFAlign = NewF->getAlign(); + const MaybeAlign GAlign = G->getAlign(); writeThunkOrAlias(F, G); writeThunkOrAlias(F, NewF); - F->setAlignment(MaxAlignment); + if (NewFAlign || GAlign) + F->setAlignment(std::max(NewFAlign.valueOrOne(), GAlign.valueOrOne())); + else + F->setAlignment(std::nullopt); F->setLinkage(GlobalValue::PrivateLinkage); ++NumDoubleWeak; ++NumFunctionsMerged; diff --git a/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/llvm/lib/Transforms/IPO/ModuleInliner.cpp index 143715006512..ee382657f5e6 100644 --- a/llvm/lib/Transforms/IPO/ModuleInliner.cpp +++ b/llvm/lib/Transforms/IPO/ModuleInliner.cpp @@ -15,7 +15,6 @@ #include "llvm/Transforms/IPO/ModuleInliner.h" #include "llvm/ADT/ScopeExit.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -49,10 +48,6 @@ using namespace llvm; STATISTIC(NumInlined, "Number of functions inlined"); STATISTIC(NumDeleted, "Number of functions deleted because all callers found"); -static cl::opt<bool> InlineEnablePriorityOrder( - "module-inline-enable-priority-order", cl::Hidden, cl::init(true), - cl::desc("Enable the priority inline order for the module inliner")); - /// Return true if the specified inline history ID /// indicates an inline history that includes the specified function. static bool inlineHistoryIncludes( @@ -85,8 +80,7 @@ InlineAdvisor &ModuleInlinerPass::getAdvisor(const ModuleAnalysisManager &MAM, // would get from the MAM can be invalidated as a result of the inliner's // activity. OwnedAdvisor = std::make_unique<DefaultInlineAdvisor>( - M, FAM, Params, - InlineContext{LTOPhase, InlinePass::ModuleInliner}); + M, FAM, Params, InlineContext{LTOPhase, InlinePass::ModuleInliner}); return *OwnedAdvisor; } @@ -111,9 +105,8 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M, LLVM_DEBUG(dbgs() << "---- Module Inliner is Running ---- \n"); auto &IAA = MAM.getResult<InlineAdvisorAnalysis>(M); - if (!IAA.tryCreate( - Params, Mode, {}, - InlineContext{LTOPhase, InlinePass::ModuleInliner})) { + if (!IAA.tryCreate(Params, Mode, {}, + InlineContext{LTOPhase, InlinePass::ModuleInliner})) { M.getContext().emitError( "Could not setup Inlining Advisor for the requested " "mode and/or options"); @@ -145,12 +138,7 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M, // // TODO: Here is a huge amount duplicate code between the module inliner and // the SCC inliner, which need some refactoring. - std::unique_ptr<InlineOrder<std::pair<CallBase *, int>>> Calls; - if (InlineEnablePriorityOrder) - Calls = std::make_unique<PriorityInlineOrder>( - std::make_unique<SizePriority>()); - else - Calls = std::make_unique<DefaultInlineOrder<std::pair<CallBase *, int>>>(); + auto Calls = getInlineOrder(FAM, Params); assert(Calls != nullptr && "Expected an initialized InlineOrder"); // Populate the initial list of calls in this module. @@ -188,135 +176,111 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M, // index into the InlineHistory vector. SmallVector<std::pair<Function *, int>, 16> InlineHistory; - // Track a set vector of inlined callees so that we can augment the caller - // with all of their edges in the call graph before pruning out the ones that - // got simplified away. - SmallSetVector<Function *, 4> InlinedCallees; - // Track the dead functions to delete once finished with inlining calls. We // defer deleting these to make it easier to handle the call graph updates. SmallVector<Function *, 4> DeadFunctions; // Loop forward over all of the calls. while (!Calls->empty()) { - // We expect the calls to typically be batched with sequences of calls that - // have the same caller, so we first set up some shared infrastructure for - // this caller. We also do any pruning we can at this layer on the caller - // alone. - Function &F = *Calls->front().first->getCaller(); + auto P = Calls->pop(); + CallBase *CB = P.first; + const int InlineHistoryID = P.second; + Function &F = *CB->getCaller(); + Function &Callee = *CB->getCalledFunction(); LLVM_DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n" << " Function size: " << F.getInstructionCount() << "\n"); + (void)F; auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { return FAM.getResult<AssumptionAnalysis>(F); }; - // Now process as many calls as we have within this caller in the sequence. - // We bail out as soon as the caller has to change so we can - // prepare the context of that new caller. - bool DidInline = false; - while (!Calls->empty() && Calls->front().first->getCaller() == &F) { - auto P = Calls->pop(); - CallBase *CB = P.first; - const int InlineHistoryID = P.second; - Function &Callee = *CB->getCalledFunction(); - - if (InlineHistoryID != -1 && - inlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) { - setInlineRemark(*CB, "recursive"); - continue; - } + if (InlineHistoryID != -1 && + inlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) { + setInlineRemark(*CB, "recursive"); + continue; + } - auto Advice = Advisor.getAdvice(*CB, /*OnlyMandatory*/ false); - // Check whether we want to inline this callsite. - if (!Advice->isInliningRecommended()) { - Advice->recordUnattemptedInlining(); - continue; - } + auto Advice = Advisor.getAdvice(*CB, /*OnlyMandatory*/ false); + // Check whether we want to inline this callsite. + if (!Advice->isInliningRecommended()) { + Advice->recordUnattemptedInlining(); + continue; + } - // Setup the data structure used to plumb customization into the - // `InlineFunction` routine. - InlineFunctionInfo IFI( - /*cg=*/nullptr, GetAssumptionCache, PSI, - &FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())), - &FAM.getResult<BlockFrequencyAnalysis>(Callee)); - - InlineResult IR = - InlineFunction(*CB, IFI, &FAM.getResult<AAManager>(*CB->getCaller())); - if (!IR.isSuccess()) { - Advice->recordUnsuccessfulInlining(IR); - continue; - } + // Setup the data structure used to plumb customization into the + // `InlineFunction` routine. + InlineFunctionInfo IFI( + /*cg=*/nullptr, GetAssumptionCache, PSI, + &FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())), + &FAM.getResult<BlockFrequencyAnalysis>(Callee)); + + InlineResult IR = + InlineFunction(*CB, IFI, /*MergeAttributes=*/true, + &FAM.getResult<AAManager>(*CB->getCaller())); + if (!IR.isSuccess()) { + Advice->recordUnsuccessfulInlining(IR); + continue; + } - DidInline = true; - InlinedCallees.insert(&Callee); - ++NumInlined; - - LLVM_DEBUG(dbgs() << " Size after inlining: " - << F.getInstructionCount() << "\n"); - - // Add any new callsites to defined functions to the worklist. - if (!IFI.InlinedCallSites.empty()) { - int NewHistoryID = InlineHistory.size(); - InlineHistory.push_back({&Callee, InlineHistoryID}); - - for (CallBase *ICB : reverse(IFI.InlinedCallSites)) { - Function *NewCallee = ICB->getCalledFunction(); - if (!NewCallee) { - // Try to promote an indirect (virtual) call without waiting for - // the post-inline cleanup and the next DevirtSCCRepeatedPass - // iteration because the next iteration may not happen and we may - // miss inlining it. - if (tryPromoteCall(*ICB)) - NewCallee = ICB->getCalledFunction(); - } - if (NewCallee) - if (!NewCallee->isDeclaration()) - Calls->push({ICB, NewHistoryID}); - } - } + Changed = true; + ++NumInlined; + + LLVM_DEBUG(dbgs() << " Size after inlining: " << F.getInstructionCount() + << "\n"); - // Merge the attributes based on the inlining. - AttributeFuncs::mergeAttributesForInlining(F, Callee); - - // For local functions, check whether this makes the callee trivially - // dead. In that case, we can drop the body of the function eagerly - // which may reduce the number of callers of other functions to one, - // changing inline cost thresholds. - bool CalleeWasDeleted = false; - if (Callee.hasLocalLinkage()) { - // To check this we also need to nuke any dead constant uses (perhaps - // made dead by this operation on other functions). - Callee.removeDeadConstantUsers(); - // if (Callee.use_empty() && !CG.isLibFunction(Callee)) { - if (Callee.use_empty() && !isKnownLibFunction(Callee, GetTLI(Callee))) { - Calls->erase_if([&](const std::pair<CallBase *, int> &Call) { - return Call.first->getCaller() == &Callee; - }); - // Clear the body and queue the function itself for deletion when we - // finish inlining. - // Note that after this point, it is an error to do anything other - // than use the callee's address or delete it. - Callee.dropAllReferences(); - assert(!is_contained(DeadFunctions, &Callee) && - "Cannot put cause a function to become dead twice!"); - DeadFunctions.push_back(&Callee); - CalleeWasDeleted = true; + // Add any new callsites to defined functions to the worklist. + if (!IFI.InlinedCallSites.empty()) { + int NewHistoryID = InlineHistory.size(); + InlineHistory.push_back({&Callee, InlineHistoryID}); + + for (CallBase *ICB : reverse(IFI.InlinedCallSites)) { + Function *NewCallee = ICB->getCalledFunction(); + if (!NewCallee) { + // Try to promote an indirect (virtual) call without waiting for + // the post-inline cleanup and the next DevirtSCCRepeatedPass + // iteration because the next iteration may not happen and we may + // miss inlining it. + if (tryPromoteCall(*ICB)) + NewCallee = ICB->getCalledFunction(); } + if (NewCallee) + if (!NewCallee->isDeclaration()) + Calls->push({ICB, NewHistoryID}); } - if (CalleeWasDeleted) - Advice->recordInliningWithCalleeDeleted(); - else - Advice->recordInlining(); } - if (!DidInline) - continue; - Changed = true; - - InlinedCallees.clear(); + // For local functions, check whether this makes the callee trivially + // dead. In that case, we can drop the body of the function eagerly + // which may reduce the number of callers of other functions to one, + // changing inline cost thresholds. + bool CalleeWasDeleted = false; + if (Callee.hasLocalLinkage()) { + // To check this we also need to nuke any dead constant uses (perhaps + // made dead by this operation on other functions). + Callee.removeDeadConstantUsers(); + // if (Callee.use_empty() && !CG.isLibFunction(Callee)) { + if (Callee.use_empty() && !isKnownLibFunction(Callee, GetTLI(Callee))) { + Calls->erase_if([&](const std::pair<CallBase *, int> &Call) { + return Call.first->getCaller() == &Callee; + }); + // Clear the body and queue the function itself for deletion when we + // finish inlining. + // Note that after this point, it is an error to do anything other + // than use the callee's address or delete it. + Callee.dropAllReferences(); + assert(!is_contained(DeadFunctions, &Callee) && + "Cannot put cause a function to become dead twice!"); + DeadFunctions.push_back(&Callee); + CalleeWasDeleted = true; + } + } + if (CalleeWasDeleted) + Advice->recordInliningWithCalleeDeleted(); + else + Advice->recordInlining(); } // Now that we've finished inlining all of the calls across this module, diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index ef2384faa273..bee154dab10f 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/EnumeratedArray.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/CallGraph.h" @@ -32,6 +33,7 @@ #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/Assumptions.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/GlobalValue.h" @@ -45,12 +47,13 @@ #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include <algorithm> +#include <optional> +#include <string> using namespace llvm; using namespace omp; @@ -71,6 +74,8 @@ static cl::opt<bool> cl::desc("Disable function internalization."), cl::Hidden, cl::init(false)); +static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values", + cl::init(false), cl::Hidden); static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden); static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels", @@ -182,13 +187,13 @@ struct AAICVTracker; /// Attributor runs. struct OMPInformationCache : public InformationCache { OMPInformationCache(Module &M, AnalysisGetter &AG, - BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC, + BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC, KernelSet &Kernels) - : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), + : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M), Kernels(Kernels) { OMPBuilder.initialize(); - initializeRuntimeFunctions(); + initializeRuntimeFunctions(M); initializeInternalControlVars(); } @@ -412,7 +417,7 @@ struct OMPInformationCache : public InformationCache { // TODO: We directly convert uses into proper calls and unknown uses. for (Use &U : RFI.Declaration->uses()) { if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) { - if (ModuleSlice.count(UserI->getFunction())) { + if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) { RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U); ++NumUses; } @@ -445,8 +450,7 @@ struct OMPInformationCache : public InformationCache { /// Helper to initialize all runtime function information for those defined /// in OpenMPKinds.def. - void initializeRuntimeFunctions() { - Module &M = *((*ModuleSlice.begin())->getParent()); + void initializeRuntimeFunctions(Module &M) { // Helper macros for handling __VA_ARGS__ in OMP_RTL #define OMP_TYPE(VarName, ...) \ @@ -499,6 +503,18 @@ struct OMPInformationCache : public InformationCache { } #include "llvm/Frontend/OpenMP/OMPKinds.def" + // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_` + // functions, except if `optnone` is present. + if (isOpenMPDevice(M)) { + for (Function &F : M) { + for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"}) + if (F.hasFnAttribute(Attribute::NoInline) && + F.getName().startswith(Prefix) && + !F.hasFnAttribute(Attribute::OptimizeNone)) + F.removeFnAttr(Attribute::NoInline); + } + } + // TODO: We should attach the attributes defined in OMPKinds.def. } @@ -587,6 +603,9 @@ struct KernelInfoState : AbstractState { /// caller is __kmpc_parallel_51. BooleanStateWithSetVector<uint8_t> ParallelLevels; + /// Flag that indicates if the kernel has nested Parallelism + bool NestedParallelism = false; + /// Abstract State interface ///{ @@ -605,6 +624,7 @@ struct KernelInfoState : AbstractState { /// See AbstractState::indicatePessimisticFixpoint(...) ChangeStatus indicatePessimisticFixpoint() override { IsAtFixpoint = true; + ParallelLevels.indicatePessimisticFixpoint(); ReachingKernelEntries.indicatePessimisticFixpoint(); SPMDCompatibilityTracker.indicatePessimisticFixpoint(); ReachedKnownParallelRegions.indicatePessimisticFixpoint(); @@ -615,6 +635,7 @@ struct KernelInfoState : AbstractState { /// See AbstractState::indicateOptimisticFixpoint(...) ChangeStatus indicateOptimisticFixpoint() override { IsAtFixpoint = true; + ParallelLevels.indicateOptimisticFixpoint(); ReachingKernelEntries.indicateOptimisticFixpoint(); SPMDCompatibilityTracker.indicateOptimisticFixpoint(); ReachedKnownParallelRegions.indicateOptimisticFixpoint(); @@ -635,6 +656,8 @@ struct KernelInfoState : AbstractState { return false; if (ReachingKernelEntries != RHS.ReachingKernelEntries) return false; + if (ParallelLevels != RHS.ParallelLevels) + return false; return true; } @@ -672,6 +695,7 @@ struct KernelInfoState : AbstractState { SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; + NestedParallelism |= KIS.NestedParallelism; return *this; } @@ -806,8 +830,6 @@ struct OpenMPOpt { if (remarksEnabled()) analysisGlobalization(); - - Changed |= eliminateBarriers(); } else { if (PrintICVValues) printICVs(); @@ -830,8 +852,6 @@ struct OpenMPOpt { Changed = true; } } - - Changed |= eliminateBarriers(); } return Changed; @@ -843,7 +863,7 @@ struct OpenMPOpt { InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel, ICV_proc_bind}; - for (Function *F : OMPInfoCache.ModuleSlice) { + for (Function *F : SCC) { for (auto ICV : ICVs) { auto ICVInfo = OMPInfoCache.ICVs[ICV]; auto Remark = [&](OptimizationRemarkAnalysis ORA) { @@ -1397,212 +1417,6 @@ private: return Changed; } - /// Eliminates redundant, aligned barriers in OpenMP offloaded kernels. - /// TODO: Make this an AA and expand it to work across blocks and functions. - bool eliminateBarriers() { - bool Changed = false; - - if (DisableOpenMPOptBarrierElimination) - return /*Changed=*/false; - - if (OMPInfoCache.Kernels.empty()) - return /*Changed=*/false; - - enum ImplicitBarrierType { IBT_ENTRY, IBT_EXIT }; - - class BarrierInfo { - Instruction *I; - enum ImplicitBarrierType Type; - - public: - BarrierInfo(enum ImplicitBarrierType Type) : I(nullptr), Type(Type) {} - BarrierInfo(Instruction &I) : I(&I) {} - - bool isImplicit() { return !I; } - - bool isImplicitEntry() { return isImplicit() && Type == IBT_ENTRY; } - - bool isImplicitExit() { return isImplicit() && Type == IBT_EXIT; } - - Instruction *getInstruction() { return I; } - }; - - for (Function *Kernel : OMPInfoCache.Kernels) { - for (BasicBlock &BB : *Kernel) { - SmallVector<BarrierInfo, 8> BarriersInBlock; - SmallPtrSet<Instruction *, 8> BarriersToBeDeleted; - - // Add the kernel entry implicit barrier. - if (&Kernel->getEntryBlock() == &BB) - BarriersInBlock.push_back(IBT_ENTRY); - - // Find implicit and explicit aligned barriers in the same basic block. - for (Instruction &I : BB) { - if (isa<ReturnInst>(I)) { - // Add the implicit barrier when exiting the kernel. - BarriersInBlock.push_back(IBT_EXIT); - continue; - } - CallBase *CB = dyn_cast<CallBase>(&I); - if (!CB) - continue; - - auto IsAlignBarrierCB = [&](CallBase &CB) { - switch (CB.getIntrinsicID()) { - case Intrinsic::nvvm_barrier0: - case Intrinsic::nvvm_barrier0_and: - case Intrinsic::nvvm_barrier0_or: - case Intrinsic::nvvm_barrier0_popc: - return true; - default: - break; - } - return hasAssumption(CB, - KnownAssumptionString("ompx_aligned_barrier")); - }; - - if (IsAlignBarrierCB(*CB)) { - // Add an explicit aligned barrier. - BarriersInBlock.push_back(I); - } - } - - if (BarriersInBlock.size() <= 1) - continue; - - // A barrier in a barrier pair is removeable if all instructions - // between the barriers in the pair are side-effect free modulo the - // barrier operation. - auto IsBarrierRemoveable = [&Kernel](BarrierInfo *StartBI, - BarrierInfo *EndBI) { - assert( - !StartBI->isImplicitExit() && - "Expected start barrier to be other than a kernel exit barrier"); - assert( - !EndBI->isImplicitEntry() && - "Expected end barrier to be other than a kernel entry barrier"); - // If StarBI instructions is null then this the implicit - // kernel entry barrier, so iterate from the first instruction in the - // entry block. - Instruction *I = (StartBI->isImplicitEntry()) - ? &Kernel->getEntryBlock().front() - : StartBI->getInstruction()->getNextNode(); - assert(I && "Expected non-null start instruction"); - Instruction *E = (EndBI->isImplicitExit()) - ? I->getParent()->getTerminator() - : EndBI->getInstruction(); - assert(E && "Expected non-null end instruction"); - - for (; I != E; I = I->getNextNode()) { - if (!I->mayHaveSideEffects() && !I->mayReadFromMemory()) - continue; - - auto IsPotentiallyAffectedByBarrier = - [](Optional<MemoryLocation> Loc) { - const Value *Obj = (Loc && Loc->Ptr) - ? getUnderlyingObject(Loc->Ptr) - : nullptr; - if (!Obj) { - LLVM_DEBUG( - dbgs() - << "Access to unknown location requires barriers\n"); - return true; - } - if (isa<UndefValue>(Obj)) - return false; - if (isa<AllocaInst>(Obj)) - return false; - if (auto *GV = dyn_cast<GlobalVariable>(Obj)) { - if (GV->isConstant()) - return false; - if (GV->isThreadLocal()) - return false; - if (GV->getAddressSpace() == (int)AddressSpace::Local) - return false; - if (GV->getAddressSpace() == (int)AddressSpace::Constant) - return false; - } - LLVM_DEBUG(dbgs() << "Access to '" << *Obj - << "' requires barriers\n"); - return true; - }; - - if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) { - Optional<MemoryLocation> Loc = MemoryLocation::getForDest(MI); - if (IsPotentiallyAffectedByBarrier(Loc)) - return false; - if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(I)) { - Optional<MemoryLocation> Loc = - MemoryLocation::getForSource(MTI); - if (IsPotentiallyAffectedByBarrier(Loc)) - return false; - } - continue; - } - - if (auto *LI = dyn_cast<LoadInst>(I)) - if (LI->hasMetadata(LLVMContext::MD_invariant_load)) - continue; - - Optional<MemoryLocation> Loc = MemoryLocation::getOrNone(I); - if (IsPotentiallyAffectedByBarrier(Loc)) - return false; - } - - return true; - }; - - // Iterate barrier pairs and remove an explicit barrier if analysis - // deems it removeable. - for (auto *It = BarriersInBlock.begin(), - *End = BarriersInBlock.end() - 1; - It != End; ++It) { - - BarrierInfo *StartBI = It; - BarrierInfo *EndBI = (It + 1); - - // Cannot remove when both are implicit barriers, continue. - if (StartBI->isImplicit() && EndBI->isImplicit()) - continue; - - if (!IsBarrierRemoveable(StartBI, EndBI)) - continue; - - assert(!(StartBI->isImplicit() && EndBI->isImplicit()) && - "Expected at least one explicit barrier to remove."); - - // Remove an explicit barrier, check first, then second. - if (!StartBI->isImplicit()) { - LLVM_DEBUG(dbgs() << "Remove start barrier " - << *StartBI->getInstruction() << "\n"); - BarriersToBeDeleted.insert(StartBI->getInstruction()); - } else { - LLVM_DEBUG(dbgs() << "Remove end barrier " - << *EndBI->getInstruction() << "\n"); - BarriersToBeDeleted.insert(EndBI->getInstruction()); - } - } - - if (BarriersToBeDeleted.empty()) - continue; - - Changed = true; - for (Instruction *I : BarriersToBeDeleted) { - ++NumBarriersEliminated; - auto Remark = [&](OptimizationRemark OR) { - return OR << "Redundant barrier eliminated."; - }; - - if (EnableVerboseRemarks) - emitRemark<OptimizationRemark>(I, "OMP190", Remark); - I->eraseFromParent(); - } - } - } - - return Changed; - } - void analysisGlobalization() { auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; @@ -1743,10 +1557,14 @@ private: // function. Used for storing information of the async transfer, allowing to // wait on it later. auto &IRBuilder = OMPInfoCache.OMPBuilder; - auto *F = RuntimeCall.getCaller(); - Instruction *FirstInst = &(F->getEntryBlock().front()); - AllocaInst *Handle = new AllocaInst( - IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst); + Function *F = RuntimeCall.getCaller(); + BasicBlock &Entry = F->getEntryBlock(); + IRBuilder.Builder.SetInsertPoint(&Entry, + Entry.getFirstNonPHIOrDbgOrAlloca()); + Value *Handle = IRBuilder.Builder.CreateAlloca( + IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle"); + Handle = + IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr); // Add "issue" runtime call declaration: // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, @@ -1995,7 +1813,7 @@ private: bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } /// Cache to remember the unique kernel for a function. - DenseMap<Function *, Optional<Kernel>> UniqueKernelMap; + DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap; /// Find the unique kernel that will execute \p F, if any. Kernel getUniqueKernelFor(Function &F); @@ -2055,30 +1873,6 @@ private: [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); }); } - /// RAII struct to temporarily change an RTL function's linkage to external. - /// This prevents it from being mistakenly removed by other optimizations. - struct ExternalizationRAII { - ExternalizationRAII(OMPInformationCache &OMPInfoCache, - RuntimeFunction RFKind) - : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) { - if (!Declaration) - return; - - LinkageType = Declaration->getLinkage(); - Declaration->setLinkage(GlobalValue::ExternalLinkage); - } - - ~ExternalizationRAII() { - if (!Declaration) - return; - - Declaration->setLinkage(LinkageType); - } - - Function *Declaration; - GlobalValue::LinkageTypes LinkageType; - }; - /// The underlying module. Module &M; @@ -2103,21 +1897,6 @@ private: if (SCC.empty()) return false; - // Temporarily make these function have external linkage so the Attributor - // doesn't remove them when we try to look them up later. - ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel); - ExternalizationRAII EndParallel(OMPInfoCache, - OMPRTL___kmpc_kernel_end_parallel); - ExternalizationRAII BarrierSPMD(OMPInfoCache, - OMPRTL___kmpc_barrier_simple_spmd); - ExternalizationRAII BarrierGeneric(OMPInfoCache, - OMPRTL___kmpc_barrier_simple_generic); - ExternalizationRAII ThreadId(OMPInfoCache, - OMPRTL___kmpc_get_hardware_thread_id_in_block); - ExternalizationRAII NumThreads( - OMPInfoCache, OMPRTL___kmpc_get_hardware_num_threads_in_block); - ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size); - registerAAs(IsModulePass); ChangeStatus Changed = A.run(); @@ -2131,17 +1910,22 @@ private: void registerFoldRuntimeCall(RuntimeFunction RF); /// Populate the Attributor with abstract attribute opportunities in the - /// function. + /// functions. void registerAAs(bool IsModulePass); + +public: + /// Callback to register AAs for live functions, including internal functions + /// marked live during the traversal. + static void registerAAsForFunction(Attributor &A, const Function &F); }; Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { - if (!OMPInfoCache.ModuleSlice.count(&F)) + if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&F)) return nullptr; // Use a scope to keep the lifetime of the CachedKernel short. { - Optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; + std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F]; if (CachedKernel) return *CachedKernel; @@ -2327,16 +2111,16 @@ struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> { static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A); /// Return the value with which \p I can be replaced for specific \p ICV. - virtual Optional<Value *> getReplacementValue(InternalControlVar ICV, - const Instruction *I, - Attributor &A) const { - return None; + virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV, + const Instruction *I, + Attributor &A) const { + return std::nullopt; } /// Return an assumed unique ICV value if a single candidate is found. If - /// there cannot be one, return a nullptr. If it is not clear yet, return the - /// Optional::NoneType. - virtual Optional<Value *> + /// there cannot be one, return a nullptr. If it is not clear yet, return + /// std::nullopt. + virtual std::optional<Value *> getUniqueReplacementValue(InternalControlVar ICV) const = 0; // Currently only nthreads is being tracked. @@ -2402,7 +2186,7 @@ struct AAICVTrackerFunction : public AAICVTracker { }; auto CallCheck = [&](Instruction &I) { - Optional<Value *> ReplVal = getValueForCall(A, I, ICV); + std::optional<Value *> ReplVal = getValueForCall(A, I, ICV); if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) HasChanged = ChangeStatus::CHANGED; @@ -2429,13 +2213,13 @@ struct AAICVTrackerFunction : public AAICVTracker { /// Helper to check if \p I is a call and get the value for it if it is /// unique. - Optional<Value *> getValueForCall(Attributor &A, const Instruction &I, - InternalControlVar &ICV) const { + std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I, + InternalControlVar &ICV) const { const auto *CB = dyn_cast<CallBase>(&I); if (!CB || CB->hasFnAttr("no_openmp") || CB->hasFnAttr("no_openmp_routines")) - return None; + return std::nullopt; auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; @@ -2446,7 +2230,7 @@ struct AAICVTrackerFunction : public AAICVTracker { if (CalledFunction == nullptr) return nullptr; if (CalledFunction == GetterRFI.Declaration) - return None; + return std::nullopt; if (CalledFunction == SetterRFI.Declaration) { if (ICVReplacementValuesMap[ICV].count(&I)) return ICVReplacementValuesMap[ICV].lookup(&I); @@ -2462,7 +2246,7 @@ struct AAICVTrackerFunction : public AAICVTracker { *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED); if (ICVTrackingAA.isAssumedTracked()) { - Optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV); + std::optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV); if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I), OMPInfoCache))) return URV; @@ -2472,16 +2256,16 @@ struct AAICVTrackerFunction : public AAICVTracker { return nullptr; } - // We don't check unique value for a function, so return None. - Optional<Value *> + // We don't check unique value for a function, so return std::nullopt. + std::optional<Value *> getUniqueReplacementValue(InternalControlVar ICV) const override { - return None; + return std::nullopt; } /// Return the value with which \p I can be replaced for specific \p ICV. - Optional<Value *> getReplacementValue(InternalControlVar ICV, - const Instruction *I, - Attributor &A) const override { + std::optional<Value *> getReplacementValue(InternalControlVar ICV, + const Instruction *I, + Attributor &A) const override { const auto &ValuesMap = ICVReplacementValuesMap[ICV]; if (ValuesMap.count(I)) return ValuesMap.lookup(I); @@ -2490,7 +2274,7 @@ struct AAICVTrackerFunction : public AAICVTracker { SmallPtrSet<const Instruction *, 16> Visited; Worklist.push_back(I); - Optional<Value *> ReplVal; + std::optional<Value *> ReplVal; while (!Worklist.empty()) { const Instruction *CurrInst = Worklist.pop_back_val(); @@ -2503,7 +2287,7 @@ struct AAICVTrackerFunction : public AAICVTracker { // ICV. while ((CurrInst = CurrInst->getPrevNode())) { if (ValuesMap.count(CurrInst)) { - Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); + std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); // Unknown value, track new. if (!ReplVal) { ReplVal = NewReplVal; @@ -2518,7 +2302,7 @@ struct AAICVTrackerFunction : public AAICVTracker { break; } - Optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV); + std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV); if (!NewReplVal) continue; @@ -2566,12 +2350,12 @@ struct AAICVTrackerFunctionReturned : AAICVTracker { } // Map of ICV to their values at specific program point. - EnumeratedArray<Optional<Value *>, InternalControlVar, + EnumeratedArray<std::optional<Value *>, InternalControlVar, InternalControlVar::ICV___last> ICVReplacementValuesMap; /// Return the value with which \p I can be replaced for specific \p ICV. - Optional<Value *> + std::optional<Value *> getUniqueReplacementValue(InternalControlVar ICV) const override { return ICVReplacementValuesMap[ICV]; } @@ -2585,11 +2369,11 @@ struct AAICVTrackerFunctionReturned : AAICVTracker { return indicatePessimisticFixpoint(); for (InternalControlVar ICV : TrackableICVs) { - Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; - Optional<Value *> UniqueICVValue; + std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; + std::optional<Value *> UniqueICVValue; auto CheckReturnInst = [&](Instruction &I) { - Optional<Value *> NewReplVal = + std::optional<Value *> NewReplVal = ICVTrackingAA.getReplacementValue(ICV, &I, A); // If we found a second ICV value there is no unique returned value. @@ -2660,7 +2444,7 @@ struct AAICVTrackerCallSite : AAICVTracker { void trackStatistics() const override {} InternalControlVar AssociatedICV; - Optional<Value *> ReplVal; + std::optional<Value *> ReplVal; ChangeStatus updateImpl(Attributor &A) override { const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>( @@ -2670,7 +2454,7 @@ struct AAICVTrackerCallSite : AAICVTracker { if (!ICVTrackingAA.isAssumedTracked()) return indicatePessimisticFixpoint(); - Optional<Value *> NewReplVal = + std::optional<Value *> NewReplVal = ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A); if (ReplVal == NewReplVal) @@ -2682,7 +2466,7 @@ struct AAICVTrackerCallSite : AAICVTracker { // Return the value with which associated value can be replaced for specific // \p ICV. - Optional<Value *> + std::optional<Value *> getUniqueReplacementValue(InternalControlVar ICV) const override { return ReplVal; } @@ -2706,13 +2490,13 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker { } // Map of ICV to their values at specific program point. - EnumeratedArray<Optional<Value *>, InternalControlVar, + EnumeratedArray<std::optional<Value *>, InternalControlVar, InternalControlVar::ICV___last> ICVReplacementValuesMap; /// Return the value with which associated value can be replaced for specific /// \p ICV. - Optional<Value *> + std::optional<Value *> getUniqueReplacementValue(InternalControlVar ICV) const override { return ICVReplacementValuesMap[ICV]; } @@ -2728,8 +2512,8 @@ struct AAICVTrackerCallSiteReturned : AAICVTracker { return indicatePessimisticFixpoint(); for (InternalControlVar ICV : TrackableICVs) { - Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; - Optional<Value *> NewReplVal = + std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV]; + std::optional<Value *> NewReplVal = ICVTrackingAA.getUniqueReplacementValue(ICV); if (ReplVal == NewReplVal) @@ -2746,77 +2530,216 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A) : AAExecutionDomain(IRP, A) {} + ~AAExecutionDomainFunction() { + delete RPOT; + } + + void initialize(Attributor &A) override { + if (getAnchorScope()->isDeclaration()) { + indicatePessimisticFixpoint(); + return; + } + RPOT = new ReversePostOrderTraversal<Function *>(getAnchorScope()); + } + const std::string getAsStr() const override { - return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) + - "/" + std::to_string(NumBBs) + " BBs thread 0 only."; + unsigned TotalBlocks = 0, InitialThreadBlocks = 0; + for (auto &It : BEDMap) { + TotalBlocks++; + InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly; + } + return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" + + std::to_string(TotalBlocks) + " executed by initial thread only"; } /// See AbstractAttribute::trackStatistics(). void trackStatistics() const override {} - void initialize(Attributor &A) override { - Function *F = getAnchorScope(); - for (const auto &BB : *F) - SingleThreadedBBs.insert(&BB); - NumBBs = SingleThreadedBBs.size(); - } - ChangeStatus manifest(Attributor &A) override { LLVM_DEBUG({ - for (const BasicBlock *BB : SingleThreadedBBs) + for (const BasicBlock &BB : *getAnchorScope()) { + if (!isExecutedByInitialThreadOnly(BB)) + continue; dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " " - << BB->getName() << " is executed by a single thread.\n"; + << BB.getName() << " is executed by a single thread.\n"; + } }); - return ChangeStatus::UNCHANGED; - } - ChangeStatus updateImpl(Attributor &A) override; + ChangeStatus Changed = ChangeStatus::UNCHANGED; - /// Check if an instruction is executed by a single thread. - bool isExecutedByInitialThreadOnly(const Instruction &I) const override { - return isExecutedByInitialThreadOnly(*I.getParent()); + if (DisableOpenMPOptBarrierElimination) + return Changed; + + SmallPtrSet<CallBase *, 16> DeletedBarriers; + auto HandleAlignedBarrier = [&](CallBase *CB) { + const ExecutionDomainTy &ED = CEDMap[CB]; + if (!ED.IsReachedFromAlignedBarrierOnly || + ED.EncounteredNonLocalSideEffect) + return; + + // We can remove this barrier, if it is one, or all aligned barriers + // reaching the kernel end. In the latter case we can transitively work + // our way back until we find a barrier that guards a side-effect if we + // are dealing with the kernel end here. + if (CB) { + DeletedBarriers.insert(CB); + A.deleteAfterManifest(*CB); + ++NumBarriersEliminated; + Changed = ChangeStatus::CHANGED; + } else if (!ED.AlignedBarriers.empty()) { + NumBarriersEliminated += ED.AlignedBarriers.size(); + Changed = ChangeStatus::CHANGED; + SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(), + ED.AlignedBarriers.end()); + SmallSetVector<CallBase *, 16> Visited; + while (!Worklist.empty()) { + CallBase *LastCB = Worklist.pop_back_val(); + if (!Visited.insert(LastCB)) + continue; + if (!DeletedBarriers.count(LastCB)) { + A.deleteAfterManifest(*LastCB); + continue; + } + // The final aligned barrier (LastCB) reaching the kernel end was + // removed already. This means we can go one step further and remove + // the barriers encoutered last before (LastCB). + const ExecutionDomainTy &LastED = CEDMap[LastCB]; + Worklist.append(LastED.AlignedBarriers.begin(), + LastED.AlignedBarriers.end()); + } + } + + // If we actually eliminated a barrier we need to eliminate the associated + // llvm.assumes as well to avoid creating UB. + if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty())) + for (auto *AssumeCB : ED.EncounteredAssumes) + A.deleteAfterManifest(*AssumeCB); + }; + + for (auto *CB : AlignedBarriers) + HandleAlignedBarrier(CB); + + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + // Handle the "kernel end barrier" for kernels too. + if (OMPInfoCache.Kernels.count(getAnchorScope())) + HandleAlignedBarrier(nullptr); + + return Changed; } + /// Merge barrier and assumption information from \p PredED into the successor + /// \p ED. + void + mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED, + const ExecutionDomainTy &PredED); + + /// Merge all information from \p PredED into the successor \p ED. If + /// \p InitialEdgeOnly is set, only the initial edge will enter the block + /// represented by \p ED from this predecessor. + void mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED, + const ExecutionDomainTy &PredED, + bool InitialEdgeOnly = false); + + /// Accumulate information for the entry block in \p EntryBBED. + void handleEntryBB(Attributor &A, ExecutionDomainTy &EntryBBED); + + /// See AbstractAttribute::updateImpl. + ChangeStatus updateImpl(Attributor &A) override; + + /// Query interface, see AAExecutionDomain + ///{ bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override { - return isValidState() && SingleThreadedBBs.contains(&BB); + if (!isValidState()) + return false; + return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly; } - /// Set of basic blocks that are executed by a single thread. - SmallSetVector<const BasicBlock *, 16> SingleThreadedBBs; + bool isExecutedInAlignedRegion(Attributor &A, + const Instruction &I) const override { + if (!isValidState() || isa<CallBase>(I)) + return false; - /// Total number of basic blocks in this function. - long unsigned NumBBs = 0; -}; + const Instruction *CurI; -ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { - Function *F = getAnchorScope(); - ReversePostOrderTraversal<Function *> RPOT(F); - auto NumSingleThreadedBBs = SingleThreadedBBs.size(); + // Check forward until a call or the block end is reached. + CurI = &I; + do { + auto *CB = dyn_cast<CallBase>(CurI); + if (!CB) + continue; + const auto &It = CEDMap.find(CB); + if (It == CEDMap.end()) + continue; + if (!It->getSecond().IsReachedFromAlignedBarrierOnly) + return false; + } while ((CurI = CurI->getNextNonDebugInstruction())); - bool AllCallSitesKnown; - auto PredForCallSite = [&](AbstractCallSite ACS) { - const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( - *this, IRPosition::function(*ACS.getInstruction()->getFunction()), - DepClassTy::REQUIRED); - return ACS.isDirectCall() && - ExecutionDomainAA.isExecutedByInitialThreadOnly( - *ACS.getInstruction()); - }; + if (!CurI && !BEDMap.lookup(I.getParent()).IsReachedFromAlignedBarrierOnly) + return false; + + // Check backward until a call or the block beginning is reached. + CurI = &I; + do { + auto *CB = dyn_cast<CallBase>(CurI); + if (!CB) + continue; + const auto &It = CEDMap.find(CB); + if (It == CEDMap.end()) + continue; + if (!AA::isNoSyncInst(A, *CB, *this)) { + if (It->getSecond().IsReachedFromAlignedBarrierOnly) + break; + return false; + } - if (!A.checkForAllCallSites(PredForCallSite, *this, - /* RequiresAllCallSites */ true, - AllCallSitesKnown)) - SingleThreadedBBs.remove(&F->getEntryBlock()); + Function *Callee = CB->getCalledFunction(); + if (!Callee || Callee->isDeclaration()) + return false; + const auto &EDAA = A.getAAFor<AAExecutionDomain>( + *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL); + if (!EDAA.getState().isValidState()) + return false; + if (!EDAA.getFunctionExecutionDomain().IsReachedFromAlignedBarrierOnly) + return false; + break; + } while ((CurI = CurI->getPrevNonDebugInstruction())); - auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; + if (!CurI && + !llvm::all_of( + predecessors(I.getParent()), [&](const BasicBlock *PredBB) { + return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly; + })) { + return false; + } + + // On neither traversal we found a anything but aligned barriers. + return true; + } + + ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override { + assert(isValidState() && + "No request should be made against an invalid state!"); + return BEDMap.lookup(&BB); + } + ExecutionDomainTy getExecutionDomain(const CallBase &CB) const override { + assert(isValidState() && + "No request should be made against an invalid state!"); + return CEDMap.lookup(&CB); + } + ExecutionDomainTy getFunctionExecutionDomain() const override { + assert(isValidState() && + "No request should be made against an invalid state!"); + return BEDMap.lookup(nullptr); + } + ///} // Check if the edge into the successor block contains a condition that only // lets the main thread execute it. - auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { + static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge, + BasicBlock &SuccessorBB) { if (!Edge || !Edge->isConditional()) return false; - if (Edge->getSuccessor(0) != SuccessorBB) + if (Edge->getSuccessor(0) != &SuccessorBB) return false; auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition()); @@ -2830,6 +2753,8 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!) if (C->isAllOnesValue()) { auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0)); + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr; if (!CB) return false; @@ -2853,30 +2778,335 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { return false; }; - // Merge all the predecessor states into the current basic block. A basic - // block is executed by a single thread if all of its predecessors are. - auto MergePredecessorStates = [&](BasicBlock *BB) { - if (pred_empty(BB)) - return SingleThreadedBBs.contains(BB); - - bool IsInitialThread = true; - for (BasicBlock *PredBB : predecessors(BB)) { - if (!IsInitialThreadOnly(dyn_cast<BranchInst>(PredBB->getTerminator()), - BB)) - IsInitialThread &= SingleThreadedBBs.contains(PredBB); + /// Mapping containing information per block. + DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap; + DenseMap<const CallBase *, ExecutionDomainTy> CEDMap; + SmallSetVector<CallBase *, 16> AlignedBarriers; + + ReversePostOrderTraversal<Function *> *RPOT = nullptr; +}; + +void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions( + Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) { + for (auto *EA : PredED.EncounteredAssumes) + ED.addAssumeInst(A, *EA); + + for (auto *AB : PredED.AlignedBarriers) + ED.addAlignedBarrier(A, *AB); +} + +void AAExecutionDomainFunction::mergeInPredecessor( + Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED, + bool InitialEdgeOnly) { + ED.IsExecutedByInitialThreadOnly = + InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly && + ED.IsExecutedByInitialThreadOnly); + + ED.IsReachedFromAlignedBarrierOnly = ED.IsReachedFromAlignedBarrierOnly && + PredED.IsReachedFromAlignedBarrierOnly; + ED.EncounteredNonLocalSideEffect = + ED.EncounteredNonLocalSideEffect | PredED.EncounteredNonLocalSideEffect; + if (ED.IsReachedFromAlignedBarrierOnly) + mergeInPredecessorBarriersAndAssumptions(A, ED, PredED); + else + ED.clearAssumeInstAndAlignedBarriers(); +} + +void AAExecutionDomainFunction::handleEntryBB(Attributor &A, + ExecutionDomainTy &EntryBBED) { + SmallVector<ExecutionDomainTy> PredExecDomains; + auto PredForCallSite = [&](AbstractCallSite ACS) { + const auto &EDAA = A.getAAFor<AAExecutionDomain>( + *this, IRPosition::function(*ACS.getInstruction()->getFunction()), + DepClassTy::OPTIONAL); + if (!EDAA.getState().isValidState()) + return false; + PredExecDomains.emplace_back( + EDAA.getExecutionDomain(*cast<CallBase>(ACS.getInstruction()))); + return true; + }; + + bool AllCallSitesKnown; + if (A.checkForAllCallSites(PredForCallSite, *this, + /* RequiresAllCallSites */ true, + AllCallSitesKnown)) { + for (const auto &PredED : PredExecDomains) + mergeInPredecessor(A, EntryBBED, PredED); + + } else { + // We could not find all predecessors, so this is either a kernel or a + // function with external linkage (or with some other weird uses). + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + if (OMPInfoCache.Kernels.count(getAnchorScope())) { + EntryBBED.IsExecutedByInitialThreadOnly = false; + EntryBBED.IsReachedFromAlignedBarrierOnly = true; + EntryBBED.EncounteredNonLocalSideEffect = false; + } else { + EntryBBED.IsExecutedByInitialThreadOnly = false; + EntryBBED.IsReachedFromAlignedBarrierOnly = false; + EntryBBED.EncounteredNonLocalSideEffect = true; } + } + + auto &FnED = BEDMap[nullptr]; + FnED.IsReachingAlignedBarrierOnly &= + EntryBBED.IsReachedFromAlignedBarrierOnly; +} + +ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { - return IsInitialThread; + bool Changed = false; + + // Helper to deal with an aligned barrier encountered during the forward + // traversal. \p CB is the aligned barrier, \p ED is the execution domain when + // it was encountered. + auto HandleAlignedBarrier = [&](CallBase *CB, ExecutionDomainTy &ED) { + if (CB) + Changed |= AlignedBarriers.insert(CB); + // First, update the barrier ED kept in the separate CEDMap. + auto &CallED = CEDMap[CB]; + mergeInPredecessor(A, CallED, ED); + // Next adjust the ED we use for the traversal. + ED.EncounteredNonLocalSideEffect = false; + ED.IsReachedFromAlignedBarrierOnly = true; + // Aligned barrier collection has to come last. + ED.clearAssumeInstAndAlignedBarriers(); + if (CB) + ED.addAlignedBarrier(A, *CB); + }; + + auto &LivenessAA = + A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL); + + // Set \p R to \V and report true if that changed \p R. + auto SetAndRecord = [&](bool &R, bool V) { + bool Eq = (R == V); + R = V; + return !Eq; }; - for (auto *BB : RPOT) { - if (!MergePredecessorStates(BB)) - SingleThreadedBBs.remove(BB); + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + + Function *F = getAnchorScope(); + BasicBlock &EntryBB = F->getEntryBlock(); + bool IsKernel = OMPInfoCache.Kernels.count(F); + + SmallVector<Instruction *> SyncInstWorklist; + for (auto &RIt : *RPOT) { + BasicBlock &BB = *RIt; + + bool IsEntryBB = &BB == &EntryBB; + // TODO: We use local reasoning since we don't have a divergence analysis + // running as well. We could basically allow uniform branches here. + bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel; + ExecutionDomainTy ED; + // Propagate "incoming edges" into information about this block. + if (IsEntryBB) { + handleEntryBB(A, ED); + } else { + // For live non-entry blocks we only propagate + // information via live edges. + if (LivenessAA.isAssumedDead(&BB)) + continue; + + for (auto *PredBB : predecessors(&BB)) { + if (LivenessAA.isEdgeDead(PredBB, &BB)) + continue; + bool InitialEdgeOnly = isInitialThreadOnlyEdge( + A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB); + mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly); + } + } + + // Now we traverse the block, accumulate effects in ED and attach + // information to calls. + for (Instruction &I : BB) { + bool UsedAssumedInformation; + if (A.isAssumedDead(I, *this, &LivenessAA, UsedAssumedInformation, + /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL, + /* CheckForDeadStore */ true)) + continue; + + // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the + // former is collected the latter is ignored. + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { + if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) { + ED.addAssumeInst(A, *AI); + continue; + } + // TODO: Should we also collect and delete lifetime markers? + if (II->isAssumeLikeIntrinsic()) + continue; + } + + auto *CB = dyn_cast<CallBase>(&I); + bool IsNoSync = AA::isNoSyncInst(A, I, *this); + bool IsAlignedBarrier = + !IsNoSync && CB && + AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock); + + AlignedBarrierLastInBlock &= IsNoSync; + + // Next we check for calls. Aligned barriers are handled + // explicitly, everything else is kept for the backward traversal and will + // also affect our state. + if (CB) { + if (IsAlignedBarrier) { + HandleAlignedBarrier(CB, ED); + AlignedBarrierLastInBlock = true; + continue; + } + + // Check the pointer(s) of a memory intrinsic explicitly. + if (isa<MemIntrinsic>(&I)) { + if (!ED.EncounteredNonLocalSideEffect && + AA::isPotentiallyAffectedByBarrier(A, I, *this)) + ED.EncounteredNonLocalSideEffect = true; + if (!IsNoSync) { + ED.IsReachedFromAlignedBarrierOnly = false; + SyncInstWorklist.push_back(&I); + } + continue; + } + + // Record how we entered the call, then accumulate the effect of the + // call in ED for potential use by the callee. + auto &CallED = CEDMap[CB]; + mergeInPredecessor(A, CallED, ED); + + // If we have a sync-definition we can check if it starts/ends in an + // aligned barrier. If we are unsure we assume any sync breaks + // alignment. + Function *Callee = CB->getCalledFunction(); + if (!IsNoSync && Callee && !Callee->isDeclaration()) { + const auto &EDAA = A.getAAFor<AAExecutionDomain>( + *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL); + if (EDAA.getState().isValidState()) { + const auto &CalleeED = EDAA.getFunctionExecutionDomain(); + ED.IsReachedFromAlignedBarrierOnly = + CalleeED.IsReachedFromAlignedBarrierOnly; + AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly; + if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly) + ED.EncounteredNonLocalSideEffect |= + CalleeED.EncounteredNonLocalSideEffect; + else + ED.EncounteredNonLocalSideEffect = + CalleeED.EncounteredNonLocalSideEffect; + if (!CalleeED.IsReachingAlignedBarrierOnly) + SyncInstWorklist.push_back(&I); + if (CalleeED.IsReachedFromAlignedBarrierOnly) + mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED); + continue; + } + } + ED.IsReachedFromAlignedBarrierOnly = + IsNoSync && ED.IsReachedFromAlignedBarrierOnly; + AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly; + ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory(); + if (!IsNoSync) + SyncInstWorklist.push_back(&I); + } + + if (!I.mayHaveSideEffects() && !I.mayReadFromMemory()) + continue; + + // If we have a callee we try to use fine-grained information to + // determine local side-effects. + if (CB) { + const auto &MemAA = A.getAAFor<AAMemoryLocation>( + *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL); + + auto AccessPred = [&](const Instruction *I, const Value *Ptr, + AAMemoryLocation::AccessKind, + AAMemoryLocation::MemoryLocationsKind) { + return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I); + }; + if (MemAA.getState().isValidState() && + MemAA.checkForAllAccessesToMemoryKind( + AccessPred, AAMemoryLocation::ALL_LOCATIONS)) + continue; + } + + if (!I.mayHaveSideEffects() && OMPInfoCache.isOnlyUsedByAssume(I)) + continue; + + if (auto *LI = dyn_cast<LoadInst>(&I)) + if (LI->hasMetadata(LLVMContext::MD_invariant_load)) + continue; + + if (!ED.EncounteredNonLocalSideEffect && + AA::isPotentiallyAffectedByBarrier(A, I, *this)) + ED.EncounteredNonLocalSideEffect = true; + } + + if (!isa<UnreachableInst>(BB.getTerminator()) && + !BB.getTerminator()->getNumSuccessors()) { + + auto &FnED = BEDMap[nullptr]; + mergeInPredecessor(A, FnED, ED); + + if (IsKernel) + HandleAlignedBarrier(nullptr, ED); + } + + ExecutionDomainTy &StoredED = BEDMap[&BB]; + ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly; + + // Check if we computed anything different as part of the forward + // traversal. We do not take assumptions and aligned barriers into account + // as they do not influence the state we iterate. Backward traversal values + // are handled later on. + if (ED.IsExecutedByInitialThreadOnly != + StoredED.IsExecutedByInitialThreadOnly || + ED.IsReachedFromAlignedBarrierOnly != + StoredED.IsReachedFromAlignedBarrierOnly || + ED.EncounteredNonLocalSideEffect != + StoredED.EncounteredNonLocalSideEffect) + Changed = true; + + // Update the state with the new value. + StoredED = std::move(ED); } - return (NumSingleThreadedBBs == SingleThreadedBBs.size()) - ? ChangeStatus::UNCHANGED - : ChangeStatus::CHANGED; + // Propagate (non-aligned) sync instruction effects backwards until the + // entry is hit or an aligned barrier. + SmallSetVector<BasicBlock *, 16> Visited; + while (!SyncInstWorklist.empty()) { + Instruction *SyncInst = SyncInstWorklist.pop_back_val(); + Instruction *CurInst = SyncInst; + bool HitAlignedBarrier = false; + while ((CurInst = CurInst->getPrevNode())) { + auto *CB = dyn_cast<CallBase>(CurInst); + if (!CB) + continue; + auto &CallED = CEDMap[CB]; + if (SetAndRecord(CallED.IsReachingAlignedBarrierOnly, false)) + Changed = true; + HitAlignedBarrier = AlignedBarriers.count(CB); + if (HitAlignedBarrier) + break; + } + if (HitAlignedBarrier) + continue; + BasicBlock *SyncBB = SyncInst->getParent(); + for (auto *PredBB : predecessors(SyncBB)) { + if (LivenessAA.isEdgeDead(PredBB, SyncBB)) + continue; + if (!Visited.insert(PredBB)) + continue; + SyncInstWorklist.push_back(PredBB->getTerminator()); + auto &PredED = BEDMap[PredBB]; + if (SetAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) + Changed = true; + } + if (SyncBB != &EntryBB) + continue; + auto &FnED = BEDMap[nullptr]; + if (SetAndRecord(FnED.IsReachingAlignedBarrierOnly, false)) + Changed = true; + } + + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; } /// Try to replace memory allocation calls called by a single thread with a @@ -2955,12 +3185,18 @@ struct AAHeapToSharedFunction : public AAHeapToShared { auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; + if (!RFI.Declaration) + return; Attributor::SimplifictionCallbackTy SCB = [](const IRPosition &, const AbstractAttribute *, - bool &) -> Optional<Value *> { return nullptr; }; + bool &) -> std::optional<Value *> { return nullptr; }; + + Function *F = getAnchorScope(); for (User *U : RFI.Declaration->users()) if (CallBase *CB = dyn_cast<CallBase>(U)) { + if (CB->getFunction() != F) + continue; MallocCalls.insert(CB); A.registerSimplificationCallback(IRPosition::callsite_returned(*CB), SCB); @@ -3057,20 +3293,33 @@ struct AAHeapToSharedFunction : public AAHeapToShared { } ChangeStatus updateImpl(Attributor &A) override { + if (MallocCalls.empty()) + return indicatePessimisticFixpoint(); auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; + if (!RFI.Declaration) + return ChangeStatus::UNCHANGED; + Function *F = getAnchorScope(); auto NumMallocCalls = MallocCalls.size(); // Only consider malloc calls executed by a single thread with a constant. for (User *U : RFI.Declaration->users()) { - const auto &ED = A.getAAFor<AAExecutionDomain>( - *this, IRPosition::function(*F), DepClassTy::REQUIRED); - if (CallBase *CB = dyn_cast<CallBase>(U)) - if (!isa<ConstantInt>(CB->getArgOperand(0)) || - !ED.isExecutedByInitialThreadOnly(*CB)) + if (CallBase *CB = dyn_cast<CallBase>(U)) { + if (CB->getCaller() != F) + continue; + if (!MallocCalls.count(CB)) + continue; + if (!isa<ConstantInt>(CB->getArgOperand(0))) { MallocCalls.remove(CB); + continue; + } + const auto &ED = A.getAAFor<AAExecutionDomain>( + *this, IRPosition::function(*F), DepClassTy::REQUIRED); + if (!ED.isExecutedByInitialThreadOnly(*CB)) + MallocCalls.remove(CB); + } } findPotentialRemovedFreeCalls(A); @@ -3115,6 +3364,10 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { ", #Reaching Kernels: " + (ReachingKernelEntries.isValidState() ? std::to_string(ReachingKernelEntries.size()) + : "<invalid>") + + ", #ParLevels: " + + (ParallelLevels.isValidState() + ? std::to_string(ParallelLevels.size()) : "<invalid>"); } @@ -3202,7 +3455,7 @@ struct AAKernelInfoFunction : AAKernelInfo { Attributor::SimplifictionCallbackTy StateMachineSimplifyCB = [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> Optional<Value *> { + bool &UsedAssumedInformation) -> std::optional<Value *> { // IRP represents the "use generic state machine" argument of an // __kmpc_target_init call. We will answer this one with the internal // state. As long as we are not in an invalid state, we will create a @@ -3223,7 +3476,7 @@ struct AAKernelInfoFunction : AAKernelInfo { Attributor::SimplifictionCallbackTy ModeSimplifyCB = [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> Optional<Value *> { + bool &UsedAssumedInformation) -> std::optional<Value *> { // IRP represents the "SPMDCompatibilityTracker" argument of an // __kmpc_target_init or // __kmpc_target_deinit call. We will answer this one with the internal @@ -3244,32 +3497,9 @@ struct AAKernelInfoFunction : AAKernelInfo { return Val; }; - Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB = - [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> Optional<Value *> { - // IRP represents the "RequiresFullRuntime" argument of an - // __kmpc_target_init or __kmpc_target_deinit call. We will answer this - // one with the internal state of the SPMDCompatibilityTracker, so if - // generic then true, if SPMD then false. - if (!SPMDCompatibilityTracker.isValidState()) - return nullptr; - if (!SPMDCompatibilityTracker.isAtFixpoint()) { - if (AA) - A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); - UsedAssumedInformation = true; - } else { - UsedAssumedInformation = false; - } - auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), - !SPMDCompatibilityTracker.isAssumed()); - return Val; - }; - constexpr const int InitModeArgNo = 1; constexpr const int DeinitModeArgNo = 1; constexpr const int InitUseStateMachineArgNo = 2; - constexpr const int InitRequiresFullRuntimeArgNo = 3; - constexpr const int DeinitRequiresFullRuntimeArgNo = 2; A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), StateMachineSimplifyCB); @@ -3279,14 +3509,6 @@ struct AAKernelInfoFunction : AAKernelInfo { A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo), ModeSimplifyCB); - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelInitCB, - InitRequiresFullRuntimeArgNo), - IsGenericModeSimplifyCB); - A.registerSimplificationCallback( - IRPosition::callsite_argument(*KernelDeinitCB, - DeinitRequiresFullRuntimeArgNo), - IsGenericModeSimplifyCB); // Check if we know we are in SPMD-mode already. ConstantInt *ModeArg = @@ -3296,6 +3518,84 @@ struct AAKernelInfoFunction : AAKernelInfo { // This is a generic region but SPMDization is disabled so stop tracking. else if (DisableOpenMPOptSPMDization) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + + // Register virtual uses of functions we might need to preserve. + auto RegisterVirtualUse = [&](RuntimeFunction RFKind, + Attributor::VirtualUseCallbackTy &CB) { + if (!OMPInfoCache.RFIs[RFKind].Declaration) + return; + A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB); + }; + + // Add a dependence to ensure updates if the state changes. + auto AddDependence = [](Attributor &A, const AAKernelInfo *KI, + const AbstractAttribute *QueryingAA) { + if (QueryingAA) { + A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL); + } + return true; + }; + + Attributor::VirtualUseCallbackTy CustomStateMachineUseCB = + [&](Attributor &A, const AbstractAttribute *QueryingAA) { + // Whenever we create a custom state machine we will insert calls to + // __kmpc_get_hardware_num_threads_in_block, + // __kmpc_get_warp_size, + // __kmpc_barrier_simple_generic, + // __kmpc_kernel_parallel, and + // __kmpc_kernel_end_parallel. + // Not needed if we are on track for SPMDzation. + if (SPMDCompatibilityTracker.isValidState()) + return AddDependence(A, this, QueryingAA); + // Not needed if we can't rewrite due to an invalid state. + if (!ReachedKnownParallelRegions.isValidState()) + return AddDependence(A, this, QueryingAA); + return false; + }; + + // Not needed if we are pre-runtime merge. + if (!KernelInitCB->getCalledFunction()->isDeclaration()) { + RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block, + CustomStateMachineUseCB); + RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB); + RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic, + CustomStateMachineUseCB); + RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel, + CustomStateMachineUseCB); + RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel, + CustomStateMachineUseCB); + } + + // If we do not perform SPMDzation we do not need the virtual uses below. + if (SPMDCompatibilityTracker.isAtFixpoint()) + return; + + Attributor::VirtualUseCallbackTy HWThreadIdUseCB = + [&](Attributor &A, const AbstractAttribute *QueryingAA) { + // Whenever we perform SPMDzation we will insert + // __kmpc_get_hardware_thread_id_in_block calls. + if (!SPMDCompatibilityTracker.isValidState()) + return AddDependence(A, this, QueryingAA); + return false; + }; + RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block, + HWThreadIdUseCB); + + Attributor::VirtualUseCallbackTy SPMDBarrierUseCB = + [&](Attributor &A, const AbstractAttribute *QueryingAA) { + // Whenever we perform SPMDzation with guarding we will insert + // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is + // nothing to guard, or there are no parallel regions, we don't need + // the calls. + if (!SPMDCompatibilityTracker.isValidState()) + return AddDependence(A, this, QueryingAA); + if (SPMDCompatibilityTracker.empty()) + return AddDependence(A, this, QueryingAA); + if (!mayContainParallelRegion()) + return AddDependence(A, this, QueryingAA); + return false; + }; + RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB); } /// Sanitize the string \p S such that it is a suitable global symbol name. @@ -3318,77 +3618,29 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!KernelInitCB || !KernelDeinitCB) return ChangeStatus::UNCHANGED; + /// Insert nested Parallelism global variable + Function *Kernel = getAnchorScope(); + Module &M = *Kernel->getParent(); + Type *Int8Ty = Type::getInt8Ty(M.getContext()); + new GlobalVariable(M, Int8Ty, /* isConstant */ true, + GlobalValue::WeakAnyLinkage, + ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0), + Kernel->getName() + "_nested_parallelism"); + // If we can we change the execution mode to SPMD-mode otherwise we build a // custom state machine. ChangeStatus Changed = ChangeStatus::UNCHANGED; - if (!changeToSPMDMode(A, Changed)) - return buildCustomStateMachine(A); + if (!changeToSPMDMode(A, Changed)) { + if (!KernelInitCB->getCalledFunction()->isDeclaration()) + return buildCustomStateMachine(A); + } return Changed; } - bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) { - if (!mayContainParallelRegion()) - return false; - + void insertInstructionGuardsHelper(Attributor &A) { auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); - if (!SPMDCompatibilityTracker.isAssumed()) { - for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) { - if (!NonCompatibleI) - continue; - - // Skip diagnostics on calls to known OpenMP runtime functions for now. - if (auto *CB = dyn_cast<CallBase>(NonCompatibleI)) - if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction())) - continue; - - auto Remark = [&](OptimizationRemarkAnalysis ORA) { - ORA << "Value has potential side effects preventing SPMD-mode " - "execution"; - if (isa<CallBase>(NonCompatibleI)) { - ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to " - "the called function to override"; - } - return ORA << "."; - }; - A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121", - Remark); - - LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: " - << *NonCompatibleI << "\n"); - } - - return false; - } - - // Get the actual kernel, could be the caller of the anchor scope if we have - // a debug wrapper. - Function *Kernel = getAnchorScope(); - if (Kernel->hasLocalLinkage()) { - assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper."); - auto *CB = cast<CallBase>(Kernel->user_back()); - Kernel = CB->getCaller(); - } - assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!"); - - // Check if the kernel is already in SPMD mode, if so, return success. - GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( - (Kernel->getName() + "_exec_mode").str()); - assert(ExecMode && "Kernel without exec mode?"); - assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!"); - - // Set the global exec mode flag to indicate SPMD-Generic mode. - assert(isa<ConstantInt>(ExecMode->getInitializer()) && - "ExecMode is not an integer!"); - const int8_t ExecModeVal = - cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue(); - if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC) - return true; - - // We will now unconditionally modify the IR, indicate a change. - Changed = ChangeStatus::CHANGED; - auto CreateGuardedRegion = [&](Instruction *RegionStartI, Instruction *RegionEndI) { LoopInfo *LI = nullptr; @@ -3605,6 +3857,125 @@ struct AAKernelInfoFunction : AAKernelInfo { for (auto &GR : GuardedRegions) CreateGuardedRegion(GR.first, GR.second); + } + + void forceSingleThreadPerWorkgroupHelper(Attributor &A) { + // Only allow 1 thread per workgroup to continue executing the user code. + // + // InitCB = __kmpc_target_init(...) + // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block(); + // if (ThreadIdInBlock != 0) return; + // UserCode: + // // user code + // + auto &Ctx = getAnchorValue().getContext(); + Function *Kernel = getAssociatedFunction(); + assert(Kernel && "Expected an associated function!"); + + // Create block for user code to branch to from initial block. + BasicBlock *InitBB = KernelInitCB->getParent(); + BasicBlock *UserCodeBB = InitBB->splitBasicBlock( + KernelInitCB->getNextNode(), "main.thread.user_code"); + BasicBlock *ReturnBB = + BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB); + + // Register blocks with attributor: + A.registerManifestAddedBasicBlock(*InitBB); + A.registerManifestAddedBasicBlock(*UserCodeBB); + A.registerManifestAddedBasicBlock(*ReturnBB); + + // Debug location: + const DebugLoc &DLoc = KernelInitCB->getDebugLoc(); + ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc); + InitBB->getTerminator()->eraseFromParent(); + + // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block. + Module &M = *Kernel->getParent(); + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + FunctionCallee ThreadIdInBlockFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_hardware_thread_id_in_block); + + // Get thread ID in block. + CallInst *ThreadIdInBlock = + CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB); + OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock); + ThreadIdInBlock->setDebugLoc(DLoc); + + // Eliminate all threads in the block with ID not equal to 0: + Instruction *IsMainThread = + ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock, + ConstantInt::get(ThreadIdInBlock->getType(), 0), + "thread.is_main", InitBB); + IsMainThread->setDebugLoc(DLoc); + BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB); + } + + bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) { + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); + + if (!SPMDCompatibilityTracker.isAssumed()) { + for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) { + if (!NonCompatibleI) + continue; + + // Skip diagnostics on calls to known OpenMP runtime functions for now. + if (auto *CB = dyn_cast<CallBase>(NonCompatibleI)) + if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction())) + continue; + + auto Remark = [&](OptimizationRemarkAnalysis ORA) { + ORA << "Value has potential side effects preventing SPMD-mode " + "execution"; + if (isa<CallBase>(NonCompatibleI)) { + ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to " + "the called function to override"; + } + return ORA << "."; + }; + A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121", + Remark); + + LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: " + << *NonCompatibleI << "\n"); + } + + return false; + } + + // Get the actual kernel, could be the caller of the anchor scope if we have + // a debug wrapper. + Function *Kernel = getAnchorScope(); + if (Kernel->hasLocalLinkage()) { + assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper."); + auto *CB = cast<CallBase>(Kernel->user_back()); + Kernel = CB->getCaller(); + } + assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!"); + + // Check if the kernel is already in SPMD mode, if so, return success. + GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( + (Kernel->getName() + "_exec_mode").str()); + assert(ExecMode && "Kernel without exec mode?"); + assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!"); + + // Set the global exec mode flag to indicate SPMD-Generic mode. + assert(isa<ConstantInt>(ExecMode->getInitializer()) && + "ExecMode is not an integer!"); + const int8_t ExecModeVal = + cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue(); + if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC) + return true; + + // We will now unconditionally modify the IR, indicate a change. + Changed = ChangeStatus::CHANGED; + + // Do not use instruction guards when no parallel is present inside + // the target region. + if (mayContainParallelRegion()) + insertInstructionGuardsHelper(A); + else + forceSingleThreadPerWorkgroupHelper(A); // Adjust the global exec mode flag that tells the runtime what mode this // kernel is executed in. @@ -3618,8 +3989,6 @@ struct AAKernelInfoFunction : AAKernelInfo { const int InitModeArgNo = 1; const int DeinitModeArgNo = 1; const int InitUseStateMachineArgNo = 2; - const int InitRequiresFullRuntimeArgNo = 3; - const int DeinitRequiresFullRuntimeArgNo = 2; auto &Ctx = getAnchorValue().getContext(); A.changeUseAfterManifest( @@ -3633,12 +4002,6 @@ struct AAKernelInfoFunction : AAKernelInfo { KernelDeinitCB->getArgOperandUse(DeinitModeArgNo), *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_SPMD)); - A.changeUseAfterManifest( - KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo), - *ConstantInt::getBool(Ctx, false)); - A.changeUseAfterManifest( - KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo), - *ConstantInt::getBool(Ctx, false)); ++NumOpenMPTargetRegionKernelsSPMD; @@ -3982,23 +4345,21 @@ struct AAKernelInfoFunction : AAKernelInfo { if (!I.mayWriteToMemory()) return true; if (auto *SI = dyn_cast<StoreInst>(&I)) { - SmallVector<const Value *> Objects; - getUnderlyingObjects(SI->getPointerOperand(), Objects); - if (llvm::all_of(Objects, - [](const Value *Obj) { return isa<AllocaInst>(Obj); })) - return true; - // Check for AAHeapToStack moved objects which must not be guarded. + const auto &UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>( + *this, IRPosition::value(*SI->getPointerOperand()), + DepClassTy::OPTIONAL); auto &HS = A.getAAFor<AAHeapToStack>( *this, IRPosition::function(*I.getFunction()), DepClassTy::OPTIONAL); - if (llvm::all_of(Objects, [&HS](const Value *Obj) { - auto *CB = dyn_cast<CallBase>(Obj); - if (!CB) - return false; - return HS.isAssumedHeapToStack(*CB); - })) { + if (UnderlyingObjsAA.forallUnderlyingObjects([&](Value &Obj) { + if (AA::isAssumedThreadLocalObject(A, Obj, *this)) + return true; + // Check for AAHeapToStack moved objects which must not be + // guarded. + auto *CB = dyn_cast<CallBase>(&Obj); + return CB && HS.isAssumedHeapToStack(*CB); + })) return true; - } } // Insert instruction that needs guarding. @@ -4020,28 +4381,30 @@ struct AAKernelInfoFunction : AAKernelInfo { updateReachingKernelEntries(A, AllReachingKernelsKnown); UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown; - if (!ParallelLevels.isValidState()) - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - else if (!ReachingKernelEntries.isValidState()) - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - else if (!SPMDCompatibilityTracker.empty()) { - // Check if all reaching kernels agree on the mode as we can otherwise - // not guard instructions. We might not be sure about the mode so we - // we cannot fix the internal spmd-zation state either. - int SPMD = 0, Generic = 0; - for (auto *Kernel : ReachingKernelEntries) { - auto &CBAA = A.getAAFor<AAKernelInfo>( - *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL); - if (CBAA.SPMDCompatibilityTracker.isValidState() && - CBAA.SPMDCompatibilityTracker.isAssumed()) - ++SPMD; - else - ++Generic; - if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint()) - UsedAssumedInformationFromReachingKernels = true; - } - if (SPMD != 0 && Generic != 0) + if (!SPMDCompatibilityTracker.empty()) { + if (!ParallelLevels.isValidState()) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + else if (!ReachingKernelEntries.isValidState()) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + else { + // Check if all reaching kernels agree on the mode as we can otherwise + // not guard instructions. We might not be sure about the mode so we + // we cannot fix the internal spmd-zation state either. + int SPMD = 0, Generic = 0; + for (auto *Kernel : ReachingKernelEntries) { + auto &CBAA = A.getAAFor<AAKernelInfo>( + *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL); + if (CBAA.SPMDCompatibilityTracker.isValidState() && + CBAA.SPMDCompatibilityTracker.isAssumed()) + ++SPMD; + else + ++Generic; + if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint()) + UsedAssumedInformationFromReachingKernels = true; + } + if (SPMD != 0 && Generic != 0) + SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + } } } @@ -4077,15 +4440,6 @@ struct AAKernelInfoFunction : AAKernelInfo { ReachedUnknownParallelRegions.indicateOptimisticFixpoint(); } - // If we are sure there are no parallel regions in the kernel we do not - // want SPMD mode. - if (IsKernelEntry && ReachedUnknownParallelRegions.isAtFixpoint() && - ReachedKnownParallelRegions.isAtFixpoint() && - ReachedUnknownParallelRegions.isValidState() && - ReachedKnownParallelRegions.isValidState() && - !mayContainParallelRegion()) - SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - // If we haven't used any assumed information for the SPMD state we can fix // it. if (!UsedAssumedInformationInCheckRWInst && @@ -4288,6 +4642,12 @@ struct AAKernelInfoCallSite : AAKernelInfo { if (auto *ParallelRegion = dyn_cast<Function>( CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) { ReachedKnownParallelRegions.insert(ParallelRegion); + /// Check nested parallelism + auto &FnAA = A.getAAFor<AAKernelInfo>( + *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL); + NestedParallelism |= !FnAA.getState().isValidState() || + !FnAA.ReachedKnownParallelRegions.empty() || + !FnAA.ReachedUnknownParallelRegions.empty(); break; } // The condition above should usually get the parallel region function @@ -4419,10 +4779,10 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { if (!SimplifiedValue) return Str + std::string("none"); - if (!SimplifiedValue.value()) + if (!*SimplifiedValue) return Str + std::string("nullptr"); - if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.value())) + if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue)) return Str + std::to_string(CI->getSExtValue()); return Str + std::string("unknown"); @@ -4445,9 +4805,9 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { A.registerSimplificationCallback( IRPosition::callsite_returned(CB), [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> Optional<Value *> { + bool &UsedAssumedInformation) -> std::optional<Value *> { assert((isValidState() || - (SimplifiedValue && SimplifiedValue.value() == nullptr)) && + (SimplifiedValue && *SimplifiedValue == nullptr)) && "Unexpected invalid state!"); if (!isAtFixpoint()) { @@ -4465,9 +4825,6 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { case OMPRTL___kmpc_is_spmd_exec_mode: Changed |= foldIsSPMDExecMode(A); break; - case OMPRTL___kmpc_is_generic_main_thread_id: - Changed |= foldIsGenericMainThread(A); - break; case OMPRTL___kmpc_parallel_level: Changed |= foldParallelLevel(A); break; @@ -4522,7 +4879,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { private: /// Fold __kmpc_is_spmd_exec_mode into a constant if possible. ChangeStatus foldIsSPMDExecMode(Attributor &A) { - Optional<Value *> SimplifiedValueBefore = SimplifiedValue; + std::optional<Value *> SimplifiedValueBefore = SimplifiedValue; unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; @@ -4582,31 +4939,9 @@ private: : ChangeStatus::CHANGED; } - /// Fold __kmpc_is_generic_main_thread_id into a constant if possible. - ChangeStatus foldIsGenericMainThread(Attributor &A) { - Optional<Value *> SimplifiedValueBefore = SimplifiedValue; - - CallBase &CB = cast<CallBase>(getAssociatedValue()); - Function *F = CB.getFunction(); - const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>( - *this, IRPosition::function(*F), DepClassTy::REQUIRED); - - if (!ExecutionDomainAA.isValidState()) - return indicatePessimisticFixpoint(); - - auto &Ctx = getAnchorValue().getContext(); - if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB)) - SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true); - else - return indicatePessimisticFixpoint(); - - return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED - : ChangeStatus::CHANGED; - } - /// Fold __kmpc_parallel_level into a constant if possible. ChangeStatus foldParallelLevel(Attributor &A) { - Optional<Value *> SimplifiedValueBefore = SimplifiedValue; + std::optional<Value *> SimplifiedValueBefore = SimplifiedValue; auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); @@ -4668,7 +5003,7 @@ private: ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) { // Specialize only if all the calls agree with the attribute constant value int32_t CurrentAttrValue = -1; - Optional<Value *> SimplifiedValueBefore = SimplifiedValue; + std::optional<Value *> SimplifiedValueBefore = SimplifiedValue; auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); @@ -4678,10 +5013,7 @@ private: // Iterate over the kernels that reach this function for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { - int32_t NextAttrVal = -1; - if (K->hasFnAttribute(Attr)) - NextAttrVal = - std::stoi(K->getFnAttribute(Attr).getValueAsString().str()); + int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1); if (NextAttrVal == -1 || (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal)) @@ -4701,7 +5033,7 @@ private: /// An optional value the associated value is assumed to fold to. That is, we /// assume the associated value (which is a call) can be replaced by this /// simplified value. - Optional<Value *> SimplifiedValue; + std::optional<Value *> SimplifiedValue; /// The runtime function kind of the callee of the associated call site. RuntimeFunction RFKind; @@ -4744,7 +5076,6 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; InitRFI.foreachUse(SCC, CreateKernelInfoCB); - registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id); registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode); registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level); registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block); @@ -4752,32 +5083,27 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { } // Create CallSite AA for all Getters. - for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { - auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; + if (DeduceICVValues) { + for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { + auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; - auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; + auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; - auto CreateAA = [&](Use &U, Function &Caller) { - CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); - if (!CI) - return false; + auto CreateAA = [&](Use &U, Function &Caller) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); + if (!CI) + return false; - auto &CB = cast<CallBase>(*CI); + auto &CB = cast<CallBase>(*CI); - IRPosition CBPos = IRPosition::callsite_function(CB); - A.getOrCreateAAFor<AAICVTracker>(CBPos); - return false; - }; + IRPosition CBPos = IRPosition::callsite_function(CB); + A.getOrCreateAAFor<AAICVTracker>(CBPos); + return false; + }; - GetterRFI.foreachUse(SCC, CreateAA); + GetterRFI.foreachUse(SCC, CreateAA); + } } - auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; - auto CreateAA = [&](Use &U, Function &F) { - A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F)); - return false; - }; - if (!DisableOpenMPOptDeglobalization) - GlobalizationRFI.foreachUse(SCC, CreateAA); // Create an ExecutionDomain AA for every function and a HeapToStack AA for // every function if there is a device kernel. @@ -4788,17 +5114,44 @@ void OpenMPOpt::registerAAs(bool IsModulePass) { if (F->isDeclaration()) continue; - A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F)); - if (!DisableOpenMPOptDeglobalization) - A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F)); - - for (auto &I : instructions(*F)) { - if (auto *LI = dyn_cast<LoadInst>(&I)) { - bool UsedAssumedInformation = false; - A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr, - UsedAssumedInformation, AA::Interprocedural); - } else if (auto *SI = dyn_cast<StoreInst>(&I)) { - A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI)); + // We look at internal functions only on-demand but if any use is not a + // direct call or outside the current set of analyzed functions, we have + // to do it eagerly. + if (F->hasLocalLinkage()) { + if (llvm::all_of(F->uses(), [this](const Use &U) { + const auto *CB = dyn_cast<CallBase>(U.getUser()); + return CB && CB->isCallee(&U) && + A.isRunOn(const_cast<Function *>(CB->getCaller())); + })) + continue; + } + registerAAsForFunction(A, *F); + } +} + +void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) { + if (!DisableOpenMPOptDeglobalization) + A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F)); + A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F)); + if (!DisableOpenMPOptDeglobalization) + A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F)); + + for (auto &I : instructions(F)) { + if (auto *LI = dyn_cast<LoadInst>(&I)) { + bool UsedAssumedInformation = false; + A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr, + UsedAssumedInformation, AA::Interprocedural); + continue; + } + if (auto *SI = dyn_cast<StoreInst>(&I)) { + A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI)); + continue; + } + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { + if (II->getIntrinsicID() == Intrinsic::assume) { + A.getOrCreateAAFor<AAPotentialValues>( + IRPosition::value(*II->getArgOperand(0))); + continue; } } } @@ -4970,10 +5323,13 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { } // Look at every function in the Module unless it was internalized. + SetVector<Function *> Functions; SmallVector<Function *, 16> SCC; for (Function &F : M) - if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) + if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) { SCC.push_back(&F); + Functions.insert(&F); + } if (SCC.empty()) return PreservedAnalyses::all(); @@ -4987,18 +5343,19 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { BumpPtrAllocator Allocator; CallGraphUpdater CGUpdater; - SetVector<Function *> Functions(SCC.begin(), SCC.end()); - OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels); + OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; AttributorConfig AC(CGUpdater); AC.DefaultInitializeLiveInternals = false; + AC.IsModulePass = true; AC.RewriteSignatures = false; AC.MaxFixpointIterations = MaxFixpointIterations; AC.OREGetter = OREGetter; AC.PassName = DEBUG_TYPE; + AC.InitializationCallback = OpenMPOpt::registerAAsForFunction; Attributor A(Functions, InfoCache, AC); @@ -5062,7 +5419,7 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, SetVector<Function *> Functions(SCC.begin(), SCC.end()); OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, - /*CGSCC*/ Functions, Kernels); + /*CGSCC*/ &Functions, Kernels); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; @@ -5074,6 +5431,7 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, AC.MaxFixpointIterations = MaxFixpointIterations; AC.OREGetter = OREGetter; AC.PassName = DEBUG_TYPE; + AC.InitializationCallback = OpenMPOpt::registerAAsForFunction; Attributor A(Functions, InfoCache, AC); @@ -5089,90 +5447,9 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::all(); } -namespace { - -struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { - CallGraphUpdater CGUpdater; - static char ID; - - OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) { - initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - CallGraphSCCPass::getAnalysisUsage(AU); - } - - bool runOnSCC(CallGraphSCC &CGSCC) override { - if (!containsOpenMP(CGSCC.getCallGraph().getModule())) - return false; - if (DisableOpenMPOptimizations || skipSCC(CGSCC)) - return false; - - SmallVector<Function *, 16> SCC; - // If there are kernels in the module, we have to run on all SCC's. - for (CallGraphNode *CGN : CGSCC) { - Function *Fn = CGN->getFunction(); - if (!Fn || Fn->isDeclaration()) - continue; - SCC.push_back(Fn); - } - - if (SCC.empty()) - return false; - - Module &M = CGSCC.getCallGraph().getModule(); - KernelSet Kernels = getDeviceKernels(M); - - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - CGUpdater.initialize(CG, CGSCC); - - // Maintain a map of functions to avoid rebuilding the ORE - DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap; - auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & { - std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F]; - if (!ORE) - ORE = std::make_unique<OptimizationRemarkEmitter>(F); - return *ORE; - }; - - AnalysisGetter AG; - SetVector<Function *> Functions(SCC.begin(), SCC.end()); - BumpPtrAllocator Allocator; - OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, - Allocator, - /*CGSCC*/ Functions, Kernels); - - unsigned MaxFixpointIterations = - (isOpenMPDevice(M)) ? SetFixpointIterations : 32; - - AttributorConfig AC(CGUpdater); - AC.DefaultInitializeLiveInternals = false; - AC.IsModulePass = false; - AC.RewriteSignatures = false; - AC.MaxFixpointIterations = MaxFixpointIterations; - AC.OREGetter = OREGetter; - AC.PassName = DEBUG_TYPE; - - Attributor A(Functions, InfoCache, AC); - - OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); - bool Result = OMPOpt.run(false); - - if (PrintModuleAfterOptimizations) - LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M); - - return Result; - } - - bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); } -}; - -} // end anonymous namespace - KernelSet llvm::omp::getDeviceKernels(Module &M) { // TODO: Create a more cross-platform way of determining device kernels. - NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations"); + NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations"); KernelSet Kernels; if (!MD) @@ -5213,15 +5490,3 @@ bool llvm::omp::isOpenMPDevice(Module &M) { return true; } - -char OpenMPOptCGSCCLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", - "OpenMP specific optimizations", false, false) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", - "OpenMP specific optimizations", false, false) - -Pass *llvm::createOpenMPOptCGSCCLegacyPass() { - return new OpenMPOptCGSCCLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp index 54c72bdbb203..310e4d4164a5 100644 --- a/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -14,7 +14,6 @@ #include "llvm/Transforms/IPO/PartialInlining.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -40,6 +39,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/User.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -716,8 +716,7 @@ static bool hasProfileData(const Function &F, const FunctionOutliningInfo &OI) { BranchInst *BR = dyn_cast<BranchInst>(E->getTerminator()); if (!BR || BR->isUnconditional()) continue; - uint64_t T, F; - if (BR->extractProfMetadata(T, F)) + if (hasBranchWeightMD(*BR)) return true; } return false; @@ -752,7 +751,7 @@ BranchProbability PartialInlinerImpl::getOutliningCallBBRelativeFreq( // is predicted to be less likely, the predicted probablity is usually // higher than the actual. For instance, the actual probability of the // less likely target is only 5%, but the guessed probablity can be - // 40%. In the latter case, there is no need for further adjustement. + // 40%. In the latter case, there is no need for further adjustment. // FIXME: add an option for this. if (OutlineRegionRelFreq < BranchProbability(45, 100)) return OutlineRegionRelFreq; @@ -853,6 +852,7 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB, TargetTransformInfo *TTI) { InstructionCost InlineCost = 0; const DataLayout &DL = BB->getParent()->getParent()->getDataLayout(); + int InstrCost = InlineConstants::getInstrCost(); for (Instruction &I : BB->instructionsWithoutDebug()) { // Skip free instructions. switch (I.getOpcode()) { @@ -899,10 +899,10 @@ PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB, } if (SwitchInst *SI = dyn_cast<SwitchInst>(&I)) { - InlineCost += (SI->getNumCases() + 1) * InlineConstants::InstrCost; + InlineCost += (SI->getNumCases() + 1) * InstrCost; continue; } - InlineCost += InlineConstants::InstrCost; + InlineCost += InstrCost; } return InlineCost; @@ -931,7 +931,7 @@ PartialInlinerImpl::computeOutliningCosts(FunctionCloner &Cloner) const { // additional unconditional branches. Those branches will be eliminated // later with bb layout. The cost should be adjusted accordingly: OutlinedFunctionCost -= - 2 * InlineConstants::InstrCost * Cloner.OutlinedFunctions.size(); + 2 * InlineConstants::getInstrCost() * Cloner.OutlinedFunctions.size(); InstructionCost OutliningRuntimeOverhead = OutliningFuncCallCost + @@ -1081,10 +1081,8 @@ void PartialInlinerImpl::FunctionCloner::normalizeReturnBlock() const { return; auto IsTrivialPhi = [](PHINode *PN) -> Value * { - Value *CommonValue = PN->getIncomingValue(0); - if (all_of(PN->incoming_values(), - [&](Value *V) { return V == CommonValue; })) - return CommonValue; + if (llvm::all_equal(PN->incoming_values())) + return PN->getIncomingValue(0); return nullptr; }; @@ -1351,16 +1349,13 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { if (Cloner.OutlinedFunctions.empty()) return false; - int SizeCost = 0; - BlockFrequency WeightedRcost; - int NonWeightedRcost; - auto OutliningCosts = computeOutliningCosts(Cloner); - assert(std::get<0>(OutliningCosts).isValid() && - std::get<1>(OutliningCosts).isValid() && "Expected valid costs"); - SizeCost = *std::get<0>(OutliningCosts).getValue(); - NonWeightedRcost = *std::get<1>(OutliningCosts).getValue(); + InstructionCost SizeCost = std::get<0>(OutliningCosts); + InstructionCost NonWeightedRcost = std::get<1>(OutliningCosts); + + assert(SizeCost.isValid() && NonWeightedRcost.isValid() && + "Expected valid costs"); // Only calculate RelativeToEntryFreq when we are doing single region // outlining. @@ -1375,7 +1370,8 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { // execute the calls to outlined functions. RelativeToEntryFreq = BranchProbability(0, 1); - WeightedRcost = BlockFrequency(NonWeightedRcost) * RelativeToEntryFreq; + BlockFrequency WeightedRcost = + BlockFrequency(*NonWeightedRcost.getValue()) * RelativeToEntryFreq; // The call sequence(s) to the outlined function(s) are larger than the sum of // the original outlined region size(s), it does not increase the chances of @@ -1436,7 +1432,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { InlineFunctionInfo IFI(nullptr, GetAssumptionCache, &PSI); // We can only forward varargs when we outlined a single region, else we // bail on vararg functions. - if (!InlineFunction(*CB, IFI, nullptr, true, + if (!InlineFunction(*CB, IFI, /*MergeAttributes=*/false, nullptr, true, (Cloner.ClonedOI ? Cloner.OutlinedFunctions.back().first : nullptr)) .isSuccess()) @@ -1492,16 +1488,6 @@ bool PartialInlinerImpl::run(Module &M) { if (CurrFunc->use_empty()) continue; - bool Recursive = false; - for (User *U : CurrFunc->users()) - if (Instruction *I = dyn_cast<Instruction>(U)) - if (I->getParent()->getParent() == CurrFunc) { - Recursive = true; - break; - } - if (Recursive) - continue; - std::pair<bool, Function *> Result = unswitchFunction(*CurrFunc); if (Result.second) Worklist.push_back(Result.second); diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp index f1b6f2bb7de4..6b91c8494f39 100644 --- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -15,8 +15,6 @@ #include "llvm-c/Transforms/PassManagerBuilder.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/CFLAndersAliasAnalysis.h" -#include "llvm/Analysis/CFLSteensAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -43,129 +41,6 @@ using namespace llvm; -namespace llvm { -cl::opt<bool> RunPartialInlining("enable-partial-inlining", cl::Hidden, - cl::desc("Run Partial inlinining pass")); - -static cl::opt<bool> -UseGVNAfterVectorization("use-gvn-after-vectorization", - cl::init(false), cl::Hidden, - cl::desc("Run GVN instead of Early CSE after vectorization passes")); - -cl::opt<bool> ExtraVectorizerPasses( - "extra-vectorizer-passes", cl::init(false), cl::Hidden, - cl::desc("Run cleanup optimization passes after vectorization.")); - -static cl::opt<bool> -RunLoopRerolling("reroll-loops", cl::Hidden, - cl::desc("Run the loop rerolling pass")); - -cl::opt<bool> RunNewGVN("enable-newgvn", cl::init(false), cl::Hidden, - cl::desc("Run the NewGVN pass")); - -// Experimental option to use CFL-AA -static cl::opt<::CFLAAType> - UseCFLAA("use-cfl-aa", cl::init(::CFLAAType::None), cl::Hidden, - cl::desc("Enable the new, experimental CFL alias analysis"), - cl::values(clEnumValN(::CFLAAType::None, "none", "Disable CFL-AA"), - clEnumValN(::CFLAAType::Steensgaard, "steens", - "Enable unification-based CFL-AA"), - clEnumValN(::CFLAAType::Andersen, "anders", - "Enable inclusion-based CFL-AA"), - clEnumValN(::CFLAAType::Both, "both", - "Enable both variants of CFL-AA"))); - -cl::opt<bool> EnableLoopInterchange( - "enable-loopinterchange", cl::init(false), cl::Hidden, - cl::desc("Enable the experimental LoopInterchange Pass")); - -cl::opt<bool> EnableUnrollAndJam("enable-unroll-and-jam", cl::init(false), - cl::Hidden, - cl::desc("Enable Unroll And Jam Pass")); - -cl::opt<bool> EnableLoopFlatten("enable-loop-flatten", cl::init(false), - cl::Hidden, - cl::desc("Enable the LoopFlatten Pass")); - -cl::opt<bool> EnableDFAJumpThreading("enable-dfa-jump-thread", - cl::desc("Enable DFA jump threading."), - cl::init(false), cl::Hidden); - -cl::opt<bool> EnableHotColdSplit("hot-cold-split", - cl::desc("Enable hot-cold splitting pass")); - -cl::opt<bool> EnableIROutliner("ir-outliner", cl::init(false), cl::Hidden, - cl::desc("Enable ir outliner pass")); - -static cl::opt<bool> UseLoopVersioningLICM( - "enable-loop-versioning-licm", cl::init(false), cl::Hidden, - cl::desc("Enable the experimental Loop Versioning LICM pass")); - -cl::opt<bool> - DisablePreInliner("disable-preinline", cl::init(false), cl::Hidden, - cl::desc("Disable pre-instrumentation inliner")); - -cl::opt<int> PreInlineThreshold( - "preinline-threshold", cl::Hidden, cl::init(75), - cl::desc("Control the amount of inlining in pre-instrumentation inliner " - "(default = 75)")); - -cl::opt<bool> - EnableGVNHoist("enable-gvn-hoist", - cl::desc("Enable the GVN hoisting pass (default = off)")); - -static cl::opt<bool> - DisableLibCallsShrinkWrap("disable-libcalls-shrinkwrap", cl::init(false), - cl::Hidden, - cl::desc("Disable shrink-wrap library calls")); - -cl::opt<bool> - EnableGVNSink("enable-gvn-sink", - cl::desc("Enable the GVN sinking pass (default = off)")); - -// This option is used in simplifying testing SampleFDO optimizations for -// profile loading. -cl::opt<bool> - EnableCHR("enable-chr", cl::init(true), cl::Hidden, - cl::desc("Enable control height reduction optimization (CHR)")); - -cl::opt<bool> FlattenedProfileUsed( - "flattened-profile-used", cl::init(false), cl::Hidden, - cl::desc("Indicate the sample profile being used is flattened, i.e., " - "no inline hierachy exists in the profile. ")); - -cl::opt<bool> EnableOrderFileInstrumentation( - "enable-order-file-instrumentation", cl::init(false), cl::Hidden, - cl::desc("Enable order file instrumentation (default = off)")); - -cl::opt<bool> EnableMatrix( - "enable-matrix", cl::init(false), cl::Hidden, - cl::desc("Enable lowering of the matrix intrinsics")); - -cl::opt<bool> EnableConstraintElimination( - "enable-constraint-elimination", cl::init(false), cl::Hidden, - cl::desc( - "Enable pass to eliminate conditions based on linear constraints.")); - -cl::opt<bool> EnableFunctionSpecialization( - "enable-function-specialization", cl::init(false), cl::Hidden, - cl::desc("Enable Function Specialization pass")); - -cl::opt<AttributorRunOption> AttributorRun( - "attributor-enable", cl::Hidden, cl::init(AttributorRunOption::NONE), - cl::desc("Enable the attributor inter-procedural deduction pass."), - cl::values(clEnumValN(AttributorRunOption::ALL, "all", - "enable all attributor runs"), - clEnumValN(AttributorRunOption::MODULE, "module", - "enable module-wide attributor runs"), - clEnumValN(AttributorRunOption::CGSCC, "cgscc", - "enable call graph SCC attributor runs"), - clEnumValN(AttributorRunOption::NONE, "none", - "disable attributor runs"))); - -extern cl::opt<bool> EnableKnowledgeRetention; -} // namespace llvm - PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; @@ -175,8 +50,6 @@ PassManagerBuilder::PassManagerBuilder() { SLPVectorize = false; LoopVectorize = true; LoopsInterleaved = true; - RerollLoops = RunLoopRerolling; - NewGVN = RunNewGVN; LicmMssaOptCap = SetLicmMssaOptCap; LicmMssaNoAccForPromotionCap = SetLicmMssaNoAccForPromotionCap; DisableGVNLoadPRE = false; @@ -193,81 +66,8 @@ PassManagerBuilder::~PassManagerBuilder() { delete Inliner; } -/// Set of global extensions, automatically added as part of the standard set. -static ManagedStatic< - SmallVector<std::tuple<PassManagerBuilder::ExtensionPointTy, - PassManagerBuilder::ExtensionFn, - PassManagerBuilder::GlobalExtensionID>, - 8>> - GlobalExtensions; -static PassManagerBuilder::GlobalExtensionID GlobalExtensionsCounter; - -/// Check if GlobalExtensions is constructed and not empty. -/// Since GlobalExtensions is a managed static, calling 'empty()' will trigger -/// the construction of the object. -static bool GlobalExtensionsNotEmpty() { - return GlobalExtensions.isConstructed() && !GlobalExtensions->empty(); -} - -PassManagerBuilder::GlobalExtensionID -PassManagerBuilder::addGlobalExtension(PassManagerBuilder::ExtensionPointTy Ty, - PassManagerBuilder::ExtensionFn Fn) { - auto ExtensionID = GlobalExtensionsCounter++; - GlobalExtensions->push_back(std::make_tuple(Ty, std::move(Fn), ExtensionID)); - return ExtensionID; -} - -void PassManagerBuilder::removeGlobalExtension( - PassManagerBuilder::GlobalExtensionID ExtensionID) { - // RegisterStandardPasses may try to call this function after GlobalExtensions - // has already been destroyed; doing so should not generate an error. - if (!GlobalExtensions.isConstructed()) - return; - - auto GlobalExtension = - llvm::find_if(*GlobalExtensions, [ExtensionID](const auto &elem) { - return std::get<2>(elem) == ExtensionID; - }); - assert(GlobalExtension != GlobalExtensions->end() && - "The extension ID to be removed should always be valid."); - - GlobalExtensions->erase(GlobalExtension); -} - -void PassManagerBuilder::addExtension(ExtensionPointTy Ty, ExtensionFn Fn) { - Extensions.push_back(std::make_pair(Ty, std::move(Fn))); -} - -void PassManagerBuilder::addExtensionsToPM(ExtensionPointTy ETy, - legacy::PassManagerBase &PM) const { - if (GlobalExtensionsNotEmpty()) { - for (auto &Ext : *GlobalExtensions) { - if (std::get<0>(Ext) == ETy) - std::get<1>(Ext)(*this, PM); - } - } - for (unsigned i = 0, e = Extensions.size(); i != e; ++i) - if (Extensions[i].first == ETy) - Extensions[i].second(*this, PM); -} - void PassManagerBuilder::addInitialAliasAnalysisPasses( legacy::PassManagerBase &PM) const { - switch (UseCFLAA) { - case ::CFLAAType::Steensgaard: - PM.add(createCFLSteensAAWrapperPass()); - break; - case ::CFLAAType::Andersen: - PM.add(createCFLAndersAAWrapperPass()); - break; - case ::CFLAAType::Both: - PM.add(createCFLSteensAAWrapperPass()); - PM.add(createCFLAndersAAWrapperPass()); - break; - default: - break; - } - // Add TypeBasedAliasAnalysis before BasicAliasAnalysis so that // BasicAliasAnalysis wins if they disagree. This is intended to help // support "obvious" type-punning idioms. @@ -277,19 +77,10 @@ void PassManagerBuilder::addInitialAliasAnalysisPasses( void PassManagerBuilder::populateFunctionPassManager( legacy::FunctionPassManager &FPM) { - addExtensionsToPM(EP_EarlyAsPossible, FPM); - // Add LibraryInfo if we have some. if (LibraryInfo) FPM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); - // The backends do not handle matrix intrinsics currently. - // Make sure they are also lowered in O0. - // FIXME: A lightweight version of the pass should run in the backend - // pipeline on demand. - if (EnableMatrix && OptLevel == 0) - FPM.add(createLowerMatrixIntrinsicsMinimalPass()); - if (OptLevel == 0) return; addInitialAliasAnalysisPasses(FPM); @@ -309,21 +100,6 @@ void PassManagerBuilder::addFunctionSimplificationPasses( assert(OptLevel >= 1 && "Calling function optimizer with no optimization level!"); MPM.add(createSROAPass()); MPM.add(createEarlyCSEPass(true /* Enable mem-ssa. */)); // Catch trivial redundancies - if (EnableKnowledgeRetention) - MPM.add(createAssumeSimplifyPass()); - - if (OptLevel > 1) { - if (EnableGVNHoist) - MPM.add(createGVNHoistPass()); - if (EnableGVNSink) { - MPM.add(createGVNSinkPass()); - MPM.add(createCFGSimplificationPass( - SimplifyCFGOptions().convertSwitchRangeToICmp(true))); - } - } - - if (EnableConstraintElimination) - MPM.add(createConstraintEliminationPass()); if (OptLevel > 1) { // Speculative execution if the target has divergent branches; otherwise nop. @@ -336,12 +112,9 @@ void PassManagerBuilder::addFunctionSimplificationPasses( createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp( true))); // Merge & remove BBs // Combine silly seq's - if (OptLevel > 2) - MPM.add(createAggressiveInstCombinerPass()); MPM.add(createInstructionCombiningPass()); - if (SizeLevel == 0 && !DisableLibCallsShrinkWrap) + if (SizeLevel == 0) MPM.add(createLibCallsShrinkWrapPass()); - addExtensionsToPM(EP_Peephole, MPM); // TODO: Investigate the cost/benefit of tail call elimination on debugging. if (OptLevel > 1) @@ -351,11 +124,6 @@ void PassManagerBuilder::addFunctionSimplificationPasses( true))); // Merge & remove BBs MPM.add(createReassociatePass()); // Reassociate expressions - // The matrix extension can introduce large vector operations early, which can - // benefit from running vector-combine early on. - if (EnableMatrix) - MPM.add(createVectorCombinePass()); - // Begin the loop pass pipeline. // The simple loop unswitch pass relies on separate cleanup passes. Schedule @@ -385,22 +153,13 @@ void PassManagerBuilder::addFunctionSimplificationPasses( SimplifyCFGOptions().convertSwitchRangeToICmp(true))); MPM.add(createInstructionCombiningPass()); // We resume loop passes creating a second loop pipeline here. - if (EnableLoopFlatten) { - MPM.add(createLoopFlattenPass()); // Flatten loops - MPM.add(createLoopSimplifyCFGPass()); - } MPM.add(createLoopIdiomPass()); // Recognize idioms like memset. MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars - addExtensionsToPM(EP_LateLoopOptimizations, MPM); MPM.add(createLoopDeletionPass()); // Delete dead loops - if (EnableLoopInterchange) - MPM.add(createLoopInterchangePass()); // Interchange loops - // Unroll small loops and perform peeling. MPM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, ForgetAllSCEVInLoopUnroll)); - addExtensionsToPM(EP_LoopOptimizerEnd, MPM); // This ends the loop pass pipelines. // Break up allocas that may now be splittable after loop unrolling. @@ -408,14 +167,10 @@ void PassManagerBuilder::addFunctionSimplificationPasses( if (OptLevel > 1) { MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds - MPM.add(NewGVN ? createNewGVNPass() - : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies + MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies } MPM.add(createSCCPPass()); // Constant prop with SCCP - if (EnableConstraintElimination) - MPM.add(createConstraintEliminationPass()); - // Delete dead bit computations (instcombine runs after to fold away the dead // computations, and then ADCE will run later to exploit any new DCE // opportunities that creates). @@ -424,11 +179,7 @@ void PassManagerBuilder::addFunctionSimplificationPasses( // Run instcombine after redundancy elimination to exploit opportunities // opened up by them. MPM.add(createInstructionCombiningPass()); - addExtensionsToPM(EP_Peephole, MPM); if (OptLevel > 1) { - if (EnableDFAJumpThreading && SizeLevel == 0) - MPM.add(createDFAJumpThreadingPass()); - MPM.add(createJumpThreadingPass()); // Thread jumps MPM.add(createCorrelatedValuePropagationPass()); } @@ -442,17 +193,11 @@ void PassManagerBuilder::addFunctionSimplificationPasses( /*AllowSpeculation=*/true)); } - addExtensionsToPM(EP_ScalarOptimizerLate, MPM); - - if (RerollLoops) - MPM.add(createLoopRerollPass()); - // Merge & remove BBs and sink & hoist common instructions. MPM.add(createCFGSimplificationPass( SimplifyCFGOptions().hoistCommonInsts(true).sinkCommonInsts(true))); // Clean up after everything. MPM.add(createInstructionCombiningPass()); - addExtensionsToPM(EP_Peephole, MPM); } /// FIXME: Should LTO cause any differences to this set of passes? @@ -468,9 +213,6 @@ void PassManagerBuilder::addVectorPasses(legacy::PassManagerBase &PM, // FIXME: It would be really good to use a loop-integrated instruction // combiner for cleanup here so that the unrolling and LICM can be pipelined // across the loop nests. - // We do UnrollAndJam in a separate LPM to ensure it happens before unroll - if (EnableUnrollAndJam && !DisableUnrollLoops) - PM.add(createLoopUnrollAndJamPass(OptLevel)); PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops, ForgetAllSCEVInLoopUnroll)); PM.add(createWarnMissedTransformationsPass()); @@ -484,24 +226,6 @@ void PassManagerBuilder::addVectorPasses(legacy::PassManagerBase &PM, // Cleanup after the loop optimization passes. PM.add(createInstructionCombiningPass()); - if (OptLevel > 1 && ExtraVectorizerPasses) { - // At higher optimization levels, try to clean up any runtime overlap and - // alignment checks inserted by the vectorizer. We want to track correlated - // runtime checks for two inner loops in the same outer loop, fold any - // common computations, hoist loop-invariant aspects out of any outer loop, - // and unswitch the runtime checks if possible. Once hoisted, we may have - // dead (or speculatable) control flows or more combining opportunities. - PM.add(createEarlyCSEPass()); - PM.add(createCorrelatedValuePropagationPass()); - PM.add(createInstructionCombiningPass()); - PM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - /*AllowSpeculation=*/true)); - PM.add(createSimpleLoopUnswitchLegacyPass()); - PM.add(createCFGSimplificationPass( - SimplifyCFGOptions().convertSwitchRangeToICmp(true))); - PM.add(createInstructionCombiningPass()); - } - // Now that we've formed fast to execute loop structures, we do further // optimizations. These are run afterward as they might block doing complex // analyses and transforms such as what are needed for loop vectorization. @@ -528,24 +252,14 @@ void PassManagerBuilder::addVectorPasses(legacy::PassManagerBase &PM, // Optimize parallel scalar instruction chains into SIMD instructions. if (SLPVectorize) { PM.add(createSLPVectorizerPass()); - if (OptLevel > 1 && ExtraVectorizerPasses) - PM.add(createEarlyCSEPass()); } // Enhance/cleanup vector code. PM.add(createVectorCombinePass()); if (!IsFullLTO) { - addExtensionsToPM(EP_Peephole, PM); PM.add(createInstructionCombiningPass()); - if (EnableUnrollAndJam && !DisableUnrollLoops) { - // Unroll and Jam. We do this before unroll but need to be in a separate - // loop pass manager in order for the outer loop to be processed by - // unroll and jam before the inner loop is unrolled. - PM.add(createLoopUnrollAndJamPass(OptLevel)); - } - // Unroll small loops PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops, ForgetAllSCEVInLoopUnroll)); @@ -595,12 +309,6 @@ void PassManagerBuilder::populateModulePassManager( // builds. The function merging pass is if (MergeFunctions) MPM.add(createMergeFunctionsPass()); - else if (GlobalExtensionsNotEmpty() || !Extensions.empty()) - MPM.add(createBarrierNoopPass()); - - addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); - - MPM.add(createAnnotationRemarksLegacyPass()); return; } @@ -613,19 +321,9 @@ void PassManagerBuilder::populateModulePassManager( // Infer attributes about declarations if possible. MPM.add(createInferFunctionAttrsLegacyPass()); - // Infer attributes on declarations, call sites, arguments, etc. - if (AttributorRun & AttributorRunOption::MODULE) - MPM.add(createAttributorLegacyPass()); - - addExtensionsToPM(EP_ModuleOptimizerEarly, MPM); - if (OptLevel > 2) MPM.add(createCallSiteSplittingPass()); - // Propage constant function arguments by specializing the functions. - if (OptLevel > 2 && EnableFunctionSpecialization) - MPM.add(createFunctionSpecializationPass()); - MPM.add(createIPSCCPPass()); // IP SCCP MPM.add(createCalledValuePropagationPass()); @@ -636,7 +334,6 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createDeadArgEliminationPass()); // Dead argument elimination MPM.add(createInstructionCombiningPass()); // Clean up after IPCP & DAE - addExtensionsToPM(EP_Peephole, MPM); MPM.add( createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp( true))); // Clean up after IPCP & DAE @@ -647,7 +344,6 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createGlobalsAAWrapperPass()); // Start of CallGraph SCC passes. - MPM.add(createPruneEHPass()); // Remove dead EH info bool RunInliner = false; if (Inliner) { MPM.add(Inliner); @@ -655,18 +351,8 @@ void PassManagerBuilder::populateModulePassManager( RunInliner = true; } - // Infer attributes on declarations, call sites, arguments, etc. for an SCC. - if (AttributorRun & AttributorRunOption::CGSCC) - MPM.add(createAttributorCGSCCLegacyPass()); - - // Try to perform OpenMP specific optimizations. This is a (quick!) no-op if - // there are no OpenMP runtime calls present in the module. - if (OptLevel > 1) - MPM.add(createOpenMPOptCGSCCLegacyPass()); - MPM.add(createPostOrderFunctionAttrsLegacyPass()); - addExtensionsToPM(EP_CGSCCOptimizerLate, MPM); addFunctionSimplificationPasses(MPM); // FIXME: This is a HACK! The inliner pass above implicitly creates a CGSCC @@ -674,9 +360,6 @@ void PassManagerBuilder::populateModulePassManager( // we must insert a no-op module pass to reset the pass manager. MPM.add(createBarrierNoopPass()); - if (RunPartialInlining) - MPM.add(createPartialInliningPass()); - if (OptLevel > 1) // Remove avail extern fns and globals definitions if we aren't // compiling an object file for later LTO. For LTO we want to preserve @@ -702,17 +385,6 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createGlobalDCEPass()); } - // Scheduling LoopVersioningLICM when inlining is over, because after that - // we may see more accurate aliasing. Reason to run this late is that too - // early versioning may prevent further inlining due to increase of code - // size. By placing it just after inlining other optimizations which runs - // later might get benefit of no-alias assumption in clone loop. - if (UseLoopVersioningLICM) { - MPM.add(createLoopVersioningLICMPass()); // Do LoopVersioningLICM - MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - /*AllowSpeculation=*/true)); - } - // We add a fresh GlobalsModRef run at this point. This is particularly // useful as the above will have inlined, DCE'ed, and function-attr // propagated everything. We should at this point have a reasonably minimal @@ -733,16 +405,6 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createFloat2IntPass()); MPM.add(createLowerConstantIntrinsicsPass()); - if (EnableMatrix) { - MPM.add(createLowerMatrixIntrinsicsPass()); - // CSE the pointer arithmetic of the column vectors. This allows alias - // analysis to establish no-aliasing between loads and stores of different - // columns of the same matrix. - MPM.add(createEarlyCSEPass(false)); - } - - addExtensionsToPM(EP_VectorizerStart, MPM); - // Re-rotate loops in all our loop nests. These may have fallout out of // rotated form due to GVN or other transformations, and the vectorizer relies // on the rotated form. Disable header duplication at -Oz. @@ -766,14 +428,6 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createConstantMergePass()); // Merge dup global constants } - // See comment in the new PM for justification of scheduling splitting at - // this stage (\ref buildModuleSimplificationPipeline). - if (EnableHotColdSplit) - MPM.add(createHotColdSplittingPass()); - - if (EnableIROutliner) - MPM.add(createIROutlinerPass()); - if (MergeFunctions) MPM.add(createMergeFunctionsPass()); @@ -794,10 +448,6 @@ void PassManagerBuilder::populateModulePassManager( // resulted in single-entry-single-exit or empty blocks. Clean up the CFG. MPM.add(createCFGSimplificationPass( SimplifyCFGOptions().convertSwitchRangeToICmp(true))); - - addExtensionsToPM(EP_OptimizerLast, MPM); - - MPM.add(createAnnotationRemarksLegacyPass()); } LLVMPassManagerBuilderRef LLVMPassManagerBuilderCreate() { diff --git a/llvm/lib/Transforms/IPO/PruneEH.cpp b/llvm/lib/Transforms/IPO/PruneEH.cpp deleted file mode 100644 index e0836a9fd699..000000000000 --- a/llvm/lib/Transforms/IPO/PruneEH.cpp +++ /dev/null @@ -1,261 +0,0 @@ -//===- PruneEH.cpp - Pass which deletes unused exception handlers ---------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a simple interprocedural pass which walks the -// call-graph, turning invoke instructions into calls, iff the callee cannot -// throw an exception, and marking functions 'nounwind' if they cannot throw. -// It implements this as a bottom-up traversal of the call-graph. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/CallGraphSCCPass.h" -#include "llvm/Analysis/EHPersonalities.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/InlineAsm.h" -#include "llvm/IR/Instructions.h" -#include "llvm/InitializePasses.h" -#include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/Utils/CallGraphUpdater.h" -#include "llvm/Transforms/Utils/Local.h" -#include <algorithm> - -using namespace llvm; - -#define DEBUG_TYPE "prune-eh" - -STATISTIC(NumRemoved, "Number of invokes removed"); -STATISTIC(NumUnreach, "Number of noreturn calls optimized"); - -namespace { - struct PruneEH : public CallGraphSCCPass { - static char ID; // Pass identification, replacement for typeid - PruneEH() : CallGraphSCCPass(ID) { - initializePruneEHPass(*PassRegistry::getPassRegistry()); - } - - // runOnSCC - Analyze the SCC, performing the transformation if possible. - bool runOnSCC(CallGraphSCC &SCC) override; - }; -} -static bool SimplifyFunction(Function *F, CallGraphUpdater &CGU); -static void DeleteBasicBlock(BasicBlock *BB, CallGraphUpdater &CGU); - -char PruneEH::ID = 0; -INITIALIZE_PASS_BEGIN(PruneEH, "prune-eh", - "Remove unused exception handling info", false, false) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END(PruneEH, "prune-eh", - "Remove unused exception handling info", false, false) - -Pass *llvm::createPruneEHPass() { return new PruneEH(); } - -static bool runImpl(CallGraphUpdater &CGU, SetVector<Function *> &Functions) { -#ifndef NDEBUG - for (auto *F : Functions) - assert(F && "null Function"); -#endif - bool MadeChange = false; - - // First pass, scan all of the functions in the SCC, simplifying them - // according to what we know. - for (Function *F : Functions) - MadeChange |= SimplifyFunction(F, CGU); - - // Next, check to see if any callees might throw or if there are any external - // functions in this SCC: if so, we cannot prune any functions in this SCC. - // Definitions that are weak and not declared non-throwing might be - // overridden at linktime with something that throws, so assume that. - // If this SCC includes the unwind instruction, we KNOW it throws, so - // obviously the SCC might throw. - // - bool SCCMightUnwind = false, SCCMightReturn = false; - for (Function *F : Functions) { - if (!F->hasExactDefinition()) { - SCCMightUnwind |= !F->doesNotThrow(); - SCCMightReturn |= !F->doesNotReturn(); - } else { - bool CheckUnwind = !SCCMightUnwind && !F->doesNotThrow(); - bool CheckReturn = !SCCMightReturn && !F->doesNotReturn(); - // Determine if we should scan for InlineAsm in a naked function as it - // is the only way to return without a ReturnInst. Only do this for - // no-inline functions as functions which may be inlined cannot - // meaningfully return via assembly. - bool CheckReturnViaAsm = CheckReturn && - F->hasFnAttribute(Attribute::Naked) && - F->hasFnAttribute(Attribute::NoInline); - - if (!CheckUnwind && !CheckReturn) - continue; - - for (const BasicBlock &BB : *F) { - const Instruction *TI = BB.getTerminator(); - if (CheckUnwind && TI->mayThrow()) { - SCCMightUnwind = true; - } else if (CheckReturn && isa<ReturnInst>(TI)) { - SCCMightReturn = true; - } - - for (const Instruction &I : BB) { - if ((!CheckUnwind || SCCMightUnwind) && - (!CheckReturnViaAsm || SCCMightReturn)) - break; - - // Check to see if this function performs an unwind or calls an - // unwinding function. - if (CheckUnwind && !SCCMightUnwind && I.mayThrow()) { - bool InstMightUnwind = true; - if (const auto *CI = dyn_cast<CallInst>(&I)) { - if (Function *Callee = CI->getCalledFunction()) { - // If the callee is outside our current SCC then we may throw - // because it might. If it is inside, do nothing. - if (Functions.contains(Callee)) - InstMightUnwind = false; - } - } - SCCMightUnwind |= InstMightUnwind; - } - if (CheckReturnViaAsm && !SCCMightReturn) - if (const auto *CB = dyn_cast<CallBase>(&I)) - if (const auto *IA = dyn_cast<InlineAsm>(CB->getCalledOperand())) - if (IA->hasSideEffects()) - SCCMightReturn = true; - } - } - if (SCCMightUnwind && SCCMightReturn) - break; - } - } - - // If the SCC doesn't unwind or doesn't throw, note this fact. - if (!SCCMightUnwind || !SCCMightReturn) - for (Function *F : Functions) { - if (!SCCMightUnwind && !F->hasFnAttribute(Attribute::NoUnwind)) { - F->addFnAttr(Attribute::NoUnwind); - MadeChange = true; - } - - if (!SCCMightReturn && !F->hasFnAttribute(Attribute::NoReturn)) { - F->addFnAttr(Attribute::NoReturn); - MadeChange = true; - } - } - - for (Function *F : Functions) { - // Convert any invoke instructions to non-throwing functions in this node - // into call instructions with a branch. This makes the exception blocks - // dead. - MadeChange |= SimplifyFunction(F, CGU); - } - - return MadeChange; -} - -bool PruneEH::runOnSCC(CallGraphSCC &SCC) { - if (skipSCC(SCC)) - return false; - SetVector<Function *> Functions; - for (auto &N : SCC) { - if (auto *F = N->getFunction()) - Functions.insert(F); - } - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - CallGraphUpdater CGU; - CGU.initialize(CG, SCC); - return runImpl(CGU, Functions); -} - - -// SimplifyFunction - Given information about callees, simplify the specified -// function if we have invokes to non-unwinding functions or code after calls to -// no-return functions. -static bool SimplifyFunction(Function *F, CallGraphUpdater &CGU) { - bool MadeChange = false; - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { - if (InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) - if (II->doesNotThrow() && canSimplifyInvokeNoUnwind(F)) { - BasicBlock *UnwindBlock = II->getUnwindDest(); - removeUnwindEdge(&*BB); - - // If the unwind block is now dead, nuke it. - if (pred_empty(UnwindBlock)) - DeleteBasicBlock(UnwindBlock, CGU); // Delete the new BB. - - ++NumRemoved; - MadeChange = true; - } - - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) - if (CallInst *CI = dyn_cast<CallInst>(I++)) - if (CI->doesNotReturn() && !CI->isMustTailCall() && - !isa<UnreachableInst>(I)) { - // This call calls a function that cannot return. Insert an - // unreachable instruction after it and simplify the code. Do this - // by splitting the BB, adding the unreachable, then deleting the - // new BB. - BasicBlock *New = BB->splitBasicBlock(I); - - // Remove the uncond branch and add an unreachable. - BB->getInstList().pop_back(); - new UnreachableInst(BB->getContext(), &*BB); - - DeleteBasicBlock(New, CGU); // Delete the new BB. - MadeChange = true; - ++NumUnreach; - break; - } - } - - return MadeChange; -} - -/// DeleteBasicBlock - remove the specified basic block from the program, -/// updating the callgraph to reflect any now-obsolete edges due to calls that -/// exist in the BB. -static void DeleteBasicBlock(BasicBlock *BB, CallGraphUpdater &CGU) { - assert(pred_empty(BB) && "BB is not dead!"); - - Instruction *TokenInst = nullptr; - - for (BasicBlock::iterator I = BB->end(), E = BB->begin(); I != E; ) { - --I; - - if (I->getType()->isTokenTy()) { - TokenInst = &*I; - break; - } - - if (auto *Call = dyn_cast<CallBase>(&*I)) { - const Function *Callee = Call->getCalledFunction(); - if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) - CGU.removeCallSite(*Call); - else if (!Callee->isIntrinsic()) - CGU.removeCallSite(*Call); - } - - if (!I->use_empty()) - I->replaceAllUsesWith(PoisonValue::get(I->getType())); - } - - if (TokenInst) { - if (!TokenInst->isTerminator()) - changeToUnreachable(TokenInst->getNextNode()); - } else { - // Get the list of successors of this block. - std::vector<BasicBlock *> Succs(succ_begin(BB), succ_end(BB)); - - for (unsigned i = 0, e = Succs.size(); i != e; ++i) - Succs[i]->removePredecessor(BB); - - BB->eraseFromParent(); - } -} diff --git a/llvm/lib/Transforms/IPO/SCCP.cpp b/llvm/lib/Transforms/IPO/SCCP.cpp index 0453af184a72..5c1582ddfdae 100644 --- a/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/llvm/lib/Transforms/IPO/SCCP.cpp @@ -11,31 +11,394 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/SCCP.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueLattice.h" +#include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/InitializePasses.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ModRef.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/FunctionSpecialization.h" #include "llvm/Transforms/Scalar/SCCP.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SCCPSolver.h" using namespace llvm; +#define DEBUG_TYPE "sccp" + +STATISTIC(NumInstRemoved, "Number of instructions removed"); +STATISTIC(NumArgsElimed ,"Number of arguments constant propagated"); +STATISTIC(NumGlobalConst, "Number of globals found to be constant"); +STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); +STATISTIC(NumInstReplaced, + "Number of instructions replaced with (simpler) instruction"); + +static cl::opt<unsigned> FuncSpecializationMaxIters( + "func-specialization-max-iters", cl::init(1), cl::Hidden, cl::desc( + "The maximum number of iterations function specialization is run")); + +static void findReturnsToZap(Function &F, + SmallVector<ReturnInst *, 8> &ReturnsToZap, + SCCPSolver &Solver) { + // We can only do this if we know that nothing else can call the function. + if (!Solver.isArgumentTrackedFunction(&F)) + return; + + if (Solver.mustPreserveReturn(&F)) { + LLVM_DEBUG( + dbgs() + << "Can't zap returns of the function : " << F.getName() + << " due to present musttail or \"clang.arc.attachedcall\" call of " + "it\n"); + return; + } + + assert( + all_of(F.users(), + [&Solver](User *U) { + if (isa<Instruction>(U) && + !Solver.isBlockExecutable(cast<Instruction>(U)->getParent())) + return true; + // Non-callsite uses are not impacted by zapping. Also, constant + // uses (like blockaddresses) could stuck around, without being + // used in the underlying IR, meaning we do not have lattice + // values for them. + if (!isa<CallBase>(U)) + return true; + if (U->getType()->isStructTy()) { + return all_of(Solver.getStructLatticeValueFor(U), + [](const ValueLatticeElement &LV) { + return !SCCPSolver::isOverdefined(LV); + }); + } + + // We don't consider assume-like intrinsics to be actual address + // captures. + if (auto *II = dyn_cast<IntrinsicInst>(U)) { + if (II->isAssumeLikeIntrinsic()) + return true; + } + + return !SCCPSolver::isOverdefined(Solver.getLatticeValueFor(U)); + }) && + "We can only zap functions where all live users have a concrete value"); + + for (BasicBlock &BB : F) { + if (CallInst *CI = BB.getTerminatingMustTailCall()) { + LLVM_DEBUG(dbgs() << "Can't zap return of the block due to present " + << "musttail call : " << *CI << "\n"); + (void)CI; + return; + } + + if (auto *RI = dyn_cast<ReturnInst>(BB.getTerminator())) + if (!isa<UndefValue>(RI->getOperand(0))) + ReturnsToZap.push_back(RI); + } +} + +static bool runIPSCCP( + Module &M, const DataLayout &DL, FunctionAnalysisManager *FAM, + std::function<const TargetLibraryInfo &(Function &)> GetTLI, + std::function<TargetTransformInfo &(Function &)> GetTTI, + std::function<AssumptionCache &(Function &)> GetAC, + function_ref<AnalysisResultsForFn(Function &)> getAnalysis, + bool IsFuncSpecEnabled) { + SCCPSolver Solver(DL, GetTLI, M.getContext()); + FunctionSpecializer Specializer(Solver, M, FAM, GetTLI, GetTTI, GetAC); + + // Loop over all functions, marking arguments to those with their addresses + // taken or that are external as overdefined. + for (Function &F : M) { + if (F.isDeclaration()) + continue; + + Solver.addAnalysis(F, getAnalysis(F)); + + // Determine if we can track the function's return values. If so, add the + // function to the solver's set of return-tracked functions. + if (canTrackReturnsInterprocedurally(&F)) + Solver.addTrackedFunction(&F); + + // Determine if we can track the function's arguments. If so, add the + // function to the solver's set of argument-tracked functions. + if (canTrackArgumentsInterprocedurally(&F)) { + Solver.addArgumentTrackedFunction(&F); + continue; + } + + // Assume the function is called. + Solver.markBlockExecutable(&F.front()); + + // Assume nothing about the incoming arguments. + for (Argument &AI : F.args()) + Solver.markOverdefined(&AI); + } + + // Determine if we can track any of the module's global variables. If so, add + // the global variables we can track to the solver's set of tracked global + // variables. + for (GlobalVariable &G : M.globals()) { + G.removeDeadConstantUsers(); + if (canTrackGlobalVariableInterprocedurally(&G)) + Solver.trackValueOfGlobalVariable(&G); + } + + // Solve for constants. + Solver.solveWhileResolvedUndefsIn(M); + + if (IsFuncSpecEnabled) { + unsigned Iters = 0; + while (Iters++ < FuncSpecializationMaxIters && Specializer.run()); + } + + // Iterate over all of the instructions in the module, replacing them with + // constants if we have found them to be of constant values. + bool MadeChanges = false; + for (Function &F : M) { + if (F.isDeclaration()) + continue; + + SmallVector<BasicBlock *, 512> BlocksToErase; + + if (Solver.isBlockExecutable(&F.front())) { + bool ReplacedPointerArg = false; + for (Argument &Arg : F.args()) { + if (!Arg.use_empty() && Solver.tryToReplaceWithConstant(&Arg)) { + ReplacedPointerArg |= Arg.getType()->isPointerTy(); + ++NumArgsElimed; + } + } + + // If we replaced an argument, we may now also access a global (currently + // classified as "other" memory). Update memory attribute to reflect this. + if (ReplacedPointerArg) { + auto UpdateAttrs = [&](AttributeList AL) { + MemoryEffects ME = AL.getMemoryEffects(); + if (ME == MemoryEffects::unknown()) + return AL; + + ME |= MemoryEffects(MemoryEffects::Other, + ME.getModRef(MemoryEffects::ArgMem)); + return AL.addFnAttribute( + F.getContext(), + Attribute::getWithMemoryEffects(F.getContext(), ME)); + }; + + F.setAttributes(UpdateAttrs(F.getAttributes())); + for (User *U : F.users()) { + auto *CB = dyn_cast<CallBase>(U); + if (!CB || CB->getCalledFunction() != &F) + continue; + + CB->setAttributes(UpdateAttrs(CB->getAttributes())); + } + } + MadeChanges |= ReplacedPointerArg; + } + + SmallPtrSet<Value *, 32> InsertedValues; + for (BasicBlock &BB : F) { + if (!Solver.isBlockExecutable(&BB)) { + LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB); + ++NumDeadBlocks; + + MadeChanges = true; + + if (&BB != &F.front()) + BlocksToErase.push_back(&BB); + continue; + } + + MadeChanges |= Solver.simplifyInstsInBlock( + BB, InsertedValues, NumInstRemoved, NumInstReplaced); + } + + DomTreeUpdater DTU = IsFuncSpecEnabled && Specializer.isClonedFunction(&F) + ? DomTreeUpdater(DomTreeUpdater::UpdateStrategy::Lazy) + : Solver.getDTU(F); + + // Change dead blocks to unreachable. We do it after replacing constants + // in all executable blocks, because changeToUnreachable may remove PHI + // nodes in executable blocks we found values for. The function's entry + // block is not part of BlocksToErase, so we have to handle it separately. + for (BasicBlock *BB : BlocksToErase) { + NumInstRemoved += changeToUnreachable(BB->getFirstNonPHI(), + /*PreserveLCSSA=*/false, &DTU); + } + if (!Solver.isBlockExecutable(&F.front())) + NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHI(), + /*PreserveLCSSA=*/false, &DTU); + + BasicBlock *NewUnreachableBB = nullptr; + for (BasicBlock &BB : F) + MadeChanges |= Solver.removeNonFeasibleEdges(&BB, DTU, NewUnreachableBB); + + for (BasicBlock *DeadBB : BlocksToErase) + if (!DeadBB->hasAddressTaken()) + DTU.deleteBB(DeadBB); + + for (BasicBlock &BB : F) { + for (Instruction &Inst : llvm::make_early_inc_range(BB)) { + if (Solver.getPredicateInfoFor(&Inst)) { + if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) { + Value *Op = II->getOperand(0); + Inst.replaceAllUsesWith(Op); + Inst.eraseFromParent(); + } + } + } + } + } + } + + // If we inferred constant or undef return values for a function, we replaced + // all call uses with the inferred value. This means we don't need to bother + // actually returning anything from the function. Replace all return + // instructions with return undef. + // + // Do this in two stages: first identify the functions we should process, then + // actually zap their returns. This is important because we can only do this + // if the address of the function isn't taken. In cases where a return is the + // last use of a function, the order of processing functions would affect + // whether other functions are optimizable. + SmallVector<ReturnInst*, 8> ReturnsToZap; + + for (const auto &I : Solver.getTrackedRetVals()) { + Function *F = I.first; + const ValueLatticeElement &ReturnValue = I.second; + + // If there is a known constant range for the return value, add !range + // metadata to the function's call sites. + if (ReturnValue.isConstantRange() && + !ReturnValue.getConstantRange().isSingleElement()) { + // Do not add range metadata if the return value may include undef. + if (ReturnValue.isConstantRangeIncludingUndef()) + continue; + + auto &CR = ReturnValue.getConstantRange(); + for (User *User : F->users()) { + auto *CB = dyn_cast<CallBase>(User); + if (!CB || CB->getCalledFunction() != F) + continue; + + // Limit to cases where the return value is guaranteed to be neither + // poison nor undef. Poison will be outside any range and currently + // values outside of the specified range cause immediate undefined + // behavior. + if (!isGuaranteedNotToBeUndefOrPoison(CB, nullptr, CB)) + continue; + + // Do not touch existing metadata for now. + // TODO: We should be able to take the intersection of the existing + // metadata and the inferred range. + if (CB->getMetadata(LLVMContext::MD_range)) + continue; + + LLVMContext &Context = CB->getParent()->getContext(); + Metadata *RangeMD[] = { + ConstantAsMetadata::get(ConstantInt::get(Context, CR.getLower())), + ConstantAsMetadata::get(ConstantInt::get(Context, CR.getUpper()))}; + CB->setMetadata(LLVMContext::MD_range, MDNode::get(Context, RangeMD)); + } + continue; + } + if (F->getReturnType()->isVoidTy()) + continue; + if (SCCPSolver::isConstant(ReturnValue) || ReturnValue.isUnknownOrUndef()) + findReturnsToZap(*F, ReturnsToZap, Solver); + } + + for (auto *F : Solver.getMRVFunctionsTracked()) { + assert(F->getReturnType()->isStructTy() && + "The return type should be a struct"); + StructType *STy = cast<StructType>(F->getReturnType()); + if (Solver.isStructLatticeConstant(F, STy)) + findReturnsToZap(*F, ReturnsToZap, Solver); + } + + // Zap all returns which we've identified as zap to change. + SmallSetVector<Function *, 8> FuncZappedReturn; + for (ReturnInst *RI : ReturnsToZap) { + Function *F = RI->getParent()->getParent(); + RI->setOperand(0, UndefValue::get(F->getReturnType())); + // Record all functions that are zapped. + FuncZappedReturn.insert(F); + } + + // Remove the returned attribute for zapped functions and the + // corresponding call sites. + for (Function *F : FuncZappedReturn) { + for (Argument &A : F->args()) + F->removeParamAttr(A.getArgNo(), Attribute::Returned); + for (Use &U : F->uses()) { + CallBase *CB = dyn_cast<CallBase>(U.getUser()); + if (!CB) { + assert(isa<BlockAddress>(U.getUser()) || + (isa<Constant>(U.getUser()) && + all_of(U.getUser()->users(), [](const User *UserUser) { + return cast<IntrinsicInst>(UserUser)->isAssumeLikeIntrinsic(); + }))); + continue; + } + + for (Use &Arg : CB->args()) + CB->removeParamAttr(CB->getArgOperandNo(&Arg), Attribute::Returned); + } + } + + // If we inferred constant or undef values for globals variables, we can + // delete the global and any stores that remain to it. + for (const auto &I : make_early_inc_range(Solver.getTrackedGlobals())) { + GlobalVariable *GV = I.first; + if (SCCPSolver::isOverdefined(I.second)) + continue; + LLVM_DEBUG(dbgs() << "Found that GV '" << GV->getName() + << "' is constant!\n"); + while (!GV->use_empty()) { + StoreInst *SI = cast<StoreInst>(GV->user_back()); + SI->eraseFromParent(); + MadeChanges = true; + } + M.getGlobalList().erase(GV); + ++NumGlobalConst; + } + + return MadeChanges; +} + PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { const DataLayout &DL = M.getDataLayout(); auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto GetTLI = [&FAM](Function &F) -> const TargetLibraryInfo & { return FAM.getResult<TargetLibraryAnalysis>(F); }; - auto getAnalysis = [&FAM](Function &F) -> AnalysisResultsForFn { + auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & { + return FAM.getResult<TargetIRAnalysis>(F); + }; + auto GetAC = [&FAM](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + auto getAnalysis = [&FAM, this](Function &F) -> AnalysisResultsForFn { DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); return { std::make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)), - &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F)}; + &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F), + isFuncSpecEnabled() ? &FAM.getResult<LoopAnalysis>(F) : nullptr }; }; - if (!runIPSCCP(M, DL, GetTLI, getAnalysis)) + if (!runIPSCCP(M, DL, &FAM, GetTLI, GetTTI, GetAC, getAnalysis, + isFuncSpecEnabled())) return PreservedAnalyses::all(); PreservedAnalyses PA; @@ -67,6 +430,12 @@ public: auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); }; + auto GetTTI = [this](Function &F) -> TargetTransformInfo & { + return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + }; + auto GetAC = [this](Function &F) -> AssumptionCache & { + return this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + }; auto getAnalysis = [this](Function &F) -> AnalysisResultsForFn { DominatorTree &DT = this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); @@ -75,17 +444,19 @@ public: F, DT, this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache( F)), - nullptr, // We cannot preserve the DT or PDT with the legacy pass - nullptr}; // manager, so set them to nullptr. + nullptr, // We cannot preserve the LI, DT or PDT with the legacy pass + nullptr, // manager, so set them to nullptr. + nullptr}; }; - return runIPSCCP(M, DL, GetTLI, getAnalysis); + return runIPSCCP(M, DL, nullptr, GetTLI, GetTTI, GetAC, getAnalysis, false); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); } }; @@ -106,93 +477,3 @@ INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", // createIPSCCPPass - This is the public interface to this file. ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); } -PreservedAnalyses FunctionSpecializationPass::run(Module &M, - ModuleAnalysisManager &AM) { - const DataLayout &DL = M.getDataLayout(); - auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { - return FAM.getResult<TargetLibraryAnalysis>(F); - }; - auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & { - return FAM.getResult<TargetIRAnalysis>(F); - }; - auto GetAC = [&FAM](Function &F) -> AssumptionCache & { - return FAM.getResult<AssumptionAnalysis>(F); - }; - auto GetAnalysis = [&FAM](Function &F) -> AnalysisResultsForFn { - DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); - return {std::make_unique<PredicateInfo>( - F, DT, FAM.getResult<AssumptionAnalysis>(F)), - &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F)}; - }; - - if (!runFunctionSpecialization(M, DL, GetTLI, GetTTI, GetAC, GetAnalysis)) - return PreservedAnalyses::all(); - - PreservedAnalyses PA; - PA.preserve<DominatorTreeAnalysis>(); - PA.preserve<PostDominatorTreeAnalysis>(); - PA.preserve<FunctionAnalysisManagerModuleProxy>(); - return PA; -} - -namespace { -struct FunctionSpecializationLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid - FunctionSpecializationLegacyPass() : ModulePass(ID) {} - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - - const DataLayout &DL = M.getDataLayout(); - auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - auto GetTTI = [this](Function &F) -> TargetTransformInfo & { - return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - }; - auto GetAC = [this](Function &F) -> AssumptionCache & { - return this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - }; - - auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn { - DominatorTree &DT = - this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); - return { - std::make_unique<PredicateInfo>( - F, DT, - this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache( - F)), - nullptr, // We cannot preserve the DT or PDT with the legacy pass - nullptr}; // manager, so set them to nullptr. - }; - return runFunctionSpecialization(M, DL, GetTLI, GetTTI, GetAC, GetAnalysis); - } -}; -} // namespace - -char FunctionSpecializationLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN( - FunctionSpecializationLegacyPass, "function-specialization", - "Propagate constant arguments by specializing the function", false, false) - -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(FunctionSpecializationLegacyPass, "function-specialization", - "Propagate constant arguments by specializing the function", - false, false) - -ModulePass *llvm::createFunctionSpecializationPass() { - return new FunctionSpecializationLegacyPass(); -} diff --git a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp index 764fd57d245f..3ddf5fe20edb 100644 --- a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp +++ b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp @@ -124,13 +124,15 @@ void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) { FuncSamples = FSamples; } -Optional<uint32_t> ContextTrieNode::getFunctionSize() const { return FuncSize; } +std::optional<uint32_t> ContextTrieNode::getFunctionSize() const { + return FuncSize; +} void ContextTrieNode::addFunctionSize(uint32_t FSize) { if (!FuncSize) FuncSize = 0; - FuncSize = FuncSize.value() + FSize; + FuncSize = *FuncSize + FSize; } LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; } @@ -534,7 +536,7 @@ SampleContextTracker::getOrCreateContextPath(const SampleContext &Context, ContextTrieNode *ContextNode = &RootContext; LineLocation CallSiteLoc(0, 0); - for (auto &Callsite : Context.getContextFrames()) { + for (const auto &Callsite : Context.getContextFrames()) { // Create child node at parent line/disc location if (AllowCreate) { ContextNode = diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp index f76b886e810a..93b368fd72a6 100644 --- a/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/PriorityQueue.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/SmallVector.h" @@ -74,6 +75,7 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/MisExpect.h" #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h" #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" #include <algorithm> @@ -127,6 +129,15 @@ static cl::opt<std::string> SampleProfileRemappingFile( "sample-profile-remapping-file", cl::init(""), cl::value_desc("filename"), cl::desc("Profile remapping file loaded by -sample-profile"), cl::Hidden); +static cl::opt<bool> ReportProfileStaleness( + "report-profile-staleness", cl::Hidden, cl::init(false), + cl::desc("Compute and report stale profile statistical metrics.")); + +static cl::opt<bool> PersistProfileStaleness( + "persist-profile-staleness", cl::Hidden, cl::init(false), + cl::desc("Compute stale profile statistical metrics and write it into the " + "native object file(.llvm_stats section).")); + static cl::opt<bool> ProfileSampleAccurate( "profile-sample-accurate", cl::Hidden, cl::init(false), cl::desc("If the sample profile is accurate, we will mark all un-sampled " @@ -362,7 +373,7 @@ private: FS->GUIDToFuncNameMap = Map; for (const auto &ICS : FS->getCallsiteSamples()) { const FunctionSamplesMap &FSMap = ICS.second; - for (auto &IFS : FSMap) { + for (const auto &IFS : FSMap) { FunctionSamples &FS = const_cast<FunctionSamples &>(IFS.second); FSToUpdate.push(&FS); } @@ -412,6 +423,30 @@ using CandidateQueue = PriorityQueue<InlineCandidate, std::vector<InlineCandidate>, CandidateComparer>; +// Sample profile matching - fuzzy match. +class SampleProfileMatcher { + Module &M; + SampleProfileReader &Reader; + const PseudoProbeManager *ProbeManager; + + // Profile mismatching statstics. + uint64_t TotalProfiledCallsites = 0; + uint64_t NumMismatchedCallsites = 0; + uint64_t MismatchedCallsiteSamples = 0; + uint64_t TotalCallsiteSamples = 0; + uint64_t TotalProfiledFunc = 0; + uint64_t NumMismatchedFuncHash = 0; + uint64_t MismatchedFuncHashSamples = 0; + uint64_t TotalFuncHashSamples = 0; + +public: + SampleProfileMatcher(Module &M, SampleProfileReader &Reader, + const PseudoProbeManager *ProbeManager) + : M(M), Reader(Reader), ProbeManager(ProbeManager) {} + void detectProfileMismatch(); + void detectProfileMismatch(const Function &F, const FunctionSamples &FS); +}; + /// Sample profile pass. /// /// This pass reads profile data from the file specified by @@ -459,7 +494,7 @@ protected: bool inlineHotFunctions(Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs); - Optional<InlineCost> getExternalInlineAdvisorCost(CallBase &CB); + std::optional<InlineCost> getExternalInlineAdvisorCost(CallBase &CB); bool getExternalInlineAdvisorShouldInline(CallBase &CB); InlineCost shouldInlineCandidate(InlineCandidate &Candidate); bool getInlineCandidate(InlineCandidate *NewCandidate, CallBase *CB); @@ -475,7 +510,7 @@ protected: const SmallVectorImpl<CallBase *> &Candidates, const Function &F, bool Hot); void promoteMergeNotInlinedContextSamples( - DenseMap<CallBase *, const FunctionSamples *> NonInlinedCallSites, + MapVector<CallBase *, const FunctionSamples *> NonInlinedCallSites, const Function &F); std::vector<Function *> buildFunctionOrder(Module &M, CallGraph *CG); std::unique_ptr<ProfiledCallGraph> buildProfiledCallGraph(CallGraph &CG); @@ -541,6 +576,9 @@ protected: // A pseudo probe helper to correlate the imported sample counts. std::unique_ptr<PseudoProbeManager> ProbeManager; + // A helper to implement the sample profile matching algorithm. + std::unique_ptr<SampleProfileMatcher> MatchingManager; + private: const char *getAnnotatedRemarkPassName() const { return AnnotatedPassName.c_str(); @@ -582,7 +620,7 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { ErrorOr<uint64_t> SampleProfileLoader::getProbeWeight(const Instruction &Inst) { assert(FunctionSamples::ProfileIsProbeBased && "Profile is not pseudo probe based"); - Optional<PseudoProbe> Probe = extractProbe(Inst); + std::optional<PseudoProbe> Probe = extractProbe(Inst); // Ignore the non-probe instruction. If none of the instruction in the BB is // probe, we choose to infer the BB's weight. if (!Probe) @@ -735,7 +773,7 @@ SampleProfileLoader::findIndirectCallFunctionSamples( const FunctionSamples * SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { if (FunctionSamples::ProfileIsProbeBased) { - Optional<PseudoProbe> Probe = extractProbe(Inst); + std::optional<PseudoProbe> Probe = extractProbe(Inst); if (!Probe) return nullptr; } @@ -984,7 +1022,7 @@ bool SampleProfileLoader::shouldInlineColdCallee(CallBase &CallInst) { void SampleProfileLoader::emitOptimizationRemarksForInlineCandidates( const SmallVectorImpl<CallBase *> &Candidates, const Function &F, bool Hot) { - for (auto I : Candidates) { + for (auto *I : Candidates) { Function *CalledFunction = I->getCalledFunction(); if (CalledFunction) { ORE->emit(OptimizationRemarkAnalysis(getAnnotatedRemarkPassName(), @@ -1106,7 +1144,7 @@ bool SampleProfileLoader::inlineHotFunctions( "ProfAccForSymsInList should be false when profile-sample-accurate " "is enabled"); - DenseMap<CallBase *, const FunctionSamples *> LocalNotInlinedCallSites; + MapVector<CallBase *, const FunctionSamples *> LocalNotInlinedCallSites; bool Changed = false; bool LocalChanged = true; while (LocalChanged) { @@ -1116,7 +1154,7 @@ bool SampleProfileLoader::inlineHotFunctions( bool Hot = false; SmallVector<CallBase *, 10> AllCandidates; SmallVector<CallBase *, 10> ColdCandidates; - for (auto &I : BB.getInstList()) { + for (auto &I : BB) { const FunctionSamples *FS = nullptr; if (auto *CB = dyn_cast<CallBase>(&I)) { if (!isa<IntrinsicInst>(I)) { @@ -1126,7 +1164,7 @@ bool SampleProfileLoader::inlineHotFunctions( AllCandidates.push_back(CB); if (FS->getHeadSamplesEstimate() > 0 || FunctionSamples::ProfileIsCS) - LocalNotInlinedCallSites.try_emplace(CB, FS); + LocalNotInlinedCallSites.insert({CB, FS}); if (callsiteIsHot(FS, PSI, ProfAccForSymsInList)) Hot = true; else if (shouldInlineColdCallee(*CB)) @@ -1219,13 +1257,11 @@ bool SampleProfileLoader::tryInlineCandidate( InlineFunctionInfo IFI(nullptr, GetAC); IFI.UpdateProfile = false; - if (!InlineFunction(CB, IFI).isSuccess()) + InlineResult IR = InlineFunction(CB, IFI, + /*MergeAttributes=*/true); + if (!IR.isSuccess()) return false; - // Merge the attributes based on the inlining. - AttributeFuncs::mergeAttributesForInlining(*BB->getParent(), - *CalledFunction); - // The call to InlineFunction erases I, so we can't pass it here. emitInlinedIntoBasedOnCost(*ORE, DLoc, BB, *CalledFunction, *BB->getParent(), Cost, true, getAnnotatedRemarkPassName()); @@ -1250,7 +1286,7 @@ bool SampleProfileLoader::tryInlineCandidate( // aggregation of duplication. if (Candidate.CallsiteDistribution < 1) { for (auto &I : IFI.InlinedCallSites) { - if (Optional<PseudoProbe> Probe = extractProbe(*I)) + if (std::optional<PseudoProbe> Probe = extractProbe(*I)) setProbeDistributionFactor(*I, Probe->Factor * Candidate.CallsiteDistribution); } @@ -1275,7 +1311,7 @@ bool SampleProfileLoader::getInlineCandidate(InlineCandidate *NewCandidate, return false; float Factor = 1.0; - if (Optional<PseudoProbe> Probe = extractProbe(*CB)) + if (std::optional<PseudoProbe> Probe = extractProbe(*CB)) Factor = Probe->Factor; uint64_t CallsiteCount = @@ -1284,7 +1320,7 @@ bool SampleProfileLoader::getInlineCandidate(InlineCandidate *NewCandidate, return true; } -Optional<InlineCost> +std::optional<InlineCost> SampleProfileLoader::getExternalInlineAdvisorCost(CallBase &CB) { std::unique_ptr<InlineAdvice> Advice = nullptr; if (ExternalInlineAdvisor) { @@ -1303,15 +1339,15 @@ SampleProfileLoader::getExternalInlineAdvisorCost(CallBase &CB) { } bool SampleProfileLoader::getExternalInlineAdvisorShouldInline(CallBase &CB) { - Optional<InlineCost> Cost = getExternalInlineAdvisorCost(CB); - return Cost ? !!Cost.value() : false; + std::optional<InlineCost> Cost = getExternalInlineAdvisorCost(CB); + return Cost ? !!*Cost : false; } InlineCost SampleProfileLoader::shouldInlineCandidate(InlineCandidate &Candidate) { - if (Optional<InlineCost> ReplayCost = + if (std::optional<InlineCost> ReplayCost = getExternalInlineAdvisorCost(*Candidate.CallInstr)) - return ReplayCost.value(); + return *ReplayCost; // Adjust threshold based on call site hotness, only do this for callsite // prioritized inliner because otherwise cost-benefit check is done earlier. int SampleThreshold = SampleColdCallSiteThreshold; @@ -1387,7 +1423,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( CandidateQueue CQueue; InlineCandidate NewCandidate; for (auto &BB : F) { - for (auto &I : BB.getInstList()) { + for (auto &I : BB) { auto *CB = dyn_cast<CallBase>(&I); if (!CB) continue; @@ -1409,7 +1445,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( if (ExternalInlineAdvisor) SizeLimit = std::numeric_limits<unsigned>::max(); - DenseMap<CallBase *, const FunctionSamples *> LocalNotInlinedCallSites; + MapVector<CallBase *, const FunctionSamples *> LocalNotInlinedCallSites; // Perform iterative BFS call site prioritized inlining bool Changed = false; @@ -1466,7 +1502,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( ICPCount++; Changed = true; } else if (!ContextTracker) { - LocalNotInlinedCallSites.try_emplace(I, FS); + LocalNotInlinedCallSites.insert({I, FS}); } } } else if (CalledFunction && CalledFunction->getSubprogram() && @@ -1479,7 +1515,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( } Changed = true; } else if (!ContextTracker) { - LocalNotInlinedCallSites.try_emplace(I, Candidate.CalleeSamples); + LocalNotInlinedCallSites.insert({I, Candidate.CalleeSamples}); } } else if (LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) { findExternalInlineCandidate(I, findCalleeFunctionSamples(*I), @@ -1505,11 +1541,11 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( } void SampleProfileLoader::promoteMergeNotInlinedContextSamples( - DenseMap<CallBase *, const FunctionSamples *> NonInlinedCallSites, + MapVector<CallBase *, const FunctionSamples *> NonInlinedCallSites, const Function &F) { // Accumulate not inlined callsite information into notInlinedSamples for (const auto &Pair : NonInlinedCallSites) { - CallBase *I = Pair.getFirst(); + CallBase *I = Pair.first; Function *Callee = I->getCalledFunction(); if (!Callee || Callee->isDeclaration()) continue; @@ -1521,7 +1557,7 @@ void SampleProfileLoader::promoteMergeNotInlinedContextSamples( << "' into '" << ore::NV("Caller", &F) << "'"); ++NumCSNotInlined; - const FunctionSamples *FS = Pair.getSecond(); + const FunctionSamples *FS = Pair.second; if (FS->getTotalSamples() == 0 && FS->getHeadSamplesEstimate() == 0) { continue; } @@ -1581,7 +1617,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { BasicBlock *BB = &BI; if (BlockWeights[BB]) { - for (auto &I : BB->getInstList()) { + for (auto &I : *BB) { if (!isa<CallInst>(I) && !isa<InvokeInst>(I)) continue; if (!cast<CallBase>(I).getCalledFunction()) { @@ -1600,7 +1636,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { // Prorate the callsite counts based on the pre-ICP distribution // factor to reflect what is already done to the callsite before // ICP, such as calliste cloning. - if (Optional<PseudoProbe> Probe = extractProbe(I)) { + if (std::optional<PseudoProbe> Probe = extractProbe(I)) { if (Probe->Factor < 1) T = SampleRecord::adjustCallTargets(T.get(), Probe->Factor); } @@ -1633,7 +1669,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) { // Set profile metadata (possibly annotated by LTO prelink) to zero or // clear it for cold code. - for (auto &I : BB->getInstList()) { + for (auto &I : *BB) { if (isa<CallInst>(I) || isa<InvokeInst>(I)) { if (cast<CallBase>(I).isIndirectCall()) I.setMetadata(LLVMContext::MD_prof, nullptr); @@ -1704,10 +1740,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { } } - // FIXME: Re-enable for sample profiling after investigating why the sum - // of branch weights can be 0 - // - // misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false); + misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false); uint64_t TempWeight; // Only set weights if there is at least one non-zero weight. @@ -2013,9 +2046,156 @@ bool SampleProfileLoader::doInitialization(Module &M, } } + if (ReportProfileStaleness || PersistProfileStaleness) { + MatchingManager = + std::make_unique<SampleProfileMatcher>(M, *Reader, ProbeManager.get()); + } + return true; } +void SampleProfileMatcher::detectProfileMismatch(const Function &F, + const FunctionSamples &FS) { + if (FunctionSamples::ProfileIsProbeBased) { + uint64_t Count = FS.getTotalSamples(); + TotalFuncHashSamples += Count; + TotalProfiledFunc++; + if (!ProbeManager->profileIsValid(F, FS)) { + MismatchedFuncHashSamples += Count; + NumMismatchedFuncHash++; + return; + } + } + + std::unordered_set<LineLocation, LineLocationHash> MatchedCallsiteLocs; + + // Go through all the callsites on the IR and flag the callsite if the target + // name is the same as the one in the profile. + for (auto &BB : F) { + for (auto &I : BB) { + if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I)) + continue; + + const auto *CB = dyn_cast<CallBase>(&I); + if (auto &DLoc = I.getDebugLoc()) { + LineLocation IRCallsite = FunctionSamples::getCallSiteIdentifier(DLoc); + + StringRef CalleeName; + if (Function *Callee = CB->getCalledFunction()) + CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName()); + + const auto CTM = FS.findCallTargetMapAt(IRCallsite); + const auto CallsiteFS = FS.findFunctionSamplesMapAt(IRCallsite); + + // Indirect call case. + if (CalleeName.empty()) { + // Since indirect call does not have the CalleeName, check + // conservatively if callsite in the profile is a callsite location. + // This is to avoid nums of false positive since otherwise all the + // indirect call samples will be reported as mismatching. + if ((CTM && !CTM->empty()) || (CallsiteFS && !CallsiteFS->empty())) + MatchedCallsiteLocs.insert(IRCallsite); + } else { + // Check if the call target name is matched for direct call case. + if ((CTM && CTM->count(CalleeName)) || + (CallsiteFS && CallsiteFS->count(CalleeName))) + MatchedCallsiteLocs.insert(IRCallsite); + } + } + } + } + + auto isInvalidLineOffset = [](uint32_t LineOffset) { + return LineOffset & 0x8000; + }; + + // Check if there are any callsites in the profile that does not match to any + // IR callsites, those callsite samples will be discarded. + for (auto &I : FS.getBodySamples()) { + const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; + + uint64_t Count = I.second.getSamples(); + if (!I.second.getCallTargets().empty()) { + TotalCallsiteSamples += Count; + TotalProfiledCallsites++; + if (!MatchedCallsiteLocs.count(Loc)) { + MismatchedCallsiteSamples += Count; + NumMismatchedCallsites++; + } + } + } + + for (auto &I : FS.getCallsiteSamples()) { + const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; + + uint64_t Count = 0; + for (auto &FM : I.second) { + Count += FM.second.getHeadSamplesEstimate(); + } + TotalCallsiteSamples += Count; + TotalProfiledCallsites++; + if (!MatchedCallsiteLocs.count(Loc)) { + MismatchedCallsiteSamples += Count; + NumMismatchedCallsites++; + } + } +} + +void SampleProfileMatcher::detectProfileMismatch() { + for (auto &F : M) { + if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) + continue; + FunctionSamples *FS = Reader.getSamplesFor(F); + if (!FS) + continue; + detectProfileMismatch(F, *FS); + } + + if (ReportProfileStaleness) { + if (FunctionSamples::ProfileIsProbeBased) { + errs() << "(" << NumMismatchedFuncHash << "/" << TotalProfiledFunc << ")" + << " of functions' profile are invalid and " + << " (" << MismatchedFuncHashSamples << "/" << TotalFuncHashSamples + << ")" + << " of samples are discarded due to function hash mismatch.\n"; + } + errs() << "(" << NumMismatchedCallsites << "/" << TotalProfiledCallsites + << ")" + << " of callsites' profile are invalid and " + << "(" << MismatchedCallsiteSamples << "/" << TotalCallsiteSamples + << ")" + << " of samples are discarded due to callsite location mismatch.\n"; + } + + if (PersistProfileStaleness) { + LLVMContext &Ctx = M.getContext(); + MDBuilder MDB(Ctx); + + SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec; + if (FunctionSamples::ProfileIsProbeBased) { + ProfStatsVec.emplace_back("NumMismatchedFuncHash", NumMismatchedFuncHash); + ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc); + ProfStatsVec.emplace_back("MismatchedFuncHashSamples", + MismatchedFuncHashSamples); + ProfStatsVec.emplace_back("TotalFuncHashSamples", TotalFuncHashSamples); + } + + ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites); + ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites); + ProfStatsVec.emplace_back("MismatchedCallsiteSamples", + MismatchedCallsiteSamples); + ProfStatsVec.emplace_back("TotalCallsiteSamples", TotalCallsiteSamples); + + auto *MD = MDB.createLLVMStats(ProfStatsVec); + auto *NMD = M.getOrInsertNamedMetadata("llvm.stats"); + NMD->addOperand(MD); + } +} + bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, ProfileSummaryInfo *_PSI, CallGraph *CG) { GUIDToFuncNameMapper Mapper(M, *Reader, GUIDToFuncNameMap); @@ -2060,8 +2240,11 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, assert(SymbolMap.count(StringRef()) == 0 && "No empty StringRef should be added in SymbolMap"); + if (ReportProfileStaleness || PersistProfileStaleness) + MatchingManager->detectProfileMismatch(); + bool retval = false; - for (auto F : buildFunctionOrder(M, CG)) { + for (auto *F : buildFunctionOrder(M, CG)) { assert(!F->isDeclaration()); clearFunctionData(); retval |= runOnFunction(*F, AM); diff --git a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp index d1ab2649ee2e..c4844dbe7f3c 100644 --- a/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp +++ b/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp @@ -98,14 +98,14 @@ void PseudoProbeVerifier::runAfterPass(StringRef PassID, Any IR) { std::string Banner = "\n*** Pseudo Probe Verification After " + PassID.str() + " ***\n"; dbgs() << Banner; - if (any_isa<const Module *>(IR)) - runAfterPass(any_cast<const Module *>(IR)); - else if (any_isa<const Function *>(IR)) - runAfterPass(any_cast<const Function *>(IR)); - else if (any_isa<const LazyCallGraph::SCC *>(IR)) - runAfterPass(any_cast<const LazyCallGraph::SCC *>(IR)); - else if (any_isa<const Loop *>(IR)) - runAfterPass(any_cast<const Loop *>(IR)); + if (const auto **M = any_cast<const Module *>(&IR)) + runAfterPass(*M); + else if (const auto **F = any_cast<const Function *>(&IR)) + runAfterPass(*F); + else if (const auto **C = any_cast<const LazyCallGraph::SCC *>(&IR)) + runAfterPass(*C); + else if (const auto **L = any_cast<const Loop *>(&IR)) + runAfterPass(*L); else llvm_unreachable("Unknown IR unit"); } @@ -137,7 +137,7 @@ void PseudoProbeVerifier::runAfterPass(const Loop *L) { void PseudoProbeVerifier::collectProbeFactors(const BasicBlock *Block, ProbeFactorMap &ProbeFactors) { for (const auto &I : *Block) { - if (Optional<PseudoProbe> Probe = extractProbe(I)) { + if (std::optional<PseudoProbe> Probe = extractProbe(I)) { uint64_t Hash = computeCallStackHash(I); ProbeFactors[{Probe->Id, Hash}] += Probe->Factor; } @@ -421,7 +421,7 @@ void PseudoProbeUpdatePass::runOnFunction(Function &F, ProbeFactorMap ProbeFactors; for (auto &Block : F) { for (auto &I : Block) { - if (Optional<PseudoProbe> Probe = extractProbe(I)) { + if (std::optional<PseudoProbe> Probe = extractProbe(I)) { uint64_t Hash = computeCallStackHash(I); ProbeFactors[{Probe->Id, Hash}] += BBProfileCount(&Block); } @@ -431,7 +431,7 @@ void PseudoProbeUpdatePass::runOnFunction(Function &F, // Fix up over-counted probes. for (auto &Block : F) { for (auto &I : Block) { - if (Optional<PseudoProbe> Probe = extractProbe(I)) { + if (std::optional<PseudoProbe> Probe = extractProbe(I)) { uint64_t Hash = computeCallStackHash(I); float Sum = ProbeFactors[{Probe->Id, Hash}]; if (Sum != 0) diff --git a/llvm/lib/Transforms/IPO/StripSymbols.cpp b/llvm/lib/Transforms/IPO/StripSymbols.cpp index 9d4e9464f361..34f8c4316cca 100644 --- a/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" @@ -32,6 +33,7 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/StripSymbols.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -181,8 +183,7 @@ static void StripTypeNames(Module &M, bool PreserveDbgInfo) { TypeFinder StructTypes; StructTypes.run(M, false); - for (unsigned i = 0, e = StructTypes.size(); i != e; ++i) { - StructType *STy = StructTypes[i]; + for (StructType *STy : StructTypes) { if (STy->isLiteral() || STy->getName().empty()) continue; if (PreserveDbgInfo && STy->getName().startswith("llvm.dbg")) @@ -295,6 +296,44 @@ bool StripDebugDeclare::runOnModule(Module &M) { return stripDebugDeclareImpl(M); } +/// Collects compilation units referenced by functions or lexical scopes. +/// Accepts any DIScope and uses recursive bottom-up approach to reach either +/// DISubprogram or DILexicalBlockBase. +static void +collectCUsWithScope(const DIScope *Scope, std::set<DICompileUnit *> &LiveCUs, + SmallPtrSet<const DIScope *, 8> &VisitedScopes) { + if (!Scope) + return; + + auto InS = VisitedScopes.insert(Scope); + if (!InS.second) + return; + + if (const auto *SP = dyn_cast<DISubprogram>(Scope)) { + if (SP->getUnit()) + LiveCUs.insert(SP->getUnit()); + return; + } + if (const auto *LB = dyn_cast<DILexicalBlockBase>(Scope)) { + const DISubprogram *SP = LB->getSubprogram(); + if (SP && SP->getUnit()) + LiveCUs.insert(SP->getUnit()); + return; + } + + collectCUsWithScope(Scope->getScope(), LiveCUs, VisitedScopes); +} + +static void +collectCUsForInlinedFuncs(const DILocation *Loc, + std::set<DICompileUnit *> &LiveCUs, + SmallPtrSet<const DIScope *, 8> &VisitedScopes) { + if (!Loc || !Loc->getInlinedAt()) + return; + collectCUsWithScope(Loc->getScope(), LiveCUs, VisitedScopes); + collectCUsForInlinedFuncs(Loc->getInlinedAt(), LiveCUs, VisitedScopes); +} + static bool stripDeadDebugInfoImpl(Module &M) { bool Changed = false; @@ -322,10 +361,18 @@ static bool stripDeadDebugInfoImpl(Module &M) { } std::set<DICompileUnit *> LiveCUs; - // Any CU referenced from a subprogram is live. - for (DISubprogram *SP : F.subprograms()) { - if (SP->getUnit()) - LiveCUs.insert(SP->getUnit()); + SmallPtrSet<const DIScope *, 8> VisitedScopes; + // Any CU is live if is referenced from a subprogram metadata that is attached + // to a function defined or inlined in the module. + for (const Function &Fn : M.functions()) { + collectCUsWithScope(Fn.getSubprogram(), LiveCUs, VisitedScopes); + for (const_inst_iterator I = inst_begin(&Fn), E = inst_end(&Fn); I != E; + ++I) { + if (!I->getDebugLoc()) + continue; + const DILocation *DILoc = I->getDebugLoc().get(); + collectCUsForInlinedFuncs(DILoc, LiveCUs, VisitedScopes); + } } bool HasDeadCUs = false; diff --git a/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index c7d54b8cdeb0..d46f9a6c6757 100644 --- a/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -101,7 +101,7 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, // parameter. auto GetCallSiteProfCount = [&](const CallGraphNode *, const CallGraphNode::CallRecord &Edge) { - Optional<Scaled64> Res = None; + std::optional<Scaled64> Res; if (!Edge.first) return Res; CallBase &CB = *cast<CallBase>(*Edge.first); @@ -115,7 +115,7 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0); BBCount /= EntryFreq; BBCount *= Counts[Caller]; - return Optional<Scaled64>(BBCount); + return std::optional<Scaled64>(BBCount); }; CallGraph CG(M); diff --git a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index ef7af551a328..670097010085 100644 --- a/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -153,7 +153,7 @@ void promoteTypeIds(Module &M, StringRef ModuleId) { GO.getMetadata(LLVMContext::MD_type, MDs); GO.eraseMetadata(LLVMContext::MD_type); - for (auto MD : MDs) { + for (auto *MD : MDs) { auto I = LocalToGlobal.find(MD->getOperand(1)); if (I == LocalToGlobal.end()) { GO.addMetadata(LLVMContext::MD_type, *MD); @@ -318,8 +318,8 @@ void splitAndWriteThinLTOBitcode( return; } if (!F->isDeclaration() && - computeFunctionBodyMemoryAccess(*F, AARGetter(*F)) == - FMRB_DoesNotAccessMemory) + computeFunctionBodyMemoryAccess(*F, AARGetter(*F)) + .doesNotAccessMemory()) EligibleVirtualFns.insert(F); }); } @@ -376,7 +376,7 @@ void splitAndWriteThinLTOBitcode( auto &Ctx = MergedM->getContext(); SmallVector<MDNode *, 8> CfiFunctionMDs; - for (auto V : CfiFunctions) { + for (auto *V : CfiFunctions) { Function &F = *cast<Function>(V); SmallVector<MDNode *, 2> Types; F.getMetadata(LLVMContext::MD_type, Types); @@ -398,7 +398,7 @@ void splitAndWriteThinLTOBitcode( if(!CfiFunctionMDs.empty()) { NamedMDNode *NMD = MergedM->getOrInsertNamedMetadata("cfi.functions"); - for (auto MD : CfiFunctionMDs) + for (auto *MD : CfiFunctionMDs) NMD->addOperand(MD); } @@ -423,7 +423,7 @@ void splitAndWriteThinLTOBitcode( if (!FunctionAliases.empty()) { NamedMDNode *NMD = MergedM->getOrInsertNamedMetadata("aliases"); - for (auto MD : FunctionAliases) + for (auto *MD : FunctionAliases) NMD->addOperand(MD); } @@ -439,7 +439,7 @@ void splitAndWriteThinLTOBitcode( if (!Symvers.empty()) { NamedMDNode *NMD = MergedM->getOrInsertNamedMetadata("symvers"); - for (auto MD : Symvers) + for (auto *MD : Symvers) NMD->addOperand(MD); } @@ -546,54 +546,8 @@ void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, writeThinLinkBitcodeToFile(M, *ThinLinkOS, *Index, ModHash); } -class WriteThinLTOBitcode : public ModulePass { - raw_ostream &OS; // raw_ostream to print on - // The output stream on which to emit a minimized module for use - // just in the thin link, if requested. - raw_ostream *ThinLinkOS = nullptr; - -public: - static char ID; // Pass identification, replacement for typeid - WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()) { - initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry()); - } - - explicit WriteThinLTOBitcode(raw_ostream &o, raw_ostream *ThinLinkOS) - : ModulePass(ID), OS(o), ThinLinkOS(ThinLinkOS) { - initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { return "ThinLTO Bitcode Writer"; } - - bool runOnModule(Module &M) override { - const ModuleSummaryIndex *Index = - &(getAnalysis<ModuleSummaryIndexWrapperPass>().getIndex()); - writeThinLTOBitcode(OS, ThinLinkOS, LegacyAARGetter(*this), M, Index); - return true; - } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<ModuleSummaryIndexWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; } // anonymous namespace -char WriteThinLTOBitcode::ID = 0; -INITIALIZE_PASS_BEGIN(WriteThinLTOBitcode, "write-thinlto-bitcode", - "Write ThinLTO Bitcode", false, true) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(ModuleSummaryIndexWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(WriteThinLTOBitcode, "write-thinlto-bitcode", - "Write ThinLTO Bitcode", false, true) - -ModulePass *llvm::createWriteThinLTOBitcodePass(raw_ostream &Str, - raw_ostream *ThinLinkOS) { - return new WriteThinLTOBitcode(Str, ThinLinkOS); -} - PreservedAnalyses llvm::ThinLTOBitcodeWriterPass::run(Module &M, ModuleAnalysisManager &AM) { FunctionAnalysisManager &FAM = diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 18efe99f7cb4..487a0a4a97f7 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -259,8 +259,7 @@ wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, if (I < B.size()) BitsUsed |= B[I]; if (BitsUsed != 0xff) - return (MinByte + I) * 8 + - countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined); + return (MinByte + I) * 8 + countTrailingZeros(uint8_t(~BitsUsed)); } } else { // Find a free (Size/8) byte region in each member of Used. @@ -387,7 +386,7 @@ bool mustBeUnreachableFunction(ValueInfo TheFnVI) { return false; } - for (auto &Summary : TheFnVI.getSummaryList()) { + for (const auto &Summary : TheFnVI.getSummaryList()) { // Conservatively returns false if any non-live functions are seen. // In general either all summaries should be live or all should be dead. if (!Summary->isLive()) @@ -814,8 +813,8 @@ void updatePublicTypeTestCalls(Module &M, for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) { auto *CI = cast<CallInst>(U.getUser()); auto *NewCI = CallInst::Create( - TypeTestFunc, {CI->getArgOperand(0), CI->getArgOperand(1)}, None, "", - CI); + TypeTestFunc, {CI->getArgOperand(0), CI->getArgOperand(1)}, + std::nullopt, "", CI); CI->replaceAllUsesWith(NewCI); CI->eraseFromParent(); } @@ -1048,7 +1047,7 @@ bool DevirtIndex::tryFindVirtualCallTargets( // conservatively return false early. const GlobalVarSummary *VS = nullptr; bool LocalFound = false; - for (auto &S : P.VTableVI.getSummaryList()) { + for (const auto &S : P.VTableVI.getSummaryList()) { if (GlobalValue::isLocalLinkage(S->linkage())) { if (LocalFound) return false; @@ -1278,7 +1277,7 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, // If the summary list contains multiple summaries where at least one is // a local, give up, as we won't know which (possibly promoted) name to use. - for (auto &S : TheFn.getSummaryList()) + for (const auto &S : TheFn.getSummaryList()) if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) return false; @@ -1709,8 +1708,8 @@ bool DevirtModule::tryVirtualConstProp( // rather than using function attributes to perform local optimization. for (VirtualCallTarget &Target : TargetsForSlot) { if (Target.Fn->isDeclaration() || - computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != - FMRB_DoesNotAccessMemory || + !computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) + .doesNotAccessMemory() || Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || Target.Fn->getReturnType() != RetType) return false; @@ -1836,10 +1835,9 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { bool DevirtModule::areRemarksEnabled() { const auto &FL = M.getFunctionList(); for (const Function &Fn : FL) { - const auto &BBL = Fn.getBasicBlockList(); - if (BBL.empty()) + if (Fn.empty()) continue; - auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); + auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &Fn.front()); return DI.isEnabled(); } return false; @@ -1875,7 +1873,7 @@ void DevirtModule::scanTypeTestUsers( auto RemoveTypeTestAssumes = [&]() { // We no longer need the assumes or the type test. - for (auto Assume : Assumes) + for (auto *Assume : Assumes) Assume->eraseFromParent(); // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we // may use the vtable argument later. @@ -2265,10 +2263,10 @@ bool DevirtModule::run() { if (ExportSummary && isa<MDString>(S.first.TypeID)) { auto GUID = GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); - for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) + for (auto *FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) FS->addTypeTest(GUID); for (auto &CCS : S.second.ConstCSInfo) - for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) + for (auto *FS : CCS.second.SummaryTypeCheckedLoadUsers) FS->addTypeTest(GUID); } } @@ -2309,8 +2307,15 @@ void DevirtIndex::run() { return; DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; - for (auto &P : ExportSummary.typeIdCompatibleVtableMap()) { + for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) { NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first); + // Create the type id summary resolution regardlness of whether we can + // devirtualize, so that lower type tests knows the type id is used on + // a global and not Unsat. We do this here rather than in the loop over the + // CallSlots, since that handling will only see type tests that directly + // feed assumes, and we would miss any that aren't currently handled by WPD + // (such as type tests that feed assumes via phis). + ExportSummary.getOrInsertTypeIdSummary(P.first); } // Collect information from summary about which calls to try to devirtualize. @@ -2358,12 +2363,11 @@ void DevirtIndex::run() { std::vector<ValueInfo> TargetsForSlot; auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID); assert(TidSummary); - // Create the type id summary resolution regardlness of whether we can - // devirtualize, so that lower type tests knows the type id is used on - // a global and not Unsat. + // The type id summary would have been created while building the NameByGUID + // map earlier. WholeProgramDevirtResolution *Res = - &ExportSummary.getOrInsertTypeIdSummary(S.first.TypeID) - .WPDRes[S.first.ByteOffset]; + &ExportSummary.getTypeIdSummary(S.first.TypeID) + ->WPDRes[S.first.ByteOffset]; if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary, S.first.ByteOffset)) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 4a459ec6c550..b68efc993723 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -576,8 +576,7 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { } } - assert((NextTmpIdx <= array_lengthof(TmpResult) + 1) && - "out-of-bound access"); + assert((NextTmpIdx <= std::size(TmpResult) + 1) && "out-of-bound access"); Value *Result; if (!SimpVect.empty()) @@ -849,6 +848,7 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add, Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + Type *Ty = Add.getType(); Constant *Op1C; if (!match(Op1, m_ImmConstant(Op1C))) return nullptr; @@ -883,7 +883,14 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { if (match(Op0, m_Not(m_Value(X)))) return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X); + // (iN X s>> (N - 1)) + 1 --> zext (X > -1) const APInt *C; + unsigned BitWidth = Ty->getScalarSizeInBits(); + if (match(Op0, m_OneUse(m_AShr(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1)))) && + match(Op1, m_One())) + return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); + if (!match(Op1, m_APInt(C))) return nullptr; @@ -911,7 +918,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // Is this add the last step in a convoluted sext? // add(zext(xor i16 X, -32768), -32768) --> sext X - Type *Ty = Add.getType(); if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) && C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) return CastInst::Create(Instruction::SExt, X, Ty); @@ -969,15 +975,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { } } - // If all bits affected by the add are included in a high-bit-mask, do the - // add before the mask op: - // (X & 0xFF00) + xx00 --> (X + xx00) & 0xFF00 - if (match(Op0, m_OneUse(m_And(m_Value(X), m_APInt(C2)))) && - C2->isNegative() && C2->isShiftedMask() && *C == (*C & *C2)) { - Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(Ty, *C)); - return BinaryOperator::CreateAnd(NewAdd, ConstantInt::get(Ty, *C2)); - } - return nullptr; } @@ -1132,6 +1129,35 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } +/// Try to reduce signed division by power-of-2 to an arithmetic shift right. +static Instruction *foldAddToAshr(BinaryOperator &Add) { + // Division must be by power-of-2, but not the minimum signed value. + Value *X; + const APInt *DivC; + if (!match(Add.getOperand(0), m_SDiv(m_Value(X), m_Power2(DivC))) || + DivC->isNegative()) + return nullptr; + + // Rounding is done by adding -1 if the dividend (X) is negative and has any + // low bits set. The canonical pattern for that is an "ugt" compare with SMIN: + // sext (icmp ugt (X & (DivC - 1)), SMIN) + const APInt *MaskC; + ICmpInst::Predicate Pred; + if (!match(Add.getOperand(1), + m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)), + m_SignMask()))) || + Pred != ICmpInst::ICMP_UGT) + return nullptr; + + APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits()); + if (*MaskC != (SMin | (*DivC - 1))) + return nullptr; + + // (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC) + return BinaryOperator::CreateAShr( + X, ConstantInt::get(Add.getType(), DivC->exactLogBase2())); +} + Instruction *InstCombinerImpl:: canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( BinaryOperator &I) { @@ -1234,7 +1260,7 @@ Instruction *InstCombinerImpl:: } /// This is a specialization of a more general transform from -/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally +/// foldUsingDistributiveLaws. If that code can be made to work optimally /// for multi-use cases or propagating nsw/nuw, then we would not need this. static Instruction *factorizeMathWithShlOps(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { @@ -1270,6 +1296,45 @@ static Instruction *factorizeMathWithShlOps(BinaryOperator &I, return NewShl; } +/// Reduce a sequence of masked half-width multiplies to a single multiply. +/// ((XLow * YHigh) + (YLow * XHigh)) << HalfBits) + (XLow * YLow) --> X * Y +static Instruction *foldBoxMultiply(BinaryOperator &I) { + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + // Skip the odd bitwidth types. + if ((BitWidth & 0x1)) + return nullptr; + + unsigned HalfBits = BitWidth >> 1; + APInt HalfMask = APInt::getMaxValue(HalfBits); + + // ResLo = (CrossSum << HalfBits) + (YLo * XLo) + Value *XLo, *YLo; + Value *CrossSum; + if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)), + m_Mul(m_Value(YLo), m_Value(XLo))))) + return nullptr; + + // XLo = X & HalfMask + // YLo = Y & HalfMask + // TODO: Refactor with SimplifyDemandedBits or KnownBits known leading zeros + // to enhance robustness + Value *X, *Y; + if (!match(XLo, m_And(m_Value(X), m_SpecificInt(HalfMask))) || + !match(YLo, m_And(m_Value(Y), m_SpecificInt(HalfMask)))) + return nullptr; + + // CrossSum = (X' * (Y >> Halfbits)) + (Y' * (X >> HalfBits)) + // X' can be either X or XLo in the pattern (and the same for Y') + if (match(CrossSum, + m_c_Add(m_c_Mul(m_LShr(m_Specific(Y), m_SpecificInt(HalfBits)), + m_CombineOr(m_Specific(X), m_Specific(XLo))), + m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(HalfBits)), + m_CombineOr(m_Specific(Y), m_Specific(YLo)))))) + return BinaryOperator::CreateMul(X, Y); + + return nullptr; +} + Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Value *V = simplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), @@ -1286,9 +1351,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return Phi; // (A*B)+(A*C) -> A*(B+C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); + if (Instruction *R = foldBoxMultiply(I)) + return R; + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) return R; @@ -1376,35 +1444,17 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateAnd(A, NewMask); } + // ZExt (B - A) + ZExt(A) --> ZExt(B) + if ((match(RHS, m_ZExt(m_Value(A))) && + match(LHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))) || + (match(LHS, m_ZExt(m_Value(A))) && + match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A)))))) + return new ZExtInst(B, LHS->getType()); + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); - // add (select X 0 (sub n A)) A --> select X A n - { - SelectInst *SI = dyn_cast<SelectInst>(LHS); - Value *A = RHS; - if (!SI) { - SI = dyn_cast<SelectInst>(RHS); - A = LHS; - } - if (SI && SI->hasOneUse()) { - Value *TV = SI->getTrueValue(); - Value *FV = SI->getFalseValue(); - Value *N; - - // Can we fold the add into the argument of the select? - // We check both true and false select arguments for a matching subtract. - if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A)))) - // Fold the add into the true select value. - return SelectInst::Create(SI->getCondition(), N, A); - - if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A)))) - // Fold the add into the false select value. - return SelectInst::Create(SI->getCondition(), A, N); - } - } - if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1424,6 +1474,68 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return &I; } + // (add A (or A, -A)) --> (and (add A, -1) A) + // (add A (or -A, A)) --> (and (add A, -1) A) + // (add (or A, -A) A) --> (and (add A, -1) A) + // (add (or -A, A) A) --> (and (add A, -1) A) + if (match(&I, m_c_BinOp(m_Value(A), m_OneUse(m_c_Or(m_Neg(m_Deferred(A)), + m_Deferred(A)))))) { + Value *Add = + Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()), "", + I.hasNoUnsignedWrap(), I.hasNoSignedWrap()); + return BinaryOperator::CreateAnd(Add, A); + } + + // Canonicalize ((A & -A) - 1) --> ((A - 1) & ~A) + // Forms all commutable operations, and simplifies ctpop -> cttz folds. + if (match(&I, + m_Add(m_OneUse(m_c_And(m_Value(A), m_OneUse(m_Neg(m_Deferred(A))))), + m_AllOnes()))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(RHS->getType()); + Value *Dec = Builder.CreateAdd(A, AllOnes); + Value *Not = Builder.CreateXor(A, AllOnes); + return BinaryOperator::CreateAnd(Dec, Not); + } + + // Disguised reassociation/factorization: + // ~(A * C1) + A + // ((A * -C1) - 1) + A + // ((A * -C1) + A) - 1 + // (A * (1 - C1)) - 1 + if (match(&I, + m_c_Add(m_OneUse(m_Not(m_OneUse(m_Mul(m_Value(A), m_APInt(C1))))), + m_Deferred(A)))) { + Type *Ty = I.getType(); + Constant *NewMulC = ConstantInt::get(Ty, 1 - *C1); + Value *NewMul = Builder.CreateMul(A, NewMulC); + return BinaryOperator::CreateAdd(NewMul, ConstantInt::getAllOnesValue(Ty)); + } + + // (A * -2**C) + B --> B - (A << C) + const APInt *NegPow2C; + if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))), + m_Value(B)))) { + Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros()); + Value *Shl = Builder.CreateShl(A, ShiftAmtC); + return BinaryOperator::CreateSub(B, Shl); + } + + // Canonicalize signum variant that ends in add: + // (A s>> (BW - 1)) + (zext (A s> 0)) --> (A s>> (BW - 1)) | (zext (A != 0)) + ICmpInst::Predicate Pred; + uint64_t BitWidth = Ty->getScalarSizeInBits(); + if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowUndef(BitWidth - 1))) && + match(RHS, m_OneUse(m_ZExt( + m_OneUse(m_ICmp(Pred, m_Specific(A), m_ZeroInt()))))) && + Pred == CmpInst::ICMP_SGT) { + Value *NotZero = Builder.CreateIsNotNull(A, "isnotnull"); + Value *Zext = Builder.CreateZExt(NotZero, Ty, "isnotnull.zext"); + return BinaryOperator::CreateOr(LHS, Zext); + } + + if (Instruction *Ashr = foldAddToAshr(I)) + return Ashr; + // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. @@ -1665,6 +1777,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { return BinaryOperator::CreateFMulFMF(X, NewMulC, &I); } + // (-X - Y) + (X + Z) --> Z - Y + if (match(&I, m_c_FAdd(m_FSub(m_FNeg(m_Value(X)), m_Value(Y)), + m_c_FAdd(m_Deferred(X), m_Value(Z))))) + return BinaryOperator::CreateFSubFMF(Z, Y, &I); + if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } @@ -1879,7 +1996,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return TryToNarrowDeduceFlags(); // Should have been handled in Negator! // (A*B)-(A*C) -> A*(B-C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (I.getType()->isIntOrIntVectorTy(1)) @@ -1967,12 +2084,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { } const APInt *Op0C; - if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) { - // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known - // zero. - KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); - if ((*Op0C | RHSKnown.Zero).isAllOnes()) - return BinaryOperator::CreateXor(Op1, Op0); + if (match(Op0, m_APInt(Op0C))) { + if (Op0C->isMask()) { + // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known + // zero. + KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + if ((*Op0C | RHSKnown.Zero).isAllOnes()) + return BinaryOperator::CreateXor(Op1, Op0); + } + + // C - ((C3 -nuw X) & C2) --> (C - (C2 & C3)) + (X & C2) when: + // (C3 - ((C2 & C3) - 1)) is pow2 + // ((C2 + C3) & ((C2 & C3) - 1)) == ((C2 & C3) - 1) + // C2 is negative pow2 || sub nuw + const APInt *C2, *C3; + BinaryOperator *InnerSub; + if (match(Op1, m_OneUse(m_And(m_BinOp(InnerSub), m_APInt(C2)))) && + match(InnerSub, m_Sub(m_APInt(C3), m_Value(X))) && + (InnerSub->hasNoUnsignedWrap() || C2->isNegatedPowerOf2())) { + APInt C2AndC3 = *C2 & *C3; + APInt C2AndC3Minus1 = C2AndC3 - 1; + APInt C2AddC3 = *C2 + *C3; + if ((*C3 - C2AndC3Minus1).isPowerOf2() && + C2AndC3Minus1.isSubsetOf(C2AddC3)) { + Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(), *C2)); + return BinaryOperator::CreateAdd( + And, ConstantInt::get(I.getType(), *Op0C - C2AndC3)); + } + } } { @@ -2165,8 +2304,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { Value *A; const APInt *ShAmt; Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && - Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && + Op1->hasNUses(2) && *ShAmt == BitWidth - 1 && match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) { // B = ashr i32 A, 31 ; smear the sign bit // sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1) @@ -2185,7 +2325,6 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { const APInt *AddC, *AndC; if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) && match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) { - unsigned BitWidth = Ty->getScalarSizeInBits(); unsigned Cttz = AddC->countTrailingZeros(); APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); if ((HighMask & *AndC).isZero()) @@ -2227,18 +2366,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { } // C - ctpop(X) => ctpop(~X) if C is bitwidth - if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) && + if (match(Op0, m_SpecificInt(BitWidth)) && match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))) return replaceInstUsesWith( I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateNot(X)})); + // Reduce multiplies for difference-of-squares by factoring: + // (X * X) - (Y * Y) --> (X + Y) * (X - Y) + if (match(Op0, m_OneUse(m_Mul(m_Value(X), m_Deferred(X)))) && + match(Op1, m_OneUse(m_Mul(m_Value(Y), m_Deferred(Y))))) { + auto *OBO0 = cast<OverflowingBinaryOperator>(Op0); + auto *OBO1 = cast<OverflowingBinaryOperator>(Op1); + bool PropagateNSW = I.hasNoSignedWrap() && OBO0->hasNoSignedWrap() && + OBO1->hasNoSignedWrap() && BitWidth > 2; + bool PropagateNUW = I.hasNoUnsignedWrap() && OBO0->hasNoUnsignedWrap() && + OBO1->hasNoUnsignedWrap() && BitWidth > 1; + Value *Add = Builder.CreateAdd(X, Y, "add", PropagateNUW, PropagateNSW); + Value *Sub = Builder.CreateSub(X, Y, "sub", PropagateNUW, PropagateNSW); + Value *Mul = Builder.CreateMul(Add, Sub, "", PropagateNUW, PropagateNSW); + return replaceInstUsesWith(I, Mul); + } + return TryToNarrowDeduceFlags(); } /// This eliminates floating-point negation in either 'fneg(X)' or /// 'fsub(-0.0, X)' form by combining into a constant operand. -static Instruction *foldFNegIntoConstant(Instruction &I) { +static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) { // This is limited with one-use because fneg is assumed better for // reassociation and cheaper in codegen than fmul/fdiv. // TODO: Should the m_OneUse restriction be removed? @@ -2252,28 +2407,31 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { // Fold negation into constant operand. // -(X * C) --> X * (-C) if (match(FNegOp, m_FMul(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFMulFMF(X, NegC, &I); // -(X / C) --> X / (-C) if (match(FNegOp, m_FDiv(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFDivFMF(X, NegC, &I); // -(C / X) --> (-C) / X - if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) { - Instruction *FDiv = - BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); - - // Intersect 'nsz' and 'ninf' because those special value exceptions may not - // apply to the fdiv. Everything else propagates from the fneg. - // TODO: We could propagate nsz/ninf from fdiv alone? - FastMathFlags FMF = I.getFastMathFlags(); - FastMathFlags OpFMF = FNegOp->getFastMathFlags(); - FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); - FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); - return FDiv; - } + if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) { + Instruction *FDiv = BinaryOperator::CreateFDivFMF(NegC, X, &I); + + // Intersect 'nsz' and 'ninf' because those special value exceptions may + // not apply to the fdiv. Everything else propagates from the fneg. + // TODO: We could propagate nsz/ninf from fdiv alone? + FastMathFlags FMF = I.getFastMathFlags(); + FastMathFlags OpFMF = FNegOp->getFastMathFlags(); + FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); + FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); + return FDiv; + } // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: // -(X + C) --> -X + -C --> -C - X if (I.hasNoSignedZeros() && match(FNegOp, m_FAdd(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFSubFMF(NegC, X, &I); return nullptr; } @@ -2301,7 +2459,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { getSimplifyQuery().getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; Value *X, *Y; @@ -2314,18 +2472,26 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) return R; + Value *OneUse; + if (!match(Op, m_OneUse(m_Value(OneUse)))) + return nullptr; + // Try to eliminate fneg if at least 1 arm of the select is negated. Value *Cond; - if (match(Op, m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) { + if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) { // Unlike most transforms, this one is not safe to propagate nsz unless - // it is present on the original select. (We are conservatively intersecting - // the nsz flags from the select and root fneg instruction.) + // it is present on the original select. We union the flags from the select + // and fneg and then remove nsz if needed. auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) { S->copyFastMathFlags(&I); - if (auto *OldSel = dyn_cast<SelectInst>(Op)) + if (auto *OldSel = dyn_cast<SelectInst>(Op)) { + FastMathFlags FMF = I.getFastMathFlags(); + FMF |= OldSel->getFastMathFlags(); + S->setFastMathFlags(FMF); if (!OldSel->hasNoSignedZeros() && !CommonOperand && !isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition())) S->setHasNoSignedZeros(false); + } }; // -(Cond ? -P : Y) --> Cond ? P : -Y Value *P; @@ -2344,6 +2510,21 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { } } + // fneg (copysign x, y) -> copysign x, (fneg y) + if (match(OneUse, m_CopySign(m_Value(X), m_Value(Y)))) { + // The source copysign has an additional value input, so we can't propagate + // flags the copysign doesn't also have. + FastMathFlags FMF = I.getFastMathFlags(); + FMF &= cast<FPMathOperator>(OneUse)->getFastMathFlags(); + + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(FMF); + + Value *NegY = Builder.CreateFNeg(Y); + Value *NewCopySign = Builder.CreateCopySign(X, NegY); + return replaceInstUsesWith(I, NewCopySign); + } + return nullptr; } @@ -2370,7 +2551,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (match(&I, m_FNeg(m_Value(Op)))) return UnaryOperator::CreateFNegFMF(Op, &I); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) @@ -2409,7 +2590,8 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. if (match(Op1, m_ImmConstant(C))) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFAddFMF(Op0, NegC, &I); // X - (-Y) --> X + Y if (match(Op1, m_FNeg(m_Value(Y)))) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 8253c575bc37..97a001b2ed32 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -233,17 +233,13 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre /// the right hand side as a pair. /// LHS and RHS are the left hand side and the right hand side ICmps and PredL /// and PredR are their predicates, respectively. -static -Optional<std::pair<unsigned, unsigned>> -getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, - Value *&D, Value *&E, ICmpInst *LHS, - ICmpInst *RHS, - ICmpInst::Predicate &PredL, - ICmpInst::Predicate &PredR) { +static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair( + Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS, + ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { // Don't allow pointers. Splat vectors are fine. if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() || !RHS->getOperand(0)->getType()->isIntOrIntVectorTy()) - return None; + return std::nullopt; // Here comes the tricky part: // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -274,7 +270,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, // Bail if LHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredL)) - return None; + return std::nullopt; Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); @@ -288,7 +284,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, A = R12; D = R11; } else { - return None; + return std::nullopt; } E = R2; R1 = nullptr; @@ -316,7 +312,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, // Bail if RHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredR)) - return None; + return std::nullopt; // Look for ANDs on the right side of the RHS icmp. if (!Ok) { @@ -336,7 +332,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, E = R1; Ok = true; } else { - return None; + return std::nullopt; } assert(Ok && "Failed to find AND on the right side of the RHS icmp."); @@ -358,7 +354,8 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, unsigned LeftType = getMaskedICmpType(A, B, C, PredL); unsigned RightType = getMaskedICmpType(A, D, E, PredR); - return Optional<std::pair<unsigned, unsigned>>(std::make_pair(LeftType, RightType)); + return std::optional<std::pair<unsigned, unsigned>>( + std::make_pair(LeftType, RightType)); } /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) into a single @@ -526,7 +523,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, InstCombiner::BuilderTy &Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - Optional<std::pair<unsigned, unsigned>> MaskPair = + std::optional<std::pair<unsigned, unsigned>> MaskPair = getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); if (!MaskPair) return nullptr; @@ -1016,10 +1013,10 @@ struct IntPart { }; /// Match an extraction of bits from an integer. -static Optional<IntPart> matchIntPart(Value *V) { +static std::optional<IntPart> matchIntPart(Value *V) { Value *X; if (!match(V, m_OneUse(m_Trunc(m_Value(X))))) - return None; + return std::nullopt; unsigned NumOriginalBits = X->getType()->getScalarSizeInBits(); unsigned NumExtractedBits = V->getType()->getScalarSizeInBits(); @@ -1056,10 +1053,10 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred) return nullptr; - Optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0)); - Optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1)); - Optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0)); - Optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1)); + std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0)); + std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1)); + std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0)); + std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1)); if (!L0 || !R0 || !L1 || !R1) return nullptr; @@ -1094,7 +1091,7 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, /// common operand with the constant. Callers are expected to call this with /// Cmp0/Cmp1 switched to handle logic op commutativity. static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, - bool IsAnd, + bool IsAnd, bool IsLogical, InstCombiner::BuilderTy &Builder, const SimplifyQuery &Q) { // Match an equality compare with a non-poison constant as Cmp0. @@ -1130,6 +1127,9 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; SubstituteCmp = Builder.CreateICmp(Pred1, Y, C); } + if (IsLogical) + return IsAnd ? Builder.CreateLogicalAnd(Cmp0, SubstituteCmp) + : Builder.CreateLogicalOr(Cmp0, SubstituteCmp); return Builder.CreateBinOp(IsAnd ? Instruction::And : Instruction::Or, Cmp0, SubstituteCmp); } @@ -1174,7 +1174,7 @@ Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, Type *Ty = V1->getType(); Value *NewV = V1; - Optional<ConstantRange> CR = CR1.exactUnionWith(CR2); + std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2); if (!CR) { if (!(ICmp1->hasOneUse() && ICmp2->hasOneUse()) || CR1.isWrappedSet() || CR2.isWrappedSet()) @@ -1205,6 +1205,47 @@ Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, return Builder.CreateICmp(NewPred, NewV, ConstantInt::get(Ty, NewC)); } +/// Ignore all operations which only change the sign of a value, returning the +/// underlying magnitude value. +static Value *stripSignOnlyFPOps(Value *Val) { + match(Val, m_FNeg(m_Value(Val))); + match(Val, m_FAbs(m_Value(Val))); + match(Val, m_CopySign(m_Value(Val), m_Value())); + return Val; +} + +/// Matches canonical form of isnan, fcmp ord x, 0 +static bool matchIsNotNaN(FCmpInst::Predicate P, Value *LHS, Value *RHS) { + return P == FCmpInst::FCMP_ORD && match(RHS, m_AnyZeroFP()); +} + +/// Matches fcmp u__ x, +/-inf +static bool matchUnorderedInfCompare(FCmpInst::Predicate P, Value *LHS, + Value *RHS) { + return FCmpInst::isUnordered(P) && match(RHS, m_Inf()); +} + +/// and (fcmp ord x, 0), (fcmp u* x, inf) -> fcmp o* x, inf +/// +/// Clang emits this pattern for doing an isfinite check in __builtin_isnormal. +static Value *matchIsFiniteTest(InstCombiner::BuilderTy &Builder, FCmpInst *LHS, + FCmpInst *RHS) { + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + + if (!matchIsNotNaN(PredL, LHS0, LHS1) || + !matchUnorderedInfCompare(PredR, RHS0, RHS1)) + return nullptr; + + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + FastMathFlags FMF = LHS->getFastMathFlags(); + FMF &= RHS->getFastMathFlags(); + Builder.setFastMathFlags(FMF); + + return Builder.CreateFCmp(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1); +} + Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd, bool IsLogicalSelect) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); @@ -1263,9 +1304,79 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, return Builder.CreateFCmp(PredL, LHS0, RHS0); } + if (IsAnd && stripSignOnlyFPOps(LHS0) == stripSignOnlyFPOps(RHS0)) { + // and (fcmp ord x, 0), (fcmp u* x, inf) -> fcmp o* x, inf + // and (fcmp ord x, 0), (fcmp u* fabs(x), inf) -> fcmp o* x, inf + if (Value *Left = matchIsFiniteTest(Builder, LHS, RHS)) + return Left; + if (Value *Right = matchIsFiniteTest(Builder, RHS, LHS)) + return Right; + } + return nullptr; } +/// or (is_fpclass x, mask0), (is_fpclass x, mask1) +/// -> is_fpclass x, (mask0 | mask1) +/// and (is_fpclass x, mask0), (is_fpclass x, mask1) +/// -> is_fpclass x, (mask0 & mask1) +/// xor (is_fpclass x, mask0), (is_fpclass x, mask1) +/// -> is_fpclass x, (mask0 ^ mask1) +Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO, + Value *Op0, Value *Op1) { + Value *ClassVal; + uint64_t ClassMask0, ClassMask1; + + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( + m_Value(ClassVal), m_ConstantInt(ClassMask0)))) && + match(Op1, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( + m_Specific(ClassVal), m_ConstantInt(ClassMask1))))) { + unsigned NewClassMask; + switch (BO.getOpcode()) { + case Instruction::And: + NewClassMask = ClassMask0 & ClassMask1; + break; + case Instruction::Or: + NewClassMask = ClassMask0 | ClassMask1; + break; + case Instruction::Xor: + NewClassMask = ClassMask0 ^ ClassMask1; + break; + default: + llvm_unreachable("not a binary logic operator"); + } + + // TODO: Also check for special fcmps + auto *II = cast<IntrinsicInst>(Op0); + II->setArgOperand( + 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask)); + return replaceInstUsesWith(BO, II); + } + + return nullptr; +} + +/// Look for the pattern that conditionally negates a value via math operations: +/// cond.splat = sext i1 cond +/// sub = add cond.splat, x +/// xor = xor sub, cond.splat +/// and rewrite it to do the same, but via logical operations: +/// value.neg = sub 0, value +/// cond = select i1 neg, value.neg, value +Instruction *InstCombinerImpl::canonicalizeConditionalNegationViaMathToSelect( + BinaryOperator &I) { + assert(I.getOpcode() == BinaryOperator::Xor && "Only for xor!"); + Value *Cond, *X; + // As per complexity ordering, `xor` is not commutative here. + if (!match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value())) || + !match(I.getOperand(1), m_SExt(m_Value(Cond))) || + !Cond->getType()->isIntOrIntVectorTy(1) || + !match(I.getOperand(0), m_c_Add(m_SExt(m_Deferred(Cond)), m_Value(X)))) + return nullptr; + return SelectInst::Create(Cond, Builder.CreateNeg(X, X->getName() + ".neg"), + X); +} + /// This a limited reassociation for a special case (see above) where we are /// checking if two values are either both NAN (unordered) or not-NAN (ordered). /// This could be handled more generally in '-reassociation', but it seems like @@ -1430,11 +1541,33 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { if (!Cast1) return nullptr; - // Both operands of the logic operation are casts. The casts must be of the - // same type for reduction. - auto CastOpcode = Cast0->getOpcode(); - if (CastOpcode != Cast1->getOpcode() || SrcTy != Cast1->getSrcTy()) + // Both operands of the logic operation are casts. The casts must be the + // same kind for reduction. + Instruction::CastOps CastOpcode = Cast0->getOpcode(); + if (CastOpcode != Cast1->getOpcode()) + return nullptr; + + // If the source types do not match, but the casts are matching extends, we + // can still narrow the logic op. + if (SrcTy != Cast1->getSrcTy()) { + Value *X, *Y; + if (match(Cast0, m_OneUse(m_ZExtOrSExt(m_Value(X)))) && + match(Cast1, m_OneUse(m_ZExtOrSExt(m_Value(Y))))) { + // Cast the narrower source to the wider source type. + unsigned XNumBits = X->getType()->getScalarSizeInBits(); + unsigned YNumBits = Y->getType()->getScalarSizeInBits(); + if (XNumBits < YNumBits) + X = Builder.CreateCast(CastOpcode, X, Y->getType()); + else + Y = Builder.CreateCast(CastOpcode, Y, X->getType()); + // Do the logic op in the intermediate width, then widen more. + Value *NarrowLogic = Builder.CreateBinOp(LogicOpc, X, Y); + return CastInst::Create(CastOpcode, NarrowLogic, DestTy); + } + + // Give up for other cast opcodes. return nullptr; + } Value *Cast0Src = Cast0->getOperand(0); Value *Cast1Src = Cast1->getOperand(0); @@ -1722,6 +1855,77 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, return nullptr; } +/// Try to reassociate a pair of binops so that values with one use only are +/// part of the same instruction. This may enable folds that are limited with +/// multi-use restrictions and makes it more likely to match other patterns that +/// are looking for a common operand. +static Instruction *reassociateForUses(BinaryOperator &BO, + InstCombinerImpl::BuilderTy &Builder) { + Instruction::BinaryOps Opcode = BO.getOpcode(); + Value *X, *Y, *Z; + if (match(&BO, + m_c_BinOp(Opcode, m_OneUse(m_BinOp(Opcode, m_Value(X), m_Value(Y))), + m_OneUse(m_Value(Z))))) { + if (!isa<Constant>(X) && !isa<Constant>(Y) && !isa<Constant>(Z)) { + // (X op Y) op Z --> (Y op Z) op X + if (!X->hasOneUse()) { + Value *YZ = Builder.CreateBinOp(Opcode, Y, Z); + return BinaryOperator::Create(Opcode, YZ, X); + } + // (X op Y) op Z --> (X op Z) op Y + if (!Y->hasOneUse()) { + Value *XZ = Builder.CreateBinOp(Opcode, X, Z); + return BinaryOperator::Create(Opcode, XZ, Y); + } + } + } + + return nullptr; +} + +// Match +// (X + C2) | C +// (X + C2) ^ C +// (X + C2) & C +// and convert to do the bitwise logic first: +// (X | C) + C2 +// (X ^ C) + C2 +// (X & C) + C2 +// iff bits affected by logic op are lower than last bit affected by math op +static Instruction *canonicalizeLogicFirst(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Type *Ty = I.getType(); + Instruction::BinaryOps OpC = I.getOpcode(); + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *X; + const APInt *C, *C2; + + if (!(match(Op0, m_OneUse(m_Add(m_Value(X), m_APInt(C2)))) && + match(Op1, m_APInt(C)))) + return nullptr; + + unsigned Width = Ty->getScalarSizeInBits(); + unsigned LastOneMath = Width - C2->countTrailingZeros(); + + switch (OpC) { + case Instruction::And: + if (C->countLeadingOnes() < LastOneMath) + return nullptr; + break; + case Instruction::Xor: + case Instruction::Or: + if (C->countLeadingZeros() < LastOneMath) + return nullptr; + break; + default: + llvm_unreachable("Unexpected BinaryOp!"); + } + + Value *NewBinOp = Builder.CreateBinOp(OpC, X, ConstantInt::get(Ty, *C)); + return BinaryOperator::CreateAdd(NewBinOp, ConstantInt::get(Ty, *C2)); +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -1754,7 +1958,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return X; // (A|B)&(A|C) -> A|(B&C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyBSwap(I, Builder)) @@ -2156,24 +2360,36 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, Op0, Constant::getNullValue(Ty)); - // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 - unsigned FullShift = Ty->getScalarSizeInBits() - 1; - if (match(&I, m_c_And(m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))), - m_Value(Y)))) { + // Similarly, a 'not' of the bool translates to a swap of the select arms: + // ~sext(A) & Op1 --> A ? 0 : Op1 + // Op0 & ~sext(A) --> A ? 0 : Op0 + if (match(Op0, m_Not(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Constant::getNullValue(Ty), Op1); + if (match(Op1, m_Not(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Constant::getNullValue(Ty), Op0); + + // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 -- with optional sext + if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf( + m_AShr(m_Value(X), m_APIntAllowUndef(C)))), + m_Value(Y))) && + *C == X->getType()->getScalarSizeInBits() - 1) { Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); return SelectInst::Create(IsNeg, Y, ConstantInt::getNullValue(Ty)); } // If there's a 'not' of the shifted value, swap the select operands: - // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y - if (match(&I, m_c_And(m_OneUse(m_Not( - m_AShr(m_Value(X), m_SpecificInt(FullShift)))), - m_Value(Y)))) { + // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y -- with optional sext + if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf( + m_Not(m_AShr(m_Value(X), m_APIntAllowUndef(C))))), + m_Value(Y))) && + *C == X->getType()->getScalarSizeInBits() - 1) { Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); return SelectInst::Create(IsNeg, ConstantInt::getNullValue(Ty), Y); } // (~x) & y --> ~(x | (~y)) iff that gets rid of inversions - if (sinkNotIntoOtherHandOfAndOrOr(I)) + if (sinkNotIntoOtherHandOfLogicalOp(I)) return &I; // An and recurrence w/loop invariant step is equivelent to (and start, step) @@ -2182,6 +2398,15 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (matchSimpleRecurrence(&I, PN, Start, Step) && DT.dominates(Step, PN)) return replaceInstUsesWith(I, Builder.CreateAnd(Start, Step)); + if (Instruction *R = reassociateForUses(I, Builder)) + return R; + + if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder)) + return Canonicalized; + + if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) + return Folded; + return nullptr; } @@ -2375,7 +2600,9 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { /// We have an expression of the form (A & C) | (B & D). If A is a scalar or /// vector composed of all-zeros or all-ones values and is the bitwise 'not' of /// B, it can be used as the condition operand of a select instruction. -Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { +/// We will detect (A & C) | ~(B | D) when the flag ABIsTheSame enabled. +Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B, + bool ABIsTheSame) { // We may have peeked through bitcasts in the caller. // Exit immediately if we don't have (vector) integer types. Type *Ty = A->getType(); @@ -2383,7 +2610,7 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { return nullptr; // If A is the 'not' operand of B and has enough signbits, we have our answer. - if (match(B, m_Not(m_Specific(A)))) { + if (ABIsTheSame ? (A == B) : match(B, m_Not(m_Specific(A)))) { // If these are scalars or vectors of i1, A can be used directly. if (Ty->isIntOrIntVectorTy(1)) return A; @@ -2403,6 +2630,10 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { return nullptr; } + // TODO: add support for sext and constant case + if (ABIsTheSame) + return nullptr; + // If both operands are constants, see if the constants are inverse bitmasks. Constant *AConst, *BConst; if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) @@ -2451,14 +2682,17 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { /// We have an expression of the form (A & C) | (B & D). Try to simplify this /// to "A' ? C : D", where A' is a boolean or vector of booleans. +/// When InvertFalseVal is set to true, we try to match the pattern +/// where we have peeked through a 'not' op and A and B are the same: +/// (A & C) | ~(A | D) --> (A & C) | (~A & ~D) --> A' ? C : ~D Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, - Value *D) { + Value *D, bool InvertFalseVal) { // The potential condition of the select may be bitcasted. In that case, look // through its bitcast and the corresponding bitcast of the 'not' condition. Type *OrigType = A->getType(); A = peekThroughBitcast(A, true); B = peekThroughBitcast(B, true); - if (Value *Cond = getSelectCondition(A, B)) { + if (Value *Cond = getSelectCondition(A, B, InvertFalseVal)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will @@ -2469,11 +2703,13 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, unsigned Elts = VecTy->getElementCount().getKnownMinValue(); // For a fixed or scalable vector, get the size in bits of N x iM; for a // scalar this is just M. - unsigned SelEltSize = SelTy->getPrimitiveSizeInBits().getKnownMinSize(); + unsigned SelEltSize = SelTy->getPrimitiveSizeInBits().getKnownMinValue(); Type *EltTy = Builder.getIntNTy(SelEltSize / Elts); SelTy = VectorType::get(EltTy, VecTy->getElementCount()); } Value *BitcastC = Builder.CreateBitCast(C, SelTy); + if (InvertFalseVal) + D = Builder.CreateNot(D); Value *BitcastD = Builder.CreateBitCast(D, SelTy); Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); return Builder.CreateBitCast(Select, OrigType); @@ -2484,8 +2720,9 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, // (icmp eq X, 0) | (icmp ult Other, X) -> (icmp ule Other, X-1) // (icmp ne X, 0) & (icmp uge Other, X) -> (icmp ugt Other, X-1) -Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - IRBuilderBase &Builder) { +static Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, + bool IsAnd, bool IsLogical, + IRBuilderBase &Builder) { ICmpInst::Predicate LPred = IsAnd ? LHS->getInversePredicate() : LHS->getPredicate(); ICmpInst::Predicate RPred = @@ -2504,6 +2741,8 @@ Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else return nullptr; + if (IsLogical) + Other = Builder.CreateFreeze(Other); return Builder.CreateICmp( IsAnd ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE, Builder.CreateAdd(LHS0, Constant::getAllOnesValue(LHS0->getType())), @@ -2552,22 +2791,23 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder)) return V; - // TODO: One of these directions is fine with logical and/or, the other could - // be supported by inserting freeze. - if (!IsLogical) { - if (Value *V = foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, Builder)) - return V; - if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, Builder)) - return V; - } + if (Value *V = + foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, IsLogical, Builder)) + return V; + // We can treat logical like bitwise here, because both operands are used on + // the LHS, and as such poison from both will propagate. + if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, + /*IsLogical*/ false, Builder)) + return V; - // TODO: Verify whether this is safe for logical and/or. - if (!IsLogical) { - if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, IsAnd, Builder, Q)) - return V; - if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, IsAnd, Builder, Q)) - return V; - } + if (Value *V = + foldAndOrOfICmpsWithConstEq(LHS, RHS, IsAnd, IsLogical, Builder, Q)) + return V; + // We can convert this case to bitwise and, because both operands are used + // on the LHS, and as such poison from both will propagate. + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, IsAnd, + /*IsLogical*/ false, Builder, Q)) + return V; if (Value *V = foldIsPowerOf2OrZero(LHS, RHS, IsAnd, Builder)) return V; @@ -2724,7 +2964,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { return X; // (A&B)|(A&C) -> A&(B|C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyBSwap(I, Builder)) @@ -2777,6 +3017,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { return BinaryOperator::CreateMul(X, IncrementY); } + // X | (X ^ Y) --> X | Y (4 commuted patterns) + if (match(&I, m_c_Or(m_Value(X), m_c_Xor(m_Deferred(X), m_Value(Y))))) + return BinaryOperator::CreateOr(X, Y); + // (A & C) | (B & D) Value *A, *B, *C, *D; if (match(Op0, m_And(m_Value(A), m_Value(C))) && @@ -2854,6 +3098,20 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { } } + if (match(Op0, m_And(m_Value(A), m_Value(C))) && + match(Op1, m_Not(m_Or(m_Value(B), m_Value(D)))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + // (Cond & C) | ~(Cond | D) -> Cond ? C : ~D + if (Value *V = matchSelectFromAndOr(A, C, B, D, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(A, C, D, B, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, B, D, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, D, B, true)) + return replaceInstUsesWith(I, V); + } + // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) @@ -2886,30 +3144,58 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { SwappedForXor = true; } - // A | ( A ^ B) -> A | B - // A | (~A ^ B) -> A | ~B - // (A & B) | (A ^ B) - // ~A | (A ^ B) -> ~(A & B) - // The swap above should always make Op0 the 'not' for the last case. if (match(Op1, m_Xor(m_Value(A), m_Value(B)))) { - if (Op0 == A || Op0 == B) - return BinaryOperator::CreateOr(A, B); - + // (A | ?) | (A ^ B) --> (A | ?) | B + // (B | ?) | (A ^ B) --> (B | ?) | A + if (match(Op0, m_c_Or(m_Specific(A), m_Value()))) + return BinaryOperator::CreateOr(Op0, B); + if (match(Op0, m_c_Or(m_Specific(B), m_Value()))) + return BinaryOperator::CreateOr(Op0, A); + + // (A & B) | (A ^ B) --> A | B + // (B & A) | (A ^ B) --> A | B if (match(Op0, m_And(m_Specific(A), m_Specific(B))) || match(Op0, m_And(m_Specific(B), m_Specific(A)))) return BinaryOperator::CreateOr(A, B); + // ~A | (A ^ B) --> ~(A & B) + // ~B | (A ^ B) --> ~(A & B) + // The swap above should always make Op0 the 'not'. if ((Op0->hasOneUse() || Op1->hasOneUse()) && (match(Op0, m_Not(m_Specific(A))) || match(Op0, m_Not(m_Specific(B))))) return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + // Same as above, but peek through an 'and' to the common operand: + // ~(A & ?) | (A ^ B) --> ~((A & ?) & B) + // ~(B & ?) | (A ^ B) --> ~((B & ?) & A) + Instruction *And; + if ((Op0->hasOneUse() || Op1->hasOneUse()) && + match(Op0, m_Not(m_CombineAnd(m_Instruction(And), + m_c_And(m_Specific(A), m_Value()))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(And, B)); + if ((Op0->hasOneUse() || Op1->hasOneUse()) && + match(Op0, m_Not(m_CombineAnd(m_Instruction(And), + m_c_And(m_Specific(B), m_Value()))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(And, A)); + + // (~A | C) | (A ^ B) --> ~(A & B) | C + // (~B | C) | (A ^ B) --> ~(A & B) | C + if (Op0->hasOneUse() && Op1->hasOneUse() && + (match(Op0, m_c_Or(m_Not(m_Specific(A)), m_Value(C))) || + match(Op0, m_c_Or(m_Not(m_Specific(B)), m_Value(C))))) { + Value *Nand = Builder.CreateNot(Builder.CreateAnd(A, B), "nand"); + return BinaryOperator::CreateOr(Nand, C); + } + + // A | (~A ^ B) --> ~B | A + // B | (A ^ ~B) --> ~A | B if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) { - Value *Not = Builder.CreateNot(B, B->getName() + ".not"); - return BinaryOperator::CreateOr(Not, Op0); + Value *NotB = Builder.CreateNot(B, B->getName() + ".not"); + return BinaryOperator::CreateOr(NotB, Op0); } if (Op1->hasOneUse() && match(B, m_Not(m_Specific(Op0)))) { - Value *Not = Builder.CreateNot(A, A->getName() + ".not"); - return BinaryOperator::CreateOr(Not, Op0); + Value *NotA = Builder.CreateNot(A, A->getName() + ".not"); + return BinaryOperator::CreateOr(NotA, Op0); } } @@ -3072,7 +3358,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { } // (~x) | y --> ~(x & (~y)) iff that gets rid of inversions - if (sinkNotIntoOtherHandOfAndOrOr(I)) + if (sinkNotIntoOtherHandOfLogicalOp(I)) return &I; // Improve "get low bit mask up to and including bit X" pattern: @@ -3121,6 +3407,15 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { Builder.CreateOr(C, Builder.CreateAnd(A, B)), D); } + if (Instruction *R = reassociateForUses(I, Builder)) + return R; + + if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder)) + return Canonicalized; + + if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) + return Folded; + return nullptr; } @@ -3338,14 +3633,8 @@ static Instruction *visitMaskedMerge(BinaryOperator &I, // (~x) ^ y // or into // x ^ (~y) -static Instruction *sinkNotIntoXor(BinaryOperator &I, +static Instruction *sinkNotIntoXor(BinaryOperator &I, Value *X, Value *Y, InstCombiner::BuilderTy &Builder) { - Value *X, *Y; - // FIXME: one-use check is not needed in general, but currently we are unable - // to fold 'not' into 'icmp', if that 'icmp' has multiple uses. (D35182) - if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y)))))) - return nullptr; - // We only want to do the transform if it is free to do. if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) { // Ok, good. @@ -3358,6 +3647,41 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I, return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan"); } +static Instruction *foldNotXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y; + // FIXME: one-use check is not needed in general, but currently we are unable + // to fold 'not' into 'icmp', if that 'icmp' has multiple uses. (D35182) + if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y)))))) + return nullptr; + + if (Instruction *NewXor = sinkNotIntoXor(I, X, Y, Builder)) + return NewXor; + + auto hasCommonOperand = [](Value *A, Value *B, Value *C, Value *D) { + return A == C || A == D || B == C || B == D; + }; + + Value *A, *B, *C, *D; + // Canonicalize ~((A & B) ^ (A | ?)) -> (A & B) | ~(A | ?) + // 4 commuted variants + if (match(X, m_And(m_Value(A), m_Value(B))) && + match(Y, m_Or(m_Value(C), m_Value(D))) && hasCommonOperand(A, B, C, D)) { + Value *NotY = Builder.CreateNot(Y); + return BinaryOperator::CreateOr(X, NotY); + }; + + // Canonicalize ~((A | ?) ^ (A & B)) -> (A & B) | ~(A | ?) + // 4 commuted variants + if (match(Y, m_And(m_Value(A), m_Value(B))) && + match(X, m_Or(m_Value(C), m_Value(D))) && hasCommonOperand(A, B, C, D)) { + Value *NotX = Builder.CreateNot(X); + return BinaryOperator::CreateOr(Y, NotX); + }; + + return nullptr; +} + /// Canonicalize a shifty way to code absolute value to the more common pattern /// that uses negation and select. static Instruction *canonicalizeAbs(BinaryOperator &Xor, @@ -3392,39 +3716,127 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor, } // Transform -// z = (~x) &/| y +// z = ~(x &/| y) // into: -// z = ~(x |/& (~y)) -// iff y is free to invert and all uses of z can be freely updated. -bool InstCombinerImpl::sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I) { - Instruction::BinaryOps NewOpc; - switch (I.getOpcode()) { - case Instruction::And: - NewOpc = Instruction::Or; - break; - case Instruction::Or: - NewOpc = Instruction::And; - break; - default: +// z = ((~x) |/& (~y)) +// iff both x and y are free to invert and all uses of z can be freely updated. +bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) { + Value *Op0, *Op1; + if (!match(&I, m_LogicalOp(m_Value(Op0), m_Value(Op1)))) return false; - }; - Value *X, *Y; - if (!match(&I, m_c_BinOp(m_Not(m_Value(X)), m_Value(Y)))) + // If this logic op has not been simplified yet, just bail out and let that + // happen first. Otherwise, the code below may wrongly invert. + if (Op0 == Op1) return false; - // Will we be able to fold the `not` into Y eventually? - if (!InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) + Instruction::BinaryOps NewOpc = + match(&I, m_LogicalAnd()) ? Instruction::Or : Instruction::And; + bool IsBinaryOp = isa<BinaryOperator>(I); + + // Can our users be adapted? + if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) + return false; + + // And can the operands be adapted? + for (Value *Op : {Op0, Op1}) + if (!(InstCombiner::isFreeToInvert(Op, /*WillInvertAllUses=*/true) && + (match(Op, m_ImmConstant()) || + (isa<Instruction>(Op) && + InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op), + /*IgnoredUser=*/&I))))) + return false; + + for (Value **Op : {&Op0, &Op1}) { + Value *NotOp; + if (auto *C = dyn_cast<Constant>(*Op)) { + NotOp = ConstantExpr::getNot(C); + } else { + Builder.SetInsertPoint( + &*cast<Instruction>(*Op)->getInsertionPointAfterDef()); + NotOp = Builder.CreateNot(*Op, (*Op)->getName() + ".not"); + (*Op)->replaceUsesWithIf( + NotOp, [NotOp](Use &U) { return U.getUser() != NotOp; }); + freelyInvertAllUsersOf(NotOp, /*IgnoredUser=*/&I); + } + *Op = NotOp; + } + + Builder.SetInsertPoint(I.getInsertionPointAfterDef()); + Value *NewLogicOp; + if (IsBinaryOp) + NewLogicOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not"); + else + NewLogicOp = + Builder.CreateLogicalOp(NewOpc, Op0, Op1, I.getName() + ".not"); + + replaceInstUsesWith(I, NewLogicOp); + // We can not just create an outer `not`, it will most likely be immediately + // folded back, reconstructing our initial pattern, and causing an + // infinite combine loop, so immediately manually fold it away. + freelyInvertAllUsersOf(NewLogicOp); + return true; +} + +// Transform +// z = (~x) &/| y +// into: +// z = ~(x |/& (~y)) +// iff y is free to invert and all uses of z can be freely updated. +bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) { + Value *Op0, *Op1; + if (!match(&I, m_LogicalOp(m_Value(Op0), m_Value(Op1)))) + return false; + Instruction::BinaryOps NewOpc = + match(&I, m_LogicalAnd()) ? Instruction::Or : Instruction::And; + bool IsBinaryOp = isa<BinaryOperator>(I); + + Value *NotOp0 = nullptr; + Value *NotOp1 = nullptr; + Value **OpToInvert = nullptr; + if (match(Op0, m_Not(m_Value(NotOp0))) && + InstCombiner::isFreeToInvert(Op1, /*WillInvertAllUses=*/true) && + (match(Op1, m_ImmConstant()) || + (isa<Instruction>(Op1) && + InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op1), + /*IgnoredUser=*/&I)))) { + Op0 = NotOp0; + OpToInvert = &Op1; + } else if (match(Op1, m_Not(m_Value(NotOp1))) && + InstCombiner::isFreeToInvert(Op0, /*WillInvertAllUses=*/true) && + (match(Op0, m_ImmConstant()) || + (isa<Instruction>(Op0) && + InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op0), + /*IgnoredUser=*/&I)))) { + Op1 = NotOp1; + OpToInvert = &Op0; + } else return false; // And can our users be adapted? if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) return false; - Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); - Value *NewBinOp = - BinaryOperator::Create(NewOpc, X, NotY, I.getName() + ".not"); - Builder.Insert(NewBinOp); + if (auto *C = dyn_cast<Constant>(*OpToInvert)) { + *OpToInvert = ConstantExpr::getNot(C); + } else { + Builder.SetInsertPoint( + &*cast<Instruction>(*OpToInvert)->getInsertionPointAfterDef()); + Value *NotOpToInvert = + Builder.CreateNot(*OpToInvert, (*OpToInvert)->getName() + ".not"); + (*OpToInvert)->replaceUsesWithIf(NotOpToInvert, [NotOpToInvert](Use &U) { + return U.getUser() != NotOpToInvert; + }); + freelyInvertAllUsersOf(NotOpToInvert, /*IgnoredUser=*/&I); + *OpToInvert = NotOpToInvert; + } + + Builder.SetInsertPoint(&*I.getInsertionPointAfterDef()); + Value *NewBinOp; + if (IsBinaryOp) + NewBinOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not"); + else + NewBinOp = Builder.CreateLogicalOp(NewOpc, Op0, Op1, I.getName() + ".not"); replaceInstUsesWith(I, NewBinOp); // We can not just create an outer `not`, it will most likely be immediately // folded back, reconstructing our initial pattern, and causing an @@ -3472,23 +3884,6 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { // Is this a 'not' (~) fed by a binary operator? BinaryOperator *NotVal; if (match(NotOp, m_BinOp(NotVal))) { - if (NotVal->getOpcode() == Instruction::And || - NotVal->getOpcode() == Instruction::Or) { - // Apply DeMorgan's Law when inverts are free: - // ~(X & Y) --> (~X | ~Y) - // ~(X | Y) --> (~X & ~Y) - if (isFreeToInvert(NotVal->getOperand(0), - NotVal->getOperand(0)->hasOneUse()) && - isFreeToInvert(NotVal->getOperand(1), - NotVal->getOperand(1)->hasOneUse())) { - Value *NotX = Builder.CreateNot(NotVal->getOperand(0), "notlhs"); - Value *NotY = Builder.CreateNot(NotVal->getOperand(1), "notrhs"); - if (NotVal->getOpcode() == Instruction::And) - return BinaryOperator::CreateOr(NotX, NotY); - return BinaryOperator::CreateAnd(NotX, NotY); - } - } - // ~((-X) | Y) --> (X - 1) & (~Y) if (match(NotVal, m_OneUse(m_c_Or(m_OneUse(m_Neg(m_Value(X))), m_Value(Y))))) { @@ -3501,6 +3896,14 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) return BinaryOperator::CreateAShr(X, Y); + // Bit-hack form of a signbit test: + // iN ~X >>s (N-1) --> sext i1 (X > -1) to iN + unsigned FullShift = Ty->getScalarSizeInBits() - 1; + if (match(NotVal, m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))))) { + Value *IsNotNeg = Builder.CreateIsNotNeg(X, "isnotneg"); + return new SExtInst(IsNotNeg, Ty); + } + // If we are inverting a right-shifted constant, we may be able to eliminate // the 'not' by inverting the constant and using the opposite shift type. // Canonicalization rules ensure that only a negative constant uses 'ashr', @@ -3545,11 +3948,28 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { // not (cmp A, B) = !cmp A, B CmpInst::Predicate Pred; - if (match(NotOp, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) { + if (match(NotOp, m_Cmp(Pred, m_Value(), m_Value())) && + (NotOp->hasOneUse() || + InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(NotOp), + /*IgnoredUser=*/nullptr))) { cast<CmpInst>(NotOp)->setPredicate(CmpInst::getInversePredicate(Pred)); - return replaceInstUsesWith(I, NotOp); + freelyInvertAllUsersOf(NotOp); + return &I; + } + + // Move a 'not' ahead of casts of a bool to enable logic reduction: + // not (bitcast (sext i1 X)) --> bitcast (sext (not i1 X)) + if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) && X->getType()->isIntOrIntVectorTy(1)) { + Type *SextTy = cast<BitCastOperator>(NotOp)->getSrcTy(); + Value *NotX = Builder.CreateNot(X); + Value *Sext = Builder.CreateSExt(NotX, SextTy); + return CastInst::CreateBitOrPointerCast(Sext, Ty); } + if (auto *NotOpI = dyn_cast<Instruction>(NotOp)) + if (sinkNotIntoLogicalOp(*NotOpI)) + return &I; + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: // ~min(~X, ~Y) --> max(X, Y) // ~max(~X, Y) --> min(X, ~Y) @@ -3570,6 +3990,14 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotY); return replaceInstUsesWith(I, InvMaxMin); } + + if (II->getIntrinsicID() == Intrinsic::is_fpclass) { + ConstantInt *ClassMask = cast<ConstantInt>(II->getArgOperand(1)); + II->setArgOperand( + 1, ConstantInt::get(ClassMask->getType(), + ~ClassMask->getZExtValue() & fcAllFlags)); + return replaceInstUsesWith(I, II); + } } if (NotOp->hasOneUse()) { @@ -3602,7 +4030,7 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { } } - if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) + if (Instruction *NewXor = foldNotXor(I, Builder)) return NewXor; return nullptr; @@ -3629,7 +4057,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { return NewXor; // (A&B)^(A&C) -> A&(B^C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole @@ -3718,6 +4146,21 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { MaskedValueIsZero(X, *C, 0, &I)) return BinaryOperator::CreateXor(X, ConstantInt::get(Ty, *C ^ *RHSC)); + // When X is a power-of-two or zero and zero input is poison: + // ctlz(i32 X) ^ 31 --> cttz(X) + // cttz(i32 X) ^ 31 --> ctlz(X) + auto *II = dyn_cast<IntrinsicInst>(Op0); + if (II && II->hasOneUse() && *RHSC == Ty->getScalarSizeInBits() - 1) { + Intrinsic::ID IID = II->getIntrinsicID(); + if ((IID == Intrinsic::ctlz || IID == Intrinsic::cttz) && + match(II->getArgOperand(1), m_One()) && + isKnownToBeAPowerOfTwo(II->getArgOperand(0), /*OrZero */ true)) { + IID = (IID == Intrinsic::ctlz) ? Intrinsic::cttz : Intrinsic::ctlz; + Function *F = Intrinsic::getDeclaration(II->getModule(), IID, Ty); + return CallInst::Create(F, {II->getArgOperand(0), Builder.getTrue()}); + } + } + // If RHSC is inverting the remaining bits of shifted X, // canonicalize to a 'not' before the shift to help SCEV and codegen: // (X << C) ^ RHSC --> ~X << C @@ -3858,5 +4301,17 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { m_Value(Y)))) return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1); + if (Instruction *R = reassociateForUses(I, Builder)) + return R; + + if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder)) + return Canonicalized; + + if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) + return Folded; + + if (Instruction *Folded = canonicalizeConditionalNegationViaMathToSelect(I)) + return Folded; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index 0327efbf9614..e73667f9c02e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -128,10 +128,9 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) { if (Ordering != AtomicOrdering::Release && Ordering != AtomicOrdering::Monotonic) return nullptr; - auto *SI = new StoreInst(RMWI.getValOperand(), - RMWI.getPointerOperand(), &RMWI); - SI->setAtomic(Ordering, RMWI.getSyncScopeID()); - SI->setAlignment(DL.getABITypeAlign(RMWI.getType())); + new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(), + /*isVolatile*/ false, RMWI.getAlign(), Ordering, + RMWI.getSyncScopeID(), &RMWI); return eraseInstFromFunction(RMWI); } @@ -152,13 +151,5 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) { return replaceOperand(RMWI, 1, ConstantFP::getNegativeZero(RMWI.getType())); } - // Check if the required ordering is compatible with an atomic load. - if (Ordering != AtomicOrdering::Acquire && - Ordering != AtomicOrdering::Monotonic) - return nullptr; - - LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand(), "", - false, DL.getABITypeAlign(RMWI.getType()), - Ordering, RMWI.getSyncScopeID()); - return Load; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index bc01d2ef7fe2..fbf1327143a8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -15,8 +15,6 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" @@ -34,6 +32,7 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" @@ -71,6 +70,7 @@ #include <algorithm> #include <cassert> #include <cstdint> +#include <optional> #include <utility> #include <vector> @@ -135,7 +135,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { // If we have a store to a location which is known constant, we can conclude // that the store must be storing the constant value (else the memory // wouldn't be constant), and this must be a noop. - if (AA->pointsToConstantMemory(MI->getDest())) { + if (!isModSet(AA->getModRefInfoMask(MI->getDest()))) { // Set the size of the copy to 0, it will be deleted on the next iteration. MI->setLength(Constant::getNullValue(MI->getLength()->getType())); return MI; @@ -223,6 +223,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); if (AccessGroupMD) S->setMetadata(LLVMContext::MD_access_group, AccessGroupMD); + S->copyMetadata(*MI, LLVMContext::MD_DIAssignID); if (auto *MT = dyn_cast<MemTransferInst>(MI)) { // non-atomics can be volatile @@ -252,7 +253,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { // If we have a store to a location which is known constant, we can conclude // that the store must be storing the constant value (else the memory // wouldn't be constant), and this must be a noop. - if (AA->pointsToConstantMemory(MI->getDest())) { + if (!isModSet(AA->getModRefInfoMask(MI->getDest()))) { // Set the size of the copy to 0, it will be deleted on the next iteration. MI->setLength(Constant::getNullValue(MI->getLength()->getType())); return MI; @@ -294,9 +295,15 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { Dest = Builder.CreateBitCast(Dest, NewDstPtrTy); // Extract the fill value and store. - uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; - StoreInst *S = Builder.CreateStore(ConstantInt::get(ITy, Fill), Dest, - MI->isVolatile()); + const uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; + Constant *FillVal = ConstantInt::get(ITy, Fill); + StoreInst *S = Builder.CreateStore(FillVal, Dest, MI->isVolatile()); + S->copyMetadata(*MI, LLVMContext::MD_DIAssignID); + for (auto *DAI : at::getAssignmentMarkers(S)) { + if (any_of(DAI->location_ops(), [&](Value *V) { return V == FillC; })) + DAI->replaceVariableLocationOp(FillC, FillVal); + } + S->setAlignment(Alignment); if (isa<AtomicMemSetInst>(MI)) S->setOrdering(AtomicOrdering::Unordered); @@ -328,7 +335,7 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { // If we can unconditionally load from this address, replace with a // load/select idiom. TODO: use DT for context sensitive query if (isDereferenceablePointer(LoadPtr, II.getType(), - II.getModule()->getDataLayout(), &II, nullptr)) { + II.getModule()->getDataLayout(), &II, &AC)) { LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); LI->copyMetadata(II); @@ -661,10 +668,21 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { // If all bits are zero except for exactly one fixed bit, then the result // must be 0 or 1, and we can get that answer by shifting to LSB: // ctpop (X & 32) --> (X & 32) >> 5 + // TODO: Investigate removing this as its likely unnecessary given the below + // `isKnownToBeAPowerOfTwo` check. if ((~Known.Zero).isPowerOf2()) return BinaryOperator::CreateLShr( Op0, ConstantInt::get(Ty, (~Known.Zero).exactLogBase2())); + // More generally we can also handle non-constant power of 2 patterns such as + // shl/shr(Pow2, X), (X & -X), etc... by transforming: + // ctpop(Pow2OrZero) --> icmp ne X, 0 + if (IC.isKnownToBeAPowerOfTwo(Op0, /* OrZero */ true)) + return CastInst::Create(Instruction::ZExt, + IC.Builder.CreateICmp(ICmpInst::ICMP_NE, Op0, + Constant::getNullValue(Ty)), + Ty); + // FIXME: Try to simplify vectors of integers. auto *IT = dyn_cast<IntegerType>(Ty); if (!IT) @@ -720,7 +738,7 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II, auto *V1 = II.getArgOperand(0); auto *V2 = Constant::getNullValue(V1->getType()); - return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes)); + return Builder.CreateShuffleVector(V1, V2, ArrayRef(Indexes)); } // Returns true iff the 2 intrinsics have the same operands, limiting the @@ -812,9 +830,10 @@ InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { return nullptr; } -static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI, - const DataLayout &DL, AssumptionCache *AC, - DominatorTree *DT) { +static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, + const DataLayout &DL, + AssumptionCache *AC, + DominatorTree *DT) { KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); if (Known.isNonNegative()) return false; @@ -1266,7 +1285,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) return replaceOperand(*II, 0, X); - if (Optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) { + if (std::optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) { // abs(x) -> x if x >= 0 if (!*Sign) return replaceInstUsesWith(*II, IIOperand); @@ -1297,11 +1316,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); // umin(x, 1) == zext(x != 0) if (match(I1, m_One())) { + assert(II->getType()->getScalarSizeInBits() != 1 && + "Expected simplify of umin with max constant"); Value *Zero = Constant::getNullValue(I0->getType()); Value *Cmp = Builder.CreateICmpNE(I0, Zero); return CastInst::Create(Instruction::ZExt, Cmp, II->getType()); } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::umax: { Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); @@ -1322,7 +1343,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } // If both operands of unsigned min/max are sign-extended, it is still ok // to narrow the operation. - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::smax: case Intrinsic::smin: { @@ -1431,6 +1452,18 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } + case Intrinsic::bitreverse: { + // bitrev (zext i1 X to ?) --> X ? SignBitC : 0 + Value *X; + if (match(II->getArgOperand(0), m_ZExt(m_Value(X))) && + X->getType()->isIntOrIntVectorTy(1)) { + Type *Ty = II->getType(); + APInt SignBit = APInt::getSignMask(Ty->getScalarSizeInBits()); + return SelectInst::Create(X, ConstantInt::get(Ty, SignBit), + ConstantInt::getNullValue(Ty)); + } + break; + } case Intrinsic::bswap: { Value *IIOperand = II->getArgOperand(0); @@ -1829,6 +1862,63 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } + case Intrinsic::matrix_multiply: { + // Optimize negation in matrix multiplication. + + // -A * -B -> A * B + Value *A, *B; + if (match(II->getArgOperand(0), m_FNeg(m_Value(A))) && + match(II->getArgOperand(1), m_FNeg(m_Value(B)))) { + replaceOperand(*II, 0, A); + replaceOperand(*II, 1, B); + return II; + } + + Value *Op0 = II->getOperand(0); + Value *Op1 = II->getOperand(1); + Value *OpNotNeg, *NegatedOp; + unsigned NegatedOpArg, OtherOpArg; + if (match(Op0, m_FNeg(m_Value(OpNotNeg)))) { + NegatedOp = Op0; + NegatedOpArg = 0; + OtherOpArg = 1; + } else if (match(Op1, m_FNeg(m_Value(OpNotNeg)))) { + NegatedOp = Op1; + NegatedOpArg = 1; + OtherOpArg = 0; + } else + // Multiplication doesn't have a negated operand. + break; + + // Only optimize if the negated operand has only one use. + if (!NegatedOp->hasOneUse()) + break; + + Value *OtherOp = II->getOperand(OtherOpArg); + VectorType *RetTy = cast<VectorType>(II->getType()); + VectorType *NegatedOpTy = cast<VectorType>(NegatedOp->getType()); + VectorType *OtherOpTy = cast<VectorType>(OtherOp->getType()); + ElementCount NegatedCount = NegatedOpTy->getElementCount(); + ElementCount OtherCount = OtherOpTy->getElementCount(); + ElementCount RetCount = RetTy->getElementCount(); + // (-A) * B -> A * (-B), if it is cheaper to negate B and vice versa. + if (ElementCount::isKnownGT(NegatedCount, OtherCount) && + ElementCount::isKnownLT(OtherCount, RetCount)) { + Value *InverseOtherOp = Builder.CreateFNeg(OtherOp); + replaceOperand(*II, NegatedOpArg, OpNotNeg); + replaceOperand(*II, OtherOpArg, InverseOtherOp); + return II; + } + // (-A) * B -> -(A * B), if it is cheaper to negate the result + if (ElementCount::isKnownGT(NegatedCount, RetCount)) { + SmallVector<Value *, 5> NewArgs(II->args()); + NewArgs[NegatedOpArg] = OpNotNeg; + Instruction *NewMul = + Builder.CreateIntrinsic(II->getType(), IID, NewArgs, II); + return replaceInstUsesWith(*II, Builder.CreateFNegFMF(NewMul, II)); + } + break; + } case Intrinsic::fmuladd: { // Canonicalize fast fmuladd to the separate fmul + fadd. if (II->isFast()) { @@ -1850,7 +1940,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return FAdd; } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::fma: { // fma fneg(x), fneg(y), z -> fma x, y, z @@ -1940,7 +2030,17 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return replaceOperand(*II, 0, TVal); } - LLVM_FALLTHROUGH; + Value *Magnitude, *Sign; + if (match(II->getArgOperand(0), + m_CopySign(m_Value(Magnitude), m_Value(Sign)))) { + // fabs (copysign x, y) -> (fabs x) + CallInst *AbsSign = + Builder.CreateCall(II->getCalledFunction(), {Magnitude}); + AbsSign->copyFastMathFlags(II); + return replaceInstUsesWith(*II, AbsSign); + } + + [[fallthrough]]; } case Intrinsic::ceil: case Intrinsic::floor: @@ -1979,7 +2079,64 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::ptrauth_auth: + case Intrinsic::ptrauth_resign: { + // (sign|resign) + (auth|resign) can be folded by omitting the middle + // sign+auth component if the key and discriminator match. + bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign; + Value *Key = II->getArgOperand(1); + Value *Disc = II->getArgOperand(2); + + // AuthKey will be the key we need to end up authenticating against in + // whatever we replace this sequence with. + Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr; + if (auto CI = dyn_cast<CallBase>(II->getArgOperand(0))) { + BasePtr = CI->getArgOperand(0); + if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) { + if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc) + break; + } else if (CI->getIntrinsicID() == Intrinsic::ptrauth_resign) { + if (CI->getArgOperand(3) != Key || CI->getArgOperand(4) != Disc) + break; + AuthKey = CI->getArgOperand(1); + AuthDisc = CI->getArgOperand(2); + } else + break; + } else + break; + + unsigned NewIntrin; + if (AuthKey && NeedSign) { + // resign(0,1) + resign(1,2) = resign(0, 2) + NewIntrin = Intrinsic::ptrauth_resign; + } else if (AuthKey) { + // resign(0,1) + auth(1) = auth(0) + NewIntrin = Intrinsic::ptrauth_auth; + } else if (NeedSign) { + // sign(0) + resign(0, 1) = sign(1) + NewIntrin = Intrinsic::ptrauth_sign; + } else { + // sign(0) + auth(0) = nop + replaceInstUsesWith(*II, BasePtr); + eraseInstFromFunction(*II); + return nullptr; + } + + SmallVector<Value *, 4> CallArgs; + CallArgs.push_back(BasePtr); + if (AuthKey) { + CallArgs.push_back(AuthKey); + CallArgs.push_back(AuthDisc); + } + if (NeedSign) { + CallArgs.push_back(II->getArgOperand(3)); + CallArgs.push_back(II->getArgOperand(4)); + } + + Function *NewFn = Intrinsic::getDeclaration(II->getModule(), NewIntrin); + return CallInst::Create(NewFn, CallArgs); + } case Intrinsic::arm_neon_vtbl1: case Intrinsic::aarch64_neon_tbl1: if (Value *V = simplifyNeonTbl1(*II, Builder)) @@ -2221,7 +2378,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Pred == ICmpInst::ICMP_NE && LHS->getOpcode() == Instruction::Load && LHS->getType()->isPointerTy() && isValidAssumeForContext(II, LHS, &DT)) { - MDNode *MD = MDNode::get(II->getContext(), None); + MDNode *MD = MDNode::get(II->getContext(), std::nullopt); LHS->setMetadata(LLVMContext::MD_nonnull, MD); return RemoveConditionFromAssume(II); @@ -2288,7 +2445,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { llvm::getKnowledgeFromBundle(cast<AssumeInst>(*II), BOI); if (BOI.End - BOI.Begin > 2) continue; // Prevent reducing knowledge in an align with offset since - // extracting a RetainedKnowledge form them looses offset + // extracting a RetainedKnowledge from them looses offset // information RetainedKnowledge CanonRK = llvm::simplifyRetainedKnowledge(cast<AssumeInst>(II), RK, @@ -2409,7 +2566,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Vec = II->getArgOperand(0); Value *Idx = II->getArgOperand(1); - auto *DstTy = dyn_cast<FixedVectorType>(II->getType()); + Type *ReturnType = II->getType(); + // (extract_vector (insert_vector InsertTuple, InsertValue, InsertIdx), + // ExtractIdx) + unsigned ExtractIdx = cast<ConstantInt>(Idx)->getZExtValue(); + Value *InsertTuple, *InsertIdx, *InsertValue; + if (match(Vec, m_Intrinsic<Intrinsic::vector_insert>(m_Value(InsertTuple), + m_Value(InsertValue), + m_Value(InsertIdx))) && + InsertValue->getType() == ReturnType) { + unsigned Index = cast<ConstantInt>(InsertIdx)->getZExtValue(); + // Case where we get the same index right after setting it. + // extract.vector(insert.vector(InsertTuple, InsertValue, Idx), Idx) --> + // InsertValue + if (ExtractIdx == Index) + return replaceInstUsesWith(CI, InsertValue); + // If we are getting a different index than what was set in the + // insert.vector intrinsic. We can just set the input tuple to the one up + // in the chain. extract.vector(insert.vector(InsertTuple, InsertValue, + // InsertIndex), ExtractIndex) + // --> extract.vector(InsertTuple, ExtractIndex) + else + return replaceOperand(CI, 0, InsertTuple); + } + + auto *DstTy = dyn_cast<FixedVectorType>(ReturnType); auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); // Only canonicalize if the the destination vector and Vec are fixed @@ -2439,11 +2620,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Vec = II->getArgOperand(0); if (match(Vec, m_OneUse(m_BinOp(m_Value(BO0), m_Value(BO1))))) { auto *OldBinOp = cast<BinaryOperator>(Vec); - if (match(BO0, m_Intrinsic<Intrinsic::experimental_vector_reverse>( - m_Value(X)))) { + if (match(BO0, m_VecReverse(m_Value(X)))) { // rev(binop rev(X), rev(Y)) --> binop X, Y - if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( - m_Value(Y)))) + if (match(BO1, m_VecReverse(m_Value(Y)))) return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( OldBinOp->getOpcode(), X, Y, OldBinOp, @@ -2456,17 +2635,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { OldBinOp, OldBinOp->getName(), II)); } // rev(binop BO0Splat, rev(Y)) --> binop BO0Splat, Y - if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( - m_Value(Y))) && - isSplatValue(BO0)) + if (match(BO1, m_VecReverse(m_Value(Y))) && isSplatValue(BO0)) return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( OldBinOp->getOpcode(), BO0, Y, OldBinOp, OldBinOp->getName(), II)); } // rev(unop rev(X)) --> unop X - if (match(Vec, m_OneUse(m_UnOp( - m_Intrinsic<Intrinsic::experimental_vector_reverse>( - m_Value(X)))))) { + if (match(Vec, m_OneUse(m_UnOp(m_VecReverse(m_Value(X)))))) { auto *OldUnOp = cast<UnaryOperator>(Vec); auto *NewUnOp = UnaryOperator::CreateWithCopiedFlags( OldUnOp->getOpcode(), X, OldUnOp, OldUnOp->getName(), II); @@ -2504,7 +2679,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return replaceInstUsesWith(CI, Res); } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_add: { if (IID == Intrinsic::vector_reduce_add) { @@ -2531,7 +2706,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_xor: { if (IID == Intrinsic::vector_reduce_xor) { @@ -2555,7 +2730,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_mul: { if (IID == Intrinsic::vector_reduce_mul) { @@ -2577,7 +2752,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_umin: case Intrinsic::vector_reduce_umax: { @@ -2604,7 +2779,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_smin: case Intrinsic::vector_reduce_smax: { @@ -2642,7 +2817,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_fmax: case Intrinsic::vector_reduce_fmin: @@ -2679,9 +2854,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } default: { // Handle target specific intrinsics - Optional<Instruction *> V = targetInstCombineIntrinsic(*II); + std::optional<Instruction *> V = targetInstCombineIntrinsic(*II); if (V) - return V.value(); + return *V; break; } } @@ -2887,7 +3062,7 @@ bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, if (!Call.getType()->isPointerTy()) return Changed; - Optional<APInt> Size = getAllocSize(&Call, TLI); + std::optional<APInt> Size = getAllocSize(&Call, TLI); if (Size && *Size != 0) { // TODO: We really should just emit deref_or_null here and then // let the generic inference code combine that with nonnull. @@ -3078,6 +3253,30 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { Call, Builder.CreateBitOrPointerCast(ReturnedArg, CallTy)); } + // Drop unnecessary kcfi operand bundles from calls that were converted + // into direct calls. + auto Bundle = Call.getOperandBundle(LLVMContext::OB_kcfi); + if (Bundle && !Call.isIndirectCall()) { + DEBUG_WITH_TYPE(DEBUG_TYPE "-kcfi", { + if (CalleeF) { + ConstantInt *FunctionType = nullptr; + ConstantInt *ExpectedType = cast<ConstantInt>(Bundle->Inputs[0]); + + if (MDNode *MD = CalleeF->getMetadata(LLVMContext::MD_kcfi_type)) + FunctionType = mdconst::extract<ConstantInt>(MD->getOperand(0)); + + if (FunctionType && + FunctionType->getZExtValue() != ExpectedType->getZExtValue()) + dbgs() << Call.getModule()->getName() + << ": warning: kcfi: " << Call.getCaller()->getName() + << ": call to " << CalleeF->getName() + << " using a mismatching function pointer type\n"; + } + }); + + return CallBase::removeOperandBundle(&Call, LLVMContext::OB_kcfi); + } + if (isRemovableAlloc(&Call, &TLI)) return visitAllocSite(Call); @@ -3140,7 +3339,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { LiveGcValues.insert(BasePtr); LiveGcValues.insert(DerivedPtr); } - Optional<OperandBundleUse> Bundle = + std::optional<OperandBundleUse> Bundle = GCSP.getOperandBundle(LLVMContext::OB_gc_live); unsigned NumOfGCLives = LiveGcValues.size(); if (!Bundle || NumOfGCLives == Bundle->Inputs.size()) @@ -3148,8 +3347,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { // We can reduce the size of gc live bundle. DenseMap<Value *, unsigned> Val2Idx; std::vector<Value *> NewLiveGc; - for (unsigned I = 0, E = Bundle->Inputs.size(); I < E; ++I) { - Value *V = Bundle->Inputs[I]; + for (Value *V : Bundle->Inputs) { if (Val2Idx.count(V)) continue; if (LiveGcValues.count(V)) { @@ -3289,6 +3487,10 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (CallerPAL.hasParamAttr(i, Attribute::SwiftError)) return false; + if (CallerPAL.hasParamAttr(i, Attribute::ByVal) != + Callee->getAttributes().hasParamAttr(i, Attribute::ByVal)) + return false; // Cannot transform to or from byval. + // If the parameter is passed as a byval argument, then we have to have a // sized type and the sized type has to have the same size as the old type. if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) { @@ -3447,21 +3649,12 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy); NC->setDebugLoc(Caller->getDebugLoc()); - // If this is an invoke/callbr instruction, we should insert it after the - // first non-phi instruction in the normal successor block. - if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { - BasicBlock::iterator I = II->getNormalDest()->getFirstInsertionPt(); - InsertNewInstBefore(NC, *I); - } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) { - BasicBlock::iterator I = CBI->getDefaultDest()->getFirstInsertionPt(); - InsertNewInstBefore(NC, *I); - } else { - // Otherwise, it's a call, just insert cast right after the call. - InsertNewInstBefore(NC, *Caller); - } + Instruction *InsertPt = NewCall->getInsertionPointAfterDef(); + assert(InsertPt && "No place to insert cast"); + InsertNewInstBefore(NC, *InsertPt); Worklist.pushUsersToWorkList(*Caller); } else { - NV = UndefValue::get(Caller->getType()); + NV = PoisonValue::get(Caller->getType()); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index a9a930555b3c..3f851a2b2182 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -14,9 +14,12 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <optional> + using namespace llvm; using namespace PatternMatch; @@ -118,14 +121,15 @@ Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI, if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr; // The alloc and cast types should be either both fixed or both scalable. - uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinSize(); - uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinSize(); + uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinValue(); + uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinValue(); if (CastElTySize == 0 || AllocElTySize == 0) return nullptr; // If the allocation has multiple uses, only promote it if we're not // shrinking the amount of memory being allocated. - uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy).getKnownMinSize(); - uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinSize(); + uint64_t AllocElTyStoreSize = + DL.getTypeStoreSize(AllocElTy).getKnownMinValue(); + uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinValue(); if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr; // See if we can satisfy the modulus by pulling a scale out of the array @@ -163,6 +167,10 @@ Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI, New->setAlignment(AI.getAlign()); New->takeName(&AI); New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); + New->setMetadata(LLVMContext::MD_DIAssignID, + AI.getMetadata(LLVMContext::MD_DIAssignID)); + + replaceAllDbgUsesWith(AI, *New, *New, DT); // If the allocation has multiple real uses, insert a cast and change all // things that used it to use the new cast. This will also hack on CI, but it @@ -239,6 +247,11 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, Res = NPN; break; } + case Instruction::FPToUI: + case Instruction::FPToSI: + Res = CastInst::Create( + static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty); + break; default: // TODO: Can handle more cases here. llvm_unreachable("Unreachable!"); @@ -483,6 +496,22 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, return false; return true; } + case Instruction::FPToUI: + case Instruction::FPToSI: { + // If the integer type can hold the max FP value, it is safe to cast + // directly to that type. Otherwise, we may create poison via overflow + // that did not exist in the original code. + // + // The max FP value is pow(2, MaxExponent) * (1 + MaxFraction), so we need + // at least one more bit than the MaxExponent to hold the max FP value. + Type *InputTy = I->getOperand(0)->getType()->getScalarType(); + const fltSemantics &Semantics = InputTy->getFltSemantics(); + uint32_t MinBitWidth = APFloatBase::semanticsMaxExponent(Semantics); + // Extra sign bit needed. + if (I->getOpcode() == Instruction::FPToSI) + ++MinBitWidth; + return Ty->getScalarSizeInBits() > MinBitWidth; + } default: // TODO: Can handle more cases here. break; @@ -726,7 +755,7 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc, InstCombiner::BuilderTy &Builder) { auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0)); if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) && - is_splat(Shuf->getShuffleMask()) && + all_equal(Shuf->getShuffleMask()) && Shuf->getType() == Shuf->getOperand(0)->getType()) { // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask @@ -974,7 +1003,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange); - if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { if (Log2_32(*MaxVScale) < DestWidth) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); return replaceInstUsesWith(Trunc, VScale); @@ -986,7 +1015,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { return nullptr; } -Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) { +Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, + ZExtInst &Zext) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. @@ -1014,28 +1044,20 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - // zext (X == 1) to i32 --> X iff X has only the low bit set. - // zext (X == 2) to i32 --> X>>1 iff X has only the 2nd bit set. // zext (X != 0) to i32 --> X iff X has only the low bit set. // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. - // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. - // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV->isZero() || Op1CV->isPowerOf2()) && - // This only works for EQ and NE - Cmp->isEquality()) { + if (Op1CV->isZero() && Cmp->isEquality() && + (Cmp->getOperand(0)->getType() == Zext.getType() || + Cmp->getPredicate() == ICmpInst::ICMP_NE)) { // If Op1C some other power of two, convert: KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext); + // Exactly 1 possible 1? But not the high-bit because that is + // canonicalized to this form. APInt KnownZeroMask(~Known.Zero); - if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE; - if (!Op1CV->isZero() && (*Op1CV != KnownZeroMask)) { - // (X&4) == 2 --> false - // (X&4) != 2 --> true - Constant *Res = ConstantInt::get(Zext.getType(), isNE); - return replaceInstUsesWith(Zext, Res); - } - + if (KnownZeroMask.isPowerOf2() && + (Zext.getType()->getScalarSizeInBits() != + KnownZeroMask.logBase2() + 1)) { uint32_t ShAmt = KnownZeroMask.logBase2(); Value *In = Cmp->getOperand(0); if (ShAmt) { @@ -1045,10 +1067,9 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) In->getName() + ".lobit"); } - if (!Op1CV->isZero() == isNE) { // Toggle the low bit. - Constant *One = ConstantInt::get(In->getType(), 1); - In = Builder.CreateXor(In, One); - } + // Toggle the low bit for "X == 0". + if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) + In = Builder.CreateXor(In, ConstantInt::get(In->getType(), 1)); if (Zext.getType() == In->getType()) return replaceInstUsesWith(Zext, In); @@ -1073,39 +1094,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) Value *And1 = Builder.CreateAnd(Lshr, ConstantInt::get(X->getType(), 1)); return replaceInstUsesWith(Zext, And1); } - - // icmp ne A, B is equal to xor A, B when A and B only really have one bit. - // It is also profitable to transform icmp eq into not(xor(A, B)) because - // that may lead to additional simplifications. - if (IntegerType *ITy = dyn_cast<IntegerType>(Zext.getType())) { - Value *LHS = Cmp->getOperand(0); - Value *RHS = Cmp->getOperand(1); - - KnownBits KnownLHS = computeKnownBits(LHS, 0, &Zext); - KnownBits KnownRHS = computeKnownBits(RHS, 0, &Zext); - - if (KnownLHS == KnownRHS) { - APInt KnownBits = KnownLHS.Zero | KnownLHS.One; - APInt UnknownBit = ~KnownBits; - if (UnknownBit.countPopulation() == 1) { - Value *Result = Builder.CreateXor(LHS, RHS); - - // Mask off any bits that are set and won't be shifted away. - if (KnownLHS.One.uge(UnknownBit)) - Result = Builder.CreateAnd(Result, - ConstantInt::get(ITy, UnknownBit)); - - // Shift the bit we're testing down to the lsb. - Result = Builder.CreateLShr( - Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros())); - - if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) - Result = Builder.CreateXor(Result, ConstantInt::get(ITy, 1)); - Result->takeName(Cmp); - return replaceInstUsesWith(Zext, Result); - } - } - } } return nullptr; @@ -1235,23 +1223,23 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, } } -Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { +Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { // If this zero extend is only used by a truncate, let the truncate be // eliminated before we try to optimize this zext. - if (CI.hasOneUse() && isa<TruncInst>(CI.user_back())) + if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back())) return nullptr; // If one of the common conversion will work, do it. - if (Instruction *Result = commonCastTransforms(CI)) + if (Instruction *Result = commonCastTransforms(Zext)) return Result; - Value *Src = CI.getOperand(0); - Type *SrcTy = Src->getType(), *DestTy = CI.getType(); + Value *Src = Zext.getOperand(0); + Type *SrcTy = Src->getType(), *DestTy = Zext.getType(); // Try to extend the entire expression tree to the wide destination type. unsigned BitsToClear; if (shouldChangeType(SrcTy, DestTy) && - canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { + canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &Zext)) { assert(BitsToClear <= SrcTy->getScalarSizeInBits() && "Can't clear more bits than in SrcTy"); @@ -1259,25 +1247,25 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid zero extend: " - << CI << '\n'); + << Zext << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); // Preserve debug values referring to Src if the zext is its last use. if (auto *SrcOp = dyn_cast<Instruction>(Src)) if (SrcOp->hasOneUse()) - replaceAllDbgUsesWith(*SrcOp, *Res, CI, DT); + replaceAllDbgUsesWith(*SrcOp, *Res, Zext, DT); - uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits()-BitsToClear; + uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits() - BitsToClear; uint32_t DestBitSize = DestTy->getScalarSizeInBits(); // If the high bits are already filled with zeros, just replace this // cast with the result. if (MaskedValueIsZero(Res, APInt::getHighBitsSet(DestBitSize, - DestBitSize-SrcBitsKept), - 0, &CI)) - return replaceInstUsesWith(CI, Res); + DestBitSize - SrcBitsKept), + 0, &Zext)) + return replaceInstUsesWith(Zext, Res); // We need to emit an AND to clear the high bits. Constant *C = ConstantInt::get(Res->getType(), @@ -1288,7 +1276,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { // If this is a TRUNC followed by a ZEXT then we are dealing with integral // types and if the sizes are just right we can convert this into a logical // 'and' which will be much cheaper than the pair of casts. - if (TruncInst *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast + if (auto *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast // TODO: Subsume this into EvaluateInDifferentType. // Get the sizes of the types involved. We know that the intermediate type @@ -1296,7 +1284,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { Value *A = CSrc->getOperand(0); unsigned SrcSize = A->getType()->getScalarSizeInBits(); unsigned MidSize = CSrc->getType()->getScalarSizeInBits(); - unsigned DstSize = CI.getType()->getScalarSizeInBits(); + unsigned DstSize = DestTy->getScalarSizeInBits(); // If we're actually extending zero bits, then if // SrcSize < DstSize: zext(a & mask) // SrcSize == DstSize: a & mask @@ -1305,7 +1293,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize)); Constant *AndConst = ConstantInt::get(A->getType(), AndValue); Value *And = Builder.CreateAnd(A, AndConst, CSrc->getName() + ".mask"); - return new ZExtInst(And, CI.getType()); + return new ZExtInst(And, DestTy); } if (SrcSize == DstSize) { @@ -1314,7 +1302,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { AndValue)); } if (SrcSize > DstSize) { - Value *Trunc = Builder.CreateTrunc(A, CI.getType()); + Value *Trunc = Builder.CreateTrunc(A, DestTy); APInt AndValue(APInt::getLowBitsSet(DstSize, MidSize)); return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Trunc->getType(), @@ -1322,34 +1310,46 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { } } - if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src)) - return transformZExtICmp(Cmp, CI); + if (auto *Cmp = dyn_cast<ICmpInst>(Src)) + return transformZExtICmp(Cmp, Zext); // zext(trunc(X) & C) -> (X & zext(C)). Constant *C; Value *X; if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && - X->getType() == CI.getType()) - return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, CI.getType())); + X->getType() == DestTy) + return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, DestTy)); // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). Value *And; if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) && - X->getType() == CI.getType()) { - Constant *ZC = ConstantExpr::getZExt(C, CI.getType()); + X->getType() == DestTy) { + Constant *ZC = ConstantExpr::getZExt(C, DestTy); return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC); } + // If we are truncating, masking, and then zexting back to the original type, + // that's just a mask. This is not handled by canEvaluateZextd if the + // intermediate values have extra uses. This could be generalized further for + // a non-constant mask operand. + // zext (and (trunc X), C) --> and X, (zext C) + if (match(Src, m_And(m_Trunc(m_Value(X)), m_Constant(C))) && + X->getType() == DestTy) { + Constant *ZextC = ConstantExpr::getZExt(C, DestTy); + return BinaryOperator::CreateAnd(X, ZextC); + } + if (match(Src, m_VScale(DL))) { - if (CI.getFunction() && - CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); - if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + if (Zext.getFunction() && + Zext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + Attribute Attr = + Zext.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); if (Log2_32(*MaxVScale) < TypeWidth) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(CI, VScale); + return replaceInstUsesWith(Zext, VScale); } } } @@ -1359,48 +1359,44 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { } /// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp. -Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI, - Instruction &CI) { - Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1); - ICmpInst::Predicate Pred = ICI->getPredicate(); +Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, + SExtInst &Sext) { + Value *Op0 = Cmp->getOperand(0), *Op1 = Cmp->getOperand(1); + ICmpInst::Predicate Pred = Cmp->getPredicate(); // Don't bother if Op1 isn't of vector or integer type. if (!Op1->getType()->isIntOrIntVectorTy()) return nullptr; - if ((Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) || - (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))) { - // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if negative - // (x >s -1) ? -1 : 0 -> not (ashr x, 31) -> all ones if positive + if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) { + // sext (x <s 0) --> ashr x, 31 (all ones if negative) Value *Sh = ConstantInt::get(Op0->getType(), Op0->getType()->getScalarSizeInBits() - 1); Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit"); - if (In->getType() != CI.getType()) - In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/); + if (In->getType() != Sext.getType()) + In = Builder.CreateIntCast(In, Sext.getType(), true /*SExt*/); - if (Pred == ICmpInst::ICMP_SGT) - In = Builder.CreateNot(In, In->getName() + ".not"); - return replaceInstUsesWith(CI, In); + return replaceInstUsesWith(Sext, In); } if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { // If we know that only one bit of the LHS of the icmp can be set and we // have an equality comparison with zero or a power of 2, we can transform // the icmp and sext into bitwise/integer operations. - if (ICI->hasOneUse() && - ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ - KnownBits Known = computeKnownBits(Op0, 0, &CI); + if (Cmp->hasOneUse() && + Cmp->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ + KnownBits Known = computeKnownBits(Op0, 0, &Sext); APInt KnownZeroMask(~Known.Zero); if (KnownZeroMask.isPowerOf2()) { - Value *In = ICI->getOperand(0); + Value *In = Cmp->getOperand(0); // If the icmp tests for a known zero bit we can constant fold it. if (!Op1C->isZero() && Op1C->getValue() != KnownZeroMask) { Value *V = Pred == ICmpInst::ICMP_NE ? - ConstantInt::getAllOnesValue(CI.getType()) : - ConstantInt::getNullValue(CI.getType()); - return replaceInstUsesWith(CI, V); + ConstantInt::getAllOnesValue(Sext.getType()) : + ConstantInt::getNullValue(Sext.getType()); + return replaceInstUsesWith(Sext, V); } if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { @@ -1431,9 +1427,9 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI, KnownZeroMask.getBitWidth() - 1), "sext"); } - if (CI.getType() == In->getType()) - return replaceInstUsesWith(CI, In); - return CastInst::CreateIntegerCast(In, CI.getType(), true/*SExt*/); + if (Sext.getType() == In->getType()) + return replaceInstUsesWith(Sext, In); + return CastInst::CreateIntegerCast(In, Sext.getType(), true/*SExt*/); } } } @@ -1496,22 +1492,22 @@ static bool canEvaluateSExtd(Value *V, Type *Ty) { return false; } -Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { +Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { // If this sign extend is only used by a truncate, let the truncate be // eliminated before we try to optimize this sext. - if (CI.hasOneUse() && isa<TruncInst>(CI.user_back())) + if (Sext.hasOneUse() && isa<TruncInst>(Sext.user_back())) return nullptr; - if (Instruction *I = commonCastTransforms(CI)) + if (Instruction *I = commonCastTransforms(Sext)) return I; - Value *Src = CI.getOperand(0); - Type *SrcTy = Src->getType(), *DestTy = CI.getType(); + Value *Src = Sext.getOperand(0); + Type *SrcTy = Src->getType(), *DestTy = Sext.getType(); unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); unsigned DestBitSize = DestTy->getScalarSizeInBits(); // If the value being extended is zero or positive, use a zext instead. - if (isKnownNonNegative(Src, DL, 0, &AC, &CI, &DT)) + if (isKnownNonNegative(Src, DL, 0, &AC, &Sext, &DT)) return CastInst::Create(Instruction::ZExt, Src, DestTy); // Try to extend the entire expression tree to the wide destination type. @@ -1520,14 +1516,14 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid sign extend: " - << CI << '\n'); + << Sext << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, true); assert(Res->getType() == DestTy); // If the high bits are already filled with sign bit, just replace this // cast with the result. - if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize) - return replaceInstUsesWith(CI, Res); + if (ComputeNumSignBits(Res, 0, &Sext) > DestBitSize - SrcBitSize) + return replaceInstUsesWith(Sext, Res); // We need to emit a shl + ashr to do the sign extend. Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); @@ -1540,7 +1536,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { // If the input has more sign bits than bits truncated, then convert // directly to final type. unsigned XBitSize = X->getType()->getScalarSizeInBits(); - if (ComputeNumSignBits(X, 0, &CI) > XBitSize - SrcBitSize) + if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize) return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true); // If input is a trunc from the destination type, then convert into shifts. @@ -1563,8 +1559,8 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { } } - if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) - return transformSExtICmp(ICI, CI); + if (auto *Cmp = dyn_cast<ICmpInst>(Src)) + return transformSExtICmp(Cmp, Sext); // If the input is a shl/ashr pair of a same constant, then this is a sign // extension from a smaller value. If we could trust arbitrary bitwidth @@ -1593,7 +1589,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { NumLowbitsLeft); NewShAmt = Constant::mergeUndefsWith(Constant::mergeUndefsWith(NewShAmt, BA), CA); - A = Builder.CreateShl(A, NewShAmt, CI.getName()); + A = Builder.CreateShl(A, NewShAmt, Sext.getName()); return BinaryOperator::CreateAShr(A, NewShAmt); } @@ -1616,13 +1612,14 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { } if (match(Src, m_VScale(DL))) { - if (CI.getFunction() && - CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); - if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + if (Sext.getFunction() && + Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + Attribute Attr = + Sext.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(CI, VScale); + return replaceInstUsesWith(Sext, VScale); } } } @@ -1659,7 +1656,6 @@ static Type *shrinkFPConstant(ConstantFP *CFP) { // Determine if this is a vector of ConstantFPs and if so, return the minimal // type we can safely truncate all elements to. -// TODO: Make these support undef elements. static Type *shrinkFPConstantVector(Value *V) { auto *CV = dyn_cast<Constant>(V); auto *CVVTy = dyn_cast<FixedVectorType>(V->getType()); @@ -1673,6 +1669,9 @@ static Type *shrinkFPConstantVector(Value *V) { // For fixed-width vectors we find the minimal type by looking // through the constant values of the vector. for (unsigned i = 0; i != NumElts; ++i) { + if (isa<UndefValue>(CV->getAggregateElement(i))) + continue; + auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); if (!CFP) return nullptr; @@ -1688,7 +1687,7 @@ static Type *shrinkFPConstantVector(Value *V) { } // Make a vector type from the minimal type. - return FixedVectorType::get(MinType, NumElts); + return MinType ? FixedVectorType::get(MinType, NumElts) : nullptr; } /// Find the minimum FP type we can safely truncate to. @@ -2862,21 +2861,27 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { } } - // A bitcasted-to-scalar and byte-reversing shuffle is better recognized as - // a byte-swap: - // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) --> bswap (bitcast X) - // TODO: We should match the related pattern for bitreverse. - if (DestTy->isIntegerTy() && - DL.isLegalInteger(DestTy->getScalarSizeInBits()) && - SrcTy->getScalarSizeInBits() == 8 && - ShufElts.getKnownMinValue() % 2 == 0 && Shuf->hasOneUse() && - Shuf->isReverse()) { - assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask"); - assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op"); - Function *Bswap = - Intrinsic::getDeclaration(CI.getModule(), Intrinsic::bswap, DestTy); - Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy); - return CallInst::Create(Bswap, { ScalarX }); + // A bitcasted-to-scalar and byte/bit reversing shuffle is better recognized + // as a byte/bit swap: + // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) -> bswap (bitcast X) + // bitcast <N x i1> (shuf X, undef, <N, N-1,...0>) -> bitreverse (bitcast X) + if (DestTy->isIntegerTy() && ShufElts.getKnownMinValue() % 2 == 0 && + Shuf->hasOneUse() && Shuf->isReverse()) { + unsigned IntrinsicNum = 0; + if (DL.isLegalInteger(DestTy->getScalarSizeInBits()) && + SrcTy->getScalarSizeInBits() == 8) { + IntrinsicNum = Intrinsic::bswap; + } else if (SrcTy->getScalarSizeInBits() == 1) { + IntrinsicNum = Intrinsic::bitreverse; + } + if (IntrinsicNum != 0) { + assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask"); + assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op"); + Function *BswapOrBitreverse = + Intrinsic::getDeclaration(CI.getModule(), IntrinsicNum, DestTy); + Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy); + return CallInst::Create(BswapOrBitreverse, {ScalarX}); + } } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 158d2e8289e0..1480a0ff9e2f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" @@ -281,7 +282,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( if (!GEP->isInBounds()) { Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); - if (Idx->getType()->getPrimitiveSizeInBits().getFixedSize() > PtrSize) + if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > PtrSize) Idx = Builder.CreateTrunc(Idx, IntPtrTy); } @@ -403,108 +404,6 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( return nullptr; } -/// Return a value that can be used to compare the *offset* implied by a GEP to -/// zero. For example, if we have &A[i], we want to return 'i' for -/// "icmp ne i, 0". Note that, in general, indices can be complex, and scales -/// are involved. The above expression would also be legal to codegen as -/// "icmp ne (i*4), 0" (assuming A is a pointer to i32). -/// This latter form is less amenable to optimization though, and we are allowed -/// to generate the first by knowing that pointer arithmetic doesn't overflow. -/// -/// If we can't emit an optimized form for this expression, this returns null. -/// -static Value *evaluateGEPOffsetExpression(User *GEP, InstCombinerImpl &IC, - const DataLayout &DL) { - gep_type_iterator GTI = gep_type_begin(GEP); - - // Check to see if this gep only has a single variable index. If so, and if - // any constant indices are a multiple of its scale, then we can compute this - // in terms of the scale of the variable index. For example, if the GEP - // implies an offset of "12 + i*4", then we can codegen this as "3 + i", - // because the expression will cross zero at the same point. - unsigned i, e = GEP->getNumOperands(); - int64_t Offset = 0; - for (i = 1; i != e; ++i, ++GTI) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { - // Compute the aggregate offset of constant indices. - if (CI->isZero()) continue; - - // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = GTI.getStructTypeOrNull()) { - Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); - } else { - uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); - Offset += Size*CI->getSExtValue(); - } - } else { - // Found our variable index. - break; - } - } - - // If there are no variable indices, we must have a constant offset, just - // evaluate it the general way. - if (i == e) return nullptr; - - Value *VariableIdx = GEP->getOperand(i); - // Determine the scale factor of the variable element. For example, this is - // 4 if the variable index is into an array of i32. - uint64_t VariableScale = DL.getTypeAllocSize(GTI.getIndexedType()); - - // Verify that there are no other variable indices. If so, emit the hard way. - for (++i, ++GTI; i != e; ++i, ++GTI) { - ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i)); - if (!CI) return nullptr; - - // Compute the aggregate offset of constant indices. - if (CI->isZero()) continue; - - // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = GTI.getStructTypeOrNull()) { - Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); - } else { - uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); - Offset += Size*CI->getSExtValue(); - } - } - - // Okay, we know we have a single variable index, which must be a - // pointer/array/vector index. If there is no offset, life is simple, return - // the index. - Type *IntPtrTy = DL.getIntPtrType(GEP->getOperand(0)->getType()); - unsigned IntPtrWidth = IntPtrTy->getIntegerBitWidth(); - if (Offset == 0) { - // Cast to intptrty in case a truncation occurs. If an extension is needed, - // we don't need to bother extending: the extension won't affect where the - // computation crosses zero. - if (VariableIdx->getType()->getPrimitiveSizeInBits().getFixedSize() > - IntPtrWidth) { - VariableIdx = IC.Builder.CreateTrunc(VariableIdx, IntPtrTy); - } - return VariableIdx; - } - - // Otherwise, there is an index. The computation we will do will be modulo - // the pointer size. - Offset = SignExtend64(Offset, IntPtrWidth); - VariableScale = SignExtend64(VariableScale, IntPtrWidth); - - // To do this transformation, any constant index must be a multiple of the - // variable scale factor. For example, we can evaluate "12 + 4*i" as "3 + i", - // but we can't evaluate "10 + 3*i" in terms of i. Check that the offset is a - // multiple of the variable scale. - int64_t NewOffs = Offset / (int64_t)VariableScale; - if (Offset != NewOffs*(int64_t)VariableScale) - return nullptr; - - // Okay, we can do this evaluation. Start by converting the index to intptr. - if (VariableIdx->getType() != IntPtrTy) - VariableIdx = IC.Builder.CreateIntCast(VariableIdx, IntPtrTy, - true /*Signed*/); - Constant *OffsetVal = ConstantInt::get(IntPtrTy, NewOffs); - return IC.Builder.CreateAdd(VariableIdx, OffsetVal, "offset"); -} - /// Returns true if we can rewrite Start as a GEP with pointer Base /// and some integer offset. The nodes that need to be re-written /// for this transformation will be added to Explored. @@ -732,8 +631,8 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, // Cast base to the expected type. Value *NewVal = Builder.CreateBitOrPointerCast( Base, PtrTy, Start->getName() + "to.ptr"); - NewVal = Builder.CreateInBoundsGEP( - ElemTy, NewVal, makeArrayRef(NewInsts[Val]), Val->getName() + ".ptr"); + NewVal = Builder.CreateInBoundsGEP(ElemTy, NewVal, ArrayRef(NewInsts[Val]), + Val->getName() + ".ptr"); NewVal = Builder.CreateBitOrPointerCast( NewVal, Val->getType(), Val->getName() + ".conv"); Val->replaceAllUsesWith(NewVal); @@ -841,18 +740,9 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - // FIXME: Support vector pointer GEPs. - if (PtrBase == RHS && GEPLHS->isInBounds() && - !GEPLHS->getType()->isVectorTy()) { + if (PtrBase == RHS && GEPLHS->isInBounds()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). - // This transformation (ignoring the base and scales) is valid because we - // know pointers can't overflow since the gep is inbounds. See if we can - // output an optimized form. - Value *Offset = evaluateGEPOffsetExpression(GEPLHS, *this, DL); - - // If not, synthesize the offset the hard way. - if (!Offset) - Offset = EmitGEPOffset(GEPLHS); + Value *Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, Constant::getNullValue(Offset->getType())); } @@ -926,8 +816,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Type *LHSIndexTy = LOffset->getType(); Type *RHSIndexTy = ROffset->getType(); if (LHSIndexTy != RHSIndexTy) { - if (LHSIndexTy->getPrimitiveSizeInBits().getFixedSize() < - RHSIndexTy->getPrimitiveSizeInBits().getFixedSize()) { + if (LHSIndexTy->getPrimitiveSizeInBits().getFixedValue() < + RHSIndexTy->getPrimitiveSizeInBits().getFixedValue()) { ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy); } else LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy); @@ -1480,7 +1370,8 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { return nullptr; // Try to simplify this compare to T/F based on the dominating condition. - Optional<bool> Imp = isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB); + std::optional<bool> Imp = + isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB); if (Imp) return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp)); @@ -1548,16 +1439,34 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, ConstantInt::get(V->getType(), 1)); } + Type *SrcTy = X->getType(); unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), - SrcBits = X->getType()->getScalarSizeInBits(); + SrcBits = SrcTy->getScalarSizeInBits(); + + // TODO: Handle any shifted constant by subtracting trailing zeros. + // TODO: Handle non-equality predicates. + Value *Y; + if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) { + // (trunc (1 << Y) to iN) == 0 --> Y u>= N + // (trunc (1 << Y) to iN) != 0 --> Y u< N + if (C.isZero()) { + auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT; + return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits)); + } + // (trunc (1 << Y) to iN) == 2**C --> Y == C + // (trunc (1 << Y) to iN) != 2**C --> Y != C + if (C.isPowerOf2()) + return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2())); + } + if (Cmp.isEquality() && Trunc->hasOneUse()) { // Canonicalize to a mask and wider compare if the wide type is suitable: // (trunc X to i8) == C --> (X & 0xff) == (zext C) - if (!X->getType()->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { - Constant *Mask = ConstantInt::get(X->getType(), - APInt::getLowBitsSet(SrcBits, DstBits)); + if (!SrcTy->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { + Constant *Mask = + ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcBits, DstBits)); Value *And = Builder.CreateAnd(X, Mask); - Constant *WideC = ConstantInt::get(X->getType(), C.zext(SrcBits)); + Constant *WideC = ConstantInt::get(SrcTy, C.zext(SrcBits)); return new ICmpInst(Pred, And, WideC); } @@ -1570,7 +1479,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, // Pull in the high bits from known-ones set. APInt NewRHS = C.zext(SrcBits); NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); + return new ICmpInst(Pred, X, ConstantInt::get(SrcTy, NewRHS)); } } @@ -1583,11 +1492,10 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, if (isSignBitCheck(Pred, C, TrueIfSigned) && match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) && DstBits == SrcBits - ShAmtC->getZExtValue()) { - return TrueIfSigned - ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp, - ConstantInt::getNullValue(X->getType())) - : new ICmpInst(ICmpInst::ICMP_SGT, ShOp, - ConstantInt::getAllOnesValue(X->getType())); + return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp, + ConstantInt::getNullValue(SrcTy)) + : new ICmpInst(ICmpInst::ICMP_SGT, ShOp, + ConstantInt::getAllOnesValue(SrcTy)); } return nullptr; @@ -1597,6 +1505,9 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, const APInt &C) { + if (Instruction *I = foldICmpXorShiftConst(Cmp, Xor, C)) + return I; + Value *X = Xor->getOperand(0); Value *Y = Xor->getOperand(1); const APInt *XorC; @@ -1660,6 +1571,37 @@ Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, return nullptr; } +/// For power-of-2 C: +/// ((X s>> ShiftC) ^ X) u< C --> (X + C) u< (C << 1) +/// ((X s>> ShiftC) ^ X) u> (C - 1) --> (X + C) u> ((C << 1) - 1) +Instruction *InstCombinerImpl::foldICmpXorShiftConst(ICmpInst &Cmp, + BinaryOperator *Xor, + const APInt &C) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + APInt PowerOf2; + if (Pred == ICmpInst::ICMP_ULT) + PowerOf2 = C; + else if (Pred == ICmpInst::ICMP_UGT && !C.isMaxValue()) + PowerOf2 = C + 1; + else + return nullptr; + if (!PowerOf2.isPowerOf2()) + return nullptr; + Value *X; + const APInt *ShiftC; + if (!match(Xor, m_OneUse(m_c_Xor(m_Value(X), + m_AShr(m_Deferred(X), m_APInt(ShiftC)))))) + return nullptr; + uint64_t Shift = ShiftC->getLimitedValue(); + Type *XType = X->getType(); + if (Shift == 0 || PowerOf2.isMinSignedValue()) + return nullptr; + Value *Add = Builder.CreateAdd(X, ConstantInt::get(XType, PowerOf2)); + APInt Bound = + Pred == ICmpInst::ICMP_ULT ? PowerOf2 << 1 : ((PowerOf2 << 1) - 1); + return new ICmpInst(Pred, Add, ConstantInt::get(XType, Bound)); +} + /// Fold icmp (and (sh X, Y), C2), C1. Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, @@ -1780,7 +1722,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, APInt NewC2 = *C2; KnownBits Know = computeKnownBits(And->getOperand(0), 0, And); // Set high zeros of C2 to allow matching negated power-of-2. - NewC2 = *C2 + APInt::getHighBitsSet(C2->getBitWidth(), + NewC2 = *C2 | APInt::getHighBitsSet(C2->getBitWidth(), Know.countMinLeadingZeros()); // Restrict this fold only for single-use 'and' (PR10267). @@ -1904,6 +1846,20 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); } + // ((zext i1 X) & Y) == 0 --> !((trunc Y) & X) + // ((zext i1 X) & Y) != 0 --> ((trunc Y) & X) + // ((zext i1 X) & Y) == 1 --> ((trunc Y) & X) + // ((zext i1 X) & Y) != 1 --> !((trunc Y) & X) + if (match(And, m_OneUse(m_c_And(m_OneUse(m_ZExt(m_Value(X))), m_Value(Y)))) && + X->getType()->isIntOrIntVectorTy(1) && (C.isZero() || C.isOne())) { + Value *TruncY = Builder.CreateTrunc(Y, X->getType()); + if (C.isZero() ^ (Pred == CmpInst::ICMP_NE)) { + Value *And = Builder.CreateAnd(TruncY, X); + return BinaryOperator::CreateNot(And); + } + return BinaryOperator::CreateAnd(TruncY, X); + } + return nullptr; } @@ -1988,21 +1944,32 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Type *MulTy = Mul->getType(); + Value *X = Mul->getOperand(0); + + // If there's no overflow: + // X * X == 0 --> X == 0 + // X * X != 0 --> X != 0 + if (Cmp.isEquality() && C.isZero() && X == Mul->getOperand(1) && + (Mul->hasNoUnsignedWrap() || Mul->hasNoSignedWrap())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); + const APInt *MulC; if (!match(Mul->getOperand(1), m_APInt(MulC))) return nullptr; // If this is a test of the sign bit and the multiply is sign-preserving with - // a constant operand, use the multiply LHS operand instead. - ICmpInst::Predicate Pred = Cmp.getPredicate(); + // a constant operand, use the multiply LHS operand instead: + // (X * +MulC) < 0 --> X < 0 + // (X * -MulC) < 0 --> X > 0 if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) { if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); - return new ICmpInst(Pred, Mul->getOperand(0), - Constant::getNullValue(Mul->getType())); + return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); } - if (MulC->isZero() || !(Mul->hasNoSignedWrap() || Mul->hasNoUnsignedWrap())) + if (MulC->isZero() || (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap())) return nullptr; // If the multiply does not wrap, try to divide the compare constant by the @@ -2010,50 +1977,45 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, if (Cmp.isEquality()) { // (mul nsw X, MulC) == C --> X == C /s MulC if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); - return new ICmpInst(Pred, Mul->getOperand(0), NewC); + Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC)); + return new ICmpInst(Pred, X, NewC); } // (mul nuw X, MulC) == C --> X == C /u MulC if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC)); - return new ICmpInst(Pred, Mul->getOperand(0), NewC); + Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); + return new ICmpInst(Pred, X, NewC); } } + // With a matching no-overflow guarantee, fold the constants: + // (X * MulC) < C --> X < (C / MulC) + // (X * MulC) > C --> X > (C / MulC) + // TODO: Assert that Pred is not equal to SGE, SLE, UGE, ULE? Constant *NewC = nullptr; - - // FIXME: Add assert that Pred is not equal to ICMP_SGE, ICMP_SLE, - // ICMP_UGE, ICMP_ULE. - if (Mul->hasNoSignedWrap()) { - if (MulC->isNegative()) { - // MININT / -1 --> overflow. - if (C.isMinSignedValue() && MulC->isAllOnes()) - return nullptr; + // MININT / -1 --> overflow. + if (C.isMinSignedValue() && MulC->isAllOnes()) + return nullptr; + if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); - } + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); + MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); - } - - if (Mul->hasNoUnsignedWrap()) { + MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); + } else { + assert(Mul->hasNoUnsignedWrap() && "Expected mul nuw"); if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); + MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); + MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); } - return NewC ? new ICmpInst(Pred, Mul->getOperand(0), NewC) : nullptr; + return NewC ? new ICmpInst(Pred, X, NewC) : nullptr; } /// Fold icmp (shl 1, Y), C. @@ -2080,39 +2042,21 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, Pred = ICmpInst::ICMP_UGT; } - // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 - // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 unsigned CLog2 = C.logBase2(); - if (CLog2 == TypeBits - 1) { - if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_NE; - } return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C.isAllOnes()) { - // (1 << Y) <= -1 -> Y == 31 - if (Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + // (1 << Y) > 0 -> Y != 31 + // (1 << Y) > C -> Y != 31 if C is negative. + if (Pred == ICmpInst::ICMP_SGT && C.sle(0)) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - // (1 << Y) > -1 -> Y != 31 - if (Pred == ICmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } else if (!C) { - // (1 << Y) < 0 -> Y == 31 - // (1 << Y) <= 0 -> Y == 31 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); - - // (1 << Y) >= 0 -> Y != 31 - // (1 << Y) > 0 -> Y != 31 - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } - } else if (Cmp.isEquality() && CIsPowerOf2) { - return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2())); + // (1 << Y) < 0 -> Y == 31 + // (1 << Y) < 1 -> Y == 31 + // (1 << Y) < C -> Y == 31 if C is negative and not signed min. + // Exclude signed min by subtracting 1 and lower the upper bound to 0. + if (Pred == ICmpInst::ICMP_SLT && (C-1).sle(0)) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); } return nullptr; @@ -2833,6 +2777,13 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, if (Pred == CmpInst::ICMP_SLT && C == *C2) return new ICmpInst(ICmpInst::ICMP_UGT, X, ConstantInt::get(Ty, C ^ SMax)); + // (X + -1) <u C --> X <=u C (if X is never null) + if (Pred == CmpInst::ICMP_ULT && C2->isAllOnes()) { + const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); + if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C)); + } + if (!Add->hasOneUse()) return nullptr; @@ -3095,7 +3046,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { ArrayRef<int> Mask; if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { // Check whether every element of Mask is the same constant - if (is_splat(Mask)) { + if (all_equal(Mask)) { auto *VecTy = cast<VectorType>(SrcType); auto *EltTy = cast<IntegerType>(VecTy->getElementType()); if (C->isSplat(EltTy->getBitWidth())) { @@ -3139,6 +3090,20 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) { if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C)) return I; + + // (extractval ([s/u]subo X, Y), 0) == 0 --> X == Y + // (extractval ([s/u]subo X, Y), 0) != 0 --> X != Y + // TODO: This checks one-use, but that is not strictly necessary. + Value *Cmp0 = Cmp.getOperand(0); + Value *X, *Y; + if (C->isZero() && Cmp.isEquality() && Cmp0->hasOneUse() && + (match(Cmp0, + m_ExtractValue<0>(m_Intrinsic<Intrinsic::ssub_with_overflow>( + m_Value(X), m_Value(Y)))) || + match(Cmp0, + m_ExtractValue<0>(m_Intrinsic<Intrinsic::usub_with_overflow>( + m_Value(X), m_Value(Y)))))) + return new ICmpInst(Cmp.getPredicate(), X, Y); } if (match(Cmp.getOperand(1), m_APIntAllowUndef(C))) @@ -3174,10 +3139,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( } break; case Instruction::Add: { - // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. - if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + // (A + C2) == C --> A == (C - C2) + // (A + C2) != C --> A != (C - C2) + // TODO: Remove the one-use limitation? See discussion in D58633. + if (Constant *C2 = dyn_cast<Constant>(BOp1)) { if (BO->hasOneUse()) - return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC)); + return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, C2)); } else if (C.isZero()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. @@ -3433,7 +3400,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, case Instruction::UDiv: if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) return I; - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::SDiv: if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) return I; @@ -3580,8 +3547,8 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred, auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * { if (Value *Res = simplifyICmpInst(Pred, Op, RHS, SQ)) return Res; - if (Optional<bool> Impl = isImpliedCondition(SI->getCondition(), Pred, Op, - RHS, DL, SelectCondIsTrue)) + if (std::optional<bool> Impl = isImpliedCondition( + SI->getCondition(), Pred, Op, RHS, DL, SelectCondIsTrue)) return ConstantInt::get(I.getType(), *Impl); return nullptr; }; @@ -4488,6 +4455,18 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, } } + // For unsigned predicates / eq / ne: + // icmp pred (x << 1), x --> icmp getSignedPredicate(pred) x, 0 + // icmp pred x, (x << 1) --> icmp getSignedPredicate(pred) 0, x + if (!ICmpInst::isSigned(Pred)) { + if (match(Op0, m_Shl(m_Specific(Op1), m_One()))) + return new ICmpInst(ICmpInst::getSignedPredicate(Pred), Op1, + Constant::getNullValue(Op1->getType())); + else if (match(Op1, m_Shl(m_Specific(Op0), m_One()))) + return new ICmpInst(ICmpInst::getSignedPredicate(Pred), + Constant::getNullValue(Op0->getType()), Op0); + } + if (Value *V = foldMultiplicationOverflowCheck(I)) return replaceInstUsesWith(I, V); @@ -4674,17 +4653,29 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } - // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) - // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) - ConstantInt *Cst1; - if ((Op0->hasOneUse() && match(Op0, m_ZExt(m_Value(A))) && - match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || - (Op1->hasOneUse() && match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && - match(Op1, m_ZExt(m_Value(A))))) { - APInt Pow2 = Cst1->getValue() + 1; - if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && - Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) + if (match(Op1, m_ZExt(m_Value(A))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + // (B & (Pow2C-1)) == zext A --> A == trunc B + // (B & (Pow2C-1)) != zext A --> A != trunc B + const APInt *MaskC; + if (match(Op0, m_And(m_Value(B), m_LowBitMask(MaskC))) && + MaskC->countTrailingOnes() == A->getType()->getScalarSizeInBits()) return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); + + // Test if 2 values have different or same signbits: + // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0 + // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1 + unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); + Value *X, *Y; + ICmpInst::Predicate Pred2; + if (match(Op0, m_LShr(m_Value(X), m_SpecificIntAllowUndef(OpWidth - 1))) && + match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) && + Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) { + Value *Xor = Builder.CreateXor(X, Y, "xor.signbits"); + Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) : + Builder.CreateIsNotNeg(Xor); + return replaceInstUsesWith(I, R); + } } // (A >> C) == (B >> C) --> (A^B) u< (1 << C) @@ -4708,6 +4699,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 + ConstantInt *Cst1; if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { unsigned TypeBits = Cst1->getBitWidth(); @@ -4788,6 +4780,20 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { Add, ConstantInt::get(A->getType(), C.shl(1))); } + // Canonicalize: + // Assume B_Pow2 != 0 + // 1. A & B_Pow2 != B_Pow2 -> A & B_Pow2 == 0 + // 2. A & B_Pow2 == B_Pow2 -> A & B_Pow2 != 0 + if (match(Op0, m_c_And(m_Specific(Op1), m_Value())) && + isKnownToBeAPowerOfTwo(Op1, /* OrZero */ false, 0, &I)) + return new ICmpInst(CmpInst::getInversePredicate(Pred), Op0, + ConstantInt::getNullValue(Op0->getType())); + + if (match(Op1, m_c_And(m_Specific(Op0), m_Value())) && + isKnownToBeAPowerOfTwo(Op0, /* OrZero */ false, 0, &I)) + return new ICmpInst(CmpInst::getInversePredicate(Pred), Op1, + ConstantInt::getNullValue(Op1->getType())); + return nullptr; } @@ -4993,7 +4999,7 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { return foldICmpWithZextOrSext(ICmp); } -static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { +static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS, bool IsSigned) { switch (BinaryOp) { default: llvm_unreachable("Unsupported binary op"); @@ -5001,7 +5007,8 @@ static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { case Instruction::Sub: return match(RHS, m_Zero()); case Instruction::Mul: - return match(RHS, m_One()); + return !(RHS->getType()->isIntOrIntVectorTy(1) && IsSigned) && + match(RHS, m_One()); } } @@ -5048,7 +5055,7 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, if (auto *LHSTy = dyn_cast<VectorType>(LHS->getType())) OverflowTy = VectorType::get(OverflowTy, LHSTy->getElementCount()); - if (isNeutralValue(BinaryOp, RHS)) { + if (isNeutralValue(BinaryOp, RHS, IsSigned)) { Result = LHS; Overflow = ConstantInt::getFalse(OverflowTy); return true; @@ -5746,7 +5753,7 @@ static Instruction *foldICmpUsingBoolRange(ICmpInst &I, return nullptr; } -llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> +std::optional<std::pair<CmpInst::Predicate, Constant *>> InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, Constant *C) { assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && @@ -5769,13 +5776,13 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, if (auto *CI = dyn_cast<ConstantInt>(C)) { // Bail out if the constant can't be safely incremented/decremented. if (!ConstantIsOk(CI)) - return llvm::None; + return std::nullopt; } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) { unsigned NumElts = FVTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = C->getAggregateElement(i); if (!Elt) - return llvm::None; + return std::nullopt; if (isa<UndefValue>(Elt)) continue; @@ -5784,14 +5791,14 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, // know that this constant is min/max. auto *CI = dyn_cast<ConstantInt>(Elt); if (!CI || !ConstantIsOk(CI)) - return llvm::None; + return std::nullopt; if (!SafeReplacementConstant) SafeReplacementConstant = CI; } } else { // ConstantExpr? - return llvm::None; + return std::nullopt; } // It may not be safe to change a compare predicate in the presence of @@ -5901,7 +5908,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, case ICmpInst::ICMP_UGT: // icmp ugt -> icmp ult std::swap(A, B); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_ULT: // icmp ult i1 A, B -> ~A & B return BinaryOperator::CreateAnd(Builder.CreateNot(A), B); @@ -5909,7 +5916,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, case ICmpInst::ICMP_SGT: // icmp sgt -> icmp slt std::swap(A, B); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_SLT: // icmp slt i1 A, B -> A & ~B return BinaryOperator::CreateAnd(Builder.CreateNot(B), A); @@ -5917,7 +5924,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, case ICmpInst::ICMP_UGE: // icmp uge -> icmp ule std::swap(A, B); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_ULE: // icmp ule i1 A, B -> ~A | B return BinaryOperator::CreateOr(Builder.CreateNot(A), B); @@ -5925,7 +5932,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, case ICmpInst::ICMP_SGE: // icmp sge -> icmp sle std::swap(A, B); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_SLE: // icmp sle i1 A, B -> A | ~B return BinaryOperator::CreateOr(Builder.CreateNot(B), A); @@ -5986,6 +5993,31 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, const CmpInst::Predicate Pred = Cmp.getPredicate(); Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1); Value *V1, *V2; + + auto createCmpReverse = [&](CmpInst::Predicate Pred, Value *X, Value *Y) { + Value *V = Builder.CreateCmp(Pred, X, Y, Cmp.getName()); + if (auto *I = dyn_cast<Instruction>(V)) + I->copyIRFlags(&Cmp); + Module *M = Cmp.getModule(); + Function *F = Intrinsic::getDeclaration( + M, Intrinsic::experimental_vector_reverse, V->getType()); + return CallInst::Create(F, V); + }; + + if (match(LHS, m_VecReverse(m_Value(V1)))) { + // cmp Pred, rev(V1), rev(V2) --> rev(cmp Pred, V1, V2) + if (match(RHS, m_VecReverse(m_Value(V2))) && + (LHS->hasOneUse() || RHS->hasOneUse())) + return createCmpReverse(Pred, V1, V2); + + // cmp Pred, rev(V1), RHSSplat --> rev(cmp Pred, V1, RHSSplat) + if (LHS->hasOneUse() && isSplatValue(RHS)) + return createCmpReverse(Pred, V1, RHS); + } + // cmp Pred, LHSSplat, rev(V2) --> rev(cmp Pred, LHSSplat, V2) + else if (isSplatValue(LHS) && match(RHS, m_OneUse(m_VecReverse(m_Value(V2))))) + return createCmpReverse(Pred, LHS, V2); + ArrayRef<int> M; if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M)))) return nullptr; @@ -6318,11 +6350,11 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } // (zext a) * (zext b) --> llvm.umul.with.overflow. - if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this)) return R; } - if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (match(Op1, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this)) return R; } @@ -6668,10 +6700,48 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI, /// Optimize fabs(X) compared with zero. static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { Value *X; - if (!match(I.getOperand(0), m_FAbs(m_Value(X))) || - !match(I.getOperand(1), m_PosZeroFP())) + if (!match(I.getOperand(0), m_FAbs(m_Value(X)))) return nullptr; + const APFloat *C; + if (!match(I.getOperand(1), m_APFloat(C))) + return nullptr; + + if (!C->isPosZero()) { + if (!C->isSmallestNormalized()) + return nullptr; + + const Function *F = I.getFunction(); + DenormalMode Mode = F->getDenormalMode(C->getSemantics()); + if (Mode.Input == DenormalMode::PreserveSign || + Mode.Input == DenormalMode::PositiveZero) { + + auto replaceFCmp = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) { + Constant *Zero = ConstantFP::getNullValue(X->getType()); + return new FCmpInst(P, X, Zero, "", I); + }; + + switch (I.getPredicate()) { + case FCmpInst::FCMP_OLT: + // fcmp olt fabs(x), smallest_normalized_number -> fcmp oeq x, 0.0 + return replaceFCmp(&I, FCmpInst::FCMP_OEQ, X); + case FCmpInst::FCMP_UGE: + // fcmp uge fabs(x), smallest_normalized_number -> fcmp une x, 0.0 + return replaceFCmp(&I, FCmpInst::FCMP_UNE, X); + case FCmpInst::FCMP_OGE: + // fcmp oge fabs(x), smallest_normalized_number -> fcmp one x, 0.0 + return replaceFCmp(&I, FCmpInst::FCMP_ONE, X); + case FCmpInst::FCMP_ULT: + // fcmp ult fabs(x), smallest_normalized_number -> fcmp ueq x, 0.0 + return replaceFCmp(&I, FCmpInst::FCMP_UEQ, X); + default: + break; + } + } + + return nullptr; + } + auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) { I->setPredicate(P); return IC.replaceOperand(*I, 0, X); @@ -6828,6 +6898,26 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); + // Ignore signbit of bitcasted int when comparing equality to FP 0.0: + // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0 + if (match(Op1, m_PosZeroFP()) && + match(Op0, m_OneUse(m_BitCast(m_Value(X)))) && + X->getType()->isVectorTy() == OpType->isVectorTy() && + X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) { + ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE; + if (Pred == FCmpInst::FCMP_OEQ) + IntPred = ICmpInst::ICMP_EQ; + else if (Pred == FCmpInst::FCMP_UNE) + IntPred = ICmpInst::ICMP_NE; + + if (IntPred != ICmpInst::BAD_ICMP_PREDICATE) { + Type *IntTy = X->getType(); + const APInt &SignMask = ~APInt::getSignMask(IntTy->getScalarSizeInBits()); + Value *MaskX = Builder.CreateAnd(X, ConstantInt::get(IntTy, SignMask)); + return new ICmpInst(IntPred, MaskX, ConstantInt::getNullValue(IntTy)); + } + } + // Handle fcmp with instruction LHS and constant RHS. Instruction *LHSI; Constant *RHSC; @@ -6866,10 +6956,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op0, m_FNeg(m_Value(X)))) { // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C Constant *C; - if (match(Op1, m_Constant(C))) { - Constant *NegC = ConstantExpr::getFNeg(C); - return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I); - } + if (match(Op1, m_Constant(C))) + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I); } if (match(Op0, m_FPExt(m_Value(X)))) { @@ -6915,7 +7004,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { APFloat Fabs = TruncC; Fabs.clearSign(); if (!Lossy && - (!(Fabs < APFloat::getSmallestNormalized(FPSem)) || Fabs.isZero())) { + (Fabs.isZero() || !(Fabs < APFloat::getSmallestNormalized(FPSem)))) { Constant *NewC = ConstantFP::get(X->getType(), TruncC); return new FCmpInst(Pred, X, NewC, "", &I); } @@ -6942,6 +7031,24 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { } } + { + Value *CanonLHS = nullptr, *CanonRHS = nullptr; + match(Op0, m_Intrinsic<Intrinsic::canonicalize>(m_Value(CanonLHS))); + match(Op1, m_Intrinsic<Intrinsic::canonicalize>(m_Value(CanonRHS))); + + // (canonicalize(x) == x) => (x == x) + if (CanonLHS == Op1) + return new FCmpInst(Pred, Op1, Op1, "", &I); + + // (x == canonicalize(x)) => (x == x) + if (CanonRHS == Op0) + return new FCmpInst(Pred, Op0, Op0, "", &I); + + // (canonicalize(x) == canonicalize(y)) => (x == y) + if (CanonLHS && CanonRHS) + return new FCmpInst(Pred, CanonLHS, CanonRHS, "", &I); + } + if (I.getType()->isVectorTy()) if (Instruction *Res = foldVectorCmp(I, Builder)) return Res; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 664226ec187b..f4e88b122383 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -106,7 +106,8 @@ public: Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); Instruction *visitAnd(BinaryOperator &I); Instruction *visitOr(BinaryOperator &I); - bool sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I); + bool sinkNotIntoLogicalOp(Instruction &I); + bool sinkNotIntoOtherHandOfLogicalOp(Instruction &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); Value *reassociateShiftAmtsOfTwoSameDirectionShifts( @@ -127,8 +128,8 @@ public: Instruction *commonCastTransforms(CastInst &CI); Instruction *commonPointerCastTransforms(CastInst &CI); Instruction *visitTrunc(TruncInst &CI); - Instruction *visitZExt(ZExtInst &CI); - Instruction *visitSExt(SExtInst &CI); + Instruction *visitZExt(ZExtInst &Zext); + Instruction *visitSExt(SExtInst &Sext); Instruction *visitFPTrunc(FPTruncInst &CI); Instruction *visitFPExt(CastInst &CI); Instruction *visitFPToUI(FPToUIInst &FI); @@ -167,6 +168,7 @@ public: Instruction *visitInsertValueInst(InsertValueInst &IV); Instruction *visitInsertElementInst(InsertElementInst &IE); Instruction *visitExtractElementInst(ExtractElementInst &EI); + Instruction *simplifyBinOpSplats(ShuffleVectorInst &SVI); Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI); Instruction *visitExtractValueInst(ExtractValueInst &EV); Instruction *visitLandingPadInst(LandingPadInst &LI); @@ -247,9 +249,9 @@ private: /// \return null if the transformation cannot be performed. If the /// transformation can be performed the new instruction that replaces the /// (zext icmp) pair will be returned. - Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI); + Instruction *transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext); - Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); + Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext); bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { @@ -329,7 +331,7 @@ private: Instruction *matchSAddSubSat(IntrinsicInst &MinMax1); Instruction *foldNot(BinaryOperator &I); - void freelyInvertAllUsersOf(Value *V); + void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); /// Determine if a pair of casts can be replaced by a single cast. /// @@ -360,14 +362,24 @@ private: Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd, bool IsLogicalSelect = false); + Instruction *foldLogicOfIsFPClass(BinaryOperator &Operator, Value *LHS, + Value *RHS); + + Instruction * + canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i); + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI, bool IsAnd, bool IsLogical = false); - Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); - Value *getSelectCondition(Value *A, Value *B); + Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D, + bool InvertFalseVal = false); + Value *getSelectCondition(Value *A, Value *B, bool ABIsTheSame); + Instruction *foldLShrOverflowBit(BinaryOperator &I); + Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV); Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); Instruction *foldFPSignBitOps(BinaryOperator &I); + Instruction *foldFDivConstantDivisor(BinaryOperator &I); // Optimize one of these forms: // and i1 Op, SI / select i1 Op, i1 SI, i1 false (if IsAnd = true) @@ -377,64 +389,6 @@ private: bool IsAnd); public: - /// Inserts an instruction \p New before instruction \p Old - /// - /// Also adds the new instruction to the worklist and returns \p New so that - /// it is suitable for use as the return from the visitation patterns. - Instruction *InsertNewInstBefore(Instruction *New, Instruction &Old) { - assert(New && !New->getParent() && - "New instruction already inserted into a basic block!"); - BasicBlock *BB = Old.getParent(); - BB->getInstList().insert(Old.getIterator(), New); // Insert inst - Worklist.add(New); - return New; - } - - /// Same as InsertNewInstBefore, but also sets the debug loc. - Instruction *InsertNewInstWith(Instruction *New, Instruction &Old) { - New->setDebugLoc(Old.getDebugLoc()); - return InsertNewInstBefore(New, Old); - } - - /// A combiner-aware RAUW-like routine. - /// - /// This method is to be used when an instruction is found to be dead, - /// replaceable with another preexisting expression. Here we add all uses of - /// I to the worklist, replace all uses of I with the new value, then return - /// I, so that the inst combiner will know that I was modified. - Instruction *replaceInstUsesWith(Instruction &I, Value *V) { - // If there are no uses to replace, then we return nullptr to indicate that - // no changes were made to the program. - if (I.use_empty()) return nullptr; - - Worklist.pushUsersToWorkList(I); // Add all modified instrs to worklist. - - // If we are replacing the instruction with itself, this must be in a - // segment of unreachable code, so just clobber the instruction. - if (&I == V) - V = PoisonValue::get(I.getType()); - - LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n" - << " with " << *V << '\n'); - - I.replaceAllUsesWith(V); - MadeIRChange = true; - return &I; - } - - /// Replace operand of instruction and add old operand to the worklist. - Instruction *replaceOperand(Instruction &I, unsigned OpNum, Value *V) { - Worklist.addValue(I.getOperand(OpNum)); - I.setOperand(OpNum, V); - return &I; - } - - /// Replace use and add the previously used value to the worklist. - void replaceUse(Use &U, Value *NewValue) { - Worklist.addValue(U); - U = NewValue; - } - /// Create and insert the idiom we use to indicate a block is unreachable /// without having to rewrite the CFG from within InstCombine. void CreateNonTerminatorUnreachable(Instruction *InsertAt) { @@ -467,67 +421,6 @@ public: return nullptr; // Don't do anything with FI } - void computeKnownBits(const Value *V, KnownBits &Known, - unsigned Depth, const Instruction *CxtI) const { - llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT); - } - - KnownBits computeKnownBits(const Value *V, unsigned Depth, - const Instruction *CxtI) const { - return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT); - } - - bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false, - unsigned Depth = 0, - const Instruction *CxtI = nullptr) { - return llvm::isKnownToBeAPowerOfTwo(V, DL, OrZero, Depth, &AC, CxtI, &DT); - } - - bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth = 0, - const Instruction *CxtI = nullptr) const { - return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT); - } - - unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0, - const Instruction *CxtI = nullptr) const { - return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForUnsignedMul(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForSignedMul(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForSignedAdd(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForUnsignedSub(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT); - } - OverflowResult computeOverflow( Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS, Instruction *CxtI) const; @@ -543,7 +436,7 @@ public: /// -> "A*(B+C)") or expanding out if this results in simplifications (eg: "A /// & (B | C) -> (A&B) | (A&C)" if this is a win). Returns the simplified /// value, or null if it didn't simplify. - Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + Value *foldUsingDistributiveLaws(BinaryOperator &I); /// Tries to simplify add operations using the definition of remainder. /// @@ -559,8 +452,7 @@ public: /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). - Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, - Value *, Value *, Value *); + Value *tryFactorizationFolds(BinaryOperator &I); /// Match a select chain which produces one of three values based on whether /// the LHS is less than, equal to, or greater than RHS respectively. @@ -647,7 +539,7 @@ public: /// If an integer typed PHI has only one use which is an IntToPtr operation, /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise /// insert a new pointer typed PHI and replace the original one. - Instruction *foldIntegerTypedPHI(PHINode &PN); + bool foldIntegerTypedPHI(PHINode &PN); /// Helper function for FoldPHIArgXIntoPHI() to set debug location for the /// folded operation. @@ -716,6 +608,8 @@ public: const APInt &C1); Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, const APInt &C1, const APInt &C2); + Instruction *foldICmpXorShiftConst(ICmpInst &Cmp, BinaryOperator *Xor, + const APInt &C); Instruction *foldICmpShrConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, const APInt &C2); Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, @@ -731,6 +625,7 @@ public: Instruction *foldICmpBitCast(ICmpInst &Cmp); // Helpers of visitSelectInst(). + Instruction *foldSelectOfBools(SelectInst &SI); Instruction *foldSelectExtConst(SelectInst &Sel); Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *); @@ -790,13 +685,13 @@ class Negator final { std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I); - LLVM_NODISCARD Value *visitImpl(Value *V, unsigned Depth); + [[nodiscard]] Value *visitImpl(Value *V, unsigned Depth); - LLVM_NODISCARD Value *negate(Value *V, unsigned Depth); + [[nodiscard]] Value *negate(Value *V, unsigned Depth); /// Recurse depth-first and attempt to sink the negation. /// FIXME: use worklist? - LLVM_NODISCARD Optional<Result> run(Value *Root); + [[nodiscard]] std::optional<Result> run(Value *Root); Negator(const Negator &) = delete; Negator(Negator &&) = delete; @@ -806,8 +701,8 @@ class Negator final { public: /// Attempt to negate \p Root. Retuns nullptr if negation can't be performed, /// otherwise returns negated value. - LLVM_NODISCARD static Value *Negate(bool LHSIsZero, Value *Root, - InstCombinerImpl &IC); + [[nodiscard]] static Value *Negate(bool LHSIsZero, Value *Root, + InstCombinerImpl &IC); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index e03b7026f802..41bc65620ff6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -28,30 +28,42 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -STATISTIC(NumDeadStore, "Number of dead stores eliminated"); +STATISTIC(NumDeadStore, "Number of dead stores eliminated"); STATISTIC(NumGlobalCopies, "Number of allocas copied from constant global"); -/// isOnlyCopiedFromConstantGlobal - Recursively walk the uses of a (derived) +static cl::opt<unsigned> MaxCopiedFromConstantUsers( + "instcombine-max-copied-from-constant-users", cl::init(128), + cl::desc("Maximum users to visit in copy from constant transform"), + cl::Hidden); + +/// isOnlyCopiedFromConstantMemory - Recursively walk the uses of a (derived) /// pointer to an alloca. Ignore any reads of the pointer, return false if we /// see any stores or other unknown uses. If we see pointer arithmetic, keep /// track of whether it moves the pointer (with IsOffset) but otherwise traverse /// the uses. If we see a memcpy/memmove that targets an unoffseted pointer to -/// the alloca, and if the source pointer is a pointer to a constant global, we -/// can optimize this. +/// the alloca, and if the source pointer is a pointer to a constant memory +/// location, we can optimize this. static bool -isOnlyCopiedFromConstantMemory(AAResults *AA, - Value *V, MemTransferInst *&TheCopy, +isOnlyCopiedFromConstantMemory(AAResults *AA, AllocaInst *V, + MemTransferInst *&TheCopy, SmallVectorImpl<Instruction *> &ToDelete) { // We track lifetime intrinsics as we encounter them. If we decide to go - // ahead and replace the value with the global, this lets the caller quickly - // eliminate the markers. - - SmallVector<std::pair<Value *, bool>, 35> ValuesToInspect; - ValuesToInspect.emplace_back(V, false); - while (!ValuesToInspect.empty()) { - auto ValuePair = ValuesToInspect.pop_back_val(); - const bool IsOffset = ValuePair.second; - for (auto &U : ValuePair.first->uses()) { + // ahead and replace the value with the memory location, this lets the caller + // quickly eliminate the markers. + + using ValueAndIsOffset = PointerIntPair<Value *, 1, bool>; + SmallVector<ValueAndIsOffset, 32> Worklist; + SmallPtrSet<ValueAndIsOffset, 32> Visited; + Worklist.emplace_back(V, false); + while (!Worklist.empty()) { + ValueAndIsOffset Elem = Worklist.pop_back_val(); + if (!Visited.insert(Elem).second) + continue; + if (Visited.size() > MaxCopiedFromConstantUsers) + return false; + + const auto [Value, IsOffset] = Elem; + for (auto &U : Value->uses()) { auto *I = cast<Instruction>(U.getUser()); if (auto *LI = dyn_cast<LoadInst>(I)) { @@ -60,15 +72,22 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, continue; } - if (isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I)) { + if (isa<PHINode, SelectInst>(I)) { + // We set IsOffset=true, to forbid the memcpy from occurring after the + // phi: If one of the phi operands is not based on the alloca, we + // would incorrectly omit a write. + Worklist.emplace_back(I, true); + continue; + } + if (isa<BitCastInst, AddrSpaceCastInst>(I)) { // If uses of the bitcast are ok, we are ok. - ValuesToInspect.emplace_back(I, IsOffset); + Worklist.emplace_back(I, IsOffset); continue; } if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { // If the GEP has all zero indices, it doesn't offset the pointer. If it // doesn't, it does. - ValuesToInspect.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices()); + Worklist.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices()); continue; } @@ -85,11 +104,12 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, if (IsArgOperand && Call->isInAllocaArgument(DataOpNo)) return false; - // If this is a readonly/readnone call site, then we know it is just a - // load (but one that potentially returns the value itself), so we can + // If this call site doesn't modify the memory, then we know it is just + // a load (but one that potentially returns the value itself), so we can // ignore it if we know that the value isn't captured. - if (Call->onlyReadsMemory() && - (Call->use_empty() || Call->doesNotCapture(DataOpNo))) + bool NoCapture = Call->doesNotCapture(DataOpNo); + if ((Call->onlyReadsMemory() && (Call->use_empty() || NoCapture)) || + (Call->onlyReadsMemory(DataOpNo) && NoCapture)) continue; // If this is being passed as a byval argument, the caller is making a @@ -111,12 +131,14 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, if (!MI) return false; + // If the transfer is volatile, reject it. + if (MI->isVolatile()) + return false; + // If the transfer is using the alloca as a source of the transfer, then // ignore it since it is a load (unless the transfer is volatile). - if (U.getOperandNo() == 1) { - if (MI->isVolatile()) return false; + if (U.getOperandNo() == 1) continue; - } // If we already have seen a copy, reject the second one. if (TheCopy) return false; @@ -128,8 +150,8 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, // If the memintrinsic isn't using the alloca as the dest, reject it. if (U.getOperandNo() != 0) return false; - // If the source of the memcpy/move is not a constant global, reject it. - if (!AA->pointsToConstantMemory(MI->getSource())) + // If the source of the memcpy/move is not constant, reject it. + if (isModSet(AA->getModRefInfoMask(MI->getSource()))) return false; // Otherwise, the transform is safe. Remember the copy instruction. @@ -139,9 +161,10 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, return true; } -/// isOnlyCopiedFromConstantGlobal - Return true if the specified alloca is only -/// modified by a copy from a constant global. If we can prove this, we can -/// replace any uses of the alloca with uses of the global directly. +/// isOnlyCopiedFromConstantMemory - Return true if the specified alloca is only +/// modified by a copy from a constant memory location. If we can prove this, we +/// can replace any uses of the alloca with uses of the memory location +/// directly. static MemTransferInst * isOnlyCopiedFromConstantMemory(AAResults *AA, AllocaInst *AI, @@ -165,7 +188,7 @@ static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI, } static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, - AllocaInst &AI) { + AllocaInst &AI, DominatorTree &DT) { // Check for array size of 1 (scalar allocation). if (!AI.isArrayAllocation()) { // i32 1 is the canonical array size for scalar allocations. @@ -184,6 +207,8 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, nullptr, AI.getName()); New->setAlignment(AI.getAlign()); + replaceAllDbgUsesWith(AI, *New, *New, DT); + // Scan to the end of the allocation instructions, to skip over a block of // allocas if possible...also skip interleaved debug info // @@ -234,31 +259,83 @@ namespace { // instruction. class PointerReplacer { public: - PointerReplacer(InstCombinerImpl &IC) : IC(IC) {} + PointerReplacer(InstCombinerImpl &IC, Instruction &Root) + : IC(IC), Root(Root) {} - bool collectUsers(Instruction &I); - void replacePointer(Instruction &I, Value *V); + bool collectUsers(); + void replacePointer(Value *V); private: + bool collectUsersRecursive(Instruction &I); void replace(Instruction *I); Value *getReplacement(Value *I); + bool isAvailable(Instruction *I) const { + return I == &Root || Worklist.contains(I); + } + SmallPtrSet<Instruction *, 32> ValuesToRevisit; SmallSetVector<Instruction *, 4> Worklist; MapVector<Value *, Value *> WorkMap; InstCombinerImpl &IC; + Instruction &Root; }; } // end anonymous namespace -bool PointerReplacer::collectUsers(Instruction &I) { - for (auto U : I.users()) { +bool PointerReplacer::collectUsers() { + if (!collectUsersRecursive(Root)) + return false; + + // Ensure that all outstanding (indirect) users of I + // are inserted into the Worklist. Return false + // otherwise. + for (auto *Inst : ValuesToRevisit) + if (!Worklist.contains(Inst)) + return false; + return true; +} + +bool PointerReplacer::collectUsersRecursive(Instruction &I) { + for (auto *U : I.users()) { auto *Inst = cast<Instruction>(&*U); if (auto *Load = dyn_cast<LoadInst>(Inst)) { if (Load->isVolatile()) return false; Worklist.insert(Load); - } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) { + } else if (auto *PHI = dyn_cast<PHINode>(Inst)) { + // All incoming values must be instructions for replacability + if (any_of(PHI->incoming_values(), + [](Value *V) { return !isa<Instruction>(V); })) + return false; + + // If at least one incoming value of the PHI is not in Worklist, + // store the PHI for revisiting and skip this iteration of the + // loop. + if (any_of(PHI->incoming_values(), [this](Value *V) { + return !isAvailable(cast<Instruction>(V)); + })) { + ValuesToRevisit.insert(Inst); + continue; + } + + Worklist.insert(PHI); + if (!collectUsersRecursive(*PHI)) + return false; + } else if (auto *SI = dyn_cast<SelectInst>(Inst)) { + if (!isa<Instruction>(SI->getTrueValue()) || + !isa<Instruction>(SI->getFalseValue())) + return false; + + if (!isAvailable(cast<Instruction>(SI->getTrueValue())) || + !isAvailable(cast<Instruction>(SI->getFalseValue()))) { + ValuesToRevisit.insert(Inst); + continue; + } + Worklist.insert(SI); + if (!collectUsersRecursive(*SI)) + return false; + } else if (isa<GetElementPtrInst, BitCastInst>(Inst)) { Worklist.insert(Inst); - if (!collectUsers(*Inst)) + if (!collectUsersRecursive(*Inst)) return false; } else if (auto *MI = dyn_cast<MemTransferInst>(Inst)) { if (MI->isVolatile()) @@ -293,6 +370,14 @@ void PointerReplacer::replace(Instruction *I) { IC.InsertNewInstWith(NewI, *LT); IC.replaceInstUsesWith(*LT, NewI); WorkMap[LT] = NewI; + } else if (auto *PHI = dyn_cast<PHINode>(I)) { + Type *NewTy = getReplacement(PHI->getIncomingValue(0))->getType(); + auto *NewPHI = PHINode::Create(NewTy, PHI->getNumIncomingValues(), + PHI->getName(), PHI); + for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I) + NewPHI->addIncoming(getReplacement(PHI->getIncomingValue(I)), + PHI->getIncomingBlock(I)); + WorkMap[PHI] = NewPHI; } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { auto *V = getReplacement(GEP->getPointerOperand()); assert(V && "Operand not replaced"); @@ -313,6 +398,13 @@ void PointerReplacer::replace(Instruction *I) { IC.InsertNewInstWith(NewI, *BC); NewI->takeName(BC); WorkMap[BC] = NewI; + } else if (auto *SI = dyn_cast<SelectInst>(I)) { + auto *NewSI = SelectInst::Create( + SI->getCondition(), getReplacement(SI->getTrueValue()), + getReplacement(SI->getFalseValue()), SI->getName(), nullptr, SI); + IC.InsertNewInstWith(NewSI, *SI); + NewSI->takeName(SI); + WorkMap[SI] = NewSI; } else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) { auto *SrcV = getReplacement(MemCpy->getRawSource()); // The pointer may appear in the destination of a copy, but we don't want to @@ -339,27 +431,27 @@ void PointerReplacer::replace(Instruction *I) { } } -void PointerReplacer::replacePointer(Instruction &I, Value *V) { +void PointerReplacer::replacePointer(Value *V) { #ifndef NDEBUG - auto *PT = cast<PointerType>(I.getType()); + auto *PT = cast<PointerType>(Root.getType()); auto *NT = cast<PointerType>(V->getType()); assert(PT != NT && PT->hasSameElementTypeAs(NT) && "Invalid usage"); #endif - WorkMap[&I] = V; + WorkMap[&Root] = V; for (Instruction *Workitem : Worklist) replace(Workitem); } Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { - if (auto *I = simplifyAllocaArraySize(*this, AI)) + if (auto *I = simplifyAllocaArraySize(*this, AI, DT)) return I; if (AI.getAllocatedType()->isSized()) { // Move all alloca's of zero byte objects to the entry block and merge them // together. Note that we only do this for alloca's, because malloc should // allocate and return a unique pointer, even for a zero byte allocation. - if (DL.getTypeAllocSize(AI.getAllocatedType()).getKnownMinSize() == 0) { + if (DL.getTypeAllocSize(AI.getAllocatedType()).getKnownMinValue() == 0) { // For a zero sized alloca there is no point in doing an array allocation. // This is helpful if the array size is a complicated expression not used // elsewhere. @@ -377,7 +469,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { AllocaInst *EntryAI = dyn_cast<AllocaInst>(FirstInst); if (!EntryAI || !EntryAI->getAllocatedType()->isSized() || DL.getTypeAllocSize(EntryAI->getAllocatedType()) - .getKnownMinSize() != 0) { + .getKnownMinValue() != 0) { AI.moveBefore(FirstInst); return &AI; } @@ -395,11 +487,11 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { } // Check to see if this allocation is only modified by a memcpy/memmove from - // a constant whose alignment is equal to or exceeds that of the allocation. - // If this is the case, we can change all users to use the constant global - // instead. This is commonly produced by the CFE by constructs like "void - // foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' is only subsequently - // read. + // a memory location whose alignment is equal to or exceeds that of the + // allocation. If this is the case, we can change all users to use the + // constant memory location instead. This is commonly produced by the CFE by + // constructs like "void foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' + // is only subsequently read. SmallVector<Instruction *, 4> ToDelete; if (MemTransferInst *Copy = isOnlyCopiedFromConstantMemory(AA, &AI, ToDelete)) { Value *TheSrc = Copy->getSource(); @@ -415,7 +507,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace(); auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace); - if (AI.getType()->getAddressSpace() == SrcAddrSpace) { + if (AI.getAddressSpace() == SrcAddrSpace) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); @@ -426,13 +518,13 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { return NewI; } - PointerReplacer PtrReplacer(*this); - if (PtrReplacer.collectUsers(AI)) { + PointerReplacer PtrReplacer(*this, AI); + if (PtrReplacer.collectUsers()) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); - PtrReplacer.replacePointer(AI, Cast); + PtrReplacer.replacePointer(Cast); ++NumGlobalCopies; } } @@ -507,6 +599,7 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI, // here. switch (ID) { case LLVMContext::MD_dbg: + case LLVMContext::MD_DIAssignID: case LLVMContext::MD_tbaa: case LLVMContext::MD_prof: case LLVMContext::MD_fpmath: @@ -575,43 +668,43 @@ static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { /// later. However, it is risky in case some backend or other part of LLVM is /// relying on the exact type loaded to select appropriate atomic operations. static Instruction *combineLoadToOperationType(InstCombinerImpl &IC, - LoadInst &LI) { + LoadInst &Load) { // FIXME: We could probably with some care handle both volatile and ordered // atomic loads here but it isn't clear that this is important. - if (!LI.isUnordered()) + if (!Load.isUnordered()) return nullptr; - if (LI.use_empty()) + if (Load.use_empty()) return nullptr; // swifterror values can't be bitcasted. - if (LI.getPointerOperand()->isSwiftError()) + if (Load.getPointerOperand()->isSwiftError()) return nullptr; - const DataLayout &DL = IC.getDataLayout(); - // Fold away bit casts of the loaded value by loading the desired type. // Note that we should not do this for pointer<->integer casts, // because that would result in type punning. - if (LI.hasOneUse()) { + if (Load.hasOneUse()) { // Don't transform when the type is x86_amx, it makes the pass that lower // x86_amx type happy. - if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) { - assert(!LI.getType()->isX86_AMXTy() && - "load from x86_amx* should not happen!"); + Type *LoadTy = Load.getType(); + if (auto *BC = dyn_cast<BitCastInst>(Load.user_back())) { + assert(!LoadTy->isX86_AMXTy() && "Load from x86_amx* should not happen!"); if (BC->getType()->isX86_AMXTy()) return nullptr; } - if (auto* CI = dyn_cast<CastInst>(LI.user_back())) - if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() == - CI->getDestTy()->isPtrOrPtrVectorTy()) - if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) { - LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy()); - CI->replaceAllUsesWith(NewLoad); - IC.eraseInstFromFunction(*CI); - return &LI; - } + if (auto *CastUser = dyn_cast<CastInst>(Load.user_back())) { + Type *DestTy = CastUser->getDestTy(); + if (CastUser->isNoopCast(IC.getDataLayout()) && + LoadTy->isPtrOrPtrVectorTy() == DestTy->isPtrOrPtrVectorTy() && + (!Load.isAtomic() || isSupportedAtomicType(DestTy))) { + LoadInst *NewLoad = IC.combineLoadToNewType(Load, DestTy); + CastUser->replaceAllUsesWith(NewLoad); + IC.eraseInstFromFunction(*CastUser); + return &Load; + } + } } // FIXME: We should also canonicalize loads of vectors when their elements are @@ -639,7 +732,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { ".unpack"); NewLoad->setAAMetadata(LI.getAAMetadata()); return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( - UndefValue::get(T), NewLoad, 0, Name)); + PoisonValue::get(T), NewLoad, 0, Name)); } // We don't want to break loads with padding here as we'd loose @@ -654,13 +747,13 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *IdxType = Type::getInt32Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - Value *V = UndefValue::get(T); + Value *V = PoisonValue::get(T); for (unsigned i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, ArrayRef(Indices), Name + ".elt"); auto *L = IC.Builder.CreateAlignedLoad( ST->getElementType(i), Ptr, @@ -681,7 +774,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack"); NewLoad->setAAMetadata(LI.getAAMetadata()); return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( - UndefValue::get(T), NewLoad, 0, Name)); + PoisonValue::get(T), NewLoad, 0, Name)); } // Bail out if the array is too large. Ideally we would like to optimize @@ -699,14 +792,14 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *IdxType = Type::getInt64Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - Value *V = UndefValue::get(T); + Value *V = PoisonValue::get(T); uint64_t Offset = 0; for (uint64_t i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), Name + ".elt"); auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr, commonAlignment(Align, Offset), @@ -769,10 +862,13 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, if (!CS) return false; - uint64_t TypeSize = DL.getTypeAllocSize(AI->getAllocatedType()); + TypeSize TS = DL.getTypeAllocSize(AI->getAllocatedType()); + if (TS.isScalable()) + return false; // Make sure that, even if the multiplication below would wrap as an // uint64_t, we still do the right thing. - if ((CS->getValue().zext(128) * APInt(128, TypeSize)).ugt(MaxSize)) + if ((CS->getValue().zext(128) * APInt(128, TS.getFixedValue())) + .ugt(MaxSize)) return false; continue; } @@ -849,7 +945,7 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC, if (!AllocTy || !AllocTy->isSized()) return false; const DataLayout &DL = IC.getDataLayout(); - uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedSize(); + uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedValue(); // If there are more indices after the one we might replace with a zero, make // sure they're all non-negative. If any of them are negative, the overall @@ -1183,8 +1279,8 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), - AddrName); + auto *Ptr = + IC.Builder.CreateInBoundsGEP(ST, Addr, ArrayRef(Indices), AddrName); auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = commonAlignment(Align, SL->getElementOffset(i)); llvm::Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); @@ -1229,8 +1325,8 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), - AddrName); + auto *Ptr = + IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), AddrName); auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = commonAlignment(Align, Offset); Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); @@ -1372,7 +1468,7 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { // If we have a store to a location which is known constant, we can conclude // that the store must be storing the constant value (else the memory // wouldn't be constant), and this must be a noop. - if (AA->pointsToConstantMemory(Ptr)) + if (!isModSet(AA->getModRefInfoMask(Ptr))) return eraseInstFromFunction(SI); // Do really simple DSE, to catch cases where there are several consecutive @@ -1547,6 +1643,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { SI.getOrdering(), SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); NewSI->setDebugLoc(MergedLoc); + NewSI->mergeDIAssignID({&SI, OtherStore}); // If the two stores had AA tags, merge them. AAMDNodes AATags = SI.getAAMetadata(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 8cb09cbac86f..97f129e200de 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -139,9 +140,56 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, return nullptr; } +/// Reduce integer multiplication patterns that contain a (+/-1 << Z) factor. +/// Callers are expected to call this twice to handle commuted patterns. +static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, + InstCombiner::BuilderTy &Builder) { + Value *X = Mul.getOperand(0), *Y = Mul.getOperand(1); + if (CommuteOperands) + std::swap(X, Y); + + const bool HasNSW = Mul.hasNoSignedWrap(); + const bool HasNUW = Mul.hasNoUnsignedWrap(); + + // X * (1 << Z) --> X << Z + Value *Z; + if (match(Y, m_Shl(m_One(), m_Value(Z)))) { + bool PropagateNSW = HasNSW && cast<ShlOperator>(Y)->hasNoSignedWrap(); + return Builder.CreateShl(X, Z, Mul.getName(), HasNUW, PropagateNSW); + } + + // Similar to above, but an increment of the shifted value becomes an add: + // X * ((1 << Z) + 1) --> (X * (1 << Z)) + X --> (X << Z) + X + // This increases uses of X, so it may require a freeze, but that is still + // expected to be an improvement because it removes the multiply. + BinaryOperator *Shift; + if (match(Y, m_OneUse(m_Add(m_BinOp(Shift), m_One()))) && + match(Shift, m_OneUse(m_Shl(m_One(), m_Value(Z))))) { + bool PropagateNSW = HasNSW && Shift->hasNoSignedWrap(); + Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *Shl = Builder.CreateShl(FrX, Z, "mulshl", HasNUW, PropagateNSW); + return Builder.CreateAdd(Shl, FrX, Mul.getName(), HasNUW, PropagateNSW); + } + + // Similar to above, but a decrement of the shifted value is disguised as + // 'not' and becomes a sub: + // X * (~(-1 << Z)) --> X * ((1 << Z) - 1) --> (X << Z) - X + // This increases uses of X, so it may require a freeze, but that is still + // expected to be an improvement because it removes the multiply. + if (match(Y, m_OneUse(m_Not(m_OneUse(m_Shl(m_AllOnes(), m_Value(Z))))))) { + Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *Shl = Builder.CreateShl(FrX, Z, "mulshl"); + return Builder.CreateSub(Shl, FrX, Mul.getName()); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { - if (Value *V = simplifyMulInst(I.getOperand(0), I.getOperand(1), - SQ.getWithInstruction(&I))) + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (Value *V = + simplifyMulInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (SimplifyAssociativeOrCommutative(I)) @@ -153,18 +201,18 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - unsigned BitWidth = I.getType()->getScalarSizeInBits(); + Type *Ty = I.getType(); + const unsigned BitWidth = Ty->getScalarSizeInBits(); + const bool HasNSW = I.hasNoSignedWrap(); + const bool HasNUW = I.hasNoUnsignedWrap(); - // X * -1 == 0 - X + // X * -1 --> 0 - X if (match(Op1, m_AllOnes())) { - BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); - if (I.hasNoSignedWrap()) - BO->setHasNoSignedWrap(); - return BO; + return HasNSW ? BinaryOperator::CreateNSWNeg(Op0) + : BinaryOperator::CreateNeg(Op0); } // Also allow combining multiply instructions on vectors. @@ -179,10 +227,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Constant *Shl = ConstantExpr::getShl(C1, C2); BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0)); BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl); - if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap()) + if (HasNUW && Mul->hasNoUnsignedWrap()) BO->setHasNoUnsignedWrap(); - if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() && - Shl->isNotMinSignedValue()) + if (HasNSW && Mul->hasNoSignedWrap() && Shl->isNotMinSignedValue()) BO->setHasNoSignedWrap(); return BO; } @@ -192,9 +239,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Constant *NewCst = ConstantExpr::getExactLogBase2(C1)) { BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); - if (I.hasNoUnsignedWrap()) + if (HasNUW) Shl->setHasNoUnsignedWrap(); - if (I.hasNoSignedWrap()) { + if (HasNSW) { const APInt *V; if (match(NewCst, m_APInt(V)) && *V != V->getBitWidth() - 1) Shl->setHasNoSignedWrap(); @@ -211,6 +258,25 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this)) return BinaryOperator::CreateMul( NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName()); + + // Try to convert multiply of extended operand to narrow negate and shift + // for better analysis. + // This is valid if the shift amount (trailing zeros in the multiplier + // constant) clears more high bits than the bitwidth difference between + // source and destination types: + // ({z/s}ext X) * (-1<<C) --> (zext (-X)) << C + const APInt *NegPow2C; + Value *X; + if (match(Op0, m_ZExtOrSExt(m_Value(X))) && + match(Op1, m_APIntAllowUndef(NegPow2C))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned ShiftAmt = NegPow2C->countTrailingZeros(); + if (ShiftAmt >= BitWidth - SrcWidth) { + Value *N = Builder.CreateNeg(X, X->getName() + ".neg"); + Value *Z = Builder.CreateZExt(N, Ty, N->getName() + ".z"); + return BinaryOperator::CreateShl(Z, ConstantInt::get(Ty, ShiftAmt)); + } + } } if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) @@ -220,16 +286,29 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { return replaceInstUsesWith(I, FoldedMul); // Simplify mul instructions with a constant RHS. - if (isa<Constant>(Op1)) { - // Canonicalize (X+C1)*CI -> X*CI+C1*CI. + Constant *MulC; + if (match(Op1, m_ImmConstant(MulC))) { + // Canonicalize (X+C1)*MulC -> X*MulC+C1*MulC. + // Canonicalize (X|C1)*MulC -> X*MulC+C1*MulC. Value *X; Constant *C1; - if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) { - Value *Mul = Builder.CreateMul(C1, Op1); - // Only go forward with the transform if C1*CI simplifies to a tidier - // constant. - if (!match(Mul, m_Mul(m_Value(), m_Value()))) - return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul); + if ((match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(C1))))) || + (match(Op0, m_OneUse(m_Or(m_Value(X), m_ImmConstant(C1)))) && + haveNoCommonBitsSet(X, C1, DL, &AC, &I, &DT))) { + // C1*MulC simplifies to a tidier constant. + Value *NewC = Builder.CreateMul(C1, MulC); + auto *BOp0 = cast<BinaryOperator>(Op0); + bool Op0NUW = + (BOp0->getOpcode() == Instruction::Or || BOp0->hasNoUnsignedWrap()); + Value *NewMul = Builder.CreateMul(X, MulC); + auto *BO = BinaryOperator::CreateAdd(NewMul, NewC); + if (HasNUW && Op0NUW) { + // If NewMulBO is constant we also can set BO to nuw. + if (auto *NewMulBO = dyn_cast<BinaryOperator>(NewMul)) + NewMulBO->setHasNoUnsignedWrap(); + BO->setHasNoUnsignedWrap(); + } + return BO; } } @@ -254,8 +333,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // -X * -Y --> X * Y if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Neg(m_Value(Y)))) { auto *NewMul = BinaryOperator::CreateMul(X, Y); - if (I.hasNoSignedWrap() && - cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() && + if (HasNSW && cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() && cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap()) NewMul->setHasNoSignedWrap(); return NewMul; @@ -306,33 +384,15 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // 2) X * Y --> X & Y, iff X, Y can be only {0,1}. // Note: We could use known bits to generalize this and related patterns with // shifts/truncs - Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1) || (match(Op0, m_And(m_Value(), m_One())) && match(Op1, m_And(m_Value(), m_One())))) return BinaryOperator::CreateAnd(Op0, Op1); - // X*(1 << Y) --> X << Y - // (1 << Y)*X --> X << Y - { - Value *Y; - BinaryOperator *BO = nullptr; - bool ShlNSW = false; - if (match(Op0, m_Shl(m_One(), m_Value(Y)))) { - BO = BinaryOperator::CreateShl(Op1, Y); - ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap(); - } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) { - BO = BinaryOperator::CreateShl(Op0, Y); - ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap(); - } - if (BO) { - if (I.hasNoUnsignedWrap()) - BO->setHasNoUnsignedWrap(); - if (I.hasNoSignedWrap() && ShlNSW) - BO->setHasNoSignedWrap(); - return BO; - } - } + if (Value *R = foldMulShl1(I, /* CommuteOperands */ false, Builder)) + return replaceInstUsesWith(I, R); + if (Value *R = foldMulShl1(I, /* CommuteOperands */ true, Builder)) + return replaceInstUsesWith(I, R); // (zext bool X) * (zext bool Y) --> zext (and X, Y) // (sext bool X) * (sext bool Y) --> zext (and X, Y) @@ -403,8 +463,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { m_One()), m_Deferred(X)))) { Value *Abs = Builder.CreateBinaryIntrinsic( - Intrinsic::abs, X, - ConstantInt::getBool(I.getContext(), I.hasNoSignedWrap())); + Intrinsic::abs, X, ConstantInt::getBool(I.getContext(), HasNSW)); Abs->takeName(&I); return replaceInstUsesWith(I, Abs); } @@ -413,12 +472,12 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { return Ext; bool Changed = false; - if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) { + if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedMul(Op0, Op1, I)) { + if (!HasNUW && willNotOverflowUnsignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -488,11 +547,19 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (match(Op1, m_SpecificFP(-1.0))) return UnaryOperator::CreateFNegFMF(Op0, &I); + // With no-nans: X * 0.0 --> copysign(0.0, X) + if (I.hasNoNaNs() && match(Op1, m_PosZeroFP())) { + CallInst *CopySign = Builder.CreateIntrinsic(Intrinsic::copysign, + {I.getType()}, {Op1, Op0}, &I); + return replaceInstUsesWith(I, CopySign); + } + // -X * C --> X * -C Value *X, *Y; Constant *C; if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFMulFMF(X, NegC, &I); // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E) if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) @@ -596,14 +663,32 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { } } + // pow(X, Y) * X --> pow(X, Y+1) + // X * pow(X, Y) --> pow(X, Y+1) + if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X), + m_Value(Y))), + m_Deferred(X)))) { + Value *Y1 = + Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I); + Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I); + return replaceInstUsesWith(I, Pow); + } + if (I.isOnlyUserOfAnyOperand()) { - // pow(x, y) * pow(x, z) -> pow(x, y + z) + // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z) if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) { auto *YZ = Builder.CreateFAddFMF(Y, Z, &I); auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I); return replaceInstUsesWith(I, NewPow); } + // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y) + if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) { + auto *XZ = Builder.CreateFMulFMF(X, Z, &I); + auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I); + return replaceInstUsesWith(I, NewPow); + } // powi(x, y) * powi(x, z) -> powi(x, y + z) if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && @@ -671,6 +756,15 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { } } + // Simplify FMUL recurrences starting with 0.0 to 0.0 if nnan and nsz are set. + // Given a phi node with entry value as 0 and it used in fmul operation, + // we can replace fmul with 0 safely and eleminate loop operation. + PHINode *PN = nullptr; + Value *Start = nullptr, *Step = nullptr; + if (matchSimpleRecurrence(&I, PN, Start, Step) && I.hasNoNaNs() && + I.hasNoSignedZeros() && match(Start, m_Zero())) + return replaceInstUsesWith(I, Start); + return nullptr; } @@ -773,6 +867,70 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, return Remainder.isMinValue(); } +static Instruction *foldIDivShl(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert((I.getOpcode() == Instruction::SDiv || + I.getOpcode() == Instruction::UDiv) && + "Expected integer divide"); + + bool IsSigned = I.getOpcode() == Instruction::SDiv; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + + Instruction *Ret = nullptr; + Value *X, *Y, *Z; + + // With appropriate no-wrap constraints, remove a common factor in the + // dividend and divisor that is disguised as a left-shifted value. + if (match(Op1, m_Shl(m_Value(X), m_Value(Z))) && + match(Op0, m_c_Mul(m_Specific(X), m_Value(Y)))) { + // Both operands must have the matching no-wrap for this kind of division. + auto *Mul = cast<OverflowingBinaryOperator>(Op0); + auto *Shl = cast<OverflowingBinaryOperator>(Op1); + bool HasNUW = Mul->hasNoUnsignedWrap() && Shl->hasNoUnsignedWrap(); + bool HasNSW = Mul->hasNoSignedWrap() && Shl->hasNoSignedWrap(); + + // (X * Y) u/ (X << Z) --> Y u>> Z + if (!IsSigned && HasNUW) + Ret = BinaryOperator::CreateLShr(Y, Z); + + // (X * Y) s/ (X << Z) --> Y s/ (1 << Z) + if (IsSigned && HasNSW && (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *Shl = Builder.CreateShl(ConstantInt::get(Ty, 1), Z); + Ret = BinaryOperator::CreateSDiv(Y, Shl); + } + } + + // With appropriate no-wrap constraints, remove a common factor in the + // dividend and divisor that is disguised as a left-shift amount. + if (match(Op0, m_Shl(m_Value(X), m_Value(Z))) && + match(Op1, m_Shl(m_Value(Y), m_Specific(Z)))) { + auto *Shl0 = cast<OverflowingBinaryOperator>(Op0); + auto *Shl1 = cast<OverflowingBinaryOperator>(Op1); + + // For unsigned div, we need 'nuw' on both shifts or + // 'nsw' on both shifts + 'nuw' on the dividend. + // (X << Z) / (Y << Z) --> X / Y + if (!IsSigned && + ((Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap()) || + (Shl0->hasNoUnsignedWrap() && Shl0->hasNoSignedWrap() && + Shl1->hasNoSignedWrap()))) + Ret = BinaryOperator::CreateUDiv(X, Y); + + // For signed div, we need 'nsw' on both shifts + 'nuw' on the divisor. + // (X << Z) / (Y << Z) --> X / Y + if (IsSigned && Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap() && + Shl1->hasNoUnsignedWrap()) + Ret = BinaryOperator::CreateSDiv(X, Y); + } + + if (!Ret) + return nullptr; + + Ret->setIsExact(I.isExact()); + return Ret; +} + /// This function implements the transforms common to both integer division /// instructions (udiv and sdiv). It is called by the visitors to those integer /// division instructions. @@ -919,6 +1077,41 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } } + // (X << Z) / (X * Y) -> (1 << Z) / Y + // TODO: Handle sdiv. + if (!IsSigned && Op1->hasOneUse() && + match(Op0, m_NUWShl(m_Value(X), m_Value(Z))) && + match(Op1, m_c_Mul(m_Specific(X), m_Value(Y)))) + if (cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap()) { + Instruction *NewDiv = BinaryOperator::CreateUDiv( + Builder.CreateShl(ConstantInt::get(Ty, 1), Z, "", /*NUW*/ true), Y); + NewDiv->setIsExact(I.isExact()); + return NewDiv; + } + + if (Instruction *R = foldIDivShl(I, Builder)) + return R; + + // With the appropriate no-wrap constraint, remove a multiply by the divisor + // after peeking through another divide: + // ((Op1 * X) / Y) / Op1 --> X / Y + if (match(Op0, m_BinOp(I.getOpcode(), m_c_Mul(m_Specific(Op1), m_Value(X)), + m_Value(Y)))) { + auto *InnerDiv = cast<PossiblyExactOperator>(Op0); + auto *Mul = cast<OverflowingBinaryOperator>(InnerDiv->getOperand(0)); + Instruction *NewDiv = nullptr; + if (!IsSigned && Mul->hasNoUnsignedWrap()) + NewDiv = BinaryOperator::CreateUDiv(X, Y); + else if (IsSigned && Mul->hasNoSignedWrap()) + NewDiv = BinaryOperator::CreateSDiv(X, Y); + + // Exact propagates only if both of the original divides are exact. + if (NewDiv) { + NewDiv->setIsExact(I.isExact() && InnerDiv->isExact()); + return NewDiv; + } + } + return nullptr; } @@ -1007,8 +1200,8 @@ static Instruction *narrowUDivURem(BinaryOperator &I, } Constant *C; - if ((match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) || - (match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C)))) { + if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) && + match(D, m_Constant(C))) { // If the constant is the same in the smaller type, use the narrow version. Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); if (ConstantExpr::getZExt(TruncC, Ty) != C) @@ -1016,18 +1209,25 @@ static Instruction *narrowUDivURem(BinaryOperator &I, // udiv (zext X), C --> zext (udiv X, C') // urem (zext X), C --> zext (urem X, C') + return new ZExtInst(Builder.CreateBinOp(Opcode, X, TruncC), Ty); + } + if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) && + match(N, m_Constant(C))) { + // If the constant is the same in the smaller type, use the narrow version. + Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); + if (ConstantExpr::getZExt(TruncC, Ty) != C) + return nullptr; + // udiv C, (zext X) --> zext (udiv C', X) // urem C, (zext X) --> zext (urem C', X) - Value *NarrowOp = isa<Constant>(D) ? Builder.CreateBinOp(Opcode, X, TruncC) - : Builder.CreateBinOp(Opcode, TruncC, X); - return new ZExtInst(NarrowOp, Ty); + return new ZExtInst(Builder.CreateBinOp(Opcode, TruncC, X), Ty); } return nullptr; } Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { - if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1086,6 +1286,16 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { return BinaryOperator::CreateUDiv(A, X); } + // Look through a right-shift to find the common factor: + // ((Op1 *nuw A) >> B) / Op1 --> A >> B + if (match(Op0, m_LShr(m_NUWMul(m_Specific(Op1), m_Value(A)), m_Value(B))) || + match(Op0, m_LShr(m_NUWMul(m_Value(A), m_Specific(Op1)), m_Value(B)))) { + Instruction *Lshr = BinaryOperator::CreateLShr(A, B); + if (I.isExact() && cast<PossiblyExactOperator>(Op0)->isExact()) + Lshr->setIsExact(); + return Lshr; + } + // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); @@ -1097,7 +1307,7 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { } Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { - if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1121,20 +1331,25 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { if (match(Op1, m_SignMask())) return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), Ty); - // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative - // sdiv exact X, -1<<C --> -(ashr exact X, C) - if (I.isExact() && ((match(Op1, m_Power2()) && match(Op1, m_NonNegative())) || - match(Op1, m_NegatedPower2()))) { - bool DivisorWasNegative = match(Op1, m_NegatedPower2()); - if (DivisorWasNegative) - Op1 = ConstantExpr::getNeg(cast<Constant>(Op1)); - auto *AShr = BinaryOperator::CreateExactAShr( - Op0, ConstantExpr::getExactLogBase2(cast<Constant>(Op1)), I.getName()); - if (!DivisorWasNegative) - return AShr; - Builder.Insert(AShr); - AShr->setName(I.getName() + ".neg"); - return BinaryOperator::CreateNeg(AShr, I.getName()); + if (I.isExact()) { + // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative + if (match(Op1, m_Power2()) && match(Op1, m_NonNegative())) { + Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(Op1)); + return BinaryOperator::CreateExactAShr(Op0, C); + } + + // sdiv exact X, (1<<ShAmt) --> ashr exact X, ShAmt (if shl is non-negative) + Value *ShAmt; + if (match(Op1, m_NSWShl(m_One(), m_Value(ShAmt)))) + return BinaryOperator::CreateExactAShr(Op0, ShAmt); + + // sdiv exact X, -1<<C --> -(ashr exact X, C) + if (match(Op1, m_NegatedPower2())) { + Constant *NegPow2C = ConstantExpr::getNeg(cast<Constant>(Op1)); + Constant *C = ConstantExpr::getExactLogBase2(NegPow2C); + Value *Ashr = Builder.CreateAShr(Op0, C, I.getName() + ".neg", true); + return BinaryOperator::CreateNeg(Ashr); + } } const APInt *Op1C; @@ -1184,12 +1399,17 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { ConstantInt::getAllOnesValue(Ty)); } - // If the sign bits of both operands are zero (i.e. we can prove they are - // unsigned inputs), turn this into a udiv. - APInt Mask(APInt::getSignMask(Ty->getScalarSizeInBits())); - if (MaskedValueIsZero(Op0, Mask, 0, &I)) { - if (MaskedValueIsZero(Op1, Mask, 0, &I)) { - // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set + KnownBits KnownDividend = computeKnownBits(Op0, 0, &I); + if (!I.isExact() && + (match(Op1, m_Power2(Op1C)) || match(Op1, m_NegatedPower2(Op1C))) && + KnownDividend.countMinTrailingZeros() >= Op1C->countTrailingZeros()) { + I.setIsExact(); + return &I; + } + + if (KnownDividend.isNonNegative()) { + // If both operands are unsigned, turn this into a udiv. + if (isKnownNonNegative(Op1, DL, 0, &AC, &I, &DT)) { auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); BO->setIsExact(I.isExact()); return BO; @@ -1219,15 +1439,28 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { } /// Remove negation and try to convert division into multiplication. -static Instruction *foldFDivConstantDivisor(BinaryOperator &I) { +Instruction *InstCombinerImpl::foldFDivConstantDivisor(BinaryOperator &I) { Constant *C; if (!match(I.getOperand(1), m_Constant(C))) return nullptr; // -X / C --> X / -C Value *X; + const DataLayout &DL = I.getModule()->getDataLayout(); if (match(I.getOperand(0), m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFDivFMF(X, NegC, &I); + + // nnan X / +0.0 -> copysign(inf, X) + if (I.hasNoNaNs() && match(I.getOperand(1), m_Zero())) { + IRBuilder<> B(&I); + // TODO: nnan nsz X / -0.0 -> copysign(inf, X) + CallInst *CopySign = B.CreateIntrinsic( + Intrinsic::copysign, {C->getType()}, + {ConstantFP::getInfinity(I.getType()), I.getOperand(0)}, &I); + CopySign->takeName(&I); + return replaceInstUsesWith(I, CopySign); + } // If the constant divisor has an exact inverse, this is always safe. If not, // then we can still create a reciprocal if fast-math-flags allow it and the @@ -1239,7 +1472,6 @@ static Instruction *foldFDivConstantDivisor(BinaryOperator &I) { // on all targets. // TODO: Use Intrinsic::canonicalize or let function attributes tell us that // denorms are flushed? - const DataLayout &DL = I.getModule()->getDataLayout(); auto *RecipC = ConstantFoldBinaryOpOperands( Instruction::FDiv, ConstantFP::get(I.getType(), 1.0), C, DL); if (!RecipC || !RecipC->isNormalFP()) @@ -1257,15 +1489,16 @@ static Instruction *foldFDivConstantDividend(BinaryOperator &I) { // C / -X --> -C / X Value *X; + const DataLayout &DL = I.getModule()->getDataLayout(); if (match(I.getOperand(1), m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFDivFMF(NegC, X, &I); if (!I.hasAllowReassoc() || !I.hasAllowReciprocal()) return nullptr; // Try to reassociate C / X expressions where X includes another constant. Constant *C2, *NewC = nullptr; - const DataLayout &DL = I.getModule()->getDataLayout(); if (match(I.getOperand(1), m_FMul(m_Value(X), m_Constant(C2)))) { // C / (X * C2) --> (C / C2) / X NewC = ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C2, DL); @@ -1435,6 +1668,16 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { if (Instruction *Mul = foldFDivPowDivisor(I, Builder)) return Mul; + // pow(X, Y) / X --> pow(X, Y-1) + if (I.hasAllowReassoc() && + match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1), + m_Value(Y))))) { + Value *Y1 = + Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), -1.0), &I); + Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, Op1, Y1, &I); + return replaceInstUsesWith(I, Pow); + } + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index c573b03f31a6..e24abc48424d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -15,8 +15,6 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -130,7 +128,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // FIXME: can this be reworked into a worklist-based algorithm while preserving // the depth-first, early bailout traversal? -LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { +[[nodiscard]] Value *Negator::visitImpl(Value *V, unsigned Depth) { // -(undef) -> undef. if (match(V, m_Undef())) return V; @@ -248,6 +246,19 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { return nullptr; switch (I->getOpcode()) { + case Instruction::ZExt: { + // Negation of zext of signbit is signbit splat: + // 0 - (zext (i8 X u>> 7) to iN) --> sext (i8 X s>> 7) to iN + Value *SrcOp = I->getOperand(0); + unsigned SrcWidth = SrcOp->getType()->getScalarSizeInBits(); + const APInt &FullShift = APInt(SrcWidth, SrcWidth - 1); + if (IsTrulyNegation && + match(SrcOp, m_LShr(m_Value(X), m_SpecificIntAllowUndef(FullShift)))) { + Value *Ashr = Builder.CreateAShr(X, FullShift); + return Builder.CreateSExt(Ashr, I->getType()); + } + break; + } case Instruction::And: { Constant *ShAmt; // sub(y,and(lshr(x,C),1)) --> add(ashr(shl(x,(BW-1)-C),BW-1),y) @@ -382,7 +393,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); // Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`. auto *Op1C = dyn_cast<Constant>(I->getOperand(1)); - if (!Op1C) // Early return. + if (!Op1C || !IsTrulyNegation) return nullptr; return Builder.CreateMul( I->getOperand(0), @@ -399,7 +410,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { if (match(Ops[1], m_One())) return Builder.CreateNot(Ops[0], I->getName() + ".neg"); // Else, just defer to Instruction::Add handling. - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Instruction::Add: { // `add` is negatible if both of its operands are negatible. @@ -465,7 +476,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { llvm_unreachable("Can't get here. We always return from switch."); } -LLVM_NODISCARD Value *Negator::negate(Value *V, unsigned Depth) { +[[nodiscard]] Value *Negator::negate(Value *V, unsigned Depth) { NegatorMaxDepthVisited.updateMax(Depth); ++NegatorNumValuesVisited; @@ -502,20 +513,20 @@ LLVM_NODISCARD Value *Negator::negate(Value *V, unsigned Depth) { return NegatedV; } -LLVM_NODISCARD Optional<Negator::Result> Negator::run(Value *Root) { +[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root) { Value *Negated = negate(Root, /*Depth=*/0); if (!Negated) { // We must cleanup newly-inserted instructions, to avoid any potential // endless combine looping. for (Instruction *I : llvm::reverse(NewInstructions)) I->eraseFromParent(); - return llvm::None; + return std::nullopt; } return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated); } -LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root, - InstCombinerImpl &IC) { +[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, Value *Root, + InstCombinerImpl &IC) { ++NegatorTotalNegationsAttempted; LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root << "\n"); @@ -525,7 +536,7 @@ LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root, Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(), IC.getDominatorTree(), LHSIsZero); - Optional<Result> Res = N.run(Root); + std::optional<Result> Res = N.run(Root); if (!Res) { // Negation failed. LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root << "\n"); diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 90a796a0939e..7f59729f0085 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -20,6 +20,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -102,15 +103,15 @@ void InstCombinerImpl::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) { // ptr_val_inc = ... // ... // -Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { +bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { if (!PN.getType()->isIntegerTy()) - return nullptr; + return false; if (!PN.hasOneUse()) - return nullptr; + return false; auto *IntToPtr = dyn_cast<IntToPtrInst>(PN.user_back()); if (!IntToPtr) - return nullptr; + return false; // Check if the pointer is actually used as pointer: auto HasPointerUse = [](Instruction *IIP) { @@ -131,11 +132,11 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { }; if (!HasPointerUse(IntToPtr)) - return nullptr; + return false; if (DL.getPointerSizeInBits(IntToPtr->getAddressSpace()) != DL.getTypeSizeInBits(IntToPtr->getOperand(0)->getType())) - return nullptr; + return false; SmallVector<Value *, 4> AvailablePtrVals; for (auto Incoming : zip(PN.blocks(), PN.incoming_values())) { @@ -174,10 +175,10 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { // For a single use integer load: auto *LoadI = dyn_cast<LoadInst>(Arg); if (!LoadI) - return nullptr; + return false; if (!LoadI->hasOneUse()) - return nullptr; + return false; // Push the integer typed Load instruction into the available // value set, and fix it up later when the pointer typed PHI @@ -194,7 +195,7 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { for (PHINode &PtrPHI : BB->phis()) { // FIXME: consider handling this in AggressiveInstCombine if (NumPhis++ > MaxNumPhis) - return nullptr; + return false; if (&PtrPHI == &PN || PtrPHI.getType() != IntToPtr->getType()) continue; if (any_of(zip(PN.blocks(), AvailablePtrVals), @@ -211,16 +212,19 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { if (MatchingPtrPHI) { assert(MatchingPtrPHI->getType() == IntToPtr->getType() && "Phi's Type does not match with IntToPtr"); - // The PtrToCast + IntToPtr will be simplified later - return CastInst::CreateBitOrPointerCast(MatchingPtrPHI, - IntToPtr->getOperand(0)->getType()); + // Explicitly replace the inttoptr (rather than inserting a ptrtoint) here, + // to make sure another transform can't undo it in the meantime. + replaceInstUsesWith(*IntToPtr, MatchingPtrPHI); + eraseInstFromFunction(*IntToPtr); + eraseInstFromFunction(PN); + return true; } // If it requires a conversion for every PHI operand, do not do it. if (all_of(AvailablePtrVals, [&](Value *V) { return (V->getType() != IntToPtr->getType()) || isa<IntToPtrInst>(V); })) - return nullptr; + return false; // If any of the operand that requires casting is a terminator // instruction, do not do it. Similarly, do not do the transform if the value @@ -239,7 +243,7 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { return true; return false; })) - return nullptr; + return false; PHINode *NewPtrPHI = PHINode::Create( IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr"); @@ -290,9 +294,12 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { NewPtrPHI->addIncoming(CI, IncomingBB); } - // The PtrToCast + IntToPtr will be simplified later - return CastInst::CreateBitOrPointerCast(NewPtrPHI, - IntToPtr->getOperand(0)->getType()); + // Explicitly replace the inttoptr (rather than inserting a ptrtoint) here, + // to make sure another transform can't undo it in the meantime. + replaceInstUsesWith(*IntToPtr, NewPtrPHI); + eraseInstFromFunction(*IntToPtr); + eraseInstFromFunction(PN); + return true; } // Remove RoundTrip IntToPtr/PtrToInt Cast on PHI-Operand and @@ -598,7 +605,7 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { Value *Base = FixedOperands[0]; GetElementPtrInst *NewGEP = GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base, - makeArrayRef(FixedOperands).slice(1)); + ArrayRef(FixedOperands).slice(1)); if (AllInBounds) NewGEP->setIsInBounds(); PHIArgMergedDebugLoc(NewGEP, PN); return NewGEP; @@ -1322,7 +1329,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, // Check that edges outgoing from the idom's terminators dominate respective // inputs of the Phi. - Optional<bool> Invert; + std::optional<bool> Invert; for (auto Pair : zip(PN.incoming_values(), PN.blocks())) { auto *Input = cast<ConstantInt>(std::get<0>(Pair)); BasicBlock *Pred = std::get<1>(Pair); @@ -1412,8 +1419,8 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // this PHI only has a single use (a PHI), and if that PHI only has one use (a // PHI)... break the cycle. if (PN.hasOneUse()) { - if (Instruction *Result = foldIntegerTypedPHI(PN)) - return Result; + if (foldIntegerTypedPHI(PN)) + return nullptr; Instruction *PHIUser = cast<Instruction>(PN.user_back()); if (PHINode *PU = dyn_cast<PHINode>(PHIUser)) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index ad96a5f475f1..e7d8208f94fd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -12,7 +12,6 @@ #include "InstCombineInternal.h" #include "llvm/ADT/APInt.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AssumptionCache.h" @@ -20,6 +19,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" @@ -314,47 +314,95 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, TI->getType()); } - // Cond ? -X : -Y --> -(Cond ? X : Y) - Value *X, *Y; - if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) && - (TI->hasOneUse() || FI->hasOneUse())) { - // Intersect FMF from the fneg instructions and union those with the select. - FastMathFlags FMF = TI->getFastMathFlags(); - FMF &= FI->getFastMathFlags(); - FMF |= SI.getFastMathFlags(); - Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); - if (auto *NewSelI = dyn_cast<Instruction>(NewSel)) - NewSelI->setFastMathFlags(FMF); - Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); - NewFNeg->setFastMathFlags(FMF); - return NewFNeg; - } - - // Min/max intrinsic with a common operand can have the common operand pulled - // after the select. This is the same transform as below for binops, but - // specialized for intrinsic matching and without the restrictive uses clause. - auto *TII = dyn_cast<IntrinsicInst>(TI); - auto *FII = dyn_cast<IntrinsicInst>(FI); - if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID() && - (TII->hasOneUse() || FII->hasOneUse())) { - Value *T0, *T1, *F0, *F1; - if (match(TII, m_MaxOrMin(m_Value(T0), m_Value(T1))) && - match(FII, m_MaxOrMin(m_Value(F0), m_Value(F1)))) { - if (T0 == F0) { - Value *NewSel = Builder.CreateSelect(Cond, T1, F1, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T0}); - } - if (T0 == F1) { - Value *NewSel = Builder.CreateSelect(Cond, T1, F0, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T0}); + Value *OtherOpT, *OtherOpF; + bool MatchIsOpZero; + auto getCommonOp = [&](Instruction *TI, Instruction *FI, bool Commute, + bool Swapped = false) -> Value * { + assert(!(Commute && Swapped) && + "Commute and Swapped can't set at the same time"); + if (!Swapped) { + if (TI->getOperand(0) == FI->getOperand(0)) { + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + return TI->getOperand(0); + } else if (TI->getOperand(1) == FI->getOperand(1)) { + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = false; + return TI->getOperand(1); } - if (T1 == F0) { - Value *NewSel = Builder.CreateSelect(Cond, T0, F1, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T1}); + } + + if (!Commute && !Swapped) + return nullptr; + + // If we are allowing commute or swap of operands, then + // allow a cross-operand match. In that case, MatchIsOpZero + // means that TI's operand 0 (FI's operand 1) is the common op. + if (TI->getOperand(0) == FI->getOperand(1)) { + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = true; + return TI->getOperand(0); + } else if (TI->getOperand(1) == FI->getOperand(0)) { + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = false; + return TI->getOperand(1); + } + return nullptr; + }; + + if (TI->hasOneUse() || FI->hasOneUse()) { + // Cond ? -X : -Y --> -(Cond ? X : Y) + Value *X, *Y; + if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y)))) { + // Intersect FMF from the fneg instructions and union those with the + // select. + FastMathFlags FMF = TI->getFastMathFlags(); + FMF &= FI->getFastMathFlags(); + FMF |= SI.getFastMathFlags(); + Value *NewSel = + Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); + if (auto *NewSelI = dyn_cast<Instruction>(NewSel)) + NewSelI->setFastMathFlags(FMF); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); + NewFNeg->setFastMathFlags(FMF); + return NewFNeg; + } + + // Min/max intrinsic with a common operand can have the common operand + // pulled after the select. This is the same transform as below for binops, + // but specialized for intrinsic matching and without the restrictive uses + // clause. + auto *TII = dyn_cast<IntrinsicInst>(TI); + auto *FII = dyn_cast<IntrinsicInst>(FI); + if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID()) { + if (match(TII, m_MaxOrMin(m_Value(), m_Value()))) { + if (Value *MatchOp = getCommonOp(TI, FI, true)) { + Value *NewSel = + Builder.CreateSelect(Cond, OtherOpT, OtherOpF, "minmaxop", &SI); + return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp}); + } } - if (T1 == F1) { - Value *NewSel = Builder.CreateSelect(Cond, T0, F0, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T1}); + } + + // icmp with a common operand also can have the common operand + // pulled after the select. + ICmpInst::Predicate TPred, FPred; + if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) && + match(FI, m_ICmp(FPred, m_Value(), m_Value()))) { + if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) { + bool Swapped = TPred != FPred; + if (Value *MatchOp = + getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) { + Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, + SI.getName() + ".v", &SI); + return new ICmpInst( + MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred), + MatchOp, NewSel); + } } } } @@ -370,33 +418,9 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, return nullptr; // Figure out if the operations have any operands in common. - Value *MatchOp, *OtherOpT, *OtherOpF; - bool MatchIsOpZero; - if (TI->getOperand(0) == FI->getOperand(0)) { - MatchOp = TI->getOperand(0); - OtherOpT = TI->getOperand(1); - OtherOpF = FI->getOperand(1); - MatchIsOpZero = true; - } else if (TI->getOperand(1) == FI->getOperand(1)) { - MatchOp = TI->getOperand(1); - OtherOpT = TI->getOperand(0); - OtherOpF = FI->getOperand(0); - MatchIsOpZero = false; - } else if (!TI->isCommutative()) { - return nullptr; - } else if (TI->getOperand(0) == FI->getOperand(1)) { - MatchOp = TI->getOperand(0); - OtherOpT = TI->getOperand(1); - OtherOpF = FI->getOperand(0); - MatchIsOpZero = true; - } else if (TI->getOperand(1) == FI->getOperand(0)) { - MatchOp = TI->getOperand(1); - OtherOpT = TI->getOperand(0); - OtherOpF = FI->getOperand(1); - MatchIsOpZero = true; - } else { + Value *MatchOp = getCommonOp(TI, FI, TI->isCommutative()); + if (!MatchOp) return nullptr; - } // If the select condition is a vector, the operands of the original select's // operands also must be vectors. This may not be the case for getelementptr @@ -442,44 +466,44 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, Value *FalseVal, bool Swapped) -> Instruction * { - if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { - if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) { - if (unsigned SFO = getSelectFoldableOperands(TVI)) { - unsigned OpToFold = 0; - if ((SFO & 1) && FalseVal == TVI->getOperand(0)) - OpToFold = 1; - else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) - OpToFold = 2; - - if (OpToFold) { - FastMathFlags FMF; - // TODO: We probably ought to revisit cases where the select and FP - // instructions have different flags and add tests to ensure the - // behaviour is correct. - if (isa<FPMathOperator>(&SI)) - FMF = SI.getFastMathFlags(); - Constant *C = ConstantExpr::getBinOpIdentity( - TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); - Value *OOp = TVI->getOperand(2 - OpToFold); - // Avoid creating select between 2 constants unless it's selecting - // between 0, 1 and -1. - const APInt *OOpC; - bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || - (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { - Value *NewSel = Builder.CreateSelect( - SI.getCondition(), Swapped ? C : OOp, Swapped ? OOp : C); - if (isa<FPMathOperator>(&SI)) - cast<Instruction>(NewSel)->setFastMathFlags(FMF); - NewSel->takeName(TVI); - BinaryOperator *BO = - BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); - BO->copyIRFlags(TVI); - return BO; - } - } - } - } + auto *TVI = dyn_cast<BinaryOperator>(TrueVal); + if (!TVI || !TVI->hasOneUse() || isa<Constant>(FalseVal)) + return nullptr; + + unsigned SFO = getSelectFoldableOperands(TVI); + unsigned OpToFold = 0; + if ((SFO & 1) && FalseVal == TVI->getOperand(0)) + OpToFold = 1; + else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) + OpToFold = 2; + + if (!OpToFold) + return nullptr; + + // TODO: We probably ought to revisit cases where the select and FP + // instructions have different flags and add tests to ensure the + // behaviour is correct. + FastMathFlags FMF; + if (isa<FPMathOperator>(&SI)) + FMF = SI.getFastMathFlags(); + Constant *C = ConstantExpr::getBinOpIdentity( + TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); + Value *OOp = TVI->getOperand(2 - OpToFold); + // Avoid creating select between 2 constants unless it's selecting + // between 0, 1 and -1. + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { + Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, + Swapped ? OOp : C); + if (isa<FPMathOperator>(&SI)) + cast<Instruction>(NewSel)->setFastMathFlags(FMF); + NewSel->takeName(TVI); + BinaryOperator *BO = + BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); + BO->copyIRFlags(TVI); + return BO; } return nullptr; }; @@ -779,19 +803,31 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, const Value *FalseVal, InstCombiner::BuilderTy &Builder) { ICmpInst::Predicate Pred = ICI->getPredicate(); - if (!ICmpInst::isUnsigned(Pred)) - return nullptr; + Value *A = ICI->getOperand(0); + Value *B = ICI->getOperand(1); // (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0 + // (a == 0) ? 0 : a - 1 -> (a != 0) ? a - 1 : 0 if (match(TrueVal, m_Zero())) { Pred = ICmpInst::getInversePredicate(Pred); std::swap(TrueVal, FalseVal); } + if (!match(FalseVal, m_Zero())) return nullptr; - Value *A = ICI->getOperand(0); - Value *B = ICI->getOperand(1); + // ugt 0 is canonicalized to ne 0 and requires special handling + // (a != 0) ? a + -1 : 0 -> usub.sat(a, 1) + if (Pred == ICmpInst::ICMP_NE) { + if (match(B, m_Zero()) && match(TrueVal, m_Add(m_Specific(A), m_AllOnes()))) + return Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, + ConstantInt::get(A->getType(), 1)); + return nullptr; + } + + if (!ICmpInst::isUnsigned(Pred)) + return nullptr; + if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) { // (b < a) ? a - b : 0 -> (a > b) ? a - b : 0 std::swap(A, B); @@ -952,8 +988,8 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - // Check if the condition value compares a value for equality against zero. - if (!ICI->isEquality() || !match(CmpRHS, m_Zero())) + // Check if the select condition compares a value for equality. + if (!ICI->isEquality()) return nullptr; Value *SelectArg = FalseVal; @@ -969,8 +1005,15 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the // input to the cttz/ctlz is used as LHS for the compare instruction. - if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) && - !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) + Value *X; + if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Value(X))) && + !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Value(X)))) + return nullptr; + + // (X == 0) ? BitWidth : ctz(X) + // (X == -1) ? BitWidth : ctz(~X) + if ((X != CmpLHS || !match(CmpRHS, m_Zero())) && + (!match(X, m_Not(m_Specific(CmpLHS))) || !match(CmpRHS, m_AllOnes()))) return nullptr; IntrinsicInst *II = cast<IntrinsicInst>(Count); @@ -1139,6 +1182,28 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, return nullptr; } +static bool replaceInInstruction(Value *V, Value *Old, Value *New, + InstCombiner &IC, unsigned Depth = 0) { + // Conservatively limit replacement to two instructions upwards. + if (Depth == 2) + return false; + + auto *I = dyn_cast<Instruction>(V); + if (!I || !I->hasOneUse() || !isSafeToSpeculativelyExecute(I)) + return false; + + bool Changed = false; + for (Use &U : I->operands()) { + if (U == Old) { + IC.replaceUse(U, New); + Changed = true; + } else { + Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1); + } + } + return Changed; +} + /// If we have a select with an equality comparison, then we know the value in /// one of the arms of the select. See if substituting this value into an arm /// and simplifying the result yields the same value as the other arm. @@ -1157,10 +1222,7 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, /// TODO: Wrapping flags could be preserved in some cases with better analysis. Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp) { - // Value equivalence substitution requires an all-or-nothing replacement. - // It does not make sense for a vector compare where each lane is chosen - // independently. - if (!Cmp.isEquality() || Cmp.getType()->isVectorTy()) + if (!Cmp.isEquality()) return nullptr; // Canonicalize the pattern to ICMP_EQ by swapping the select operands. @@ -1189,15 +1251,11 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // with different operands, which should not cause side-effects or trigger // undefined behavior). Only do this if CmpRHS is a constant, as // profitability is not clear for other cases. - // FIXME: The replacement could be performed recursively. - if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant())) - if (auto *I = dyn_cast<Instruction>(TrueVal)) - if (I->hasOneUse() && isSafeToSpeculativelyExecute(I)) - for (Use &U : I->operands()) - if (U == CmpLHS) { - replaceUse(U, CmpRHS); - return &Sel; - } + // FIXME: Support vectors. + if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && + !Cmp.getType()->isVectorTy()) + if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this)) + return &Sel; } if (TrueVal != CmpRHS && isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) @@ -1371,7 +1429,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, C2->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have signed max element[s]. C2 = InstCombiner::AddOne(C2); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::Predicate::ICMP_SGE: // Also non-canonical, but here we don't need to change C2, // so we don't have any restrictions on C2, so we can just handle it. @@ -2307,6 +2365,41 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, } Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { + if (!isa<VectorType>(Sel.getType())) + return nullptr; + + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + Value *C, *X, *Y; + + if (match(Cond, m_VecReverse(m_Value(C)))) { + auto createSelReverse = [&](Value *C, Value *X, Value *Y) { + Value *V = Builder.CreateSelect(C, X, Y, Sel.getName(), &Sel); + if (auto *I = dyn_cast<Instruction>(V)) + I->copyIRFlags(&Sel); + Module *M = Sel.getModule(); + Function *F = Intrinsic::getDeclaration( + M, Intrinsic::experimental_vector_reverse, V->getType()); + return CallInst::Create(F, V); + }; + + if (match(TVal, m_VecReverse(m_Value(X)))) { + // select rev(C), rev(X), rev(Y) --> rev(select C, X, Y) + if (match(FVal, m_VecReverse(m_Value(Y))) && + (Cond->hasOneUse() || TVal->hasOneUse() || FVal->hasOneUse())) + return createSelReverse(C, X, Y); + + // select rev(C), rev(X), FValSplat --> rev(select C, X, FValSplat) + if ((Cond->hasOneUse() || TVal->hasOneUse()) && isSplatValue(FVal)) + return createSelReverse(C, X, FVal); + } + // select rev(C), TValSplat, rev(Y) --> rev(select C, TValSplat, Y) + else if (isSplatValue(TVal) && match(FVal, m_VecReverse(m_Value(Y))) && + (Cond->hasOneUse() || FVal->hasOneUse())) + return createSelReverse(C, TVal, Y); + } + auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType()); if (!VecTy) return nullptr; @@ -2323,10 +2416,6 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { // A select of a "select shuffle" with a common operand can be rearranged // to select followed by "select shuffle". Because of poison, this only works // in the case of a shuffle with no undefined mask elements. - Value *Cond = Sel.getCondition(); - Value *TVal = Sel.getTrueValue(); - Value *FVal = Sel.getFalseValue(); - Value *X, *Y; ArrayRef<int> Mask; if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && !is_contained(Mask, UndefMaskElem) && @@ -2472,7 +2561,7 @@ Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, assert(Op->getType()->isIntOrIntVectorTy(1) && "Op must be either i1 or vector of i1."); - Optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd); + std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd); if (!Res) return nullptr; @@ -2510,6 +2599,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, InstCombinerImpl &IC) { Value *CondVal = SI.getCondition(); + bool ChangedFMF = false; for (bool Swap : {false, true}) { Value *TrueVal = SI.getTrueValue(); Value *X = SI.getFalseValue(); @@ -2534,13 +2624,33 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, } } + if (!match(TrueVal, m_FNeg(m_Specific(X)))) + return nullptr; + + // Forward-propagate nnan and ninf from the fneg to the select. + // If all inputs are not those values, then the select is not either. + // Note: nsz is defined differently, so it may not be correct to propagate. + FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags(); + if (FMF.noNaNs() && !SI.hasNoNaNs()) { + SI.setHasNoNaNs(true); + ChangedFMF = true; + } + if (FMF.noInfs() && !SI.hasNoInfs()) { + SI.setHasNoInfs(true); + ChangedFMF = true; + } + // With nsz, when 'Swap' is false: // fold (X < +/-0.0) ? -X : X or (X <= +/-0.0) ? -X : X to fabs(X) // fold (X > +/-0.0) ? -X : X or (X >= +/-0.0) ? -X : X to -fabs(x) // when 'Swap' is true: // fold (X > +/-0.0) ? X : -X or (X >= +/-0.0) ? X : -X to fabs(X) // fold (X < +/-0.0) ? X : -X or (X <= +/-0.0) ? X : -X to -fabs(X) - if (!match(TrueVal, m_FNeg(m_Specific(X))) || !SI.hasNoSignedZeros()) + // + // Note: We require "nnan" for this fold because fcmp ignores the signbit + // of NAN, but IEEE-754 specifies the signbit of NAN values with + // fneg/fabs operations. + if (!SI.hasNoSignedZeros() || !SI.hasNoNaNs()) return nullptr; if (Swap) @@ -2563,7 +2673,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, } } - return nullptr; + return ChangedFMF ? &SI : nullptr; } // Match the following IR pattern: @@ -2602,10 +2712,14 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst)))) return nullptr; + // Match even if the AND and ADD are swapped. const APInt *BiasCst, *HighBitMaskCst; if (!match(XBiasedHighBits, m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)), - m_APIntAllowUndef(HighBitMaskCst)))) + m_APIntAllowUndef(HighBitMaskCst))) && + !match(XBiasedHighBits, + m_Add(m_And(m_Specific(X), m_APIntAllowUndef(HighBitMaskCst)), + m_APIntAllowUndef(BiasCst)))) return nullptr; if (!LowBitMaskCst->isMask()) @@ -2635,200 +2749,392 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, return R; } -Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { +namespace { +struct DecomposedSelect { + Value *Cond = nullptr; + Value *TrueVal = nullptr; + Value *FalseVal = nullptr; +}; +} // namespace + +/// Look for patterns like +/// %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false +/// %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f +/// %outer.sel = select i1 %outer.cond, i8 %outer.sel.t, i8 %inner.sel +/// and rewrite it as +/// %inner.sel = select i1 %cond.alternative, i8 %sel.outer.t, i8 %sel.inner.t +/// %sel.outer = select i1 %cond.inner, i8 %inner.sel, i8 %sel.inner.f +static Instruction *foldNestedSelects(SelectInst &OuterSelVal, + InstCombiner::BuilderTy &Builder) { + // We must start with a `select`. + DecomposedSelect OuterSel; + match(&OuterSelVal, + m_Select(m_Value(OuterSel.Cond), m_Value(OuterSel.TrueVal), + m_Value(OuterSel.FalseVal))); + + // Canonicalize inversion of the outermost `select`'s condition. + if (match(OuterSel.Cond, m_Not(m_Value(OuterSel.Cond)))) + std::swap(OuterSel.TrueVal, OuterSel.FalseVal); + + // The condition of the outermost select must be an `and`/`or`. + if (!match(OuterSel.Cond, m_c_LogicalOp(m_Value(), m_Value()))) + return nullptr; + + // Depending on the logical op, inner select might be in different hand. + bool IsAndVariant = match(OuterSel.Cond, m_LogicalAnd()); + Value *InnerSelVal = IsAndVariant ? OuterSel.FalseVal : OuterSel.TrueVal; + + // Profitability check - avoid increasing instruction count. + if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}), + [](Value *V) { return V->hasOneUse(); })) + return nullptr; + + // The appropriate hand of the outermost `select` must be a select itself. + DecomposedSelect InnerSel; + if (!match(InnerSelVal, + m_Select(m_Value(InnerSel.Cond), m_Value(InnerSel.TrueVal), + m_Value(InnerSel.FalseVal)))) + return nullptr; + + // Canonicalize inversion of the innermost `select`'s condition. + if (match(InnerSel.Cond, m_Not(m_Value(InnerSel.Cond)))) + std::swap(InnerSel.TrueVal, InnerSel.FalseVal); + + Value *AltCond = nullptr; + auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) { + return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond))); + }; + + // Finally, match the condition that was driving the outermost `select`, + // it should be a logical operation between the condition that was driving + // the innermost `select` (after accounting for the possible inversions + // of the condition), and some other condition. + if (matchOuterCond(m_Specific(InnerSel.Cond))) { + // Done! + } else if (Value * NotInnerCond; matchOuterCond(m_CombineAnd( + m_Not(m_Specific(InnerSel.Cond)), m_Value(NotInnerCond)))) { + // Done! + std::swap(InnerSel.TrueVal, InnerSel.FalseVal); + InnerSel.Cond = NotInnerCond; + } else // Not the pattern we were looking for. + return nullptr; + + Value *SelInner = Builder.CreateSelect( + AltCond, IsAndVariant ? OuterSel.TrueVal : InnerSel.FalseVal, + IsAndVariant ? InnerSel.TrueVal : OuterSel.FalseVal); + SelInner->takeName(InnerSelVal); + return SelectInst::Create(InnerSel.Cond, + IsAndVariant ? SelInner : InnerSel.TrueVal, + !IsAndVariant ? SelInner : InnerSel.FalseVal); +} + +Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); Type *SelType = SI.getType(); - if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal, - SQ.getWithInstruction(&SI))) - return replaceInstUsesWith(SI, V); - - if (Instruction *I = canonicalizeSelectToShuffle(SI)) - return I; - - if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this)) - return I; - // Avoid potential infinite loops by checking for non-constant condition. // TODO: Can we assert instead by improving canonicalizeSelectToShuffle()? // Scalar select must have simplified? - if (SelType->isIntOrIntVectorTy(1) && !isa<Constant>(CondVal) && - TrueVal->getType() == CondVal->getType()) { - // Folding select to and/or i1 isn't poison safe in general. impliesPoison - // checks whether folding it does not convert a well-defined value into - // poison. - if (match(TrueVal, m_One())) { - if (impliesPoison(FalseVal, CondVal)) { - // Change: A = select B, true, C --> A = or B, C - return BinaryOperator::CreateOr(CondVal, FalseVal); - } + if (!SelType->isIntOrIntVectorTy(1) || isa<Constant>(CondVal) || + TrueVal->getType() != CondVal->getType()) + return nullptr; + + auto *One = ConstantInt::getTrue(SelType); + auto *Zero = ConstantInt::getFalse(SelType); + Value *A, *B, *C, *D; - if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) - if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) - if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, - /*IsSelectLogical*/ true)) - return replaceInstUsesWith(SI, V); + // Folding select to and/or i1 isn't poison safe in general. impliesPoison + // checks whether folding it does not convert a well-defined value into + // poison. + if (match(TrueVal, m_One())) { + if (impliesPoison(FalseVal, CondVal)) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); } - if (match(FalseVal, m_Zero())) { - if (impliesPoison(TrueVal, CondVal)) { - // Change: A = select B, C, false --> A = and B, C - return BinaryOperator::CreateAnd(CondVal, TrueVal); - } - if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) - if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) - if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, - /*IsSelectLogical*/ true)) - return replaceInstUsesWith(SI, V); + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) + if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) + if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, + /*IsSelectLogical*/ true)) + return replaceInstUsesWith(SI, V); + + // (A && B) || (C && B) --> (A || C) && B + if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) && + match(FalseVal, m_LogicalAnd(m_Value(C), m_Value(D))) && + (CondVal->hasOneUse() || FalseVal->hasOneUse())) { + bool CondLogicAnd = isa<SelectInst>(CondVal); + bool FalseLogicAnd = isa<SelectInst>(FalseVal); + auto AndFactorization = [&](Value *Common, Value *InnerCond, + Value *InnerVal, + bool SelFirst = false) -> Instruction * { + Value *InnerSel = Builder.CreateSelect(InnerCond, One, InnerVal); + if (SelFirst) + std::swap(Common, InnerSel); + if (FalseLogicAnd || (CondLogicAnd && Common == A)) + return SelectInst::Create(Common, InnerSel, Zero); + else + return BinaryOperator::CreateAnd(Common, InnerSel); + }; + + if (A == C) + return AndFactorization(A, B, D); + if (A == D) + return AndFactorization(A, B, C); + if (B == C) + return AndFactorization(B, A, D); + if (B == D) + return AndFactorization(B, A, C, CondLogicAnd && FalseLogicAnd); } + } - auto *One = ConstantInt::getTrue(SelType); - auto *Zero = ConstantInt::getFalse(SelType); + if (match(FalseVal, m_Zero())) { + if (impliesPoison(TrueVal, CondVal)) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } - // We match the "full" 0 or 1 constant here to avoid a potential infinite - // loop with vectors that may have undefined/poison elements. - // select a, false, b -> select !a, b, false - if (match(TrueVal, m_Specific(Zero))) { - Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return SelectInst::Create(NotCond, FalseVal, Zero); + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) + if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) + if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, + /*IsSelectLogical*/ true)) + return replaceInstUsesWith(SI, V); + + // (A || B) && (C || B) --> (A && C) || B + if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) && + match(TrueVal, m_LogicalOr(m_Value(C), m_Value(D))) && + (CondVal->hasOneUse() || TrueVal->hasOneUse())) { + bool CondLogicOr = isa<SelectInst>(CondVal); + bool TrueLogicOr = isa<SelectInst>(TrueVal); + auto OrFactorization = [&](Value *Common, Value *InnerCond, + Value *InnerVal, + bool SelFirst = false) -> Instruction * { + Value *InnerSel = Builder.CreateSelect(InnerCond, InnerVal, Zero); + if (SelFirst) + std::swap(Common, InnerSel); + if (TrueLogicOr || (CondLogicOr && Common == A)) + return SelectInst::Create(Common, One, InnerSel); + else + return BinaryOperator::CreateOr(Common, InnerSel); + }; + + if (A == C) + return OrFactorization(A, B, D); + if (A == D) + return OrFactorization(A, B, C); + if (B == C) + return OrFactorization(B, A, D); + if (B == D) + return OrFactorization(B, A, C, CondLogicOr && TrueLogicOr); } - // select a, b, true -> select !a, true, b - if (match(FalseVal, m_Specific(One))) { - Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return SelectInst::Create(NotCond, One, TrueVal); + } + + // We match the "full" 0 or 1 constant here to avoid a potential infinite + // loop with vectors that may have undefined/poison elements. + // select a, false, b -> select !a, b, false + if (match(TrueVal, m_Specific(Zero))) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, FalseVal, Zero); + } + // select a, b, true -> select !a, true, b + if (match(FalseVal, m_Specific(One))) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, One, TrueVal); + } + + // DeMorgan in select form: !a && !b --> !(a || b) + // select !a, !b, false --> not (select a, true, b) + if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) && + (CondVal->hasOneUse() || TrueVal->hasOneUse()) && + !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) + return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B)); + + // DeMorgan in select form: !a || !b --> !(a && b) + // select !a, true, !b --> not (select a, b, false) + if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) && + (CondVal->hasOneUse() || FalseVal->hasOneUse()) && + !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) + return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero)); + + // select (select a, true, b), true, b -> select a, true, b + if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && + match(TrueVal, m_One()) && match(FalseVal, m_Specific(B))) + return replaceOperand(SI, 0, A); + // select (select a, b, false), b, false -> select a, b, false + if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && + match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) + return replaceOperand(SI, 0, A); + + // ~(A & B) & (A | B) --> A ^ B + if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))), + m_c_LogicalOr(m_Deferred(A), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); + + // select (~a | c), a, b -> and a, (or c, freeze(b)) + if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && + CondVal->hasOneUse()) { + FalseVal = Builder.CreateFreeze(FalseVal); + return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); + } + // select (~c & b), a, b -> and b, (or freeze(a), c) + if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && + CondVal->hasOneUse()) { + TrueVal = Builder.CreateFreeze(TrueVal); + return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); + } + + if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { + Use *Y = nullptr; + bool IsAnd = match(FalseVal, m_Zero()) ? true : false; + Value *Op1 = IsAnd ? TrueVal : FalseVal; + if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { + auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); + InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser())); + replaceUse(*Y, FI); + return replaceInstUsesWith(SI, Op1); } - // select a, a, b -> select a, true, b - if (CondVal == TrueVal) - return replaceOperand(SI, 1, One); - // select a, b, a -> select a, b, false - if (CondVal == FalseVal) - return replaceOperand(SI, 2, Zero); - - // select a, !a, b -> select !a, b, false - if (match(TrueVal, m_Not(m_Specific(CondVal)))) - return SelectInst::Create(TrueVal, FalseVal, Zero); - // select a, b, !a -> select !a, true, b - if (match(FalseVal, m_Not(m_Specific(CondVal)))) - return SelectInst::Create(FalseVal, One, TrueVal); - - Value *A, *B; - - // DeMorgan in select form: !a && !b --> !(a || b) - // select !a, !b, false --> not (select a, true, b) - if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) && - (CondVal->hasOneUse() || TrueVal->hasOneUse()) && - !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) - return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B)); - - // DeMorgan in select form: !a || !b --> !(a && b) - // select !a, true, !b --> not (select a, b, false) - if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) && - (CondVal->hasOneUse() || FalseVal->hasOneUse()) && - !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) - return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero)); - - // select (select a, true, b), true, b -> select a, true, b - if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && - match(TrueVal, m_One()) && match(FalseVal, m_Specific(B))) + if (auto *Op1SI = dyn_cast<SelectInst>(Op1)) + if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, + /* IsAnd */ IsAnd)) + return I; + + if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) + if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) + if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, + /* IsLogical */ true)) + return replaceInstUsesWith(SI, V); + } + + // select (a || b), c, false -> select a, c, false + // select c, (a || b), false -> select c, a, false + // if c implies that b is false. + if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) && + match(FalseVal, m_Zero())) { + std::optional<bool> Res = isImpliedCondition(TrueVal, B, DL); + if (Res && *Res == false) return replaceOperand(SI, 0, A); - // select (select a, b, false), b, false -> select a, b, false - if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && - match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) + } + if (match(TrueVal, m_LogicalOr(m_Value(A), m_Value(B))) && + match(FalseVal, m_Zero())) { + std::optional<bool> Res = isImpliedCondition(CondVal, B, DL); + if (Res && *Res == false) + return replaceOperand(SI, 1, A); + } + // select c, true, (a && b) -> select c, true, a + // select (a && b), true, c -> select a, true, c + // if c = false implies that b = true + if (match(TrueVal, m_One()) && + match(FalseVal, m_LogicalAnd(m_Value(A), m_Value(B)))) { + std::optional<bool> Res = isImpliedCondition(CondVal, B, DL, false); + if (Res && *Res == true) + return replaceOperand(SI, 2, A); + } + if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) && + match(TrueVal, m_One())) { + std::optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false); + if (Res && *Res == true) return replaceOperand(SI, 0, A); + } + if (match(TrueVal, m_One())) { Value *C; - // select (~a | c), a, b -> and a, (or c, freeze(b)) - if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && - CondVal->hasOneUse()) { - FalseVal = Builder.CreateFreeze(FalseVal); - return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); - } - // select (~c & b), a, b -> and b, (or freeze(a), c) - if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && - CondVal->hasOneUse()) { - TrueVal = Builder.CreateFreeze(TrueVal); - return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); + + // (C && A) || (!C && B) --> sel C, A, B + // (A && C) || (!C && B) --> sel C, A, B + // (C && A) || (B && !C) --> sel C, A, B + // (A && C) || (B && !C) --> sel C, A, B (may require freeze) + if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(B))) && + match(CondVal, m_c_LogicalAnd(m_Specific(C), m_Value(A)))) { + auto *SelCond = dyn_cast<SelectInst>(CondVal); + auto *SelFVal = dyn_cast<SelectInst>(FalseVal); + bool MayNeedFreeze = SelCond && SelFVal && + match(SelFVal->getTrueValue(), + m_Not(m_Specific(SelCond->getTrueValue()))); + if (MayNeedFreeze) + C = Builder.CreateFreeze(C); + return SelectInst::Create(C, A, B); } - if (!SelType->isVectorTy()) { - if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ, - /* AllowRefinement */ true)) - return replaceOperand(SI, 1, S); - if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ, - /* AllowRefinement */ true)) - return replaceOperand(SI, 2, S); + // (!C && A) || (C && B) --> sel C, B, A + // (A && !C) || (C && B) --> sel C, B, A + // (!C && A) || (B && C) --> sel C, B, A + // (A && !C) || (B && C) --> sel C, B, A (may require freeze) + if (match(CondVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(A))) && + match(FalseVal, m_c_LogicalAnd(m_Specific(C), m_Value(B)))) { + auto *SelCond = dyn_cast<SelectInst>(CondVal); + auto *SelFVal = dyn_cast<SelectInst>(FalseVal); + bool MayNeedFreeze = SelCond && SelFVal && + match(SelCond->getTrueValue(), + m_Not(m_Specific(SelFVal->getTrueValue()))); + if (MayNeedFreeze) + C = Builder.CreateFreeze(C); + return SelectInst::Create(C, B, A); } + } - if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { - Use *Y = nullptr; - bool IsAnd = match(FalseVal, m_Zero()) ? true : false; - Value *Op1 = IsAnd ? TrueVal : FalseVal; - if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { - auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); - InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser())); - replaceUse(*Y, FI); - return replaceInstUsesWith(SI, Op1); - } + return nullptr; +} + +Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); - if (auto *Op1SI = dyn_cast<SelectInst>(Op1)) - if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, - /* IsAnd */ IsAnd)) - return I; + if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal, + SQ.getWithInstruction(&SI))) + return replaceInstUsesWith(SI, V); - if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) - if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) - if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, - /* IsLogical */ true)) - return replaceInstUsesWith(SI, V); - } + if (Instruction *I = canonicalizeSelectToShuffle(SI)) + return I; - // select (select a, true, b), c, false -> select a, c, false - // select c, (select a, true, b), false -> select c, a, false - // if c implies that b is false. - if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && - match(FalseVal, m_Zero())) { - Optional<bool> Res = isImpliedCondition(TrueVal, B, DL); - if (Res && *Res == false) - return replaceOperand(SI, 0, A); - } - if (match(TrueVal, m_Select(m_Value(A), m_One(), m_Value(B))) && - match(FalseVal, m_Zero())) { - Optional<bool> Res = isImpliedCondition(CondVal, B, DL); - if (Res && *Res == false) - return replaceOperand(SI, 1, A); - } - // select c, true, (select a, b, false) -> select c, true, a - // select (select a, b, false), true, c -> select a, true, c - // if c = false implies that b = true - if (match(TrueVal, m_One()) && - match(FalseVal, m_Select(m_Value(A), m_Value(B), m_Zero()))) { - Optional<bool> Res = isImpliedCondition(CondVal, B, DL, false); - if (Res && *Res == true) - return replaceOperand(SI, 2, A); - } - if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && - match(TrueVal, m_One())) { - Optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false); - if (Res && *Res == true) - return replaceOperand(SI, 0, A); - } + if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this)) + return I; - // sel (sel c, a, false), true, (sel !c, b, false) -> sel c, a, b - // sel (sel !c, a, false), true, (sel c, b, false) -> sel c, b, a - Value *C1, *C2; - if (match(CondVal, m_Select(m_Value(C1), m_Value(A), m_Zero())) && - match(TrueVal, m_One()) && - match(FalseVal, m_Select(m_Value(C2), m_Value(B), m_Zero()))) { - if (match(C2, m_Not(m_Specific(C1)))) // first case - return SelectInst::Create(C1, A, B); - else if (match(C1, m_Not(m_Specific(C2)))) // second case - return SelectInst::Create(C2, B, A); - } + // If the type of select is not an integer type or if the condition and + // the selection type are not both scalar nor both vector types, there is no + // point in attempting to match these patterns. + Type *CondType = CondVal->getType(); + if (!isa<Constant>(CondVal) && SelType->isIntOrIntVectorTy() && + CondType->isVectorTy() == SelType->isVectorTy()) { + if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, + ConstantInt::getTrue(CondType), SQ, + /* AllowRefinement */ true)) + return replaceOperand(SI, 1, S); + + if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, + ConstantInt::getFalse(CondType), SQ, + /* AllowRefinement */ true)) + return replaceOperand(SI, 2, S); + + // Handle patterns involving sext/zext + not explicitly, + // as simplifyWithOpReplaced() only looks past one instruction. + Value *NotCond; + + // select a, sext(!a), b -> select !a, b, 0 + // select a, zext(!a), b -> select !a, b, 0 + if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond), + m_Not(m_Specific(CondVal)))))) + return SelectInst::Create(NotCond, FalseVal, + Constant::getNullValue(SelType)); + + // select a, b, zext(!a) -> select !a, 1, b + if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond), + m_Not(m_Specific(CondVal)))))) + return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal); + + // select a, b, sext(!a) -> select !a, -1, b + if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond), + m_Not(m_Specific(CondVal)))))) + return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType), + TrueVal); } + if (Instruction *R = foldSelectOfBools(SI)) + return R; + // Selecting between two integer or vector splat integer constants? // // Note that we don't handle a scalar select of vectors: @@ -2881,8 +3187,23 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); return replaceInstUsesWith(SI, NewSel); } + } + } + + if (isa<FPMathOperator>(SI)) { + // TODO: Try to forward-propagate FMF from select arms to the select. - // NOTE: if we wanted to, this is where to detect MIN/MAX + // Canonicalize select of FP values where NaN and -0.0 are not valid as + // minnum/maxnum intrinsics. + if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) { + Value *X, *Y; + if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) + return replaceInstUsesWith( + SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI)); + + if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y)))) + return replaceInstUsesWith( + SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI)); } } @@ -2997,19 +3318,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } - // Canonicalize select of FP values where NaN and -0.0 are not valid as - // minnum/maxnum intrinsics. - if (isa<FPMathOperator>(SI) && SI.hasNoNaNs() && SI.hasNoSignedZeros()) { - Value *X, *Y; - if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) - return replaceInstUsesWith( - SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI)); - - if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y)))) - return replaceInstUsesWith( - SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI)); - } - // See if we can fold the select into a phi node if the condition is a select. if (auto *PN = dyn_cast<PHINode>(SI.getCondition())) // The true/false values have to be live in the PHI predecessor's blocks. @@ -3198,5 +3506,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + if (Instruction *I = foldNestedSelects(SI, Builder)) + return I; + + // Match logical variants of the pattern, + // and transform them iff that gets rid of inversions. + // (~x) | y --> ~(x & (~y)) + // (~x) & y --> ~(x | (~y)) + if (sinkNotIntoOtherHandOfLogicalOp(SI)) + return &SI; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 13c98b935adf..ec505381cc86 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -346,8 +346,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, Value *X, *Y; auto matchFirstShift = [&](Value *V) { APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits()); - return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) && - match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) && + return match(V, + m_OneUse(m_BinOp(ShiftOpcode, m_Value(X), m_Constant(C0)))) && match(ConstantExpr::getAdd(C0, C1), m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); }; @@ -363,7 +363,7 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1); Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); - Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); + Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1); return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); } @@ -730,13 +730,34 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, return BinaryOperator::Create( I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); + bool IsLeftShift = I.getOpcode() == Instruction::Shl; + Type *Ty = I.getType(); + unsigned TypeBits = Ty->getScalarSizeInBits(); + + // (X / +DivC) >> (Width - 1) --> ext (X <= -DivC) + // (X / -DivC) >> (Width - 1) --> ext (X >= +DivC) + const APInt *DivC; + if (!IsLeftShift && match(C1, m_SpecificIntAllowUndef(TypeBits - 1)) && + match(Op0, m_SDiv(m_Value(X), m_APInt(DivC))) && !DivC->isZero() && + !DivC->isMinSignedValue()) { + Constant *NegDivC = ConstantInt::get(Ty, -(*DivC)); + ICmpInst::Predicate Pred = + DivC->isNegative() ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SLE; + Value *Cmp = Builder.CreateICmp(Pred, X, NegDivC); + auto ExtOpcode = (I.getOpcode() == Instruction::AShr) ? Instruction::SExt + : Instruction::ZExt; + return CastInst::Create(ExtOpcode, Cmp, Ty); + } + const APInt *Op1C; if (!match(C1, m_APInt(Op1C))) return nullptr; + assert(!Op1C->uge(TypeBits) && + "Shift over the type width should have been removed already"); + // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. - bool IsLeftShift = I.getOpcode() == Instruction::Shl; if (I.getOpcode() != Instruction::AShr && canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) { LLVM_DEBUG( @@ -748,14 +769,6 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL)); } - // See if we can simplify any instructions used by the instruction whose sole - // purpose is to compute bits we don't care about. - Type *Ty = I.getType(); - unsigned TypeBits = Ty->getScalarSizeInBits(); - assert(!Op1C->uge(TypeBits) && - "Shift over the type width should have been removed already"); - (void)TypeBits; - if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; @@ -826,6 +839,74 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, return nullptr; } +// Tries to perform +// (lshr (add (zext X), (zext Y)), K) +// -> (icmp ult (add X, Y), X) +// where +// - The add's operands are zexts from a K-bits integer to a bigger type. +// - The add is only used by the shr, or by iK (or narrower) truncates. +// - The lshr type has more than 2 bits (other types are boolean math). +// - K > 1 +// note that +// - The resulting add cannot have nuw/nsw, else on overflow we get a +// poison value and the transform isn't legal anymore. +Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { + assert(I.getOpcode() == Instruction::LShr); + + Value *Add = I.getOperand(0); + Value *ShiftAmt = I.getOperand(1); + Type *Ty = I.getType(); + + if (Ty->getScalarSizeInBits() < 3) + return nullptr; + + const APInt *ShAmtAPInt = nullptr; + Value *X = nullptr, *Y = nullptr; + if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) || + !match(Add, + m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y)))))) + return nullptr; + + const unsigned ShAmt = ShAmtAPInt->getZExtValue(); + if (ShAmt == 1) + return nullptr; + + // X/Y are zexts from `ShAmt`-sized ints. + if (X->getType()->getScalarSizeInBits() != ShAmt || + Y->getType()->getScalarSizeInBits() != ShAmt) + return nullptr; + + // Make sure that `Add` is only used by `I` and `ShAmt`-truncates. + if (!Add->hasOneUse()) { + for (User *U : Add->users()) { + if (U == &I) + continue; + + TruncInst *Trunc = dyn_cast<TruncInst>(U); + if (!Trunc || Trunc->getType()->getScalarSizeInBits() > ShAmt) + return nullptr; + } + } + + // Insert at Add so that the newly created `NarrowAdd` will dominate it's + // users (i.e. `Add`'s users). + Instruction *AddInst = cast<Instruction>(Add); + Builder.SetInsertPoint(AddInst); + + Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed"); + Value *Overflow = + Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow"); + + // Replace the uses of the original add with a zext of the + // NarrowAdd's result. Note that all users at this stage are known to + // be ShAmt-sized truncs, or the lshr itself. + if (!Add->hasOneUse()) + replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty)); + + // Replace the LShr with a zext of the overflow check. + return new ZExtInst(Overflow, Ty); +} + Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -1046,11 +1127,21 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { } } - // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 - if (match(Op0, m_One()) && - match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) - return BinaryOperator::CreateLShr( - ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + if (match(Op0, m_One())) { + // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 + if (match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) + return BinaryOperator::CreateLShr( + ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + + // The only way to shift out the 1 is with an over-shift, so that would + // be poison with or without "nuw". Undef is excluded because (undef << X) + // is not undef (it is zero). + Constant *ConstantOne = cast<Constant>(Op0); + if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) { + I.setHasNoUnsignedWrap(); + return &I; + } + } return nullptr; } @@ -1068,10 +1159,17 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); + Value *X; const APInt *C; + unsigned BitWidth = Ty->getScalarSizeInBits(); + + // (iN (~X) u>> (N - 1)) --> zext (X > -1) + if (match(Op0, m_OneUse(m_Not(m_Value(X)))) && + match(Op1, m_SpecificIntAllowUndef(BitWidth - 1))) + return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); + if (match(Op1, m_APInt(C))) { unsigned ShAmtC = C->getZExtValue(); - unsigned BitWidth = Ty->getScalarSizeInBits(); auto *II = dyn_cast<IntrinsicInst>(Op0); if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC && (II->getIntrinsicID() == Intrinsic::ctlz || @@ -1276,6 +1374,18 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } } + // Reduce add-carry of bools to logic: + // ((zext BoolX) + (zext BoolY)) >> 1 --> zext (BoolX && BoolY) + Value *BoolX, *BoolY; + if (ShAmtC == 1 && match(Op0, m_Add(m_Value(X), m_Value(Y))) && + match(X, m_ZExt(m_Value(BoolX))) && match(Y, m_ZExt(m_Value(BoolY))) && + BoolX->getType()->isIntOrIntVectorTy(1) && + BoolY->getType()->isIntOrIntVectorTy(1) && + (X->hasOneUse() || Y->hasOneUse() || Op0->hasOneUse())) { + Value *And = Builder.CreateAnd(BoolX, BoolY); + return new ZExtInst(And, Ty); + } + // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { @@ -1285,13 +1395,15 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } // Transform (x << y) >> y to x & (-1 >> y) - Value *X; if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); Value *Mask = Builder.CreateLShr(AllOnes, Op1); return BinaryOperator::CreateAnd(Mask, X); } + if (Instruction *Overflow = foldLShrOverflowBit(I)) + return Overflow; + return nullptr; } @@ -1469,8 +1581,11 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { return R; // See if we can turn a signed shr into an unsigned shr. - if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) - return BinaryOperator::CreateLShr(Op0, Op1); + if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) { + Instruction *Lshr = BinaryOperator::CreateLShr(Op0, Op1); + Lshr->setIsExact(I.isExact()); + return Lshr; + } // ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1 if (match(Op0, m_OneUse(m_Not(m_Value(X))))) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index febd0f51d25f..77d675422966 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -130,9 +130,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (Depth == MaxAnalysisRecursionDepth) return nullptr; - if (isa<ScalableVectorType>(VTy)) - return nullptr; - Instruction *I = dyn_cast<Instruction>(V); if (!I) { computeKnownBits(V, Known, Depth, CxtI); @@ -154,6 +151,20 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (Depth == 0 && !V->hasOneUse()) DemandedMask.setAllBits(); + // Update flags after simplifying an operand based on the fact that some high + // order bits are not demanded. + auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I, + unsigned NLZ) { + if (NLZ > 0) { + // Disable the nsw and nuw flags here: We can no longer guarantee that + // we won't wrap after simplification. Removing the nsw/nuw flags is + // legal here because the top bit is not demanded. + I->setHasNoSignedWrap(false); + I->setHasNoUnsignedWrap(false); + } + return I; + }; + // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care // about the high bits of the operands. auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { @@ -165,13 +176,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || ShrinkDemandedConstant(I, 1, DemandedFromOps) || SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { - if (NLZ > 0) { - // Disable the nsw and nuw flags here: We can no longer guarantee that - // we won't wrap after simplification. Removing the nsw/nuw flags is - // legal here because the top bit is not demanded. - I->setHasNoSignedWrap(false); - I->setHasNoUnsignedWrap(false); - } + disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); return true; } return false; @@ -397,7 +402,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::ZExt: { unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); @@ -416,7 +421,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (auto *DstVTy = dyn_cast<VectorType>(VTy)) { if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) { - if (cast<FixedVectorType>(DstVTy)->getNumElements() != + if (isa<ScalableVectorType>(DstVTy) || + isa<ScalableVectorType>(SrcVTy) || + cast<FixedVectorType>(DstVTy)->getNumElements() != cast<FixedVectorType>(SrcVTy)->getNumElements()) // Don't touch a bitcast between vectors of different element counts. return nullptr; @@ -461,7 +468,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } - case Instruction::Add: + case Instruction::Add: { if ((DemandedMask & 1) == 0) { // If we do not need the low bit, try to convert bool math to logic: // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN @@ -498,26 +505,68 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return Builder.CreateSExt(Or, VTy); } } - LLVM_FALLTHROUGH; - case Instruction::Sub: { - APInt DemandedFromOps; - if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) - return I; - // If we are known to be adding/subtracting zeros to every bit below + // Right fill the mask of bits for the operands to demand the most + // significant bit and all those below it. + unsigned NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) + return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); + + // If low order bits are not demanded and known to be zero in one operand, + // then we don't need to demand them from the other operand, since they + // can't cause overflow into any bits that are demanded in the result. + unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes(); + APInt DemandedFromLHS = DemandedFromOps; + DemandedFromLHS.clearLowBits(NTZ); + if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || + SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1)) + return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); + + // If we are known to be adding zeros to every bit below + // the highest demanded bit, we just return the other side. + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + // Otherwise just compute the known bits of the result. + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + Known = KnownBits::computeForAddSub(true, NSW, LHSKnown, RHSKnown); + break; + } + case Instruction::Sub: { + // Right fill the mask of bits for the operands to demand the most + // significant bit and all those below it. + unsigned NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) + return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); + + // If low order bits are not demanded and are known to be zero in RHS, + // then we don't need to demand them from LHS, since they can't cause a + // borrow from any bits that are demanded in the result. + unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes(); + APInt DemandedFromLHS = DemandedFromOps; + DemandedFromLHS.clearLowBits(NTZ); + if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || + SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1)) + return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); + + // If we are known to be subtracting zeros from every bit below // the highest demanded bit, we just return the other side. if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); // We can't do this with the LHS for subtraction, unless we are only // demanding the LSB. - if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) && - DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); // Otherwise just compute the known bits of the result. bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add, - NSW, LHSKnown, RHSKnown); + Known = KnownBits::computeForAddSub(false, NSW, LHSKnown, RHSKnown); break; } case Instruction::Mul: { @@ -747,18 +796,18 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // UDiv doesn't demand low bits that are zero in the divisor. const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { - // If the shift is exact, then it does demand the low bits. - if (cast<UDivOperator>(I)->isExact()) - break; - - // FIXME: Take the demanded mask of the result into account. + // TODO: Take the demanded mask of the result into account. unsigned RHSTrailingZeros = SA->countTrailingZeros(); APInt DemandedMaskIn = APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) { + // We can't guarantee that "exact" is still true after changing the + // the dividend. + I->dropPoisonGeneratingFlags(); return I; + } - // Propagate zero bits from the input. + // Increase high zero bits from the input. Known.Zero.setHighBits(std::min( BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros)); } else { @@ -922,10 +971,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } default: { // Handle target specific intrinsics - Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( + std::optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( *II, DemandedMask, Known, KnownBitsComputed); if (V) - return V.value(); + return *V; break; } } @@ -962,11 +1011,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( // this instruction has a simpler value in that context. switch (I->getOpcode()) { case Instruction::And: { - // If either the LHS or the RHS are Zero, the result is zero. computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, - CxtI); - + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown & RHSKnown; // If the client is only demanding bits that we know, return the known @@ -975,8 +1021,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( return Constant::getIntegerValue(ITy, Known.One); // If all of the demanded bits are known 1 on one side, return the other. - // These bits cannot contribute to the result of the 'and' in this - // context. + // These bits cannot contribute to the result of the 'and' in this context. if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) return I->getOperand(0); if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) @@ -985,14 +1030,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } case Instruction::Or: { - // We can simplify (X|Y) -> X or Y in the user's context if we know that - // only bits from X or Y are demanded. - - // If either the LHS or the RHS are One, the result is One. computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, - CxtI); - + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown | RHSKnown; // If the client is only demanding bits that we know, return the known @@ -1000,9 +1039,10 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) return Constant::getIntegerValue(ITy, Known.One); - // If all of the demanded bits are known zero on one side, return the - // other. These bits cannot contribute to the result of the 'or' in this - // context. + // We can simplify (X|Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + // If all of the demanded bits are known zero on one side, return the other. + // These bits cannot contribute to the result of the 'or' in this context. if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) return I->getOperand(0); if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) @@ -1011,13 +1051,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } case Instruction::Xor: { - // We can simplify (X^Y) -> X or Y in the user's context if we know that - // only bits from X or Y are demanded. - computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, - CxtI); - + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown ^ RHSKnown; // If the client is only demanding bits that we know, return the known @@ -1025,8 +1060,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) return Constant::getIntegerValue(ITy, Known.One); - // If all of the demanded bits are known zero on one side, return the - // other. + // We can simplify (X^Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + // If all of the demanded bits are known zero on one side, return the other. if (DemandedMask.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); if (DemandedMask.isSubsetOf(LHSKnown.Zero)) @@ -1034,6 +1070,34 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } + case Instruction::Add: { + unsigned NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + + // If an operand adds zeros to every bit below the highest demanded bit, + // that operand doesn't change the result. Return the other side. + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + break; + } + case Instruction::Sub: { + unsigned NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + + // If an operand subtracts zeros from every bit below the highest demanded + // bit, that operand doesn't change the result. Return the other side. + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + + break; + } case Instruction::AShr: { // Compute the Known bits to simplify things downstream. computeKnownBits(I, Known, Depth, CxtI); @@ -1632,11 +1696,11 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, } default: { // Handle target specific intrinsics - Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( + std::optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( *II, DemandedElts, UndefElts, UndefElts2, UndefElts3, simplifyAndSetOp); if (V) - return V.value(); + return *V; break; } } // switch on IntrinsicID diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index b80c58183dd5..61e62adbe327 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -105,7 +105,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, // 2) Possibly more ExtractElements with the same index. // 3) Another operand, which will feed back into the PHI. Instruction *PHIUser = nullptr; - for (auto U : PN->users()) { + for (auto *U : PN->users()) { if (ExtractElementInst *EU = dyn_cast<ExtractElementInst>(U)) { if (EI.getIndexOperand() == EU->getIndexOperand()) Extracts.push_back(EU); @@ -171,7 +171,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, } } - for (auto E : Extracts) + for (auto *E : Extracts) replaceInstUsesWith(*E, scalarPHI); return &EI; @@ -187,13 +187,12 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { ElementCount NumElts = cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); Type *DestTy = Ext.getType(); + unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); bool IsBigEndian = DL.isBigEndian(); // If we are casting an integer to vector and extracting a portion, that is // a shift-right and truncate. - // TODO: Allow FP dest type by casting the trunc to FP? - if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() && - isDesirableIntType(X->getType()->getPrimitiveSizeInBits())) { + if (X->getType()->isIntegerTy()) { assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) && "Expected fixed vector type for bitcast from scalar integer"); @@ -202,10 +201,18 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { // BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8 if (IsBigEndian) ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC; - unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits(); - if (!ShiftAmountC || Ext.getVectorOperand()->hasOneUse()) { - Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); - return new TruncInst(Lshr, DestTy); + unsigned ShiftAmountC = ExtIndexC * DestWidth; + if (!ShiftAmountC || + (isDesirableIntType(X->getType()->getPrimitiveSizeInBits()) && + Ext.getVectorOperand()->hasOneUse())) { + if (ShiftAmountC) + X = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); + if (DestTy->isFloatingPointTy()) { + Type *DstIntTy = IntegerType::getIntNTy(X->getContext(), DestWidth); + Value *Trunc = Builder.CreateTrunc(X, DstIntTy); + return new BitCastInst(Trunc, DestTy); + } + return new TruncInst(X, DestTy); } } @@ -278,7 +285,6 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { return nullptr; unsigned SrcWidth = SrcTy->getScalarSizeInBits(); - unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); unsigned ShAmt = Chunk * DestWidth; // TODO: This limitation is more strict than necessary. We could sum the @@ -393,6 +399,20 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { SQ.getWithInstruction(&EI))) return replaceInstUsesWith(EI, V); + // extractelt (select %x, %vec1, %vec2), %const -> + // select %x, %vec1[%const], %vec2[%const] + // TODO: Support constant folding of multiple select operands: + // extractelt (select %x, %vec1, %vec2), (select %x, %c1, %c2) + // If the extractelement will for instance try to do out of bounds accesses + // because of the values of %c1 and/or %c2, the sequence could be optimized + // early. This is currently not possible because constant folding will reach + // an unreachable assertion if it doesn't find a constant operand. + if (SelectInst *SI = dyn_cast<SelectInst>(EI.getVectorOperand())) + if (SI->getCondition()->getType()->isIntegerTy() && + isa<Constant>(EI.getIndexOperand())) + if (Instruction *R = FoldOpIntoSelect(EI, SI)) + return R; + // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. auto *IndexC = dyn_cast<ConstantInt>(Index); @@ -850,17 +870,16 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( if (NumAggElts > 2) return nullptr; - static constexpr auto NotFound = None; + static constexpr auto NotFound = std::nullopt; static constexpr auto FoundMismatch = nullptr; // Try to find a value of each element of an aggregate. // FIXME: deal with more complex, not one-dimensional, aggregate types - SmallVector<Optional<Instruction *>, 2> AggElts(NumAggElts, NotFound); + SmallVector<std::optional<Instruction *>, 2> AggElts(NumAggElts, NotFound); // Do we know values for each element of the aggregate? auto KnowAllElts = [&AggElts]() { - return all_of(AggElts, - [](Optional<Instruction *> Elt) { return Elt != NotFound; }); + return !llvm::is_contained(AggElts, NotFound); }; int Depth = 0; @@ -889,7 +908,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // Now, we may have already previously recorded the value for this element // of an aggregate. If we did, that means the CurrIVI will later be // overwritten with the already-recorded value. But if not, let's record it! - Optional<Instruction *> &Elt = AggElts[Indices.front()]; + std::optional<Instruction *> &Elt = AggElts[Indices.front()]; Elt = Elt.value_or(InsertedValue); // FIXME: should we handle chain-terminating undef base operand? @@ -919,7 +938,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( /// or different elements had different source aggregates. FoundMismatch }; - auto Describe = [](Optional<Value *> SourceAggregate) { + auto Describe = [](std::optional<Value *> SourceAggregate) { if (SourceAggregate == NotFound) return AggregateDescription::NotFound; if (*SourceAggregate == FoundMismatch) @@ -933,8 +952,8 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // If found, return the source aggregate from which the extraction was. // If \p PredBB is provided, does PHI translation of an \p Elt first. auto FindSourceAggregate = - [&](Instruction *Elt, unsigned EltIdx, Optional<BasicBlock *> UseBB, - Optional<BasicBlock *> PredBB) -> Optional<Value *> { + [&](Instruction *Elt, unsigned EltIdx, std::optional<BasicBlock *> UseBB, + std::optional<BasicBlock *> PredBB) -> std::optional<Value *> { // For now(?), only deal with, at most, a single level of PHI indirection. if (UseBB && PredBB) Elt = dyn_cast<Instruction>(Elt->DoPHITranslation(*UseBB, *PredBB)); @@ -961,9 +980,9 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // see if we can find appropriate source aggregate for each of the elements, // and see it's the same aggregate for each element. If so, return it. auto FindCommonSourceAggregate = - [&](Optional<BasicBlock *> UseBB, - Optional<BasicBlock *> PredBB) -> Optional<Value *> { - Optional<Value *> SourceAggregate; + [&](std::optional<BasicBlock *> UseBB, + std::optional<BasicBlock *> PredBB) -> std::optional<Value *> { + std::optional<Value *> SourceAggregate; for (auto I : enumerate(AggElts)) { assert(Describe(SourceAggregate) != AggregateDescription::FoundMismatch && @@ -975,7 +994,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // For this element, is there a plausible source aggregate? // FIXME: we could special-case undef element, IFF we know that in the // source aggregate said element isn't poison. - Optional<Value *> SourceAggregateForElement = + std::optional<Value *> SourceAggregateForElement = FindSourceAggregate(*I.value(), I.index(), UseBB, PredBB); // Okay, what have we found? Does that correlate with previous findings? @@ -1009,10 +1028,11 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( return *SourceAggregate; }; - Optional<Value *> SourceAggregate; + std::optional<Value *> SourceAggregate; // Can we find the source aggregate without looking at predecessors? - SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/None, /*PredBB=*/None); + SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/std::nullopt, + /*PredBB=*/std::nullopt); if (Describe(SourceAggregate) != AggregateDescription::NotFound) { if (Describe(SourceAggregate) == AggregateDescription::FoundMismatch) return nullptr; // Conflicting source aggregates! @@ -1029,7 +1049,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // they all should be defined in the same basic block. BasicBlock *UseBB = nullptr; - for (const Optional<Instruction *> &I : AggElts) { + for (const std::optional<Instruction *> &I : AggElts) { BasicBlock *BB = (*I)->getParent(); // If it's the first instruction we've encountered, record the basic block. if (!UseBB) { @@ -1495,6 +1515,71 @@ static Instruction *narrowInsElt(InsertElementInst &InsElt, return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType()); } +/// If we are inserting 2 halves of a value into adjacent elements of a vector, +/// try to convert to a single insert with appropriate bitcasts. +static Instruction *foldTruncInsEltPair(InsertElementInst &InsElt, + bool IsBigEndian, + InstCombiner::BuilderTy &Builder) { + Value *VecOp = InsElt.getOperand(0); + Value *ScalarOp = InsElt.getOperand(1); + Value *IndexOp = InsElt.getOperand(2); + + // Pattern depends on endian because we expect lower index is inserted first. + // Big endian: + // inselt (inselt BaseVec, (trunc (lshr X, BW/2), Index0), (trunc X), Index1 + // Little endian: + // inselt (inselt BaseVec, (trunc X), Index0), (trunc (lshr X, BW/2)), Index1 + // Note: It is not safe to do this transform with an arbitrary base vector + // because the bitcast of that vector to fewer/larger elements could + // allow poison to spill into an element that was not poison before. + // TODO: Detect smaller fractions of the scalar. + // TODO: One-use checks are conservative. + auto *VTy = dyn_cast<FixedVectorType>(InsElt.getType()); + Value *Scalar0, *BaseVec; + uint64_t Index0, Index1; + if (!VTy || (VTy->getNumElements() & 1) || + !match(IndexOp, m_ConstantInt(Index1)) || + !match(VecOp, m_InsertElt(m_Value(BaseVec), m_Value(Scalar0), + m_ConstantInt(Index0))) || + !match(BaseVec, m_Undef())) + return nullptr; + + // The first insert must be to the index one less than this one, and + // the first insert must be to an even index. + if (Index0 + 1 != Index1 || Index0 & 1) + return nullptr; + + // For big endian, the high half of the value should be inserted first. + // For little endian, the low half of the value should be inserted first. + Value *X; + uint64_t ShAmt; + if (IsBigEndian) { + if (!match(ScalarOp, m_Trunc(m_Value(X))) || + !match(Scalar0, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt))))) + return nullptr; + } else { + if (!match(Scalar0, m_Trunc(m_Value(X))) || + !match(ScalarOp, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt))))) + return nullptr; + } + + Type *SrcTy = X->getType(); + unsigned ScalarWidth = SrcTy->getScalarSizeInBits(); + unsigned VecEltWidth = VTy->getScalarSizeInBits(); + if (ScalarWidth != VecEltWidth * 2 || ShAmt != VecEltWidth) + return nullptr; + + // Bitcast the base vector to a vector type with the source element type. + Type *CastTy = FixedVectorType::get(SrcTy, VTy->getNumElements() / 2); + Value *CastBaseVec = Builder.CreateBitCast(BaseVec, CastTy); + + // Scale the insert index for a vector with half as many elements. + // bitcast (inselt (bitcast BaseVec), X, NewIndex) + uint64_t NewIndex = IsBigEndian ? Index1 / 2 : Index0 / 2; + Value *NewInsert = Builder.CreateInsertElement(CastBaseVec, X, NewIndex); + return new BitCastInst(NewInsert, VTy); +} + Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); @@ -1505,10 +1590,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { return replaceInstUsesWith(IE, V); // Canonicalize type of constant indices to i64 to simplify CSE - if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) + if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) { if (auto *NewIdx = getPreferredVectorIndex(IndexC)) return replaceOperand(IE, 2, NewIdx); + Value *BaseVec, *OtherScalar; + uint64_t OtherIndexVal; + if (match(VecOp, m_OneUse(m_InsertElt(m_Value(BaseVec), + m_Value(OtherScalar), + m_ConstantInt(OtherIndexVal)))) && + !isa<Constant>(OtherScalar) && OtherIndexVal > IndexC->getZExtValue()) { + Value *NewIns = Builder.CreateInsertElement(BaseVec, ScalarOp, IdxOp); + return InsertElementInst::Create(NewIns, OtherScalar, + Builder.getInt64(OtherIndexVal)); + } + } + // If the scalar is bitcast and inserted into undef, do the insert in the // source type followed by bitcast. // TODO: Generalize for insert into any constant, not just undef? @@ -1622,6 +1719,9 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Ext = narrowInsElt(IE, Builder)) return Ext; + if (Instruction *Ext = foldTruncInsEltPair(IE, DL.isBigEndian(), Builder)) + return Ext; + return nullptr; } @@ -1653,7 +1753,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // from an undefined element in an operand. if (llvm::is_contained(Mask, -1)) return false; - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -1700,8 +1800,8 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // Verify that 'CI' does not occur twice in Mask. A single 'insertelement' // can't put an element into multiple indices. bool SeenOnce = false; - for (int i = 0, e = Mask.size(); i != e; ++i) { - if (Mask[i] == ElementNumber) { + for (int I : Mask) { + if (I == ElementNumber) { if (SeenOnce) return false; SeenOnce = true; @@ -1957,6 +2057,56 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { return {}; } +/// A select shuffle of a select shuffle with a shared operand can be reduced +/// to a single select shuffle. This is an obvious improvement in IR, and the +/// backend is expected to lower select shuffles efficiently. +static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) { + assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); + + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + SmallVector<int, 16> Mask; + Shuf.getShuffleMask(Mask); + unsigned NumElts = Mask.size(); + + // Canonicalize a select shuffle with common operand as Op1. + auto *ShufOp = dyn_cast<ShuffleVectorInst>(Op0); + if (ShufOp && ShufOp->isSelect() && + (ShufOp->getOperand(0) == Op1 || ShufOp->getOperand(1) == Op1)) { + std::swap(Op0, Op1); + ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); + } + + ShufOp = dyn_cast<ShuffleVectorInst>(Op1); + if (!ShufOp || !ShufOp->isSelect() || + (ShufOp->getOperand(0) != Op0 && ShufOp->getOperand(1) != Op0)) + return nullptr; + + Value *X = ShufOp->getOperand(0), *Y = ShufOp->getOperand(1); + SmallVector<int, 16> Mask1; + ShufOp->getShuffleMask(Mask1); + assert(Mask1.size() == NumElts && "Vector size changed with select shuffle"); + + // Canonicalize common operand (Op0) as X (first operand of first shuffle). + if (Y == Op0) { + std::swap(X, Y); + ShuffleVectorInst::commuteShuffleMask(Mask1, NumElts); + } + + // If the mask chooses from X (operand 0), it stays the same. + // If the mask chooses from the earlier shuffle, the other mask value is + // transferred to the combined select shuffle: + // shuf X, (shuf X, Y, M1), M --> shuf X, Y, M' + SmallVector<int, 16> NewMask(NumElts); + for (unsigned i = 0; i != NumElts; ++i) + NewMask[i] = Mask[i] < (signed)NumElts ? Mask[i] : Mask1[i]; + + // A select mask with undef elements might look like an identity mask. + assert((ShuffleVectorInst::isSelectMask(NewMask) || + ShuffleVectorInst::isIdentityMask(NewMask)) && + "Unexpected shuffle mask"); + return new ShuffleVectorInst(X, Y, NewMask); +} + static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); @@ -2061,6 +2211,9 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { return &Shuf; } + if (Instruction *I = foldSelectShuffleOfSelectShuffle(Shuf)) + return I; + if (Instruction *I = foldSelectShuffleWith1Binop(Shuf)) return I; @@ -2541,6 +2694,35 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { return new ShuffleVectorInst(X, Y, NewMask); } +// Splatting the first element of the result of a BinOp, where any of the +// BinOp's operands are the result of a first element splat can be simplified to +// splatting the first element of the result of the BinOp +Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) { + if (!match(SVI.getOperand(1), m_Undef()) || + !match(SVI.getShuffleMask(), m_ZeroMask())) + return nullptr; + + Value *Op0 = SVI.getOperand(0); + Value *X, *Y; + if (!match(Op0, m_BinOp(m_Shuffle(m_Value(X), m_Undef(), m_ZeroMask()), + m_Value(Y))) && + !match(Op0, m_BinOp(m_Value(X), + m_Shuffle(m_Value(Y), m_Undef(), m_ZeroMask())))) + return nullptr; + if (X->getType() != Y->getType()) + return nullptr; + + auto *BinOp = cast<BinaryOperator>(Op0); + if (!isSafeToSpeculativelyExecute(BinOp)) + return nullptr; + + Value *NewBO = Builder.CreateBinOp(BinOp->getOpcode(), X, Y); + if (auto NewBOI = dyn_cast<Instruction>(NewBO)) + NewBOI->copyIRFlags(BinOp); + + return new ShuffleVectorInst(NewBO, SVI.getShuffleMask()); +} + Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); @@ -2549,7 +2731,9 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { SVI.getType(), ShufQuery)) return replaceInstUsesWith(SVI, V); - // Bail out for scalable vectors + if (Instruction *I = simplifyBinOpSplats(SVI)) + return I; + if (isa<ScalableVectorType>(LHS->getType())) return nullptr; @@ -2694,7 +2878,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *V = LHS; unsigned MaskElems = Mask.size(); auto *SrcTy = cast<FixedVectorType>(V->getType()); - unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize(); + unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedValue(); unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); assert(SrcElemBitWidth && "vector elements must have a bitwidth"); unsigned SrcNumElems = SrcTy->getNumElements(); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 71c763de43b4..fb6f4f96ea48 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -38,7 +38,6 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -99,16 +98,19 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstdint> #include <memory> +#include <optional> #include <string> #include <utility> #define DEBUG_TYPE "instcombine" #include "llvm/Transforms/Utils/InstructionWorklist.h" +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -167,16 +169,16 @@ MaxArraySize("instcombine-maxarray-size", cl::init(1024), static cl::opt<unsigned> ShouldLowerDbgDeclare("instcombine-lower-dbg-declare", cl::Hidden, cl::init(true)); -Optional<Instruction *> +std::optional<Instruction *> InstCombiner::targetInstCombineIntrinsic(IntrinsicInst &II) { // Handle target specific intrinsics if (II.getCalledFunction()->isTargetIntrinsic()) { return TTI.instCombineIntrinsic(*this, II); } - return None; + return std::nullopt; } -Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic( +std::optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic( IntrinsicInst &II, APInt DemandedMask, KnownBits &Known, bool &KnownBitsComputed) { // Handle target specific intrinsics @@ -184,10 +186,10 @@ Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic( return TTI.simplifyDemandedUseBitsIntrinsic(*this, II, DemandedMask, Known, KnownBitsComputed); } - return None; + return std::nullopt; } -Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( +std::optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3, std::function<void(Instruction *, unsigned, APInt, APInt &)> @@ -198,11 +200,11 @@ Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( *this, II, DemandedElts, UndefElts, UndefElts2, UndefElts3, SimplifyAndSetOp); } - return None; + return std::nullopt; } Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { - return llvm::EmitGEPOffset(&Builder, DL, GEP); + return llvm::emitGEPOffset(&Builder, DL, GEP); } /// Legal integers and common types are considered desirable. This is used to @@ -223,11 +225,12 @@ bool InstCombinerImpl::isDesirableIntType(unsigned BitWidth) const { /// Return true if it is desirable to convert an integer computation from a /// given bit width to a new bit width. -/// We don't want to convert from a legal to an illegal type or from a smaller -/// to a larger illegal type. A width of '1' is always treated as a desirable -/// type because i1 is a fundamental type in IR, and there are many specialized -/// optimizations for i1 types. Common/desirable widths are equally treated as -/// legal to convert to, in order to open up more combining opportunities. +/// We don't want to convert from a legal or desirable type (like i8) to an +/// illegal type or from a smaller to a larger illegal type. A width of '1' +/// is always treated as a desirable type because i1 is a fundamental type in +/// IR, and there are many specialized optimizations for i1 types. +/// Common/desirable widths are equally treated as legal to convert to, in +/// order to open up more combining opportunities. bool InstCombinerImpl::shouldChangeType(unsigned FromWidth, unsigned ToWidth) const { bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); @@ -238,9 +241,9 @@ bool InstCombinerImpl::shouldChangeType(unsigned FromWidth, if (ToWidth < FromWidth && isDesirableIntType(ToWidth)) return true; - // If this is a legal integer from type, and the result would be an illegal - // type, don't do the transformation. - if (FromLegal && !ToLegal) + // If this is a legal or desiable integer from type, and the result would be + // an illegal type, don't do the transformation. + if ((FromLegal || isDesirableIntType(FromWidth)) && !ToLegal) return false; // Otherwise, if both are illegal, do not increase the size of the result. We @@ -367,14 +370,14 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, // inttoptr ( ptrtoint (x) ) --> x Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { auto *IntToPtr = dyn_cast<IntToPtrInst>(Val); - if (IntToPtr && DL.getPointerTypeSizeInBits(IntToPtr->getDestTy()) == + if (IntToPtr && DL.getTypeSizeInBits(IntToPtr->getDestTy()) == DL.getTypeSizeInBits(IntToPtr->getSrcTy())) { auto *PtrToInt = dyn_cast<PtrToIntInst>(IntToPtr->getOperand(0)); Type *CastTy = IntToPtr->getDestTy(); if (PtrToInt && CastTy->getPointerAddressSpace() == PtrToInt->getSrcTy()->getPointerAddressSpace() && - DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) == + DL.getTypeSizeInBits(PtrToInt->getSrcTy()) == DL.getTypeSizeInBits(PtrToInt->getDestTy())) { return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy, "", PtrToInt); @@ -632,14 +635,14 @@ getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op, /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). -Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, - Instruction::BinaryOps InnerOpcode, - Value *A, Value *B, Value *C, - Value *D) { +static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ, + InstCombiner::BuilderTy &Builder, + Instruction::BinaryOps InnerOpcode, Value *A, + Value *B, Value *C, Value *D) { assert(A && B && C && D && "All values must be provided"); Value *V = nullptr; - Value *SimplifiedInst = nullptr; + Value *RetVal = nullptr; Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); @@ -647,7 +650,7 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, bool InnerCommutative = Instruction::isCommutative(InnerOpcode); // Does "X op' (Y op Z)" always equal "(X op' Y) op (X op' Z)"? - if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode)) + if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode)) { // Does the instruction have the form "(A op' B) op (A op' D)" or, in the // commutative case, "(A op' B) op (C op' A)"? if (A == C || (InnerCommutative && A == D)) { @@ -656,17 +659,18 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, // Consider forming "A op' (B op D)". // If "B op D" simplifies then it can be formed with no cost. V = simplifyBinOp(TopLevelOpcode, B, D, SQ.getWithInstruction(&I)); - // If "B op D" doesn't simplify then only go on if both of the existing + + // If "B op D" doesn't simplify then only go on if one of the existing // operations "A op' B" and "C op' D" will be zapped as no longer used. - if (!V && LHS->hasOneUse() && RHS->hasOneUse()) + if (!V && (LHS->hasOneUse() || RHS->hasOneUse())) V = Builder.CreateBinOp(TopLevelOpcode, B, D, RHS->getName()); - if (V) { - SimplifiedInst = Builder.CreateBinOp(InnerOpcode, A, V); - } + if (V) + RetVal = Builder.CreateBinOp(InnerOpcode, A, V); } + } // Does "(X op Y) op' Z" always equal "(X op' Z) op (Y op' Z)"? - if (!SimplifiedInst && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) + if (!RetVal && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) { // Does the instruction have the form "(A op' B) op (C op' B)" or, in the // commutative case, "(A op' B) op (B op' D)"? if (B == D || (InnerCommutative && B == C)) { @@ -676,61 +680,94 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, // If "A op C" simplifies then it can be formed with no cost. V = simplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); - // If "A op C" doesn't simplify then only go on if both of the existing + // If "A op C" doesn't simplify then only go on if one of the existing // operations "A op' B" and "C op' D" will be zapped as no longer used. - if (!V && LHS->hasOneUse() && RHS->hasOneUse()) + if (!V && (LHS->hasOneUse() || RHS->hasOneUse())) V = Builder.CreateBinOp(TopLevelOpcode, A, C, LHS->getName()); - if (V) { - SimplifiedInst = Builder.CreateBinOp(InnerOpcode, V, B); - } + if (V) + RetVal = Builder.CreateBinOp(InnerOpcode, V, B); } + } - if (SimplifiedInst) { - ++NumFactor; - SimplifiedInst->takeName(&I); - - // Check if we can add NSW/NUW flags to SimplifiedInst. If so, set them. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(SimplifiedInst)) { - if (isa<OverflowingBinaryOperator>(SimplifiedInst)) { - bool HasNSW = false; - bool HasNUW = false; - if (isa<OverflowingBinaryOperator>(&I)) { - HasNSW = I.hasNoSignedWrap(); - HasNUW = I.hasNoUnsignedWrap(); - } + if (!RetVal) + return nullptr; - if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) { - HasNSW &= LOBO->hasNoSignedWrap(); - HasNUW &= LOBO->hasNoUnsignedWrap(); - } + ++NumFactor; + RetVal->takeName(&I); - if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) { - HasNSW &= ROBO->hasNoSignedWrap(); - HasNUW &= ROBO->hasNoUnsignedWrap(); - } + // Try to add no-overflow flags to the final value. + if (isa<OverflowingBinaryOperator>(RetVal)) { + bool HasNSW = false; + bool HasNUW = false; + if (isa<OverflowingBinaryOperator>(&I)) { + HasNSW = I.hasNoSignedWrap(); + HasNUW = I.hasNoUnsignedWrap(); + } + if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) { + HasNSW &= LOBO->hasNoSignedWrap(); + HasNUW &= LOBO->hasNoUnsignedWrap(); + } - if (TopLevelOpcode == Instruction::Add && - InnerOpcode == Instruction::Mul) { - // We can propagate 'nsw' if we know that - // %Y = mul nsw i16 %X, C - // %Z = add nsw i16 %Y, %X - // => - // %Z = mul nsw i16 %X, C+1 - // - // iff C+1 isn't INT_MIN - const APInt *CInt; - if (match(V, m_APInt(CInt))) { - if (!CInt->isMinSignedValue()) - BO->setHasNoSignedWrap(HasNSW); - } + if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) { + HasNSW &= ROBO->hasNoSignedWrap(); + HasNUW &= ROBO->hasNoUnsignedWrap(); + } - // nuw can be propagated with any constant or nuw value. - BO->setHasNoUnsignedWrap(HasNUW); - } - } + if (TopLevelOpcode == Instruction::Add && InnerOpcode == Instruction::Mul) { + // We can propagate 'nsw' if we know that + // %Y = mul nsw i16 %X, C + // %Z = add nsw i16 %Y, %X + // => + // %Z = mul nsw i16 %X, C+1 + // + // iff C+1 isn't INT_MIN + const APInt *CInt; + if (match(V, m_APInt(CInt)) && !CInt->isMinSignedValue()) + cast<Instruction>(RetVal)->setHasNoSignedWrap(HasNSW); + + // nuw can be propagated with any constant or nuw value. + cast<Instruction>(RetVal)->setHasNoUnsignedWrap(HasNUW); } } - return SimplifiedInst; + return RetVal; +} + +Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); + BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); + Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); + Value *A, *B, *C, *D; + Instruction::BinaryOps LHSOpcode, RHSOpcode; + + if (Op0) + LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + if (Op1) + RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); + + // The instruction has the form "(A op' B) op (C op' D)". Try to factorize + // a common term. + if (Op0 && Op1 && LHSOpcode == RHSOpcode) + if (Value *V = tryFactorization(I, SQ, Builder, LHSOpcode, A, B, C, D)) + return V; + + // The instruction has the form "(A op' B) op (C)". Try to factorize common + // term. + if (Op0) + if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) + if (Value *V = + tryFactorization(I, SQ, Builder, LHSOpcode, A, B, RHS, Ident)) + return V; + + // The instruction has the form "(B) op (C op' D)". Try to factorize common + // term. + if (Op1) + if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) + if (Value *V = + tryFactorization(I, SQ, Builder, RHSOpcode, LHS, Ident, C, D)) + return V; + + return nullptr; } /// This tries to simplify binary operations which some other binary operation @@ -738,41 +775,15 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, /// (eg "(A*B)+(A*C)" -> "A*(B+C)") or expanding out if this results in /// simplifications (eg: "A & (B | C) -> (A&B) | (A&C)" if this is a win). /// Returns the simplified value, or null if it didn't simplify. -Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) { +Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); - { - // Factorization. - Value *A, *B, *C, *D; - Instruction::BinaryOps LHSOpcode, RHSOpcode; - if (Op0) - LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); - if (Op1) - RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); - - // The instruction has the form "(A op' B) op (C op' D)". Try to factorize - // a common term. - if (Op0 && Op1 && LHSOpcode == RHSOpcode) - if (Value *V = tryFactorization(I, LHSOpcode, A, B, C, D)) - return V; - - // The instruction has the form "(A op' B) op (C)". Try to factorize common - // term. - if (Op0) - if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) - if (Value *V = tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) - return V; - - // The instruction has the form "(B) op (C op' D)". Try to factorize common - // term. - if (Op1) - if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) - if (Value *V = tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) - return V; - } + // Factorization. + if (Value *R = tryFactorizationFolds(I)) + return R; // Expansion. if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { @@ -876,6 +887,28 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Cond, *True = nullptr, *False = nullptr; + + // Special-case for add/negate combination. Replace the zero in the negation + // with the trailing add operand: + // (Cond ? TVal : -N) + Z --> Cond ? True : (Z - N) + // (Cond ? -N : FVal) + Z --> Cond ? (Z - N) : False + auto foldAddNegate = [&](Value *TVal, Value *FVal, Value *Z) -> Value * { + // We need an 'add' and exactly 1 arm of the select to have been simplified. + if (Opcode != Instruction::Add || (!True && !False) || (True && False)) + return nullptr; + + Value *N; + if (True && match(FVal, m_Neg(m_Value(N)))) { + Value *Sub = Builder.CreateSub(Z, N); + return Builder.CreateSelect(Cond, True, Sub, I.getName()); + } + if (False && match(TVal, m_Neg(m_Value(N)))) { + Value *Sub = Builder.CreateSub(Z, N); + return Builder.CreateSelect(Cond, Sub, False, I.getName()); + } + return nullptr; + }; + if (LHSIsSelect && RHSIsSelect && A == D) { // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) Cond = A; @@ -893,11 +926,15 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Cond = A; True = simplifyBinOp(Opcode, B, RHS, FMF, Q); False = simplifyBinOp(Opcode, C, RHS, FMF, Q); + if (Value *NewSel = foldAddNegate(B, C, RHS)) + return NewSel; } else if (RHSIsSelect && RHS->hasOneUse()) { // X op (D ? E : F) -> D ? (X op E) : (X op F) Cond = D; True = simplifyBinOp(Opcode, LHS, E, FMF, Q); False = simplifyBinOp(Opcode, LHS, F, FMF, Q); + if (Value *NewSel = foldAddNegate(E, F, LHS)) + return NewSel; } if (!True || !False) @@ -910,8 +947,10 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, /// Freely adapt every user of V as-if V was changed to !V. /// WARNING: only if canFreelyInvertAllUsersOf() said this can be done. -void InstCombinerImpl::freelyInvertAllUsersOf(Value *I) { - for (User *U : I->users()) { +void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) { + for (User *U : make_early_inc_range(I->users())) { + if (U == IgnoredUser) + continue; // Don't consider this user. switch (cast<Instruction>(U)->getOpcode()) { case Instruction::Select: { auto *SI = cast<SelectInst>(U); @@ -1033,6 +1072,9 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1)); } + if (auto *EI = dyn_cast<ExtractElementInst>(&I)) + return Builder.CreateExtractElement(SO, EI->getIndexOperand()); + assert(I.isBinaryOp() && "Unexpected opcode for select folding"); // Figure out if the constant is the left or the right argument. @@ -1133,22 +1175,6 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } -static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, - InstCombiner::BuilderTy &Builder) { - bool ConstIsRHS = isa<Constant>(I->getOperand(1)); - Constant *C = cast<Constant>(I->getOperand(ConstIsRHS)); - - Value *Op0 = InV, *Op1 = C; - if (!ConstIsRHS) - std::swap(Op0, Op1); - - Value *RI = Builder.CreateBinOp(I->getOpcode(), Op0, Op1, "phi.bo"); - auto *FPInst = dyn_cast<Instruction>(RI); - if (FPInst && isa<FPMathOperator>(FPInst)) - FPInst->copyFastMathFlags(I); - return RI; -} - Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { unsigned NumPHIValues = PN->getNumIncomingValues(); if (NumPHIValues == 0) @@ -1167,48 +1193,69 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { // Otherwise, we can replace *all* users with the new PHI we form. } - // Check to see if all of the operands of the PHI are simple constants - // (constantint/constantfp/undef). If there is one non-constant value, - // remember the BB it is in. If there is more than one or if *it* is a PHI, - // bail out. We don't do arbitrary constant expressions here because moving - // their computation can be expensive without a cost model. - BasicBlock *NonConstBB = nullptr; + // Check to see whether the instruction can be folded into each phi operand. + // If there is one operand that does not fold, remember the BB it is in. + // If there is more than one or if *it* is a PHI, bail out. + SmallVector<Value *> NewPhiValues; + BasicBlock *NonSimplifiedBB = nullptr; + Value *NonSimplifiedInVal = nullptr; for (unsigned i = 0; i != NumPHIValues; ++i) { Value *InVal = PN->getIncomingValue(i); - // For non-freeze, require constant operand - // For freeze, require non-undef, non-poison operand - if (!isa<FreezeInst>(I) && match(InVal, m_ImmConstant())) - continue; - if (isa<FreezeInst>(I) && isGuaranteedNotToBeUndefOrPoison(InVal)) + BasicBlock *InBB = PN->getIncomingBlock(i); + + // NB: It is a precondition of this transform that the operands be + // phi translatable! This is usually trivially satisfied by limiting it + // to constant ops, and for selects we do a more sophisticated check. + SmallVector<Value *> Ops; + for (Value *Op : I.operands()) { + if (Op == PN) + Ops.push_back(InVal); + else + Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB)); + } + + // Don't consider the simplification successful if we get back a constant + // expression. That's just an instruction in hiding. + // Also reject the case where we simplify back to the phi node. We wouldn't + // be able to remove it in that case. + Value *NewVal = simplifyInstructionWithOperands( + &I, Ops, SQ.getWithInstruction(InBB->getTerminator())); + if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr())) { + NewPhiValues.push_back(NewVal); continue; + } if (isa<PHINode>(InVal)) return nullptr; // Itself a phi. - if (NonConstBB) return nullptr; // More than one non-const value. + if (NonSimplifiedBB) return nullptr; // More than one non-simplified value. - NonConstBB = PN->getIncomingBlock(i); + NonSimplifiedBB = InBB; + NonSimplifiedInVal = InVal; + NewPhiValues.push_back(nullptr); // If the InVal is an invoke at the end of the pred block, then we can't // insert a computation after it without breaking the edge. if (isa<InvokeInst>(InVal)) - if (cast<Instruction>(InVal)->getParent() == NonConstBB) + if (cast<Instruction>(InVal)->getParent() == NonSimplifiedBB) return nullptr; // If the incoming non-constant value is reachable from the phis block, // we'll push the operation across a loop backedge. This could result in // an infinite combine loop, and is generally non-profitable (especially // if the operation was originally outside the loop). - if (isPotentiallyReachable(PN->getParent(), NonConstBB, nullptr, &DT, LI)) + if (isPotentiallyReachable(PN->getParent(), NonSimplifiedBB, nullptr, &DT, + LI)) return nullptr; } - // If there is exactly one non-constant value, we can insert a copy of the + // If there is exactly one non-simplified value, we can insert a copy of the // operation in that block. However, if this is a critical edge, we would be // inserting the computation on some other paths (e.g. inside a loop). Only // do this if the pred block is unconditionally branching into the phi block. // Also, make sure that the pred block is not dead code. - if (NonConstBB != nullptr) { - BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator()); - if (!BI || !BI->isUnconditional() || !DT.isReachableFromEntry(NonConstBB)) + if (NonSimplifiedBB != nullptr) { + BranchInst *BI = dyn_cast<BranchInst>(NonSimplifiedBB->getTerminator()); + if (!BI || !BI->isUnconditional() || + !DT.isReachableFromEntry(NonSimplifiedBB)) return nullptr; } @@ -1219,83 +1266,23 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { // If we are going to have to insert a new computation, do so right before the // predecessor's terminator. - if (NonConstBB) - Builder.SetInsertPoint(NonConstBB->getTerminator()); - - // Next, add all of the operands to the PHI. - if (SelectInst *SI = dyn_cast<SelectInst>(&I)) { - // We only currently try to fold the condition of a select when it is a phi, - // not the true/false values. - Value *TrueV = SI->getTrueValue(); - Value *FalseV = SI->getFalseValue(); - BasicBlock *PhiTransBB = PN->getParent(); - for (unsigned i = 0; i != NumPHIValues; ++i) { - BasicBlock *ThisBB = PN->getIncomingBlock(i); - Value *TrueVInPred = TrueV->DoPHITranslation(PhiTransBB, ThisBB); - Value *FalseVInPred = FalseV->DoPHITranslation(PhiTransBB, ThisBB); - Value *InV = nullptr; - // Beware of ConstantExpr: it may eventually evaluate to getNullValue, - // even if currently isNullValue gives false. - Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)); - // For vector constants, we cannot use isNullValue to fold into - // FalseVInPred versus TrueVInPred. When we have individual nonzero - // elements in the vector, we will incorrectly fold InC to - // `TrueVInPred`. - if (InC && isa<ConstantInt>(InC)) - InV = InC->isNullValue() ? FalseVInPred : TrueVInPred; - else { - // Generate the select in the same block as PN's current incoming block. - // Note: ThisBB need not be the NonConstBB because vector constants - // which are constants by definition are handled here. - // FIXME: This can lead to an increase in IR generation because we might - // generate selects for vector constant phi operand, that could not be - // folded to TrueVInPred or FalseVInPred as done for ConstantInt. For - // non-vector phis, this transformation was always profitable because - // the select would be generated exactly once in the NonConstBB. - Builder.SetInsertPoint(ThisBB->getTerminator()); - InV = Builder.CreateSelect(PN->getIncomingValue(i), TrueVInPred, - FalseVInPred, "phi.sel"); - } - NewPN->addIncoming(InV, ThisBB); - } - } else if (CmpInst *CI = dyn_cast<CmpInst>(&I)) { - Constant *C = cast<Constant>(I.getOperand(1)); - for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV = nullptr; - if (auto *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) - InV = ConstantExpr::getCompare(CI->getPredicate(), InC, C); - else - InV = Builder.CreateCmp(CI->getPredicate(), PN->getIncomingValue(i), - C, "phi.cmp"); - NewPN->addIncoming(InV, PN->getIncomingBlock(i)); - } - } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) { - for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), - Builder); - NewPN->addIncoming(InV, PN->getIncomingBlock(i)); - } - } else if (isa<FreezeInst>(&I)) { - for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV; - if (NonConstBB == PN->getIncomingBlock(i)) - InV = Builder.CreateFreeze(PN->getIncomingValue(i), "phi.fr"); - else - InV = PN->getIncomingValue(i); - NewPN->addIncoming(InV, PN->getIncomingBlock(i)); - } - } else { - CastInst *CI = cast<CastInst>(&I); - Type *RetTy = CI->getType(); - for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV; - if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) - InV = ConstantExpr::getCast(CI->getOpcode(), InC, RetTy); + Instruction *Clone = nullptr; + if (NonSimplifiedBB) { + Clone = I.clone(); + for (Use &U : Clone->operands()) { + if (U == PN) + U = NonSimplifiedInVal; else - InV = Builder.CreateCast(CI->getOpcode(), PN->getIncomingValue(i), - I.getType(), "phi.cast"); - NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + U = U->DoPHITranslation(PN->getParent(), NonSimplifiedBB); } + InsertNewInstBefore(Clone, *NonSimplifiedBB->getTerminator()); + } + + for (unsigned i = 0; i != NumPHIValues; ++i) { + if (NewPhiValues[i]) + NewPN->addIncoming(NewPhiValues[i], PN->getIncomingBlock(i)); + else + NewPN->addIncoming(Clone, PN->getIncomingBlock(i)); } for (User *U : make_early_inc_range(PN->users())) { @@ -1696,6 +1683,35 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { return new ShuffleVectorInst(NewBO0, NewBO1, Mask); } + auto createBinOpReverse = [&](Value *X, Value *Y) { + Value *V = Builder.CreateBinOp(Opcode, X, Y, Inst.getName()); + if (auto *BO = dyn_cast<BinaryOperator>(V)) + BO->copyIRFlags(&Inst); + Module *M = Inst.getModule(); + Function *F = Intrinsic::getDeclaration( + M, Intrinsic::experimental_vector_reverse, V->getType()); + return CallInst::Create(F, V); + }; + + // NOTE: Reverse shuffles don't require the speculative execution protection + // below because they don't affect which lanes take part in the computation. + + Value *V1, *V2; + if (match(LHS, m_VecReverse(m_Value(V1)))) { + // Op(rev(V1), rev(V2)) -> rev(Op(V1, V2)) + if (match(RHS, m_VecReverse(m_Value(V2))) && + (LHS->hasOneUse() || RHS->hasOneUse() || + (LHS == RHS && LHS->hasNUses(2)))) + return createBinOpReverse(V1, V2); + + // Op(rev(V1), RHSSplat)) -> rev(Op(V1, RHSSplat)) + if (LHS->hasOneUse() && isSplatValue(RHS)) + return createBinOpReverse(V1, RHS); + } + // Op(LHSSplat, rev(V2)) -> rev(Op(LHSSplat, V2)) + else if (isSplatValue(LHS) && match(RHS, m_OneUse(m_VecReverse(m_Value(V2))))) + return createBinOpReverse(LHS, V2); + // It may not be safe to reorder shuffles and things like div, urem, etc. // because we may trap when executing those ops on unknown vector elements. // See PR20059. @@ -1711,7 +1727,6 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { // If both arguments of the binary operation are shuffles that use the same // mask and shuffle within a single vector, move the shuffle after the binop. - Value *V1, *V2; if (match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))) && match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(Mask))) && V1->getType() == V2->getType() && @@ -2228,7 +2243,7 @@ Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI, if (Instruction *I = visitBitCast(*BCI)) { if (I != BCI) { I->takeName(BCI); - BCI->getParent()->getInstList().insert(BCI->getIterator(), I); + I->insertInto(BCI->getParent(), BCI->getIterator()); replaceInstUsesWith(*BCI, I); } return &GEP; @@ -2434,10 +2449,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { NewGEP->setOperand(DI, NewPN); } - GEP.getParent()->getInstList().insert( - GEP.getParent()->getFirstInsertionPt(), NewGEP); - replaceOperand(GEP, 0, NewGEP); - PtrOp = NewGEP; + NewGEP->insertInto(GEP.getParent(), GEP.getParent()->getFirstInsertionPt()); + return replaceOperand(GEP, 0, NewGEP); } if (auto *Src = dyn_cast<GEPOperator>(PtrOp)) @@ -2450,7 +2463,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL.getIndexSizeInBits(AS)) { - uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); + uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); bool Matched = false; uint64_t C; @@ -2580,8 +2593,9 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) { // Check that changing the type amounts to dividing the index by a scale // factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); - uint64_t SrcSize = DL.getTypeAllocSize(StrippedPtrEltTy).getFixedSize(); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); + uint64_t SrcSize = + DL.getTypeAllocSize(StrippedPtrEltTy).getFixedValue(); if (ResSize && SrcSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); @@ -2617,10 +2631,10 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { StrippedPtrEltTy->isArrayTy()) { // Check that changing to the array element type amounts to dividing the // index by a scale factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); uint64_t ArrayEltSize = DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) - .getFixedSize(); + .getFixedValue(); if (ResSize && ArrayEltSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); @@ -2681,7 +2695,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { BasePtrOffset.isNonNegative()) { APInt AllocSize( IdxWidth, - DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinSize()); + DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinValue()); if (BasePtrOffset.ule(AllocSize)) { return GetElementPtrInst::CreateInBounds( GEP.getSourceElementType(), PtrOp, Indices, GEP.getName()); @@ -2724,7 +2738,7 @@ static bool isRemovableWrite(CallBase &CB, Value *UsedV, // If the only possible side effect of the call is writing to the alloca, // and the result isn't used, we can safely remove any reads implied by the // call including those which might read the alloca itself. - Optional<MemoryLocation> Dest = MemoryLocation::getForDest(&CB, TLI); + std::optional<MemoryLocation> Dest = MemoryLocation::getForDest(&CB, TLI); return Dest && Dest->Ptr == UsedV; } @@ -2732,7 +2746,7 @@ static bool isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakTrackingVH> &Users, const TargetLibraryInfo &TLI) { SmallVector<Instruction*, 4> Worklist; - const Optional<StringRef> Family = getAllocationFamily(AI, &TLI); + const std::optional<StringRef> Family = getAllocationFamily(AI, &TLI); Worklist.push_back(AI); do { @@ -2778,7 +2792,7 @@ static bool isAllocSiteRemovable(Instruction *AI, MemIntrinsic *MI = cast<MemIntrinsic>(II); if (MI->isVolatile() || MI->getRawDest() != PI) return false; - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::assume: case Intrinsic::invariant_start: @@ -2808,7 +2822,7 @@ static bool isAllocSiteRemovable(Instruction *AI, continue; } - if (getReallocatedOperand(cast<CallBase>(I), &TLI) == PI && + if (getReallocatedOperand(cast<CallBase>(I)) == PI && getAllocationFamily(I, &TLI) == Family) { assert(Family); Users.emplace_back(I); @@ -2902,7 +2916,7 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { Module *M = II->getModule(); Function *F = Intrinsic::getDeclaration(M, Intrinsic::donothing); InvokeInst::Create(F, II->getNormalDest(), II->getUnwindDest(), - None, "", II->getParent()); + std::nullopt, "", II->getParent()); } // Remove debug intrinsics which describe the value contained within the @@ -3052,7 +3066,7 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) { // realloc() entirely. CallInst *CI = dyn_cast<CallInst>(Op); if (CI && CI->hasOneUse()) - if (Value *ReallocatedOp = getReallocatedOperand(CI, &TLI)) + if (Value *ReallocatedOp = getReallocatedOperand(CI)) return eraseInstFromFunction(*replaceInstUsesWith(*CI, ReallocatedOp)); // If we optimize for code size, try to move the call to free before the null @@ -3166,31 +3180,41 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { return visitUnconditionalBranchInst(BI); // Change br (not X), label True, label False to: br X, label False, True - Value *X = nullptr; - if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) && - !isa<Constant>(X)) { + Value *Cond = BI.getCondition(); + Value *X; + if (match(Cond, m_Not(m_Value(X))) && !isa<Constant>(X)) { // Swap Destinations and condition... BI.swapSuccessors(); return replaceOperand(BI, 0, X); } + // Canonicalize logical-and-with-invert as logical-or-with-invert. + // This is done by inverting the condition and swapping successors: + // br (X && !Y), T, F --> br !(X && !Y), F, T --> br (!X || Y), F, T + Value *Y; + if (isa<SelectInst>(Cond) && + match(Cond, + m_OneUse(m_LogicalAnd(m_Value(X), m_OneUse(m_Not(m_Value(Y))))))) { + Value *NotX = Builder.CreateNot(X, "not." + X->getName()); + Value *Or = Builder.CreateLogicalOr(NotX, Y); + BI.swapSuccessors(); + return replaceOperand(BI, 0, Or); + } + // If the condition is irrelevant, remove the use so that other // transforms on the condition become more effective. - if (!isa<ConstantInt>(BI.getCondition()) && - BI.getSuccessor(0) == BI.getSuccessor(1)) - return replaceOperand( - BI, 0, ConstantInt::getFalse(BI.getCondition()->getType())); + if (!isa<ConstantInt>(Cond) && BI.getSuccessor(0) == BI.getSuccessor(1)) + return replaceOperand(BI, 0, ConstantInt::getFalse(Cond->getType())); // Canonicalize, for example, fcmp_one -> fcmp_oeq. CmpInst::Predicate Pred; - if (match(&BI, m_Br(m_OneUse(m_FCmp(Pred, m_Value(), m_Value())), - m_BasicBlock(), m_BasicBlock())) && + if (match(Cond, m_OneUse(m_FCmp(Pred, m_Value(), m_Value()))) && !isCanonicalPredicate(Pred)) { // Swap destinations and condition. - CmpInst *Cond = cast<CmpInst>(BI.getCondition()); - Cond->setPredicate(CmpInst::getInversePredicate(Pred)); + auto *Cmp = cast<CmpInst>(Cond); + Cmp->setPredicate(CmpInst::getInversePredicate(Pred)); BI.swapSuccessors(); - Worklist.push(Cond); + Worklist.push(Cmp); return &BI; } @@ -3218,7 +3242,7 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { // Compute the number of leading bits we can ignore. // TODO: A better way to determine this would use ComputeNumSignBits(). - for (auto &C : SI.cases()) { + for (const auto &C : SI.cases()) { LeadingKnownZeros = std::min( LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); LeadingKnownOnes = std::min( @@ -3247,6 +3271,81 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return nullptr; } +Instruction * +InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) { + auto *WO = dyn_cast<WithOverflowInst>(EV.getAggregateOperand()); + if (!WO) + return nullptr; + + Intrinsic::ID OvID = WO->getIntrinsicID(); + const APInt *C = nullptr; + if (match(WO->getRHS(), m_APIntAllowUndef(C))) { + if (*EV.idx_begin() == 0 && (OvID == Intrinsic::smul_with_overflow || + OvID == Intrinsic::umul_with_overflow)) { + // extractvalue (any_mul_with_overflow X, -1), 0 --> -X + if (C->isAllOnes()) + return BinaryOperator::CreateNeg(WO->getLHS()); + // extractvalue (any_mul_with_overflow X, 2^n), 0 --> X << n + if (C->isPowerOf2()) { + return BinaryOperator::CreateShl( + WO->getLHS(), + ConstantInt::get(WO->getLHS()->getType(), C->logBase2())); + } + } + } + + // We're extracting from an overflow intrinsic. See if we're the only user. + // That allows us to simplify multiple result intrinsics to simpler things + // that just get one value. + if (!WO->hasOneUse()) + return nullptr; + + // Check if we're grabbing only the result of a 'with overflow' intrinsic + // and replace it with a traditional binary instruction. + if (*EV.idx_begin() == 0) { + Instruction::BinaryOps BinOp = WO->getBinaryOp(); + Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); + // Replace the old instruction's uses with poison. + replaceInstUsesWith(*WO, PoisonValue::get(WO->getType())); + eraseInstFromFunction(*WO); + return BinaryOperator::Create(BinOp, LHS, RHS); + } + + assert(*EV.idx_begin() == 1 && "Unexpected extract index for overflow inst"); + + // (usub LHS, RHS) overflows when LHS is unsigned-less-than RHS. + if (OvID == Intrinsic::usub_with_overflow) + return new ICmpInst(ICmpInst::ICMP_ULT, WO->getLHS(), WO->getRHS()); + + // smul with i1 types overflows when both sides are set: -1 * -1 == +1, but + // +1 is not possible because we assume signed values. + if (OvID == Intrinsic::smul_with_overflow && + WO->getLHS()->getType()->isIntOrIntVectorTy(1)) + return BinaryOperator::CreateAnd(WO->getLHS(), WO->getRHS()); + + // If only the overflow result is used, and the right hand side is a + // constant (or constant splat), we can remove the intrinsic by directly + // checking for overflow. + if (C) { + // Compute the no-wrap range for LHS given RHS=C, then construct an + // equivalent icmp, potentially using an offset. + ConstantRange NWR = ConstantRange::makeExactNoWrapRegion( + WO->getBinaryOp(), *C, WO->getNoWrapKind()); + + CmpInst::Predicate Pred; + APInt NewRHSC, Offset; + NWR.getEquivalentICmp(Pred, NewRHSC, Offset); + auto *OpTy = WO->getRHS()->getType(); + auto *NewLHS = WO->getLHS(); + if (Offset != 0) + NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset)); + return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, + ConstantInt::get(OpTy, NewRHSC)); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); @@ -3294,7 +3393,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *NewEV = Builder.CreateExtractValue(IV->getAggregateOperand(), EV.getIndices()); return InsertValueInst::Create(NewEV, IV->getInsertedValueOperand(), - makeArrayRef(insi, inse)); + ArrayRef(insi, inse)); } if (insi == inse) // The insert list is a prefix of the extract list @@ -3306,60 +3405,13 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { // with // %E extractvalue { i32 } { i32 42 }, 0 return ExtractValueInst::Create(IV->getInsertedValueOperand(), - makeArrayRef(exti, exte)); + ArrayRef(exti, exte)); } - if (WithOverflowInst *WO = dyn_cast<WithOverflowInst>(Agg)) { - // extractvalue (any_mul_with_overflow X, -1), 0 --> -X - Intrinsic::ID OvID = WO->getIntrinsicID(); - if (*EV.idx_begin() == 0 && - (OvID == Intrinsic::smul_with_overflow || - OvID == Intrinsic::umul_with_overflow) && - match(WO->getArgOperand(1), m_AllOnes())) { - return BinaryOperator::CreateNeg(WO->getArgOperand(0)); - } - // We're extracting from an overflow intrinsic, see if we're the only user, - // which allows us to simplify multiple result intrinsics to simpler - // things that just get one value. - if (WO->hasOneUse()) { - // Check if we're grabbing only the result of a 'with overflow' intrinsic - // and replace it with a traditional binary instruction. - if (*EV.idx_begin() == 0) { - Instruction::BinaryOps BinOp = WO->getBinaryOp(); - Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); - // Replace the old instruction's uses with poison. - replaceInstUsesWith(*WO, PoisonValue::get(WO->getType())); - eraseInstFromFunction(*WO); - return BinaryOperator::Create(BinOp, LHS, RHS); - } + if (Instruction *R = foldExtractOfOverflowIntrinsic(EV)) + return R; - assert(*EV.idx_begin() == 1 && - "unexpected extract index for overflow inst"); - - // If only the overflow result is used, and the right hand side is a - // constant (or constant splat), we can remove the intrinsic by directly - // checking for overflow. - const APInt *C; - if (match(WO->getRHS(), m_APInt(C))) { - // Compute the no-wrap range for LHS given RHS=C, then construct an - // equivalent icmp, potentially using an offset. - ConstantRange NWR = - ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, - WO->getNoWrapKind()); - - CmpInst::Predicate Pred; - APInt NewRHSC, Offset; - NWR.getEquivalentICmp(Pred, NewRHSC, Offset); - auto *OpTy = WO->getRHS()->getType(); - auto *NewLHS = WO->getLHS(); - if (Offset != 0) - NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset)); - return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, - ConstantInt::get(OpTy, NewRHSC)); - } - } - } - if (LoadInst *L = dyn_cast<LoadInst>(Agg)) + if (LoadInst *L = dyn_cast<LoadInst>(Agg)) { // If the (non-volatile) load only has one use, we can rewrite this to a // load from a GEP. This reduces the size of the load. If a load is used // only by extractvalue instructions then this either must have been @@ -3386,6 +3438,12 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { // the wrong spot, so use replaceInstUsesWith(). return replaceInstUsesWith(EV, NL); } + } + + if (auto *PN = dyn_cast<PHINode>(Agg)) + if (Instruction *Res = foldOpIntoPhi(EV, PN)) + return Res; + // We could simplify extracts from other values. Note that nested extracts may // already be simplified implicitly by the above: extract (extract (insert) ) // will be translated into extract ( insert ( extract ) ) first and then just @@ -3771,7 +3829,8 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { // poison. If the only source of new poison is flags, we can simply // strip them (since we know the only use is the freeze and nothing can // benefit from them.) - if (canCreateUndefOrPoison(cast<Operator>(OrigOp), /*ConsiderFlags*/ false)) + if (canCreateUndefOrPoison(cast<Operator>(OrigOp), + /*ConsiderFlagsAndMetadata*/ false)) return nullptr; // If operand is guaranteed not to be poison, there is no need to add freeze @@ -3779,7 +3838,8 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { // poison. Use *MaybePoisonOperand = nullptr; for (Use &U : OrigOpInst->operands()) { - if (isGuaranteedNotToBeUndefOrPoison(U.get())) + if (isa<MetadataAsValue>(U.get()) || + isGuaranteedNotToBeUndefOrPoison(U.get())) continue; if (!MaybePoisonOperand) MaybePoisonOperand = &U; @@ -3787,7 +3847,7 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { return nullptr; } - OrigOpInst->dropPoisonGeneratingFlags(); + OrigOpInst->dropPoisonGeneratingFlagsAndMetadata(); // If all operands are guaranteed to be non-poison, we can drop freeze. if (!MaybePoisonOperand) @@ -3850,7 +3910,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, Instruction *I = dyn_cast<Instruction>(V); if (!I || canCreateUndefOrPoison(cast<Operator>(I), - /*ConsiderFlags*/ false)) + /*ConsiderFlagsAndMetadata*/ false)) return nullptr; DropFlags.push_back(I); @@ -3858,7 +3918,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, } for (Instruction *I : DropFlags) - I->dropPoisonGeneratingFlags(); + I->dropPoisonGeneratingFlagsAndMetadata(); if (StartNeedsFreeze) { Builder.SetInsertPoint(StartBB->getTerminator()); @@ -3880,21 +3940,14 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) { // *all* uses if the operand is an invoke/callbr and the use is in a phi on // the normal/default destination. This is why the domination check in the // replacement below is still necessary. - Instruction *MoveBefore = nullptr; + Instruction *MoveBefore; if (isa<Argument>(Op)) { - MoveBefore = &FI.getFunction()->getEntryBlock().front(); - while (isa<AllocaInst>(MoveBefore)) - MoveBefore = MoveBefore->getNextNode(); - } else if (auto *PN = dyn_cast<PHINode>(Op)) { - MoveBefore = PN->getParent()->getFirstNonPHI(); - } else if (auto *II = dyn_cast<InvokeInst>(Op)) { - MoveBefore = II->getNormalDest()->getFirstNonPHI(); - } else if (auto *CB = dyn_cast<CallBrInst>(Op)) { - MoveBefore = CB->getDefaultDest()->getFirstNonPHI(); + MoveBefore = + &*FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca(); } else { - auto *I = cast<Instruction>(Op); - assert(!I->isTerminator() && "Cannot be a terminator"); - MoveBefore = I->getNextNode(); + MoveBefore = cast<Instruction>(Op)->getInsertionPointAfterDef(); + if (!MoveBefore) + return false; } bool Changed = false; @@ -3987,7 +4040,7 @@ static bool SoleWriteToDeadLocal(Instruction *I, TargetLibraryInfo &TLI) { // to allow reload along used path as described below. Otherwise, this // is simply a store to a dead allocation which will be removed. return false; - Optional<MemoryLocation> Dest = MemoryLocation::getForDest(CB, TLI); + std::optional<MemoryLocation> Dest = MemoryLocation::getForDest(CB, TLI); if (!Dest) return false; auto *AI = dyn_cast<AllocaInst>(getUnderlyingObject(Dest->Ptr)); @@ -4103,7 +4156,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, SmallVector<DbgVariableIntrinsic *, 2> DIIClones; SmallSet<DebugVariable, 4> SunkVariables; - for (auto User : DbgUsersToSink) { + for (auto *User : DbgUsersToSink) { // A dbg.declare instruction should not be cloned, since there can only be // one per variable fragment. It should be left in the original place // because the sunk instruction is not an alloca (otherwise we could not be @@ -4118,6 +4171,11 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, if (!SunkVariables.insert(DbgUserVariable).second) continue; + // Leave dbg.assign intrinsics in their original positions and there should + // be no need to insert a clone. + if (isa<DbgAssignIntrinsic>(User)) + continue; + DIIClones.emplace_back(cast<DbgVariableIntrinsic>(User->clone())); if (isa<DbgDeclareInst>(User) && isa<CastInst>(I)) DIIClones.back()->replaceVariableLocationOp(I, I->getOperand(0)); @@ -4190,9 +4248,9 @@ bool InstCombinerImpl::run() { // prove that the successor is not executed more frequently than our block. // Return the UserBlock if successful. auto getOptionalSinkBlockForInst = - [this](Instruction *I) -> Optional<BasicBlock *> { + [this](Instruction *I) -> std::optional<BasicBlock *> { if (!EnableCodeSinking) - return None; + return std::nullopt; BasicBlock *BB = I->getParent(); BasicBlock *UserParent = nullptr; @@ -4202,7 +4260,7 @@ bool InstCombinerImpl::run() { if (U->isDroppable()) continue; if (NumUsers > MaxSinkNumUsers) - return None; + return std::nullopt; Instruction *UserInst = cast<Instruction>(U); // Special handling for Phi nodes - get the block the use occurs in. @@ -4213,14 +4271,14 @@ bool InstCombinerImpl::run() { // sophisticated analysis (i.e finding NearestCommonDominator of // these use blocks). if (UserParent && UserParent != PN->getIncomingBlock(i)) - return None; + return std::nullopt; UserParent = PN->getIncomingBlock(i); } } assert(UserParent && "expected to find user block!"); } else { if (UserParent && UserParent != UserInst->getParent()) - return None; + return std::nullopt; UserParent = UserInst->getParent(); } @@ -4230,7 +4288,7 @@ bool InstCombinerImpl::run() { // Try sinking to another block. If that block is unreachable, then do // not bother. SimplifyCFG should handle it. if (UserParent == BB || !DT.isReachableFromEntry(UserParent)) - return None; + return std::nullopt; auto *Term = UserParent->getTerminator(); // See if the user is one of our successors that has only one @@ -4242,7 +4300,7 @@ bool InstCombinerImpl::run() { // - the User will be executed at most once. // So sinking I down to User is always profitable or neutral. if (UserParent->getUniquePredecessor() != BB && !succ_empty(Term)) - return None; + return std::nullopt; assert(DT.dominates(BB, UserParent) && "Dominance relation broken?"); } @@ -4252,7 +4310,7 @@ bool InstCombinerImpl::run() { // No user or only has droppable users. if (!UserParent) - return None; + return std::nullopt; return UserParent; }; @@ -4312,7 +4370,7 @@ bool InstCombinerImpl::run() { InsertPos = InstParent->getFirstNonPHI()->getIterator(); } - InstParent->getInstList().insert(InsertPos, Result); + Result->insertInto(InstParent, InsertPos); // Push the new instruction and any users onto the worklist. Worklist.pushUsersToWorkList(*Result); @@ -4360,7 +4418,7 @@ public: const auto *MDScopeList = dyn_cast_or_null<MDNode>(ScopeList); if (!MDScopeList || !Container.insert(MDScopeList).second) return; - for (auto &MDOperand : MDScopeList->operands()) + for (const auto &MDOperand : MDScopeList->operands()) if (auto *MDScope = dyn_cast<MDNode>(MDOperand)) Container.insert(MDScope); }; @@ -4543,6 +4601,13 @@ static bool combineInstructionsOverFunction( bool MadeIRChange = false; if (ShouldLowerDbgDeclare) MadeIRChange = LowerDbgDeclare(F); + // LowerDbgDeclare calls RemoveRedundantDbgInstrs, but LowerDbgDeclare will + // almost never return true when running an assignment tracking build. Take + // this opportunity to do some clean up for assignment tracking builds too. + if (!MadeIRChange && isAssignmentTrackingEnabled(*F.getParent())) { + for (auto &BB : F) + RemoveRedundantDbgInstrs(&BB); + } // Iterate while there is work to do. unsigned Iteration = 0; diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 3274e36ab71a..599eeeabc143 100644 --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -105,6 +106,7 @@ static const uint64_t kMIPS_ShadowOffsetN32 = 1ULL << 29; static const uint64_t kMIPS32_ShadowOffset32 = 0x0aaa0000; static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 37; static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36; +static const uint64_t kLoongArch64_ShadowOffset64 = 1ULL << 46; static const uint64_t kRISCV64_ShadowOffset64 = 0xd55550000; static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; @@ -347,6 +349,13 @@ static cl::opt<bool> ClSkipPromotableAllocas( cl::desc("Do not instrument promotable allocas"), cl::Hidden, cl::init(true)); +static cl::opt<AsanCtorKind> ClConstructorKind( + "asan-constructor-kind", + cl::desc("Sets the ASan constructor kind"), + cl::values(clEnumValN(AsanCtorKind::None, "none", "No constructors"), + clEnumValN(AsanCtorKind::Global, "global", + "Use global constructors")), + cl::init(AsanCtorKind::Global), cl::Hidden); // These flags allow to change the shadow mapping. // The shadow mapping looks like // Shadow = (Mem >> scale) + offset @@ -395,12 +404,12 @@ static cl::opt<uint32_t> ClForceExperiment( static cl::opt<bool> ClUsePrivateAlias("asan-use-private-alias", cl::desc("Use private aliases for global variables"), - cl::Hidden, cl::init(false)); + cl::Hidden, cl::init(true)); static cl::opt<bool> ClUseOdrIndicator("asan-use-odr-indicator", cl::desc("Use odr indicators to improve ODR reporting"), - cl::Hidden, cl::init(false)); + cl::Hidden, cl::init(true)); static cl::opt<bool> ClUseGlobalsGC("asan-globals-live-support", @@ -483,6 +492,7 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, bool IsMIPS64 = TargetTriple.isMIPS64(); bool IsArmOrThumb = TargetTriple.isARM() || TargetTriple.isThumb(); bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64; + bool IsLoongArch64 = TargetTriple.getArch() == Triple::loongarch64; bool IsRISCV64 = TargetTriple.getArch() == Triple::riscv64; bool IsWindows = TargetTriple.isOSWindows(); bool IsFuchsia = TargetTriple.isOSFuchsia(); @@ -554,6 +564,8 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, Mapping.Offset = kDynamicShadowSentinel; else if (IsAArch64) Mapping.Offset = kAArch64_ShadowOffset64; + else if (IsLoongArch64) + Mapping.Offset = kLoongArch64_ShadowOffset64; else if (IsRISCV64) Mapping.Offset = kRISCV64_ShadowOffset64; else if (IsAMDGPU) @@ -572,12 +584,12 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, } // OR-ing shadow offset if more efficient (at least on x86) if the offset - // is a power of two, but on ppc64 we have to use add since the shadow - // offset is not necessary 1/8-th of the address space. On SystemZ, - // we could OR the constant in a single instruction, but it's more + // is a power of two, but on ppc64 and loongarch64 we have to use add since + // the shadow offset is not necessarily 1/8-th of the address space. On + // SystemZ, we could OR the constant in a single instruction, but it's more // efficient to load it once and use indexed addressing. Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ && !IsPS && - !IsRISCV64 && + !IsRISCV64 && !IsLoongArch64 && !(Mapping.Offset & (Mapping.Offset - 1)) && Mapping.Offset != kDynamicShadowSentinel; bool IsAndroidWithIfuncSupport = @@ -707,7 +719,7 @@ struct AddressSanitizer { private: friend struct FunctionStackPoisoner; - void initializeCallbacks(Module &M); + void initializeCallbacks(Module &M, const TargetLibraryInfo *TLI); bool LooksLikeCodeInBug11395(Instruction *I); bool GlobalIsLinkerInitialized(GlobalVariable *G); @@ -766,15 +778,20 @@ class ModuleAddressSanitizer { public: ModuleAddressSanitizer(Module &M, bool CompileKernel = false, bool Recover = false, bool UseGlobalsGC = true, - bool UseOdrIndicator = false, - AsanDtorKind DestructorKind = AsanDtorKind::Global) + bool UseOdrIndicator = true, + AsanDtorKind DestructorKind = AsanDtorKind::Global, + AsanCtorKind ConstructorKind = AsanCtorKind::Global) : CompileKernel(ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan : CompileKernel), Recover(ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC && !this->CompileKernel), // Enable aliases as they should have no downside with ODR indicators. - UsePrivateAlias(UseOdrIndicator || ClUsePrivateAlias), - UseOdrIndicator(UseOdrIndicator || ClUseOdrIndicator), + UsePrivateAlias(ClUsePrivateAlias.getNumOccurrences() > 0 + ? ClUsePrivateAlias + : UseOdrIndicator), + UseOdrIndicator(ClUseOdrIndicator.getNumOccurrences() > 0 + ? ClUseOdrIndicator + : UseOdrIndicator), // Not a typo: ClWithComdat is almost completely pointless without // ClUseGlobalsGC (because then it only works on modules without // globals, which are rare); it is a prerequisite for ClUseGlobalsGC; @@ -783,7 +800,8 @@ public: // ClWithComdat and ClUseGlobalsGC unless the frontend says it's ok to // do globals-gc. UseCtorComdat(UseGlobalsGC && ClWithComdat && !this->CompileKernel), - DestructorKind(DestructorKind) { + DestructorKind(DestructorKind), + ConstructorKind(ConstructorKind) { C = &(M.getContext()); int LongSize = M.getDataLayout().getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); @@ -841,6 +859,7 @@ private: bool UseOdrIndicator; bool UseCtorComdat; AsanDtorKind DestructorKind; + AsanCtorKind ConstructorKind; Type *IntptrTy; LLVMContext *C; Triple TargetTriple; @@ -1110,9 +1129,9 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { } // end anonymous namespace -void ModuleAddressSanitizerPass::printPipeline( +void AddressSanitizerPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { - static_cast<PassInfoMixin<ModuleAddressSanitizerPass> *>(this)->printPipeline( + static_cast<PassInfoMixin<AddressSanitizerPass> *>(this)->printPipeline( OS, MapClassName2PassName); OS << "<"; if (Options.CompileKernel) @@ -1120,17 +1139,20 @@ void ModuleAddressSanitizerPass::printPipeline( OS << ">"; } -ModuleAddressSanitizerPass::ModuleAddressSanitizerPass( +AddressSanitizerPass::AddressSanitizerPass( const AddressSanitizerOptions &Options, bool UseGlobalGC, - bool UseOdrIndicator, AsanDtorKind DestructorKind) + bool UseOdrIndicator, AsanDtorKind DestructorKind, + AsanCtorKind ConstructorKind) : Options(Options), UseGlobalGC(UseGlobalGC), - UseOdrIndicator(UseOdrIndicator), DestructorKind(DestructorKind) {} + UseOdrIndicator(UseOdrIndicator), DestructorKind(DestructorKind), + ConstructorKind(ClConstructorKind) {} -PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M, - ModuleAnalysisManager &MAM) { +PreservedAnalyses AddressSanitizerPass::run(Module &M, + ModuleAnalysisManager &MAM) { ModuleAddressSanitizer ModuleSanitizer(M, Options.CompileKernel, Options.Recover, UseGlobalGC, - UseOdrIndicator, DestructorKind); + UseOdrIndicator, DestructorKind, + ConstructorKind); bool Modified = false; auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); const StackSafetyGlobalInfo *const SSGI = @@ -1143,7 +1165,15 @@ PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M, Modified |= FunctionSanitizer.instrumentFunction(F, &TLI); } Modified |= ModuleSanitizer.instrumentModule(M); - return Modified ? PreservedAnalyses::none() : PreservedAnalyses::all(); + if (!Modified) + return PreservedAnalyses::all(); + + PreservedAnalyses PA = PreservedAnalyses::none(); + // GlobalsAA is considered stateless and does not get invalidated unless + // explicitly invalidated; PreservedAnalyses::none() is not enough. Sanitizers + // make changes that require GlobalsAA to be invalidated. + PA.abandon<GlobalsAA>(); + return PA; } static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { @@ -1241,7 +1271,7 @@ bool AddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { } bool AddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { - // Instrument acesses from different address spaces only for AMDGPU. + // Instrument accesses from different address spaces only for AMDGPU. Type *PtrTy = cast<PointerType>(Ptr->getType()->getScalarType()); if (PtrTy->getPointerAddressSpace() != 0 && !(TargetTriple.isAMDGPU() && !isUnsupportedAMDGPUAddrspace(Ptr))) @@ -1288,12 +1318,13 @@ void AddressSanitizer::getInterestingMemoryOperands( if (!ClInstrumentAtomics || ignoreAccess(I, RMW->getPointerOperand())) return; Interesting.emplace_back(I, RMW->getPointerOperandIndex(), true, - RMW->getValOperand()->getType(), None); + RMW->getValOperand()->getType(), std::nullopt); } else if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) { if (!ClInstrumentAtomics || ignoreAccess(I, XCHG->getPointerOperand())) return; Interesting.emplace_back(I, XCHG->getPointerOperandIndex(), true, - XCHG->getCompareOperand()->getType(), None); + XCHG->getCompareOperand()->getType(), + std::nullopt); } else if (auto CI = dyn_cast<CallInst>(I)) { if (CI->getIntrinsicID() == Intrinsic::masked_load || CI->getIntrinsicID() == Intrinsic::masked_store) { @@ -1555,7 +1586,7 @@ Instruction *AddressSanitizer::instrumentAMDGPUAddress( Value *IsShared = IRB.CreateCall(AMDGPUAddressShared, {AddrLong}); Value *IsPrivate = IRB.CreateCall(AMDGPUAddressPrivate, {AddrLong}); Value *IsSharedOrPrivate = IRB.CreateOr(IsShared, IsPrivate); - Value *Cmp = IRB.CreateICmpNE(IRB.getTrue(), IsSharedOrPrivate); + Value *Cmp = IRB.CreateNot(IsSharedOrPrivate); Value *AddrSpaceZeroLanding = SplitBlockAndInsertIfThen(Cmp, InsertBefore, false); InsertBefore = cast<Instruction>(AddrSpaceZeroLanding); @@ -1603,11 +1634,10 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, IntegerType::get(*C, std::max(8U, TypeSize >> Mapping.Scale)); Type *ShadowPtrTy = PointerType::get(ShadowTy, 0); Value *ShadowPtr = memToShadow(AddrLong, IRB); - Value *CmpVal = Constant::getNullValue(ShadowTy); Value *ShadowValue = IRB.CreateLoad(ShadowTy, IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); - Value *Cmp = IRB.CreateICmpNE(ShadowValue, CmpVal); + Value *Cmp = IRB.CreateIsNotNull(ShadowValue); size_t Granularity = 1ULL << Mapping.Scale; Instruction *CrashTerm = nullptr; @@ -1675,7 +1705,7 @@ void ModuleAddressSanitizer::poisonOneInitializer(Function &GlobalInit, IRB.CreateCall(AsanPoisonGlobals, ModuleNameAddr); // Add calls to unpoison all globals before each return instruction. - for (auto &BB : GlobalInit.getBasicBlockList()) + for (auto &BB : GlobalInit) if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) CallInst::Create(AsanUnpoisonGlobals, "", RI); } @@ -1742,7 +1772,7 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const { // - Need to poison all copies, not just the main thread's one. if (G->isThreadLocal()) return false; // For now, just ignore this Global if the alignment is large. - if (G->getAlignment() > getMinRedzoneSizeForGlobal()) return false; + if (G->getAlign() && *G->getAlign() > getMinRedzoneSizeForGlobal()) return false; // For non-COFF targets, only instrument globals known to be defined by this // TU. @@ -2078,7 +2108,8 @@ void ModuleAddressSanitizer::InstrumentGlobalsELF( StopELFMetadata->setVisibility(GlobalVariable::HiddenVisibility); // Create a call to register the globals with the runtime. - IRB.CreateCall(AsanRegisterElfGlobals, + if (ConstructorKind == AsanCtorKind::Global) + IRB.CreateCall(AsanRegisterElfGlobals, {IRB.CreatePointerCast(RegisteredFlag, IntptrTy), IRB.CreatePointerCast(StartELFMetadata, IntptrTy), IRB.CreatePointerCast(StopELFMetadata, IntptrTy)}); @@ -2141,7 +2172,8 @@ void ModuleAddressSanitizer::InstrumentGlobalsMachO( ConstantInt::get(IntptrTy, 0), kAsanGlobalsRegisteredFlagName); RegisteredFlag->setVisibility(GlobalVariable::HiddenVisibility); - IRB.CreateCall(AsanRegisterImageGlobals, + if (ConstructorKind == AsanCtorKind::Global) + IRB.CreateCall(AsanRegisterImageGlobals, {IRB.CreatePointerCast(RegisteredFlag, IntptrTy)}); // We also need to unregister globals at the end, e.g., when a shared library @@ -2170,7 +2202,8 @@ void ModuleAddressSanitizer::InstrumentGlobalsWithMetadataArray( if (Mapping.Scale > 3) AllGlobals->setAlignment(Align(1ULL << Mapping.Scale)); - IRB.CreateCall(AsanRegisterGlobals, + if (ConstructorKind == AsanCtorKind::Global) + IRB.CreateCall(AsanRegisterGlobals, {IRB.CreatePointerCast(AllGlobals, IntptrTy), ConstantInt::get(IntptrTy, N)}); @@ -2247,11 +2280,12 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, if (G->hasSanitizerMetadata()) MD = G->getSanitizerMetadata(); - // TODO: Symbol names in the descriptor can be demangled by the runtime - // library. This could save ~0.4% of VM size for a private large binary. - std::string NameForGlobal = llvm::demangle(G->getName().str()); + // The runtime library tries demangling symbol names in the descriptor but + // functionality like __cxa_demangle may be unavailable (e.g. + // -static-libstdc++). So we demangle the symbol names here. + std::string NameForGlobal = G->getName().str(); GlobalVariable *Name = - createPrivateGlobalForString(M, NameForGlobal, + createPrivateGlobalForString(M, llvm::demangle(NameForGlobal), /*AllowMerging*/ true, kAsanGenPrefix); Type *Ty = G->getValueType(); @@ -2398,7 +2432,7 @@ ModuleAddressSanitizer::getRedzoneSizeForGlobal(uint64_t SizeInBytes) const { RZ = MinRZ - SizeInBytes; } else { // Calculate RZ, where MinRZ <= RZ <= MaxRZ, and RZ ~ 1/4 * SizeInBytes. - RZ = std::max(MinRZ, std::min(kMaxRZ, (SizeInBytes / MinRZ / 4) * MinRZ)); + RZ = std::clamp((SizeInBytes / MinRZ / 4) * MinRZ, MinRZ, kMaxRZ); // Round up to multiple of MinRZ. if (SizeInBytes % MinRZ) @@ -2425,24 +2459,32 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) { // Create a module constructor. A destructor is created lazily because not all // platforms, and not all modules need it. - if (CompileKernel) { - // The kernel always builds with its own runtime, and therefore does not - // need the init and version check calls. - AsanCtorFunction = createSanitizerCtor(M, kAsanModuleCtorName); - } else { - std::string AsanVersion = std::to_string(GetAsanVersion(M)); - std::string VersionCheckName = - ClInsertVersionCheck ? (kAsanVersionCheckNamePrefix + AsanVersion) : ""; - std::tie(AsanCtorFunction, std::ignore) = - createSanitizerCtorAndInitFunctions(M, kAsanModuleCtorName, - kAsanInitName, /*InitArgTypes=*/{}, - /*InitArgs=*/{}, VersionCheckName); + if (ConstructorKind == AsanCtorKind::Global) { + if (CompileKernel) { + // The kernel always builds with its own runtime, and therefore does not + // need the init and version check calls. + AsanCtorFunction = createSanitizerCtor(M, kAsanModuleCtorName); + } else { + std::string AsanVersion = std::to_string(GetAsanVersion(M)); + std::string VersionCheckName = + ClInsertVersionCheck ? (kAsanVersionCheckNamePrefix + AsanVersion) : ""; + std::tie(AsanCtorFunction, std::ignore) = + createSanitizerCtorAndInitFunctions(M, kAsanModuleCtorName, + kAsanInitName, /*InitArgTypes=*/{}, + /*InitArgs=*/{}, VersionCheckName); + } } bool CtorComdat = true; if (ClGlobals) { - IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator()); - InstrumentGlobals(IRB, M, &CtorComdat); + assert(AsanCtorFunction || ConstructorKind == AsanCtorKind::None); + if (AsanCtorFunction) { + IRBuilder<> IRB(AsanCtorFunction->getEntryBlock().getTerminator()); + InstrumentGlobals(IRB, M, &CtorComdat); + } else { + IRBuilder<> IRB(*C); + InstrumentGlobals(IRB, M, &CtorComdat); + } } const uint64_t Priority = GetCtorAndDtorPriority(TargetTriple); @@ -2451,14 +2493,17 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) { // (1) global instrumentation is not TU-specific // (2) target is ELF. if (UseCtorComdat && TargetTriple.isOSBinFormatELF() && CtorComdat) { - AsanCtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleCtorName)); - appendToGlobalCtors(M, AsanCtorFunction, Priority, AsanCtorFunction); + if (AsanCtorFunction) { + AsanCtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleCtorName)); + appendToGlobalCtors(M, AsanCtorFunction, Priority, AsanCtorFunction); + } if (AsanDtorFunction) { AsanDtorFunction->setComdat(M.getOrInsertComdat(kAsanModuleDtorName)); appendToGlobalDtors(M, AsanDtorFunction, Priority, AsanDtorFunction); } } else { - appendToGlobalCtors(M, AsanCtorFunction, Priority); + if (AsanCtorFunction) + appendToGlobalCtors(M, AsanCtorFunction, Priority); if (AsanDtorFunction) appendToGlobalDtors(M, AsanDtorFunction, Priority); } @@ -2466,7 +2511,7 @@ bool ModuleAddressSanitizer::instrumentModule(Module &M) { return true; } -void AddressSanitizer::initializeCallbacks(Module &M) { +void AddressSanitizer::initializeCallbacks(Module &M, const TargetLibraryInfo *TLI) { IRBuilder<> IRB(*C); // Create __asan_report* callbacks. // IsWrite, TypeSize and Exp are encoded in the function name. @@ -2478,18 +2523,24 @@ void AddressSanitizer::initializeCallbacks(Module &M) { SmallVector<Type *, 3> Args2 = {IntptrTy, IntptrTy}; SmallVector<Type *, 2> Args1{1, IntptrTy}; + AttributeList AL2; + AttributeList AL1; if (Exp) { Type *ExpType = Type::getInt32Ty(*C); Args2.push_back(ExpType); Args1.push_back(ExpType); + if (auto AK = TLI->getExtAttrForI32Param(false)) { + AL2 = AL2.addParamAttribute(*C, 2, AK); + AL1 = AL1.addParamAttribute(*C, 1, AK); + } } AsanErrorCallbackSized[AccessIsWrite][Exp] = M.getOrInsertFunction( kAsanReportErrorTemplate + ExpStr + TypeStr + "_n" + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args2, false)); + FunctionType::get(IRB.getVoidTy(), Args2, false), AL2); AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args2, false)); + FunctionType::get(IRB.getVoidTy(), Args2, false), AL2); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { @@ -2497,12 +2548,12 @@ void AddressSanitizer::initializeCallbacks(Module &M) { AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = M.getOrInsertFunction( kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args1, false)); + FunctionType::get(IRB.getVoidTy(), Args1, false), AL1); AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] = M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args1, false)); + FunctionType::get(IRB.getVoidTy(), Args1, false), AL1); } } } @@ -2518,6 +2569,7 @@ void AddressSanitizer::initializeCallbacks(Module &M) { IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); AsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset", + TLI->getAttrList(C, {1}, /*Signed=*/false), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); @@ -2644,7 +2696,7 @@ bool AddressSanitizer::instrumentFunction(Function &F, LLVM_DEBUG(dbgs() << "ASAN instrumenting:\n" << F << "\n"); - initializeCallbacks(*F.getParent()); + initializeCallbacks(*F.getParent(), TLI); FunctionStateRAII CleanupObj(this); @@ -2733,7 +2785,7 @@ bool AddressSanitizer::instrumentFunction(Function &F, F.getParent()->getDataLayout()); FunctionModified = true; } - for (auto Inst : IntrinToInstrument) { + for (auto *Inst : IntrinToInstrument) { if (!suppressInstrumentationSiteForDebug(NumInstrumented)) instrumentMemIntrinsic(Inst); FunctionModified = true; @@ -2744,12 +2796,12 @@ bool AddressSanitizer::instrumentFunction(Function &F, // We must unpoison the stack before NoReturn calls (throw, _exit, etc). // See e.g. https://github.com/google/sanitizers/issues/37 - for (auto CI : NoReturnCalls) { + for (auto *CI : NoReturnCalls) { IRBuilder<> IRB(CI); IRB.CreateCall(AsanHandleNoReturnFunc, {}); } - for (auto Inst : PointerComparisonsOrSubtracts) { + for (auto *Inst : PointerComparisonsOrSubtracts) { instrumentPointerComparisonOrSubtraction(Inst); FunctionModified = true; } @@ -2800,7 +2852,8 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), IntptrTy, IntptrTy); } - for (size_t Val : {0x00, 0xf1, 0xf2, 0xf3, 0xf5, 0xf8}) { + for (size_t Val : {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0xf1, 0xf2, + 0xf3, 0xf5, 0xf8}) { std::ostringstream Name; Name << kAsanSetShadowPrefix; Name << std::setw(2) << std::setfill('0') << std::hex << Val; @@ -3342,7 +3395,8 @@ void FunctionStackPoisoner::processStaticAllocas() { } // We are done. Remove the old unused alloca instructions. - for (auto AI : AllocaVec) AI->eraseFromParent(); + for (auto *AI : AllocaVec) + AI->eraseFromParent(); } void FunctionStackPoisoner::poisonAlloca(Value *V, uint64_t Size, diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index 1eadafb4e4b4..8b1d39ad412f 100644 --- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -146,6 +146,7 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, const DataLayout &DL = F.getParent()->getDataLayout(); ObjectSizeOpts EvalOpts; EvalOpts.RoundToAlign = true; + EvalOpts.EvalMode = ObjectSizeOpts::Mode::ExactUnderlyingSizeAndOffset; ObjectSizeOffsetEvaluator ObjSizeEval(DL, &TLI, F.getContext(), EvalOpts); // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory @@ -221,35 +222,3 @@ PreservedAnalyses BoundsCheckingPass::run(Function &F, FunctionAnalysisManager & return PreservedAnalyses::none(); } - -namespace { -struct BoundsCheckingLegacyPass : public FunctionPass { - static char ID; - - BoundsCheckingLegacyPass() : FunctionPass(ID) { - initializeBoundsCheckingLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - return addBoundsChecking(F, TLI, SE); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - } -}; -} // namespace - -char BoundsCheckingLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(BoundsCheckingLegacyPass, "bounds-checking", - "Run-time bounds checking", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(BoundsCheckingLegacyPass, "bounds-checking", - "Run-time bounds checking", false, false) - -FunctionPass *llvm::createBoundsCheckingLegacyPass() { - return new BoundsCheckingLegacyPass(); -} diff --git a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp index 27107f46ed92..1c630e9ee424 100644 --- a/llvm/lib/Transforms/Instrumentation/CGProfile.cpp +++ b/llvm/lib/Transforms/Instrumentation/CGProfile.cpp @@ -18,6 +18,7 @@ #include "llvm/InitializePasses.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Transforms/Instrumentation.h" +#include <optional> using namespace llvm; @@ -73,7 +74,7 @@ static bool runCGProfilePass( continue; TargetTransformInfo &TTI = GetTTI(F); for (auto &BB : F) { - Optional<uint64_t> BBCount = BFI.getBlockProfileCount(&BB); + std::optional<uint64_t> BBCount = BFI.getBlockProfileCount(&BB); if (!BBCount) continue; for (auto &I : BB) { diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index adc007dacae4..a072ba278fce 100644 --- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -38,6 +39,7 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ValueMapper.h" +#include <optional> #include <set> #include <sstream> @@ -47,6 +49,9 @@ using namespace llvm; #define CHR_DEBUG(X) LLVM_DEBUG(X) +static cl::opt<bool> DisableCHR("disable-chr", cl::init(false), cl::Hidden, + cl::desc("Disable CHR for all functions")); + static cl::opt<bool> ForceCHR("force-chr", cl::init(false), cl::Hidden, cl::desc("Apply CHR for all functions")); @@ -66,6 +71,10 @@ static cl::opt<std::string> CHRFunctionList( "chr-function-list", cl::init(""), cl::Hidden, cl::desc("Specify file to retrieve the list of functions to apply CHR to")); +static cl::opt<unsigned> CHRDupThreshsold( + "chr-dup-threshold", cl::init(3), cl::Hidden, + cl::desc("Max number of duplications by CHR for a region")); + static StringSet<> CHRModules; static StringSet<> CHRFunctions; @@ -339,23 +348,27 @@ class CHR { BasicBlock *EntryBlock, BasicBlock *NewEntryBlock, ValueToValueMapTy &VMap); - void fixupBranchesAndSelects(CHRScope *Scope, - BasicBlock *PreEntryBlock, - BranchInst *MergedBR, - uint64_t ProfileCount); - void fixupBranch(Region *R, - CHRScope *Scope, - IRBuilder<> &IRB, + void fixupBranchesAndSelects(CHRScope *Scope, BasicBlock *PreEntryBlock, + BranchInst *MergedBR, uint64_t ProfileCount); + void fixupBranch(Region *R, CHRScope *Scope, IRBuilder<> &IRB, Value *&MergedCondition, BranchProbability &CHRBranchBias); - void fixupSelect(SelectInst* SI, - CHRScope *Scope, - IRBuilder<> &IRB, + void fixupSelect(SelectInst *SI, CHRScope *Scope, IRBuilder<> &IRB, Value *&MergedCondition, BranchProbability &CHRBranchBias); void addToMergedCondition(bool IsTrueBiased, Value *Cond, - Instruction *BranchOrSelect, - CHRScope *Scope, - IRBuilder<> &IRB, - Value *&MergedCondition); + Instruction *BranchOrSelect, CHRScope *Scope, + IRBuilder<> &IRB, Value *&MergedCondition); + unsigned getRegionDuplicationCount(const Region *R) { + unsigned Count = 0; + // Find out how many times region R is cloned. Note that if the parent + // of R is cloned, R is also cloned, but R's clone count is not updated + // from the clone of the parent. We need to accumlate all the counts + // from the ancestors to get the clone count. + while (R) { + Count += DuplicationCount[R]; + R = R->getParent(); + } + return Count; + } Function &F; BlockFrequencyInfo &BFI; @@ -379,6 +392,8 @@ class CHR { DenseMap<SelectInst *, BranchProbability> SelectBiasMap; // All the scopes. DenseSet<CHRScope *> Scopes; + // This maps records how many times this region is cloned. + DenseMap<const Region *, unsigned> DuplicationCount; }; } // end anonymous namespace @@ -396,7 +411,10 @@ raw_ostream &operator<<(raw_ostream &OS, const CHRScope &Scope) { return OS; } -static bool shouldApply(Function &F, ProfileSummaryInfo& PSI) { +static bool shouldApply(Function &F, ProfileSummaryInfo &PSI) { + if (DisableCHR) + return false; + if (ForceCHR) return true; @@ -406,7 +424,6 @@ static bool shouldApply(Function &F, ProfileSummaryInfo& PSI) { return CHRFunctions.count(F.getName()); } - assert(PSI.hasProfileSummary() && "Empty PSI?"); return PSI.isFunctionEntryHot(&F); } @@ -462,7 +479,7 @@ static bool isHoistableInstructionType(Instruction *I) { static bool isHoistable(Instruction *I, DominatorTree &DT) { if (!isHoistableInstructionType(I)) return false; - return isSafeToSpeculativelyExecute(I, nullptr, &DT); + return isSafeToSpeculativelyExecute(I, nullptr, nullptr, &DT); } // Recursively traverse the use-def chains of the given value and return a set @@ -559,32 +576,26 @@ checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT, return true; } -// Returns true and sets the true probability and false probability of an -// MD_prof metadata if it's well-formed. -static bool checkMDProf(MDNode *MD, BranchProbability &TrueProb, - BranchProbability &FalseProb) { - if (!MD) return false; - MDString *MDName = cast<MDString>(MD->getOperand(0)); - if (MDName->getString() != "branch_weights" || - MD->getNumOperands() != 3) - return false; - ConstantInt *TrueWeight = mdconst::extract<ConstantInt>(MD->getOperand(1)); - ConstantInt *FalseWeight = mdconst::extract<ConstantInt>(MD->getOperand(2)); - if (!TrueWeight || !FalseWeight) +// Constructs the true and false branch probabilities if the the instruction has +// valid branch weights. Returns true when this was successful, false otherwise. +static bool extractBranchProbabilities(Instruction *I, + BranchProbability &TrueProb, + BranchProbability &FalseProb) { + uint64_t TrueWeight; + uint64_t FalseWeight; + if (!extractBranchWeights(*I, TrueWeight, FalseWeight)) return false; - uint64_t TrueWt = TrueWeight->getValue().getZExtValue(); - uint64_t FalseWt = FalseWeight->getValue().getZExtValue(); - uint64_t SumWt = TrueWt + FalseWt; + uint64_t SumWeight = TrueWeight + FalseWeight; - assert(SumWt >= TrueWt && SumWt >= FalseWt && + assert(SumWeight >= TrueWeight && SumWeight >= FalseWeight && "Overflow calculating branch probabilities."); // Guard against 0-to-0 branch weights to avoid a division-by-zero crash. - if (SumWt == 0) + if (SumWeight == 0) return false; - TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt); - FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt); + TrueProb = BranchProbability::getBranchProbability(TrueWeight, SumWeight); + FalseProb = BranchProbability::getBranchProbability(FalseWeight, SumWeight); return true; } @@ -623,8 +634,7 @@ static bool checkBiasedBranch(BranchInst *BI, Region *R, if (!BI->isConditional()) return false; BranchProbability ThenProb, ElseProb; - if (!checkMDProf(BI->getMetadata(LLVMContext::MD_prof), - ThenProb, ElseProb)) + if (!extractBranchProbabilities(BI, ThenProb, ElseProb)) return false; BasicBlock *IfThen = BI->getSuccessor(0); BasicBlock *IfElse = BI->getSuccessor(1); @@ -653,8 +663,7 @@ static bool checkBiasedSelect( DenseSet<SelectInst *> &FalseBiasedSelectsGlobal, DenseMap<SelectInst *, BranchProbability> &SelectBiasMap) { BranchProbability TrueProb, FalseProb; - if (!checkMDProf(SI->getMetadata(LLVMContext::MD_prof), - TrueProb, FalseProb)) + if (!extractBranchProbabilities(SI, TrueProb, FalseProb)) return false; CHR_DEBUG(dbgs() << "SI " << *SI << " "); CHR_DEBUG(dbgs() << "TrueProb " << TrueProb << " "); @@ -1667,11 +1676,32 @@ void CHR::transformScopes(CHRScope *Scope, DenseSet<PHINode *> &TrivialPHIs) { CHR_DEBUG(dbgs() << "transformScopes " << *Scope << "\n"); assert(Scope->RegInfos.size() >= 1 && "Should have at least one Region"); + + for (RegInfo &RI : Scope->RegInfos) { + const Region *R = RI.R; + unsigned Duplication = getRegionDuplicationCount(R); + CHR_DEBUG(dbgs() << "Dup count for R=" << R << " is " << Duplication + << "\n"); + if (Duplication >= CHRDupThreshsold) { + CHR_DEBUG(dbgs() << "Reached the dup threshold of " << Duplication + << " for this region"); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "DupThresholdReached", + R->getEntry()->getTerminator()) + << "Reached the duplication threshold for the region"; + }); + return; + } + } + for (RegInfo &RI : Scope->RegInfos) { + DuplicationCount[RI.R]++; + } + Region *FirstRegion = Scope->RegInfos[0].R; BasicBlock *EntryBlock = FirstRegion->getEntry(); Region *LastRegion = Scope->RegInfos[Scope->RegInfos.size() - 1].R; BasicBlock *ExitBlock = LastRegion->getExit(); - Optional<uint64_t> ProfileCount = BFI.getBlockProfileCount(EntryBlock); + std::optional<uint64_t> ProfileCount = BFI.getBlockProfileCount(EntryBlock); if (ExitBlock) { // Insert a trivial phi at the exit block (where the CHR hot path and the @@ -1753,13 +1783,12 @@ void CHR::cloneScopeBlocks(CHRScope *Scope, // Place the cloned blocks right after the original blocks (right before the // exit block of.) if (ExitBlock) - F.getBasicBlockList().splice(ExitBlock->getIterator(), - F.getBasicBlockList(), - NewBlocks[0]->getIterator(), F.end()); + F.splice(ExitBlock->getIterator(), &F, NewBlocks[0]->getIterator(), + F.end()); // Update the cloned blocks/instructions to refer to themselves. - for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) - for (Instruction &I : *NewBlocks[i]) + for (BasicBlock *NewBB : NewBlocks) + for (Instruction &I : *NewBB) RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); @@ -1801,7 +1830,7 @@ BranchInst *CHR::createMergedBranch(BasicBlock *PreEntryBlock, BranchInst *NewBR = BranchInst::Create(NewEntryBlock, cast<BasicBlock>(VMap[NewEntryBlock]), ConstantInt::getTrue(F.getContext())); - PreEntryBlock->getInstList().push_back(NewBR); + NewBR->insertInto(PreEntryBlock, PreEntryBlock->end()); assert(NewEntryBlock->getSinglePredecessor() == EntryBlock && "NewEntryBlock's only pred must be EntryBlock"); return NewBR; @@ -1983,7 +2012,7 @@ bool CHR::run() { findScopes(AllScopes); CHR_DEBUG(dumpScopes(AllScopes, "All scopes")); - // Split the scopes if 1) the conditiona values of the biased + // Split the scopes if 1) the conditional values of the biased // branches/selects of the inner/lower scope can't be hoisted up to the // outermost/uppermost scope entry, or 2) the condition values of the biased // branches/selects in a scope (including subscopes) don't share at least diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 6815688827d2..e9614b48fde7 100644 --- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -63,13 +63,14 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" @@ -222,6 +223,14 @@ static cl::opt<bool> ClConditionalCallbacks( cl::desc("Insert calls to callback functions on conditionals."), cl::Hidden, cl::init(false)); +// Experimental feature that inserts callbacks for data reaching a function, +// either via function arguments and loads. +// This must be true for dfsan_set_reaches_function_callback() to have effect. +static cl::opt<bool> ClReachesFunctionCallbacks( + "dfsan-reaches-function-callbacks", + cl::desc("Insert calls to callback functions on data reaching a function."), + cl::Hidden, cl::init(false)); + // Controls whether the pass tracks the control flow of select instructions. static cl::opt<bool> ClTrackSelectControlFlow( "dfsan-track-select-control-flow", @@ -278,14 +287,23 @@ struct MemoryMapParams { } // end anonymous namespace +// NOLINTBEGIN(readability-identifier-naming) +// aarch64 Linux +const MemoryMapParams Linux_AArch64_MemoryMapParams = { + 0, // AndMask (not used) + 0x0B00000000000, // XorMask + 0, // ShadowBase (not used) + 0x0200000000000, // OriginBase +}; + // x86_64 Linux -// NOLINTNEXTLINE(readability-identifier-naming) -static const MemoryMapParams Linux_X86_64_MemoryMapParams = { +const MemoryMapParams Linux_X86_64_MemoryMapParams = { 0, // AndMask (not used) 0x500000000000, // XorMask 0, // ShadowBase (not used) 0x100000000000, // OriginBase }; +// NOLINTEND(readability-identifier-naming) namespace { @@ -386,7 +404,7 @@ transformFunctionAttributes(const TransformedFunction &TransformedFunction, return AttributeList::get(Ctx, CallSiteAttrs.getFnAttrs(), CallSiteAttrs.getRetAttrs(), - llvm::makeArrayRef(ArgumentAttributes)); + llvm::ArrayRef(ArgumentAttributes)); } class DataFlowSanitizer { @@ -445,12 +463,16 @@ class DataFlowSanitizer { FunctionType *DFSanVarargWrapperFnTy; FunctionType *DFSanConditionalCallbackFnTy; FunctionType *DFSanConditionalCallbackOriginFnTy; + FunctionType *DFSanReachesFunctionCallbackFnTy; + FunctionType *DFSanReachesFunctionCallbackOriginFnTy; FunctionType *DFSanCmpCallbackFnTy; FunctionType *DFSanLoadStoreCallbackFnTy; FunctionType *DFSanMemTransferCallbackFnTy; FunctionType *DFSanChainOriginFnTy; FunctionType *DFSanChainOriginIfTaintedFnTy; FunctionType *DFSanMemOriginTransferFnTy; + FunctionType *DFSanMemShadowOriginTransferFnTy; + FunctionType *DFSanMemShadowOriginConditionalExchangeFnTy; FunctionType *DFSanMaybeStoreOriginFnTy; FunctionCallee DFSanUnionLoadFn; FunctionCallee DFSanLoadLabelAndOriginFn; @@ -464,10 +486,14 @@ class DataFlowSanitizer { FunctionCallee DFSanMemTransferCallbackFn; FunctionCallee DFSanConditionalCallbackFn; FunctionCallee DFSanConditionalCallbackOriginFn; + FunctionCallee DFSanReachesFunctionCallbackFn; + FunctionCallee DFSanReachesFunctionCallbackOriginFn; FunctionCallee DFSanCmpCallbackFn; FunctionCallee DFSanChainOriginFn; FunctionCallee DFSanChainOriginIfTaintedFn; FunctionCallee DFSanMemOriginTransferFn; + FunctionCallee DFSanMemShadowOriginTransferFn; + FunctionCallee DFSanMemShadowOriginConditionalExchangeFn; FunctionCallee DFSanMaybeStoreOriginFn; SmallPtrSet<Value *, 16> DFSanRuntimeFunctions; MDNode *ColdCallWeights; @@ -498,7 +524,6 @@ class DataFlowSanitizer { FunctionType *NewFT); void initializeCallbackFunctions(Module &M); void initializeRuntimeFunctions(Module &M); - void injectMetadataGlobals(Module &M); bool initializeModule(Module &M); /// Advances \p OriginAddr to point to the next 32-bit origin and then loads @@ -539,7 +564,8 @@ class DataFlowSanitizer { public: DataFlowSanitizer(const std::vector<std::string> &ABIListFiles); - bool runImpl(Module &M); + bool runImpl(Module &M, + llvm::function_ref<TargetLibraryInfo &(Function &)> GetTLI); }; struct DFSanFunction { @@ -548,6 +574,7 @@ struct DFSanFunction { DominatorTree DT; bool IsNativeABI; bool IsForceZeroLabels; + TargetLibraryInfo &TLI; AllocaInst *LabelReturnAlloca = nullptr; AllocaInst *OriginReturnAlloca = nullptr; DenseMap<Value *, Value *> ValShadowMap; @@ -579,9 +606,9 @@ struct DFSanFunction { DenseMap<Value *, std::set<Value *>> ShadowElements; DFSanFunction(DataFlowSanitizer &DFS, Function *F, bool IsNativeABI, - bool IsForceZeroLabels) + bool IsForceZeroLabels, TargetLibraryInfo &TLI) : DFS(DFS), F(F), IsNativeABI(IsNativeABI), - IsForceZeroLabels(IsForceZeroLabels) { + IsForceZeroLabels(IsForceZeroLabels), TLI(TLI) { DT.recalculate(*F); } @@ -666,6 +693,11 @@ struct DFSanFunction { // branch instruction using the given conditional expression. void addConditionalCallbacksIfEnabled(Instruction &I, Value *Condition); + // If ClReachesFunctionCallbacks is enabled, insert a callback for each + // argument and load instruction. + void addReachesFunctionCallbacksIfEnabled(IRBuilder<> &IRB, Instruction &I, + Value *Data); + bool isLookupTableConstant(Value *P); private: @@ -763,6 +795,10 @@ public: void visitAtomicRMWInst(AtomicRMWInst &I); void visitAtomicCmpXchgInst(AtomicCmpXchgInst &I); void visitReturnInst(ReturnInst &RI); + void visitLibAtomicLoad(CallBase &CB); + void visitLibAtomicStore(CallBase &CB); + void visitLibAtomicExchange(CallBase &CB); + void visitLibAtomicCompareExchange(CallBase &CB); void visitCallBase(CallBase &CB); void visitPHINode(PHINode &PN); void visitExtractElementInst(ExtractElementInst &I); @@ -791,8 +827,31 @@ private: void addOriginArguments(Function &F, CallBase &CB, std::vector<Value *> &Args, IRBuilder<> &IRB); + + Value *makeAddAcquireOrderingTable(IRBuilder<> &IRB); + Value *makeAddReleaseOrderingTable(IRBuilder<> &IRB); }; +bool LibAtomicFunction(const Function &F) { + // This is a bit of a hack because TargetLibraryInfo is a function pass. + // The DFSan pass would need to be refactored to be function pass oriented + // (like MSan is) in order to fit together nicely with TargetLibraryInfo. + // We need this check to prevent them from being instrumented, or wrapped. + // Match on name and number of arguments. + if (!F.hasName() || F.isVarArg()) + return false; + switch (F.arg_size()) { + case 4: + return F.getName() == "__atomic_load" || F.getName() == "__atomic_store"; + case 5: + return F.getName() == "__atomic_exchange"; + case 6: + return F.getName() == "__atomic_compare_exchange"; + default: + return false; + } +} + } // end anonymous namespace DataFlowSanitizer::DataFlowSanitizer( @@ -982,13 +1041,55 @@ void DFSanFunction::addConditionalCallbacksIfEnabled(Instruction &I, } IRBuilder<> IRB(&I); Value *CondShadow = getShadow(Condition); + CallInst *CI; if (DFS.shouldTrackOrigins()) { Value *CondOrigin = getOrigin(Condition); - IRB.CreateCall(DFS.DFSanConditionalCallbackOriginFn, - {CondShadow, CondOrigin}); + CI = IRB.CreateCall(DFS.DFSanConditionalCallbackOriginFn, + {CondShadow, CondOrigin}); + } else { + CI = IRB.CreateCall(DFS.DFSanConditionalCallbackFn, {CondShadow}); + } + CI->addParamAttr(0, Attribute::ZExt); +} + +void DFSanFunction::addReachesFunctionCallbacksIfEnabled(IRBuilder<> &IRB, + Instruction &I, + Value *Data) { + if (!ClReachesFunctionCallbacks) { + return; + } + const DebugLoc &dbgloc = I.getDebugLoc(); + Value *DataShadow = collapseToPrimitiveShadow(getShadow(Data), IRB); + ConstantInt *CILine; + llvm::Value *FilePathPtr; + + if (dbgloc.get() == nullptr) { + CILine = llvm::ConstantInt::get(I.getContext(), llvm::APInt(32, 0)); + FilePathPtr = IRB.CreateGlobalStringPtr( + I.getFunction()->getParent()->getSourceFileName()); } else { - IRB.CreateCall(DFS.DFSanConditionalCallbackFn, {CondShadow}); + CILine = llvm::ConstantInt::get(I.getContext(), + llvm::APInt(32, dbgloc.getLine())); + FilePathPtr = + IRB.CreateGlobalStringPtr(dbgloc->getFilename()); } + + llvm::Value *FunctionNamePtr = + IRB.CreateGlobalStringPtr(I.getFunction()->getName()); + + CallInst *CB; + std::vector<Value *> args; + + if (DFS.shouldTrackOrigins()) { + Value *DataOrigin = getOrigin(Data); + args = { DataShadow, DataOrigin, FilePathPtr, CILine, FunctionNamePtr }; + CB = IRB.CreateCall(DFS.DFSanReachesFunctionCallbackOriginFn, args); + } else { + args = { DataShadow, FilePathPtr, CILine, FunctionNamePtr }; + CB = IRB.CreateCall(DFS.DFSanReachesFunctionCallbackFn, args); + } + CB->addParamAttr(0, Attribute::ZExt); + CB->setDebugLoc(dbgloc); } Type *DataFlowSanitizer::getShadowTy(Type *OrigTy) { @@ -1020,9 +1121,16 @@ bool DataFlowSanitizer::initializeModule(Module &M) { if (TargetTriple.getOS() != Triple::Linux) report_fatal_error("unsupported operating system"); - if (TargetTriple.getArch() != Triple::x86_64) + switch (TargetTriple.getArch()) { + case Triple::aarch64: + MapParams = &Linux_AArch64_MemoryMapParams; + break; + case Triple::x86_64: + MapParams = &Linux_X86_64_MemoryMapParams; + break; + default: report_fatal_error("unsupported architecture"); - MapParams = &Linux_X86_64_MemoryMapParams; + } Mod = &M; Ctx = &M.getContext(); @@ -1052,8 +1160,8 @@ bool DataFlowSanitizer::initializeModule(Module &M) { Type::getInt8PtrTy(*Ctx), IntptrTy}; DFSanSetLabelFnTy = FunctionType::get(Type::getVoidTy(*Ctx), DFSanSetLabelArgs, /*isVarArg=*/false); - DFSanNonzeroLabelFnTy = - FunctionType::get(Type::getVoidTy(*Ctx), None, /*isVarArg=*/false); + DFSanNonzeroLabelFnTy = FunctionType::get(Type::getVoidTy(*Ctx), std::nullopt, + /*isVarArg=*/false); DFSanVarargWrapperFnTy = FunctionType::get( Type::getVoidTy(*Ctx), Type::getInt8PtrTy(*Ctx), /*isVarArg=*/false); DFSanConditionalCallbackFnTy = @@ -1063,6 +1171,16 @@ bool DataFlowSanitizer::initializeModule(Module &M) { DFSanConditionalCallbackOriginFnTy = FunctionType::get( Type::getVoidTy(*Ctx), DFSanConditionalCallbackOriginArgs, /*isVarArg=*/false); + Type *DFSanReachesFunctionCallbackArgs[4] = {PrimitiveShadowTy, Int8Ptr, + OriginTy, Int8Ptr}; + DFSanReachesFunctionCallbackFnTy = + FunctionType::get(Type::getVoidTy(*Ctx), DFSanReachesFunctionCallbackArgs, + /*isVarArg=*/false); + Type *DFSanReachesFunctionCallbackOriginArgs[5] = { + PrimitiveShadowTy, OriginTy, Int8Ptr, OriginTy, Int8Ptr}; + DFSanReachesFunctionCallbackOriginFnTy = FunctionType::get( + Type::getVoidTy(*Ctx), DFSanReachesFunctionCallbackOriginArgs, + /*isVarArg=*/false); DFSanCmpCallbackFnTy = FunctionType::get(Type::getVoidTy(*Ctx), PrimitiveShadowTy, /*isVarArg=*/false); @@ -1078,6 +1196,15 @@ bool DataFlowSanitizer::initializeModule(Module &M) { Type *DFSanMemOriginTransferArgs[3] = {Int8Ptr, Int8Ptr, IntptrTy}; DFSanMemOriginTransferFnTy = FunctionType::get( Type::getVoidTy(*Ctx), DFSanMemOriginTransferArgs, /*isVarArg=*/false); + Type *DFSanMemShadowOriginTransferArgs[3] = {Int8Ptr, Int8Ptr, IntptrTy}; + DFSanMemShadowOriginTransferFnTy = + FunctionType::get(Type::getVoidTy(*Ctx), DFSanMemShadowOriginTransferArgs, + /*isVarArg=*/false); + Type *DFSanMemShadowOriginConditionalExchangeArgs[5] = { + IntegerType::get(*Ctx, 8), Int8Ptr, Int8Ptr, Int8Ptr, IntptrTy}; + DFSanMemShadowOriginConditionalExchangeFnTy = FunctionType::get( + Type::getVoidTy(*Ctx), DFSanMemShadowOriginConditionalExchangeArgs, + /*isVarArg=*/false); Type *DFSanLoadStoreCallbackArgs[2] = {PrimitiveShadowTy, Int8Ptr}; DFSanLoadStoreCallbackFnTy = FunctionType::get(Type::getVoidTy(*Ctx), DFSanLoadStoreCallbackArgs, @@ -1146,7 +1273,7 @@ void DataFlowSanitizer::buildExternWeakCheckIfNeeded(IRBuilder<> &IRB, // but replacing with a known-to-not-be-null wrapper can break this check. // When replacing uses of the extern weak function with the wrapper we try // to avoid replacing uses in conditionals, but this is not perfect. - // In the case where we fail, and accidentially optimize out a null check + // In the case where we fail, and accidentally optimize out a null check // for a extern weak function, add a check here to help identify the issue. if (GlobalValue::isExternalWeakLinkage(F->getLinkage())) { std::vector<Value *> Args; @@ -1190,19 +1317,22 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName, // Initialize DataFlowSanitizer runtime functions and declare them in the module void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { + LLVMContext &C = M.getContext(); { AttributeList AL; - AL = AL.addFnAttribute(M.getContext(), Attribute::NoUnwind); - AL = AL.addFnAttribute(M.getContext(), Attribute::ReadOnly); - AL = AL.addRetAttribute(M.getContext(), Attribute::ZExt); + AL = AL.addFnAttribute(C, Attribute::NoUnwind); + AL = AL.addFnAttribute( + C, Attribute::getWithMemoryEffects(C, MemoryEffects::readOnly())); + AL = AL.addRetAttribute(C, Attribute::ZExt); DFSanUnionLoadFn = Mod->getOrInsertFunction("__dfsan_union_load", DFSanUnionLoadFnTy, AL); } { AttributeList AL; - AL = AL.addFnAttribute(M.getContext(), Attribute::NoUnwind); - AL = AL.addFnAttribute(M.getContext(), Attribute::ReadOnly); - AL = AL.addRetAttribute(M.getContext(), Attribute::ZExt); + AL = AL.addFnAttribute(C, Attribute::NoUnwind); + AL = AL.addFnAttribute( + C, Attribute::getWithMemoryEffects(C, MemoryEffects::readOnly())); + AL = AL.addRetAttribute(C, Attribute::ZExt); DFSanLoadLabelAndOriginFn = Mod->getOrInsertFunction( "__dfsan_load_label_and_origin", DFSanLoadLabelAndOriginFnTy, AL); } @@ -1239,6 +1369,13 @@ void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { DFSanMemOriginTransferFn = Mod->getOrInsertFunction( "__dfsan_mem_origin_transfer", DFSanMemOriginTransferFnTy); + DFSanMemShadowOriginTransferFn = Mod->getOrInsertFunction( + "__dfsan_mem_shadow_origin_transfer", DFSanMemShadowOriginTransferFnTy); + + DFSanMemShadowOriginConditionalExchangeFn = + Mod->getOrInsertFunction("__dfsan_mem_shadow_origin_conditional_exchange", + DFSanMemShadowOriginConditionalExchangeFnTy); + { AttributeList AL; AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); @@ -1272,6 +1409,10 @@ void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { DFSanRuntimeFunctions.insert( DFSanConditionalCallbackOriginFn.getCallee()->stripPointerCasts()); DFSanRuntimeFunctions.insert( + DFSanReachesFunctionCallbackFn.getCallee()->stripPointerCasts()); + DFSanRuntimeFunctions.insert( + DFSanReachesFunctionCallbackOriginFn.getCallee()->stripPointerCasts()); + DFSanRuntimeFunctions.insert( DFSanCmpCallbackFn.getCallee()->stripPointerCasts()); DFSanRuntimeFunctions.insert( DFSanChainOriginFn.getCallee()->stripPointerCasts()); @@ -1280,48 +1421,67 @@ void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { DFSanRuntimeFunctions.insert( DFSanMemOriginTransferFn.getCallee()->stripPointerCasts()); DFSanRuntimeFunctions.insert( + DFSanMemShadowOriginTransferFn.getCallee()->stripPointerCasts()); + DFSanRuntimeFunctions.insert( + DFSanMemShadowOriginConditionalExchangeFn.getCallee() + ->stripPointerCasts()); + DFSanRuntimeFunctions.insert( DFSanMaybeStoreOriginFn.getCallee()->stripPointerCasts()); } // Initializes event callback functions and declare them in the module void DataFlowSanitizer::initializeCallbackFunctions(Module &M) { - DFSanLoadCallbackFn = Mod->getOrInsertFunction("__dfsan_load_callback", - DFSanLoadStoreCallbackFnTy); - DFSanStoreCallbackFn = Mod->getOrInsertFunction("__dfsan_store_callback", - DFSanLoadStoreCallbackFnTy); + { + AttributeList AL; + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + DFSanLoadCallbackFn = Mod->getOrInsertFunction( + "__dfsan_load_callback", DFSanLoadStoreCallbackFnTy, AL); + } + { + AttributeList AL; + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + DFSanStoreCallbackFn = Mod->getOrInsertFunction( + "__dfsan_store_callback", DFSanLoadStoreCallbackFnTy, AL); + } DFSanMemTransferCallbackFn = Mod->getOrInsertFunction( "__dfsan_mem_transfer_callback", DFSanMemTransferCallbackFnTy); - DFSanCmpCallbackFn = - Mod->getOrInsertFunction("__dfsan_cmp_callback", DFSanCmpCallbackFnTy); - - DFSanConditionalCallbackFn = Mod->getOrInsertFunction( - "__dfsan_conditional_callback", DFSanConditionalCallbackFnTy); - DFSanConditionalCallbackOriginFn = - Mod->getOrInsertFunction("__dfsan_conditional_callback_origin", - DFSanConditionalCallbackOriginFnTy); -} - -void DataFlowSanitizer::injectMetadataGlobals(Module &M) { - // These variables can be used: - // - by the runtime (to discover what the shadow width was, during - // compilation) - // - in testing (to avoid hardcoding the shadow width and type but instead - // extract them by pattern matching) - Type *IntTy = Type::getInt32Ty(*Ctx); - (void)Mod->getOrInsertGlobal("__dfsan_shadow_width_bits", IntTy, [&] { - return new GlobalVariable( - M, IntTy, /*isConstant=*/true, GlobalValue::WeakODRLinkage, - ConstantInt::get(IntTy, ShadowWidthBits), "__dfsan_shadow_width_bits"); - }); - (void)Mod->getOrInsertGlobal("__dfsan_shadow_width_bytes", IntTy, [&] { - return new GlobalVariable(M, IntTy, /*isConstant=*/true, - GlobalValue::WeakODRLinkage, - ConstantInt::get(IntTy, ShadowWidthBytes), - "__dfsan_shadow_width_bytes"); - }); + { + AttributeList AL; + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + DFSanCmpCallbackFn = Mod->getOrInsertFunction("__dfsan_cmp_callback", + DFSanCmpCallbackFnTy, AL); + } + { + AttributeList AL; + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + DFSanConditionalCallbackFn = Mod->getOrInsertFunction( + "__dfsan_conditional_callback", DFSanConditionalCallbackFnTy, AL); + } + { + AttributeList AL; + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + DFSanConditionalCallbackOriginFn = + Mod->getOrInsertFunction("__dfsan_conditional_callback_origin", + DFSanConditionalCallbackOriginFnTy, AL); + } + { + AttributeList AL; + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + DFSanReachesFunctionCallbackFn = + Mod->getOrInsertFunction("__dfsan_reaches_function_callback", + DFSanReachesFunctionCallbackFnTy, AL); + } + { + AttributeList AL; + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + DFSanReachesFunctionCallbackOriginFn = + Mod->getOrInsertFunction("__dfsan_reaches_function_callback_origin", + DFSanReachesFunctionCallbackOriginFnTy, AL); + } } -bool DataFlowSanitizer::runImpl(Module &M) { +bool DataFlowSanitizer::runImpl( + Module &M, llvm::function_ref<TargetLibraryInfo &(Function &)> GetTLI) { initializeModule(M); if (ABIList.isIn(M, "skip")) @@ -1362,8 +1522,6 @@ bool DataFlowSanitizer::runImpl(Module &M) { "__dfsan_track_origins"); }); - injectMetadataGlobals(M); - initializeCallbackFunctions(M); initializeRuntimeFunctions(M); @@ -1372,7 +1530,8 @@ bool DataFlowSanitizer::runImpl(Module &M) { SmallPtrSet<Function *, 2> FnsWithForceZeroLabel; SmallPtrSet<Constant *, 1> PersonalityFns; for (Function &F : M) - if (!F.isIntrinsic() && !DFSanRuntimeFunctions.contains(&F)) { + if (!F.isIntrinsic() && !DFSanRuntimeFunctions.contains(&F) && + !LibAtomicFunction(F)) { FnsToInstrument.push_back(&F); if (F.hasPersonalityFn()) PersonalityFns.insert(F.getPersonalityFn()->stripPointerCasts()); @@ -1383,9 +1542,7 @@ bool DataFlowSanitizer::runImpl(Module &M) { assert(isa<Function>(C) && "Personality routine is not a function!"); Function *F = cast<Function>(C); if (!isInstrumented(F)) - FnsToInstrument.erase( - std::remove(FnsToInstrument.begin(), FnsToInstrument.end(), F), - FnsToInstrument.end()); + llvm::erase_value(FnsToInstrument, F); } } @@ -1414,8 +1571,8 @@ bool DataFlowSanitizer::runImpl(Module &M) { } } - ReadOnlyNoneAttrs.addAttribute(Attribute::ReadOnly) - .addAttribute(Attribute::ReadNone); + // TODO: This could be more precise. + ReadOnlyNoneAttrs.addAttribute(Attribute::Memory); // First, change the ABI of every function in the module. ABI-listed // functions keep their original ABI and get a wrapper function. @@ -1464,8 +1621,8 @@ bool DataFlowSanitizer::runImpl(Module &M) { // br i1 icmp ne (i8 (i8)* @my_func, i8 (i8)* null), label %use_my_func, // label %avoid_my_func // The @"dfsw$my_func" wrapper is never null, so if we replace this use - // in the comparision, the icmp will simplify to false and we have - // accidentially optimized away a null check that is necessary. + // in the comparison, the icmp will simplify to false and we have + // accidentally optimized away a null check that is necessary. // This can lead to a crash when the null extern_weak my_func is called. // // To prevent (the most common pattern of) this problem, @@ -1525,7 +1682,32 @@ bool DataFlowSanitizer::runImpl(Module &M) { removeUnreachableBlocks(*F); DFSanFunction DFSF(*this, F, FnsWithNativeABI.count(F), - FnsWithForceZeroLabel.count(F)); + FnsWithForceZeroLabel.count(F), GetTLI(*F)); + + if (ClReachesFunctionCallbacks) { + // Add callback for arguments reaching this function. + for (auto &FArg : F->args()) { + Instruction *Next = &F->getEntryBlock().front(); + Value *FArgShadow = DFSF.getShadow(&FArg); + if (isZeroShadow(FArgShadow)) + continue; + if (Instruction *FArgShadowInst = dyn_cast<Instruction>(FArgShadow)) { + Next = FArgShadowInst->getNextNode(); + } + if (shouldTrackOrigins()) { + if (Instruction *Origin = + dyn_cast<Instruction>(DFSF.getOrigin(&FArg))) { + // Ensure IRB insertion point is after loads for shadow and origin. + Instruction *OriginNext = Origin->getNextNode(); + if (Next->comesBefore(OriginNext)) { + Next = OriginNext; + } + } + } + IRBuilder<> IRB(Next); + DFSF.addReachesFunctionCallbacksIfEnabled(IRB, *Next, &FArg); + } + } // DFSanVisitor may create new basic blocks, which confuses df_iterator. // Build a copy of the list before iterating over it. @@ -2209,6 +2391,7 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) { if (LI.isAtomic()) LI.setOrdering(addAcquireOrdering(LI.getOrdering())); + Instruction *AfterLi = LI.getNextNode(); Instruction *Pos = LI.isAtomic() ? LI.getNextNode() : &LI; std::vector<Value *> Shadows; std::vector<Value *> Origins; @@ -2244,8 +2427,13 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) { if (ClEventCallbacks) { IRBuilder<> IRB(Pos); Value *Addr8 = IRB.CreateBitCast(LI.getPointerOperand(), DFSF.DFS.Int8Ptr); - IRB.CreateCall(DFSF.DFS.DFSanLoadCallbackFn, {PrimitiveShadow, Addr8}); + CallInst *CI = + IRB.CreateCall(DFSF.DFS.DFSanLoadCallbackFn, {PrimitiveShadow, Addr8}); + CI->addParamAttr(0, Attribute::ZExt); } + + IRBuilder<> IRB(AfterLi); + DFSF.addReachesFunctionCallbacksIfEnabled(IRB, LI, &LI); } Value *DFSanFunction::updateOriginIfTainted(Value *Shadow, Value *Origin, @@ -2406,7 +2594,7 @@ void DFSanFunction::storePrimitiveShadowOrigin(Value *Addr, uint64_t Size, if (LeftSize >= ShadowVecSize) { auto *ShadowVecTy = FixedVectorType::get(DFS.PrimitiveShadowTy, ShadowVecSize); - Value *ShadowVec = UndefValue::get(ShadowVecTy); + Value *ShadowVec = PoisonValue::get(ShadowVecTy); for (unsigned I = 0; I != ShadowVecSize; ++I) { ShadowVec = IRB.CreateInsertElement( ShadowVec, PrimitiveShadow, @@ -2501,7 +2689,9 @@ void DFSanVisitor::visitStoreInst(StoreInst &SI) { if (ClEventCallbacks) { IRBuilder<> IRB(&SI); Value *Addr8 = IRB.CreateBitCast(SI.getPointerOperand(), DFSF.DFS.Int8Ptr); - IRB.CreateCall(DFSF.DFS.DFSanStoreCallbackFn, {PrimitiveShadow, Addr8}); + CallInst *CI = + IRB.CreateCall(DFSF.DFS.DFSanStoreCallbackFn, {PrimitiveShadow, Addr8}); + CI->addParamAttr(0, Attribute::ZExt); } } @@ -2563,7 +2753,9 @@ void DFSanVisitor::visitCmpInst(CmpInst &CI) { if (ClEventCallbacks) { IRBuilder<> IRB(&CI); Value *CombinedShadow = DFSF.getShadow(&CI); - IRB.CreateCall(DFSF.DFS.DFSanCmpCallbackFn, CombinedShadow); + CallInst *CallI = + IRB.CreateCall(DFSF.DFS.DFSanCmpCallbackFn, CombinedShadow); + CallI->addParamAttr(0, Attribute::ZExt); } } @@ -2983,6 +3175,146 @@ bool DFSanVisitor::visitWrappedCallBase(Function &F, CallBase &CB) { return false; } +Value *DFSanVisitor::makeAddAcquireOrderingTable(IRBuilder<> &IRB) { + constexpr int NumOrderings = (int)AtomicOrderingCABI::seq_cst + 1; + uint32_t OrderingTable[NumOrderings] = {}; + + OrderingTable[(int)AtomicOrderingCABI::relaxed] = + OrderingTable[(int)AtomicOrderingCABI::acquire] = + OrderingTable[(int)AtomicOrderingCABI::consume] = + (int)AtomicOrderingCABI::acquire; + OrderingTable[(int)AtomicOrderingCABI::release] = + OrderingTable[(int)AtomicOrderingCABI::acq_rel] = + (int)AtomicOrderingCABI::acq_rel; + OrderingTable[(int)AtomicOrderingCABI::seq_cst] = + (int)AtomicOrderingCABI::seq_cst; + + return ConstantDataVector::get(IRB.getContext(), + ArrayRef(OrderingTable, NumOrderings)); +} + +void DFSanVisitor::visitLibAtomicLoad(CallBase &CB) { + // Since we use getNextNode here, we can't have CB terminate the BB. + assert(isa<CallInst>(CB)); + + IRBuilder<> IRB(&CB); + Value *Size = CB.getArgOperand(0); + Value *SrcPtr = CB.getArgOperand(1); + Value *DstPtr = CB.getArgOperand(2); + Value *Ordering = CB.getArgOperand(3); + // Convert the call to have at least Acquire ordering to make sure + // the shadow operations aren't reordered before it. + Value *NewOrdering = + IRB.CreateExtractElement(makeAddAcquireOrderingTable(IRB), Ordering); + CB.setArgOperand(3, NewOrdering); + + IRBuilder<> NextIRB(CB.getNextNode()); + NextIRB.SetCurrentDebugLocation(CB.getDebugLoc()); + + // TODO: Support ClCombinePointerLabelsOnLoad + // TODO: Support ClEventCallbacks + + NextIRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn, + {NextIRB.CreatePointerCast(DstPtr, NextIRB.getInt8PtrTy()), + NextIRB.CreatePointerCast(SrcPtr, NextIRB.getInt8PtrTy()), + NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); +} + +Value *DFSanVisitor::makeAddReleaseOrderingTable(IRBuilder<> &IRB) { + constexpr int NumOrderings = (int)AtomicOrderingCABI::seq_cst + 1; + uint32_t OrderingTable[NumOrderings] = {}; + + OrderingTable[(int)AtomicOrderingCABI::relaxed] = + OrderingTable[(int)AtomicOrderingCABI::release] = + (int)AtomicOrderingCABI::release; + OrderingTable[(int)AtomicOrderingCABI::consume] = + OrderingTable[(int)AtomicOrderingCABI::acquire] = + OrderingTable[(int)AtomicOrderingCABI::acq_rel] = + (int)AtomicOrderingCABI::acq_rel; + OrderingTable[(int)AtomicOrderingCABI::seq_cst] = + (int)AtomicOrderingCABI::seq_cst; + + return ConstantDataVector::get(IRB.getContext(), + ArrayRef(OrderingTable, NumOrderings)); +} + +void DFSanVisitor::visitLibAtomicStore(CallBase &CB) { + IRBuilder<> IRB(&CB); + Value *Size = CB.getArgOperand(0); + Value *SrcPtr = CB.getArgOperand(1); + Value *DstPtr = CB.getArgOperand(2); + Value *Ordering = CB.getArgOperand(3); + // Convert the call to have at least Release ordering to make sure + // the shadow operations aren't reordered after it. + Value *NewOrdering = + IRB.CreateExtractElement(makeAddReleaseOrderingTable(IRB), Ordering); + CB.setArgOperand(3, NewOrdering); + + // TODO: Support ClCombinePointerLabelsOnStore + // TODO: Support ClEventCallbacks + + IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn, + {IRB.CreatePointerCast(DstPtr, IRB.getInt8PtrTy()), + IRB.CreatePointerCast(SrcPtr, IRB.getInt8PtrTy()), + IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); +} + +void DFSanVisitor::visitLibAtomicExchange(CallBase &CB) { + // void __atomic_exchange(size_t size, void *ptr, void *val, void *ret, int + // ordering) + IRBuilder<> IRB(&CB); + Value *Size = CB.getArgOperand(0); + Value *TargetPtr = CB.getArgOperand(1); + Value *SrcPtr = CB.getArgOperand(2); + Value *DstPtr = CB.getArgOperand(3); + + // This operation is not atomic for the shadow and origin memory. + // This could result in DFSan false positives or false negatives. + // For now we will assume these operations are rare, and + // the additional complexity to address this is not warrented. + + // Current Target to Dest + IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn, + {IRB.CreatePointerCast(DstPtr, IRB.getInt8PtrTy()), + IRB.CreatePointerCast(TargetPtr, IRB.getInt8PtrTy()), + IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); + + // Current Src to Target (overriding) + IRB.CreateCall(DFSF.DFS.DFSanMemShadowOriginTransferFn, + {IRB.CreatePointerCast(TargetPtr, IRB.getInt8PtrTy()), + IRB.CreatePointerCast(SrcPtr, IRB.getInt8PtrTy()), + IRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); +} + +void DFSanVisitor::visitLibAtomicCompareExchange(CallBase &CB) { + // bool __atomic_compare_exchange(size_t size, void *ptr, void *expected, void + // *desired, int success_order, int failure_order) + Value *Size = CB.getArgOperand(0); + Value *TargetPtr = CB.getArgOperand(1); + Value *ExpectedPtr = CB.getArgOperand(2); + Value *DesiredPtr = CB.getArgOperand(3); + + // This operation is not atomic for the shadow and origin memory. + // This could result in DFSan false positives or false negatives. + // For now we will assume these operations are rare, and + // the additional complexity to address this is not warrented. + + IRBuilder<> NextIRB(CB.getNextNode()); + NextIRB.SetCurrentDebugLocation(CB.getDebugLoc()); + + DFSF.setShadow(&CB, DFSF.DFS.getZeroShadow(&CB)); + + // If original call returned true, copy Desired to Target. + // If original call returned false, copy Target to Expected. + NextIRB.CreateCall( + DFSF.DFS.DFSanMemShadowOriginConditionalExchangeFn, + {NextIRB.CreateIntCast(&CB, NextIRB.getInt8Ty(), false), + NextIRB.CreatePointerCast(TargetPtr, NextIRB.getInt8PtrTy()), + NextIRB.CreatePointerCast(ExpectedPtr, NextIRB.getInt8PtrTy()), + NextIRB.CreatePointerCast(DesiredPtr, NextIRB.getInt8PtrTy()), + NextIRB.CreateIntCast(Size, DFSF.DFS.IntptrTy, false)}); +} + void DFSanVisitor::visitCallBase(CallBase &CB) { Function *F = CB.getCalledFunction(); if ((F && F->isIntrinsic()) || CB.isInlineAsm()) { @@ -2995,6 +3327,40 @@ void DFSanVisitor::visitCallBase(CallBase &CB) { if (F == DFSF.DFS.DFSanVarargWrapperFn.getCallee()->stripPointerCasts()) return; + LibFunc LF; + if (DFSF.TLI.getLibFunc(CB, LF)) { + // libatomic.a functions need to have special handling because there isn't + // a good way to intercept them or compile the library with + // instrumentation. + switch (LF) { + case LibFunc_atomic_load: + if (!isa<CallInst>(CB)) { + llvm::errs() << "DFSAN -- cannot instrument invoke of libatomic load. " + "Ignoring!\n"; + break; + } + visitLibAtomicLoad(CB); + return; + case LibFunc_atomic_store: + visitLibAtomicStore(CB); + return; + default: + break; + } + } + + // TODO: These are not supported by TLI? They are not in the enum. + if (F && F->hasName() && !F->isVarArg()) { + if (F->getName() == "__atomic_exchange") { + visitLibAtomicExchange(CB); + return; + } + if (F->getName() == "__atomic_compare_exchange") { + visitLibAtomicCompareExchange(CB); + return; + } + } + DenseMap<Value *, Function *>::iterator UnwrappedFnIt = DFSF.DFS.UnwrappedFnMap.find(CB.getCalledOperand()); if (UnwrappedFnIt != DFSF.DFS.UnwrappedFnMap.end()) @@ -3071,6 +3437,8 @@ void DFSanVisitor::visitCallBase(CallBase &CB) { DFSF.SkipInsts.insert(LI); DFSF.setOrigin(&CB, LI); } + + DFSF.addReachesFunctionCallbacksIfEnabled(NextIRB, CB, &CB); } } @@ -3099,38 +3467,20 @@ void DFSanVisitor::visitPHINode(PHINode &PN) { DFSF.PHIFixups.push_back({&PN, ShadowPN, OriginPN}); } -namespace { -class DataFlowSanitizerLegacyPass : public ModulePass { -private: - std::vector<std::string> ABIListFiles; - -public: - static char ID; - - DataFlowSanitizerLegacyPass( - const std::vector<std::string> &ABIListFiles = std::vector<std::string>()) - : ModulePass(ID), ABIListFiles(ABIListFiles) {} - - bool runOnModule(Module &M) override { - return DataFlowSanitizer(ABIListFiles).runImpl(M); - } -}; -} // namespace - -char DataFlowSanitizerLegacyPass::ID; - -INITIALIZE_PASS(DataFlowSanitizerLegacyPass, "dfsan", - "DataFlowSanitizer: dynamic data flow analysis.", false, false) - -ModulePass *llvm::createDataFlowSanitizerLegacyPassPass( - const std::vector<std::string> &ABIListFiles) { - return new DataFlowSanitizerLegacyPass(ABIListFiles); -} - PreservedAnalyses DataFlowSanitizerPass::run(Module &M, ModuleAnalysisManager &AM) { - if (DataFlowSanitizer(ABIListFiles).runImpl(M)) { - return PreservedAnalyses::none(); - } - return PreservedAnalyses::all(); + auto GetTLI = [&](Function &F) -> TargetLibraryInfo & { + auto &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + return FAM.getResult<TargetLibraryAnalysis>(F); + }; + if (!DataFlowSanitizer(ABIListFiles).runImpl(M, GetTLI)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA = PreservedAnalyses::none(); + // GlobalsAA is considered stateless and does not get invalidated unless + // explicitly invalidated; PreservedAnalyses::none() is not enough. Sanitizers + // make changes that require GlobalsAA to be invalidated. + PA.abandon<GlobalsAA>(); + return PA; } diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp index ac4a1fd6bb7e..9f3ca8b02fd9 100644 --- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -119,7 +119,8 @@ private: function_ref<BranchProbabilityInfo *(Function &F)> GetBPI, function_ref<const TargetLibraryInfo &(Function &F)> GetTLI); - Function *createInternalFunction(FunctionType *FTy, StringRef Name); + Function *createInternalFunction(FunctionType *FTy, StringRef Name, + StringRef MangledType = ""); void emitGlobalConstructor( SmallVectorImpl<std::pair<GlobalVariable *, MDNode *>> &CountersBySP); @@ -251,8 +252,8 @@ namespace { void writeOut() { write(0); writeString(Filename); - for (int i = 0, e = Lines.size(); i != e; ++i) - write(Lines[i]); + for (uint32_t L : Lines) + write(L); } GCOVLines(GCOVProfiler *P, StringRef F) @@ -595,8 +596,8 @@ static bool functionHasLines(const Function &F, unsigned &EndLine) { // Check whether this function actually has any source lines. Not only // do these waste space, they also can crash gcov. EndLine = 0; - for (auto &BB : F) { - for (auto &I : BB) { + for (const auto &BB : F) { + for (const auto &I : BB) { // Debug intrinsic locations correspond to the location of the // declaration, not necessarily any statements or expressions. if (isa<DbgInfoIntrinsic>(&I)) continue; @@ -623,10 +624,11 @@ static bool isUsingScopeBasedEH(Function &F) { } bool GCOVProfiler::AddFlushBeforeForkAndExec() { + const TargetLibraryInfo *TLI = nullptr; SmallVector<CallInst *, 2> Forks; SmallVector<CallInst *, 2> Execs; for (auto &F : M->functions()) { - auto *TLI = &GetTLI(F); + TLI = TLI == nullptr ? &GetTLI(F) : TLI; for (auto &I : instructions(F)) { if (CallInst *CI = dyn_cast<CallInst>(&I)) { if (Function *Callee = CI->getCalledFunction()) { @@ -648,14 +650,16 @@ bool GCOVProfiler::AddFlushBeforeForkAndExec() { } } - for (auto F : Forks) { + for (auto *F : Forks) { IRBuilder<> Builder(F); BasicBlock *Parent = F->getParent(); auto NextInst = ++F->getIterator(); // We've a fork so just reset the counters in the child process FunctionType *FTy = FunctionType::get(Builder.getInt32Ty(), {}, false); - FunctionCallee GCOVFork = M->getOrInsertFunction("__gcov_fork", FTy); + FunctionCallee GCOVFork = M->getOrInsertFunction( + "__gcov_fork", FTy, + TLI->getAttrList(Ctx, {}, /*Signed=*/true, /*Ret=*/true)); F->setCalledFunction(GCOVFork); // We split just after the fork to have a counter for the lines after @@ -673,7 +677,7 @@ bool GCOVProfiler::AddFlushBeforeForkAndExec() { Parent->back().setDebugLoc(Loc); } - for (auto E : Execs) { + for (auto *E : Execs) { IRBuilder<> Builder(E); BasicBlock *Parent = E->getParent(); auto NextInst = ++E->getIterator(); @@ -797,6 +801,8 @@ bool GCOVProfiler::emitProfileNotes( if (isUsingScopeBasedEH(F)) continue; if (F.hasFnAttribute(llvm::Attribute::NoProfile)) continue; + if (F.hasFnAttribute(llvm::Attribute::SkipProfile)) + continue; // Add the function line number to the lines of the entry block // to have a counter for the function definition. @@ -877,7 +883,7 @@ bool GCOVProfiler::emitProfileNotes( while ((Idx >>= 8) > 0); } - for (auto &I : BB) { + for (const auto &I : BB) { // Debug intrinsic locations correspond to the location of the // declaration, not necessarily any statements or expressions. if (isa<DbgInfoIntrinsic>(&I)) continue; @@ -974,13 +980,16 @@ bool GCOVProfiler::emitProfileNotes( } Function *GCOVProfiler::createInternalFunction(FunctionType *FTy, - StringRef Name) { + StringRef Name, + StringRef MangledType /*=""*/) { Function *F = Function::createWithDefaultAttr( FTy, GlobalValue::InternalLinkage, 0, Name, M); F->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); F->addFnAttr(Attribute::NoUnwind); if (Options.NoRedZone) F->addFnAttr(Attribute::NoRedZone); + if (!MangledType.empty()) + setKCFIType(*M, *F, MangledType); return F; } @@ -993,7 +1002,7 @@ void GCOVProfiler::emitGlobalConstructor( // be executed at exit and the "__llvm_gcov_reset" function to be executed // when "__gcov_flush" is called. FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); - Function *F = createInternalFunction(FTy, "__llvm_gcov_init"); + Function *F = createInternalFunction(FTy, "__llvm_gcov_init", "_ZTSFvvE"); F->addFnAttr(Attribute::NoInline); BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", F); @@ -1019,11 +1028,8 @@ FunctionCallee GCOVProfiler::getStartFileFunc(const TargetLibraryInfo *TLI) { Type::getInt32Ty(*Ctx), // uint32_t checksum }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - AttributeList AL; - if (auto AK = TLI->getExtAttrForI32Param(false)) - AL = AL.addParamAttribute(*Ctx, 2, AK); - FunctionCallee Res = M->getOrInsertFunction("llvm_gcda_start_file", FTy, AL); - return Res; + return M->getOrInsertFunction("llvm_gcda_start_file", FTy, + TLI->getAttrList(Ctx, {1, 2}, /*Signed=*/false)); } FunctionCallee GCOVProfiler::getEmitFunctionFunc(const TargetLibraryInfo *TLI) { @@ -1033,13 +1039,8 @@ FunctionCallee GCOVProfiler::getEmitFunctionFunc(const TargetLibraryInfo *TLI) { Type::getInt32Ty(*Ctx), // uint32_t cfg_checksum }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - AttributeList AL; - if (auto AK = TLI->getExtAttrForI32Param(false)) { - AL = AL.addParamAttribute(*Ctx, 0, AK); - AL = AL.addParamAttribute(*Ctx, 1, AK); - AL = AL.addParamAttribute(*Ctx, 2, AK); - } - return M->getOrInsertFunction("llvm_gcda_emit_function", FTy); + return M->getOrInsertFunction("llvm_gcda_emit_function", FTy, + TLI->getAttrList(Ctx, {0, 1, 2}, /*Signed=*/false)); } FunctionCallee GCOVProfiler::getEmitArcsFunc(const TargetLibraryInfo *TLI) { @@ -1048,10 +1049,8 @@ FunctionCallee GCOVProfiler::getEmitArcsFunc(const TargetLibraryInfo *TLI) { Type::getInt64PtrTy(*Ctx), // uint64_t *counters }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - AttributeList AL; - if (auto AK = TLI->getExtAttrForI32Param(false)) - AL = AL.addParamAttribute(*Ctx, 0, AK); - return M->getOrInsertFunction("llvm_gcda_emit_arcs", FTy, AL); + return M->getOrInsertFunction("llvm_gcda_emit_arcs", FTy, + TLI->getAttrList(Ctx, {0}, /*Signed=*/false)); } FunctionCallee GCOVProfiler::getSummaryInfoFunc() { @@ -1069,7 +1068,8 @@ Function *GCOVProfiler::insertCounterWriteout( FunctionType *WriteoutFTy = FunctionType::get(Type::getVoidTy(*Ctx), false); Function *WriteoutF = M->getFunction("__llvm_gcov_writeout"); if (!WriteoutF) - WriteoutF = createInternalFunction(WriteoutFTy, "__llvm_gcov_writeout"); + WriteoutF = + createInternalFunction(WriteoutFTy, "__llvm_gcov_writeout", "_ZTSFvvE"); WriteoutF->addFnAttr(Attribute::NoInline); BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", WriteoutF); @@ -1315,7 +1315,7 @@ Function *GCOVProfiler::insertReset( FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); Function *ResetF = M->getFunction("__llvm_gcov_reset"); if (!ResetF) - ResetF = createInternalFunction(FTy, "__llvm_gcov_reset"); + ResetF = createInternalFunction(FTy, "__llvm_gcov_reset", "_ZTSFvvE"); ResetF->addFnAttr(Attribute::NoInline); BasicBlock *Entry = BasicBlock::Create(*Ctx, "entry", ResetF); diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index b01c74320380..34c61f83ad30 100644 --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/ValueTracking.h" @@ -42,6 +43,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/NoFolder.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" @@ -53,6 +55,7 @@ #include "llvm/Transforms/Utils/MemoryTaggingSupport.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include <optional> using namespace llvm; @@ -307,7 +310,6 @@ public: void getInterestingMemoryOperands( Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting); - bool isInterestingAlloca(const AllocaInst &AI); void tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size); Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong); @@ -357,14 +359,14 @@ private: /// If WithFrameRecord is true, then __hwasan_tls will be used to access the /// ring buffer for storing stack allocations on targets that support it. struct ShadowMapping { - int Scale; + uint8_t Scale; uint64_t Offset; bool InGlobal; bool InTls; bool WithFrameRecord; void init(Triple &TargetTriple, bool InstrumentWithCalls); - uint64_t getObjectAlignment() const { return 1ULL << Scale; } + Align getObjectAlignment() const { return Align(1ULL << Scale); } }; ShadowMapping Mapping; @@ -386,8 +388,7 @@ private: bool DetectUseAfterScope; bool UsePageAliases; - bool HasMatchAllTag = false; - uint8_t MatchAllTag = 0; + std::optional<uint8_t> MatchAllTag; unsigned PointerTagShift; uint64_t TagMaskByte; @@ -423,9 +424,15 @@ PreservedAnalyses HWAddressSanitizerPass::run(Module &M, auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); for (Function &F : M) Modified |= HWASan.sanitizeFunction(F, FAM); - if (Modified) - return PreservedAnalyses::none(); - return PreservedAnalyses::all(); + if (!Modified) + return PreservedAnalyses::all(); + + PreservedAnalyses PA = PreservedAnalyses::none(); + // GlobalsAA is considered stateless and does not get invalidated unless + // explicitly invalidated; PreservedAnalyses::none() is not enough. Sanitizers + // make changes that require GlobalsAA to be invalidated. + PA.abandon<GlobalsAA>(); + return PA; } void HWAddressSanitizerPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { @@ -487,16 +494,14 @@ void HWAddressSanitizer::createHwasanCtorComdat() { Comdat *NoteComdat = M.getOrInsertComdat(kHwasanModuleCtorName); Type *Int8Arr0Ty = ArrayType::get(Int8Ty, 0); - auto Start = + auto *Start = new GlobalVariable(M, Int8Arr0Ty, true, GlobalVariable::ExternalLinkage, nullptr, "__start_hwasan_globals"); Start->setVisibility(GlobalValue::HiddenVisibility); - Start->setDSOLocal(true); - auto Stop = + auto *Stop = new GlobalVariable(M, Int8Arr0Ty, true, GlobalVariable::ExternalLinkage, nullptr, "__stop_hwasan_globals"); Stop->setVisibility(GlobalValue::HiddenVisibility); - Stop->setDSOLocal(true); // Null-terminated so actually 8 bytes, which are required in order to align // the note properly. @@ -510,7 +515,6 @@ void HWAddressSanitizer::createHwasanCtorComdat() { Note->setSection(".note.hwasan.globals"); Note->setComdat(NoteComdat); Note->setAlignment(Align(4)); - Note->setDSOLocal(true); // The pointers in the note need to be relative so that the note ends up being // placed in rodata, which is the standard location for notes. @@ -529,7 +533,7 @@ void HWAddressSanitizer::createHwasanCtorComdat() { // Create a zero-length global in hwasan_globals so that the linker will // always create start and stop symbols. - auto Dummy = new GlobalVariable( + auto *Dummy = new GlobalVariable( M, Int8Arr0Ty, /*isConstantGlobal*/ true, GlobalVariable::PrivateLinkage, Constant::getNullValue(Int8Arr0Ty), "hwasan.dummy.global"); Dummy->setSection("hwasan_globals"); @@ -579,16 +583,15 @@ void HWAddressSanitizer::initializeModule() { UseShortGranules = ClUseShortGranules.getNumOccurrences() ? ClUseShortGranules : NewRuntime; OutlinedChecks = - TargetTriple.isAArch64() && TargetTriple.isOSBinFormatELF() && + (TargetTriple.isAArch64() || TargetTriple.isRISCV64()) && + TargetTriple.isOSBinFormatELF() && (ClInlineAllChecks.getNumOccurrences() ? !ClInlineAllChecks : !Recover); if (ClMatchAllTag.getNumOccurrences()) { if (ClMatchAllTag != -1) { - HasMatchAllTag = true; MatchAllTag = ClMatchAllTag & 0xFF; } } else if (CompileKernel) { - HasMatchAllTag = true; MatchAllTag = 0xFF; } @@ -697,18 +700,17 @@ Value *HWAddressSanitizer::getShadowNonTls(IRBuilder<> &IRB) { IRB, ConstantExpr::getIntToPtr( ConstantInt::get(IntptrTy, Mapping.Offset), Int8PtrTy)); - if (Mapping.InGlobal) { + if (Mapping.InGlobal) return getDynamicShadowIfunc(IRB); - } else { - Value *GlobalDynamicAddress = - IRB.GetInsertBlock()->getParent()->getParent()->getOrInsertGlobal( - kHwasanShadowMemoryDynamicAddress, Int8PtrTy); - return IRB.CreateLoad(Int8PtrTy, GlobalDynamicAddress); - } + + Value *GlobalDynamicAddress = + IRB.GetInsertBlock()->getParent()->getParent()->getOrInsertGlobal( + kHwasanShadowMemoryDynamicAddress, Int8PtrTy); + return IRB.CreateLoad(Int8PtrTy, GlobalDynamicAddress); } bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { - // Do not instrument acesses from different address spaces; we cannot deal + // Do not instrument accesses from different address spaces; we cannot deal // with them. Type *PtrTy = cast<PointerType>(Ptr->getType()->getScalarType()); if (PtrTy->getPointerAddressSpace() != 0) @@ -754,13 +756,14 @@ void HWAddressSanitizer::getInterestingMemoryOperands( if (!ClInstrumentAtomics || ignoreAccess(I, RMW->getPointerOperand())) return; Interesting.emplace_back(I, RMW->getPointerOperandIndex(), true, - RMW->getValOperand()->getType(), None); + RMW->getValOperand()->getType(), std::nullopt); } else if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) { if (!ClInstrumentAtomics || ignoreAccess(I, XCHG->getPointerOperand())) return; Interesting.emplace_back(I, XCHG->getPointerOperandIndex(), true, - XCHG->getCompareOperand()->getType(), None); - } else if (auto CI = dyn_cast<CallInst>(I)) { + XCHG->getCompareOperand()->getType(), + std::nullopt); + } else if (auto *CI = dyn_cast<CallInst>(I)) { for (unsigned ArgNo = 0; ArgNo < CI->arg_size(); ArgNo++) { if (!ClInstrumentByval || !CI->isByValArgument(ArgNo) || ignoreAccess(I, CI->getArgOperand(ArgNo))) @@ -791,7 +794,8 @@ static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { } void HWAddressSanitizer::untagPointerOperand(Instruction *I, Value *Addr) { - if (TargetTriple.isAArch64() || TargetTriple.getArch() == Triple::x86_64) + if (TargetTriple.isAArch64() || TargetTriple.getArch() == Triple::x86_64 || + TargetTriple.isRISCV64()) return; IRBuilder<> IRB(I); @@ -812,11 +816,11 @@ Value *HWAddressSanitizer::memToShadow(Value *Mem, IRBuilder<> &IRB) { int64_t HWAddressSanitizer::getAccessInfo(bool IsWrite, unsigned AccessSizeIndex) { - return (CompileKernel << HWASanAccessInfo::CompileKernelShift) + - (HasMatchAllTag << HWASanAccessInfo::HasMatchAllShift) + - (MatchAllTag << HWASanAccessInfo::MatchAllShift) + - (Recover << HWASanAccessInfo::RecoverShift) + - (IsWrite << HWASanAccessInfo::IsWriteShift) + + return (CompileKernel << HWASanAccessInfo::CompileKernelShift) | + (MatchAllTag.has_value() << HWASanAccessInfo::HasMatchAllShift) | + (MatchAllTag.value_or(0) << HWASanAccessInfo::MatchAllShift) | + (Recover << HWASanAccessInfo::RecoverShift) | + (IsWrite << HWASanAccessInfo::IsWriteShift) | (AccessSizeIndex << HWASanAccessInfo::AccessSizeShift); } @@ -850,9 +854,9 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, Value *MemTag = IRB.CreateLoad(Int8Ty, Shadow); Value *TagMismatch = IRB.CreateICmpNE(PtrTag, MemTag); - if (HasMatchAllTag) { + if (MatchAllTag.has_value()) { Value *TagNotIgnored = IRB.CreateICmpNE( - PtrTag, ConstantInt::get(PtrTag->getType(), MatchAllTag)); + PtrTag, ConstantInt::get(PtrTag->getType(), *MatchAllTag)); TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored); } @@ -909,6 +913,15 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, "{x0}", /*hasSideEffects=*/true); break; + case Triple::riscv64: + // The signal handler will find the data address in x10. + Asm = InlineAsm::get( + FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), + "ebreak\naddiw x0, x11, " + + itostr(0x40 + (AccessInfo & HWASanAccessInfo::RuntimeMask)), + "{x10}", + /*hasSideEffects=*/true); + break; default: report_fatal_error("unsupported architecture"); } @@ -956,7 +969,7 @@ bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) { IRBuilder<> IRB(O.getInsn()); if (isPowerOf2_64(O.TypeSize) && (O.TypeSize / 8 <= (1ULL << (kNumberOfAccessSizes - 1))) && - (!O.Alignment || *O.Alignment >= (1ULL << Mapping.Scale) || + (!O.Alignment || *O.Alignment >= Mapping.getObjectAlignment() || *O.Alignment >= O.TypeSize / 8)) { size_t AccessSizeIndex = TypeSizeToSizeIndex(O.TypeSize); if (InstrumentWithCalls) { @@ -1000,9 +1013,9 @@ void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, if (ShadowSize) IRB.CreateMemSet(ShadowPtr, JustTag, ShadowSize, Align(1)); if (Size != AlignedSize) { - IRB.CreateStore( - ConstantInt::get(Int8Ty, Size % Mapping.getObjectAlignment()), - IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize)); + const uint8_t SizeRemainder = Size % Mapping.getObjectAlignment().value(); + IRB.CreateStore(ConstantInt::get(Int8Ty, SizeRemainder), + IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize)); IRB.CreateStore(JustTag, IRB.CreateConstGEP1_32( Int8Ty, IRB.CreateBitCast(AI, Int8PtrTy), AlignedSize - 1)); @@ -1028,7 +1041,7 @@ unsigned HWAddressSanitizer::retagMask(unsigned AllocaNo) { 48, 16, 120, 248, 56, 24, 8, 124, 252, 60, 28, 12, 4, 126, 254, 62, 30, 14, 6, 2, 127, 63, 31, 15, 7, 3, 1}; - return FastMasks[AllocaNo % (sizeof(FastMasks) / sizeof(FastMasks[0]))]; + return FastMasks[AllocaNo % std::size(FastMasks)]; } Value *HWAddressSanitizer::applyTagMask(IRBuilder<> &IRB, Value *OldTag) { @@ -1136,8 +1149,7 @@ Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) { Value *HWAddressSanitizer::getPC(IRBuilder<> &IRB) { if (TargetTriple.getArch() == Triple::aarch64) return readRegister(IRB, "pc"); - else - return IRB.CreatePtrToInt(IRB.GetInsertBlock()->getParent(), IntptrTy); + return IRB.CreatePtrToInt(IRB.GetInsertBlock()->getParent(), IntptrTy); } Value *HWAddressSanitizer::getSP(IRBuilder<> &IRB) { @@ -1146,7 +1158,7 @@ Value *HWAddressSanitizer::getSP(IRBuilder<> &IRB) { // first). Function *F = IRB.GetInsertBlock()->getParent(); Module *M = F->getParent(); - auto GetStackPointerFn = Intrinsic::getDeclaration( + auto *GetStackPointerFn = Intrinsic::getDeclaration( M, Intrinsic::frameaddress, IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); CachedSP = IRB.CreatePtrToInt( @@ -1383,31 +1395,13 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, for (auto &II : Info.LifetimeEnd) II->eraseFromParent(); } - memtag::alignAndPadAlloca(Info, Align(Mapping.getObjectAlignment())); + memtag::alignAndPadAlloca(Info, Mapping.getObjectAlignment()); } for (auto &I : SInfo.UnrecognizedLifetimes) I->eraseFromParent(); return true; } -bool HWAddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { - return (AI.getAllocatedType()->isSized() && - // FIXME: instrument dynamic allocas, too - AI.isStaticAlloca() && - // alloca() may be called with 0 size, ignore it. - memtag::getAllocaSizeInBytes(AI) > 0 && - // We are only interested in allocas not promotable to registers. - // Promotable allocas are common under -O0. - !isAllocaPromotable(&AI) && - // inalloca allocas are not treated as static, and we don't want - // dynamic alloca instrumentation for them as well. - !AI.isUsedWithInAlloca() && - // swifterror allocas are register promoted by ISel - !AI.isSwiftError()) && - // safe allocas are not interesting - !(SSI && SSI->isSafe(AI)); -} - bool HWAddressSanitizer::sanitizeFunction(Function &F, FunctionAnalysisManager &FAM) { if (&F == HwasanCtorFunction) @@ -1422,8 +1416,7 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F, SmallVector<MemIntrinsic *, 16> IntrinToInstrument; SmallVector<Instruction *, 8> LandingPadVec; - memtag::StackInfoBuilder SIB( - [this](const AllocaInst &AI) { return isInterestingAlloca(AI); }); + memtag::StackInfoBuilder SIB(SSI); for (auto &Inst : instructions(F)) { if (InstrumentStack) { SIB.visit(Inst); @@ -1495,8 +1488,8 @@ bool HWAddressSanitizer::sanitizeFunction(Function &F, instrumentMemAccess(Operand); if (ClInstrumentMemIntrinsics && !IntrinToInstrument.empty()) { - for (auto Inst : IntrinToInstrument) - instrumentMemIntrinsic(cast<MemIntrinsic>(Inst)); + for (auto *Inst : IntrinToInstrument) + instrumentMemIntrinsic(Inst); } ShadowBase = nullptr; @@ -1528,7 +1521,7 @@ void HWAddressSanitizer::instrumentGlobal(GlobalVariable *GV, uint8_t Tag) { NewGV->setLinkage(GlobalValue::PrivateLinkage); NewGV->copyMetadata(GV, 0); NewGV->setAlignment( - MaybeAlign(std::max(GV->getAlignment(), Mapping.getObjectAlignment()))); + std::max(GV->getAlign().valueOrOne(), Mapping.getObjectAlignment())); // It is invalid to ICF two globals that have different tags. In the case // where the size of the global is a multiple of the tag granularity the diff --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 3ef06907dfee..b66e761d53b0 100644 --- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -281,7 +281,7 @@ uint32_t ICallPromotionFunc::tryToPromote( uint64_t &TotalCount) { uint32_t NumPromoted = 0; - for (auto &C : Candidates) { + for (const auto &C : Candidates) { uint64_t Count = C.Count; pgo::promoteIndirectCall(CB, C.TargetFunction, Count, TotalCount, SamplePGO, &ORE); diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index 5b7aa304b987..c0409206216e 100644 --- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -259,7 +259,7 @@ public: // of the loop, the result profile is incomplete. // FIXME: add other heuristics to detect long running loops. if (SkipRetExitBlock) { - for (auto BB : ExitBlocks) + for (auto *BB : ExitBlocks) if (isa<ReturnInst>(BB->getTerminator())) return false; } @@ -525,15 +525,15 @@ bool InstrProfiling::run( TT = Triple(M.getTargetTriple()); bool MadeChange = false; - - // Emit the runtime hook even if no counters are present. - if (needsRuntimeHookUnconditionally(TT)) + bool NeedsRuntimeHook = needsRuntimeHookUnconditionally(TT); + if (NeedsRuntimeHook) MadeChange = emitRuntimeHook(); - // Improve compile time by avoiding linear scans when there is no work. + bool ContainsProfiling = containsProfilingIntrinsics(M); GlobalVariable *CoverageNamesVar = M.getNamedGlobal(getCoverageUnusedNamesVarName()); - if (!containsProfilingIntrinsics(M) && !CoverageNamesVar) + // Improve compile time by avoiding linear scans when there is no work. + if (!ContainsProfiling && !CoverageNamesVar) return MadeChange; // We did not know how many value sites there would be inside @@ -567,7 +567,14 @@ bool InstrProfiling::run( emitVNodes(); emitNameData(); - emitRuntimeHook(); + + // Emit runtime hook for the cases where the target does not unconditionally + // require pulling in profile runtime, and coverage is enabled on code that is + // not eliminated by the front-end, e.g. unused functions with internal + // linkage. + if (!NeedsRuntimeHook && ContainsProfiling) + emitRuntimeHook(); + emitRegistration(); emitUses(); emitInitialization(); @@ -592,7 +599,7 @@ static FunctionCallee getOrInsertValueProfilingCall( #include "llvm/ProfileData/InstrProfData.inc" }; auto *ValueProfilingCallTy = - FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false); + FunctionType::get(ReturnTy, ArrayRef(ParamTypes), false); StringRef FuncName = CallType == ValueProfilingCallType::Default ? getInstrProfValueProfFuncName() : getInstrProfValueProfMemOpFuncName(); @@ -914,6 +921,11 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { if (!NeedComdat) C->setSelectionKind(Comdat::NoDeduplicate); GV->setComdat(C); + // COFF doesn't allow the comdat group leader to have private linkage, so + // upgrade private linkage to internal linkage to produce a symbol table + // entry. + if (TT.isOSBinFormatCOFF() && GV->hasPrivateLinkage()) + GV->setLinkage(GlobalValue::InternalLinkage); } }; @@ -924,8 +936,8 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { CounterPtr->setVisibility(Visibility); CounterPtr->setSection( getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat())); - MaybeSetComdat(CounterPtr); CounterPtr->setLinkage(Linkage); + MaybeSetComdat(CounterPtr); PD.RegionCounters = CounterPtr; if (DebugInfoCorrelate) { if (auto *SP = Fn->getSubprogram()) { @@ -1000,7 +1012,7 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { #define INSTR_PROF_DATA(Type, LLVMType, Name, Init) LLVMType, #include "llvm/ProfileData/InstrProfData.inc" }; - auto *DataTy = StructType::get(Ctx, makeArrayRef(DataTypes)); + auto *DataTy = StructType::get(Ctx, ArrayRef(DataTypes)); Constant *FunctionAddr = shouldRecordFunctionAddr(Fn) ? ConstantExpr::getBitCast(Fn, Int8PtrTy) @@ -1045,7 +1057,6 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfInstBase *Inc) { Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat())); Data->setAlignment(Align(INSTR_PROF_DATA_ALIGNMENT)); MaybeSetComdat(Data); - Data->setLinkage(Linkage); PD.DataVar = Data; @@ -1097,7 +1108,7 @@ void InstrProfiling::emitVNodes() { #define INSTR_PROF_VALUE_NODE(Type, LLVMType, Name, Init) LLVMType, #include "llvm/ProfileData/InstrProfData.inc" }; - auto *VNodeTy = StructType::get(Ctx, makeArrayRef(VNodeTypes)); + auto *VNodeTy = StructType::get(Ctx, ArrayRef(VNodeTypes)); ArrayType *VNodesTy = ArrayType::get(VNodeTy, NumCounters); auto *VNodesVar = new GlobalVariable( @@ -1174,7 +1185,7 @@ void InstrProfiling::emitRegistration() { if (NamesVar) { Type *ParamTypes[] = {VoidPtrTy, Int64Ty}; auto *NamesRegisterTy = - FunctionType::get(VoidTy, makeArrayRef(ParamTypes), false); + FunctionType::get(VoidTy, ArrayRef(ParamTypes), false); auto *NamesRegisterF = Function::Create(NamesRegisterTy, GlobalVariable::ExternalLinkage, getInstrProfNamesRegFuncName(), M); @@ -1188,7 +1199,7 @@ void InstrProfiling::emitRegistration() { bool InstrProfiling::emitRuntimeHook() { // We expect the linker to be invoked with -u<hook_var> flag for Linux // in which case there is no need to emit the external variable. - if (TT.isOSLinux()) + if (TT.isOSLinux() || TT.isOSAIX()) return false; // If the module's provided its own runtime, we don't need to do anything. diff --git a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp index bd575b6cf3b0..ab72650ae801 100644 --- a/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -88,17 +88,3 @@ Comdat *llvm::getOrCreateFunctionComdat(Function &F, Triple &T) { return C; } -/// initializeInstrumentation - Initialize all passes in the TransformUtils -/// library. -void llvm::initializeInstrumentation(PassRegistry &Registry) { - initializeMemProfilerLegacyPassPass(Registry); - initializeModuleMemProfilerLegacyPassPass(Registry); - initializeBoundsCheckingLegacyPassPass(Registry); - initializeDataFlowSanitizerLegacyPassPass(Registry); -} - -/// LLVMInitializeInstrumentation - C binding for -/// initializeInstrumentation. -void LLVMInitializeInstrumentation(LLVMPassRegistryRef R) { - initializeInstrumentation(*unwrap(R)); -} diff --git a/llvm/lib/Transforms/Instrumentation/KCFI.cpp b/llvm/lib/Transforms/Instrumentation/KCFI.cpp new file mode 100644 index 000000000000..7978c766f0f0 --- /dev/null +++ b/llvm/lib/Transforms/Instrumentation/KCFI.cpp @@ -0,0 +1,111 @@ +//===-- KCFI.cpp - Generic KCFI operand bundle lowering ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass emits generic KCFI indirect call checks for targets that don't +// support lowering KCFI operand bundles in the back-end. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/KCFI.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalObject.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "kcfi" + +STATISTIC(NumKCFIChecks, "Number of kcfi operands transformed into checks"); + +namespace { +class DiagnosticInfoKCFI : public DiagnosticInfo { + const Twine &Msg; + +public: + DiagnosticInfoKCFI(const Twine &DiagMsg, + DiagnosticSeverity Severity = DS_Error) + : DiagnosticInfo(DK_Linker, Severity), Msg(DiagMsg) {} + void print(DiagnosticPrinter &DP) const override { DP << Msg; } +}; +} // namespace + +PreservedAnalyses KCFIPass::run(Function &F, FunctionAnalysisManager &AM) { + Module &M = *F.getParent(); + if (!M.getModuleFlag("kcfi")) + return PreservedAnalyses::all(); + + // Find call instructions with KCFI operand bundles. + SmallVector<CallInst *> KCFICalls; + for (Instruction &I : instructions(F)) { + if (auto *CI = dyn_cast<CallInst>(&I)) + if (CI->getOperandBundle(LLVMContext::OB_kcfi)) + KCFICalls.push_back(CI); + } + + if (KCFICalls.empty()) + return PreservedAnalyses::all(); + + LLVMContext &Ctx = M.getContext(); + // patchable-function-prefix emits nops between the KCFI type identifier + // and the function start. As we don't know the size of the emitted nops, + // don't allow this attribute with generic lowering. + if (F.hasFnAttribute("patchable-function-prefix")) + Ctx.diagnose( + DiagnosticInfoKCFI("-fpatchable-function-entry=N,M, where M>0 is not " + "compatible with -fsanitize=kcfi on this target")); + + IntegerType *Int32Ty = Type::getInt32Ty(Ctx); + MDNode *VeryUnlikelyWeights = + MDBuilder(Ctx).createBranchWeights(1, (1U << 20) - 1); + + for (CallInst *CI : KCFICalls) { + // Get the expected hash value. + const uint32_t ExpectedHash = + cast<ConstantInt>(CI->getOperandBundle(LLVMContext::OB_kcfi)->Inputs[0]) + ->getZExtValue(); + + // Drop the KCFI operand bundle. + CallBase *Call = + CallBase::removeOperandBundle(CI, LLVMContext::OB_kcfi, CI); + assert(Call != CI); + Call->copyMetadata(*CI); + CI->replaceAllUsesWith(Call); + CI->eraseFromParent(); + + if (!Call->isIndirectCall()) + continue; + + // Emit a check and trap if the target hash doesn't match. + IRBuilder<> Builder(Call); + Value *HashPtr = Builder.CreateConstInBoundsGEP1_32( + Int32Ty, Call->getCalledOperand(), -1); + Value *Test = Builder.CreateICmpNE(Builder.CreateLoad(Int32Ty, HashPtr), + ConstantInt::get(Int32Ty, ExpectedHash)); + Instruction *ThenTerm = + SplitBlockAndInsertIfThen(Test, Call, false, VeryUnlikelyWeights); + Builder.SetInsertPoint(ThenTerm); + Builder.CreateCall(Intrinsic::getDeclaration(&M, Intrinsic::trap)); + ++NumKCFIChecks; + } + + return PreservedAnalyses::none(); +} diff --git a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp index 01e3b2c20218..2a1601fab45f 100644 --- a/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp @@ -171,8 +171,8 @@ public: /// If it is an interesting memory access, populate information /// about the access and return a InterestingMemoryAccess struct. - /// Otherwise return None. - Optional<InterestingMemoryAccess> + /// Otherwise return std::nullopt. + std::optional<InterestingMemoryAccess> isInterestingMemoryAccess(Instruction *I) const; void instrumentMop(Instruction *I, const DataLayout &DL, @@ -204,22 +204,6 @@ private: Value *DynamicShadowOffset = nullptr; }; -class MemProfilerLegacyPass : public FunctionPass { -public: - static char ID; - - explicit MemProfilerLegacyPass() : FunctionPass(ID) { - initializeMemProfilerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { return "MemProfilerFunctionPass"; } - - bool runOnFunction(Function &F) override { - MemProfiler Profiler(*F.getParent()); - return Profiler.instrumentFunction(F); - } -}; - class ModuleMemProfiler { public: ModuleMemProfiler(Module &M) { TargetTriple = Triple(M.getTargetTriple()); } @@ -232,24 +216,6 @@ private: Function *MemProfCtorFunction = nullptr; }; -class ModuleMemProfilerLegacyPass : public ModulePass { -public: - static char ID; - - explicit ModuleMemProfilerLegacyPass() : ModulePass(ID) { - initializeModuleMemProfilerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { return "ModuleMemProfiler"; } - - void getAnalysisUsage(AnalysisUsage &AU) const override {} - - bool runOnModule(Module &M) override { - ModuleMemProfiler MemProfiler(M); - return MemProfiler.instrumentModule(M); - } -}; - } // end anonymous namespace MemProfilerPass::MemProfilerPass() = default; @@ -273,30 +239,6 @@ PreservedAnalyses ModuleMemProfilerPass::run(Module &M, return PreservedAnalyses::all(); } -char MemProfilerLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(MemProfilerLegacyPass, "memprof", - "MemProfiler: profile memory allocations and accesses.", - false, false) -INITIALIZE_PASS_END(MemProfilerLegacyPass, "memprof", - "MemProfiler: profile memory allocations and accesses.", - false, false) - -FunctionPass *llvm::createMemProfilerFunctionPass() { - return new MemProfilerLegacyPass(); -} - -char ModuleMemProfilerLegacyPass::ID = 0; - -INITIALIZE_PASS(ModuleMemProfilerLegacyPass, "memprof-module", - "MemProfiler: profile memory allocations and accesses." - "ModulePass", - false, false) - -ModulePass *llvm::createModuleMemProfilerLegacyPassPass() { - return new ModuleMemProfilerLegacyPass(); -} - Value *MemProfiler::memToShadow(Value *Shadow, IRBuilder<> &IRB) { // (Shadow & mask) >> scale Shadow = IRB.CreateAnd(Shadow, Mapping.Mask); @@ -325,35 +267,35 @@ void MemProfiler::instrumentMemIntrinsic(MemIntrinsic *MI) { MI->eraseFromParent(); } -Optional<InterestingMemoryAccess> +std::optional<InterestingMemoryAccess> MemProfiler::isInterestingMemoryAccess(Instruction *I) const { // Do not instrument the load fetching the dynamic shadow address. if (DynamicShadowOffset == I) - return None; + return std::nullopt; InterestingMemoryAccess Access; if (LoadInst *LI = dyn_cast<LoadInst>(I)) { if (!ClInstrumentReads) - return None; + return std::nullopt; Access.IsWrite = false; Access.AccessTy = LI->getType(); Access.Addr = LI->getPointerOperand(); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { if (!ClInstrumentWrites) - return None; + return std::nullopt; Access.IsWrite = true; Access.AccessTy = SI->getValueOperand()->getType(); Access.Addr = SI->getPointerOperand(); } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) { if (!ClInstrumentAtomics) - return None; + return std::nullopt; Access.IsWrite = true; Access.AccessTy = RMW->getValOperand()->getType(); Access.Addr = RMW->getPointerOperand(); } else if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) { if (!ClInstrumentAtomics) - return None; + return std::nullopt; Access.IsWrite = true; Access.AccessTy = XCHG->getCompareOperand()->getType(); Access.Addr = XCHG->getPointerOperand(); @@ -364,14 +306,14 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { unsigned OpOffset = 0; if (F->getIntrinsicID() == Intrinsic::masked_store) { if (!ClInstrumentWrites) - return None; + return std::nullopt; // Masked store has an initial operand for the value. OpOffset = 1; Access.AccessTy = CI->getArgOperand(0)->getType(); Access.IsWrite = true; } else { if (!ClInstrumentReads) - return None; + return std::nullopt; Access.AccessTy = CI->getType(); Access.IsWrite = false; } @@ -383,20 +325,20 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { } if (!Access.Addr) - return None; + return std::nullopt; - // Do not instrument acesses from different address spaces; we cannot deal + // Do not instrument accesses from different address spaces; we cannot deal // with them. Type *PtrTy = cast<PointerType>(Access.Addr->getType()->getScalarType()); if (PtrTy->getPointerAddressSpace() != 0) - return None; + return std::nullopt; // Ignore swifterror addresses. // swifterror memory addresses are mem2reg promoted by instruction // selection. As such they cannot have regular uses like an instrumentation // function and it makes no sense to track them as memory. if (Access.Addr->isSwiftError()) - return None; + return std::nullopt; // Peel off GEPs and BitCasts. auto *Addr = Access.Addr->stripInBoundsOffsets(); @@ -409,12 +351,12 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { auto OF = Triple(I->getModule()->getTargetTriple()).getObjectFormat(); if (SectionName.endswith( getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false))) - return None; + return std::nullopt; } // Do not instrument accesses to LLVM internal variables. if (GV->getName().startswith("__llvm")) - return None; + return std::nullopt; } const DataLayout &DL = I->getModule()->getDataLayout(); @@ -643,7 +585,7 @@ bool MemProfiler::instrumentFunction(Function &F) { for (auto *Inst : ToInstrument) { if (ClDebugMin < 0 || ClDebugMax < 0 || (NumInstrumented >= ClDebugMin && NumInstrumented <= ClDebugMax)) { - Optional<InterestingMemoryAccess> Access = + std::optional<InterestingMemoryAccess> Access = isInterestingMemoryAccess(Inst); if (Access) instrumentMop(Inst, F.getParent()->getDataLayout(), *Access); diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 4606bd5de6c3..fe8b8ce0dc86 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -145,13 +145,15 @@ #include "llvm/Transforms/Instrumentation/MemorySanitizer.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" @@ -184,6 +186,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -202,6 +205,9 @@ using namespace llvm; #define DEBUG_TYPE "msan" +DEBUG_COUNTER(DebugInsertCheck, "msan-insert-check", + "Controls which checks to insert"); + static const unsigned kOriginSize = 4; static const Align kMinOriginAlignment = Align(4); static const Align kShadowTLSAlignment = Align(8); @@ -217,37 +223,48 @@ static const size_t kNumberOfAccessSizes = 4; /// /// Adds a section to MemorySanitizer report that points to the allocation /// (stack or heap) the uninitialized bits came from originally. -static cl::opt<int> ClTrackOrigins("msan-track-origins", - cl::desc("Track origins (allocation sites) of poisoned memory"), - cl::Hidden, cl::init(0)); +static cl::opt<int> ClTrackOrigins( + "msan-track-origins", + cl::desc("Track origins (allocation sites) of poisoned memory"), cl::Hidden, + cl::init(0)); static cl::opt<bool> ClKeepGoing("msan-keep-going", - cl::desc("keep going after reporting a UMR"), - cl::Hidden, cl::init(false)); + cl::desc("keep going after reporting a UMR"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> + ClPoisonStack("msan-poison-stack", + cl::desc("poison uninitialized stack variables"), cl::Hidden, + cl::init(true)); -static cl::opt<bool> ClPoisonStack("msan-poison-stack", - cl::desc("poison uninitialized stack variables"), - cl::Hidden, cl::init(true)); +static cl::opt<bool> ClPoisonStackWithCall( + "msan-poison-stack-with-call", + cl::desc("poison uninitialized stack variables with a call"), cl::Hidden, + cl::init(false)); -static cl::opt<bool> ClPoisonStackWithCall("msan-poison-stack-with-call", - cl::desc("poison uninitialized stack variables with a call"), - cl::Hidden, cl::init(false)); +static cl::opt<int> ClPoisonStackPattern( + "msan-poison-stack-pattern", + cl::desc("poison uninitialized stack variables with the given pattern"), + cl::Hidden, cl::init(0xff)); -static cl::opt<int> ClPoisonStackPattern("msan-poison-stack-pattern", - cl::desc("poison uninitialized stack variables with the given pattern"), - cl::Hidden, cl::init(0xff)); +static cl::opt<bool> + ClPrintStackNames("msan-print-stack-names", + cl::desc("Print name of local stack variable"), + cl::Hidden, cl::init(true)); static cl::opt<bool> ClPoisonUndef("msan-poison-undef", - cl::desc("poison undef temps"), - cl::Hidden, cl::init(true)); + cl::desc("poison undef temps"), cl::Hidden, + cl::init(true)); -static cl::opt<bool> ClHandleICmp("msan-handle-icmp", - cl::desc("propagate shadow through ICmpEQ and ICmpNE"), - cl::Hidden, cl::init(true)); +static cl::opt<bool> + ClHandleICmp("msan-handle-icmp", + cl::desc("propagate shadow through ICmpEQ and ICmpNE"), + cl::Hidden, cl::init(true)); -static cl::opt<bool> ClHandleICmpExact("msan-handle-icmp-exact", - cl::desc("exact handling of relational integer ICmp"), - cl::Hidden, cl::init(false)); +static cl::opt<bool> + ClHandleICmpExact("msan-handle-icmp-exact", + cl::desc("exact handling of relational integer ICmp"), + cl::Hidden, cl::init(false)); static cl::opt<bool> ClHandleLifetimeIntrinsics( "msan-handle-lifetime-intrinsics", @@ -277,18 +294,20 @@ static cl::opt<bool> ClHandleAsmConservative( // (e.g. only lower bits of address are garbage, or the access happens // early at program startup where malloc-ed memory is more likely to // be zeroed. As of 2012-08-28 this flag adds 20% slowdown. -static cl::opt<bool> ClCheckAccessAddress("msan-check-access-address", - cl::desc("report accesses through a pointer which has poisoned shadow"), - cl::Hidden, cl::init(true)); +static cl::opt<bool> ClCheckAccessAddress( + "msan-check-access-address", + cl::desc("report accesses through a pointer which has poisoned shadow"), + cl::Hidden, cl::init(true)); static cl::opt<bool> ClEagerChecks( "msan-eager-checks", cl::desc("check arguments and return values at function call boundaries"), cl::Hidden, cl::init(false)); -static cl::opt<bool> ClDumpStrictInstructions("msan-dump-strict-instructions", - cl::desc("print out instructions with default strict semantics"), - cl::Hidden, cl::init(false)); +static cl::opt<bool> ClDumpStrictInstructions( + "msan-dump-strict-instructions", + cl::desc("print out instructions with default strict semantics"), + cl::Hidden, cl::init(false)); static cl::opt<int> ClInstrumentationWithCallThreshold( "msan-instrumentation-with-call-threshold", @@ -308,18 +327,17 @@ static cl::opt<bool> cl::desc("Apply no_sanitize to the whole file"), cl::Hidden, cl::init(false)); -// This is an experiment to enable handling of cases where shadow is a non-zero -// compile-time constant. For some unexplainable reason they were silently -// ignored in the instrumentation. -static cl::opt<bool> ClCheckConstantShadow("msan-check-constant-shadow", - cl::desc("Insert checks for constant shadow values"), - cl::Hidden, cl::init(false)); +static cl::opt<bool> + ClCheckConstantShadow("msan-check-constant-shadow", + cl::desc("Insert checks for constant shadow values"), + cl::Hidden, cl::init(true)); // This is off by default because of a bug in gold: // https://sourceware.org/bugzilla/show_bug.cgi?id=19002 -static cl::opt<bool> ClWithComdat("msan-with-comdat", - cl::desc("Place MSan constructors in comdat sections"), - cl::Hidden, cl::init(false)); +static cl::opt<bool> + ClWithComdat("msan-with-comdat", + cl::desc("Place MSan constructors in comdat sections"), + cl::Hidden, cl::init(false)); // These options allow to specify custom memory map parameters // See MemoryMapParams for details. @@ -339,6 +357,12 @@ static cl::opt<uint64_t> ClOriginBase("msan-origin-base", cl::desc("Define custom MSan OriginBase"), cl::Hidden, cl::init(0)); +static cl::opt<int> + ClDisambiguateWarning("msan-disambiguate-warning-threshold", + cl::desc("Define threshold for number of checks per " + "debug location to force origin update."), + cl::Hidden, cl::init(3)); + const char kMsanModuleCtorName[] = "msan.module_ctor"; const char kMsanInitName[] = "__msan_init"; @@ -364,41 +388,34 @@ struct PlatformMemoryMapParams { // i386 Linux static const MemoryMapParams Linux_I386_MemoryMapParams = { - 0x000080000000, // AndMask - 0, // XorMask (not used) - 0, // ShadowBase (not used) - 0x000040000000, // OriginBase + 0x000080000000, // AndMask + 0, // XorMask (not used) + 0, // ShadowBase (not used) + 0x000040000000, // OriginBase }; // x86_64 Linux static const MemoryMapParams Linux_X86_64_MemoryMapParams = { -#ifdef MSAN_LINUX_X86_64_OLD_MAPPING - 0x400000000000, // AndMask - 0, // XorMask (not used) - 0, // ShadowBase (not used) - 0x200000000000, // OriginBase -#else - 0, // AndMask (not used) - 0x500000000000, // XorMask - 0, // ShadowBase (not used) - 0x100000000000, // OriginBase -#endif + 0, // AndMask (not used) + 0x500000000000, // XorMask + 0, // ShadowBase (not used) + 0x100000000000, // OriginBase }; // mips64 Linux static const MemoryMapParams Linux_MIPS64_MemoryMapParams = { - 0, // AndMask (not used) - 0x008000000000, // XorMask - 0, // ShadowBase (not used) - 0x002000000000, // OriginBase + 0, // AndMask (not used) + 0x008000000000, // XorMask + 0, // ShadowBase (not used) + 0x002000000000, // OriginBase }; // ppc64 Linux static const MemoryMapParams Linux_PowerPC64_MemoryMapParams = { - 0xE00000000000, // AndMask - 0x100000000000, // XorMask - 0x080000000000, // ShadowBase - 0x1C0000000000, // OriginBase + 0xE00000000000, // AndMask + 0x100000000000, // XorMask + 0x080000000000, // ShadowBase + 0x1C0000000000, // OriginBase }; // s390x Linux @@ -411,57 +428,57 @@ static const MemoryMapParams Linux_S390X_MemoryMapParams = { // aarch64 Linux static const MemoryMapParams Linux_AArch64_MemoryMapParams = { - 0, // AndMask (not used) - 0x06000000000, // XorMask - 0, // ShadowBase (not used) - 0x01000000000, // OriginBase + 0, // AndMask (not used) + 0x0B00000000000, // XorMask + 0, // ShadowBase (not used) + 0x0200000000000, // OriginBase }; // aarch64 FreeBSD static const MemoryMapParams FreeBSD_AArch64_MemoryMapParams = { - 0x1800000000000, // AndMask - 0x0400000000000, // XorMask - 0x0200000000000, // ShadowBase - 0x0700000000000, // OriginBase + 0x1800000000000, // AndMask + 0x0400000000000, // XorMask + 0x0200000000000, // ShadowBase + 0x0700000000000, // OriginBase }; // i386 FreeBSD static const MemoryMapParams FreeBSD_I386_MemoryMapParams = { - 0x000180000000, // AndMask - 0x000040000000, // XorMask - 0x000020000000, // ShadowBase - 0x000700000000, // OriginBase + 0x000180000000, // AndMask + 0x000040000000, // XorMask + 0x000020000000, // ShadowBase + 0x000700000000, // OriginBase }; // x86_64 FreeBSD static const MemoryMapParams FreeBSD_X86_64_MemoryMapParams = { - 0xc00000000000, // AndMask - 0x200000000000, // XorMask - 0x100000000000, // ShadowBase - 0x380000000000, // OriginBase + 0xc00000000000, // AndMask + 0x200000000000, // XorMask + 0x100000000000, // ShadowBase + 0x380000000000, // OriginBase }; // x86_64 NetBSD static const MemoryMapParams NetBSD_X86_64_MemoryMapParams = { - 0, // AndMask - 0x500000000000, // XorMask - 0, // ShadowBase - 0x100000000000, // OriginBase + 0, // AndMask + 0x500000000000, // XorMask + 0, // ShadowBase + 0x100000000000, // OriginBase }; static const PlatformMemoryMapParams Linux_X86_MemoryMapParams = { - &Linux_I386_MemoryMapParams, - &Linux_X86_64_MemoryMapParams, + &Linux_I386_MemoryMapParams, + &Linux_X86_64_MemoryMapParams, }; static const PlatformMemoryMapParams Linux_MIPS_MemoryMapParams = { - nullptr, - &Linux_MIPS64_MemoryMapParams, + nullptr, + &Linux_MIPS64_MemoryMapParams, }; static const PlatformMemoryMapParams Linux_PowerPC_MemoryMapParams = { - nullptr, - &Linux_PowerPC64_MemoryMapParams, + nullptr, + &Linux_PowerPC64_MemoryMapParams, }; static const PlatformMemoryMapParams Linux_S390_MemoryMapParams = { @@ -470,23 +487,23 @@ static const PlatformMemoryMapParams Linux_S390_MemoryMapParams = { }; static const PlatformMemoryMapParams Linux_ARM_MemoryMapParams = { - nullptr, - &Linux_AArch64_MemoryMapParams, + nullptr, + &Linux_AArch64_MemoryMapParams, }; static const PlatformMemoryMapParams FreeBSD_ARM_MemoryMapParams = { - nullptr, - &FreeBSD_AArch64_MemoryMapParams, + nullptr, + &FreeBSD_AArch64_MemoryMapParams, }; static const PlatformMemoryMapParams FreeBSD_X86_MemoryMapParams = { - &FreeBSD_I386_MemoryMapParams, - &FreeBSD_X86_64_MemoryMapParams, + &FreeBSD_I386_MemoryMapParams, + &FreeBSD_X86_64_MemoryMapParams, }; static const PlatformMemoryMapParams NetBSD_X86_MemoryMapParams = { - nullptr, - &NetBSD_X86_64_MemoryMapParams, + nullptr, + &NetBSD_X86_64_MemoryMapParams, }; namespace { @@ -522,9 +539,9 @@ private: friend struct VarArgSystemZHelper; void initializeModule(Module &M); - void initializeCallbacks(Module &M); - void createKernelApi(Module &M); - void createUserspaceApi(Module &M); + void initializeCallbacks(Module &M, const TargetLibraryInfo &TLI); + void createKernelApi(Module &M, const TargetLibraryInfo &TLI); + void createUserspaceApi(Module &M, const TargetLibraryInfo &TLI); /// True if we're compiling the Linux kernel. bool CompileKernel; @@ -579,7 +596,9 @@ private: /// Run-time helper that generates a new origin value for a stack /// allocation. - FunctionCallee MsanSetAllocaOrigin4Fn; + FunctionCallee MsanSetAllocaOriginWithDescriptionFn; + // No description version + FunctionCallee MsanSetAllocaOriginNoDescriptionFn; /// Run-time helper that poisons stack on function entry. FunctionCallee MsanPoisonStackFn; @@ -655,20 +674,32 @@ MemorySanitizerOptions::MemorySanitizerOptions(int TO, bool R, bool K, Recover(getOptOrDefault(ClKeepGoing, Kernel || R)), EagerChecks(getOptOrDefault(ClEagerChecks, EagerChecks)) {} -PreservedAnalyses MemorySanitizerPass::run(Function &F, - FunctionAnalysisManager &FAM) { - MemorySanitizer Msan(*F.getParent(), Options); - if (Msan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F))) - return PreservedAnalyses::none(); - return PreservedAnalyses::all(); -} +PreservedAnalyses MemorySanitizerPass::run(Module &M, + ModuleAnalysisManager &AM) { + bool Modified = false; + if (!Options.Kernel) { + insertModuleCtor(M); + Modified = true; + } -PreservedAnalyses -ModuleMemorySanitizerPass::run(Module &M, ModuleAnalysisManager &AM) { - if (Options.Kernel) + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + for (Function &F : M) { + if (F.empty()) + continue; + MemorySanitizer Msan(*F.getParent(), Options); + Modified |= + Msan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F)); + } + + if (!Modified) return PreservedAnalyses::all(); - insertModuleCtor(M); - return PreservedAnalyses::none(); + + PreservedAnalyses PA = PreservedAnalyses::none(); + // GlobalsAA is considered stateless and does not get invalidated unless + // explicitly invalidated; PreservedAnalyses::none() is not enough. Sanitizers + // make changes that require GlobalsAA to be invalidated. + PA.abandon<GlobalsAA>(); + return PA; } void MemorySanitizerPass::printPipeline( @@ -691,15 +722,15 @@ void MemorySanitizerPass::printPipeline( /// Creates a writable global for Str so that we can pass it to the /// run-time lib. Runtime uses first 4 bytes of the string to store the /// frame ID, so the string needs to be mutable. -static GlobalVariable *createPrivateNonConstGlobalForString(Module &M, - StringRef Str) { +static GlobalVariable *createPrivateConstGlobalForString(Module &M, + StringRef Str) { Constant *StrConst = ConstantDataArray::getString(M.getContext(), Str); - return new GlobalVariable(M, StrConst->getType(), /*isConstant=*/false, + return new GlobalVariable(M, StrConst->getType(), /*isConstant=*/true, GlobalValue::PrivateLinkage, StrConst, ""); } /// Create KMSAN API callbacks. -void MemorySanitizer::createKernelApi(Module &M) { +void MemorySanitizer::createKernelApi(Module &M, const TargetLibraryInfo &TLI) { IRBuilder<> IRB(*C); // These will be initialized in insertKmsanPrologue(). @@ -711,8 +742,10 @@ void MemorySanitizer::createKernelApi(Module &M) { VAArgOriginTLS = nullptr; VAArgOverflowSizeTLS = nullptr; - WarningFn = M.getOrInsertFunction("__msan_warning", IRB.getVoidTy(), - IRB.getInt32Ty()); + WarningFn = M.getOrInsertFunction("__msan_warning", + TLI.getAttrList(C, {0}, /*Signed=*/false), + IRB.getVoidTy(), IRB.getInt32Ty()); + // Requests the per-task context state (kmsan_context_state*) from the // runtime library. MsanContextStateTy = StructType::get( @@ -763,16 +796,23 @@ static Constant *getOrInsertGlobal(Module &M, StringRef Name, Type *Ty) { } /// Insert declarations for userspace-specific functions and globals. -void MemorySanitizer::createUserspaceApi(Module &M) { +void MemorySanitizer::createUserspaceApi(Module &M, const TargetLibraryInfo &TLI) { IRBuilder<> IRB(*C); // Create the callback. // FIXME: this function should have "Cold" calling conv, // which is not yet implemented. - StringRef WarningFnName = Recover ? "__msan_warning_with_origin" - : "__msan_warning_with_origin_noreturn"; - WarningFn = - M.getOrInsertFunction(WarningFnName, IRB.getVoidTy(), IRB.getInt32Ty()); + if (TrackOrigins) { + StringRef WarningFnName = Recover ? "__msan_warning_with_origin" + : "__msan_warning_with_origin_noreturn"; + WarningFn = M.getOrInsertFunction(WarningFnName, + TLI.getAttrList(C, {0}, /*Signed=*/false), + IRB.getVoidTy(), IRB.getInt32Ty()); + } else { + StringRef WarningFnName = + Recover ? "__msan_warning" : "__msan_warning_noreturn"; + WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy()); + } // Create the global TLS variables. RetvalTLS = @@ -804,37 +844,29 @@ void MemorySanitizer::createUserspaceApi(Module &M) { AccessSizeIndex++) { unsigned AccessSize = 1 << AccessSizeIndex; std::string FunctionName = "__msan_maybe_warning_" + itostr(AccessSize); - SmallVector<std::pair<unsigned, Attribute>, 2> MaybeWarningFnAttrs; - MaybeWarningFnAttrs.push_back(std::make_pair( - AttributeList::FirstArgIndex, Attribute::get(*C, Attribute::ZExt))); - MaybeWarningFnAttrs.push_back(std::make_pair( - AttributeList::FirstArgIndex + 1, Attribute::get(*C, Attribute::ZExt))); MaybeWarningFn[AccessSizeIndex] = M.getOrInsertFunction( - FunctionName, AttributeList::get(*C, MaybeWarningFnAttrs), + FunctionName, TLI.getAttrList(C, {0, 1}, /*Signed=*/false), IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), IRB.getInt32Ty()); FunctionName = "__msan_maybe_store_origin_" + itostr(AccessSize); - SmallVector<std::pair<unsigned, Attribute>, 2> MaybeStoreOriginFnAttrs; - MaybeStoreOriginFnAttrs.push_back(std::make_pair( - AttributeList::FirstArgIndex, Attribute::get(*C, Attribute::ZExt))); - MaybeStoreOriginFnAttrs.push_back(std::make_pair( - AttributeList::FirstArgIndex + 2, Attribute::get(*C, Attribute::ZExt))); MaybeStoreOriginFn[AccessSizeIndex] = M.getOrInsertFunction( - FunctionName, AttributeList::get(*C, MaybeStoreOriginFnAttrs), + FunctionName, TLI.getAttrList(C, {0, 2}, /*Signed=*/false), IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), IRB.getInt8PtrTy(), IRB.getInt32Ty()); } - MsanSetAllocaOrigin4Fn = M.getOrInsertFunction( - "__msan_set_alloca_origin4", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy, - IRB.getInt8PtrTy(), IntptrTy); - MsanPoisonStackFn = - M.getOrInsertFunction("__msan_poison_stack", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy); + MsanSetAllocaOriginWithDescriptionFn = M.getOrInsertFunction( + "__msan_set_alloca_origin_with_descr", IRB.getVoidTy(), + IRB.getInt8PtrTy(), IntptrTy, IRB.getInt8PtrTy(), IRB.getInt8PtrTy()); + MsanSetAllocaOriginNoDescriptionFn = M.getOrInsertFunction( + "__msan_set_alloca_origin_no_descr", IRB.getVoidTy(), IRB.getInt8PtrTy(), + IntptrTy, IRB.getInt8PtrTy()); + MsanPoisonStackFn = M.getOrInsertFunction( + "__msan_poison_stack", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy); } /// Insert extern declaration of runtime-provided functions and globals. -void MemorySanitizer::initializeCallbacks(Module &M) { +void MemorySanitizer::initializeCallbacks(Module &M, const TargetLibraryInfo &TLI) { // Only do this once. if (CallbacksInitialized) return; @@ -843,28 +875,30 @@ void MemorySanitizer::initializeCallbacks(Module &M) { // Initialize callbacks that are common for kernel and userspace // instrumentation. MsanChainOriginFn = M.getOrInsertFunction( - "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty()); - MsanSetOriginFn = - M.getOrInsertFunction("__msan_set_origin", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy, IRB.getInt32Ty()); - MemmoveFn = M.getOrInsertFunction( - "__msan_memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy); - MemcpyFn = M.getOrInsertFunction( - "__msan_memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IntptrTy); + "__msan_chain_origin", + TLI.getAttrList(C, {0}, /*Signed=*/false, /*Ret=*/true), IRB.getInt32Ty(), + IRB.getInt32Ty()); + MsanSetOriginFn = M.getOrInsertFunction( + "__msan_set_origin", TLI.getAttrList(C, {2}, /*Signed=*/false), + IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy, IRB.getInt32Ty()); + MemmoveFn = + M.getOrInsertFunction("__msan_memmove", IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); + MemcpyFn = + M.getOrInsertFunction("__msan_memcpy", IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); MemsetFn = M.getOrInsertFunction( - "__msan_memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), - IntptrTy); + "__msan_memset", TLI.getAttrList(C, {1}, /*Signed=*/true), + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); MsanInstrumentAsmStoreFn = M.getOrInsertFunction("__msan_instrument_asm_store", IRB.getVoidTy(), PointerType::get(IRB.getInt8Ty(), 0), IntptrTy); if (CompileKernel) { - createKernelApi(M); + createKernelApi(M, TLI); } else { - createUserspaceApi(M); + createUserspaceApi(M, TLI); } CallbacksInitialized = true; } @@ -905,59 +939,59 @@ void MemorySanitizer::initializeModule(Module &M) { } else { Triple TargetTriple(M.getTargetTriple()); switch (TargetTriple.getOS()) { - case Triple::FreeBSD: - switch (TargetTriple.getArch()) { - case Triple::aarch64: - MapParams = FreeBSD_ARM_MemoryMapParams.bits64; - break; - case Triple::x86_64: - MapParams = FreeBSD_X86_MemoryMapParams.bits64; - break; - case Triple::x86: - MapParams = FreeBSD_X86_MemoryMapParams.bits32; - break; - default: - report_fatal_error("unsupported architecture"); - } + case Triple::FreeBSD: + switch (TargetTriple.getArch()) { + case Triple::aarch64: + MapParams = FreeBSD_ARM_MemoryMapParams.bits64; break; - case Triple::NetBSD: - switch (TargetTriple.getArch()) { - case Triple::x86_64: - MapParams = NetBSD_X86_MemoryMapParams.bits64; - break; - default: - report_fatal_error("unsupported architecture"); - } + case Triple::x86_64: + MapParams = FreeBSD_X86_MemoryMapParams.bits64; break; - case Triple::Linux: - switch (TargetTriple.getArch()) { - case Triple::x86_64: - MapParams = Linux_X86_MemoryMapParams.bits64; - break; - case Triple::x86: - MapParams = Linux_X86_MemoryMapParams.bits32; - break; - case Triple::mips64: - case Triple::mips64el: - MapParams = Linux_MIPS_MemoryMapParams.bits64; - break; - case Triple::ppc64: - case Triple::ppc64le: - MapParams = Linux_PowerPC_MemoryMapParams.bits64; - break; - case Triple::systemz: - MapParams = Linux_S390_MemoryMapParams.bits64; - break; - case Triple::aarch64: - case Triple::aarch64_be: - MapParams = Linux_ARM_MemoryMapParams.bits64; - break; - default: - report_fatal_error("unsupported architecture"); - } + case Triple::x86: + MapParams = FreeBSD_X86_MemoryMapParams.bits32; + break; + default: + report_fatal_error("unsupported architecture"); + } + break; + case Triple::NetBSD: + switch (TargetTriple.getArch()) { + case Triple::x86_64: + MapParams = NetBSD_X86_MemoryMapParams.bits64; break; default: - report_fatal_error("unsupported operating system"); + report_fatal_error("unsupported architecture"); + } + break; + case Triple::Linux: + switch (TargetTriple.getArch()) { + case Triple::x86_64: + MapParams = Linux_X86_MemoryMapParams.bits64; + break; + case Triple::x86: + MapParams = Linux_X86_MemoryMapParams.bits32; + break; + case Triple::mips64: + case Triple::mips64el: + MapParams = Linux_MIPS_MemoryMapParams.bits64; + break; + case Triple::ppc64: + case Triple::ppc64le: + MapParams = Linux_PowerPC_MemoryMapParams.bits64; + break; + case Triple::systemz: + MapParams = Linux_S390_MemoryMapParams.bits64; + break; + case Triple::aarch64: + case Triple::aarch64_be: + MapParams = Linux_ARM_MemoryMapParams.bits64; + break; + default: + report_fatal_error("unsupported architecture"); + } + break; + default: + report_fatal_error("unsupported operating system"); } } @@ -983,7 +1017,7 @@ void MemorySanitizer::initializeModule(Module &M) { GlobalValue::WeakODRLinkage, IRB.getInt32(Recover), "__msan_keep_going"); }); -} + } } namespace { @@ -1023,12 +1057,22 @@ static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, MemorySanitizerVisitor &Visitor); static unsigned TypeSizeToSizeIndex(unsigned TypeSize) { - if (TypeSize <= 8) return 0; + if (TypeSize <= 8) + return 0; return Log2_32_Ceil((TypeSize + 7) / 8); } namespace { +/// Helper class to attach debug information of the given instruction onto new +/// instructions inserted after. +class NextNodeIRBuilder : public IRBuilder<> { +public: + explicit NextNodeIRBuilder(Instruction *IP) : IRBuilder<>(IP->getNextNode()) { + SetCurrentDebugLocation(IP->getDebugLoc()); + } +}; + /// This class does all the work for a given function. Store and Load /// instructions store and load corresponding shadow and origin /// values. Most instructions propagate shadow from arguments to their @@ -1039,7 +1083,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Function &F; MemorySanitizer &MS; SmallVector<PHINode *, 16> ShadowPHINodes, OriginPHINodes; - ValueMap<Value*, Value*> ShadowMap, OriginMap; + ValueMap<Value *, Value *> ShadowMap, OriginMap; std::unique_ptr<VarArgHelper> VAHelper; const TargetLibraryInfo *TLI; Instruction *FnPrologueEnd; @@ -1057,13 +1101,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Instruction *OrigIns; ShadowOriginAndInsertPoint(Value *S, Value *O, Instruction *I) - : Shadow(S), Origin(O), OrigIns(I) {} + : Shadow(S), Origin(O), OrigIns(I) {} }; SmallVector<ShadowOriginAndInsertPoint, 16> InstrumentationList; + DenseMap<const DILocation *, int> LazyWarningDebugLocationCount; bool InstrumentLifetimeStart = ClHandleLifetimeIntrinsics; - SmallSet<AllocaInst *, 16> AllocaSet; + SmallSetVector<AllocaInst *, 16> AllocaSet; SmallVector<std::pair<IntrinsicInst *, AllocaInst *>, 16> LifetimeStartList; SmallVector<StoreInst *, 16> StoreList; + int64_t SplittableBlocksCount = 0; MemorySanitizerVisitor(Function &F, MemorySanitizer &MS, const TargetLibraryInfo &TLI) @@ -1081,7 +1127,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // It's easier to remove unreachable blocks than deal with missing shadow. removeUnreachableBlocks(F); - MS.initializeCallbacks(*F.getParent()); + MS.initializeCallbacks(*F.getParent(), TLI); FnPrologueEnd = IRBuilder<>(F.getEntryBlock().getFirstNonPHI()) .CreateIntrinsic(Intrinsic::donothing, {}, {}); @@ -1095,20 +1141,36 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { << F.getName() << "'\n"); } + bool instrumentWithCalls(Value *V) { + // Constants likely will be eliminated by follow-up passes. + if (isa<Constant>(V)) + return false; + + ++SplittableBlocksCount; + return ClInstrumentationWithCallThreshold >= 0 && + SplittableBlocksCount > ClInstrumentationWithCallThreshold; + } + bool isInPrologue(Instruction &I) { return I.getParent() == FnPrologueEnd->getParent() && (&I == FnPrologueEnd || I.comesBefore(FnPrologueEnd)); } + // Creates a new origin and records the stack trace. In general we can call + // this function for any origin manipulation we like. However it will cost + // runtime resources. So use this wisely only if it can provide additional + // information helpful to a user. Value *updateOrigin(Value *V, IRBuilder<> &IRB) { - if (MS.TrackOrigins <= 1) return V; + if (MS.TrackOrigins <= 1) + return V; return IRB.CreateCall(MS.MsanChainOriginFn, V); } Value *originToIntptr(IRBuilder<> &IRB, Value *Origin) { const DataLayout &DL = F.getParent()->getDataLayout(); unsigned IntptrSize = DL.getTypeStoreSize(MS.IntptrTy); - if (IntptrSize == kOriginSize) return Origin; + if (IntptrSize == kOriginSize) + return Origin; assert(IntptrSize == kOriginSize * 2); Origin = IRB.CreateIntCast(Origin, MS.IntptrTy, /* isSigned */ false); return IRB.CreateOr(Origin, IRB.CreateShl(Origin, kOriginSize * 8)); @@ -1147,21 +1209,30 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } void storeOrigin(IRBuilder<> &IRB, Value *Addr, Value *Shadow, Value *Origin, - Value *OriginPtr, Align Alignment, bool AsCall) { + Value *OriginPtr, Align Alignment) { const DataLayout &DL = F.getParent()->getDataLayout(); const Align OriginAlignment = std::max(kMinOriginAlignment, Alignment); unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType()); Value *ConvertedShadow = convertShadowToScalar(Shadow, IRB); if (auto *ConstantShadow = dyn_cast<Constant>(ConvertedShadow)) { - if (ClCheckConstantShadow && !ConstantShadow->isZeroValue()) + if (!ClCheckConstantShadow || ConstantShadow->isZeroValue()) { + // Origin is not needed: value is initialized or const shadow is + // ignored. + return; + } + if (llvm::isKnownNonZero(ConvertedShadow, DL)) { + // Copy origin as the value is definitely uninitialized. paintOrigin(IRB, updateOrigin(Origin, IRB), OriginPtr, StoreSize, OriginAlignment); - return; + return; + } + // Fallback to runtime check, which still can be optimized out later. } unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); - if (AsCall && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { + if (instrumentWithCalls(ConvertedShadow) && + SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { FunctionCallee Fn = MS.MaybeStoreOriginFn[SizeIndex]; Value *ConvertedShadow2 = IRB.CreateZExt(ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex))); @@ -1180,7 +1251,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - void materializeStores(bool InstrumentWithCalls) { + void materializeStores() { for (StoreInst *SI : StoreList) { IRBuilder<> IRB(SI); Value *Val = SI->getValueOperand(); @@ -1202,40 +1273,62 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins && !SI->isAtomic()) storeOrigin(IRB, Addr, Shadow, getOrigin(Val), OriginPtr, - OriginAlignment, InstrumentWithCalls); + OriginAlignment); } } + // Returns true if Debug Location curresponds to multiple warnings. + bool shouldDisambiguateWarningLocation(const DebugLoc &DebugLoc) { + if (MS.TrackOrigins < 2) + return false; + + if (LazyWarningDebugLocationCount.empty()) + for (const auto &I : InstrumentationList) + ++LazyWarningDebugLocationCount[I.OrigIns->getDebugLoc()]; + + return LazyWarningDebugLocationCount[DebugLoc] >= ClDisambiguateWarning; + } + /// Helper function to insert a warning at IRB's current insert point. void insertWarningFn(IRBuilder<> &IRB, Value *Origin) { if (!Origin) Origin = (Value *)IRB.getInt32(0); assert(Origin->getType()->isIntegerTy()); - IRB.CreateCall(MS.WarningFn, Origin)->setCannotMerge(); - // FIXME: Insert UnreachableInst if !MS.Recover? - // This may invalidate some of the following checks and needs to be done - // at the very end. - } - - void materializeOneCheck(Instruction *OrigIns, Value *Shadow, Value *Origin, - bool AsCall) { - IRBuilder<> IRB(OrigIns); - LLVM_DEBUG(dbgs() << " SHAD0 : " << *Shadow << "\n"); - Value *ConvertedShadow = convertShadowToScalar(Shadow, IRB); - LLVM_DEBUG(dbgs() << " SHAD1 : " << *ConvertedShadow << "\n"); - if (auto *ConstantShadow = dyn_cast<Constant>(ConvertedShadow)) { - if (ClCheckConstantShadow && !ConstantShadow->isZeroValue()) { - insertWarningFn(IRB, Origin); + if (shouldDisambiguateWarningLocation(IRB.getCurrentDebugLocation())) { + // Try to create additional origin with debug info of the last origin + // instruction. It may provide additional information to the user. + if (Instruction *OI = dyn_cast_or_null<Instruction>(Origin)) { + assert(MS.TrackOrigins); + auto NewDebugLoc = OI->getDebugLoc(); + // Origin update with missing or the same debug location provides no + // additional value. + if (NewDebugLoc && NewDebugLoc != IRB.getCurrentDebugLocation()) { + // Insert update just before the check, so we call runtime only just + // before the report. + IRBuilder<> IRBOrigin(&*IRB.GetInsertPoint()); + IRBOrigin.SetCurrentDebugLocation(NewDebugLoc); + Origin = updateOrigin(Origin, IRBOrigin); + } } - return; } - const DataLayout &DL = OrigIns->getModule()->getDataLayout(); + if (MS.CompileKernel || MS.TrackOrigins) + IRB.CreateCall(MS.WarningFn, Origin)->setCannotMerge(); + else + IRB.CreateCall(MS.WarningFn)->setCannotMerge(); + // FIXME: Insert UnreachableInst if !MS.Recover? + // This may invalidate some of the following checks and needs to be done + // at the very end. + } + void materializeOneCheck(IRBuilder<> &IRB, Value *ConvertedShadow, + Value *Origin) { + const DataLayout &DL = F.getParent()->getDataLayout(); unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); - if (AsCall && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { + if (instrumentWithCalls(ConvertedShadow) && + SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { FunctionCallee Fn = MS.MaybeWarningFn[SizeIndex]; Value *ConvertedShadow2 = IRB.CreateZExt(ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex))); @@ -1247,7 +1340,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } else { Value *Cmp = convertToBool(ConvertedShadow, IRB, "_mscmp"); Instruction *CheckTerm = SplitBlockAndInsertIfThen( - Cmp, OrigIns, + Cmp, &*IRB.GetInsertPoint(), /* Unreachable */ !MS.Recover, MS.ColdCallWeights); IRB.SetInsertPoint(CheckTerm); @@ -1256,13 +1349,77 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - void materializeChecks(bool InstrumentWithCalls) { - for (const auto &ShadowData : InstrumentationList) { - Instruction *OrigIns = ShadowData.OrigIns; - Value *Shadow = ShadowData.Shadow; - Value *Origin = ShadowData.Origin; - materializeOneCheck(OrigIns, Shadow, Origin, InstrumentWithCalls); + void materializeInstructionChecks( + ArrayRef<ShadowOriginAndInsertPoint> InstructionChecks) { + const DataLayout &DL = F.getParent()->getDataLayout(); + // Disable combining in some cases. TrackOrigins checks each shadow to pick + // correct origin. + bool Combine = !MS.TrackOrigins; + Instruction *Instruction = InstructionChecks.front().OrigIns; + Value *Shadow = nullptr; + for (const auto &ShadowData : InstructionChecks) { + assert(ShadowData.OrigIns == Instruction); + IRBuilder<> IRB(Instruction); + + Value *ConvertedShadow = ShadowData.Shadow; + + if (auto *ConstantShadow = dyn_cast<Constant>(ConvertedShadow)) { + if (!ClCheckConstantShadow || ConstantShadow->isZeroValue()) { + // Skip, value is initialized or const shadow is ignored. + continue; + } + if (llvm::isKnownNonZero(ConvertedShadow, DL)) { + // Report as the value is definitely uninitialized. + insertWarningFn(IRB, ShadowData.Origin); + if (!MS.Recover) + return; // Always fail and stop here, not need to check the rest. + // Skip entire instruction, + continue; + } + // Fallback to runtime check, which still can be optimized out later. + } + + if (!Combine) { + materializeOneCheck(IRB, ConvertedShadow, ShadowData.Origin); + continue; + } + + if (!Shadow) { + Shadow = ConvertedShadow; + continue; + } + + Shadow = convertToBool(Shadow, IRB, "_mscmp"); + ConvertedShadow = convertToBool(ConvertedShadow, IRB, "_mscmp"); + Shadow = IRB.CreateOr(Shadow, ConvertedShadow, "_msor"); + } + + if (Shadow) { + assert(Combine); + IRBuilder<> IRB(Instruction); + materializeOneCheck(IRB, Shadow, nullptr); + } + } + + void materializeChecks() { + llvm::stable_sort(InstrumentationList, + [](const ShadowOriginAndInsertPoint &L, + const ShadowOriginAndInsertPoint &R) { + return L.OrigIns < R.OrigIns; + }); + + for (auto I = InstrumentationList.begin(); + I != InstrumentationList.end();) { + auto J = + std::find_if(I + 1, InstrumentationList.end(), + [L = I->OrigIns](const ShadowOriginAndInsertPoint &R) { + return L != R.OrigIns; + }); + // Process all checks of instruction at once. + materializeInstructionChecks(ArrayRef<ShadowOriginAndInsertPoint>(I, J)); + I = J; } + LLVM_DEBUG(dbgs() << "DONE:\n" << F); } @@ -1303,7 +1460,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { size_t NumValues = PN->getNumIncomingValues(); for (size_t v = 0; v < NumValues; v++) { PNS->addIncoming(getShadow(PN, v), PN->getIncomingBlock(v)); - if (PNO) PNO->addIncoming(getOrigin(PN, v), PN->getIncomingBlock(v)); + if (PNO) + PNO->addIncoming(getOrigin(PN, v), PN->getIncomingBlock(v)); } } @@ -1314,7 +1472,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (InstrumentLifetimeStart) { for (auto Item : LifetimeStartList) { instrumentAlloca(*Item.second, Item.first); - AllocaSet.erase(Item.second); + AllocaSet.remove(Item.second); } } // Poison the allocas for which we didn't instrument the corresponding @@ -1322,24 +1480,18 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { for (AllocaInst *AI : AllocaSet) instrumentAlloca(*AI); - bool InstrumentWithCalls = ClInstrumentationWithCallThreshold >= 0 && - InstrumentationList.size() + StoreList.size() > - (unsigned)ClInstrumentationWithCallThreshold; - // Insert shadow value checks. - materializeChecks(InstrumentWithCalls); + materializeChecks(); // Delayed instrumentation of StoreInst. // This may not add new address checks. - materializeStores(InstrumentWithCalls); + materializeStores(); return true; } /// Compute the shadow type that corresponds to a given Value. - Type *getShadowTy(Value *V) { - return getShadowTy(V->getType()); - } + Type *getShadowTy(Value *V) { return getShadowTy(V->getType()); } /// Compute the shadow type that corresponds to a given Type. Type *getShadowTy(Type *OrigTy) { @@ -1361,7 +1513,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { AT->getNumElements()); } if (StructType *ST = dyn_cast<StructType>(OrigTy)) { - SmallVector<Type*, 4> Elements; + SmallVector<Type *, 4> Elements; for (unsigned i = 0, n = ST->getNumElements(); i < n; i++) Elements.push_back(getShadowTy(ST->getElementType(i))); StructType *Res = StructType::get(*MS.C, Elements, ST->isPacked()); @@ -1376,7 +1528,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Type *getShadowTyNoVec(Type *ty) { if (VectorType *vt = dyn_cast<VectorType>(ty)) return IntegerType::get(*MS.C, - vt->getPrimitiveSizeInBits().getFixedSize()); + vt->getPrimitiveSizeInBits().getFixedValue()); return ty; } @@ -1428,36 +1580,66 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return collapseArrayShadow(Array, V, IRB); Type *Ty = V->getType(); Type *NoVecTy = getShadowTyNoVec(Ty); - if (Ty == NoVecTy) return V; + if (Ty == NoVecTy) + return V; return IRB.CreateBitCast(V, NoVecTy); } // Convert a scalar value to an i1 by comparing with 0 Value *convertToBool(Value *V, IRBuilder<> &IRB, const Twine &name = "") { Type *VTy = V->getType(); - assert(VTy->isIntegerTy()); + if (!VTy->isIntegerTy()) + return convertToBool(convertShadowToScalar(V, IRB), IRB, name); if (VTy->getIntegerBitWidth() == 1) // Just converting a bool to a bool, so do nothing. return V; return IRB.CreateICmpNE(V, ConstantInt::get(VTy, 0), name); } + Type *ptrToIntPtrType(Type *PtrTy) const { + if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(PtrTy)) { + return FixedVectorType::get(ptrToIntPtrType(VectTy->getElementType()), + VectTy->getNumElements()); + } + assert(PtrTy->isIntOrPtrTy()); + return MS.IntptrTy; + } + + Type *getPtrToShadowPtrType(Type *IntPtrTy, Type *ShadowTy) const { + if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(IntPtrTy)) { + return FixedVectorType::get( + getPtrToShadowPtrType(VectTy->getElementType(), ShadowTy), + VectTy->getNumElements()); + } + assert(IntPtrTy == MS.IntptrTy); + return ShadowTy->getPointerTo(); + } + + Constant *constToIntPtr(Type *IntPtrTy, uint64_t C) const { + if (FixedVectorType *VectTy = dyn_cast<FixedVectorType>(IntPtrTy)) { + return ConstantDataVector::getSplat( + VectTy->getNumElements(), constToIntPtr(VectTy->getElementType(), C)); + } + assert(IntPtrTy == MS.IntptrTy); + return ConstantInt::get(MS.IntptrTy, C); + } + /// Compute the integer shadow offset that corresponds to a given /// application address. /// /// Offset = (Addr & ~AndMask) ^ XorMask + /// Addr can be a ptr or <N x ptr>. In both cases ShadowTy the shadow type of + /// a single pointee. + /// Returns <shadow_ptr, origin_ptr> or <<N x shadow_ptr>, <N x origin_ptr>>. Value *getShadowPtrOffset(Value *Addr, IRBuilder<> &IRB) { - Value *OffsetLong = IRB.CreatePointerCast(Addr, MS.IntptrTy); + Type *IntptrTy = ptrToIntPtrType(Addr->getType()); + Value *OffsetLong = IRB.CreatePointerCast(Addr, IntptrTy); - uint64_t AndMask = MS.MapParams->AndMask; - if (AndMask) - OffsetLong = - IRB.CreateAnd(OffsetLong, ConstantInt::get(MS.IntptrTy, ~AndMask)); + if (uint64_t AndMask = MS.MapParams->AndMask) + OffsetLong = IRB.CreateAnd(OffsetLong, constToIntPtr(IntptrTy, ~AndMask)); - uint64_t XorMask = MS.MapParams->XorMask; - if (XorMask) - OffsetLong = - IRB.CreateXor(OffsetLong, ConstantInt::get(MS.IntptrTy, XorMask)); + if (uint64_t XorMask = MS.MapParams->XorMask) + OffsetLong = IRB.CreateXor(OffsetLong, constToIntPtr(IntptrTy, XorMask)); return OffsetLong; } @@ -1466,41 +1648,43 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// /// Shadow = ShadowBase + Offset /// Origin = (OriginBase + Offset) & ~3ULL + /// Addr can be a ptr or <N x ptr>. In both cases ShadowTy the shadow type of + /// a single pointee. + /// Returns <shadow_ptr, origin_ptr> or <<N x shadow_ptr>, <N x origin_ptr>>. std::pair<Value *, Value *> getShadowOriginPtrUserspace(Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, MaybeAlign Alignment) { + Type *IntptrTy = ptrToIntPtrType(Addr->getType()); Value *ShadowOffset = getShadowPtrOffset(Addr, IRB); Value *ShadowLong = ShadowOffset; - uint64_t ShadowBase = MS.MapParams->ShadowBase; - if (ShadowBase != 0) { + if (uint64_t ShadowBase = MS.MapParams->ShadowBase) { ShadowLong = - IRB.CreateAdd(ShadowLong, - ConstantInt::get(MS.IntptrTy, ShadowBase)); + IRB.CreateAdd(ShadowLong, constToIntPtr(IntptrTy, ShadowBase)); } - Value *ShadowPtr = - IRB.CreateIntToPtr(ShadowLong, PointerType::get(ShadowTy, 0)); + Value *ShadowPtr = IRB.CreateIntToPtr( + ShadowLong, getPtrToShadowPtrType(IntptrTy, ShadowTy)); + Value *OriginPtr = nullptr; if (MS.TrackOrigins) { Value *OriginLong = ShadowOffset; uint64_t OriginBase = MS.MapParams->OriginBase; if (OriginBase != 0) - OriginLong = IRB.CreateAdd(OriginLong, - ConstantInt::get(MS.IntptrTy, OriginBase)); + OriginLong = + IRB.CreateAdd(OriginLong, constToIntPtr(IntptrTy, OriginBase)); if (!Alignment || *Alignment < kMinOriginAlignment) { uint64_t Mask = kMinOriginAlignment.value() - 1; - OriginLong = - IRB.CreateAnd(OriginLong, ConstantInt::get(MS.IntptrTy, ~Mask)); + OriginLong = IRB.CreateAnd(OriginLong, constToIntPtr(IntptrTy, ~Mask)); } - OriginPtr = - IRB.CreateIntToPtr(OriginLong, PointerType::get(MS.OriginTy, 0)); + OriginPtr = IRB.CreateIntToPtr( + OriginLong, getPtrToShadowPtrType(IntptrTy, MS.OriginTy)); } return std::make_pair(ShadowPtr, OriginPtr); } - std::pair<Value *, Value *> getShadowOriginPtrKernel(Value *Addr, - IRBuilder<> &IRB, - Type *ShadowTy, - bool isStore) { + std::pair<Value *, Value *> getShadowOriginPtrKernelNoVec(Value *Addr, + IRBuilder<> &IRB, + Type *ShadowTy, + bool isStore) { Value *ShadowOriginPtrs; const DataLayout &DL = F.getParent()->getDataLayout(); int Size = DL.getTypeStoreSize(ShadowTy); @@ -1523,6 +1707,42 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return std::make_pair(ShadowPtr, OriginPtr); } + /// Addr can be a ptr or <N x ptr>. In both cases ShadowTy the shadow type of + /// a single pointee. + /// Returns <shadow_ptr, origin_ptr> or <<N x shadow_ptr>, <N x origin_ptr>>. + std::pair<Value *, Value *> getShadowOriginPtrKernel(Value *Addr, + IRBuilder<> &IRB, + Type *ShadowTy, + bool isStore) { + FixedVectorType *VectTy = dyn_cast<FixedVectorType>(Addr->getType()); + if (!VectTy) { + assert(Addr->getType()->isPointerTy()); + return getShadowOriginPtrKernelNoVec(Addr, IRB, ShadowTy, isStore); + } + + // TODO: Support callbacs with vectors of addresses. + unsigned NumElements = VectTy->getNumElements(); + Value *ShadowPtrs = ConstantInt::getNullValue( + FixedVectorType::get(ShadowTy->getPointerTo(), NumElements)); + Value *OriginPtrs = nullptr; + if (MS.TrackOrigins) + OriginPtrs = ConstantInt::getNullValue( + FixedVectorType::get(MS.OriginTy->getPointerTo(), NumElements)); + for (unsigned i = 0; i < NumElements; ++i) { + Value *OneAddr = + IRB.CreateExtractElement(Addr, ConstantInt::get(IRB.getInt32Ty(), i)); + auto [ShadowPtr, OriginPtr] = + getShadowOriginPtrKernelNoVec(OneAddr, IRB, ShadowTy, isStore); + + ShadowPtrs = IRB.CreateInsertElement( + ShadowPtrs, ShadowPtr, ConstantInt::get(IRB.getInt32Ty(), i)); + if (MS.TrackOrigins) + OriginPtrs = IRB.CreateInsertElement( + OriginPtrs, OriginPtr, ConstantInt::get(IRB.getInt32Ty(), i)); + } + return {ShadowPtrs, OriginPtrs}; + } + std::pair<Value *, Value *> getShadowOriginPtr(Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, MaybeAlign Alignment, @@ -1535,8 +1755,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Compute the shadow address for a given function argument. /// /// Shadow = ParamTLS+ArgOffset. - Value *getShadowPtrForArgument(Value *A, IRBuilder<> &IRB, - int ArgOffset) { + Value *getShadowPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) { Value *Base = IRB.CreatePointerCast(MS.ParamTLS, MS.IntptrTy); if (ArgOffset) Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); @@ -1545,8 +1764,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } /// Compute the origin address for a given function argument. - Value *getOriginPtrForArgument(Value *A, IRBuilder<> &IRB, - int ArgOffset) { + Value *getOriginPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) { if (!MS.TrackOrigins) return nullptr; Value *Base = IRB.CreatePointerCast(MS.ParamOriginTLS, MS.IntptrTy); @@ -1559,8 +1777,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Compute the shadow address for a retval. Value *getShadowPtrForRetval(Value *A, IRBuilder<> &IRB) { return IRB.CreatePointerCast(MS.RetvalTLS, - PointerType::get(getShadowTy(A), 0), - "_msret"); + PointerType::get(getShadowTy(A), 0), "_msret"); } /// Compute the origin address for a retval. @@ -1577,7 +1794,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Set Origin to be the origin value for V. void setOrigin(Value *V, Value *Origin) { - if (!MS.TrackOrigins) return; + if (!MS.TrackOrigins) + return; assert(!OriginMap.count(V) && "Values may only have one origin"); LLVM_DEBUG(dbgs() << "ORIGIN: " << *V << " ==> " << *Origin << "\n"); OriginMap[V] = Origin; @@ -1594,9 +1812,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// /// Clean shadow (all zeroes) means all bits of the value are defined /// (initialized). - Constant *getCleanShadow(Value *V) { - return getCleanShadow(V->getType()); - } + Constant *getCleanShadow(Value *V) { return getCleanShadow(V->getType()); } /// Create a dirty shadow of a given shadow type. Constant *getPoisonedShadow(Type *ShadowTy) { @@ -1626,9 +1842,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } /// Create a clean (zero) origin. - Value *getCleanOrigin() { - return Constant::getNullValue(MS.OriginTy); - } + Value *getCleanOrigin() { return Constant::getNullValue(MS.OriginTy); } /// Get the shadow value for a given Value. /// @@ -1680,7 +1894,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // argument shadow to the underlying memory. // Figure out maximal valid memcpy alignment. const Align ArgAlign = DL.getValueOrABITypeAlignment( - MaybeAlign(FArg.getParamAlignment()), FArg.getParamByValType()); + FArg.getParamAlign(), FArg.getParamByValType()); Value *CpShadowPtr, *CpOriginPtr; std::tie(CpShadowPtr, CpOriginPtr) = getShadowOriginPtr(V, EntryIRB, EntryIRB.getInt8Ty(), ArgAlign, @@ -1721,7 +1935,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Shadow over TLS Value *Base = getShadowPtrForArgument(&FArg, EntryIRB, ArgOffset); ShadowPtr = EntryIRB.CreateAlignedLoad(getShadowTy(&FArg), Base, - kShadowTLSAlignment); + kShadowTLSAlignment); if (MS.TrackOrigins) { Value *OriginPtr = getOriginPtrForArgument(&FArg, EntryIRB, ArgOffset); @@ -1749,9 +1963,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Get the origin for a value. Value *getOrigin(Value *V) { - if (!MS.TrackOrigins) return nullptr; - if (!PropagateShadow) return getCleanOrigin(); - if (isa<Constant>(V)) return getCleanOrigin(); + if (!MS.TrackOrigins) + return nullptr; + if (!PropagateShadow || isa<Constant>(V) || isa<InlineAsm>(V)) + return getCleanOrigin(); assert((isa<Instruction>(V) || isa<Argument>(V)) && "Unexpected value type in getOrigin()"); if (Instruction *I = dyn_cast<Instruction>(V)) { @@ -1774,7 +1989,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// UMR warning in runtime if the shadow value is not 0. void insertShadowCheck(Value *Shadow, Value *Origin, Instruction *OrigIns) { assert(Shadow); - if (!InsertChecks) return; + if (!InsertChecks) + return; + + if (!DebugCounter::shouldExecute(DebugInsertCheck)) { + LLVM_DEBUG(dbgs() << "Skipping check of " << *Shadow << " before " + << *OrigIns << "\n"); + return; + } #ifndef NDEBUG Type *ShadowTy = Shadow->getType(); assert((isa<IntegerType>(ShadowTy) || isa<VectorType>(ShadowTy) || @@ -1795,11 +2017,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Shadow, *Origin; if (ClCheckConstantShadow) { Shadow = getShadow(Val); - if (!Shadow) return; + if (!Shadow) + return; Origin = getOrigin(Val); } else { Shadow = dyn_cast_or_null<Instruction>(getShadow(Val)); - if (!Shadow) return; + if (!Shadow) + return; Origin = dyn_cast_or_null<Instruction>(getOrigin(Val)); } insertShadowCheck(Shadow, Origin, OrigIns); @@ -1807,17 +2031,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { AtomicOrdering addReleaseOrdering(AtomicOrdering a) { switch (a) { - case AtomicOrdering::NotAtomic: - return AtomicOrdering::NotAtomic; - case AtomicOrdering::Unordered: - case AtomicOrdering::Monotonic: - case AtomicOrdering::Release: - return AtomicOrdering::Release; - case AtomicOrdering::Acquire: - case AtomicOrdering::AcquireRelease: - return AtomicOrdering::AcquireRelease; - case AtomicOrdering::SequentiallyConsistent: - return AtomicOrdering::SequentiallyConsistent; + case AtomicOrdering::NotAtomic: + return AtomicOrdering::NotAtomic; + case AtomicOrdering::Unordered: + case AtomicOrdering::Monotonic: + case AtomicOrdering::Release: + return AtomicOrdering::Release; + case AtomicOrdering::Acquire: + case AtomicOrdering::AcquireRelease: + return AtomicOrdering::AcquireRelease; + case AtomicOrdering::SequentiallyConsistent: + return AtomicOrdering::SequentiallyConsistent; } llvm_unreachable("Unknown ordering"); } @@ -1837,22 +2061,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { (int)AtomicOrderingCABI::seq_cst; return ConstantDataVector::get(IRB.getContext(), - makeArrayRef(OrderingTable, NumOrderings)); + ArrayRef(OrderingTable, NumOrderings)); } AtomicOrdering addAcquireOrdering(AtomicOrdering a) { switch (a) { - case AtomicOrdering::NotAtomic: - return AtomicOrdering::NotAtomic; - case AtomicOrdering::Unordered: - case AtomicOrdering::Monotonic: - case AtomicOrdering::Acquire: - return AtomicOrdering::Acquire; - case AtomicOrdering::Release: - case AtomicOrdering::AcquireRelease: - return AtomicOrdering::AcquireRelease; - case AtomicOrdering::SequentiallyConsistent: - return AtomicOrdering::SequentiallyConsistent; + case AtomicOrdering::NotAtomic: + return AtomicOrdering::NotAtomic; + case AtomicOrdering::Unordered: + case AtomicOrdering::Monotonic: + case AtomicOrdering::Acquire: + return AtomicOrdering::Acquire; + case AtomicOrdering::Release: + case AtomicOrdering::AcquireRelease: + return AtomicOrdering::AcquireRelease; + case AtomicOrdering::SequentiallyConsistent: + return AtomicOrdering::SequentiallyConsistent; } llvm_unreachable("Unknown ordering"); } @@ -1872,7 +2096,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { (int)AtomicOrderingCABI::seq_cst; return ConstantDataVector::get(IRB.getContext(), - makeArrayRef(OrderingTable, NumOrderings)); + ArrayRef(OrderingTable, NumOrderings)); } // ------------------- Visitors. @@ -1893,7 +2117,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitLoadInst(LoadInst &I) { assert(I.getType()->isSized() && "Load type must have size"); assert(!I.getMetadata(LLVMContext::MD_nosanitize)); - IRBuilder<> IRB(I.getNextNode()); + NextNodeIRBuilder IRB(&I); Type *ShadowTy = getShadowTy(&I); Value *Addr = I.getPointerOperand(); Value *ShadowPtr = nullptr, *OriginPtr = nullptr; @@ -1940,7 +2164,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); Value *Addr = I.getOperand(0); Value *Val = I.getOperand(1); - Value *ShadowPtr = getShadowOriginPtr(Addr, IRB, Val->getType(), Align(1), + Value *ShadowPtr = getShadowOriginPtr(Addr, IRB, getShadowTy(Val), Align(1), /*isStore*/ true) .first; @@ -1974,22 +2198,26 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { insertShadowCheck(I.getOperand(1), &I); IRBuilder<> IRB(&I); setShadow(&I, IRB.CreateExtractElement(getShadow(&I, 0), I.getOperand(1), - "_msprop")); + "_msprop")); setOrigin(&I, getOrigin(&I, 0)); } void visitInsertElementInst(InsertElementInst &I) { insertShadowCheck(I.getOperand(2), &I); IRBuilder<> IRB(&I); - setShadow(&I, IRB.CreateInsertElement(getShadow(&I, 0), getShadow(&I, 1), - I.getOperand(2), "_msprop")); + auto *Shadow0 = getShadow(&I, 0); + auto *Shadow1 = getShadow(&I, 1); + setShadow(&I, IRB.CreateInsertElement(Shadow0, Shadow1, I.getOperand(2), + "_msprop")); setOriginForNaryOp(I); } void visitShuffleVectorInst(ShuffleVectorInst &I) { IRBuilder<> IRB(&I); - setShadow(&I, IRB.CreateShuffleVector(getShadow(&I, 0), getShadow(&I, 1), - I.getShuffleMask(), "_msprop")); + auto *Shadow0 = getShadow(&I, 0); + auto *Shadow1 = getShadow(&I, 1); + setShadow(&I, IRB.CreateShuffleVector(Shadow0, Shadow1, I.getShuffleMask(), + "_msprop")); setOriginForNaryOp(I); } @@ -2027,23 +2255,23 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitPtrToIntInst(PtrToIntInst &I) { IRBuilder<> IRB(&I); setShadow(&I, IRB.CreateIntCast(getShadow(&I, 0), getShadowTy(&I), false, - "_msprop_ptrtoint")); + "_msprop_ptrtoint")); setOrigin(&I, getOrigin(&I, 0)); } void visitIntToPtrInst(IntToPtrInst &I) { IRBuilder<> IRB(&I); setShadow(&I, IRB.CreateIntCast(getShadow(&I, 0), getShadowTy(&I), false, - "_msprop_inttoptr")); + "_msprop_inttoptr")); setOrigin(&I, getOrigin(&I, 0)); } - void visitFPToSIInst(CastInst& I) { handleShadowOr(I); } - void visitFPToUIInst(CastInst& I) { handleShadowOr(I); } - void visitSIToFPInst(CastInst& I) { handleShadowOr(I); } - void visitUIToFPInst(CastInst& I) { handleShadowOr(I); } - void visitFPExtInst(CastInst& I) { handleShadowOr(I); } - void visitFPTruncInst(CastInst& I) { handleShadowOr(I); } + void visitFPToSIInst(CastInst &I) { handleShadowOr(I); } + void visitFPToUIInst(CastInst &I) { handleShadowOr(I); } + void visitSIToFPInst(CastInst &I) { handleShadowOr(I); } + void visitUIToFPInst(CastInst &I) { handleShadowOr(I); } + void visitFPExtInst(CastInst &I) { handleShadowOr(I); } + void visitFPTruncInst(CastInst &I) { handleShadowOr(I); } /// Propagate shadow for bitwise AND. /// @@ -2109,8 +2337,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// not entirely initialized. If there is more than one such arguments, the /// rightmost of them is picked. It does not matter which one is picked if all /// arguments are initialized. - template <bool CombineShadow> - class Combiner { + template <bool CombineShadow> class Combiner { Value *Shadow = nullptr; Value *Origin = nullptr; IRBuilder<> &IRB; @@ -2177,7 +2404,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Propagate origin for arbitrary operation. void setOriginForNaryOp(Instruction &I) { - if (!MS.TrackOrigins) return; + if (!MS.TrackOrigins) + return; IRBuilder<> IRB(&I); OriginCombiner OC(this, IRB); for (Use &Op : I.operands()) @@ -2211,7 +2439,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return IRB.CreateIntCast(V, dstTy, Signed); Value *V1 = IRB.CreateBitCast(V, Type::getIntNTy(*MS.C, srcSizeInBits)); Value *V2 = - IRB.CreateIntCast(V1, Type::getIntNTy(*MS.C, dstSizeInBits), Signed); + IRB.CreateIntCast(V1, Type::getIntNTy(*MS.C, dstSizeInBits), Signed); return IRB.CreateBitCast(V2, dstTy); // TODO: handle struct types. } @@ -2347,10 +2575,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Si = !(C & ~Sc) && Sc Value *Zero = Constant::getNullValue(Sc->getType()); Value *MinusOne = Constant::getAllOnesValue(Sc->getType()); - Value *Si = - IRB.CreateAnd(IRB.CreateICmpNE(Sc, Zero), - IRB.CreateICmpEQ( - IRB.CreateAnd(IRB.CreateXor(Sc, MinusOne), C), Zero)); + Value *LHS = IRB.CreateICmpNE(Sc, Zero); + Value *RHS = + IRB.CreateICmpEQ(IRB.CreateAnd(IRB.CreateXor(Sc, MinusOne), C), Zero); + Value *Si = IRB.CreateAnd(LHS, RHS); Si->setName("_msprop_icmp"); setShadow(&I, Si); setOriginForNaryOp(I); @@ -2365,8 +2593,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *SaOtherBits = IRB.CreateLShr(IRB.CreateShl(Sa, 1), 1); Value *SaSignBit = IRB.CreateXor(Sa, SaOtherBits); // Maximise the undefined shadow bit, minimize other undefined bits. - return - IRB.CreateOr(IRB.CreateAnd(A, IRB.CreateNot(SaOtherBits)), SaSignBit); + return IRB.CreateOr(IRB.CreateAnd(A, IRB.CreateNot(SaOtherBits)), + SaSignBit); } else { // Minimize undefined bits. return IRB.CreateAnd(A, IRB.CreateNot(Sa)); @@ -2376,14 +2604,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Build the highest possible value of V, taking into account V's /// uninitialized bits. Value *getHighestPossibleValue(IRBuilder<> &IRB, Value *A, Value *Sa, - bool isSigned) { + bool isSigned) { if (isSigned) { // Split shadow into sign bit and other bits. Value *SaOtherBits = IRB.CreateLShr(IRB.CreateShl(Sa, 1), 1); Value *SaSignBit = IRB.CreateXor(Sa, SaOtherBits); // Minimise the undefined shadow bit, maximise other undefined bits. - return - IRB.CreateOr(IRB.CreateAnd(A, IRB.CreateNot(SaSignBit)), SaOtherBits); + return IRB.CreateOr(IRB.CreateAnd(A, IRB.CreateNot(SaSignBit)), + SaOtherBits); } else { // Maximize undefined bits. return IRB.CreateOr(A, Sa); @@ -2485,9 +2713,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { handleShadowOr(I); } - void visitFCmpInst(FCmpInst &I) { - handleShadowOr(I); - } + void visitFCmpInst(FCmpInst &I) { handleShadowOr(I); } void handleShift(BinaryOperator &I) { IRBuilder<> IRB(&I); @@ -2495,8 +2721,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Otherwise perform the same shift on S1. Value *S1 = getShadow(&I, 0); Value *S2 = getShadow(&I, 1); - Value *S2Conv = IRB.CreateSExt(IRB.CreateICmpNE(S2, getCleanShadow(S2)), - S2->getType()); + Value *S2Conv = + IRB.CreateSExt(IRB.CreateICmpNE(S2, getCleanShadow(S2)), S2->getType()); Value *V2 = I.getOperand(1); Value *Shift = IRB.CreateBinOp(I.getOpcode(), S1, V2); setShadow(&I, IRB.CreateOr(Shift, S2Conv)); @@ -2545,10 +2771,20 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { I.eraseFromParent(); } - // Similar to memmove: avoid copying shadow twice. - // This is somewhat unfortunate as it may slowdown small constant memcpys. - // FIXME: consider doing manual inline for small constant sizes and proper - // alignment. + /// Instrument memcpy + /// + /// Similar to memmove: avoid copying shadow twice. This is somewhat + /// unfortunate as it may slowdown small constant memcpys. + /// FIXME: consider doing manual inline for small constant sizes and proper + /// alignment. + /// + /// Note: This also handles memcpy.inline, which promises no calls to external + /// functions as an optimization. However, with instrumentation enabled this + /// is difficult to promise; additionally, we know that the MSan runtime + /// exists and provides __msan_memcpy(). Therefore, we assume that with + /// instrumentation it's safe to turn memcpy.inline into a call to + /// __msan_memcpy(). Should this be wrong, such as when implementing memcpy() + /// itself, instrumentation should be disabled with the no_sanitize attribute. void visitMemCpyInst(MemCpyInst &I) { getShadow(I.getArgOperand(1)); // Ensure shadow initialized IRBuilder<> IRB(&I); @@ -2571,13 +2807,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { I.eraseFromParent(); } - void visitVAStartInst(VAStartInst &I) { - VAHelper->visitVAStartInst(I); - } + void visitVAStartInst(VAStartInst &I) { VAHelper->visitVAStartInst(I); } - void visitVACopyInst(VACopyInst &I) { - VAHelper->visitVACopyInst(I); - } + void visitVACopyInst(VACopyInst &I) { VAHelper->visitVACopyInst(I); } /// Handle vector store-like intrinsics. /// @@ -2585,7 +2817,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// has 1 pointer argument and 1 vector argument, returns void. bool handleVectorStoreIntrinsic(IntrinsicInst &I) { IRBuilder<> IRB(&I); - Value* Addr = I.getArgOperand(0); + Value *Addr = I.getArgOperand(0); Value *Shadow = getShadow(&I, 1); Value *ShadowPtr, *OriginPtr; @@ -2599,7 +2831,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { insertShadowCheck(Addr, &I); // FIXME: factor out common code from materializeStores - if (MS.TrackOrigins) IRB.CreateStore(getOrigin(&I, 1), OriginPtr); + if (MS.TrackOrigins) + IRB.CreateStore(getOrigin(&I, 1), OriginPtr); return true; } @@ -2645,8 +2878,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Caller guarantees that this intrinsic does not access memory. bool maybeHandleSimpleNomemIntrinsic(IntrinsicInst &I) { Type *RetTy = I.getType(); - if (!(RetTy->isIntOrIntVectorTy() || - RetTy->isFPOrFPVectorTy() || + if (!(RetTy->isIntOrIntVectorTy() || RetTy->isFPOrFPVectorTy() || RetTy->isX86_MMXTy())) return false; @@ -2681,19 +2913,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (NumArgOperands == 0) return false; - if (NumArgOperands == 2 && - I.getArgOperand(0)->getType()->isPointerTy() && + if (NumArgOperands == 2 && I.getArgOperand(0)->getType()->isPointerTy() && I.getArgOperand(1)->getType()->isVectorTy() && - I.getType()->isVoidTy() && - !I.onlyReadsMemory()) { + I.getType()->isVoidTy() && !I.onlyReadsMemory()) { // This looks like a vector store. return handleVectorStoreIntrinsic(I); } - if (NumArgOperands == 1 && - I.getArgOperand(0)->getType()->isPointerTy() && - I.getType()->isVectorTy() && - I.onlyReadsMemory()) { + if (NumArgOperands == 1 && I.getArgOperand(0)->getType()->isPointerTy() && + I.getType()->isVectorTy() && I.onlyReadsMemory()) { // This looks like a vector load. return handleVectorLoadIntrinsic(I); } @@ -2725,11 +2953,32 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Op = I.getArgOperand(0); Type *OpType = Op->getType(); Function *BswapFunc = Intrinsic::getDeclaration( - F.getParent(), Intrinsic::bswap, makeArrayRef(&OpType, 1)); + F.getParent(), Intrinsic::bswap, ArrayRef(&OpType, 1)); setShadow(&I, IRB.CreateCall(BswapFunc, getShadow(Op))); setOrigin(&I, getOrigin(Op)); } + void handleCountZeroes(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Src = I.getArgOperand(0); + + // Set the Output shadow based on input Shadow + Value *BoolShadow = IRB.CreateIsNotNull(getShadow(Src), "_mscz_bs"); + + // If zero poison is requested, mix in with the shadow + Constant *IsZeroPoison = cast<Constant>(I.getOperand(1)); + if (!IsZeroPoison->isZeroValue()) { + Value *BoolZeroPoison = IRB.CreateIsNull(Src, "_mscz_bzp"); + BoolShadow = IRB.CreateOr(BoolShadow, BoolZeroPoison, "_mscz_bs"); + } + + Value *OutputShadow = + IRB.CreateSExt(BoolShadow, getShadowTy(Src), "_mscz_os"); + + setShadow(&I, OutputShadow); + setOriginForNaryOp(I); + } + // Instrument vector convert intrinsic. // // This function instruments intrinsics like cvtsi2ss: @@ -2873,30 +3122,30 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // intrinsic. Intrinsic::ID getSignedPackIntrinsic(Intrinsic::ID id) { switch (id) { - case Intrinsic::x86_sse2_packsswb_128: - case Intrinsic::x86_sse2_packuswb_128: - return Intrinsic::x86_sse2_packsswb_128; + case Intrinsic::x86_sse2_packsswb_128: + case Intrinsic::x86_sse2_packuswb_128: + return Intrinsic::x86_sse2_packsswb_128; - case Intrinsic::x86_sse2_packssdw_128: - case Intrinsic::x86_sse41_packusdw: - return Intrinsic::x86_sse2_packssdw_128; + case Intrinsic::x86_sse2_packssdw_128: + case Intrinsic::x86_sse41_packusdw: + return Intrinsic::x86_sse2_packssdw_128; - case Intrinsic::x86_avx2_packsswb: - case Intrinsic::x86_avx2_packuswb: - return Intrinsic::x86_avx2_packsswb; + case Intrinsic::x86_avx2_packsswb: + case Intrinsic::x86_avx2_packuswb: + return Intrinsic::x86_avx2_packsswb; - case Intrinsic::x86_avx2_packssdw: - case Intrinsic::x86_avx2_packusdw: - return Intrinsic::x86_avx2_packssdw; + case Intrinsic::x86_avx2_packssdw: + case Intrinsic::x86_avx2_packusdw: + return Intrinsic::x86_avx2_packssdw; - case Intrinsic::x86_mmx_packsswb: - case Intrinsic::x86_mmx_packuswb: - return Intrinsic::x86_mmx_packsswb; + case Intrinsic::x86_mmx_packsswb: + case Intrinsic::x86_mmx_packuswb: + return Intrinsic::x86_mmx_packsswb; - case Intrinsic::x86_mmx_packssdw: - return Intrinsic::x86_mmx_packssdw; - default: - llvm_unreachable("unexpected intrinsic id"); + case Intrinsic::x86_mmx_packssdw: + return Intrinsic::x86_mmx_packssdw; + default: + llvm_unreachable("unexpected intrinsic id"); } } @@ -2923,10 +3172,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { S1 = IRB.CreateBitCast(S1, T); S2 = IRB.CreateBitCast(S2, T); } - Value *S1_ext = IRB.CreateSExt( - IRB.CreateICmpNE(S1, Constant::getNullValue(T)), T); - Value *S2_ext = IRB.CreateSExt( - IRB.CreateICmpNE(S2, Constant::getNullValue(T)), T); + Value *S1_ext = + IRB.CreateSExt(IRB.CreateICmpNE(S1, Constant::getNullValue(T)), T); + Value *S2_ext = + IRB.CreateSExt(IRB.CreateICmpNE(S2, Constant::getNullValue(T)), T); if (isX86_MMX) { Type *X86_MMXTy = Type::getX86_MMXTy(*MS.C); S1_ext = IRB.CreateBitCast(S1_ext, X86_MMXTy); @@ -2938,7 +3187,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *S = IRB.CreateCall(ShadowFn, {S1_ext, S2_ext}, "_msprop_vector_pack"); - if (isX86_MMX) S = IRB.CreateBitCast(S, getShadowTy(&I)); + if (isX86_MMX) + S = IRB.CreateBitCast(S, getShadowTy(&I)); setShadow(&I, S); setOriginForNaryOp(I); } @@ -2952,7 +3202,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { ResTy->getScalarSizeInBits() - SignificantBitsPerResultElement; IRBuilder<> IRB(&I); - Value *S = IRB.CreateOr(getShadow(&I, 0), getShadow(&I, 1)); + auto *Shadow0 = getShadow(&I, 0); + auto *Shadow1 = getShadow(&I, 1); + Value *S = IRB.CreateOr(Shadow0, Shadow1); S = IRB.CreateBitCast(S, ResTy); S = IRB.CreateSExt(IRB.CreateICmpNE(S, Constant::getNullValue(ResTy)), ResTy); @@ -2968,7 +3220,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { bool isX86_MMX = I.getOperand(0)->getType()->isX86_MMXTy(); Type *ResTy = isX86_MMX ? getMMXVectorTy(EltSizeInBits * 2) : I.getType(); IRBuilder<> IRB(&I); - Value *S = IRB.CreateOr(getShadow(&I, 0), getShadow(&I, 1)); + auto *Shadow0 = getShadow(&I, 0); + auto *Shadow1 = getShadow(&I, 1); + Value *S = IRB.CreateOr(Shadow0, Shadow1); S = IRB.CreateBitCast(S, ResTy); S = IRB.CreateSExt(IRB.CreateICmpNE(S, Constant::getNullValue(ResTy)), ResTy); @@ -2983,7 +3237,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void handleVectorComparePackedIntrinsic(IntrinsicInst &I) { IRBuilder<> IRB(&I); Type *ResTy = getShadowTy(&I); - Value *S0 = IRB.CreateOr(getShadow(&I, 0), getShadow(&I, 1)); + auto *Shadow0 = getShadow(&I, 0); + auto *Shadow1 = getShadow(&I, 1); + Value *S0 = IRB.CreateOr(Shadow0, Shadow1); Value *S = IRB.CreateSExt( IRB.CreateICmpNE(S0, Constant::getNullValue(ResTy)), ResTy); setShadow(&I, S); @@ -2995,7 +3251,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // element of a vector, and comi* which return the result as i32. void handleVectorCompareScalarIntrinsic(IntrinsicInst &I) { IRBuilder<> IRB(&I); - Value *S0 = IRB.CreateOr(getShadow(&I, 0), getShadow(&I, 1)); + auto *Shadow0 = getShadow(&I, 0); + auto *Shadow1 = getShadow(&I, 1); + Value *S0 = IRB.CreateOr(Shadow0, Shadow1); Value *S = LowerElementShadowExtend(IRB, S0, getShadowTy(&I)); setShadow(&I, S); setOriginForNaryOp(I); @@ -3047,7 +3305,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void handleStmxcsr(IntrinsicInst &I) { IRBuilder<> IRB(&I); - Value* Addr = I.getArgOperand(0); + Value *Addr = I.getArgOperand(0); Type *Ty = IRB.getInt32Ty(); Value *ShadowPtr = getShadowOriginPtr(Addr, IRB, Ty, Align(1), /*isStore*/ true).first; @@ -3060,7 +3318,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } void handleLdmxcsr(IntrinsicInst &I) { - if (!InsertChecks) return; + if (!InsertChecks) + return; IRBuilder<> IRB(&I); Value *Addr = I.getArgOperand(0); @@ -3079,93 +3338,201 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { insertShadowCheck(Shadow, Origin, &I); } + void handleMaskedExpandLoad(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Ptr = I.getArgOperand(0); + Value *Mask = I.getArgOperand(1); + Value *PassThru = I.getArgOperand(2); + + if (ClCheckAccessAddress) { + insertShadowCheck(Ptr, &I); + insertShadowCheck(Mask, &I); + } + + if (!PropagateShadow) { + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); + return; + } + + Type *ShadowTy = getShadowTy(&I); + Type *ElementShadowTy = cast<FixedVectorType>(ShadowTy)->getElementType(); + auto [ShadowPtr, OriginPtr] = + getShadowOriginPtr(Ptr, IRB, ElementShadowTy, {}, /*isStore*/ false); + + Value *Shadow = IRB.CreateMaskedExpandLoad( + ShadowTy, ShadowPtr, Mask, getShadow(PassThru), "_msmaskedexpload"); + + setShadow(&I, Shadow); + + // TODO: Store origins. + setOrigin(&I, getCleanOrigin()); + } + + void handleMaskedCompressStore(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Values = I.getArgOperand(0); + Value *Ptr = I.getArgOperand(1); + Value *Mask = I.getArgOperand(2); + + if (ClCheckAccessAddress) { + insertShadowCheck(Ptr, &I); + insertShadowCheck(Mask, &I); + } + + Value *Shadow = getShadow(Values); + Type *ElementShadowTy = + getShadowTy(cast<FixedVectorType>(Values->getType())->getElementType()); + auto [ShadowPtr, OriginPtrs] = + getShadowOriginPtr(Ptr, IRB, ElementShadowTy, {}, /*isStore*/ true); + + IRB.CreateMaskedCompressStore(Shadow, ShadowPtr, Mask); + + // TODO: Store origins. + } + + void handleMaskedGather(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Ptrs = I.getArgOperand(0); + const Align Alignment( + cast<ConstantInt>(I.getArgOperand(1))->getZExtValue()); + Value *Mask = I.getArgOperand(2); + Value *PassThru = I.getArgOperand(3); + + Type *PtrsShadowTy = getShadowTy(Ptrs); + if (ClCheckAccessAddress) { + insertShadowCheck(Mask, &I); + Value *MaskedPtrShadow = IRB.CreateSelect( + Mask, getShadow(Ptrs), Constant::getNullValue((PtrsShadowTy)), + "_msmaskedptrs"); + insertShadowCheck(MaskedPtrShadow, getOrigin(Ptrs), &I); + } + + if (!PropagateShadow) { + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); + return; + } + + Type *ShadowTy = getShadowTy(&I); + Type *ElementShadowTy = cast<FixedVectorType>(ShadowTy)->getElementType(); + auto [ShadowPtrs, OriginPtrs] = getShadowOriginPtr( + Ptrs, IRB, ElementShadowTy, Alignment, /*isStore*/ false); + + Value *Shadow = + IRB.CreateMaskedGather(ShadowTy, ShadowPtrs, Alignment, Mask, + getShadow(PassThru), "_msmaskedgather"); + + setShadow(&I, Shadow); + + // TODO: Store origins. + setOrigin(&I, getCleanOrigin()); + } + + void handleMaskedScatter(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Values = I.getArgOperand(0); + Value *Ptrs = I.getArgOperand(1); + const Align Alignment( + cast<ConstantInt>(I.getArgOperand(2))->getZExtValue()); + Value *Mask = I.getArgOperand(3); + + Type *PtrsShadowTy = getShadowTy(Ptrs); + if (ClCheckAccessAddress) { + insertShadowCheck(Mask, &I); + Value *MaskedPtrShadow = IRB.CreateSelect( + Mask, getShadow(Ptrs), Constant::getNullValue((PtrsShadowTy)), + "_msmaskedptrs"); + insertShadowCheck(MaskedPtrShadow, getOrigin(Ptrs), &I); + } + + Value *Shadow = getShadow(Values); + Type *ElementShadowTy = + getShadowTy(cast<FixedVectorType>(Values->getType())->getElementType()); + auto [ShadowPtrs, OriginPtrs] = getShadowOriginPtr( + Ptrs, IRB, ElementShadowTy, Alignment, /*isStore*/ true); + + IRB.CreateMaskedScatter(Shadow, ShadowPtrs, Alignment, Mask); + + // TODO: Store origin. + } + void handleMaskedStore(IntrinsicInst &I) { IRBuilder<> IRB(&I); Value *V = I.getArgOperand(0); - Value *Addr = I.getArgOperand(1); + Value *Ptr = I.getArgOperand(1); const Align Alignment( cast<ConstantInt>(I.getArgOperand(2))->getZExtValue()); Value *Mask = I.getArgOperand(3); Value *Shadow = getShadow(V); - Value *ShadowPtr; - Value *OriginPtr; - std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr( - Addr, IRB, Shadow->getType(), Alignment, /*isStore*/ true); - if (ClCheckAccessAddress) { - insertShadowCheck(Addr, &I); - // Uninitialized mask is kind of like uninitialized address, but not as - // scary. + insertShadowCheck(Ptr, &I); insertShadowCheck(Mask, &I); } + Value *ShadowPtr; + Value *OriginPtr; + std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr( + Ptr, IRB, Shadow->getType(), Alignment, /*isStore*/ true); + IRB.CreateMaskedStore(Shadow, ShadowPtr, Alignment, Mask); - if (MS.TrackOrigins) { - auto &DL = F.getParent()->getDataLayout(); - paintOrigin(IRB, getOrigin(V), OriginPtr, - DL.getTypeStoreSize(Shadow->getType()), - std::max(Alignment, kMinOriginAlignment)); - } + if (!MS.TrackOrigins) + return; + + auto &DL = F.getParent()->getDataLayout(); + paintOrigin(IRB, getOrigin(V), OriginPtr, + DL.getTypeStoreSize(Shadow->getType()), + std::max(Alignment, kMinOriginAlignment)); } - bool handleMaskedLoad(IntrinsicInst &I) { + void handleMaskedLoad(IntrinsicInst &I) { IRBuilder<> IRB(&I); - Value *Addr = I.getArgOperand(0); + Value *Ptr = I.getArgOperand(0); const Align Alignment( cast<ConstantInt>(I.getArgOperand(1))->getZExtValue()); Value *Mask = I.getArgOperand(2); Value *PassThru = I.getArgOperand(3); - Type *ShadowTy = getShadowTy(&I); - Value *ShadowPtr, *OriginPtr; - if (PropagateShadow) { - std::tie(ShadowPtr, OriginPtr) = - getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment, /*isStore*/ false); - setShadow(&I, IRB.CreateMaskedLoad(ShadowTy, ShadowPtr, Alignment, Mask, - getShadow(PassThru), "_msmaskedld")); - } else { - setShadow(&I, getCleanShadow(&I)); - } - if (ClCheckAccessAddress) { - insertShadowCheck(Addr, &I); + insertShadowCheck(Ptr, &I); insertShadowCheck(Mask, &I); } - if (MS.TrackOrigins) { - if (PropagateShadow) { - // Choose between PassThru's and the loaded value's origins. - Value *MaskedPassThruShadow = IRB.CreateAnd( - getShadow(PassThru), IRB.CreateSExt(IRB.CreateNeg(Mask), ShadowTy)); - - Value *Acc = IRB.CreateExtractElement( - MaskedPassThruShadow, ConstantInt::get(IRB.getInt32Ty(), 0)); - for (int i = 1, N = cast<FixedVectorType>(PassThru->getType()) - ->getNumElements(); - i < N; ++i) { - Value *More = IRB.CreateExtractElement( - MaskedPassThruShadow, ConstantInt::get(IRB.getInt32Ty(), i)); - Acc = IRB.CreateOr(Acc, More); - } + if (!PropagateShadow) { + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); + return; + } - Value *Origin = IRB.CreateSelect( - IRB.CreateICmpNE(Acc, Constant::getNullValue(Acc->getType())), - getOrigin(PassThru), IRB.CreateLoad(MS.OriginTy, OriginPtr)); + Type *ShadowTy = getShadowTy(&I); + Value *ShadowPtr, *OriginPtr; + std::tie(ShadowPtr, OriginPtr) = + getShadowOriginPtr(Ptr, IRB, ShadowTy, Alignment, /*isStore*/ false); + setShadow(&I, IRB.CreateMaskedLoad(ShadowTy, ShadowPtr, Alignment, Mask, + getShadow(PassThru), "_msmaskedld")); - setOrigin(&I, Origin); - } else { - setOrigin(&I, getCleanOrigin()); - } - } - return true; + if (!MS.TrackOrigins) + return; + + // Choose between PassThru's and the loaded value's origins. + Value *MaskedPassThruShadow = IRB.CreateAnd( + getShadow(PassThru), IRB.CreateSExt(IRB.CreateNeg(Mask), ShadowTy)); + + Value *ConvertedShadow = convertShadowToScalar(MaskedPassThruShadow, IRB); + Value *NotNull = convertToBool(ConvertedShadow, IRB, "_mscmp"); + + Value *PtrOrigin = IRB.CreateLoad(MS.OriginTy, OriginPtr); + Value *Origin = IRB.CreateSelect(NotNull, getOrigin(PassThru), PtrOrigin); + + setOrigin(&I, Origin); } // Instrument BMI / BMI2 intrinsics. // All of these intrinsics are Z = I(X, Y) - // where the types of all operands and the result match, and are either i32 or i64. - // The following instrumentation happens to work for all of them: + // where the types of all operands and the result match, and are either i32 or + // i64. The following instrumentation happens to work for all of them: // Sz = I(Sx, Y) | (sext (Sy != 0)) void handleBmiIntrinsic(IntrinsicInst &I) { IRBuilder<> IRB(&I); @@ -3234,6 +3601,19 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } + void handleVtestIntrinsic(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Shadow0 = getShadow(&I, 0); + Value *Shadow1 = getShadow(&I, 1); + Value *Or = IRB.CreateOr(Shadow0, Shadow1); + Value *NZ = IRB.CreateICmpNE(Or, Constant::getNullValue(Or->getType())); + Value *Scalar = convertShadowToScalar(NZ, IRB); + Value *Shadow = IRB.CreateZExt(Scalar, getShadowTy(&I)); + + setShadow(&I, Shadow); + setOriginForNaryOp(I); + } + void handleBinarySdSsIntrinsic(IntrinsicInst &I) { IRBuilder<> IRB(&I); unsigned Width = @@ -3280,6 +3660,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { case Intrinsic::bswap: handleBswap(I); break; + case Intrinsic::ctlz: + case Intrinsic::cttz: + handleCountZeroes(I); + break; + case Intrinsic::masked_compressstore: + handleMaskedCompressStore(I); + break; + case Intrinsic::masked_expandload: + handleMaskedExpandLoad(I); + break; + case Intrinsic::masked_gather: + handleMaskedGather(I); + break; + case Intrinsic::masked_scatter: + handleMaskedScatter(I); + break; case Intrinsic::masked_store: handleMaskedStore(I); break; @@ -3495,11 +3891,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { handleVectorCompareScalarIntrinsic(I); break; - case Intrinsic::x86_sse_cmp_ps: + case Intrinsic::x86_avx_cmp_pd_256: + case Intrinsic::x86_avx_cmp_ps_256: case Intrinsic::x86_sse2_cmp_pd: - // FIXME: For x86_avx_cmp_pd_256 and x86_avx_cmp_ps_256 this function - // generates reasonably looking IR that fails in the backend with "Do not - // know how to split the result of this operator!". + case Intrinsic::x86_sse_cmp_ps: handleVectorComparePackedIntrinsic(I); break; @@ -3531,6 +3926,27 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { handleBinarySdSsIntrinsic(I); break; + case Intrinsic::x86_avx_vtestc_pd: + case Intrinsic::x86_avx_vtestc_pd_256: + case Intrinsic::x86_avx_vtestc_ps: + case Intrinsic::x86_avx_vtestc_ps_256: + case Intrinsic::x86_avx_vtestnzc_pd: + case Intrinsic::x86_avx_vtestnzc_pd_256: + case Intrinsic::x86_avx_vtestnzc_ps: + case Intrinsic::x86_avx_vtestnzc_ps_256: + case Intrinsic::x86_avx_vtestz_pd: + case Intrinsic::x86_avx_vtestz_pd_256: + case Intrinsic::x86_avx_vtestz_ps: + case Intrinsic::x86_avx_vtestz_ps_256: + case Intrinsic::x86_avx_ptestc_256: + case Intrinsic::x86_avx_ptestnzc_256: + case Intrinsic::x86_avx_ptestz_256: + case Intrinsic::x86_sse41_ptestc: + case Intrinsic::x86_sse41_ptestnzc: + case Intrinsic::x86_sse41_ptestz: + handleVtestIntrinsic(I); + break; + case Intrinsic::fshl: case Intrinsic::fshr: handleFunnelShift(I); @@ -3564,9 +3980,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRB.CreateExtractElement(makeAddAcquireOrderingTable(IRB), Ordering); CB.setArgOperand(3, NewOrdering); - IRBuilder<> NextIRB(CB.getNextNode()); - NextIRB.SetCurrentDebugLocation(CB.getDebugLoc()); - + NextNodeIRBuilder NextIRB(&CB); Value *SrcShadowPtr, *SrcOriginPtr; std::tie(SrcShadowPtr, SrcOriginPtr) = getShadowOriginPtr(SrcPtr, NextIRB, NextIRB.getInt8Ty(), Align(1), @@ -3648,12 +4062,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // will become a non-readonly function after it is instrumented by us. To // prevent this code from being optimized out, mark that function // non-readonly in advance. + // TODO: We can likely do better than dropping memory() completely here. AttributeMask B; - B.addAttribute(Attribute::ReadOnly) - .addAttribute(Attribute::ReadNone) - .addAttribute(Attribute::WriteOnly) - .addAttribute(Attribute::ArgMemOnly) - .addAttribute(Attribute::Speculatable); + B.addAttribute(Attribute::Memory).addAttribute(Attribute::Speculatable); Call->removeFnAttrs(B); if (Function *Func = Call->getCalledFunction()) { @@ -3672,10 +4083,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { unsigned ArgOffset = 0; LLVM_DEBUG(dbgs() << " CallSite: " << CB << "\n"); - for (auto ArgIt = CB.arg_begin(), End = CB.arg_end(); ArgIt != End; - ++ArgIt) { - Value *A = *ArgIt; - unsigned i = ArgIt - CB.arg_begin(); + for (const auto &[i, A] : llvm::enumerate(CB.args())) { if (!A->getType()->isSized()) { LLVM_DEBUG(dbgs() << "Arg " << i << " is not sized: " << CB << "\n"); continue; @@ -3708,7 +4116,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ArgOffset + Size > kParamTLSSize) break; const MaybeAlign ParamAlignment(CB.getParamAlign(i)); - MaybeAlign Alignment = llvm::None; + MaybeAlign Alignment = std::nullopt; if (ParamAlignment) Alignment = std::min(*ParamAlignment, kShadowTLSAlignment); Value *AShadowPtr, *AOriginPtr; @@ -3794,8 +4202,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&CB, getCleanOrigin()); return; } - // FIXME: NextInsn is likely in a basic block that has not been visited yet. - // Anything inserted there will be instrumented by MSan later! + // FIXME: NextInsn is likely in a basic block that has not been visited + // yet. Anything inserted there will be instrumented by MSan later! NextInsn = NormalDest->getFirstInsertionPt(); assert(NextInsn != NormalDest->end() && "Could not find insertion point for retval shadow load"); @@ -3823,12 +4231,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitReturnInst(ReturnInst &I) { IRBuilder<> IRB(&I); Value *RetVal = I.getReturnValue(); - if (!RetVal) return; + if (!RetVal) + return; // Don't emit the epilogue for musttail call returns. - if (isAMustTailRetVal(RetVal)) return; + if (isAMustTailRetVal(RetVal)) + return; Value *ShadowPtr = getShadowPtrForRetval(RetVal, IRB); - bool HasNoUndef = - F.hasRetAttribute(Attribute::NoUndef); + bool HasNoUndef = F.hasRetAttribute(Attribute::NoUndef); bool StoreShadow = !(MS.EagerChecks && HasNoUndef); // FIXME: Consider using SpecialCaseList to specify a list of functions that // must always return fully initialized values. For now, we hardcode "main". @@ -3863,21 +4272,20 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setShadow(&I, IRB.CreatePHI(getShadowTy(&I), I.getNumIncomingValues(), "_msphi_s")); if (MS.TrackOrigins) - setOrigin(&I, IRB.CreatePHI(MS.OriginTy, I.getNumIncomingValues(), - "_msphi_o")); + setOrigin( + &I, IRB.CreatePHI(MS.OriginTy, I.getNumIncomingValues(), "_msphi_o")); + } + + Value *getLocalVarIdptr(AllocaInst &I) { + ConstantInt *IntConst = + ConstantInt::get(Type::getInt32Ty((*F.getParent()).getContext()), 0); + return new GlobalVariable(*F.getParent(), IntConst->getType(), + /*isConstant=*/false, GlobalValue::PrivateLinkage, + IntConst); } Value *getLocalVarDescription(AllocaInst &I) { - SmallString<2048> StackDescriptionStorage; - raw_svector_ostream StackDescription(StackDescriptionStorage); - // We create a string with a description of the stack allocation and - // pass it into __msan_set_alloca_origin. - // It will be printed by the run-time if stack-originated UMR is found. - // The first 4 bytes of the string are set to '----' and will be replaced - // by __msan_va_arg_overflow_size_tls at the first call. - StackDescription << "----" << I.getName() << "@" << F.getName(); - return createPrivateNonConstGlobalForString(*F.getParent(), - StackDescription.str()); + return createPrivateConstGlobalForString(*F.getParent(), I.getName()); } void poisonAllocaUserspace(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { @@ -3894,11 +4302,18 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } if (PoisonStack && MS.TrackOrigins) { - Value *Descr = getLocalVarDescription(I); - IRB.CreateCall(MS.MsanSetAllocaOrigin4Fn, - {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, - IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy()), - IRB.CreatePointerCast(&F, MS.IntptrTy)}); + Value *Idptr = getLocalVarIdptr(I); + if (ClPrintStackNames) { + Value *Descr = getLocalVarDescription(I); + IRB.CreateCall(MS.MsanSetAllocaOriginWithDescriptionFn, + {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, + IRB.CreatePointerCast(Idptr, IRB.getInt8PtrTy()), + IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy())}); + } else { + IRB.CreateCall(MS.MsanSetAllocaOriginNoDescriptionFn, + {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, + IRB.CreatePointerCast(Idptr, IRB.getInt8PtrTy())}); + } } } @@ -3917,12 +4332,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void instrumentAlloca(AllocaInst &I, Instruction *InsPoint = nullptr) { if (!InsPoint) InsPoint = &I; - IRBuilder<> IRB(InsPoint->getNextNode()); + NextNodeIRBuilder IRB(InsPoint); const DataLayout &DL = F.getParent()->getDataLayout(); uint64_t TypeSize = DL.getTypeAllocSize(I.getAllocatedType()); Value *Len = ConstantInt::get(MS.IntptrTy, TypeSize); if (I.isArrayAllocation()) - Len = IRB.CreateMul(Len, I.getArraySize()); + Len = IRB.CreateMul(Len, + IRB.CreateZExtOrTrunc(I.getArraySize(), MS.IntptrTy)); if (MS.CompileKernel) poisonAllocaKmsan(I, IRB, Len); @@ -3938,7 +4354,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { AllocaSet.insert(&I); } - void visitSelectInst(SelectInst& I) { + void visitSelectInst(SelectInst &I) { IRBuilder<> IRB(&I); // a = select b, c, d Value *B = I.getCondition(); @@ -3977,9 +4393,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (B->getType()->isVectorTy()) { Type *FlatTy = getShadowTyNoVec(B->getType()); B = IRB.CreateICmpNE(IRB.CreateBitCast(B, FlatTy), - ConstantInt::getNullValue(FlatTy)); + ConstantInt::getNullValue(FlatTy)); Sb = IRB.CreateICmpNE(IRB.CreateBitCast(Sb, FlatTy), - ConstantInt::getNullValue(FlatTy)); + ConstantInt::getNullValue(FlatTy)); } // a = select b, c, d // Oa = Sb ? Ob : (b ? Oc : Od) @@ -4007,9 +4423,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getCleanOrigin()); } - void visitGetElementPtrInst(GetElementPtrInst &I) { - handleShadowOr(I); - } + void visitGetElementPtrInst(GetElementPtrInst &I) { handleShadowOr(I); } void visitExtractValueInst(ExtractValueInst &I) { IRBuilder<> IRB(&I); @@ -4177,7 +4591,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { struct VarArgAMD64Helper : public VarArgHelper { // An unfortunate workaround for asymmetric lowering of va_arg stuff. // See a comment in visitCallBase for more details. - static const unsigned AMD64GpEndOffset = 48; // AMD64 ABI Draft 0.99.6 p3.5.7 + static const unsigned AMD64GpEndOffset = 48; // AMD64 ABI Draft 0.99.6 p3.5.7 static const unsigned AMD64FpEndOffsetSSE = 176; // If SSE is disabled, fp_offset in va_list is zero. static const unsigned AMD64FpEndOffsetNoSSE = AMD64GpEndOffset; @@ -4190,7 +4604,7 @@ struct VarArgAMD64Helper : public VarArgHelper { Value *VAArgTLSOriginCopy = nullptr; Value *VAArgOverflowSize = nullptr; - SmallVector<CallInst*, 16> VAStartInstrumentationList; + SmallVector<CallInst *, 16> VAStartInstrumentationList; enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory }; @@ -4208,7 +4622,7 @@ struct VarArgAMD64Helper : public VarArgHelper { } } - ArgKind classifyArgument(Value* arg) { + ArgKind classifyArgument(Value *arg) { // A very rough approximation of X86_64 argument classification rules. Type *T = arg->getType(); if (T->isFPOrFPVectorTy() || T->isX86_MMXTy()) @@ -4233,10 +4647,7 @@ struct VarArgAMD64Helper : public VarArgHelper { unsigned FpOffset = AMD64GpEndOffset; unsigned OverflowOffset = AMD64FpEndOffset; const DataLayout &DL = F.getParent()->getDataLayout(); - for (auto ArgIt = CB.arg_begin(), End = CB.arg_end(); ArgIt != End; - ++ArgIt) { - Value *A = *ArgIt; - unsigned ArgNo = CB.getArgOperandNo(ArgIt); + for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); bool IsByVal = CB.paramHasAttr(ArgNo, Attribute::ByVal); if (IsByVal) { @@ -4274,32 +4685,30 @@ struct VarArgAMD64Helper : public VarArgHelper { AK = AK_Memory; Value *ShadowBase, *OriginBase = nullptr; switch (AK) { - case AK_GeneralPurpose: - ShadowBase = - getShadowPtrForVAArgument(A->getType(), IRB, GpOffset, 8); - if (MS.TrackOrigins) - OriginBase = - getOriginPtrForVAArgument(A->getType(), IRB, GpOffset); - GpOffset += 8; - break; - case AK_FloatingPoint: - ShadowBase = - getShadowPtrForVAArgument(A->getType(), IRB, FpOffset, 16); - if (MS.TrackOrigins) - OriginBase = - getOriginPtrForVAArgument(A->getType(), IRB, FpOffset); - FpOffset += 16; - break; - case AK_Memory: - if (IsFixed) - continue; - uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); - ShadowBase = - getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, 8); - if (MS.TrackOrigins) - OriginBase = - getOriginPtrForVAArgument(A->getType(), IRB, OverflowOffset); - OverflowOffset += alignTo(ArgSize, 8); + case AK_GeneralPurpose: + ShadowBase = + getShadowPtrForVAArgument(A->getType(), IRB, GpOffset, 8); + if (MS.TrackOrigins) + OriginBase = getOriginPtrForVAArgument(A->getType(), IRB, GpOffset); + GpOffset += 8; + break; + case AK_FloatingPoint: + ShadowBase = + getShadowPtrForVAArgument(A->getType(), IRB, FpOffset, 16); + if (MS.TrackOrigins) + OriginBase = getOriginPtrForVAArgument(A->getType(), IRB, FpOffset); + FpOffset += 16; + break; + case AK_Memory: + if (IsFixed) + continue; + uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); + ShadowBase = + getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, 8); + if (MS.TrackOrigins) + OriginBase = + getOriginPtrForVAArgument(A->getType(), IRB, OverflowOffset); + OverflowOffset += alignTo(ArgSize, 8); } // Take fixed arguments into account for GpOffset and FpOffset, // but don't actually store shadows for them. @@ -4319,7 +4728,7 @@ struct VarArgAMD64Helper : public VarArgHelper { } } Constant *OverflowSize = - ConstantInt::get(IRB.getInt64Ty(), OverflowOffset - AMD64FpEndOffset); + ConstantInt::get(IRB.getInt64Ty(), OverflowOffset - AMD64FpEndOffset); IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS); } @@ -4371,7 +4780,8 @@ struct VarArgAMD64Helper : public VarArgHelper { } void visitVACopyInst(VACopyInst &I) override { - if (F.getCallingConv() == CallingConv::Win64) return; + if (F.getCallingConv() == CallingConv::Win64) + return; unpoisonVAListTagForInst(I); } @@ -4384,9 +4794,8 @@ struct VarArgAMD64Helper : public VarArgHelper { IRBuilder<> IRB(MSV.FnPrologueEnd); VAArgOverflowSize = IRB.CreateLoad(IRB.getInt64Ty(), MS.VAArgOverflowSizeTLS); - Value *CopySize = - IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, AMD64FpEndOffset), - VAArgOverflowSize); + Value *CopySize = IRB.CreateAdd( + ConstantInt::get(MS.IntptrTy, AMD64FpEndOffset), VAArgOverflowSize); VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize); if (MS.TrackOrigins) { @@ -4400,7 +4809,7 @@ struct VarArgAMD64Helper : public VarArgHelper { // Copy va_list shadow from the backup copy of the TLS contents. for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { CallInst *OrigInst = VAStartInstrumentationList[i]; - IRBuilder<> IRB(OrigInst->getNextNode()); + NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); @@ -4453,24 +4862,23 @@ struct VarArgMIPS64Helper : public VarArgHelper { Value *VAArgTLSCopy = nullptr; Value *VAArgSize = nullptr; - SmallVector<CallInst*, 16> VAStartInstrumentationList; + SmallVector<CallInst *, 16> VAStartInstrumentationList; VarArgMIPS64Helper(Function &F, MemorySanitizer &MS, - MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) {} + MemorySanitizerVisitor &MSV) + : F(F), MS(MS), MSV(MSV) {} void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override { unsigned VAArgOffset = 0; const DataLayout &DL = F.getParent()->getDataLayout(); - for (auto ArgIt = CB.arg_begin() + CB.getFunctionType()->getNumParams(), - End = CB.arg_end(); - ArgIt != End; ++ArgIt) { + for (Value *A : + llvm::drop_begin(CB.args(), CB.getFunctionType()->getNumParams())) { Triple TargetTriple(F.getParent()->getTargetTriple()); - Value *A = *ArgIt; Value *Base; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); if (TargetTriple.getArch() == Triple::mips64) { - // Adjusting the shadow for argument with size < 8 to match the placement - // of bits in big endian system + // Adjusting the shadow for argument with size < 8 to match the + // placement of bits in big endian system if (ArgSize < 8) VAArgOffset += (8 - ArgSize); } @@ -4529,8 +4937,8 @@ struct VarArgMIPS64Helper : public VarArgHelper { "finalizeInstrumentation called twice"); IRBuilder<> IRB(MSV.FnPrologueEnd); VAArgSize = IRB.CreateLoad(IRB.getInt64Ty(), MS.VAArgOverflowSizeTLS); - Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), - VAArgSize); + Value *CopySize = + IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), VAArgSize); if (!VAStartInstrumentationList.empty()) { // If there is a va_start in this function, make a backup copy of @@ -4543,7 +4951,7 @@ struct VarArgMIPS64Helper : public VarArgHelper { // Copy va_list shadow from the backup copy of the TLS contents. for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { CallInst *OrigInst = VAStartInstrumentationList[i]; - IRBuilder<> IRB(OrigInst->getNextNode()); + NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); Value *RegSaveAreaPtrPtr = @@ -4571,8 +4979,8 @@ struct VarArgAArch64Helper : public VarArgHelper { static const unsigned AArch64GrEndOffset = kAArch64GrArgSize; // Make VR space aligned to 16 bytes. static const unsigned AArch64VrBegOffset = AArch64GrEndOffset; - static const unsigned AArch64VrEndOffset = AArch64VrBegOffset - + kAArch64VrArgSize; + static const unsigned AArch64VrEndOffset = + AArch64VrBegOffset + kAArch64VrArgSize; static const unsigned AArch64VAEndOffset = AArch64VrEndOffset; Function &F; @@ -4581,19 +4989,20 @@ struct VarArgAArch64Helper : public VarArgHelper { Value *VAArgTLSCopy = nullptr; Value *VAArgOverflowSize = nullptr; - SmallVector<CallInst*, 16> VAStartInstrumentationList; + SmallVector<CallInst *, 16> VAStartInstrumentationList; enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory }; VarArgAArch64Helper(Function &F, MemorySanitizer &MS, - MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) {} + MemorySanitizerVisitor &MSV) + : F(F), MS(MS), MSV(MSV) {} - ArgKind classifyArgument(Value* arg) { + ArgKind classifyArgument(Value *arg) { Type *T = arg->getType(); if (T->isFPOrFPVectorTy()) return AK_FloatingPoint; - if ((T->isIntegerTy() && T->getPrimitiveSizeInBits() <= 64) - || (T->isPointerTy())) + if ((T->isIntegerTy() && T->getPrimitiveSizeInBits() <= 64) || + (T->isPointerTy())) return AK_GeneralPurpose; return AK_Memory; } @@ -4613,10 +5022,7 @@ struct VarArgAArch64Helper : public VarArgHelper { unsigned OverflowOffset = AArch64VAEndOffset; const DataLayout &DL = F.getParent()->getDataLayout(); - for (auto ArgIt = CB.arg_begin(), End = CB.arg_end(); ArgIt != End; - ++ArgIt) { - Value *A = *ArgIt; - unsigned ArgNo = CB.getArgOperandNo(ArgIt); + for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); ArgKind AK = classifyArgument(A); if (AK == AK_GeneralPurpose && GrOffset >= AArch64GrEndOffset) @@ -4625,24 +5031,24 @@ struct VarArgAArch64Helper : public VarArgHelper { AK = AK_Memory; Value *Base; switch (AK) { - case AK_GeneralPurpose: - Base = getShadowPtrForVAArgument(A->getType(), IRB, GrOffset, 8); - GrOffset += 8; - break; - case AK_FloatingPoint: - Base = getShadowPtrForVAArgument(A->getType(), IRB, VrOffset, 8); - VrOffset += 16; - break; - case AK_Memory: - // Don't count fixed arguments in the overflow area - va_start will - // skip right over them. - if (IsFixed) - continue; - uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); - Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, - alignTo(ArgSize, 8)); - OverflowOffset += alignTo(ArgSize, 8); - break; + case AK_GeneralPurpose: + Base = getShadowPtrForVAArgument(A->getType(), IRB, GrOffset, 8); + GrOffset += 8; + break; + case AK_FloatingPoint: + Base = getShadowPtrForVAArgument(A->getType(), IRB, VrOffset, 8); + VrOffset += 16; + break; + case AK_Memory: + // Don't count fixed arguments in the overflow area - va_start will + // skip right over them. + if (IsFixed) + continue; + uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); + Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, + alignTo(ArgSize, 8)); + OverflowOffset += alignTo(ArgSize, 8); + break; } // Count Gp/Vr fixed arguments to their respective offsets, but don't // bother to actually store a shadow. @@ -4653,7 +5059,7 @@ struct VarArgAArch64Helper : public VarArgHelper { IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } Constant *OverflowSize = - ConstantInt::get(IRB.getInt64Ty(), OverflowOffset - AArch64VAEndOffset); + ConstantInt::get(IRB.getInt64Ty(), OverflowOffset - AArch64VAEndOffset); IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS); } @@ -4694,9 +5100,8 @@ struct VarArgAArch64Helper : public VarArgHelper { } // Retrieve a va_list field of 'void*' size. - Value* getVAField64(IRBuilder<> &IRB, Value *VAListTag, int offset) { - Value *SaveAreaPtrPtr = - IRB.CreateIntToPtr( + Value *getVAField64(IRBuilder<> &IRB, Value *VAListTag, int offset) { + Value *SaveAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, offset)), Type::getInt64PtrTy(*MS.C)); @@ -4704,9 +5109,8 @@ struct VarArgAArch64Helper : public VarArgHelper { } // Retrieve a va_list field of 'int' size. - Value* getVAField32(IRBuilder<> &IRB, Value *VAListTag, int offset) { - Value *SaveAreaPtr = - IRB.CreateIntToPtr( + Value *getVAField32(IRBuilder<> &IRB, Value *VAListTag, int offset) { + Value *SaveAreaPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, offset)), Type::getInt32PtrTy(*MS.C)); @@ -4723,9 +5127,8 @@ struct VarArgAArch64Helper : public VarArgHelper { IRBuilder<> IRB(MSV.FnPrologueEnd); VAArgOverflowSize = IRB.CreateLoad(IRB.getInt64Ty(), MS.VAArgOverflowSizeTLS); - Value *CopySize = - IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, AArch64VAEndOffset), - VAArgOverflowSize); + Value *CopySize = IRB.CreateAdd( + ConstantInt::get(MS.IntptrTy, AArch64VAEndOffset), VAArgOverflowSize); VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); IRB.CreateMemCpy(VAArgTLSCopy, Align(8), MS.VAArgTLS, Align(8), CopySize); } @@ -4737,7 +5140,7 @@ struct VarArgAArch64Helper : public VarArgHelper { // the TLS contents. for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { CallInst *OrigInst = VAStartInstrumentationList[i]; - IRBuilder<> IRB(OrigInst->getNextNode()); + NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); @@ -4774,7 +5177,7 @@ struct VarArgAArch64Helper : public VarArgHelper { // '0 - ((8 - named_gr) * 8)', the idea is to just propagate the variadic // argument by ignoring the bytes of shadow from named arguments. Value *GrRegSaveAreaShadowPtrOff = - IRB.CreateAdd(GrArgSize, GrOffSaveArea); + IRB.CreateAdd(GrArgSize, GrOffSaveArea); Value *GrRegSaveAreaShadowPtr = MSV.getShadowOriginPtr(GrRegSaveAreaPtr, IRB, IRB.getInt8Ty(), @@ -4798,10 +5201,10 @@ struct VarArgAArch64Helper : public VarArgHelper { .first; Value *VrSrcPtr = IRB.CreateInBoundsGEP( - IRB.getInt8Ty(), - IRB.CreateInBoundsGEP(IRB.getInt8Ty(), VAArgTLSCopy, - IRB.getInt32(AArch64VrBegOffset)), - VrRegSaveAreaShadowPtrOff); + IRB.getInt8Ty(), + IRB.CreateInBoundsGEP(IRB.getInt8Ty(), VAArgTLSCopy, + IRB.getInt32(AArch64VrBegOffset)), + VrRegSaveAreaShadowPtrOff); Value *VrCopySize = IRB.CreateSub(VrArgSize, VrRegSaveAreaShadowPtrOff); IRB.CreateMemCpy(VrRegSaveAreaShadowPtr, Align(8), VrSrcPtr, Align(8), @@ -4813,9 +5216,8 @@ struct VarArgAArch64Helper : public VarArgHelper { Align(16), /*isStore*/ true) .first; - Value *StackSrcPtr = - IRB.CreateInBoundsGEP(IRB.getInt8Ty(), VAArgTLSCopy, - IRB.getInt32(AArch64VAEndOffset)); + Value *StackSrcPtr = IRB.CreateInBoundsGEP( + IRB.getInt8Ty(), VAArgTLSCopy, IRB.getInt32(AArch64VAEndOffset)); IRB.CreateMemCpy(StackSaveAreaShadowPtr, Align(16), StackSrcPtr, Align(16), VAArgOverflowSize); @@ -4831,10 +5233,11 @@ struct VarArgPowerPC64Helper : public VarArgHelper { Value *VAArgTLSCopy = nullptr; Value *VAArgSize = nullptr; - SmallVector<CallInst*, 16> VAStartInstrumentationList; + SmallVector<CallInst *, 16> VAStartInstrumentationList; VarArgPowerPC64Helper(Function &F, MemorySanitizer &MS, - MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) {} + MemorySanitizerVisitor &MSV) + : F(F), MS(MS), MSV(MSV) {} void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override { // For PowerPC, we need to deal with alignment of stack arguments - @@ -4854,10 +5257,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { VAArgBase = 32; unsigned VAArgOffset = VAArgBase; const DataLayout &DL = F.getParent()->getDataLayout(); - for (auto ArgIt = CB.arg_begin(), End = CB.arg_end(); ArgIt != End; - ++ArgIt) { - Value *A = *ArgIt; - unsigned ArgNo = CB.getArgOperandNo(ArgIt); + for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); bool IsByVal = CB.paramHasAttr(ArgNo, Attribute::ByVal); if (IsByVal) { @@ -4918,8 +5318,8 @@ struct VarArgPowerPC64Helper : public VarArgHelper { VAArgBase = VAArgOffset; } - Constant *TotalVAArgSize = ConstantInt::get(IRB.getInt64Ty(), - VAArgOffset - VAArgBase); + Constant *TotalVAArgSize = + ConstantInt::get(IRB.getInt64Ty(), VAArgOffset - VAArgBase); // Here using VAArgOverflowSizeTLS as VAArgSizeTLS to avoid creation of // a new class member i.e. it is the total size of all VarArgs. IRB.CreateStore(TotalVAArgSize, MS.VAArgOverflowSizeTLS); @@ -4967,8 +5367,8 @@ struct VarArgPowerPC64Helper : public VarArgHelper { "finalizeInstrumentation called twice"); IRBuilder<> IRB(MSV.FnPrologueEnd); VAArgSize = IRB.CreateLoad(IRB.getInt64Ty(), MS.VAArgOverflowSizeTLS); - Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), - VAArgSize); + Value *CopySize = + IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), VAArgSize); if (!VAStartInstrumentationList.empty()) { // If there is a va_start in this function, make a backup copy of @@ -4981,7 +5381,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { // Copy va_list shadow from the backup copy of the TLS contents. for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { CallInst *OrigInst = VAStartInstrumentationList[i]; - IRBuilder<> IRB(OrigInst->getNextNode()); + NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); Value *RegSaveAreaPtrPtr = @@ -5082,10 +5482,7 @@ struct VarArgSystemZHelper : public VarArgHelper { unsigned VrIndex = 0; unsigned OverflowOffset = SystemZOverflowOffset; const DataLayout &DL = F.getParent()->getDataLayout(); - for (auto ArgIt = CB.arg_begin(), End = CB.arg_end(); ArgIt != End; - ++ArgIt) { - Value *A = *ArgIt; - unsigned ArgNo = CB.getArgOperandNo(ArgIt); + for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); // SystemZABIInfo does not produce ByVal parameters. assert(!CB.paramHasAttr(ArgNo, Attribute::ByVal)); @@ -5304,7 +5701,7 @@ struct VarArgSystemZHelper : public VarArgHelper { for (size_t VaStartNo = 0, VaStartNum = VAStartInstrumentationList.size(); VaStartNo < VaStartNum; VaStartNo++) { CallInst *OrigInst = VAStartInstrumentationList[VaStartNo]; - IRBuilder<> IRB(OrigInst->getNextNode()); + NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); copyRegSaveArea(IRB, VAListTag); copyOverflowArea(IRB, VAListTag); @@ -5357,13 +5754,9 @@ bool MemorySanitizer::sanitizeFunction(Function &F, TargetLibraryInfo &TLI) { MemorySanitizerVisitor Visitor(F, *this, TLI); - // Clear out readonly/readnone attributes. + // Clear out memory attributes. AttributeMask B; - B.addAttribute(Attribute::ReadOnly) - .addAttribute(Attribute::ReadNone) - .addAttribute(Attribute::WriteOnly) - .addAttribute(Attribute::ArgMemOnly) - .addAttribute(Attribute::Speculatable); + B.addAttribute(Attribute::Memory).addAttribute(Attribute::Speculatable); F.removeFnAttrs(B); return Visitor.runOnFunction(); diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index c4512d0222cd..4d4eb6f8ce80 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -65,6 +65,8 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryProfileInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -91,11 +93,13 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/ProfileSummary.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/ProfileData/InstrProfReader.h" +#include "llvm/Support/BLAKE3.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CRC.h" #include "llvm/Support/Casting.h" @@ -105,6 +109,7 @@ #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GraphWriter.h" +#include "llvm/Support/HashBuilder.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -113,14 +118,18 @@ #include <algorithm> #include <cassert> #include <cstdint> +#include <map> #include <memory> #include <numeric> +#include <optional> +#include <set> #include <string> #include <unordered_map> #include <utility> #include <vector> using namespace llvm; +using namespace llvm::memprof; using ProfileCount = Function::ProfileCount; using VPCandidateInfo = ValueProfileCollector::CandidateInfo; @@ -135,6 +144,7 @@ STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts."); STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile."); STATISTIC(NumOfPGOMissing, "Number of functions without profile."); +STATISTIC(NumOfMemProfMissing, "Number of functions without memory profile."); STATISTIC(NumOfPGOICall, "Number of indirect call value instrumentations."); STATISTIC(NumOfCSPGOInstrument, "Number of edges instrumented in CSPGO."); STATISTIC(NumOfCSPGOSelectInsts, @@ -291,6 +301,19 @@ static cl::opt<std::string> PGOTraceFuncHash( cl::value_desc("function name"), cl::desc("Trace the hash of the function with this name.")); +static cl::opt<unsigned> PGOFunctionSizeThreshold( + "pgo-function-size-threshold", cl::Hidden, + cl::desc("Do not instrument functions smaller than this threshold.")); + +static cl::opt<bool> MatchMemProf( + "pgo-match-memprof", cl::init(true), cl::Hidden, + cl::desc("Perform matching and annotation of memprof profiles.")); + +static cl::opt<unsigned> PGOFunctionCriticalEdgeThreshold( + "pgo-critical-edge-threshold", cl::init(20000), cl::Hidden, + cl::desc("Do not instrument functions with the number of critical edges " + " greater than this threshold.")); + namespace llvm { // Command line option to turn on CFG dot dump after profile annotation. // Defined in Analysis/BlockFrequencyInfo.cpp: -pgo-view-counts @@ -363,7 +386,7 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) { auto IRLevelVersionVariable = new GlobalVariable( M, IntTy64, true, GlobalValue::WeakAnyLinkage, Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), VarName); - IRLevelVersionVariable->setVisibility(GlobalValue::DefaultVisibility); + IRLevelVersionVariable->setVisibility(GlobalValue::HiddenVisibility); Triple TT(M.getTargetTriple()); if (TT.supportsCOMDAT()) { IRLevelVersionVariable->setLinkage(GlobalValue::ExternalLinkage); @@ -499,6 +522,7 @@ private: void renameComdatFunction(); public: + const TargetLibraryInfo &TLI; std::vector<std::vector<VPCandidateInfo>> ValueSites; SelectInstVisitor SIVisitor; std::string FuncName; @@ -537,7 +561,7 @@ public: BlockFrequencyInfo *BFI = nullptr, bool IsCS = false, bool InstrumentFuncEntry = true) : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), VPC(Func, TLI), - ValueSites(IPVK_Last + 1), SIVisitor(Func), + TLI(TLI), ValueSites(IPVK_Last + 1), SIVisitor(Func), MST(F, InstrumentFuncEntry, BPI, BFI) { // This should be done before CFG hash computation. SIVisitor.countSelects(Func); @@ -803,7 +827,7 @@ populateEHOperandBundle(VPCandidateInfo &Cand, if (!isa<IntrinsicInst>(OrigCall)) { // The instrumentation call should belong to the same funclet as a // non-intrinsic call, so just copy the operand bundle, if any exists. - Optional<OperandBundleUse> ParentFunclet = + std::optional<OperandBundleUse> ParentFunclet = OrigCall->getOperandBundle(LLVMContext::OB_funclet); if (ParentFunclet) OpBundles.emplace_back(OperandBundleDef(*ParentFunclet)); @@ -991,7 +1015,7 @@ struct UseBBInfo : public BBInfo { // Sum up the count values for all the edges. static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) { uint64_t Total = 0; - for (auto &E : Edges) { + for (const auto &E : Edges) { if (E->Removed) continue; Total += E->CountValue; @@ -1014,7 +1038,10 @@ public: // Read counts for the instrumented BB from profile. bool readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, - bool &AllMinusOnes); + InstrProfRecord::CountPseudoKind &PseudoKind); + + // Read memprof data for the instrumented function from profile. + bool readMemprof(IndexedInstrProfReader *PGOReader); // Populate the counts for all BBs. void populateCounters(); @@ -1203,7 +1230,7 @@ static void annotateFunctionWithHashMismatch(Function &F, auto *Existing = F.getMetadata(LLVMContext::MD_annotation); if (Existing) { MDTuple *Tuple = cast<MDTuple>(Existing); - for (auto &N : Tuple->operands()) { + for (const auto &N : Tuple->operands()) { if (cast<MDString>(N.get())->getString() == MetadataName) return; Names.push_back(N.get()); @@ -1216,11 +1243,262 @@ static void annotateFunctionWithHashMismatch(Function &F, F.setMetadata(LLVMContext::MD_annotation, MD); } +static void addCallsiteMetadata(Instruction &I, + std::vector<uint64_t> &InlinedCallStack, + LLVMContext &Ctx) { + I.setMetadata(LLVMContext::MD_callsite, + buildCallstackMetadata(InlinedCallStack, Ctx)); +} + +static uint64_t computeStackId(GlobalValue::GUID Function, uint32_t LineOffset, + uint32_t Column) { + llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::support::endianness::little> + HashBuilder; + HashBuilder.add(Function, LineOffset, Column); + llvm::BLAKE3Result<8> Hash = HashBuilder.final(); + uint64_t Id; + std::memcpy(&Id, Hash.data(), sizeof(Hash)); + return Id; +} + +static uint64_t computeStackId(const memprof::Frame &Frame) { + return computeStackId(Frame.Function, Frame.LineOffset, Frame.Column); +} + +static void addCallStack(CallStackTrie &AllocTrie, + const AllocationInfo *AllocInfo) { + SmallVector<uint64_t> StackIds; + for (auto StackFrame : AllocInfo->CallStack) + StackIds.push_back(computeStackId(StackFrame)); + auto AllocType = getAllocType(AllocInfo->Info.getMaxAccessCount(), + AllocInfo->Info.getMinSize(), + AllocInfo->Info.getMinLifetime()); + AllocTrie.addCallStack(AllocType, StackIds); +} + +// Helper to compare the InlinedCallStack computed from an instruction's debug +// info to a list of Frames from profile data (either the allocation data or a +// callsite). For callsites, the StartIndex to use in the Frame array may be +// non-zero. +static bool +stackFrameIncludesInlinedCallStack(ArrayRef<Frame> ProfileCallStack, + ArrayRef<uint64_t> InlinedCallStack, + unsigned StartIndex = 0) { + auto StackFrame = ProfileCallStack.begin() + StartIndex; + auto InlCallStackIter = InlinedCallStack.begin(); + for (; StackFrame != ProfileCallStack.end() && + InlCallStackIter != InlinedCallStack.end(); + ++StackFrame, ++InlCallStackIter) { + uint64_t StackId = computeStackId(*StackFrame); + if (StackId != *InlCallStackIter) + return false; + } + // Return true if we found and matched all stack ids from the call + // instruction. + return InlCallStackIter == InlinedCallStack.end(); +} + +bool PGOUseFunc::readMemprof(IndexedInstrProfReader *PGOReader) { + if (!MatchMemProf) + return true; + + auto &Ctx = M->getContext(); + + auto FuncGUID = Function::getGUID(FuncInfo.FuncName); + Expected<memprof::MemProfRecord> MemProfResult = + PGOReader->getMemProfRecord(FuncGUID); + if (Error E = MemProfResult.takeError()) { + handleAllErrors(std::move(E), [&](const InstrProfError &IPE) { + auto Err = IPE.get(); + bool SkipWarning = false; + LLVM_DEBUG(dbgs() << "Error in reading profile for Func " + << FuncInfo.FuncName << ": "); + if (Err == instrprof_error::unknown_function) { + NumOfMemProfMissing++; + SkipWarning = !PGOWarnMissing; + LLVM_DEBUG(dbgs() << "unknown function"); + } else if (Err == instrprof_error::hash_mismatch) { + SkipWarning = + NoPGOWarnMismatch || + (NoPGOWarnMismatchComdatWeak && + (F.hasComdat() || + F.getLinkage() == GlobalValue::AvailableExternallyLinkage)); + LLVM_DEBUG(dbgs() << "hash mismatch (skip=" << SkipWarning << ")"); + } + + if (SkipWarning) + return; + + std::string Msg = + (IPE.message() + Twine(" ") + F.getName().str() + Twine(" Hash = ") + + std::to_string(FuncInfo.FunctionHash)) + .str(); + + Ctx.diagnose( + DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); + }); + return false; + } + + // Build maps of the location hash to all profile data with that leaf location + // (allocation info and the callsites). + std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo; + // For the callsites we need to record the index of the associated frame in + // the frame array (see comments below where the map entries are added). + std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, unsigned>>> + LocHashToCallSites; + const auto MemProfRec = std::move(MemProfResult.get()); + for (auto &AI : MemProfRec.AllocSites) { + // Associate the allocation info with the leaf frame. The later matching + // code will match any inlined call sequences in the IR with a longer prefix + // of call stack frames. + uint64_t StackId = computeStackId(AI.CallStack[0]); + LocHashToAllocInfo[StackId].insert(&AI); + } + for (auto &CS : MemProfRec.CallSites) { + // Need to record all frames from leaf up to and including this function, + // as any of these may or may not have been inlined at this point. + unsigned Idx = 0; + for (auto &StackFrame : CS) { + uint64_t StackId = computeStackId(StackFrame); + LocHashToCallSites[StackId].insert(std::make_pair(&CS, Idx++)); + // Once we find this function, we can stop recording. + if (StackFrame.Function == FuncGUID) + break; + } + assert(Idx <= CS.size() && CS[Idx - 1].Function == FuncGUID); + } + + auto GetOffset = [](const DILocation *DIL) { + return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) & + 0xffff; + }; + + // Now walk the instructions, looking up the associated profile data using + // dbug locations. + for (auto &BB : F) { + for (auto &I : BB) { + if (I.isDebugOrPseudoInst()) + continue; + // We are only interested in calls (allocation or interior call stack + // context calls). + auto *CI = dyn_cast<CallBase>(&I); + if (!CI) + continue; + auto *CalledFunction = CI->getCalledFunction(); + if (CalledFunction && CalledFunction->isIntrinsic()) + continue; + // List of call stack ids computed from the location hashes on debug + // locations (leaf to inlined at root). + std::vector<uint64_t> InlinedCallStack; + // Was the leaf location found in one of the profile maps? + bool LeafFound = false; + // If leaf was found in a map, iterators pointing to its location in both + // of the maps. It might exist in neither, one, or both (the latter case + // can happen because we don't currently have discriminators to + // distinguish the case when a single line/col maps to both an allocation + // and another callsite). + std::map<uint64_t, std::set<const AllocationInfo *>>::iterator + AllocInfoIter; + std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, + unsigned>>>::iterator CallSitesIter; + for (const DILocation *DIL = I.getDebugLoc(); DIL != nullptr; + DIL = DIL->getInlinedAt()) { + // Use C++ linkage name if possible. Need to compile with + // -fdebug-info-for-profiling to get linkage name. + StringRef Name = DIL->getScope()->getSubprogram()->getLinkageName(); + if (Name.empty()) + Name = DIL->getScope()->getSubprogram()->getName(); + auto CalleeGUID = Function::getGUID(Name); + auto StackId = + computeStackId(CalleeGUID, GetOffset(DIL), DIL->getColumn()); + // LeafFound will only be false on the first iteration, since we either + // set it true or break out of the loop below. + if (!LeafFound) { + AllocInfoIter = LocHashToAllocInfo.find(StackId); + CallSitesIter = LocHashToCallSites.find(StackId); + // Check if the leaf is in one of the maps. If not, no need to look + // further at this call. + if (AllocInfoIter == LocHashToAllocInfo.end() && + CallSitesIter == LocHashToCallSites.end()) + break; + LeafFound = true; + } + InlinedCallStack.push_back(StackId); + } + // If leaf not in either of the maps, skip inst. + if (!LeafFound) + continue; + + // First add !memprof metadata from allocation info, if we found the + // instruction's leaf location in that map, and if the rest of the + // instruction's locations match the prefix Frame locations on an + // allocation context with the same leaf. + if (AllocInfoIter != LocHashToAllocInfo.end()) { + // Only consider allocations via new, to reduce unnecessary metadata, + // since those are the only allocations that will be targeted initially. + if (!isNewLikeFn(CI, &FuncInfo.TLI)) + continue; + // We may match this instruction's location list to multiple MIB + // contexts. Add them to a Trie specialized for trimming the contexts to + // the minimal needed to disambiguate contexts with unique behavior. + CallStackTrie AllocTrie; + for (auto *AllocInfo : AllocInfoIter->second) { + // Check the full inlined call stack against this one. + // If we found and thus matched all frames on the call, include + // this MIB. + if (stackFrameIncludesInlinedCallStack(AllocInfo->CallStack, + InlinedCallStack)) + addCallStack(AllocTrie, AllocInfo); + } + // We might not have matched any to the full inlined call stack. + // But if we did, create and attach metadata, or a function attribute if + // all contexts have identical profiled behavior. + if (!AllocTrie.empty()) { + // MemprofMDAttached will be false if a function attribute was + // attached. + bool MemprofMDAttached = AllocTrie.buildAndAttachMIBMetadata(CI); + assert(MemprofMDAttached == I.hasMetadata(LLVMContext::MD_memprof)); + if (MemprofMDAttached) { + // Add callsite metadata for the instruction's location list so that + // it simpler later on to identify which part of the MIB contexts + // are from this particular instruction (including during inlining, + // when the callsite metdata will be updated appropriately). + // FIXME: can this be changed to strip out the matching stack + // context ids from the MIB contexts and not add any callsite + // metadata here to save space? + addCallsiteMetadata(I, InlinedCallStack, Ctx); + } + } + continue; + } + + // Otherwise, add callsite metadata. If we reach here then we found the + // instruction's leaf location in the callsites map and not the allocation + // map. + assert(CallSitesIter != LocHashToCallSites.end()); + for (auto CallStackIdx : CallSitesIter->second) { + // If we found and thus matched all frames on the call, create and + // attach call stack metadata. + if (stackFrameIncludesInlinedCallStack( + *CallStackIdx.first, InlinedCallStack, CallStackIdx.second)) { + addCallsiteMetadata(I, InlinedCallStack, Ctx); + // Only need to find one with a matching call stack and add a single + // callsite metadata. + break; + } + } + } + } + + return true; +} + // Read the profile from ProfileFileName and assign the value to the // instrumented BB and the edges. This function also updates ProgramMaxCount. // Return true if the profile are successfully read, and false on errors. bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, - bool &AllMinusOnes) { + InstrProfRecord::CountPseudoKind &PseudoKind) { auto &Ctx = M->getContext(); uint64_t MismatchedFuncSum = 0; Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord( @@ -1265,17 +1543,19 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros, return false; } ProfileRecord = std::move(Result.get()); + PseudoKind = ProfileRecord.getCountPseudoKind(); + if (PseudoKind != InstrProfRecord::NotPseudo) { + return true; + } std::vector<uint64_t> &CountFromProfile = ProfileRecord.Counts; IsCS ? NumOfCSPGOFunc++ : NumOfPGOFunc++; LLVM_DEBUG(dbgs() << CountFromProfile.size() << " counts\n"); - AllMinusOnes = (CountFromProfile.size() > 0); + uint64_t ValueSum = 0; for (unsigned I = 0, S = CountFromProfile.size(); I < S; I++) { LLVM_DEBUG(dbgs() << " " << I << ": " << CountFromProfile[I] << "\n"); ValueSum += CountFromProfile[I]; - if (CountFromProfile[I] != (uint64_t)-1) - AllMinusOnes = false; } AllZeros = (ValueSum == 0); @@ -1391,7 +1671,8 @@ void PGOUseFunc::setBranchWeights() { if (TI->getNumSuccessors() < 2) continue; if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || - isa<IndirectBrInst>(TI) || isa<InvokeInst>(TI))) + isa<IndirectBrInst>(TI) || isa<InvokeInst>(TI) || + isa<CallBrInst>(TI))) continue; if (getBBInfo(&BB).CountValue == 0) @@ -1414,7 +1695,21 @@ void PGOUseFunc::setBranchWeights() { MaxCount = EdgeCount; EdgeCounts[SuccNum] = EdgeCount; } - setProfMetadata(M, TI, EdgeCounts, MaxCount); + + if (MaxCount) + setProfMetadata(M, TI, EdgeCounts, MaxCount); + else { + // A zero MaxCount can come about when we have a BB with a positive + // count, and whose successor blocks all have 0 count. This can happen + // when there is no exit block and the code exits via a noreturn function. + auto &Ctx = M->getContext(); + Ctx.diagnose(DiagnosticInfoPGOProfile( + M->getName().data(), + Twine("Profile in ") + F.getName().str() + + Twine(" partially ignored") + + Twine(", possibly due to the lack of a return path."), + DS_Warning)); + } } } @@ -1557,6 +1852,38 @@ static void collectComdatMembers( ComdatMembers.insert(std::make_pair(C, &GA)); } +// Don't perform PGO instrumeatnion / profile-use. +static bool skipPGO(const Function &F) { + if (F.isDeclaration()) + return true; + if (F.hasFnAttribute(llvm::Attribute::NoProfile)) + return true; + if (F.hasFnAttribute(llvm::Attribute::SkipProfile)) + return true; + if (F.getInstructionCount() < PGOFunctionSizeThreshold) + return true; + + // If there are too many critical edges, PGO might cause + // compiler time problem. Skip PGO if the number of + // critical edges execeed the threshold. + unsigned NumCriticalEdges = 0; + for (auto &BB : F) { + const Instruction *TI = BB.getTerminator(); + for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) { + if (isCriticalEdge(TI, I)) + NumCriticalEdges++; + } + } + if (NumCriticalEdges > PGOFunctionCriticalEdgeThreshold) { + LLVM_DEBUG(dbgs() << "In func " << F.getName() + << ", NumCriticalEdges=" << NumCriticalEdges + << " exceed the threshold. Skip PGO.\n"); + return true; + } + + return false; +} + static bool InstrumentAllFunctions( Module &M, function_ref<TargetLibraryInfo &(Function &)> LookupTLI, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, @@ -1569,9 +1896,7 @@ static bool InstrumentAllFunctions( collectComdatMembers(M, ComdatMembers); for (auto &F : M) { - if (F.isDeclaration()) - continue; - if (F.hasFnAttribute(llvm::Attribute::NoProfile)) + if (skipPGO(F)) continue; auto &TLI = LookupTLI(F); auto *BPI = LookupBPI(F); @@ -1762,7 +2087,7 @@ static bool annotateAllFunctions( return false; // TODO: might need to change the warning once the clang option is finalized. - if (!PGOReader->isIRLevelProfile()) { + if (!PGOReader->isIRLevelProfile() && !PGOReader->hasMemoryProfile()) { Ctx.diagnose(DiagnosticInfoPGOProfile( ProfileFileName.data(), "Not an IR level instrumentation profile")); return false; @@ -1799,7 +2124,7 @@ static bool annotateAllFunctions( if (PGOInstrumentEntry.getNumOccurrences() > 0) InstrumentFuncEntry = PGOInstrumentEntry; for (auto &F : M) { - if (F.isDeclaration()) + if (skipPGO(F)) continue; auto &TLI = LookupTLI(F); auto *BPI = LookupBPI(F); @@ -1809,13 +2134,21 @@ static bool annotateAllFunctions( SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI); PGOUseFunc Func(F, &M, TLI, ComdatMembers, BPI, BFI, PSI, IsCS, InstrumentFuncEntry); - // When AllMinusOnes is true, it means the profile for the function - // is unrepresentative and this function is actually hot. Set the - // entry count of the function to be multiple times of hot threshold - // and drop all its internal counters. - bool AllMinusOnes = false; + // Read and match memprof first since we do this via debug info and can + // match even if there is an IR mismatch detected for regular PGO below. + if (PGOReader->hasMemoryProfile()) + Func.readMemprof(PGOReader.get()); + + if (!PGOReader->isIRLevelProfile()) + continue; + + // When PseudoKind is set to a vaule other than InstrProfRecord::NotPseudo, + // it means the profile for the function is unrepresentative and this + // function is actually hot / warm. We will reset the function hot / cold + // attribute and drop all the profile counters. + InstrProfRecord::CountPseudoKind PseudoKind = InstrProfRecord::NotPseudo; bool AllZeros = false; - if (!Func.readCounters(PGOReader.get(), AllZeros, AllMinusOnes)) + if (!Func.readCounters(PGOReader.get(), AllZeros, PseudoKind)) continue; if (AllZeros) { F.setEntryCount(ProfileCount(0, Function::PCT_Real)); @@ -1823,13 +2156,13 @@ static bool annotateAllFunctions( ColdFunctions.push_back(&F); continue; } - const unsigned MultiplyFactor = 3; - if (AllMinusOnes) { - uint64_t HotThreshold = PSI->getHotCountThreshold(); - if (HotThreshold) - F.setEntryCount( - ProfileCount(HotThreshold * MultiplyFactor, Function::PCT_Real)); - HotFunctions.push_back(&F); + if (PseudoKind != InstrProfRecord::NotPseudo) { + // Clear function attribute cold. + if (F.hasFnAttribute(Attribute::Cold)) + F.removeFnAttr(Attribute::Cold); + // Set function attribute as hot. + if (PseudoKind == InstrProfRecord::PseudoHot) + F.addFnAttr(Attribute::Hot); continue; } Func.populateCounters(); @@ -2067,7 +2400,7 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits { // Display scaled counts for SELECT instruction: OS << "SELECT : { T = "; uint64_t TC, FC; - bool HasProf = I.extractProfMetadata(TC, FC); + bool HasProf = extractBranchWeights(I, TC, FC); if (!HasProf) OS << "Unknown, F = Unknown }\\l"; else diff --git a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index b11f16894669..35db8483fc91 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -291,9 +291,9 @@ bool MemOPSizeOpt::perform(MemOp MO) { uint64_t SavedRemainCount = SavedTotalCount; SmallVector<uint64_t, 16> SizeIds; SmallVector<uint64_t, 16> CaseCounts; + SmallDenseSet<uint64_t, 16> SeenSizeId; uint64_t MaxCount = 0; unsigned Version = 0; - int64_t LastV = -1; // Default case is in the front -- save the slot here. CaseCounts.push_back(0); SmallVector<InstrProfValueData, 24> RemainingVDs; @@ -316,15 +316,12 @@ bool MemOPSizeOpt::perform(MemOp MO) { break; } - if (V == LastV) { - LLVM_DEBUG(dbgs() << "Invalid Profile Data in Function " << Func.getName() - << ": Two consecutive, identical values in MemOp value" - "counts.\n"); + if (!SeenSizeId.insert(V).second) { + errs() << "Invalid Profile Data in Function " << Func.getName() + << ": Two identical values in MemOp value counts.\n"; return false; } - LastV = V; - SizeIds.push_back(V); CaseCounts.push_back(C); if (C > MaxCount) @@ -425,7 +422,7 @@ bool MemOPSizeOpt::perform(MemOp MO) { assert(SizeType && "Expected integer type size argument."); ConstantInt *CaseSizeId = ConstantInt::get(SizeType, SizeId); NewMO.setLength(CaseSizeId); - CaseBB->getInstList().push_back(NewMO.I); + NewMO.I->insertInto(CaseBB, CaseBB->end()); IRBuilder<> IRBCase(CaseBB); IRBCase.CreateBr(MergeBB); SI->addCase(CaseSizeId, CaseBB); @@ -440,7 +437,8 @@ bool MemOPSizeOpt::perform(MemOp MO) { DTU.applyUpdates(Updates); Updates.clear(); - setProfMetadata(Func.getParent(), SI, CaseCounts, MaxCount); + if (MaxCount) + setProfMetadata(Func.getParent(), SI, CaseCounts, MaxCount); LLVM_DEBUG(dbgs() << *BB << "\n"); LLVM_DEBUG(dbgs() << *DefaultBB << "\n"); diff --git a/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp b/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp index 0e39fe266369..42e7cd80374d 100644 --- a/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp +++ b/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp @@ -89,9 +89,9 @@ static Value *buildOrChain(IRBuilder<> &B, ArrayRef<Value*> Ops) { if (i == Ops.size()) return B.getFalse(); Value *Accum = Ops[i++]; - for (; i < Ops.size(); i++) - if (!isConstantFalse(Ops[i])) - Accum = B.CreateOr(Accum, Ops[i]); + for (Value *Op : llvm::drop_begin(Ops, i)) + if (!isConstantFalse(Op)) + Accum = B.CreateOr(Accum, Op); return Accum; } @@ -276,10 +276,13 @@ static bool rewrite(Function &F) { // Note: There are many more sources of documented UB, but this pass only // attempts to find UB triggered by propagation of poison. - SmallPtrSet<const Value *, 4> NonPoisonOps; + SmallVector<const Value *, 4> NonPoisonOps; + SmallPtrSet<const Value *, 4> SeenNonPoisonOps; getGuaranteedNonPoisonOps(&I, NonPoisonOps); for (const Value *Op : NonPoisonOps) - CreateAssertNot(B, getPoisonFor(ValToPoison, const_cast<Value *>(Op))); + if (SeenNonPoisonOps.insert(Op).second) + CreateAssertNot(B, + getPoisonFor(ValToPoison, const_cast<Value *>(Op))); if (LocalCheck) if (auto *RI = dyn_cast<ReturnInst>(&I)) @@ -289,9 +292,10 @@ static bool rewrite(Function &F) { } SmallVector<Value*, 4> Checks; - if (propagatesPoison(cast<Operator>(&I))) - for (Value *V : I.operands()) - Checks.push_back(getPoisonFor(ValToPoison, V)); + for (const Use &U : I.operands()) { + if (ValToPoison.count(U) && propagatesPoison(U)) + Checks.push_back(getPoisonFor(ValToPoison, U)); + } if (canCreatePoison(cast<Operator>(&I))) generateCreationChecks(I, Checks); diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp new file mode 100644 index 000000000000..142b9c38e5fc --- /dev/null +++ b/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp @@ -0,0 +1,408 @@ +//===- SanitizerBinaryMetadata.cpp - binary analysis sanitizers metadata --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is a part of SanitizerBinaryMetadata. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/SanitizerBinaryMetadata.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Triple.h" +#include "llvm/ADT/Twine.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +#include <array> +#include <cstdint> + +using namespace llvm; + +#define DEBUG_TYPE "sanmd" + +namespace { + +//===--- Constants --------------------------------------------------------===// + +constexpr uint32_t kVersionBase = 1; // occupies lower 16 bits +constexpr uint32_t kVersionPtrSizeRel = (1u << 16); // offsets are pointer-sized +constexpr int kCtorDtorPriority = 2; + +// Pairs of names of initialization callback functions and which section +// contains the relevant metadata. +class MetadataInfo { +public: + const StringRef FunctionPrefix; + const StringRef SectionSuffix; + const uint32_t FeatureMask; + + static const MetadataInfo Covered; + static const MetadataInfo Atomics; + +private: + // Forbid construction elsewhere. + explicit constexpr MetadataInfo(StringRef FunctionPrefix, + StringRef SectionSuffix, uint32_t Feature) + : FunctionPrefix(FunctionPrefix), SectionSuffix(SectionSuffix), + FeatureMask(Feature) {} +}; +const MetadataInfo MetadataInfo::Covered{"__sanitizer_metadata_covered", + kSanitizerBinaryMetadataCoveredSection, + kSanitizerBinaryMetadataNone}; +const MetadataInfo MetadataInfo::Atomics{"__sanitizer_metadata_atomics", + kSanitizerBinaryMetadataAtomicsSection, + kSanitizerBinaryMetadataAtomics}; + +// The only instances of MetadataInfo are the constants above, so a set of +// them may simply store pointers to them. To deterministically generate code, +// we need to use a set with stable iteration order, such as SetVector. +using MetadataInfoSet = SetVector<const MetadataInfo *>; + +//===--- Command-line options ---------------------------------------------===// + +cl::opt<bool> ClWeakCallbacks( + "sanitizer-metadata-weak-callbacks", + cl::desc("Declare callbacks extern weak, and only call if non-null."), + cl::Hidden, cl::init(true)); + +cl::opt<bool> ClEmitCovered("sanitizer-metadata-covered", + cl::desc("Emit PCs for covered functions."), + cl::Hidden, cl::init(false)); +cl::opt<bool> ClEmitAtomics("sanitizer-metadata-atomics", + cl::desc("Emit PCs for atomic operations."), + cl::Hidden, cl::init(false)); +cl::opt<bool> ClEmitUAR("sanitizer-metadata-uar", + cl::desc("Emit PCs for start of functions that are " + "subject for use-after-return checking"), + cl::Hidden, cl::init(false)); + +//===--- Statistics -------------------------------------------------------===// + +STATISTIC(NumMetadataCovered, "Metadata attached to covered functions"); +STATISTIC(NumMetadataAtomics, "Metadata attached to atomics"); +STATISTIC(NumMetadataUAR, "Metadata attached to UAR functions"); + +//===----------------------------------------------------------------------===// + +// Apply opt overrides. +SanitizerBinaryMetadataOptions && +transformOptionsFromCl(SanitizerBinaryMetadataOptions &&Opts) { + Opts.Covered |= ClEmitCovered; + Opts.Atomics |= ClEmitAtomics; + Opts.UAR |= ClEmitUAR; + return std::move(Opts); +} + +class SanitizerBinaryMetadata { +public: + SanitizerBinaryMetadata(Module &M, SanitizerBinaryMetadataOptions Opts) + : Mod(M), Options(transformOptionsFromCl(std::move(Opts))), + TargetTriple(M.getTargetTriple()), IRB(M.getContext()) { + // FIXME: Make it work with other formats. + assert(TargetTriple.isOSBinFormatELF() && "ELF only"); + } + + bool run(); + +private: + // Return enabled feature mask of per-instruction metadata. + uint32_t getEnabledPerInstructionFeature() const { + uint32_t FeatureMask = 0; + if (Options.Atomics) + FeatureMask |= MetadataInfo::Atomics.FeatureMask; + return FeatureMask; + } + + uint32_t getVersion() const { + uint32_t Version = kVersionBase; + const auto CM = Mod.getCodeModel(); + if (CM.has_value() && (*CM == CodeModel::Medium || *CM == CodeModel::Large)) + Version |= kVersionPtrSizeRel; + return Version; + } + + void runOn(Function &F, MetadataInfoSet &MIS); + + // Determines which set of metadata to collect for this instruction. + // + // Returns true if covered metadata is required to unambiguously interpret + // other metadata. For example, if we are interested in atomics metadata, any + // function with memory operations (atomic or not) requires covered metadata + // to determine if a memory operation is atomic or not in modules compiled + // with SanitizerBinaryMetadata. + bool runOn(Instruction &I, MetadataInfoSet &MIS, MDBuilder &MDB, + uint32_t &FeatureMask); + + // Get start/end section marker pointer. + GlobalVariable *getSectionMarker(const Twine &MarkerName, Type *Ty); + + // Returns the target-dependent section name. + StringRef getSectionName(StringRef SectionSuffix); + + // Returns the section start marker name. + Twine getSectionStart(StringRef SectionSuffix); + + // Returns the section end marker name. + Twine getSectionEnd(StringRef SectionSuffix); + + Module &Mod; + const SanitizerBinaryMetadataOptions Options; + const Triple TargetTriple; + IRBuilder<> IRB; +}; + +bool SanitizerBinaryMetadata::run() { + MetadataInfoSet MIS; + + for (Function &F : Mod) + runOn(F, MIS); + + if (MIS.empty()) + return false; + + // + // Setup constructors and call all initialization functions for requested + // metadata features. + // + + auto *Int8PtrTy = IRB.getInt8PtrTy(); + auto *Int8PtrPtrTy = PointerType::getUnqual(Int8PtrTy); + auto *Int32Ty = IRB.getInt32Ty(); + const std::array<Type *, 3> InitTypes = {Int32Ty, Int8PtrPtrTy, Int8PtrPtrTy}; + auto *Version = ConstantInt::get(Int32Ty, getVersion()); + + for (const MetadataInfo *MI : MIS) { + const std::array<Value *, InitTypes.size()> InitArgs = { + Version, + getSectionMarker(getSectionStart(MI->SectionSuffix), Int8PtrTy), + getSectionMarker(getSectionEnd(MI->SectionSuffix), Int8PtrTy), + }; + // We declare the _add and _del functions as weak, and only call them if + // there is a valid symbol linked. This allows building binaries with + // semantic metadata, but without having callbacks. When a tool that wants + // the metadata is linked which provides the callbacks, they will be called. + Function *Ctor = + createSanitizerCtorAndInitFunctions( + Mod, (MI->FunctionPrefix + ".module_ctor").str(), + (MI->FunctionPrefix + "_add").str(), InitTypes, InitArgs, + /*VersionCheckName=*/StringRef(), /*Weak=*/ClWeakCallbacks) + .first; + Function *Dtor = + createSanitizerCtorAndInitFunctions( + Mod, (MI->FunctionPrefix + ".module_dtor").str(), + (MI->FunctionPrefix + "_del").str(), InitTypes, InitArgs, + /*VersionCheckName=*/StringRef(), /*Weak=*/ClWeakCallbacks) + .first; + Constant *CtorData = nullptr; + Constant *DtorData = nullptr; + if (TargetTriple.supportsCOMDAT()) { + // Use COMDAT to deduplicate constructor/destructor function. + Ctor->setComdat(Mod.getOrInsertComdat(Ctor->getName())); + Dtor->setComdat(Mod.getOrInsertComdat(Dtor->getName())); + CtorData = Ctor; + DtorData = Dtor; + } + appendToGlobalCtors(Mod, Ctor, kCtorDtorPriority, CtorData); + appendToGlobalDtors(Mod, Dtor, kCtorDtorPriority, DtorData); + } + + return true; +} + +void SanitizerBinaryMetadata::runOn(Function &F, MetadataInfoSet &MIS) { + if (F.empty()) + return; + if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) + return; + // Don't touch available_externally functions, their actual body is elsewhere. + if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) + return; + + MDBuilder MDB(F.getContext()); + + // The metadata features enabled for this function, stored along covered + // metadata (if enabled). + uint32_t FeatureMask = getEnabledPerInstructionFeature(); + // Don't emit unnecessary covered metadata for all functions to save space. + bool RequiresCovered = false; + // We can only understand if we need to set UAR feature after looking + // at the instructions. So we need to check instructions even if FeatureMask + // is empty. + if (FeatureMask || Options.UAR) { + for (BasicBlock &BB : F) + for (Instruction &I : BB) + RequiresCovered |= runOn(I, MIS, MDB, FeatureMask); + } + + if (F.isVarArg()) + FeatureMask &= ~kSanitizerBinaryMetadataUAR; + if (FeatureMask & kSanitizerBinaryMetadataUAR) { + RequiresCovered = true; + NumMetadataUAR++; + } + + // Covered metadata is always emitted if explicitly requested, otherwise only + // if some other metadata requires it to unambiguously interpret it for + // modules compiled with SanitizerBinaryMetadata. + if (Options.Covered || (FeatureMask && RequiresCovered)) { + NumMetadataCovered++; + const auto *MI = &MetadataInfo::Covered; + MIS.insert(MI); + const StringRef Section = getSectionName(MI->SectionSuffix); + // The feature mask will be placed after the size (32 bit) of the function, + // so in total one covered entry will use `sizeof(void*) + 4 + 4`. + Constant *CFM = IRB.getInt32(FeatureMask); + F.setMetadata(LLVMContext::MD_pcsections, + MDB.createPCSections({{Section, {CFM}}})); + } +} + +bool isUARSafeCall(CallInst *CI) { + auto *F = CI->getCalledFunction(); + // There are no intrinsic functions that leak arguments. + // If the called function does not return, the current function + // does not return as well, so no possibility of use-after-return. + // Sanitizer function also don't leak or don't return. + // It's safe to both pass pointers to local variables to them + // and to tail-call them. + return F && (F->isIntrinsic() || F->doesNotReturn() || + F->getName().startswith("__asan_") || + F->getName().startswith("__hwsan_") || + F->getName().startswith("__ubsan_") || + F->getName().startswith("__msan_") || + F->getName().startswith("__tsan_")); +} + +bool hasUseAfterReturnUnsafeUses(Value &V) { + for (User *U : V.users()) { + if (auto *I = dyn_cast<Instruction>(U)) { + if (I->isLifetimeStartOrEnd() || I->isDroppable()) + continue; + if (auto *CI = dyn_cast<CallInst>(U)) { + if (isUARSafeCall(CI)) + continue; + } + if (isa<LoadInst>(U)) + continue; + if (auto *SI = dyn_cast<StoreInst>(U)) { + // If storing TO the alloca, then the address isn't taken. + if (SI->getOperand(1) == &V) + continue; + } + if (auto *GEPI = dyn_cast<GetElementPtrInst>(U)) { + if (!hasUseAfterReturnUnsafeUses(*GEPI)) + continue; + } else if (auto *BCI = dyn_cast<BitCastInst>(U)) { + if (!hasUseAfterReturnUnsafeUses(*BCI)) + continue; + } + } + return true; + } + return false; +} + +bool useAfterReturnUnsafe(Instruction &I) { + if (isa<AllocaInst>(I)) + return hasUseAfterReturnUnsafeUses(I); + // Tail-called functions are not necessary intercepted + // at runtime because there is no call instruction. + // So conservatively mark the caller as requiring checking. + else if (auto *CI = dyn_cast<CallInst>(&I)) + return CI->isTailCall() && !isUARSafeCall(CI); + return false; +} + +bool SanitizerBinaryMetadata::runOn(Instruction &I, MetadataInfoSet &MIS, + MDBuilder &MDB, uint32_t &FeatureMask) { + SmallVector<const MetadataInfo *, 1> InstMetadata; + bool RequiresCovered = false; + + if (Options.UAR && !(FeatureMask & kSanitizerBinaryMetadataUAR)) { + if (useAfterReturnUnsafe(I)) + FeatureMask |= kSanitizerBinaryMetadataUAR; + } + + if (Options.Atomics && I.mayReadOrWriteMemory()) { + auto SSID = getAtomicSyncScopeID(&I); + if (SSID.has_value() && *SSID != SyncScope::SingleThread) { + NumMetadataAtomics++; + InstMetadata.push_back(&MetadataInfo::Atomics); + } + RequiresCovered = true; + } + + // Attach MD_pcsections to instruction. + if (!InstMetadata.empty()) { + MIS.insert(InstMetadata.begin(), InstMetadata.end()); + SmallVector<MDBuilder::PCSection, 1> Sections; + for (const auto &MI : InstMetadata) + Sections.push_back({getSectionName(MI->SectionSuffix), {}}); + I.setMetadata(LLVMContext::MD_pcsections, MDB.createPCSections(Sections)); + } + + return RequiresCovered; +} + +GlobalVariable * +SanitizerBinaryMetadata::getSectionMarker(const Twine &MarkerName, Type *Ty) { + // Use ExternalWeak so that if all sections are discarded due to section + // garbage collection, the linker will not report undefined symbol errors. + auto *Marker = new GlobalVariable(Mod, Ty, /*isConstant=*/false, + GlobalVariable::ExternalWeakLinkage, + /*Initializer=*/nullptr, MarkerName); + Marker->setVisibility(GlobalValue::HiddenVisibility); + return Marker; +} + +StringRef SanitizerBinaryMetadata::getSectionName(StringRef SectionSuffix) { + // FIXME: Other TargetTriple (req. string pool) + return SectionSuffix; +} + +Twine SanitizerBinaryMetadata::getSectionStart(StringRef SectionSuffix) { + return "__start_" + SectionSuffix; +} + +Twine SanitizerBinaryMetadata::getSectionEnd(StringRef SectionSuffix) { + return "__stop_" + SectionSuffix; +} + +} // namespace + +SanitizerBinaryMetadataPass::SanitizerBinaryMetadataPass( + SanitizerBinaryMetadataOptions Opts) + : Options(std::move(Opts)) {} + +PreservedAnalyses +SanitizerBinaryMetadataPass::run(Module &M, AnalysisManager<Module> &AM) { + SanitizerBinaryMetadata Pass(M, Options); + if (Pass.run()) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 3ca476e74953..23a88c3cfba2 100644 --- a/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/EHPersonalities.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DataLayout.h" @@ -75,11 +76,13 @@ const char SanCovTracePCGuardInitName[] = "__sanitizer_cov_trace_pc_guard_init"; const char SanCov8bitCountersInitName[] = "__sanitizer_cov_8bit_counters_init"; const char SanCovBoolFlagInitName[] = "__sanitizer_cov_bool_flag_init"; const char SanCovPCsInitName[] = "__sanitizer_cov_pcs_init"; +const char SanCovCFsInitName[] = "__sanitizer_cov_cfs_init"; const char SanCovGuardsSectionName[] = "sancov_guards"; const char SanCovCountersSectionName[] = "sancov_cntrs"; const char SanCovBoolFlagSectionName[] = "sancov_bools"; const char SanCovPCsSectionName[] = "sancov_pcs"; +const char SanCovCFsSectionName[] = "sancov_cfs"; const char SanCovLowestStackName[] = "__sancov_lowest_stack"; @@ -147,6 +150,11 @@ static cl::opt<bool> ClStackDepth("sanitizer-coverage-stack-depth", cl::desc("max stack depth tracing"), cl::Hidden, cl::init(false)); +static cl::opt<bool> + ClCollectCF("sanitizer-coverage-control-flow", + cl::desc("collect control flow for each function"), cl::Hidden, + cl::init(false)); + namespace { SanitizerCoverageOptions getOptions(int LegacyCoverageLevel) { @@ -193,6 +201,7 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) { !Options.Inline8bitCounters && !Options.StackDepth && !Options.InlineBoolFlag && !Options.TraceLoads && !Options.TraceStores) Options.TracePCGuard = true; // TracePCGuard is default. + Options.CollectControlFlow |= ClCollectCF; return Options; } @@ -212,6 +221,7 @@ public: PostDomTreeCallback PDTCallback); private: + void createFunctionControlFlow(Function &F); void instrumentFunction(Function &F, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback); void InjectCoverageForIndirectCalls(Function &F, @@ -241,7 +251,7 @@ private: Type *Ty); void SetNoSanitizeMetadata(Instruction *I) { - I->setMetadata(LLVMContext::MD_nosanitize, MDNode::get(*C, None)); + I->setMetadata(LLVMContext::MD_nosanitize, MDNode::get(*C, std::nullopt)); } std::string getSectionName(const std::string &Section) const; @@ -270,6 +280,7 @@ private: GlobalVariable *Function8bitCounterArray; // for inline-8bit-counters. GlobalVariable *FunctionBoolArray; // for inline-bool-flag. GlobalVariable *FunctionPCsArray; // for pc-table. + GlobalVariable *FunctionCFsArray; // for control flow table SmallVector<GlobalValue *, 20> GlobalsToAppendToUsed; SmallVector<GlobalValue *, 20> GlobalsToAppendToCompilerUsed; @@ -280,8 +291,8 @@ private: }; } // namespace -PreservedAnalyses ModuleSanitizerCoveragePass::run(Module &M, - ModuleAnalysisManager &MAM) { +PreservedAnalyses SanitizerCoveragePass::run(Module &M, + ModuleAnalysisManager &MAM) { ModuleSanitizerCoverage ModuleSancov(Options, Allowlist.get(), Blocklist.get()); auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); @@ -291,9 +302,15 @@ PreservedAnalyses ModuleSanitizerCoveragePass::run(Module &M, auto PDTCallback = [&FAM](Function &F) -> const PostDominatorTree * { return &FAM.getResult<PostDominatorTreeAnalysis>(F); }; - if (ModuleSancov.instrumentModule(M, DTCallback, PDTCallback)) - return PreservedAnalyses::none(); - return PreservedAnalyses::all(); + if (!ModuleSancov.instrumentModule(M, DTCallback, PDTCallback)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA = PreservedAnalyses::none(); + // GlobalsAA is considered stateless and does not get invalidated unless + // explicitly invalidated; PreservedAnalyses::none() is not enough. Sanitizers + // make changes that require GlobalsAA to be invalidated. + PA.abandon<GlobalsAA>(); + return PA; } std::pair<Value *, Value *> @@ -378,6 +395,7 @@ bool ModuleSanitizerCoverage::instrumentModule( Function8bitCounterArray = nullptr; FunctionBoolArray = nullptr; FunctionPCsArray = nullptr; + FunctionCFsArray = nullptr; IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits()); IntptrPtrTy = PointerType::getUnqual(IntptrTy); Type *VoidTy = Type::getVoidTy(*C); @@ -502,6 +520,15 @@ bool ModuleSanitizerCoverage::instrumentModule( IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator()); IRBCtor.CreateCall(InitFunction, {SecStartEnd.first, SecStartEnd.second}); } + + if (Ctor && Options.CollectControlFlow) { + auto SecStartEnd = CreateSecStartEnd(M, SanCovCFsSectionName, IntptrTy); + FunctionCallee InitFunction = declareSanitizerInitFunction( + M, SanCovCFsInitName, {IntptrPtrTy, IntptrPtrTy}); + IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator()); + IRBCtor.CreateCall(InitFunction, {SecStartEnd.first, SecStartEnd.second}); + } + appendToUsed(M, GlobalsToAppendToUsed); appendToCompilerUsed(M, GlobalsToAppendToCompilerUsed); return true; @@ -671,6 +698,9 @@ void ModuleSanitizerCoverage::instrumentFunction( } } + if (Options.CollectControlFlow) + createFunctionControlFlow(F); + InjectCoverage(F, BlocksToInstrument, IsLeafFunc); InjectCoverageForIndirectCalls(F, IndirCalls); InjectTraceForCmp(F, CmpTraceTargets); @@ -692,7 +722,7 @@ GlobalVariable *ModuleSanitizerCoverage::CreateFunctionLocalArrayInSection( if (auto Comdat = getOrCreateFunctionComdat(F, TargetTriple)) Array->setComdat(Comdat); Array->setSection(getSectionName(Section)); - Array->setAlignment(Align(DL->getTypeStoreSize(Ty).getFixedSize())); + Array->setAlignment(Align(DL->getTypeStoreSize(Ty).getFixedValue())); // sancov_pcs parallels the other metadata section(s). Optimizers (e.g. // GlobalOpt/ConstantMerge) may not discard sancov_pcs and the other @@ -726,8 +756,7 @@ ModuleSanitizerCoverage::CreatePCArray(Function &F, } else { PCs.push_back((Constant *)IRB.CreatePointerCast( BlockAddress::get(AllBlocks[i]), IntptrPtrTy)); - PCs.push_back((Constant *)IRB.CreateIntToPtr( - ConstantInt::get(IntptrTy, 0), IntptrPtrTy)); + PCs.push_back(Constant::getNullValue(IntptrPtrTy)); } } auto *PCArray = CreateFunctionLocalArrayInSection(N * 2, F, IntptrPtrTy, @@ -779,7 +808,7 @@ void ModuleSanitizerCoverage::InjectCoverageForIndirectCalls( return; assert(Options.TracePC || Options.TracePCGuard || Options.Inline8bitCounters || Options.InlineBoolFlag); - for (auto I : IndirCalls) { + for (auto *I : IndirCalls) { IRBuilder<> IRB(I); CallBase &CB = cast<CallBase>(*I); Value *Callee = CB.getCalledOperand(); @@ -795,7 +824,7 @@ void ModuleSanitizerCoverage::InjectCoverageForIndirectCalls( void ModuleSanitizerCoverage::InjectTraceForSwitch( Function &, ArrayRef<Instruction *> SwitchTraceTargets) { - for (auto I : SwitchTraceTargets) { + for (auto *I : SwitchTraceTargets) { if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { IRBuilder<> IRB(I); SmallVector<Constant *, 16> Initializers; @@ -834,7 +863,7 @@ void ModuleSanitizerCoverage::InjectTraceForSwitch( void ModuleSanitizerCoverage::InjectTraceForDiv( Function &, ArrayRef<BinaryOperator *> DivTraceTargets) { - for (auto BO : DivTraceTargets) { + for (auto *BO : DivTraceTargets) { IRBuilder<> IRB(BO); Value *A1 = BO->getOperand(1); if (isa<ConstantInt>(A1)) continue; @@ -852,7 +881,7 @@ void ModuleSanitizerCoverage::InjectTraceForDiv( void ModuleSanitizerCoverage::InjectTraceForGep( Function &, ArrayRef<GetElementPtrInst *> GepTraceTargets) { - for (auto GEP : GepTraceTargets) { + for (auto *GEP : GepTraceTargets) { IRBuilder<> IRB(GEP); for (Use &Idx : GEP->indices()) if (!isa<ConstantInt>(Idx) && Idx->getType()->isIntegerTy()) @@ -874,7 +903,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores( }; Type *PointerType[5] = {Int8PtrTy, Int16PtrTy, Int32PtrTy, Int64PtrTy, Int128PtrTy}; - for (auto LI : Loads) { + for (auto *LI : Loads) { IRBuilder<> IRB(LI); auto Ptr = LI->getPointerOperand(); int Idx = CallbackIdx(LI->getType()); @@ -883,7 +912,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores( IRB.CreateCall(SanCovLoadFunction[Idx], IRB.CreatePointerCast(Ptr, PointerType[Idx])); } - for (auto SI : Stores) { + for (auto *SI : Stores) { IRBuilder<> IRB(SI); auto Ptr = SI->getPointerOperand(); int Idx = CallbackIdx(SI->getValueOperand()->getType()); @@ -896,7 +925,7 @@ void ModuleSanitizerCoverage::InjectTraceForLoadsAndStores( void ModuleSanitizerCoverage::InjectTraceForCmp( Function &, ArrayRef<Instruction *> CmpTraceTargets) { - for (auto I : CmpTraceTargets) { + for (auto *I : CmpTraceTargets) { if (ICmpInst *ICMP = dyn_cast<ICmpInst>(I)) { IRBuilder<> IRB(ICMP); Value *A0 = ICMP->getOperand(0); @@ -1028,3 +1057,48 @@ ModuleSanitizerCoverage::getSectionEnd(const std::string &Section) const { return "\1section$end$__DATA$__" + Section; return "__stop___" + Section; } + +void ModuleSanitizerCoverage::createFunctionControlFlow(Function &F) { + SmallVector<Constant *, 32> CFs; + IRBuilder<> IRB(&*F.getEntryBlock().getFirstInsertionPt()); + + for (auto &BB : F) { + // blockaddress can not be used on function's entry block. + if (&BB == &F.getEntryBlock()) + CFs.push_back((Constant *)IRB.CreatePointerCast(&F, IntptrPtrTy)); + else + CFs.push_back((Constant *)IRB.CreatePointerCast(BlockAddress::get(&BB), + IntptrPtrTy)); + + for (auto SuccBB : successors(&BB)) { + assert(SuccBB != &F.getEntryBlock()); + CFs.push_back((Constant *)IRB.CreatePointerCast(BlockAddress::get(SuccBB), + IntptrPtrTy)); + } + + CFs.push_back((Constant *)Constant::getNullValue(IntptrPtrTy)); + + for (auto &Inst : BB) { + if (CallBase *CB = dyn_cast<CallBase>(&Inst)) { + if (CB->isIndirectCall()) { + // TODO(navidem): handle indirect calls, for now mark its existence. + CFs.push_back((Constant *)IRB.CreateIntToPtr( + ConstantInt::get(IntptrTy, -1), IntptrPtrTy)); + } else { + auto CalledF = CB->getCalledFunction(); + if (CalledF && !CalledF->isIntrinsic()) + CFs.push_back( + (Constant *)IRB.CreatePointerCast(CalledF, IntptrPtrTy)); + } + } + } + + CFs.push_back((Constant *)Constant::getNullValue(IntptrPtrTy)); + } + + FunctionCFsArray = CreateFunctionLocalArrayInSection( + CFs.size(), F, IntptrPtrTy, SanCovCFsSectionName); + FunctionCFsArray->setInitializer( + ConstantArray::get(ArrayType::get(IntptrPtrTy, CFs.size()), CFs)); + FunctionCFsArray->setConstant(true); +} diff --git a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index d4aa31db8337..a127e81ce643 100644 --- a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -20,7 +20,6 @@ #include "llvm/Transforms/Instrumentation/ThreadSanitizer.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -133,7 +132,7 @@ private: unsigned Flags = 0; }; - void initialize(Module &M); + void initialize(Module &M, const TargetLibraryInfo &TLI); bool instrumentLoadOrStore(const InstructionInfo &II, const DataLayout &DL); bool instrumentAtomic(Instruction *I, const DataLayout &DL); bool instrumentMemIntrinsic(Instruction *I); @@ -196,13 +195,14 @@ PreservedAnalyses ModuleThreadSanitizerPass::run(Module &M, insertModuleCtor(M); return PreservedAnalyses::none(); } -void ThreadSanitizer::initialize(Module &M) { +void ThreadSanitizer::initialize(Module &M, const TargetLibraryInfo &TLI) { const DataLayout &DL = M.getDataLayout(); - IntptrTy = DL.getIntPtrType(M.getContext()); + LLVMContext &Ctx = M.getContext(); + IntptrTy = DL.getIntPtrType(Ctx); - IRBuilder<> IRB(M.getContext()); + IRBuilder<> IRB(Ctx); AttributeList Attr; - Attr = Attr.addFnAttribute(M.getContext(), Attribute::NoUnwind); + Attr = Attr.addFnAttribute(Ctx, Attribute::NoUnwind); // Initialize the callbacks. TsanFuncEntry = M.getOrInsertFunction("__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); @@ -261,24 +261,24 @@ void ThreadSanitizer::initialize(Module &M) { TsanUnalignedCompoundRW[i] = M.getOrInsertFunction( UnalignedCompoundRWName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); - Type *Ty = Type::getIntNTy(M.getContext(), BitSize); + Type *Ty = Type::getIntNTy(Ctx, BitSize); Type *PtrTy = Ty->getPointerTo(); SmallString<32> AtomicLoadName("__tsan_atomic" + BitSizeStr + "_load"); - { - AttributeList AL = Attr; - AL = AL.addParamAttribute(M.getContext(), 1, Attribute::ZExt); - TsanAtomicLoad[i] = - M.getOrInsertFunction(AtomicLoadName, AL, Ty, PtrTy, OrdTy); - } - + TsanAtomicLoad[i] = + M.getOrInsertFunction(AtomicLoadName, + TLI.getAttrList(&Ctx, {1}, /*Signed=*/true, + /*Ret=*/BitSize <= 32, Attr), + Ty, PtrTy, OrdTy); + + // Args of type Ty need extension only when BitSize is 32 or less. + using Idxs = std::vector<unsigned>; + Idxs Idxs2Or12 ((BitSize <= 32) ? Idxs({1, 2}) : Idxs({2})); + Idxs Idxs34Or1234((BitSize <= 32) ? Idxs({1, 2, 3, 4}) : Idxs({3, 4})); SmallString<32> AtomicStoreName("__tsan_atomic" + BitSizeStr + "_store"); - { - AttributeList AL = Attr; - AL = AL.addParamAttribute(M.getContext(), 1, Attribute::ZExt); - AL = AL.addParamAttribute(M.getContext(), 2, Attribute::ZExt); - TsanAtomicStore[i] = M.getOrInsertFunction( - AtomicStoreName, AL, IRB.getVoidTy(), PtrTy, Ty, OrdTy); - } + TsanAtomicStore[i] = M.getOrInsertFunction( + AtomicStoreName, + TLI.getAttrList(&Ctx, Idxs2Or12, /*Signed=*/true, /*Ret=*/false, Attr), + IRB.getVoidTy(), PtrTy, Ty, OrdTy); for (unsigned Op = AtomicRMWInst::FIRST_BINOP; Op <= AtomicRMWInst::LAST_BINOP; ++Op) { @@ -301,54 +301,46 @@ void ThreadSanitizer::initialize(Module &M) { else continue; SmallString<32> RMWName("__tsan_atomic" + itostr(BitSize) + NamePart); - { - AttributeList AL = Attr; - AL = AL.addParamAttribute(M.getContext(), 1, Attribute::ZExt); - AL = AL.addParamAttribute(M.getContext(), 2, Attribute::ZExt); - TsanAtomicRMW[Op][i] = - M.getOrInsertFunction(RMWName, AL, Ty, PtrTy, Ty, OrdTy); - } + TsanAtomicRMW[Op][i] = M.getOrInsertFunction( + RMWName, + TLI.getAttrList(&Ctx, Idxs2Or12, /*Signed=*/true, + /*Ret=*/BitSize <= 32, Attr), + Ty, PtrTy, Ty, OrdTy); } SmallString<32> AtomicCASName("__tsan_atomic" + BitSizeStr + "_compare_exchange_val"); - { - AttributeList AL = Attr; - AL = AL.addParamAttribute(M.getContext(), 1, Attribute::ZExt); - AL = AL.addParamAttribute(M.getContext(), 2, Attribute::ZExt); - AL = AL.addParamAttribute(M.getContext(), 3, Attribute::ZExt); - AL = AL.addParamAttribute(M.getContext(), 4, Attribute::ZExt); - TsanAtomicCAS[i] = M.getOrInsertFunction(AtomicCASName, AL, Ty, PtrTy, Ty, - Ty, OrdTy, OrdTy); - } + TsanAtomicCAS[i] = M.getOrInsertFunction( + AtomicCASName, + TLI.getAttrList(&Ctx, Idxs34Or1234, /*Signed=*/true, + /*Ret=*/BitSize <= 32, Attr), + Ty, PtrTy, Ty, Ty, OrdTy, OrdTy); } TsanVptrUpdate = M.getOrInsertFunction("__tsan_vptr_update", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy()); TsanVptrLoad = M.getOrInsertFunction("__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); - { - AttributeList AL = Attr; - AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); - TsanAtomicThreadFence = M.getOrInsertFunction("__tsan_atomic_thread_fence", - AL, IRB.getVoidTy(), OrdTy); - } - { - AttributeList AL = Attr; - AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); - TsanAtomicSignalFence = M.getOrInsertFunction("__tsan_atomic_signal_fence", - AL, IRB.getVoidTy(), OrdTy); - } + TsanAtomicThreadFence = M.getOrInsertFunction( + "__tsan_atomic_thread_fence", + TLI.getAttrList(&Ctx, {0}, /*Signed=*/true, /*Ret=*/false, Attr), + IRB.getVoidTy(), OrdTy); + + TsanAtomicSignalFence = M.getOrInsertFunction( + "__tsan_atomic_signal_fence", + TLI.getAttrList(&Ctx, {0}, /*Signed=*/true, /*Ret=*/false, Attr), + IRB.getVoidTy(), OrdTy); MemmoveFn = - M.getOrInsertFunction("memmove", Attr, IRB.getInt8PtrTy(), + M.getOrInsertFunction("__tsan_memmove", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); MemcpyFn = - M.getOrInsertFunction("memcpy", Attr, IRB.getInt8PtrTy(), + M.getOrInsertFunction("__tsan_memcpy", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); - MemsetFn = - M.getOrInsertFunction("memset", Attr, IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); + MemsetFn = M.getOrInsertFunction( + "__tsan_memset", + TLI.getAttrList(&Ctx, {1}, /*Signed=*/true, /*Ret=*/false, Attr), + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); } static bool isVtableAccess(Instruction *I) { @@ -379,7 +371,7 @@ static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) { return false; } - // Do not instrument acesses from different address spaces; we cannot deal + // Do not instrument accesses from different address spaces; we cannot deal // with them. if (Addr) { Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType()); @@ -486,7 +478,7 @@ static bool isTsanAtomic(const Instruction *I) { if (!SSID) return false; if (isa<LoadInst>(I) || isa<StoreInst>(I)) - return SSID.value() != SyncScope::SingleThread; + return *SSID != SyncScope::SingleThread; return true; } @@ -517,7 +509,7 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) return false; - initialize(*F.getParent()); + initialize(*F.getParent(), TLI); SmallVector<InstructionInfo, 8> AllLoadsAndStores; SmallVector<Instruction*, 8> LocalLoadsAndStores; SmallVector<Instruction*, 8> AtomicAccesses; @@ -561,12 +553,12 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, // Instrument atomic memory accesses in any case (they can be used to // implement synchronization). if (ClInstrumentAtomics) - for (auto Inst : AtomicAccesses) { + for (auto *Inst : AtomicAccesses) { Res |= instrumentAtomic(Inst, DL); } if (ClInstrumentMemIntrinsics && SanitizeFunction) - for (auto Inst : MemIntrinCalls) { + for (auto *Inst : MemIntrinCalls) { Res |= instrumentMemIntrinsic(Inst); } @@ -676,7 +668,7 @@ static ConstantInt *createOrdering(IRBuilder<> *IRB, AtomicOrdering ord) { switch (ord) { case AtomicOrdering::NotAtomic: llvm_unreachable("unexpected atomic ordering!"); - case AtomicOrdering::Unordered: LLVM_FALLTHROUGH; + case AtomicOrdering::Unordered: [[fallthrough]]; case AtomicOrdering::Monotonic: v = 0; break; // Not specified yet: // case AtomicOrdering::Consume: v = 1; break; @@ -802,7 +794,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { } Value *Res = - IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0); + IRB.CreateInsertValue(PoisonValue::get(CASI->getType()), OldVal, 0); Res = IRB.CreateInsertValue(Res, Success, 1); I->replaceAllUsesWith(Res); diff --git a/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp b/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp index de0f5803b4c7..0fea6bcc4882 100644 --- a/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp +++ b/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp @@ -48,10 +48,10 @@ bool llvm::objcarc::CanAlterRefCount(const Instruction *Inst, const Value *Ptr, const auto *Call = cast<CallBase>(Inst); // See if AliasAnalysis can help us with the call. - FunctionModRefBehavior MRB = PA.getAA()->getModRefBehavior(Call); - if (AliasAnalysis::onlyReadsMemory(MRB)) + MemoryEffects ME = PA.getAA()->getMemoryEffects(Call); + if (ME.onlyReadsMemory()) return false; - if (AliasAnalysis::onlyAccessesArgPointees(MRB)) { + if (ME.onlyAccessesArgPointees()) { for (const Value *Op : Call->args()) { if (IsPotentialRetainableObjPtr(Op, *PA.getAA()) && PA.related(Ptr, Op)) return true; diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp index 70f150c9461a..02f9db719e26 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp +++ b/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp @@ -13,35 +13,14 @@ //===----------------------------------------------------------------------===// #include "ObjCARC.h" -#include "llvm-c/Initialization.h" #include "llvm/Analysis/ObjCARCUtil.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" -#include "llvm/InitializePasses.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -namespace llvm { - class PassRegistry; -} - using namespace llvm; using namespace llvm::objcarc; -/// initializeObjCARCOptsPasses - Initialize all passes linked into the -/// ObjCARCOpts library. -void llvm::initializeObjCARCOpts(PassRegistry &Registry) { - initializeObjCARCAAWrapperPassPass(Registry); - initializeObjCARCAPElimPass(Registry); - initializeObjCARCExpandPass(Registry); - initializeObjCARCContractLegacyPassPass(Registry); - initializeObjCARCOptLegacyPassPass(Registry); - initializePAEvalPass(Registry); -} - -void LLVMInitializeObjCARCOpts(LLVMPassRegistryRef R) { - initializeObjCARCOpts(*unwrap(R)); -} - CallInst *objcarc::createCallInstWithColors( FunctionCallee Func, ArrayRef<Value *> Args, const Twine &NameStr, Instruction *InsertBefore, diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARC.h b/llvm/lib/Transforms/ObjCARC/ObjCARC.h index 2bc0c8f87d77..d4570ff908f1 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARC.h +++ b/llvm/lib/Transforms/ObjCARC/ObjCARC.h @@ -132,8 +132,8 @@ public: auto It = RVCalls.find(CI); if (It != RVCalls.end()) { // Remove call to @llvm.objc.clang.arc.noop.use. - for (auto U = It->second->user_begin(), E = It->second->user_end(); U != E; ++U) - if (auto *CI = dyn_cast<CallInst>(*U)) + for (User *U : It->second->users()) + if (auto *CI = dyn_cast<CallInst>(U)) if (CI->getIntrinsicID() == Intrinsic::objc_clang_arc_noop_use) { CI->eraseFromParent(); break; diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp index 03e5fb18d5ac..dceb2ebb1863 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp +++ b/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp @@ -29,8 +29,6 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/PassManager.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/ObjCARC.h" @@ -147,35 +145,8 @@ bool runImpl(Module &M) { return Changed; } -/// Autorelease pool elimination. -class ObjCARCAPElim : public ModulePass { - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnModule(Module &M) override; - -public: - static char ID; - ObjCARCAPElim() : ModulePass(ID) { - initializeObjCARCAPElimPass(*PassRegistry::getPassRegistry()); - } -}; } // namespace -char ObjCARCAPElim::ID = 0; -INITIALIZE_PASS(ObjCARCAPElim, "objc-arc-apelim", - "ObjC ARC autorelease pool elimination", false, false) - -Pass *llvm::createObjCARCAPElimPass() { return new ObjCARCAPElim(); } - -void ObjCARCAPElim::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesCFG(); -} - -bool ObjCARCAPElim::runOnModule(Module &M) { - if (skipModule(M)) - return false; - return runImpl(M); -} - PreservedAnalyses ObjCARCAPElimPass::run(Module &M, ModuleAnalysisManager &AM) { if (!runImpl(M)) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp index f64c26ef2bed..ab90ef090ae0 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -428,7 +428,7 @@ bool ObjCARCContract::tryToPeepholeInstruction( if (!optimizeRetainCall(F, Inst)) return false; // If we succeed in our optimization, fall through. - LLVM_FALLTHROUGH; + [[fallthrough]]; case ARCInstKind::RetainRV: case ARCInstKind::UnsafeClaimRV: { // Return true if this is a bundled retainRV/claimRV call, which is always @@ -472,7 +472,8 @@ bool ObjCARCContract::tryToPeepholeInstruction( RVInstMarker->getString(), /*Constraints=*/"", /*hasSideEffects=*/true); - objcarc::createCallInstWithColors(IA, None, "", Inst, BlockColors); + objcarc::createCallInstWithColors(IA, std::nullopt, "", Inst, + BlockColors); } decline_rv_optimization: return false; diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp index efcdc51ef5e3..bb0a01b78a96 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp +++ b/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp @@ -29,9 +29,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/PassRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -89,31 +86,8 @@ static bool runImpl(Function &F) { return Changed; } -/// Early ARC transformations. -class ObjCARCExpand : public FunctionPass { - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; - -public: - static char ID; - ObjCARCExpand() : FunctionPass(ID) { - initializeObjCARCExpandPass(*PassRegistry::getPassRegistry()); - } -}; } // namespace -char ObjCARCExpand::ID = 0; -INITIALIZE_PASS(ObjCARCExpand, "objc-arc-expand", "ObjC ARC expansion", false, - false) - -Pass *llvm::createObjCARCExpandPass() { return new ObjCARCExpand(); } - -void ObjCARCExpand::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesCFG(); -} - -bool ObjCARCExpand::runOnFunction(Function &F) { return runImpl(F); } - PreservedAnalyses ObjCARCExpandPass::run(Function &F, FunctionAnalysisManager &AM) { if (!runImpl(F)) diff --git a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index e1a000b31cf9..a374958f9707 100644 --- a/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -31,7 +31,6 @@ #include "ProvenanceAnalysis.h" #include "PtrState.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -58,8 +57,6 @@ #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" @@ -483,8 +480,8 @@ namespace { /// The main ARC optimization pass. class ObjCARCOpt { - bool Changed; - bool CFGChanged; + bool Changed = false; + bool CFGChanged = false; ProvenanceAnalysis PA; /// A cache of references to runtime entry point constants. @@ -504,6 +501,8 @@ class ObjCARCOpt { /// is in fact used in the current function. unsigned UsedInThisFunction; + DenseMap<BasicBlock *, ColorVector> BlockEHColors; + bool OptimizeRetainRVCall(Function &F, Instruction *RetainRV); void OptimizeAutoreleaseRVCall(Function &F, Instruction *AutoreleaseRV, ARCInstKind &Class); @@ -511,17 +510,16 @@ class ObjCARCOpt { /// Optimize an individual call, optionally passing the /// GetArgRCIdentityRoot if it has already been computed. - void OptimizeIndividualCallImpl( - Function &F, DenseMap<BasicBlock *, ColorVector> &BlockColors, - Instruction *Inst, ARCInstKind Class, const Value *Arg); + void OptimizeIndividualCallImpl(Function &F, Instruction *Inst, + ARCInstKind Class, const Value *Arg); /// Try to optimize an AutoreleaseRV with a RetainRV or UnsafeClaimRV. If the /// optimization occurs, returns true to indicate that the caller should /// assume the instructions are dead. - bool OptimizeInlinedAutoreleaseRVCall( - Function &F, DenseMap<BasicBlock *, ColorVector> &BlockColors, - Instruction *Inst, const Value *&Arg, ARCInstKind Class, - Instruction *AutoreleaseRV, const Value *&AutoreleaseRVArg); + bool OptimizeInlinedAutoreleaseRVCall(Function &F, Instruction *Inst, + const Value *&Arg, ARCInstKind Class, + Instruction *AutoreleaseRV, + const Value *&AutoreleaseRVArg); void CheckForCFGHazards(const BasicBlock *BB, DenseMap<const BasicBlock *, BBState> &BBStates, @@ -569,54 +567,41 @@ class ObjCARCOpt { void OptimizeReturns(Function &F); + template <typename PredicateT> + static void cloneOpBundlesIf(CallBase *CI, + SmallVectorImpl<OperandBundleDef> &OpBundles, + PredicateT Predicate) { + for (unsigned I = 0, E = CI->getNumOperandBundles(); I != E; ++I) { + OperandBundleUse B = CI->getOperandBundleAt(I); + if (Predicate(B)) + OpBundles.emplace_back(B); + } + } + + void addOpBundleForFunclet(BasicBlock *BB, + SmallVectorImpl<OperandBundleDef> &OpBundles) { + if (!BlockEHColors.empty()) { + const ColorVector &CV = BlockEHColors.find(BB)->second; + assert(CV.size() > 0 && "Uncolored block"); + for (BasicBlock *EHPadBB : CV) + if (auto *EHPad = dyn_cast<FuncletPadInst>(EHPadBB->getFirstNonPHI())) { + OpBundles.emplace_back("funclet", EHPad); + return; + } + } + } + #ifndef NDEBUG void GatherStatistics(Function &F, bool AfterOptimization = false); #endif public: - void init(Module &M); + void init(Function &F); bool run(Function &F, AAResults &AA); - void releaseMemory(); bool hasCFGChanged() const { return CFGChanged; } }; - -/// The main ARC optimization pass. -class ObjCARCOptLegacyPass : public FunctionPass { -public: - ObjCARCOptLegacyPass() : FunctionPass(ID) { - initializeObjCARCOptLegacyPassPass(*PassRegistry::getPassRegistry()); - } - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool doInitialization(Module &M) override { - OCAO.init(M); - return false; - } - bool runOnFunction(Function &F) override { - return OCAO.run(F, getAnalysis<AAResultsWrapperPass>().getAAResults()); - } - void releaseMemory() override { OCAO.releaseMemory(); } - static char ID; - -private: - ObjCARCOpt OCAO; -}; } // end anonymous namespace -char ObjCARCOptLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(ObjCARCOptLegacyPass, "objc-arc", "ObjC ARC optimization", - false, false) -INITIALIZE_PASS_DEPENDENCY(ObjCARCAAWrapperPass) -INITIALIZE_PASS_END(ObjCARCOptLegacyPass, "objc-arc", "ObjC ARC optimization", - false, false) - -Pass *llvm::createObjCARCOptPass() { return new ObjCARCOptLegacyPass(); } - -void ObjCARCOptLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<ObjCARCAAWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); -} - /// Turn objc_retainAutoreleasedReturnValue into objc_retain if the operand is /// not a return value. bool @@ -664,8 +649,7 @@ ObjCARCOpt::OptimizeRetainRVCall(Function &F, Instruction *RetainRV) { } bool ObjCARCOpt::OptimizeInlinedAutoreleaseRVCall( - Function &F, DenseMap<BasicBlock *, ColorVector> &BlockColors, - Instruction *Inst, const Value *&Arg, ARCInstKind Class, + Function &F, Instruction *Inst, const Value *&Arg, ARCInstKind Class, Instruction *AutoreleaseRV, const Value *&AutoreleaseRVArg) { if (BundledInsts->contains(Inst)) return false; @@ -718,8 +702,7 @@ bool ObjCARCOpt::OptimizeInlinedAutoreleaseRVCall( EraseInstruction(Inst); // Run the normal optimizations on Release. - OptimizeIndividualCallImpl(F, BlockColors, Release, ARCInstKind::Release, - Arg); + OptimizeIndividualCallImpl(F, Release, ARCInstKind::Release, Arg); return true; } @@ -772,31 +755,6 @@ void ObjCARCOpt::OptimizeAutoreleaseRVCall(Function &F, LLVM_DEBUG(dbgs() << "New: " << *AutoreleaseRV << "\n"); } -namespace { -Instruction * -CloneCallInstForBB(CallInst &CI, BasicBlock &BB, - const DenseMap<BasicBlock *, ColorVector> &BlockColors) { - SmallVector<OperandBundleDef, 1> OpBundles; - for (unsigned I = 0, E = CI.getNumOperandBundles(); I != E; ++I) { - auto Bundle = CI.getOperandBundleAt(I); - // Funclets will be reassociated in the future. - if (Bundle.getTagID() == LLVMContext::OB_funclet) - continue; - OpBundles.emplace_back(Bundle); - } - - if (!BlockColors.empty()) { - const ColorVector &CV = BlockColors.find(&BB)->second; - assert(CV.size() == 1 && "non-unique color for block!"); - Instruction *EHPad = CV.front()->getFirstNonPHI(); - if (EHPad->isEHPad()) - OpBundles.emplace_back("funclet", EHPad); - } - - return CallInst::Create(&CI, OpBundles); -} -} - /// Visit each call, one at a time, and make simplifications without doing any /// additional analysis. void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { @@ -804,11 +762,6 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { // Reset all the flags in preparation for recomputing them. UsedInThisFunction = 0; - DenseMap<BasicBlock *, ColorVector> BlockColors; - if (F.hasPersonalityFn() && - isScopedEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) - BlockColors = colorEHFunclets(F); - // Store any delayed AutoreleaseRV intrinsics, so they can be easily paired // with RetainRV and UnsafeClaimRV. Instruction *DelayedAutoreleaseRV = nullptr; @@ -821,7 +774,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { auto optimizeDelayedAutoreleaseRV = [&]() { if (!DelayedAutoreleaseRV) return; - OptimizeIndividualCallImpl(F, BlockColors, DelayedAutoreleaseRV, + OptimizeIndividualCallImpl(F, DelayedAutoreleaseRV, ARCInstKind::AutoreleaseRV, DelayedAutoreleaseRVArg); setDelayedAutoreleaseRV(nullptr); @@ -884,7 +837,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { case ARCInstKind::UnsafeClaimRV: if (DelayedAutoreleaseRV) { // We have a potential RV pair. Check if they cancel out. - if (OptimizeInlinedAutoreleaseRVCall(F, BlockColors, Inst, Arg, Class, + if (OptimizeInlinedAutoreleaseRVCall(F, Inst, Arg, Class, DelayedAutoreleaseRV, DelayedAutoreleaseRVArg)) { setDelayedAutoreleaseRV(nullptr); @@ -895,7 +848,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { break; } - OptimizeIndividualCallImpl(F, BlockColors, Inst, Class, Arg); + OptimizeIndividualCallImpl(F, Inst, Class, Arg); } // Catch the final delayed AutoreleaseRV. @@ -929,9 +882,9 @@ static bool isInertARCValue(Value *V, SmallPtrSet<Value *, 1> &VisitedPhis) { return false; } -void ObjCARCOpt::OptimizeIndividualCallImpl( - Function &F, DenseMap<BasicBlock *, ColorVector> &BlockColors, - Instruction *Inst, ARCInstKind Class, const Value *Arg) { +void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, + ARCInstKind Class, + const Value *Arg) { LLVM_DEBUG(dbgs() << "Visiting: Class: " << Class << "; " << *Inst << "\n"); // We can delete this call if it takes an inert value. @@ -1038,7 +991,7 @@ void ObjCARCOpt::OptimizeIndividualCallImpl( CallInst *NewCall = CallInst::Create(Decl, Call->getArgOperand(0), "", Call); NewCall->setMetadata(MDKindCache.get(ARCMDKindID::ImpreciseRelease), - MDNode::get(C, None)); + MDNode::get(C, std::nullopt)); LLVM_DEBUG(dbgs() << "Replacing autorelease{,RV}(x) with objc_release(x) " "since x is otherwise unused.\nOld: " @@ -1189,8 +1142,12 @@ void ObjCARCOpt::OptimizeIndividualCallImpl( continue; Value *Op = PN->getIncomingValue(i); Instruction *InsertPos = &PN->getIncomingBlock(i)->back(); - CallInst *Clone = cast<CallInst>( - CloneCallInstForBB(*CInst, *InsertPos->getParent(), BlockColors)); + SmallVector<OperandBundleDef, 1> OpBundles; + cloneOpBundlesIf(CInst, OpBundles, [](const OperandBundleUse &B) { + return B.getTagID() != LLVMContext::OB_funclet; + }); + addOpBundleForFunclet(InsertPos->getParent(), OpBundles); + CallInst *Clone = CallInst::Create(CInst, OpBundles); if (Op->getType() != ParamTy) Op = new BitCastInst(Op, ParamTy, "", InsertPos); Clone->setArgOperand(0, Op); @@ -1503,7 +1460,7 @@ static void collectReleaseInsertPts( const BlotMapVector<Value *, RRInfo> &Retains, DenseMap<const Instruction *, SmallPtrSet<const Value *, 2>> &ReleaseInsertPtToRCIdentityRoots) { - for (auto &P : Retains) { + for (const auto &P : Retains) { // Retains is a map from an objc_retain call to a RRInfo of the RC identity // root of the call. Get the RC identity root of the objc_retain call. Instruction *Retain = cast<Instruction>(P.first); @@ -1541,7 +1498,7 @@ bool ObjCARCOpt::VisitInstructionTopDown( if (const SmallPtrSet<const Value *, 2> *Roots = getRCIdentityRootsFromReleaseInsertPt( Inst, ReleaseInsertPtToRCIdentityRoots)) - for (auto *Root : *Roots) { + for (const auto *Root : *Roots) { TopDownPtrState &S = MyStates.getPtrTopDownState(Root); // Disable code motion if the current position is S_Retain to prevent // moving the objc_retain call past objc_release calls. If it's @@ -1812,7 +1769,9 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, Value *MyArg = ArgTy == ParamTy ? Arg : new BitCastInst(Arg, ParamTy, "", InsertPt); Function *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); - CallInst *Call = CallInst::Create(Decl, MyArg, "", InsertPt); + SmallVector<OperandBundleDef, 1> BundleList; + addOpBundleForFunclet(InsertPt->getParent(), BundleList); + CallInst *Call = CallInst::Create(Decl, MyArg, BundleList, "", InsertPt); Call->setDoesNotThrow(); Call->setTailCall(); @@ -1825,7 +1784,9 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, Value *MyArg = ArgTy == ParamTy ? Arg : new BitCastInst(Arg, ParamTy, "", InsertPt); Function *Decl = EP.get(ARCRuntimeEntryPointKind::Release); - CallInst *Call = CallInst::Create(Decl, MyArg, "", InsertPt); + SmallVector<OperandBundleDef, 1> BundleList; + addOpBundleForFunclet(InsertPt->getParent(), BundleList); + CallInst *Call = CallInst::Create(Decl, MyArg, BundleList, "", InsertPt); // Attach a clang.imprecise_release metadata tag, if appropriate. if (MDNode *M = ReleasesToMove.ReleaseMetadata) Call->setMetadata(MDKindCache.get(ARCMDKindID::ImpreciseRelease), M); @@ -2441,17 +2402,22 @@ ObjCARCOpt::GatherStatistics(Function &F, bool AfterOptimization) { } #endif -void ObjCARCOpt::init(Module &M) { +void ObjCARCOpt::init(Function &F) { if (!EnableARCOpts) return; // Intuitively, objc_retain and others are nocapture, however in practice // they are not, because they return their argument value. And objc_release // calls finalizers which can have arbitrary side effects. - MDKindCache.init(&M); + MDKindCache.init(F.getParent()); // Initialize our runtime entry point cache. - EP.init(&M); + EP.init(F.getParent()); + + // Compute which blocks are in which funclet. + if (F.hasPersonalityFn() && + isScopedEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) + BlockEHColors = colorEHFunclets(F); } bool ObjCARCOpt::run(Function &F, AAResults &AA) { @@ -2521,17 +2487,13 @@ bool ObjCARCOpt::run(Function &F, AAResults &AA) { return Changed; } -void ObjCARCOpt::releaseMemory() { - PA.clear(); -} - /// @} /// PreservedAnalyses ObjCARCOptPass::run(Function &F, FunctionAnalysisManager &AM) { ObjCARCOpt OCAO; - OCAO.init(*F.getParent()); + OCAO.init(F); bool Changed = OCAO.run(F, AM.getResult<AAManager>(F)); bool CFGChanged = OCAO.hasCFGChanged(); diff --git a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp index 6731b841771c..2fa25a79ae9d 100644 --- a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp +++ b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp @@ -42,21 +42,40 @@ bool ProvenanceAnalysis::relatedSelect(const SelectInst *A, const Value *B) { // If the values are Selects with the same condition, we can do a more precise // check: just check for relations between the values on corresponding arms. - if (const SelectInst *SB = dyn_cast<SelectInst>(B)) + if (const SelectInst *SB = dyn_cast<SelectInst>(B)) { if (A->getCondition() == SB->getCondition()) return related(A->getTrueValue(), SB->getTrueValue()) || related(A->getFalseValue(), SB->getFalseValue()); + // Check both arms of B individually. Return false if neither arm is related + // to A. + if (!(related(SB->getTrueValue(), A) || related(SB->getFalseValue(), A))) + return false; + } + // Check both arms of the Select node individually. return related(A->getTrueValue(), B) || related(A->getFalseValue(), B); } bool ProvenanceAnalysis::relatedPHI(const PHINode *A, const Value *B) { - // If the values are PHIs in the same block, we can do a more precise as well - // as efficient check: just check for relations between the values on - // corresponding edges. - if (const PHINode *PNB = dyn_cast<PHINode>(B)) + + auto comparePHISources = [this](const PHINode *PNA, const Value *B) -> bool { + // Check each unique source of the PHI node against B. + SmallPtrSet<const Value *, 4> UniqueSrc; + for (Value *PV1 : PNA->incoming_values()) { + if (UniqueSrc.insert(PV1).second && related(PV1, B)) + return true; + } + + // All of the arms checked out. + return false; + }; + + if (const PHINode *PNB = dyn_cast<PHINode>(B)) { + // If the values are PHIs in the same block, we can do a more precise as + // well as efficient check: just check for relations between the values on + // corresponding edges. if (PNB->getParent() == A->getParent()) { for (unsigned i = 0, e = A->getNumIncomingValues(); i != e; ++i) if (related(A->getIncomingValue(i), @@ -65,15 +84,11 @@ bool ProvenanceAnalysis::relatedPHI(const PHINode *A, return false; } - // Check each unique source of the PHI node against B. - SmallPtrSet<const Value *, 4> UniqueSrc; - for (Value *PV1 : A->incoming_values()) { - if (UniqueSrc.insert(PV1).second && related(PV1, B)) - return true; + if (!comparePHISources(PNB, A)) + return false; } - // All of the arms checked out. - return false; + return comparePHISources(A, B); } /// Test if the value of P, or any value covered by its provenance, is ever @@ -125,22 +140,19 @@ bool ProvenanceAnalysis::relatedCheck(const Value *A, const Value *B) { bool BIsIdentified = IsObjCIdentifiedObject(B); // An ObjC-Identified object can't alias a load if it is never locally stored. - if (AIsIdentified) { - // Check for an obvious escape. - if (isa<LoadInst>(B)) - return IsStoredObjCPointer(A); - if (BIsIdentified) { - // Check for an obvious escape. - if (isa<LoadInst>(A)) - return IsStoredObjCPointer(B); - // Both pointers are identified and escapes aren't an evident problem. - return false; - } - } else if (BIsIdentified) { - // Check for an obvious escape. - if (isa<LoadInst>(A)) - return IsStoredObjCPointer(B); - } + + // Check for an obvious escape. + if ((AIsIdentified && isa<LoadInst>(B) && !IsStoredObjCPointer(A)) || + (BIsIdentified && isa<LoadInst>(A) && !IsStoredObjCPointer(B))) + return false; + + if ((AIsIdentified && isa<LoadInst>(B)) || + (BIsIdentified && isa<LoadInst>(A))) + return true; + + // Both pointers are identified and escapes aren't an evident problem. + if (AIsIdentified && BIsIdentified && !isa<LoadInst>(A) && !isa<LoadInst>(B)) + return false; // Special handling for PHI and Select. if (const PHINode *PN = dyn_cast<PHINode>(A)) @@ -174,6 +186,8 @@ bool ProvenanceAnalysis::related(const Value *A, const Value *B) { return Pair.first->second; bool Result = relatedCheck(A, B); + assert(relatedCheck(B, A) == Result && + "relatedCheck result depending on order of parameters!"); CachedResults[ValuePairTy(A, B)] = Result; return Result; } diff --git a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.h b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.h index 1624cf26094a..bc946fac4544 100644 --- a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.h +++ b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysis.h @@ -26,6 +26,7 @@ #define LLVM_LIB_TRANSFORMS_OBJCARC_PROVENANCEANALYSIS_H #include "llvm/ADT/DenseMap.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/ValueHandle.h" #include <utility> diff --git a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp index fe637ee066a4..9f15772f2fa1 100644 --- a/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp +++ b/llvm/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp @@ -7,38 +7,16 @@ //===----------------------------------------------------------------------===// #include "ProvenanceAnalysis.h" +#include "llvm/Transforms/ObjCARC.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/Passes.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; using namespace llvm::objcarc; -namespace { -class PAEval : public FunctionPass { - -public: - static char ID; - PAEval(); - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; -}; -} - -char PAEval::ID = 0; -PAEval::PAEval() : FunctionPass(ID) {} - -void PAEval::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<AAResultsWrapperPass>(); -} - static StringRef getName(Value *V) { StringRef Name = V->getName(); if (Name.startswith("\1")) @@ -52,7 +30,7 @@ static void insertIfNamed(SetVector<Value *> &Values, Value *V) { Values.insert(V); } -bool PAEval::runOnFunction(Function &F) { +PreservedAnalyses PAEvalPass::run(Function &F, FunctionAnalysisManager &AM) { SetVector<Value *> Values; for (auto &Arg : F.args()) @@ -66,7 +44,7 @@ bool PAEval::runOnFunction(Function &F) { } ProvenanceAnalysis PA; - PA.setAA(&getAnalysis<AAResultsWrapperPass>().getAAResults()); + PA.setAA(&AM.getResult<AAManager>(F)); for (Value *V1 : Values) { StringRef NameV1 = getName(V1); @@ -82,13 +60,5 @@ bool PAEval::runOnFunction(Function &F) { } } - return false; + return PreservedAnalyses::all(); } - -FunctionPass *llvm::createPAEvalPass() { return new PAEval(); } - -INITIALIZE_PASS_BEGIN(PAEval, "pa-eval", - "Evaluate ProvenanceAnalysis on all pairs", false, true) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(PAEval, "pa-eval", - "Evaluate ProvenanceAnalysis on all pairs", false, true) diff --git a/llvm/lib/Transforms/ObjCARC/PtrState.cpp b/llvm/lib/Transforms/ObjCARC/PtrState.cpp index d10d5851d5ea..e9b2dbeb62e6 100644 --- a/llvm/lib/Transforms/ObjCARC/PtrState.cpp +++ b/llvm/lib/Transforms/ObjCARC/PtrState.cpp @@ -212,7 +212,7 @@ bool BottomUpPtrState::MatchWithRetain() { // imprecise release, clear our reverse insertion points. if (OldSeq != S_Use || IsTrackingImpreciseReleases()) ClearReverseInsertPts(); - LLVM_FALLTHROUGH; + [[fallthrough]]; case S_CanRelease: return true; case S_None: @@ -360,7 +360,7 @@ bool TopDownPtrState::MatchWithRelease(ARCMDKindCache &Cache, case S_CanRelease: if (OldSeq == S_Retain || ReleaseMetadata != nullptr) ClearReverseInsertPts(); - LLVM_FALLTHROUGH; + [[fallthrough]]; case S_Use: SetReleaseMetadata(ReleaseMetadata); SetTailCallRelease(cast<CallInst>(Release)->isTailCall()); diff --git a/llvm/lib/Transforms/Scalar/ADCE.cpp b/llvm/lib/Transforms/Scalar/ADCE.cpp index cdf9de8d78d5..253293582945 100644 --- a/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -29,6 +29,7 @@ #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/Dominators.h" @@ -295,7 +296,7 @@ void AggressiveDeadCodeElimination::initialize() { // return of the function. // We do this by seeing which of the postdomtree root children exit the // program, and for all others, mark the subtree live. - for (auto &PDTChild : children<DomTreeNode *>(PDT.getRootNode())) { + for (const auto &PDTChild : children<DomTreeNode *>(PDT.getRootNode())) { auto *BB = PDTChild->getBlock(); auto &Info = BlockInfo[BB]; // Real function return @@ -306,7 +307,7 @@ void AggressiveDeadCodeElimination::initialize() { } // This child is something else, like an infinite loop. - for (auto DFNode : depth_first(PDTChild)) + for (auto *DFNode : depth_first(PDTChild)) markLive(BlockInfo[DFNode->getBlock()].Terminator); } @@ -543,6 +544,11 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { continue; if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&I)) { + // Avoid removing a dbg.assign that is linked to instructions because it + // holds information about an existing store. + if (auto *DAI = dyn_cast<DbgAssignIntrinsic>(DII)) + if (!at::getAssignmentInsts(DAI).empty()) + continue; // Check if the scope of this variable location is alive. if (AliveScopes.count(DII->getDebugLoc()->getScope())) continue; diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index 9571e99dfb19..f419f7bd769f 100644 --- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -116,7 +116,7 @@ static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV, return Align(DiffUnitsAbs); } - return None; + return std::nullopt; } // There is an address given by an offset OffSCEV from AASCEV which has an diff --git a/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp b/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp index 155f47b49357..79f7e253d45b 100644 --- a/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp +++ b/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp @@ -16,8 +16,6 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/MemoryOpRemark.h" @@ -82,42 +80,6 @@ static void runImpl(Function &F, const TargetLibraryInfo &TLI) { } } -namespace { - -struct AnnotationRemarksLegacy : public FunctionPass { - static char ID; - - AnnotationRemarksLegacy() : FunctionPass(ID) { - initializeAnnotationRemarksLegacyPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - const TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - runImpl(F, TLI); - return false; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; - -} // end anonymous namespace - -char AnnotationRemarksLegacy::ID = 0; - -INITIALIZE_PASS_BEGIN(AnnotationRemarksLegacy, "annotation-remarks", - "Annotation Remarks", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(AnnotationRemarksLegacy, "annotation-remarks", - "Annotation Remarks", false, false) - -FunctionPass *llvm::createAnnotationRemarksLegacyPass() { - return new AnnotationRemarksLegacy(); -} - PreservedAnalyses AnnotationRemarksPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); diff --git a/llvm/lib/Transforms/Scalar/BDCE.cpp b/llvm/lib/Transforms/Scalar/BDCE.cpp index 6c2467db79f7..187927b3dede 100644 --- a/llvm/lib/Transforms/Scalar/BDCE.cpp +++ b/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -143,9 +143,8 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { clearAssumptionsOfUsers(&I, DB); - // FIXME: In theory we could substitute undef here instead of zero. - // This should be reconsidered once we settle on the semantics of - // undef, poison, etc. + // Substitute all uses with zero. In theory we could use `freeze poison` + // instead, but that seems unlikely to be profitable. U.set(ConstantInt::get(U->getType(), 0)); ++NumSimplified; Changed = true; diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp index cc12033fb677..6665a927826d 100644 --- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -165,7 +165,7 @@ static void recordConditions(CallBase &CB, BasicBlock *Pred, } static void addConditions(CallBase &CB, const ConditionsTy &Conditions) { - for (auto &Cond : Conditions) { + for (const auto &Cond : Conditions) { Value *Arg = Cond.first->getOperand(0); Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1)); if (Cond.second == ICmpInst::ICMP_EQ) @@ -364,9 +364,9 @@ static void splitCallSite(CallBase &CB, // attempting removal. SmallVector<BasicBlock *, 2> Splits(predecessors((TailBB))); assert(Splits.size() == 2 && "Expected exactly 2 splits!"); - for (unsigned i = 0; i < Splits.size(); i++) { - Splits[i]->getTerminator()->eraseFromParent(); - DTU.applyUpdatesPermissive({{DominatorTree::Delete, Splits[i], TailBB}}); + for (BasicBlock *BB : Splits) { + BB->getTerminator()->eraseFromParent(); + DTU.applyUpdatesPermissive({{DominatorTree::Delete, BB, TailBB}}); } // Erase the tail block once done with musttail patching diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index fe6f9486ab0c..8858545bbc5d 100644 --- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -35,8 +35,6 @@ #include "llvm/Transforms/Scalar/ConstantHoisting.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -221,7 +219,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, // dominated by any other blocks in set 'BBs', and all nodes in the path // in the dominator tree from Entry to 'BB'. SmallPtrSet<BasicBlock *, 16> Candidates; - for (auto BB : BBs) { + for (auto *BB : BBs) { // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(BB)) continue; @@ -258,7 +256,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, Orders.push_back(Entry); while (Idx != Orders.size()) { BasicBlock *Node = Orders[Idx++]; - for (auto ChildDomNode : DT.getNode(Node)->children()) { + for (auto *ChildDomNode : DT.getNode(Node)->children()) { if (Candidates.count(ChildDomNode->getBlock())) Orders.push_back(ChildDomNode->getBlock()); } @@ -330,7 +328,7 @@ SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint( if (BFI) { findBestInsertionSet(*DT, *BFI, Entry, BBs); - for (auto BB : BBs) { + for (auto *BB : BBs) { BasicBlock::iterator InsertPt = BB->begin(); for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt) ; @@ -533,8 +531,9 @@ void ConstantHoistingPass::collectConstantCandidates(Function &Fn) { // bit widths (APInt Operator- does not like that). If the value cannot be // represented in uint64 we return an "empty" APInt. This is then interpreted // as the value is not in range. -static Optional<APInt> calculateOffsetDiff(const APInt &V1, const APInt &V2) { - Optional<APInt> Res = None; +static std::optional<APInt> calculateOffsetDiff(const APInt &V1, + const APInt &V2) { + std::optional<APInt> Res; unsigned BW = V1.getBitWidth() > V2.getBitWidth() ? V1.getBitWidth() : V2.getBitWidth(); uint64_t LimVal1 = V1.getLimitedValue(); @@ -606,14 +605,13 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, LLVM_DEBUG(dbgs() << "Cost: " << Cost << "\n"); for (auto C2 = S; C2 != E; ++C2) { - Optional<APInt> Diff = calculateOffsetDiff( - C2->ConstInt->getValue(), - ConstCand->ConstInt->getValue()); + std::optional<APInt> Diff = calculateOffsetDiff( + C2->ConstInt->getValue(), ConstCand->ConstInt->getValue()); if (Diff) { const InstructionCost ImmCosts = - TTI->getIntImmCodeSizeCost(Opcode, OpndIdx, Diff.value(), Ty); + TTI->getIntImmCodeSizeCost(Opcode, OpndIdx, *Diff, Ty); Cost -= ImmCosts; - LLVM_DEBUG(dbgs() << "Offset " << Diff.value() << " " + LLVM_DEBUG(dbgs() << "Offset " << *Diff << " " << "has penalty: " << ImmCosts << "\n" << "Adjusted cost: " << Cost << "\n"); } @@ -724,7 +722,7 @@ void ConstantHoistingPass::findBaseConstants(GlobalVariable *BaseGV) { /// Updates the operand at Idx in instruction Inst with the result of /// instruction Mat. If the instruction is a PHI node then special -/// handling for duplicate values form the same incoming basic block is +/// handling for duplicate values from the same incoming basic block is /// required. /// \return The update will always succeed, but the return value indicated if /// Mat was used for the update or not. diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 6dfa2440023f..12fcb6aa9846 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -19,18 +19,20 @@ #include "llvm/Analysis/ConstraintSystem.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Support/MathExtras.h" -#include "llvm/Transforms/Scalar.h" +#include <cmath> #include <string> using namespace llvm; @@ -42,9 +44,27 @@ STATISTIC(NumCondsRemoved, "Number of instructions removed"); DEBUG_COUNTER(EliminatedCounter, "conds-eliminated", "Controls which conditions are eliminated"); +static cl::opt<unsigned> + MaxRows("constraint-elimination-max-rows", cl::init(500), cl::Hidden, + cl::desc("Maximum number of rows to keep in constraint system")); + static int64_t MaxConstraintValue = std::numeric_limits<int64_t>::max(); static int64_t MinSignedConstraintValue = std::numeric_limits<int64_t>::min(); +// A helper to multiply 2 signed integers where overflowing is allowed. +static int64_t multiplyWithOverflow(int64_t A, int64_t B) { + int64_t Result; + MulOverflow(A, B, Result); + return Result; +} + +// A helper to add 2 signed integers where overflowing is allowed. +static int64_t addWithOverflow(int64_t A, int64_t B) { + int64_t Result; + AddOverflow(A, B, Result); + return Result; +} + namespace { class ConstraintInfo; @@ -52,15 +72,14 @@ class ConstraintInfo; struct StackEntry { unsigned NumIn; unsigned NumOut; - bool IsNot; bool IsSigned = false; /// Variables that can be removed from the system once the stack entry gets /// removed. SmallVector<Value *, 2> ValuesToRelease; - StackEntry(unsigned NumIn, unsigned NumOut, bool IsNot, bool IsSigned, + StackEntry(unsigned NumIn, unsigned NumOut, bool IsSigned, SmallVector<Value *, 2> ValuesToRelease) - : NumIn(NumIn), NumOut(NumOut), IsNot(IsNot), IsSigned(IsSigned), + : NumIn(NumIn), NumOut(NumOut), IsSigned(IsSigned), ValuesToRelease(ValuesToRelease) {} }; @@ -78,6 +97,8 @@ struct ConstraintTy { SmallVector<int64_t, 8> Coefficients; SmallVector<PreconditionTy, 2> Preconditions; + SmallVector<SmallVector<int64_t, 8>> ExtraInfo; + bool IsSigned = false; bool IsEq = false; @@ -90,18 +111,6 @@ struct ConstraintTy { unsigned empty() const { return Coefficients.empty(); } - /// Returns true if any constraint has a non-zero coefficient for any of the - /// newly added indices. Zero coefficients for new indices are removed. If it - /// returns true, no new variable need to be added to the system. - bool needsNewIndices(const DenseMap<Value *, unsigned> &NewIndices) { - for (unsigned I = 0; I < NewIndices.size(); ++I) { - int64_t Last = Coefficients.pop_back_val(); - if (Last != 0) - return true; - } - return false; - } - /// Returns true if all preconditions for this list of constraints are /// satisfied given \p CS and the corresponding \p Value2Index mapping. bool isValid(const ConstraintInfo &Info) const; @@ -120,7 +129,11 @@ class ConstraintInfo { ConstraintSystem UnsignedCS; ConstraintSystem SignedCS; + const DataLayout &DL; + public: + ConstraintInfo(const DataLayout &DL) : DL(DL) {} + DenseMap<Value *, unsigned> &getValue2Index(bool Signed) { return Signed ? SignedValue2Index : UnsignedValue2Index; } @@ -142,140 +155,240 @@ public: bool doesHold(CmpInst::Predicate Pred, Value *A, Value *B) const; - void addFact(CmpInst::Predicate Pred, Value *A, Value *B, bool IsNegated, - unsigned NumIn, unsigned NumOut, - SmallVectorImpl<StackEntry> &DFSInStack); + void addFact(CmpInst::Predicate Pred, Value *A, Value *B, unsigned NumIn, + unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack); /// Turn a comparison of the form \p Op0 \p Pred \p Op1 into a vector of /// constraints, using indices from the corresponding constraint system. - /// Additional indices for newly discovered values are added to \p NewIndices. + /// New variables that need to be added to the system are collected in + /// \p NewVariables. ConstraintTy getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, - DenseMap<Value *, unsigned> &NewIndices) const; - - /// Turn a condition \p CmpI into a vector of constraints, using indices from - /// the corresponding constraint system. Additional indices for newly - /// discovered values are added to \p NewIndices. - ConstraintTy getConstraint(CmpInst *Cmp, - DenseMap<Value *, unsigned> &NewIndices) const { - return getConstraint(Cmp->getPredicate(), Cmp->getOperand(0), - Cmp->getOperand(1), NewIndices); - } + SmallVectorImpl<Value *> &NewVariables) const; + + /// Turns a comparison of the form \p Op0 \p Pred \p Op1 into a vector of + /// constraints using getConstraint. Returns an empty constraint if the result + /// cannot be used to query the existing constraint system, e.g. because it + /// would require adding new variables. Also tries to convert signed + /// predicates to unsigned ones if possible to allow using the unsigned system + /// which increases the effectiveness of the signed <-> unsigned transfer + /// logic. + ConstraintTy getConstraintForSolving(CmpInst::Predicate Pred, Value *Op0, + Value *Op1) const; /// Try to add information from \p A \p Pred \p B to the unsigned/signed /// system if \p Pred is signed/unsigned. void transferToOtherSystem(CmpInst::Predicate Pred, Value *A, Value *B, - bool IsNegated, unsigned NumIn, unsigned NumOut, + unsigned NumIn, unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack); }; +/// Represents a (Coefficient * Variable) entry after IR decomposition. +struct DecompEntry { + int64_t Coefficient; + Value *Variable; + /// True if the variable is known positive in the current constraint. + bool IsKnownNonNegative; + + DecompEntry(int64_t Coefficient, Value *Variable, + bool IsKnownNonNegative = false) + : Coefficient(Coefficient), Variable(Variable), + IsKnownNonNegative(IsKnownNonNegative) {} +}; + +/// Represents an Offset + Coefficient1 * Variable1 + ... decomposition. +struct Decomposition { + int64_t Offset = 0; + SmallVector<DecompEntry, 3> Vars; + + Decomposition(int64_t Offset) : Offset(Offset) {} + Decomposition(Value *V, bool IsKnownNonNegative = false) { + Vars.emplace_back(1, V, IsKnownNonNegative); + } + Decomposition(int64_t Offset, ArrayRef<DecompEntry> Vars) + : Offset(Offset), Vars(Vars) {} + + void add(int64_t OtherOffset) { + Offset = addWithOverflow(Offset, OtherOffset); + } + + void add(const Decomposition &Other) { + add(Other.Offset); + append_range(Vars, Other.Vars); + } + + void mul(int64_t Factor) { + Offset = multiplyWithOverflow(Offset, Factor); + for (auto &Var : Vars) + Var.Coefficient = multiplyWithOverflow(Var.Coefficient, Factor); + } +}; + } // namespace -// Decomposes \p V into a vector of pairs of the form { c, X } where c * X. The -// sum of the pairs equals \p V. The first pair is the constant-factor and X -// must be nullptr. If the expression cannot be decomposed, returns an empty -// vector. -static SmallVector<std::pair<int64_t, Value *>, 4> -decompose(Value *V, SmallVector<PreconditionTy, 4> &Preconditions, - bool IsSigned) { - - auto CanUseSExt = [](ConstantInt *CI) { - const APInt &Val = CI->getValue(); - return Val.sgt(MinSignedConstraintValue) && Val.slt(MaxConstraintValue); +static Decomposition decompose(Value *V, + SmallVectorImpl<PreconditionTy> &Preconditions, + bool IsSigned, const DataLayout &DL); + +static bool canUseSExt(ConstantInt *CI) { + const APInt &Val = CI->getValue(); + return Val.sgt(MinSignedConstraintValue) && Val.slt(MaxConstraintValue); +} + +static Decomposition +decomposeGEP(GetElementPtrInst &GEP, + SmallVectorImpl<PreconditionTy> &Preconditions, bool IsSigned, + const DataLayout &DL) { + // Do not reason about pointers where the index size is larger than 64 bits, + // as the coefficients used to encode constraints are 64 bit integers. + if (DL.getIndexTypeSizeInBits(GEP.getPointerOperand()->getType()) > 64) + return &GEP; + + if (!GEP.isInBounds()) + return &GEP; + + assert(!IsSigned && "The logic below only supports decomposition for " + "unsinged predicates at the moment."); + Type *PtrTy = GEP.getType()->getScalarType(); + unsigned BitWidth = DL.getIndexTypeSizeInBits(PtrTy); + MapVector<Value *, APInt> VariableOffsets; + APInt ConstantOffset(BitWidth, 0); + if (!GEP.collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) + return &GEP; + + // Handle the (gep (gep ....), C) case by incrementing the constant + // coefficient of the inner GEP, if C is a constant. + auto *InnerGEP = dyn_cast<GetElementPtrInst>(GEP.getPointerOperand()); + if (VariableOffsets.empty() && InnerGEP && InnerGEP->getNumOperands() == 2) { + auto Result = decompose(InnerGEP, Preconditions, IsSigned, DL); + Result.add(ConstantOffset.getSExtValue()); + + if (ConstantOffset.isNegative()) { + unsigned Scale = DL.getTypeAllocSize(InnerGEP->getResultElementType()); + int64_t ConstantOffsetI = ConstantOffset.getSExtValue(); + if (ConstantOffsetI % Scale != 0) + return &GEP; + // Add pre-condition ensuring the GEP is increasing monotonically and + // can be de-composed. + // Both sides are normalized by being divided by Scale. + Preconditions.emplace_back( + CmpInst::ICMP_SGE, InnerGEP->getOperand(1), + ConstantInt::get(InnerGEP->getOperand(1)->getType(), + -1 * (ConstantOffsetI / Scale))); + } + return Result; + } + + Decomposition Result(ConstantOffset.getSExtValue(), + DecompEntry(1, GEP.getPointerOperand())); + for (auto [Index, Scale] : VariableOffsets) { + auto IdxResult = decompose(Index, Preconditions, IsSigned, DL); + IdxResult.mul(Scale.getSExtValue()); + Result.add(IdxResult); + + // If Op0 is signed non-negative, the GEP is increasing monotonically and + // can be de-composed. + if (!isKnownNonNegative(Index, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1)) + Preconditions.emplace_back(CmpInst::ICMP_SGE, Index, + ConstantInt::get(Index->getType(), 0)); + } + return Result; +} + +// Decomposes \p V into a constant offset + list of pairs { Coefficient, +// Variable } where Coefficient * Variable. The sum of the constant offset and +// pairs equals \p V. +static Decomposition decompose(Value *V, + SmallVectorImpl<PreconditionTy> &Preconditions, + bool IsSigned, const DataLayout &DL) { + + auto MergeResults = [&Preconditions, IsSigned, &DL](Value *A, Value *B, + bool IsSignedB) { + auto ResA = decompose(A, Preconditions, IsSigned, DL); + auto ResB = decompose(B, Preconditions, IsSignedB, DL); + ResA.add(ResB); + return ResA; }; + // Decompose \p V used with a signed predicate. if (IsSigned) { if (auto *CI = dyn_cast<ConstantInt>(V)) { - if (CanUseSExt(CI)) - return {{CI->getSExtValue(), nullptr}}; + if (canUseSExt(CI)) + return CI->getSExtValue(); } + Value *Op0; + Value *Op1; + if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) + return MergeResults(Op0, Op1, IsSigned); - return {{0, nullptr}, {1, V}}; + return V; } if (auto *CI = dyn_cast<ConstantInt>(V)) { if (CI->uge(MaxConstraintValue)) - return {}; - return {{CI->getZExtValue(), nullptr}}; - } - auto *GEP = dyn_cast<GetElementPtrInst>(V); - if (GEP && GEP->getNumOperands() == 2 && GEP->isInBounds()) { - Value *Op0, *Op1; - ConstantInt *CI; - - // If the index is zero-extended, it is guaranteed to be positive. - if (match(GEP->getOperand(GEP->getNumOperands() - 1), - m_ZExt(m_Value(Op0)))) { - if (match(Op0, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && - CanUseSExt(CI)) - return {{0, nullptr}, - {1, GEP->getPointerOperand()}, - {std::pow(int64_t(2), CI->getSExtValue()), Op1}}; - if (match(Op0, m_NSWAdd(m_Value(Op1), m_ConstantInt(CI))) && - CanUseSExt(CI)) - return {{CI->getSExtValue(), nullptr}, - {1, GEP->getPointerOperand()}, - {1, Op1}}; - return {{0, nullptr}, {1, GEP->getPointerOperand()}, {1, Op0}}; - } - - if (match(GEP->getOperand(GEP->getNumOperands() - 1), m_ConstantInt(CI)) && - !CI->isNegative() && CanUseSExt(CI)) - return {{CI->getSExtValue(), nullptr}, {1, GEP->getPointerOperand()}}; - - SmallVector<std::pair<int64_t, Value *>, 4> Result; - if (match(GEP->getOperand(GEP->getNumOperands() - 1), - m_NUWShl(m_Value(Op0), m_ConstantInt(CI))) && - CanUseSExt(CI)) - Result = {{0, nullptr}, - {1, GEP->getPointerOperand()}, - {std::pow(int64_t(2), CI->getSExtValue()), Op0}}; - else if (match(GEP->getOperand(GEP->getNumOperands() - 1), - m_NSWAdd(m_Value(Op0), m_ConstantInt(CI))) && - CanUseSExt(CI)) - Result = {{CI->getSExtValue(), nullptr}, - {1, GEP->getPointerOperand()}, - {1, Op0}}; - else { - Op0 = GEP->getOperand(GEP->getNumOperands() - 1); - Result = {{0, nullptr}, {1, GEP->getPointerOperand()}, {1, Op0}}; - } - // If Op0 is signed non-negative, the GEP is increasing monotonically and - // can be de-composed. - Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, - ConstantInt::get(Op0->getType(), 0)); - return Result; + return V; + return int64_t(CI->getZExtValue()); } + if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) + return decomposeGEP(*GEP, Preconditions, IsSigned, DL); + Value *Op0; - if (match(V, m_ZExt(m_Value(Op0)))) + bool IsKnownNonNegative = false; + if (match(V, m_ZExt(m_Value(Op0)))) { + IsKnownNonNegative = true; V = Op0; + } Value *Op1; ConstantInt *CI; - if (match(V, m_NUWAdd(m_Value(Op0), m_ConstantInt(CI))) && - !CI->uge(MaxConstraintValue)) - return {{CI->getZExtValue(), nullptr}, {1, Op0}}; + if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1)))) { + return MergeResults(Op0, Op1, IsSigned); + } + if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) { + if (!isKnownNonNegative(Op0, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1)) + Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, + ConstantInt::get(Op0->getType(), 0)); + if (!isKnownNonNegative(Op1, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1)) + Preconditions.emplace_back(CmpInst::ICMP_SGE, Op1, + ConstantInt::get(Op1->getType(), 0)); + + return MergeResults(Op0, Op1, IsSigned); + } + if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && - CanUseSExt(CI)) { + canUseSExt(CI)) { Preconditions.emplace_back( CmpInst::ICMP_UGE, Op0, ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); - return {{CI->getSExtValue(), nullptr}, {1, Op0}}; + return MergeResults(Op0, CI, true); + } + + if (match(V, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI)) { + int64_t Mult = int64_t(std::pow(int64_t(2), CI->getSExtValue())); + auto Result = decompose(Op1, Preconditions, IsSigned, DL); + Result.mul(Mult); + return Result; } - if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1)))) - return {{0, nullptr}, {1, Op0}, {1, Op1}}; - if (match(V, m_NUWSub(m_Value(Op0), m_ConstantInt(CI))) && CanUseSExt(CI)) - return {{-1 * CI->getSExtValue(), nullptr}, {1, Op0}}; + if (match(V, m_NUWMul(m_Value(Op1), m_ConstantInt(CI))) && canUseSExt(CI) && + (!CI->isNegative())) { + auto Result = decompose(Op1, Preconditions, IsSigned, DL); + Result.mul(CI->getSExtValue()); + return Result; + } + + if (match(V, m_NUWSub(m_Value(Op0), m_ConstantInt(CI))) && canUseSExt(CI)) + return {-1 * CI->getSExtValue(), {{1, Op0}}}; if (match(V, m_NUWSub(m_Value(Op0), m_Value(Op1)))) - return {{0, nullptr}, {1, Op0}, {-1, Op1}}; + return {0, {{1, Op0}, {-1, Op1}}}; - return {{0, nullptr}, {1, V}}; + return {V, IsKnownNonNegative}; } ConstraintTy ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, - DenseMap<Value *, unsigned> &NewIndices) const { + SmallVectorImpl<Value *> &NewVariables) const { + assert(NewVariables.empty() && "NewVariables must be empty when passed in"); bool IsEq = false; // Try to convert Pred to one of ULE/SLT/SLE/SLT. switch (Pred) { @@ -305,7 +418,6 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, break; } - // Only ULE and ULT predicates are supported at the moment. if (Pred != CmpInst::ICMP_ULE && Pred != CmpInst::ICMP_ULT && Pred != CmpInst::ICMP_SLE && Pred != CmpInst::ICMP_SLT) return {}; @@ -314,49 +426,58 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, bool IsSigned = CmpInst::isSigned(Pred); auto &Value2Index = getValue2Index(IsSigned); auto ADec = decompose(Op0->stripPointerCastsSameRepresentation(), - Preconditions, IsSigned); + Preconditions, IsSigned, DL); auto BDec = decompose(Op1->stripPointerCastsSameRepresentation(), - Preconditions, IsSigned); - // Skip if decomposing either of the values failed. - if (ADec.empty() || BDec.empty()) - return {}; - - int64_t Offset1 = ADec[0].first; - int64_t Offset2 = BDec[0].first; + Preconditions, IsSigned, DL); + int64_t Offset1 = ADec.Offset; + int64_t Offset2 = BDec.Offset; Offset1 *= -1; - // Create iterator ranges that skip the constant-factor. - auto VariablesA = llvm::drop_begin(ADec); - auto VariablesB = llvm::drop_begin(BDec); + auto &VariablesA = ADec.Vars; + auto &VariablesB = BDec.Vars; - // First try to look up \p V in Value2Index and NewIndices. Otherwise add a - // new entry to NewIndices. - auto GetOrAddIndex = [&Value2Index, &NewIndices](Value *V) -> unsigned { + // First try to look up \p V in Value2Index and NewVariables. Otherwise add a + // new entry to NewVariables. + DenseMap<Value *, unsigned> NewIndexMap; + auto GetOrAddIndex = [&Value2Index, &NewVariables, + &NewIndexMap](Value *V) -> unsigned { auto V2I = Value2Index.find(V); if (V2I != Value2Index.end()) return V2I->second; auto Insert = - NewIndices.insert({V, Value2Index.size() + NewIndices.size() + 1}); + NewIndexMap.insert({V, Value2Index.size() + NewVariables.size() + 1}); + if (Insert.second) + NewVariables.push_back(V); return Insert.first->second; }; - // Make sure all variables have entries in Value2Index or NewIndices. - for (const auto &KV : - concat<std::pair<int64_t, Value *>>(VariablesA, VariablesB)) - GetOrAddIndex(KV.second); + // Make sure all variables have entries in Value2Index or NewVariables. + for (const auto &KV : concat<DecompEntry>(VariablesA, VariablesB)) + GetOrAddIndex(KV.Variable); // Build result constraint, by first adding all coefficients from A and then // subtracting all coefficients from B. ConstraintTy Res( - SmallVector<int64_t, 8>(Value2Index.size() + NewIndices.size() + 1, 0), + SmallVector<int64_t, 8>(Value2Index.size() + NewVariables.size() + 1, 0), IsSigned); + // Collect variables that are known to be positive in all uses in the + // constraint. + DenseMap<Value *, bool> KnownNonNegativeVariables; Res.IsEq = IsEq; auto &R = Res.Coefficients; - for (const auto &KV : VariablesA) - R[GetOrAddIndex(KV.second)] += KV.first; + for (const auto &KV : VariablesA) { + R[GetOrAddIndex(KV.Variable)] += KV.Coefficient; + auto I = + KnownNonNegativeVariables.insert({KV.Variable, KV.IsKnownNonNegative}); + I.first->second &= KV.IsKnownNonNegative; + } - for (const auto &KV : VariablesB) - R[GetOrAddIndex(KV.second)] -= KV.first; + for (const auto &KV : VariablesB) { + R[GetOrAddIndex(KV.Variable)] -= KV.Coefficient; + auto I = + KnownNonNegativeVariables.insert({KV.Variable, KV.IsKnownNonNegative}); + I.first->second &= KV.IsKnownNonNegative; + } int64_t OffsetSum; if (AddOverflow(Offset1, Offset2, OffsetSum)) @@ -366,9 +487,48 @@ ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, return {}; R[0] = OffsetSum; Res.Preconditions = std::move(Preconditions); + + // Remove any (Coefficient, Variable) entry where the Coefficient is 0 for new + // variables. + while (!NewVariables.empty()) { + int64_t Last = R.back(); + if (Last != 0) + break; + R.pop_back(); + Value *RemovedV = NewVariables.pop_back_val(); + NewIndexMap.erase(RemovedV); + } + + // Add extra constraints for variables that are known positive. + for (auto &KV : KnownNonNegativeVariables) { + if (!KV.second || (Value2Index.find(KV.first) == Value2Index.end() && + NewIndexMap.find(KV.first) == NewIndexMap.end())) + continue; + SmallVector<int64_t, 8> C(Value2Index.size() + NewVariables.size() + 1, 0); + C[GetOrAddIndex(KV.first)] = -1; + Res.ExtraInfo.push_back(C); + } return Res; } +ConstraintTy ConstraintInfo::getConstraintForSolving(CmpInst::Predicate Pred, + Value *Op0, + Value *Op1) const { + // If both operands are known to be non-negative, change signed predicates to + // unsigned ones. This increases the reasoning effectiveness in combination + // with the signed <-> unsigned transfer logic. + if (CmpInst::isSigned(Pred) && + isKnownNonNegative(Op0, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1) && + isKnownNonNegative(Op1, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1)) + Pred = CmpInst::getUnsignedPredicate(Pred); + + SmallVector<Value *> NewVariables; + ConstraintTy R = getConstraint(Pred, Op0, Op1, NewVariables); + if (R.IsEq || !NewVariables.empty()) + return {}; + return R; +} + bool ConstraintTy::isValid(const ConstraintInfo &Info) const { return Coefficients.size() > 0 && all_of(Preconditions, [&Info](const PreconditionTy &C) { @@ -378,20 +538,13 @@ bool ConstraintTy::isValid(const ConstraintInfo &Info) const { bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A, Value *B) const { - DenseMap<Value *, unsigned> NewIndices; - auto R = getConstraint(Pred, A, B, NewIndices); - - if (!NewIndices.empty()) - return false; - - // TODO: properly check NewIndices. - return NewIndices.empty() && R.Preconditions.empty() && !R.IsEq && - !R.empty() && - getCS(CmpInst::isSigned(Pred)).isConditionImplied(R.Coefficients); + auto R = getConstraintForSolving(Pred, A, B); + return R.Preconditions.empty() && !R.empty() && + getCS(R.IsSigned).isConditionImplied(R.Coefficients); } void ConstraintInfo::transferToOtherSystem( - CmpInst::Predicate Pred, Value *A, Value *B, bool IsNegated, unsigned NumIn, + CmpInst::Predicate Pred, Value *A, Value *B, unsigned NumIn, unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack) { // Check if we can combine facts from the signed and unsigned systems to // derive additional facts. @@ -406,53 +559,69 @@ void ConstraintInfo::transferToOtherSystem( case CmpInst::ICMP_ULT: // If B is a signed positive constant, A >=s 0 and A <s B. if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) { - addFact(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0), - IsNegated, NumIn, NumOut, DFSInStack); - addFact(CmpInst::ICMP_SLT, A, B, IsNegated, NumIn, NumOut, DFSInStack); + addFact(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0), NumIn, + NumOut, DFSInStack); + addFact(CmpInst::ICMP_SLT, A, B, NumIn, NumOut, DFSInStack); } break; case CmpInst::ICMP_SLT: if (doesHold(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0))) - addFact(CmpInst::ICMP_ULT, A, B, IsNegated, NumIn, NumOut, DFSInStack); + addFact(CmpInst::ICMP_ULT, A, B, NumIn, NumOut, DFSInStack); break; case CmpInst::ICMP_SGT: if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), -1))) - addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), - IsNegated, NumIn, NumOut, DFSInStack); + addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), NumIn, + NumOut, DFSInStack); break; case CmpInst::ICMP_SGE: if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) { - addFact(CmpInst::ICMP_UGE, A, B, IsNegated, NumIn, NumOut, DFSInStack); + addFact(CmpInst::ICMP_UGE, A, B, NumIn, NumOut, DFSInStack); } break; } } namespace { -/// Represents either a condition that holds on entry to a block or a basic -/// block, with their respective Dominator DFS in and out numbers. -struct ConstraintOrBlock { +/// Represents either +/// * a condition that holds on entry to a block (=conditional fact) +/// * an assume (=assume fact) +/// * an instruction to simplify. +/// It also tracks the Dominator DFS in and out numbers for each entry. +struct FactOrCheck { + Instruction *Inst; unsigned NumIn; unsigned NumOut; - bool IsBlock; + bool IsCheck; bool Not; - union { - BasicBlock *BB; - CmpInst *Condition; - }; - ConstraintOrBlock(DomTreeNode *DTN) - : NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), IsBlock(true), - BB(DTN->getBlock()) {} - ConstraintOrBlock(DomTreeNode *DTN, CmpInst *Condition, bool Not) - : NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), IsBlock(false), - Not(Not), Condition(Condition) {} + FactOrCheck(DomTreeNode *DTN, Instruction *Inst, bool IsCheck, bool Not) + : Inst(Inst), NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), + IsCheck(IsCheck), Not(Not) {} + + static FactOrCheck getFact(DomTreeNode *DTN, Instruction *Inst, + bool Not = false) { + return FactOrCheck(DTN, Inst, false, Not); + } + + static FactOrCheck getCheck(DomTreeNode *DTN, Instruction *Inst) { + return FactOrCheck(DTN, Inst, true, false); + } + + bool isAssumeFact() const { + if (!IsCheck && isa<IntrinsicInst>(Inst)) { + assert(match(Inst, m_Intrinsic<Intrinsic::assume>())); + return true; + } + return false; + } + + bool isConditionFact() const { return !IsCheck && isa<CmpInst>(Inst); } }; /// Keep state required to build worklist. struct State { DominatorTree &DT; - SmallVector<ConstraintOrBlock, 64> WorkList; + SmallVector<FactOrCheck, 64> WorkList; State(DominatorTree &DT) : DT(DT) {} @@ -460,19 +629,9 @@ struct State { void addInfoFor(BasicBlock &BB); /// Returns true if we can add a known condition from BB to its successor - /// block Succ. Each predecessor of Succ can either be BB or be dominated - /// by Succ (e.g. the case when adding a condition from a pre-header to a - /// loop header). + /// block Succ. bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const { - if (BB.getSingleSuccessor()) { - assert(BB.getSingleSuccessor() == Succ); - return DT.properlyDominates(&BB, Succ); - } - return any_of(successors(&BB), - [Succ](const BasicBlock *S) { return S != Succ; }) && - all_of(predecessors(Succ), [&BB, Succ, this](BasicBlock *Pred) { - return Pred == &BB || DT.dominates(Succ, Pred); - }); + return DT.dominates(BasicBlockEdge(&BB, Succ), Succ); } }; @@ -497,16 +656,20 @@ static void dumpWithNames(ArrayRef<int64_t> C, #endif void State::addInfoFor(BasicBlock &BB) { - WorkList.emplace_back(DT.getNode(&BB)); - // True as long as long as the current instruction is guaranteed to execute. bool GuaranteedToExecute = true; - // Scan BB for assume calls. - // TODO: also use this scan to queue conditions to simplify, so we can - // interleave facts from assumes and conditions to simplify in a single - // basic block. And to skip another traversal of each basic block when - // simplifying. + // Queue conditions and assumes. for (Instruction &I : BB) { + if (auto Cmp = dyn_cast<ICmpInst>(&I)) { + WorkList.push_back(FactOrCheck::getCheck(DT.getNode(&BB), Cmp)); + continue; + } + + if (match(&I, m_Intrinsic<Intrinsic::ssub_with_overflow>())) { + WorkList.push_back(FactOrCheck::getCheck(DT.getNode(&BB), &I)); + continue; + } + Value *Cond; // For now, just handle assumes with a single compare as condition. if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) && @@ -514,14 +677,11 @@ void State::addInfoFor(BasicBlock &BB) { if (GuaranteedToExecute) { // The assume is guaranteed to execute when BB is entered, hence Cond // holds on entry to BB. - WorkList.emplace_back(DT.getNode(&BB), cast<ICmpInst>(Cond), false); + WorkList.emplace_back(FactOrCheck::getFact(DT.getNode(I.getParent()), + cast<Instruction>(Cond))); } else { - // Otherwise the condition only holds in the successors. - for (BasicBlock *Succ : successors(&BB)) { - if (!canAddSuccessor(BB, Succ)) - continue; - WorkList.emplace_back(DT.getNode(Succ), cast<ICmpInst>(Cond), false); - } + WorkList.emplace_back( + FactOrCheck::getFact(DT.getNode(I.getParent()), &I)); } } GuaranteedToExecute &= isGuaranteedToTransferExecutionToSuccessor(&I); @@ -531,33 +691,48 @@ void State::addInfoFor(BasicBlock &BB) { if (!Br || !Br->isConditional()) return; - // If the condition is an OR of 2 compares and the false successor only has - // the current block as predecessor, queue both negated conditions for the - // false successor. - Value *Op0, *Op1; - if (match(Br->getCondition(), m_LogicalOr(m_Value(Op0), m_Value(Op1))) && - isa<ICmpInst>(Op0) && isa<ICmpInst>(Op1)) { - BasicBlock *FalseSuccessor = Br->getSuccessor(1); - if (canAddSuccessor(BB, FalseSuccessor)) { - WorkList.emplace_back(DT.getNode(FalseSuccessor), cast<ICmpInst>(Op0), - true); - WorkList.emplace_back(DT.getNode(FalseSuccessor), cast<ICmpInst>(Op1), - true); - } - return; - } + Value *Cond = Br->getCondition(); - // If the condition is an AND of 2 compares and the true successor only has - // the current block as predecessor, queue both conditions for the true - // successor. - if (match(Br->getCondition(), m_LogicalAnd(m_Value(Op0), m_Value(Op1))) && - isa<ICmpInst>(Op0) && isa<ICmpInst>(Op1)) { - BasicBlock *TrueSuccessor = Br->getSuccessor(0); - if (canAddSuccessor(BB, TrueSuccessor)) { - WorkList.emplace_back(DT.getNode(TrueSuccessor), cast<ICmpInst>(Op0), - false); - WorkList.emplace_back(DT.getNode(TrueSuccessor), cast<ICmpInst>(Op1), - false); + // If the condition is a chain of ORs/AND and the successor only has the + // current block as predecessor, queue conditions for the successor. + Value *Op0, *Op1; + if (match(Cond, m_LogicalOr(m_Value(Op0), m_Value(Op1))) || + match(Cond, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) { + bool IsOr = match(Cond, m_LogicalOr()); + bool IsAnd = match(Cond, m_LogicalAnd()); + // If there's a select that matches both AND and OR, we need to commit to + // one of the options. Arbitrarily pick OR. + if (IsOr && IsAnd) + IsAnd = false; + + BasicBlock *Successor = Br->getSuccessor(IsOr ? 1 : 0); + if (canAddSuccessor(BB, Successor)) { + SmallVector<Value *> CondWorkList; + SmallPtrSet<Value *, 8> SeenCond; + auto QueueValue = [&CondWorkList, &SeenCond](Value *V) { + if (SeenCond.insert(V).second) + CondWorkList.push_back(V); + }; + QueueValue(Op1); + QueueValue(Op0); + while (!CondWorkList.empty()) { + Value *Cur = CondWorkList.pop_back_val(); + if (auto *Cmp = dyn_cast<ICmpInst>(Cur)) { + WorkList.emplace_back( + FactOrCheck::getFact(DT.getNode(Successor), Cmp, IsOr)); + continue; + } + if (IsOr && match(Cur, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) { + QueueValue(Op1); + QueueValue(Op0); + continue; + } + if (IsAnd && match(Cur, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) { + QueueValue(Op1); + QueueValue(Op0); + continue; + } + } } return; } @@ -566,47 +741,113 @@ void State::addInfoFor(BasicBlock &BB) { if (!CmpI) return; if (canAddSuccessor(BB, Br->getSuccessor(0))) - WorkList.emplace_back(DT.getNode(Br->getSuccessor(0)), CmpI, false); + WorkList.emplace_back( + FactOrCheck::getFact(DT.getNode(Br->getSuccessor(0)), CmpI)); if (canAddSuccessor(BB, Br->getSuccessor(1))) - WorkList.emplace_back(DT.getNode(Br->getSuccessor(1)), CmpI, true); + WorkList.emplace_back( + FactOrCheck::getFact(DT.getNode(Br->getSuccessor(1)), CmpI, true)); +} + +static bool checkAndReplaceCondition(CmpInst *Cmp, ConstraintInfo &Info) { + LLVM_DEBUG(dbgs() << "Checking " << *Cmp << "\n"); + + CmpInst::Predicate Pred = Cmp->getPredicate(); + Value *A = Cmp->getOperand(0); + Value *B = Cmp->getOperand(1); + + auto R = Info.getConstraintForSolving(Pred, A, B); + if (R.empty() || !R.isValid(Info)){ + LLVM_DEBUG(dbgs() << " failed to decompose condition\n"); + return false; + } + + auto &CSToUse = Info.getCS(R.IsSigned); + + // If there was extra information collected during decomposition, apply + // it now and remove it immediately once we are done with reasoning + // about the constraint. + for (auto &Row : R.ExtraInfo) + CSToUse.addVariableRow(Row); + auto InfoRestorer = make_scope_exit([&]() { + for (unsigned I = 0; I < R.ExtraInfo.size(); ++I) + CSToUse.popLastConstraint(); + }); + + bool Changed = false; + if (CSToUse.isConditionImplied(R.Coefficients)) { + if (!DebugCounter::shouldExecute(EliminatedCounter)) + return false; + + LLVM_DEBUG({ + dbgs() << "Condition " << *Cmp << " implied by dominating constraints\n"; + dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); + }); + Constant *TrueC = + ConstantInt::getTrue(CmpInst::makeCmpResultType(Cmp->getType())); + Cmp->replaceUsesWithIf(TrueC, [](Use &U) { + // Conditions in an assume trivially simplify to true. Skip uses + // in assume calls to not destroy the available information. + auto *II = dyn_cast<IntrinsicInst>(U.getUser()); + return !II || II->getIntrinsicID() != Intrinsic::assume; + }); + NumCondsRemoved++; + Changed = true; + } + if (CSToUse.isConditionImplied(ConstraintSystem::negate(R.Coefficients))) { + if (!DebugCounter::shouldExecute(EliminatedCounter)) + return false; + + LLVM_DEBUG({ + dbgs() << "Condition !" << *Cmp << " implied by dominating constraints\n"; + dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); + }); + Constant *FalseC = + ConstantInt::getFalse(CmpInst::makeCmpResultType(Cmp->getType())); + Cmp->replaceAllUsesWith(FalseC); + NumCondsRemoved++; + Changed = true; + } + return Changed; } void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B, - bool IsNegated, unsigned NumIn, unsigned NumOut, + unsigned NumIn, unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack) { // If the constraint has a pre-condition, skip the constraint if it does not // hold. - DenseMap<Value *, unsigned> NewIndices; - auto R = getConstraint(Pred, A, B, NewIndices); + SmallVector<Value *> NewVariables; + auto R = getConstraint(Pred, A, B, NewVariables); if (!R.isValid(*this)) return; - //LLVM_DEBUG(dbgs() << "Adding " << *Condition << " " << IsNegated << "\n"); + LLVM_DEBUG(dbgs() << "Adding '" << CmpInst::getPredicateName(Pred) << " "; + A->printAsOperand(dbgs(), false); dbgs() << ", "; + B->printAsOperand(dbgs(), false); dbgs() << "'\n"); bool Added = false; - assert(CmpInst::isSigned(Pred) == R.IsSigned && - "condition and constraint signs must match"); auto &CSToUse = getCS(R.IsSigned); if (R.Coefficients.empty()) return; Added |= CSToUse.addVariableRowFill(R.Coefficients); - // If R has been added to the system, queue it for removal once it goes - // out-of-scope. + // If R has been added to the system, add the new variables and queue it for + // removal once it goes out-of-scope. if (Added) { SmallVector<Value *, 2> ValuesToRelease; - for (auto &KV : NewIndices) { - getValue2Index(R.IsSigned).insert(KV); - ValuesToRelease.push_back(KV.first); + auto &Value2Index = getValue2Index(R.IsSigned); + for (Value *V : NewVariables) { + Value2Index.insert({V, Value2Index.size() + 1}); + ValuesToRelease.push_back(V); } LLVM_DEBUG({ dbgs() << " constraint: "; dumpWithNames(R.Coefficients, getValue2Index(R.IsSigned)); + dbgs() << "\n"; }); - DFSInStack.emplace_back(NumIn, NumOut, IsNegated, R.IsSigned, - ValuesToRelease); + DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, + std::move(ValuesToRelease)); if (R.IsEq) { // Also add the inverted constraint for equality constraints. @@ -614,26 +855,58 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B, Coeff *= -1; CSToUse.addVariableRowFill(R.Coefficients); - DFSInStack.emplace_back(NumIn, NumOut, IsNegated, R.IsSigned, + DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned, SmallVector<Value *, 2>()); } } } -static void +static bool replaceSubOverflowUses(IntrinsicInst *II, Value *A, Value *B, + SmallVectorImpl<Instruction *> &ToRemove) { + bool Changed = false; + IRBuilder<> Builder(II->getParent(), II->getIterator()); + Value *Sub = nullptr; + for (User *U : make_early_inc_range(II->users())) { + if (match(U, m_ExtractValue<0>(m_Value()))) { + if (!Sub) + Sub = Builder.CreateSub(A, B); + U->replaceAllUsesWith(Sub); + Changed = true; + } else if (match(U, m_ExtractValue<1>(m_Value()))) { + U->replaceAllUsesWith(Builder.getFalse()); + Changed = true; + } else + continue; + + if (U->use_empty()) { + auto *I = cast<Instruction>(U); + ToRemove.push_back(I); + I->setOperand(0, PoisonValue::get(II->getType())); + Changed = true; + } + } + + if (II->use_empty()) { + II->eraseFromParent(); + Changed = true; + } + return Changed; +} + +static bool tryToSimplifyOverflowMath(IntrinsicInst *II, ConstraintInfo &Info, SmallVectorImpl<Instruction *> &ToRemove) { auto DoesConditionHold = [](CmpInst::Predicate Pred, Value *A, Value *B, ConstraintInfo &Info) { - DenseMap<Value *, unsigned> NewIndices; - auto R = Info.getConstraint(Pred, A, B, NewIndices); - if (R.size() < 2 || R.needsNewIndices(NewIndices) || !R.isValid(Info)) + auto R = Info.getConstraintForSolving(Pred, A, B); + if (R.size() < 2 || !R.isValid(Info)) return false; - auto &CSToUse = Info.getCS(CmpInst::isSigned(Pred)); + auto &CSToUse = Info.getCS(R.IsSigned); return CSToUse.isConditionImplied(R.Coefficients); }; + bool Changed = false; if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow) { // If A s>= B && B s>= 0, ssub.with.overflow(a, b) should not overflow and // can be simplified to a regular sub. @@ -642,37 +915,17 @@ tryToSimplifyOverflowMath(IntrinsicInst *II, ConstraintInfo &Info, if (!DoesConditionHold(CmpInst::ICMP_SGE, A, B, Info) || !DoesConditionHold(CmpInst::ICMP_SGE, B, ConstantInt::get(A->getType(), 0), Info)) - return; - - IRBuilder<> Builder(II->getParent(), II->getIterator()); - Value *Sub = nullptr; - for (User *U : make_early_inc_range(II->users())) { - if (match(U, m_ExtractValue<0>(m_Value()))) { - if (!Sub) - Sub = Builder.CreateSub(A, B); - U->replaceAllUsesWith(Sub); - } else if (match(U, m_ExtractValue<1>(m_Value()))) - U->replaceAllUsesWith(Builder.getFalse()); - else - continue; - - if (U->use_empty()) { - auto *I = cast<Instruction>(U); - ToRemove.push_back(I); - I->setOperand(0, PoisonValue::get(II->getType())); - } - } - - if (II->use_empty()) - II->eraseFromParent(); + return false; + Changed = replaceSubOverflowUses(II, A, B, ToRemove); } + return Changed; } static bool eliminateConstraints(Function &F, DominatorTree &DT) { bool Changed = false; DT.updateDFSNumbers(); - ConstraintInfo Info; + ConstraintInfo Info(F.getParent()->getDataLayout()); State S(DT); // First, collect conditions implied by branches and blocks with their @@ -683,19 +936,41 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { S.addInfoFor(BB); } - // Next, sort worklist by dominance, so that dominating blocks and conditions - // come before blocks and conditions dominated by them. If a block and a - // condition have the same numbers, the condition comes before the block, as - // it holds on entry to the block. - stable_sort(S.WorkList, [](const ConstraintOrBlock &A, const ConstraintOrBlock &B) { - return std::tie(A.NumIn, A.IsBlock) < std::tie(B.NumIn, B.IsBlock); + // Next, sort worklist by dominance, so that dominating conditions to check + // and facts come before conditions and facts dominated by them. If a + // condition to check and a fact have the same numbers, conditional facts come + // first. Assume facts and checks are ordered according to their relative + // order in the containing basic block. Also make sure conditions with + // constant operands come before conditions without constant operands. This + // increases the effectiveness of the current signed <-> unsigned fact + // transfer logic. + stable_sort(S.WorkList, [](const FactOrCheck &A, const FactOrCheck &B) { + auto HasNoConstOp = [](const FactOrCheck &B) { + return !isa<ConstantInt>(B.Inst->getOperand(0)) && + !isa<ConstantInt>(B.Inst->getOperand(1)); + }; + // If both entries have the same In numbers, conditional facts come first. + // Otherwise use the relative order in the basic block. + if (A.NumIn == B.NumIn) { + if (A.isConditionFact() && B.isConditionFact()) { + bool NoConstOpA = HasNoConstOp(A); + bool NoConstOpB = HasNoConstOp(B); + return NoConstOpA < NoConstOpB; + } + if (A.isConditionFact()) + return true; + if (B.isConditionFact()) + return false; + return A.Inst->comesBefore(B.Inst); + } + return A.NumIn < B.NumIn; }); SmallVector<Instruction *> ToRemove; // Finally, process ordered worklist and eliminate implied conditions. SmallVector<StackEntry, 16> DFSInStack; - for (ConstraintOrBlock &CB : S.WorkList) { + for (FactOrCheck &CB : S.WorkList) { // First, pop entries from the stack that are out-of-scope for CB. Remove // the corresponding entry from the constraint system. while (!DFSInStack.empty()) { @@ -724,94 +999,42 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { LLVM_DEBUG({ dbgs() << "Processing "; - if (CB.IsBlock) - dbgs() << *CB.BB; + if (CB.IsCheck) + dbgs() << "condition to simplify: " << *CB.Inst; else - dbgs() << *CB.Condition; + dbgs() << "fact to add to the system: " << *CB.Inst; dbgs() << "\n"; }); // For a block, check if any CmpInsts become known based on the current set // of constraints. - if (CB.IsBlock) { - for (Instruction &I : make_early_inc_range(*CB.BB)) { - if (auto *II = dyn_cast<WithOverflowInst>(&I)) { - tryToSimplifyOverflowMath(II, Info, ToRemove); - continue; - } - auto *Cmp = dyn_cast<ICmpInst>(&I); - if (!Cmp) - continue; - - DenseMap<Value *, unsigned> NewIndices; - auto R = Info.getConstraint(Cmp, NewIndices); - if (R.IsEq || R.empty() || R.needsNewIndices(NewIndices) || - !R.isValid(Info)) - continue; - - auto &CSToUse = Info.getCS(R.IsSigned); - if (CSToUse.isConditionImplied(R.Coefficients)) { - if (!DebugCounter::shouldExecute(EliminatedCounter)) - continue; - - LLVM_DEBUG({ - dbgs() << "Condition " << *Cmp - << " implied by dominating constraints\n"; - dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); - }); - Cmp->replaceUsesWithIf( - ConstantInt::getTrue(F.getParent()->getContext()), [](Use &U) { - // Conditions in an assume trivially simplify to true. Skip uses - // in assume calls to not destroy the available information. - auto *II = dyn_cast<IntrinsicInst>(U.getUser()); - return !II || II->getIntrinsicID() != Intrinsic::assume; - }); - NumCondsRemoved++; - Changed = true; - } - if (CSToUse.isConditionImplied( - ConstraintSystem::negate(R.Coefficients))) { - if (!DebugCounter::shouldExecute(EliminatedCounter)) - continue; - - LLVM_DEBUG({ - dbgs() << "Condition !" << *Cmp - << " implied by dominating constraints\n"; - dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); - }); - Cmp->replaceAllUsesWith( - ConstantInt::getFalse(F.getParent()->getContext())); - NumCondsRemoved++; - Changed = true; - } + if (CB.IsCheck) { + if (auto *II = dyn_cast<WithOverflowInst>(CB.Inst)) { + Changed |= tryToSimplifyOverflowMath(II, Info, ToRemove); + } else if (auto *Cmp = dyn_cast<ICmpInst>(CB.Inst)) { + Changed |= checkAndReplaceCondition(Cmp, Info); } continue; } - // Set up a function to restore the predicate at the end of the scope if it - // has been negated. Negate the predicate in-place, if required. - auto *CI = dyn_cast<ICmpInst>(CB.Condition); - auto PredicateRestorer = make_scope_exit([CI, &CB]() { - if (CB.Not && CI) - CI->setPredicate(CI->getInversePredicate()); - }); - if (CB.Not) { - if (CI) { - CI->setPredicate(CI->getInversePredicate()); - } else { - LLVM_DEBUG(dbgs() << "Can only negate compares so far.\n"); + ICmpInst::Predicate Pred; + Value *A, *B; + Value *Cmp = CB.Inst; + match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp))); + if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) { + if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) { + LLVM_DEBUG( + dbgs() + << "Skip adding constraint because system has too many rows.\n"); continue; } - } - ICmpInst::Predicate Pred; - Value *A, *B; - if (match(CB.Condition, m_ICmp(Pred, m_Value(A), m_Value(B)))) { - // Otherwise, add the condition to the system and stack, if we can - // transform it into a constraint. - Info.addFact(Pred, A, B, CB.Not, CB.NumIn, CB.NumOut, DFSInStack); - Info.transferToOtherSystem(Pred, A, B, CB.Not, CB.NumIn, CB.NumOut, - DFSInStack); + // Use the inverse predicate if required. + if (CB.Not) + Pred = CmpInst::getInversePredicate(Pred); + + Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack); + Info.transferToOtherSystem(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack); } } @@ -840,41 +1063,3 @@ PreservedAnalyses ConstraintEliminationPass::run(Function &F, PA.preserveSet<CFGAnalyses>(); return PA; } - -namespace { - -class ConstraintElimination : public FunctionPass { -public: - static char ID; - - ConstraintElimination() : FunctionPass(ID) { - initializeConstraintEliminationPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - return eliminateConstraints(F, DT); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - } -}; - -} // end anonymous namespace - -char ConstraintElimination::ID = 0; - -INITIALIZE_PASS_BEGIN(ConstraintElimination, "constraint-elimination", - "Constraint Elimination", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) -INITIALIZE_PASS_END(ConstraintElimination, "constraint-elimination", - "Constraint Elimination", false, false) - -FunctionPass *llvm::createConstraintEliminationPass() { - return new ConstraintElimination(); -} diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 64bd4241f37c..90b4b521e7de 100644 --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -12,7 +12,6 @@ #include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h" #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/DomTreeUpdater.h" @@ -44,6 +43,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> +#include <optional> #include <utility> using namespace llvm; @@ -94,6 +94,8 @@ STATISTIC(NumSaturating, "Number of saturating arithmetics converted to normal arithmetics"); STATISTIC(NumNonNull, "Number of function pointer arguments marked non-null"); STATISTIC(NumMinMax, "Number of llvm.[us]{min,max} intrinsics removed"); +STATISTIC(NumUDivURemsNarrowedExpanded, + "Number of bound udiv's/urem's expanded"); namespace { @@ -340,18 +342,16 @@ static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) { /// exploiting range information. static bool constantFoldCmp(CmpInst *Cmp, LazyValueInfo *LVI) { Value *Op0 = Cmp->getOperand(0); - auto *C = dyn_cast<Constant>(Cmp->getOperand(1)); - if (!C) - return false; - + Value *Op1 = Cmp->getOperand(1); LazyValueInfo::Tristate Result = - LVI->getPredicateAt(Cmp->getPredicate(), Op0, C, Cmp, + LVI->getPredicateAt(Cmp->getPredicate(), Op0, Op1, Cmp, /*UseBlockValue=*/true); if (Result == LazyValueInfo::Unknown) return false; ++NumCmps; - Constant *TorF = ConstantInt::get(Type::getInt1Ty(Cmp->getContext()), Result); + Constant *TorF = + ConstantInt::get(CmpInst::makeCmpResultType(Op0->getType()), Result); Cmp->replaceAllUsesWith(TorF); Cmp->eraseFromParent(); return true; @@ -439,8 +439,8 @@ static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI, // See if we can prove that the given binary op intrinsic will not overflow. static bool willNotOverflow(BinaryOpIntrinsic *BO, LazyValueInfo *LVI) { - ConstantRange LRange = LVI->getConstantRange(BO->getLHS(), BO); - ConstantRange RRange = LVI->getConstantRange(BO->getRHS(), BO); + ConstantRange LRange = LVI->getConstantRangeAtUse(BO->getOperandUse(0)); + ConstantRange RRange = LVI->getConstantRangeAtUse(BO->getOperandUse(1)); ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( BO->getBinaryOp(), RRange, BO->getNoWrapKind()); return NWRegion.contains(LRange); @@ -693,55 +693,38 @@ static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) { return true; } -static bool isNonNegative(Value *V, LazyValueInfo *LVI, Instruction *CxtI) { - Constant *Zero = ConstantInt::get(V->getType(), 0); - auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SGE, V, Zero, CxtI, - /*UseBlockValue=*/true); - return Result == LazyValueInfo::True; -} - -static bool isNonPositive(Value *V, LazyValueInfo *LVI, Instruction *CxtI) { - Constant *Zero = ConstantInt::get(V->getType(), 0); - auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SLE, V, Zero, CxtI, - /*UseBlockValue=*/true); - return Result == LazyValueInfo::True; -} - enum class Domain { NonNegative, NonPositive, Unknown }; -Domain getDomain(Value *V, LazyValueInfo *LVI, Instruction *CxtI) { - if (isNonNegative(V, LVI, CxtI)) +static Domain getDomain(const ConstantRange &CR) { + if (CR.isAllNonNegative()) return Domain::NonNegative; - if (isNonPositive(V, LVI, CxtI)) + if (CR.icmp(ICmpInst::ICMP_SLE, APInt::getNullValue(CR.getBitWidth()))) return Domain::NonPositive; return Domain::Unknown; } /// Try to shrink a sdiv/srem's width down to the smallest power of two that's /// sufficient to contain its operands. -static bool narrowSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { +static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR, + const ConstantRange &RCR) { assert(Instr->getOpcode() == Instruction::SDiv || Instr->getOpcode() == Instruction::SRem); - if (Instr->getType()->isVectorTy()) - return false; + assert(!Instr->getType()->isVectorTy()); // Find the smallest power of two bitwidth that's sufficient to hold Instr's // operands. unsigned OrigWidth = Instr->getType()->getIntegerBitWidth(); - // What is the smallest bit width that can accomodate the entire value ranges + // What is the smallest bit width that can accommodate the entire value ranges // of both of the operands? - std::array<Optional<ConstantRange>, 2> CRs; - unsigned MinSignedBits = 0; - for (auto I : zip(Instr->operands(), CRs)) { - std::get<1>(I) = LVI->getConstantRange(std::get<0>(I), Instr); - MinSignedBits = std::max(std::get<1>(I)->getMinSignedBits(), MinSignedBits); - } + std::array<std::optional<ConstantRange>, 2> CRs; + unsigned MinSignedBits = + std::max(LCR.getMinSignedBits(), RCR.getMinSignedBits()); // sdiv/srem is UB if divisor is -1 and divident is INT_MIN, so unless we can // prove that such a combination is impossible, we need to bump the bitwidth. - if (CRs[1]->contains(APInt::getAllOnes(OrigWidth)) && - CRs[0]->contains(APInt::getSignedMinValue(MinSignedBits).sext(OrigWidth))) + if (RCR.contains(APInt::getAllOnes(OrigWidth)) && + LCR.contains(APInt::getSignedMinValue(MinSignedBits).sext(OrigWidth))) ++MinSignedBits; // Don't shrink below 8 bits wide. @@ -770,24 +753,91 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { return true; } +static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, + const ConstantRange &YCR) { + Type *Ty = Instr->getType(); + assert(Instr->getOpcode() == Instruction::UDiv || + Instr->getOpcode() == Instruction::URem); + assert(!Ty->isVectorTy()); + bool IsRem = Instr->getOpcode() == Instruction::URem; + + Value *X = Instr->getOperand(0); + Value *Y = Instr->getOperand(1); + + // X u/ Y -> 0 iff X u< Y + // X u% Y -> X iff X u< Y + if (XCR.icmp(ICmpInst::ICMP_ULT, YCR)) { + Instr->replaceAllUsesWith(IsRem ? X : Constant::getNullValue(Ty)); + Instr->eraseFromParent(); + ++NumUDivURemsNarrowedExpanded; + return true; + } + + // Given + // R = X u% Y + // We can represent the modulo operation as a loop/self-recursion: + // urem_rec(X, Y): + // Z = X - Y + // if X u< Y + // ret X + // else + // ret urem_rec(Z, Y) + // which isn't better, but if we only need a single iteration + // to compute the answer, this becomes quite good: + // R = X < Y ? X : X - Y iff X u< 2*Y (w/ unsigned saturation) + // Now, we do not care about all full multiples of Y in X, they do not change + // the answer, thus we could rewrite the expression as: + // X* = X - (Y * |_ X / Y _|) + // R = X* % Y + // so we don't need the *first* iteration to return, we just need to + // know *which* iteration will always return, so we could also rewrite it as: + // X* = X - (Y * |_ X / Y _|) + // R = X* % Y iff X* u< 2*Y (w/ unsigned saturation) + // but that does not seem profitable here. + + // Even if we don't know X's range, the divisor may be so large, X can't ever + // be 2x larger than that. I.e. if divisor is always negative. + if (!XCR.icmp(ICmpInst::ICMP_ULT, + YCR.umul_sat(APInt(YCR.getBitWidth(), 2))) && + !YCR.isAllNegative()) + return false; + + IRBuilder<> B(Instr); + Value *ExpandedOp; + if (IsRem) { + // NOTE: this transformation introduces two uses of X, + // but it may be undef so we must freeze it first. + Value *FrozenX = B.CreateFreeze(X, X->getName() + ".frozen"); + auto *AdjX = B.CreateNUWSub(FrozenX, Y, Instr->getName() + ".urem"); + auto *Cmp = + B.CreateICmp(ICmpInst::ICMP_ULT, FrozenX, Y, Instr->getName() + ".cmp"); + ExpandedOp = B.CreateSelect(Cmp, FrozenX, AdjX); + } else { + auto *Cmp = + B.CreateICmp(ICmpInst::ICMP_UGE, X, Y, Instr->getName() + ".cmp"); + ExpandedOp = B.CreateZExt(Cmp, Ty, Instr->getName() + ".udiv"); + } + ExpandedOp->takeName(Instr); + Instr->replaceAllUsesWith(ExpandedOp); + Instr->eraseFromParent(); + ++NumUDivURemsNarrowedExpanded; + return true; +} + /// Try to shrink a udiv/urem's width down to the smallest power of two that's /// sufficient to contain its operands. -static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { +static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, + const ConstantRange &YCR) { assert(Instr->getOpcode() == Instruction::UDiv || Instr->getOpcode() == Instruction::URem); - if (Instr->getType()->isVectorTy()) - return false; + assert(!Instr->getType()->isVectorTy()); // Find the smallest power of two bitwidth that's sufficient to hold Instr's // operands. - // What is the smallest bit width that can accomodate the entire value ranges + // What is the smallest bit width that can accommodate the entire value ranges // of both of the operands? - unsigned MaxActiveBits = 0; - for (Value *Operand : Instr->operands()) { - ConstantRange CR = LVI->getConstantRange(Operand, Instr); - MaxActiveBits = std::max(CR.getActiveBits(), MaxActiveBits); - } + unsigned MaxActiveBits = std::max(XCR.getActiveBits(), YCR.getActiveBits()); // Don't shrink below 8 bits wide. unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MaxActiveBits), 8); @@ -814,24 +864,39 @@ static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { return true; } -static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { - assert(SDI->getOpcode() == Instruction::SRem); - if (SDI->getType()->isVectorTy()) +static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { + assert(Instr->getOpcode() == Instruction::UDiv || + Instr->getOpcode() == Instruction::URem); + if (Instr->getType()->isVectorTy()) return false; + ConstantRange XCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(0)); + ConstantRange YCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(1)); + if (expandUDivOrURem(Instr, XCR, YCR)) + return true; + + return narrowUDivOrURem(Instr, XCR, YCR); +} + +static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR, + const ConstantRange &RCR, LazyValueInfo *LVI) { + assert(SDI->getOpcode() == Instruction::SRem); + assert(!SDI->getType()->isVectorTy()); + + if (LCR.abs().icmp(CmpInst::ICMP_ULT, RCR.abs())) { + SDI->replaceAllUsesWith(SDI->getOperand(0)); + SDI->eraseFromParent(); + return true; + } + struct Operand { Value *V; Domain D; }; - std::array<Operand, 2> Ops; - - for (const auto I : zip(Ops, SDI->operands())) { - Operand &Op = std::get<0>(I); - Op.V = std::get<1>(I); - Op.D = getDomain(Op.V, LVI, SDI); - if (Op.D == Domain::Unknown) - return false; - } + std::array<Operand, 2> Ops = {{{SDI->getOperand(0), getDomain(LCR)}, + {SDI->getOperand(1), getDomain(RCR)}}}; + if (Ops[0].D == Domain::Unknown || Ops[1].D == Domain::Unknown) + return false; // We know domains of both of the operands! ++NumSRems; @@ -850,11 +915,13 @@ static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { BinaryOperator::CreateURem(Ops[0].V, Ops[1].V, SDI->getName(), SDI); URem->setDebugLoc(SDI->getDebugLoc()); - Value *Res = URem; + auto *Res = URem; // If the divident was non-positive, we need to negate the result. - if (Ops[0].D == Domain::NonPositive) + if (Ops[0].D == Domain::NonPositive) { Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", SDI); + Res->setDebugLoc(SDI->getDebugLoc()); + } SDI->replaceAllUsesWith(Res); SDI->eraseFromParent(); @@ -870,24 +937,19 @@ static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { /// If this is the case, replace the SDiv with a UDiv. Even for local /// conditions, this can sometimes prove conditions instcombine can't by /// exploiting range information. -static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { +static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR, + const ConstantRange &RCR, LazyValueInfo *LVI) { assert(SDI->getOpcode() == Instruction::SDiv); - if (SDI->getType()->isVectorTy()) - return false; + assert(!SDI->getType()->isVectorTy()); struct Operand { Value *V; Domain D; }; - std::array<Operand, 2> Ops; - - for (const auto I : zip(Ops, SDI->operands())) { - Operand &Op = std::get<0>(I); - Op.V = std::get<1>(I); - Op.D = getDomain(Op.V, LVI, SDI); - if (Op.D == Domain::Unknown) - return false; - } + std::array<Operand, 2> Ops = {{{SDI->getOperand(0), getDomain(LCR)}, + {SDI->getOperand(1), getDomain(RCR)}}}; + if (Ops[0].D == Domain::Unknown || Ops[1].D == Domain::Unknown) + return false; // We know domains of both of the operands! ++NumSDivs; @@ -928,22 +990,25 @@ static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { if (Instr->getType()->isVectorTy()) return false; + ConstantRange LCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(0)); + ConstantRange RCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(1)); if (Instr->getOpcode() == Instruction::SDiv) - if (processSDiv(Instr, LVI)) + if (processSDiv(Instr, LCR, RCR, LVI)) return true; - if (Instr->getOpcode() == Instruction::SRem) - if (processSRem(Instr, LVI)) + if (Instr->getOpcode() == Instruction::SRem) { + if (processSRem(Instr, LCR, RCR, LVI)) return true; + } - return narrowSDivOrSRem(Instr, LVI); + return narrowSDivOrSRem(Instr, LCR, RCR); } static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { if (SDI->getType()->isVectorTy()) return false; - ConstantRange LRange = LVI->getConstantRange(SDI->getOperand(0), SDI); + ConstantRange LRange = LVI->getConstantRangeAtUse(SDI->getOperandUse(0)); unsigned OrigWidth = SDI->getType()->getIntegerBitWidth(); ConstantRange NegOneOrZero = ConstantRange(APInt(OrigWidth, (uint64_t)-1, true), APInt(OrigWidth, 1)); @@ -955,7 +1020,7 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { return true; } - if (!isNonNegative(SDI->getOperand(0), LVI, SDI)) + if (!LRange.isAllNonNegative()) return false; ++NumAShrsConverted; @@ -974,9 +1039,8 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { if (SDI->getType()->isVectorTy()) return false; - Value *Base = SDI->getOperand(0); - - if (!isNonNegative(Base, LVI, SDI)) + const Use &Base = SDI->getOperandUse(0); + if (!LVI->getConstantRangeAtUse(Base).isAllNonNegative()) return false; ++NumSExt; @@ -1033,7 +1097,7 @@ static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) { // Pattern match (and lhs, C) where C includes a superset of bits which might // be set in lhs. This is a common truncation idiom created by instcombine. - Value *LHS = BinOp->getOperand(0); + const Use &LHS = BinOp->getOperandUse(0); ConstantInt *RHS = dyn_cast<ConstantInt>(BinOp->getOperand(1)); if (!RHS || !RHS->getValue().isMask()) return false; @@ -1041,7 +1105,7 @@ static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) { // We can only replace the AND with LHS based on range info if the range does // not include undef. ConstantRange LRange = - LVI->getConstantRange(LHS, BinOp, /*UndefAllowed=*/false); + LVI->getConstantRangeAtUse(LHS, /*UndefAllowed=*/false); if (!LRange.getUnsignedMax().ule(RHS->getValue())) return false; diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 5667eefabad5..658d0fcb53fa 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -724,7 +724,7 @@ private: // Make DeterminatorBB the first element in Path. PathType Path = TPath.getPath(); - auto ItDet = std::find(Path.begin(), Path.end(), DeterminatorBB); + auto ItDet = llvm::find(Path, DeterminatorBB); std::rotate(Path.begin(), ItDet, Path.end()); bool IsDetBBSeen = false; @@ -798,7 +798,7 @@ private: // Otherwise update Metrics for all blocks that will be cloned. If any // block is already cloned and would be reused, don't double count it. - auto DetIt = std::find(PathBBs.begin(), PathBBs.end(), Determinator); + auto DetIt = llvm::find(PathBBs, Determinator); for (auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) { BB = *BBIt; VisitedBB = getClonedBB(BB, NextState, DuplicateMap); @@ -840,7 +840,7 @@ private: } } - unsigned DuplicationCost = 0; + InstructionCost DuplicationCost = 0; unsigned JumpTableSize = 0; TTI->getEstimatedNumberOfCaseClusters(*Switch, JumpTableSize, nullptr, @@ -851,7 +851,7 @@ private: // using binary search, hence the LogBase2(). unsigned CondBranches = APInt(32, Switch->getNumSuccessors()).ceilLogBase2(); - DuplicationCost = *Metrics.NumInsts.getValue() / CondBranches; + DuplicationCost = Metrics.NumInsts / CondBranches; } else { // Compared with jump tables, the DFA optimizer removes an indirect branch // on each loop iteration, thus making branch prediction more precise. The @@ -859,7 +859,7 @@ private: // predictor to make a mistake, and the more benefit there is in the DFA // optimizer. Thus, the more branch targets there are, the lower is the // cost of the DFA opt. - DuplicationCost = *Metrics.NumInsts.getValue() / JumpTableSize; + DuplicationCost = Metrics.NumInsts / JumpTableSize; } LLVM_DEBUG(dbgs() << "\nDFA Jump Threading: Cost to jump thread block " @@ -943,7 +943,7 @@ private: if (PathBBs.front() == Determinator) PathBBs.pop_front(); - auto DetIt = std::find(PathBBs.begin(), PathBBs.end(), Determinator); + auto DetIt = llvm::find(PathBBs, Determinator); auto Prev = std::prev(DetIt); BasicBlock *PrevBB = *Prev; for (auto BBIt = DetIt; BBIt != PathBBs.end(); BBIt++) { diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index 3f0dad7ee769..9c0b4d673145 100644 --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -56,6 +56,7 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -85,6 +86,7 @@ #include <cstdint> #include <iterator> #include <map> +#include <optional> #include <utility> using namespace llvm; @@ -242,19 +244,30 @@ static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI, const auto *DeadII = dyn_cast<IntrinsicInst>(DeadI); if (KillingII == nullptr || DeadII == nullptr) return OW_Unknown; - if (KillingII->getIntrinsicID() != Intrinsic::masked_store || - DeadII->getIntrinsicID() != Intrinsic::masked_store) + if (KillingII->getIntrinsicID() != DeadII->getIntrinsicID()) return OW_Unknown; - // Pointers. - Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts(); - Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts(); - if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr)) - return OW_Unknown; - // Masks. - // TODO: check that KillingII's mask is a superset of the DeadII's mask. - if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3)) - return OW_Unknown; - return OW_Complete; + if (KillingII->getIntrinsicID() == Intrinsic::masked_store) { + // Type size. + VectorType *KillingTy = + cast<VectorType>(KillingII->getArgOperand(0)->getType()); + VectorType *DeadTy = cast<VectorType>(DeadII->getArgOperand(0)->getType()); + if (KillingTy->getScalarSizeInBits() != DeadTy->getScalarSizeInBits()) + return OW_Unknown; + // Element count. + if (KillingTy->getElementCount() != DeadTy->getElementCount()) + return OW_Unknown; + // Pointers. + Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts(); + Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts(); + if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr)) + return OW_Unknown; + // Masks. + // TODO: check that KillingII's mask is a superset of the DeadII's mask. + if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3)) + return OW_Unknown; + return OW_Complete; + } + return OW_Unknown; } /// Return 'OW_Complete' if a store to the 'KillingLoc' location completely @@ -472,6 +485,45 @@ memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI, return true; } +static void shortenAssignment(Instruction *Inst, uint64_t OldOffsetInBits, + uint64_t OldSizeInBits, uint64_t NewSizeInBits, + bool IsOverwriteEnd) { + DIExpression::FragmentInfo DeadFragment; + DeadFragment.SizeInBits = OldSizeInBits - NewSizeInBits; + DeadFragment.OffsetInBits = + OldOffsetInBits + (IsOverwriteEnd ? NewSizeInBits : 0); + + auto CreateDeadFragExpr = [Inst, DeadFragment]() { + // FIXME: This should be using the DIExpression in the Alloca's dbg.assign + // for the variable, since that could also contain a fragment? + return *DIExpression::createFragmentExpression( + DIExpression::get(Inst->getContext(), std::nullopt), + DeadFragment.OffsetInBits, DeadFragment.SizeInBits); + }; + + // A DIAssignID to use so that the inserted dbg.assign intrinsics do not + // link to any instructions. Created in the loop below (once). + DIAssignID *LinkToNothing = nullptr; + + // Insert an unlinked dbg.assign intrinsic for the dead fragment after each + // overlapping dbg.assign intrinsic. + for (auto *DAI : at::getAssignmentMarkers(Inst)) { + if (auto FragInfo = DAI->getExpression()->getFragmentInfo()) { + if (!DIExpression::fragmentsOverlap(*FragInfo, DeadFragment)) + continue; + } + + // Fragments overlap: insert a new dbg.assign for this dead part. + auto *NewAssign = cast<DbgAssignIntrinsic>(DAI->clone()); + NewAssign->insertAfter(DAI); + if (!LinkToNothing) + LinkToNothing = DIAssignID::getDistinct(Inst->getContext()); + NewAssign->setAssignId(LinkToNothing); + NewAssign->setExpression(CreateDeadFragExpr()); + NewAssign->setKillAddress(); + } +} + static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart, uint64_t &DeadSize, int64_t KillingStart, uint64_t KillingSize, bool IsOverwriteEnd) { @@ -563,6 +615,10 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart, DeadIntrinsic->setDest(NewDestGEP); } + // Update attached dbg.assign intrinsics. Assume 8-bit byte. + shortenAssignment(DeadI, DeadStart * 8, DeadSize * 8, NewSize * 8, + IsOverwriteEnd); + // Finally update start and size of dead access. if (!IsOverwriteEnd) DeadStart += ToRemoveSize; @@ -823,6 +879,27 @@ struct DSEState { CodeMetrics::collectEphemeralValues(&F, &AC, EphValues); } + LocationSize strengthenLocationSize(const Instruction *I, + LocationSize Size) const { + if (auto *CB = dyn_cast<CallBase>(I)) { + LibFunc F; + if (TLI.getLibFunc(*CB, F) && TLI.has(F) && + (F == LibFunc_memset_chk || F == LibFunc_memcpy_chk)) { + // Use the precise location size specified by the 3rd argument + // for determining KillingI overwrites DeadLoc if it is a memset_chk + // instruction. memset_chk will write either the amount specified as 3rd + // argument or the function will immediately abort and exit the program. + // NOTE: AA may determine NoAlias if it can prove that the access size + // is larger than the allocation size due to that being UB. To avoid + // returning potentially invalid NoAlias results by AA, limit the use of + // the precise location size to isOverwrite. + if (const auto *Len = dyn_cast<ConstantInt>(CB->getArgOperand(2))) + return LocationSize::precise(Len->getZExtValue()); + } + } + return Size; + } + /// Return 'OW_Complete' if a store to the 'KillingLoc' location (by \p /// KillingI instruction) completely overwrites a store to the 'DeadLoc' /// location (by \p DeadI instruction). @@ -842,6 +919,8 @@ struct DSEState { if (!isGuaranteedLoopIndependent(DeadI, KillingI, DeadLoc)) return OW_Unknown; + LocationSize KillingLocSize = + strengthenLocationSize(KillingI, KillingLoc.Size); const Value *DeadPtr = DeadLoc.Ptr->stripPointerCasts(); const Value *KillingPtr = KillingLoc.Ptr->stripPointerCasts(); const Value *DeadUndObj = getUnderlyingObject(DeadPtr); @@ -849,16 +928,16 @@ struct DSEState { // Check whether the killing store overwrites the whole object, in which // case the size/offset of the dead store does not matter. - if (DeadUndObj == KillingUndObj && KillingLoc.Size.isPrecise()) { + if (DeadUndObj == KillingUndObj && KillingLocSize.isPrecise()) { uint64_t KillingUndObjSize = getPointerSize(KillingUndObj, DL, TLI, &F); if (KillingUndObjSize != MemoryLocation::UnknownSize && - KillingUndObjSize == KillingLoc.Size.getValue()) + KillingUndObjSize == KillingLocSize.getValue()) return OW_Complete; } // FIXME: Vet that this works for size upper-bounds. Seems unlikely that we'll // get imprecise values here, though (except for unknown sizes). - if (!KillingLoc.Size.isPrecise() || !DeadLoc.Size.isPrecise()) { + if (!KillingLocSize.isPrecise() || !DeadLoc.Size.isPrecise()) { // In case no constant size is known, try to an IR values for the number // of bytes written and check if they match. const auto *KillingMemI = dyn_cast<MemIntrinsic>(KillingI); @@ -875,7 +954,7 @@ struct DSEState { return isMaskedStoreOverwrite(KillingI, DeadI, BatchAA); } - const uint64_t KillingSize = KillingLoc.Size.getValue(); + const uint64_t KillingSize = KillingLocSize.getValue(); const uint64_t DeadSize = DeadLoc.Size.getValue(); // Query the alias information @@ -990,9 +1069,9 @@ struct DSEState { return !I.first->second; } - Optional<MemoryLocation> getLocForWrite(Instruction *I) const { + std::optional<MemoryLocation> getLocForWrite(Instruction *I) const { if (!I->mayWriteToMemory()) - return None; + return std::nullopt; if (auto *CB = dyn_cast<CallBase>(I)) return MemoryLocation::getForDest(CB, TLI); @@ -1075,13 +1154,16 @@ struct DSEState { } MemoryAccess *UseAccess = WorkList[I]; - // Simply adding the users of MemoryPhi to the worklist is not enough, - // because we might miss read clobbers in different iterations of a loop, - // for example. - // TODO: Add support for phi translation to handle the loop case. - if (isa<MemoryPhi>(UseAccess)) - return false; + if (isa<MemoryPhi>(UseAccess)) { + // AliasAnalysis does not account for loops. Limit elimination to + // candidates for which we can guarantee they always store to the same + // memory location. + if (!isGuaranteedLoopInvariant(MaybeLoc->Ptr)) + return false; + PushMemUses(cast<MemoryPhi>(UseAccess)); + continue; + } // TODO: Checking for aliasing is expensive. Consider reducing the amount // of times this is called and/or caching it. Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst(); @@ -1099,7 +1181,7 @@ struct DSEState { /// If \p I is a memory terminator like llvm.lifetime.end or free, return a /// pair with the MemoryLocation terminated by \p I and a boolean flag /// indicating whether \p I is a free-like call. - Optional<std::pair<MemoryLocation, bool>> + std::optional<std::pair<MemoryLocation, bool>> getLocForTerminator(Instruction *I) const { uint64_t Len; Value *Ptr; @@ -1112,7 +1194,7 @@ struct DSEState { return {std::make_pair(MemoryLocation::getAfter(FreedOp), true)}; } - return None; + return std::nullopt; } /// Returns true if \p I is a memory terminator instruction like @@ -1127,7 +1209,7 @@ struct DSEState { /// instruction \p AccessI. bool isMemTerminator(const MemoryLocation &Loc, Instruction *AccessI, Instruction *MaybeTerm) { - Optional<std::pair<MemoryLocation, bool>> MaybeTermLoc = + std::optional<std::pair<MemoryLocation, bool>> MaybeTermLoc = getLocForTerminator(MaybeTerm); if (!MaybeTermLoc) @@ -1201,25 +1283,27 @@ struct DSEState { if (GEP->hasAllConstantIndices()) Ptr = GEP->getPointerOperand()->stripPointerCasts(); - if (auto *I = dyn_cast<Instruction>(Ptr)) - return I->getParent()->isEntryBlock(); + if (auto *I = dyn_cast<Instruction>(Ptr)) { + return I->getParent()->isEntryBlock() || + (!ContainsIrreducibleLoops && !LI.getLoopFor(I->getParent())); + } return true; } // Find a MemoryDef writing to \p KillingLoc and dominating \p StartAccess, // with no read access between them or on any other path to a function exit // block if \p KillingLoc is not accessible after the function returns. If - // there is no such MemoryDef, return None. The returned value may not + // there is no such MemoryDef, return std::nullopt. The returned value may not // (completely) overwrite \p KillingLoc. Currently we bail out when we // encounter an aliasing MemoryUse (read). - Optional<MemoryAccess *> + std::optional<MemoryAccess *> getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess, const MemoryLocation &KillingLoc, const Value *KillingUndObj, unsigned &ScanLimit, unsigned &WalkerStepLimit, bool IsMemTerm, unsigned &PartialLimit) { if (ScanLimit == 0 || WalkerStepLimit == 0) { LLVM_DEBUG(dbgs() << "\n ... hit scan limit\n"); - return None; + return std::nullopt; } MemoryAccess *Current = StartAccess; @@ -1236,7 +1320,7 @@ struct DSEState { !KillingI->mayReadFromMemory(); // Find the next clobbering Mod access for DefLoc, starting at StartAccess. - Optional<MemoryLocation> CurrentLoc; + std::optional<MemoryLocation> CurrentLoc; for (;; Current = cast<MemoryDef>(Current)->getDefiningAccess()) { LLVM_DEBUG({ dbgs() << " visiting " << *Current; @@ -1252,7 +1336,7 @@ struct DSEState { if (CanOptimize && Current != KillingDef->getDefiningAccess()) // The first clobbering def is... none. KillingDef->setOptimized(Current); - return None; + return std::nullopt; } // Cost of a step. Accesses in the same block are more likely to be valid @@ -1262,7 +1346,7 @@ struct DSEState { : MemorySSAOtherBBStepCost; if (WalkerStepLimit <= StepCost) { LLVM_DEBUG(dbgs() << " ... hit walker step limit\n"); - return None; + return std::nullopt; } WalkerStepLimit -= StepCost; @@ -1287,14 +1371,14 @@ struct DSEState { // instructions that block us from DSEing if (mayThrowBetween(KillingI, CurrentI, KillingUndObj)) { LLVM_DEBUG(dbgs() << " ... skip, may throw!\n"); - return None; + return std::nullopt; } // Check for anything that looks like it will be a barrier to further // removal if (isDSEBarrier(KillingUndObj, CurrentI)) { LLVM_DEBUG(dbgs() << " ... skip, barrier\n"); - return None; + return std::nullopt; } // If Current is known to be on path that reads DefLoc or is a read @@ -1302,7 +1386,7 @@ struct DSEState { // for intrinsic calls, because the code knows how to handle memcpy // intrinsics. if (!isa<IntrinsicInst>(CurrentI) && isReadClobber(KillingLoc, CurrentI)) - return None; + return std::nullopt; // Quick check if there are direct uses that are read-clobbers. if (any_of(Current->uses(), [this, &KillingLoc, StartAccess](Use &U) { @@ -1312,7 +1396,7 @@ struct DSEState { return false; })) { LLVM_DEBUG(dbgs() << " ... found a read clobber\n"); - return None; + return std::nullopt; } // If Current does not have an analyzable write location or is not @@ -1406,7 +1490,7 @@ struct DSEState { // Bail out if the number of accesses to check exceeds the scan limit. if (ScanLimit < (WorkList.size() - I)) { LLVM_DEBUG(dbgs() << "\n ... hit scan limit\n"); - return None; + return std::nullopt; } --ScanLimit; NumDomMemDefChecks++; @@ -1451,14 +1535,14 @@ struct DSEState { if (UseInst->mayThrow() && !isInvisibleToCallerOnUnwind(KillingUndObj)) { LLVM_DEBUG(dbgs() << " ... found throwing instruction\n"); - return None; + return std::nullopt; } // Uses which may read the original MemoryDef mean we cannot eliminate the // original MD. Stop walk. if (isReadClobber(MaybeDeadLoc, UseInst)) { LLVM_DEBUG(dbgs() << " ... found read clobber\n"); - return None; + return std::nullopt; } // If this worklist walks back to the original memory access (and the @@ -1467,7 +1551,7 @@ struct DSEState { if (MaybeDeadAccess == UseAccess && !isGuaranteedLoopInvariant(MaybeDeadLoc.Ptr)) { LLVM_DEBUG(dbgs() << " ... found not loop invariant self access\n"); - return None; + return std::nullopt; } // Otherwise, for the KillingDef and MaybeDeadAccess we only have to check // if it reads the memory location. @@ -1501,7 +1585,7 @@ struct DSEState { } else { LLVM_DEBUG(dbgs() << " ... found preceeding def " << *UseInst << "\n"); - return None; + return std::nullopt; } } else PushMemUses(UseDef); @@ -1531,7 +1615,7 @@ struct DSEState { // killing block. if (!PDT.dominates(CommonPred, MaybeDeadAccess->getBlock())) { if (!AnyUnreachableExit) - return None; + return std::nullopt; // Fall back to CFG scan starting at all non-unreachable roots if not // all paths to the exit go through CommonPred. @@ -1562,7 +1646,7 @@ struct DSEState { if (KillingBlocks.count(Current)) continue; if (Current == MaybeDeadAccess->getBlock()) - return None; + return std::nullopt; // MaybeDeadAccess is reachable from the entry, so we don't have to // explore unreachable blocks further. @@ -1573,7 +1657,7 @@ struct DSEState { WorkList.insert(Pred); if (WorkList.size() >= MemorySSAPathCheckLimit) - return None; + return std::nullopt; } NumCFGSuccess++; } @@ -1774,10 +1858,9 @@ struct DSEState { !memoryIsNotModifiedBetween(Malloc, MemSet, BatchAA, DL, &DT)) return false; IRBuilder<> IRB(Malloc); - const auto &DL = Malloc->getModule()->getDataLayout(); - auto *Calloc = - emitCalloc(ConstantInt::get(IRB.getIntPtrTy(DL), 1), - Malloc->getArgOperand(0), IRB, TLI); + Type *SizeTTy = Malloc->getArgOperand(0)->getType(); + auto *Calloc = emitCalloc(ConstantInt::get(SizeTTy, 1), + Malloc->getArgOperand(0), IRB, TLI); if (!Calloc) return false; MemorySSAUpdater Updater(&MSSA); @@ -1818,7 +1901,7 @@ struct DSEState { // can modify the memory location. if (InitC && InitC == StoredConstant) return MSSA.isLiveOnEntryDef( - MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(Def)); + MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(Def, BatchAA)); } if (!Store) @@ -1837,7 +1920,7 @@ struct DSEState { // does not match LoadAccess. SetVector<MemoryAccess *> ToCheck; MemoryAccess *Current = - MSSA.getWalker()->getClobberingMemoryAccess(Def); + MSSA.getWalker()->getClobberingMemoryAccess(Def, BatchAA); // We don't want to bail when we run into the store memory def. But, // the phi access may point to it. So, pretend like we've already // checked it. @@ -1965,12 +2048,13 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, continue; Instruction *KillingI = KillingDef->getMemoryInst(); - Optional<MemoryLocation> MaybeKillingLoc; - if (State.isMemTerminatorInst(KillingI)) - MaybeKillingLoc = State.getLocForTerminator(KillingI).map( - [](const std::pair<MemoryLocation, bool> &P) { return P.first; }); - else + std::optional<MemoryLocation> MaybeKillingLoc; + if (State.isMemTerminatorInst(KillingI)) { + if (auto KillingLoc = State.getLocForTerminator(KillingI)) + MaybeKillingLoc = KillingLoc->first; + } else { MaybeKillingLoc = State.getLocForWrite(KillingI); + } if (!MaybeKillingLoc) { LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for " @@ -1998,7 +2082,7 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, if (State.SkipStores.count(Current)) continue; - Optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef( + std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef( KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit, WalkerStepLimit, IsMemTerm, PartialLimit); diff --git a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp index 66c9d9f0902a..303951643a0b 100644 --- a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp @@ -26,6 +26,7 @@ #include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BypassSlowDivision.h" +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -49,10 +50,10 @@ struct ExpandedMatch { /// X - ((X ?/ Y) * Y) /// which is equivalent to: /// X ?% Y -static llvm::Optional<ExpandedMatch> matchExpandedRem(Instruction &I) { +static std::optional<ExpandedMatch> matchExpandedRem(Instruction &I) { Value *Dividend, *XroundedDownToMultipleOfY; if (!match(&I, m_Sub(m_Value(Dividend), m_Value(XroundedDownToMultipleOfY)))) - return llvm::None; + return std::nullopt; Value *Divisor; Instruction *Div; @@ -62,7 +63,7 @@ static llvm::Optional<ExpandedMatch> matchExpandedRem(Instruction &I) { m_c_Mul(m_CombineAnd(m_IDiv(m_Specific(Dividend), m_Value(Divisor)), m_Instruction(Div)), m_Deferred(Divisor)))) - return llvm::None; + return std::nullopt; ExpandedMatch M; M.Key.SignedOp = Div->getOpcode() == Instruction::SDiv; @@ -266,12 +267,32 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, // DivBB will always reach the Div/Rem, we can hoist Div to PredBB. If // we have a DivRem operation we can also hoist Rem. Otherwise we'll leave // Rem where it is and rewrite it to mul/sub. - // FIXME: We could handle more hoisting cases. - if (RemBB->getSingleSuccessor() == DivBB) + if (RemBB->getSingleSuccessor() == DivBB) { PredBB = RemBB->getUniquePredecessor(); - if (PredBB && IsSafeToHoist(RemInst, RemBB) && - IsSafeToHoist(DivInst, DivBB) && + // Look for something like this + // PredBB + // / \ + // Div Rem + // + // If the Rem and Din blocks share a unique predecessor, and all + // paths from PredBB go to either RemBB or DivBB, and execution of RemBB + // and DivBB will always reach the Div/Rem, we can hoist Div to PredBB. + // If we have a DivRem operation we can also hoist Rem. By hoisting both + // ops to the same block, we reduce code size and allow the DivRem to + // issue sooner. Without a DivRem op, this transformation is + // unprofitable because we would end up performing an extra Mul+Sub on + // the Rem path. + } else if (BasicBlock *RemPredBB = RemBB->getUniquePredecessor()) { + // This hoist is only profitable when the target has a DivRem op. + if (HasDivRemOp && RemPredBB == DivBB->getUniquePredecessor()) + PredBB = RemPredBB; + } + // FIXME: We could handle more hoisting cases. + + if (PredBB && !isa<CatchSwitchInst>(PredBB->getTerminator()) && + isGuaranteedToTransferExecutionToSuccessor(PredBB->getTerminator()) && + IsSafeToHoist(RemInst, RemBB) && IsSafeToHoist(DivInst, DivBB) && all_of(successors(PredBB), [&](BasicBlock *BB) { return BB == DivBB || BB == RemBB; }) && all_of(predecessors(DivBB), diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index cf2824954122..26821c7ee81e 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -120,11 +120,27 @@ struct SimpleValue { case Intrinsic::experimental_constrained_fcmp: case Intrinsic::experimental_constrained_fcmps: { auto *CFP = cast<ConstrainedFPIntrinsic>(CI); - return CFP->isDefaultFPEnvironment(); + if (CFP->getExceptionBehavior() && + CFP->getExceptionBehavior() == fp::ebStrict) + return false; + // Since we CSE across function calls we must not allow + // the rounding mode to change. + if (CFP->getRoundingMode() && + CFP->getRoundingMode() == RoundingMode::Dynamic) + return false; + return true; } } } - return CI->doesNotAccessMemory() && !CI->getType()->isVoidTy(); + return CI->doesNotAccessMemory() && !CI->getType()->isVoidTy() && + // FIXME: Currently the calls which may access the thread id may + // be considered as not accessing the memory. But this is + // problematic for coroutines, since coroutines may resume in a + // different thread. So we disable the optimization here for the + // correctness. However, it may block many other correct + // optimizations. Revert this one when we detect the memory + // accessing kind more precisely. + !CI->getFunction()->isPresplitCoroutine(); } return isa<CastInst>(Inst) || isa<UnaryOperator>(Inst) || isa<BinaryOperator>(Inst) || isa<GetElementPtrInst>(Inst) || @@ -455,7 +471,15 @@ struct CallValue { return false; CallInst *CI = dyn_cast<CallInst>(Inst); - if (!CI || !CI->onlyReadsMemory()) + if (!CI || !CI->onlyReadsMemory() || + // FIXME: Currently the calls which may access the thread id may + // be considered as not accessing the memory. But this is + // problematic for coroutines, since coroutines may resume in a + // different thread. So we disable the optimization here for the + // correctness. However, it may block many other correct + // optimizations. Revert this one when we detect the memory + // accessing kind more precisely. + CI->getFunction()->isPresplitCoroutine()) return false; return true; } @@ -840,7 +864,7 @@ private: // TODO: We could insert relevant casts on type mismatch here. if (auto *LI = dyn_cast<LoadInst>(Inst)) return LI->getType() == ExpectedType ? LI : nullptr; - else if (auto *SI = dyn_cast<StoreInst>(Inst)) { + if (auto *SI = dyn_cast<StoreInst>(Inst)) { Value *V = SI->getValueOperand(); return V->getType() == ExpectedType ? V : nullptr; } @@ -853,11 +877,14 @@ private: Value *getOrCreateResultNonTargetMemIntrinsic(IntrinsicInst *II, Type *ExpectedType) const { + // TODO: We could insert relevant casts on type mismatch here. switch (II->getIntrinsicID()) { case Intrinsic::masked_load: - return II; - case Intrinsic::masked_store: - return II->getOperand(0); + return II->getType() == ExpectedType ? II : nullptr; + case Intrinsic::masked_store: { + Value *V = II->getOperand(0); + return V->getType() == ExpectedType ? V : nullptr; + } } return nullptr; } @@ -881,8 +908,8 @@ private: auto *Vec1 = dyn_cast<ConstantVector>(Mask1); if (!Vec0 || !Vec1) return false; - assert(Vec0->getType() == Vec1->getType() && - "Masks should have the same type"); + if (Vec0->getType() != Vec1->getType()) + return false; for (int i = 0, e = Vec0->getNumOperands(); i != e; ++i) { Constant *Elem0 = Vec0->getOperand(i); Constant *Elem1 = Vec1->getOperand(i); @@ -1106,7 +1133,7 @@ bool EarlyCSE::handleBranchCondition(Instruction *CondInst, Value *LHS, *RHS; if (MatchBinOp(Curr, PropagateOpcode, LHS, RHS)) - for (auto &Op : { LHS, RHS }) + for (auto *Op : { LHS, RHS }) if (Instruction *OPI = dyn_cast<Instruction>(Op)) if (SimpleValue::canHandle(OPI) && Visited.insert(OPI).second) WorkList.push_back(OPI); @@ -1234,7 +1261,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // See if any instructions in the block can be eliminated. If so, do it. If // not, add them to AvailableValues. - for (Instruction &Inst : make_early_inc_range(BB->getInstList())) { + for (Instruction &Inst : make_early_inc_range(*BB)) { // Dead instructions should just be removed. if (isInstructionTriviallyDead(&Inst, &TLI)) { LLVM_DEBUG(dbgs() << "EarlyCSE DCE: " << Inst << '\n'); @@ -1374,6 +1401,13 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If this is a simple instruction that we can value number, process it. if (SimpleValue::canHandle(&Inst)) { + if (auto *CI = dyn_cast<ConstrainedFPIntrinsic>(&Inst)) { + assert(CI->getExceptionBehavior() != fp::ebStrict && + "Unexpected ebStrict from SimpleValue::canHandle()"); + assert((!CI->getRoundingMode() || + CI->getRoundingMode() != RoundingMode::Dynamic) && + "Unexpected dynamic rounding from SimpleValue::canHandle()"); + } // See if the instruction has an available value. If so, use it. if (Value *V = AvailableValues.lookup(&Inst)) { LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << Inst << " to: " << *V diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp index 56f2a3b3004d..f66d1b914b0b 100644 --- a/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -235,15 +235,15 @@ void Float2IntPass::walkBackwards() { } // Calculate result range from operand ranges. -// Return None if the range cannot be calculated yet. -Optional<ConstantRange> Float2IntPass::calcRange(Instruction *I) { +// Return std::nullopt if the range cannot be calculated yet. +std::optional<ConstantRange> Float2IntPass::calcRange(Instruction *I) { SmallVector<ConstantRange, 4> OpRanges; for (Value *O : I->operands()) { if (Instruction *OI = dyn_cast<Instruction>(O)) { auto OpIt = SeenInsts.find(OI); assert(OpIt != SeenInsts.end() && "def not seen before use!"); if (OpIt->second == unknownRange()) - return None; // Wait until operand range has been calculated. + return std::nullopt; // Wait until operand range has been calculated. OpRanges.push_back(OpIt->second); } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) { // Work out if the floating point number can be losslessly represented @@ -335,7 +335,7 @@ void Float2IntPass::walkForwards() { Instruction *I = Worklist.back(); Worklist.pop_back(); - if (Optional<ConstantRange> Range = calcRange(I)) + if (std::optional<ConstantRange> Range = calcRange(I)) seen(I, *Range); else Worklist.push_front(I); // Reprocess later. diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index b460637b7d88..6158894e3437 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -76,6 +76,7 @@ #include <algorithm> #include <cassert> #include <cstdint> +#include <optional> #include <utility> using namespace llvm; @@ -121,6 +122,11 @@ static cl::opt<uint32_t> MaxBBSpeculations( "into) when deducing if a value is fully available or not in GVN " "(default = 600)")); +static cl::opt<uint32_t> MaxNumVisitedInsts( + "gvn-max-num-visited-insts", cl::Hidden, cl::init(100), + cl::desc("Max number of visited instructions when trying to find " + "dominating value of select dependency (default = 100)")); + struct llvm::GVNPass::Expression { uint32_t opcode; bool commutative = false; @@ -192,6 +198,8 @@ struct llvm::gvn::AvailableValue { /// Offset - The byte offset in Val that is interesting for the load query. unsigned Offset = 0; + /// V1, V2 - The dominating non-clobbered values of SelectVal. + Value *V1 = nullptr, *V2 = nullptr; static AvailableValue get(Value *V, unsigned Offset = 0) { AvailableValue Res; @@ -225,11 +233,13 @@ struct llvm::gvn::AvailableValue { return Res; } - static AvailableValue getSelect(SelectInst *Sel) { + static AvailableValue getSelect(SelectInst *Sel, Value *V1, Value *V2) { AvailableValue Res; Res.Val = Sel; Res.Kind = ValType::SelectVal; Res.Offset = 0; + Res.V1 = V1; + Res.V2 = V2; return Res; } @@ -290,8 +300,9 @@ struct llvm::gvn::AvailableValueInBlock { return get(BB, AvailableValue::getUndef()); } - static AvailableValueInBlock getSelect(BasicBlock *BB, SelectInst *Sel) { - return get(BB, AvailableValue::getSelect(Sel)); + static AvailableValueInBlock getSelect(BasicBlock *BB, SelectInst *Sel, + Value *V1, Value *V2) { + return get(BB, AvailableValue::getSelect(Sel, V1, V2)); } /// Emit code at the end of this block to adjust the value defined here to @@ -450,12 +461,28 @@ void GVNPass::ValueTable::add(Value *V, uint32_t num) { } uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) { - if (AA->doesNotAccessMemory(C)) { + if (AA->doesNotAccessMemory(C) && + // FIXME: Currently the calls which may access the thread id may + // be considered as not accessing the memory. But this is + // problematic for coroutines, since coroutines may resume in a + // different thread. So we disable the optimization here for the + // correctness. However, it may block many other correct + // optimizations. Revert this one when we detect the memory + // accessing kind more precisely. + !C->getFunction()->isPresplitCoroutine()) { Expression exp = createExpr(C); uint32_t e = assignExpNewValueNum(exp).first; valueNumbering[C] = e; return e; - } else if (MD && AA->onlyReadsMemory(C)) { + } else if (MD && AA->onlyReadsMemory(C) && + // FIXME: Currently the calls which may access the thread id may + // be considered as not accessing the memory. But this is + // problematic for coroutines, since coroutines may resume in a + // different thread. So we disable the optimization here for the + // correctness. However, it may block many other correct + // optimizations. Revert this one when we detect the memory + // accessing kind more precisely. + !C->getFunction()->isPresplitCoroutine()) { Expression exp = createExpr(C); auto ValNum = assignExpNewValueNum(exp); if (ValNum.second) { @@ -471,7 +498,7 @@ uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) { } if (local_dep.isDef()) { - // For masked load/store intrinsics, the local_dep may actully be + // For masked load/store intrinsics, the local_dep may actually be // a normal load or store instruction. CallInst *local_cdep = dyn_cast<CallInst>(local_dep.getInst()); @@ -502,21 +529,20 @@ uint32_t GVNPass::ValueTable::lookupOrAddCall(CallInst *C) { // Check to see if we have a single dominating call instruction that is // identical to C. - for (unsigned i = 0, e = deps.size(); i != e; ++i) { - const NonLocalDepEntry *I = &deps[i]; - if (I->getResult().isNonLocal()) + for (const NonLocalDepEntry &I : deps) { + if (I.getResult().isNonLocal()) continue; // We don't handle non-definitions. If we already have a call, reject // instruction dependencies. - if (!I->getResult().isDef() || cdep != nullptr) { + if (!I.getResult().isDef() || cdep != nullptr) { cdep = nullptr; break; } - CallInst *NonLocalDepCall = dyn_cast<CallInst>(I->getResult().getInst()); + CallInst *NonLocalDepCall = dyn_cast<CallInst>(I.getResult().getInst()); // FIXME: All duplicated with non-local case. - if (NonLocalDepCall && DT->properlyDominates(I->getBB(), C->getParent())){ + if (NonLocalDepCall && DT->properlyDominates(I.getBB(), C->getParent())) { cdep = NonLocalDepCall; continue; } @@ -564,12 +590,12 @@ uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) { if (VI != valueNumbering.end()) return VI->second; - if (!isa<Instruction>(V)) { + auto *I = dyn_cast<Instruction>(V); + if (!I) { valueNumbering[V] = nextValueNumber; return nextValueNumber++; } - Instruction* I = cast<Instruction>(V); Expression exp; switch (I->getOpcode()) { case Instruction::Call: @@ -747,15 +773,15 @@ void GVNPass::printPipeline( OS, MapClassName2PassName); OS << "<"; - if (Options.AllowPRE != None) - OS << (Options.AllowPRE.value() ? "" : "no-") << "pre;"; - if (Options.AllowLoadPRE != None) - OS << (Options.AllowLoadPRE.value() ? "" : "no-") << "load-pre;"; - if (Options.AllowLoadPRESplitBackedge != None) - OS << (Options.AllowLoadPRESplitBackedge.value() ? "" : "no-") + if (Options.AllowPRE != std::nullopt) + OS << (*Options.AllowPRE ? "" : "no-") << "pre;"; + if (Options.AllowLoadPRE != std::nullopt) + OS << (*Options.AllowLoadPRE ? "" : "no-") << "load-pre;"; + if (Options.AllowLoadPRESplitBackedge != std::nullopt) + OS << (*Options.AllowLoadPRESplitBackedge ? "" : "no-") << "split-backedge-load-pre;"; - if (Options.AllowMemDep != None) - OS << (Options.AllowMemDep.value() ? "" : "no-") << "memdep"; + if (Options.AllowMemDep != std::nullopt) + OS << (*Options.AllowMemDep ? "" : "no-") << "memdep"; OS << ">"; } @@ -794,7 +820,7 @@ static bool IsValueFullyAvailableInBlock( BasicBlock *BB, DenseMap<BasicBlock *, AvailabilityState> &FullyAvailableBlocks) { SmallVector<BasicBlock *, 32> Worklist; - Optional<BasicBlock *> UnavailableBB; + std::optional<BasicBlock *> UnavailableBB; // The number of times we didn't find an entry for a block in a map and // optimistically inserted an entry marking block as speculatively available. @@ -951,17 +977,6 @@ ConstructSSAForLoadSet(LoadInst *Load, return SSAUpdate.GetValueInMiddleOfBlock(Load->getParent()); } -static LoadInst *findDominatingLoad(Value *Ptr, Type *LoadTy, SelectInst *Sel, - DominatorTree &DT) { - for (Value *U : Ptr->users()) { - auto *LI = dyn_cast<LoadInst>(U); - if (LI && LI->getType() == LoadTy && LI->getParent() == Sel->getParent() && - DT.dominates(LI, Sel)) - return LI; - } - return nullptr; -} - Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, Instruction *InsertPt, GVNPass &gvn) const { @@ -1005,14 +1020,8 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, } else if (isSelectValue()) { // Introduce a new value select for a load from an eligible pointer select. SelectInst *Sel = getSelectValue(); - LoadInst *L1 = findDominatingLoad(Sel->getOperand(1), LoadTy, Sel, - gvn.getDominatorTree()); - LoadInst *L2 = findDominatingLoad(Sel->getOperand(2), LoadTy, Sel, - gvn.getDominatorTree()); - assert(L1 && L2 && - "must be able to obtain dominating loads for both value operands of " - "the select"); - Res = SelectInst::Create(Sel->getCondition(), L1, L2, "", Sel); + assert(V1 && V2 && "both value operands of the select must be present"); + Res = SelectInst::Create(Sel->getCondition(), V1, V2, "", Sel); } else { llvm_unreachable("Should not materialize value from dead block"); } @@ -1044,25 +1053,25 @@ static void reportMayClobberedLoad(LoadInst *Load, MemDepResult DepInfo, OptimizationRemarkEmitter *ORE) { using namespace ore; - User *OtherAccess = nullptr; + Instruction *OtherAccess = nullptr; OptimizationRemarkMissed R(DEBUG_TYPE, "LoadClobbered", Load); R << "load of type " << NV("Type", Load->getType()) << " not eliminated" << setExtraArgs(); for (auto *U : Load->getPointerOperand()->users()) { - if (U != Load && (isa<LoadInst>(U) || isa<StoreInst>(U)) && - cast<Instruction>(U)->getFunction() == Load->getFunction() && - DT->dominates(cast<Instruction>(U), Load)) { - // Use the most immediately dominating value - if (OtherAccess) { - if (DT->dominates(cast<Instruction>(OtherAccess), cast<Instruction>(U))) - OtherAccess = U; - else - assert(U == OtherAccess || DT->dominates(cast<Instruction>(U), - cast<Instruction>(OtherAccess))); - } else - OtherAccess = U; + if (U != Load && (isa<LoadInst>(U) || isa<StoreInst>(U))) { + auto *I = cast<Instruction>(U); + if (I->getFunction() == Load->getFunction() && DT->dominates(I, Load)) { + // Use the most immediately dominating value + if (OtherAccess) { + if (DT->dominates(OtherAccess, I)) + OtherAccess = I; + else + assert(U == OtherAccess || DT->dominates(I, OtherAccess)); + } else + OtherAccess = I; + } } } @@ -1070,22 +1079,22 @@ static void reportMayClobberedLoad(LoadInst *Load, MemDepResult DepInfo, // There is no dominating use, check if we can find a closest non-dominating // use that lies between any other potentially available use and Load. for (auto *U : Load->getPointerOperand()->users()) { - if (U != Load && (isa<LoadInst>(U) || isa<StoreInst>(U)) && - cast<Instruction>(U)->getFunction() == Load->getFunction() && - isPotentiallyReachable(cast<Instruction>(U), Load, nullptr, DT)) { - if (OtherAccess) { - if (liesBetween(cast<Instruction>(OtherAccess), cast<Instruction>(U), - Load, DT)) { - OtherAccess = U; - } else if (!liesBetween(cast<Instruction>(U), - cast<Instruction>(OtherAccess), Load, DT)) { - // These uses are both partially available at Load were it not for - // the clobber, but neither lies strictly after the other. - OtherAccess = nullptr; - break; - } // else: keep current OtherAccess since it lies between U and Load - } else { - OtherAccess = U; + if (U != Load && (isa<LoadInst>(U) || isa<StoreInst>(U))) { + auto *I = cast<Instruction>(U); + if (I->getFunction() == Load->getFunction() && + isPotentiallyReachable(I, Load, nullptr, DT)) { + if (OtherAccess) { + if (liesBetween(OtherAccess, I, Load, DT)) { + OtherAccess = I; + } else if (!liesBetween(I, OtherAccess, Load, DT)) { + // These uses are both partially available at Load were it not for + // the clobber, but neither lies strictly after the other. + OtherAccess = nullptr; + break; + } // else: keep current OtherAccess since it lies between U and Load + } else { + OtherAccess = I; + } } } } @@ -1099,61 +1108,39 @@ static void reportMayClobberedLoad(LoadInst *Load, MemDepResult DepInfo, ORE->emit(R); } -/// Check if a load from pointer-select \p Address in \p DepBB can be converted -/// to a value select. The following conditions need to be satisfied: -/// 1. The pointer select (\p Address) must be defined in \p DepBB. -/// 2. Both value operands of the pointer select must be loaded in the same -/// basic block, before the pointer select. -/// 3. There must be no instructions between the found loads and \p End that may -/// clobber the loads. -static Optional<AvailableValue> -tryToConvertLoadOfPtrSelect(BasicBlock *DepBB, BasicBlock::iterator End, - Value *Address, Type *LoadTy, DominatorTree &DT, - AAResults *AA) { - - auto *Sel = dyn_cast_or_null<SelectInst>(Address); - if (!Sel || DepBB != Sel->getParent()) - return None; - - LoadInst *L1 = findDominatingLoad(Sel->getOperand(1), LoadTy, Sel, DT); - LoadInst *L2 = findDominatingLoad(Sel->getOperand(2), LoadTy, Sel, DT); - if (!L1 || !L2) - return None; - - // Ensure there are no accesses that may modify the locations referenced by - // either L1 or L2 between L1, L2 and the specified End iterator. - Instruction *EarlierLoad = L1->comesBefore(L2) ? L1 : L2; - MemoryLocation L1Loc = MemoryLocation::get(L1); - MemoryLocation L2Loc = MemoryLocation::get(L2); - if (any_of(make_range(EarlierLoad->getIterator(), End), [&](Instruction &I) { - return isModSet(AA->getModRefInfo(&I, L1Loc)) || - isModSet(AA->getModRefInfo(&I, L2Loc)); - })) - return None; - - return AvailableValue::getSelect(Sel); -} - -bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, - Value *Address, AvailableValue &Res) { - if (!DepInfo.isDef() && !DepInfo.isClobber()) { - assert(isa<SelectInst>(Address)); - if (auto R = tryToConvertLoadOfPtrSelect( - Load->getParent(), Load->getIterator(), Address, Load->getType(), - getDominatorTree(), getAliasAnalysis())) { - Res = *R; - return true; +// Find non-clobbered value for Loc memory location in extended basic block +// (chain of basic blocks with single predecessors) starting From instruction. +static Value *findDominatingValue(const MemoryLocation &Loc, Type *LoadTy, + Instruction *From, AAResults *AA) { + uint32_t NumVisitedInsts = 0; + BasicBlock *FromBB = From->getParent(); + BatchAAResults BatchAA(*AA); + for (BasicBlock *BB = FromBB; BB; BB = BB->getSinglePredecessor()) + for (auto I = BB == FromBB ? From->getReverseIterator() : BB->rbegin(), + E = BB->rend(); + I != E; ++I) { + // Stop the search if limit is reached. + if (++NumVisitedInsts > MaxNumVisitedInsts) + return nullptr; + Instruction *Inst = &*I; + if (isModSet(BatchAA.getModRefInfo(Inst, Loc))) + return nullptr; + if (auto *LI = dyn_cast<LoadInst>(Inst)) + if (LI->getPointerOperand() == Loc.Ptr && LI->getType() == LoadTy) + return LI; } - return false; - } + return nullptr; +} - assert((DepInfo.isDef() || DepInfo.isClobber()) && - "expected a local dependence"); +std::optional<AvailableValue> +GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, + Value *Address) { assert(Load->isUnordered() && "rules below are incorrect for ordered access"); - - const DataLayout &DL = Load->getModule()->getDataLayout(); + assert(DepInfo.isLocal() && "expected a local dependence"); Instruction *DepInst = DepInfo.getInst(); + + const DataLayout &DL = Load->getModule()->getDataLayout(); if (DepInfo.isClobber()) { // If the dependence is to a store that writes to a superset of the bits // read by the load, we can extract the bits we need for the load from the @@ -1163,10 +1150,8 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, if (Address && Load->isAtomic() <= DepSI->isAtomic()) { int Offset = analyzeLoadFromClobberingStore(Load->getType(), Address, DepSI, DL); - if (Offset != -1) { - Res = AvailableValue::get(DepSI->getValueOperand(), Offset); - return true; - } + if (Offset != -1) + return AvailableValue::get(DepSI->getValueOperand(), Offset); } } @@ -1188,15 +1173,15 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, canCoerceMustAliasedValueToLoad(DepLoad, LoadType, DL)) { const auto ClobberOff = MD->getClobberOffset(DepLoad); // GVN has no deal with a negative offset. - Offset = (ClobberOff == None || *ClobberOff < 0) ? -1 : *ClobberOff; + Offset = (ClobberOff == std::nullopt || *ClobberOff < 0) + ? -1 + : *ClobberOff; } if (Offset == -1) Offset = analyzeLoadFromClobberingLoad(LoadType, Address, DepLoad, DL); - if (Offset != -1) { - Res = AvailableValue::getLoad(DepLoad, Offset); - return true; - } + if (Offset != -1) + return AvailableValue::getLoad(DepLoad, Offset); } } @@ -1206,10 +1191,8 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, if (Address && !Load->isAtomic()) { int Offset = analyzeLoadFromClobberingMemInst(Load->getType(), Address, DepMI, DL); - if (Offset != -1) { - Res = AvailableValue::getMI(DepMI, Offset); - return true; - } + if (Offset != -1) + return AvailableValue::getMI(DepMI, Offset); } } @@ -1221,22 +1204,18 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, if (ORE->allowExtraAnalysis(DEBUG_TYPE)) reportMayClobberedLoad(Load, DepInfo, DT, ORE); - return false; + return std::nullopt; } assert(DepInfo.isDef() && "follows from above"); // Loading the alloca -> undef. // Loading immediately after lifetime begin -> undef. - if (isa<AllocaInst>(DepInst) || isLifetimeStart(DepInst)) { - Res = AvailableValue::get(UndefValue::get(Load->getType())); - return true; - } + if (isa<AllocaInst>(DepInst) || isLifetimeStart(DepInst)) + return AvailableValue::get(UndefValue::get(Load->getType())); if (Constant *InitVal = - getInitialValueOfAllocation(DepInst, TLI, Load->getType())) { - Res = AvailableValue::get(InitVal); - return true; - } + getInitialValueOfAllocation(DepInst, TLI, Load->getType())) + return AvailableValue::get(InitVal); if (StoreInst *S = dyn_cast<StoreInst>(DepInst)) { // Reject loads and stores that are to the same address but are of @@ -1244,14 +1223,13 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, // the loaded value, we can reuse it. if (!canCoerceMustAliasedValueToLoad(S->getValueOperand(), Load->getType(), DL)) - return false; + return std::nullopt; // Can't forward from non-atomic to atomic without violating memory model. if (S->isAtomic() < Load->isAtomic()) - return false; + return std::nullopt; - Res = AvailableValue::get(S->getValueOperand()); - return true; + return AvailableValue::get(S->getValueOperand()); } if (LoadInst *LD = dyn_cast<LoadInst>(DepInst)) { @@ -1259,14 +1237,32 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, // If the stored value is larger or equal to the loaded value, we can reuse // it. if (!canCoerceMustAliasedValueToLoad(LD, Load->getType(), DL)) - return false; + return std::nullopt; // Can't forward from non-atomic to atomic without violating memory model. if (LD->isAtomic() < Load->isAtomic()) - return false; - - Res = AvailableValue::getLoad(LD); - return true; + return std::nullopt; + + return AvailableValue::getLoad(LD); + } + + // Check if load with Addr dependent from select can be converted to select + // between load values. There must be no instructions between the found + // loads and DepInst that may clobber the loads. + if (auto *Sel = dyn_cast<SelectInst>(DepInst)) { + assert(Sel->getType() == Load->getPointerOperandType()); + auto Loc = MemoryLocation::get(Load); + Value *V1 = + findDominatingValue(Loc.getWithNewPtr(Sel->getTrueValue()), + Load->getType(), DepInst, getAliasAnalysis()); + if (!V1) + return std::nullopt; + Value *V2 = + findDominatingValue(Loc.getWithNewPtr(Sel->getFalseValue()), + Load->getType(), DepInst, getAliasAnalysis()); + if (!V2) + return std::nullopt; + return AvailableValue::getSelect(Sel, V1, V2); } // Unknown def - must be conservative @@ -1274,7 +1270,7 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, // fast print dep, using operator<< on instruction is too slow. dbgs() << "GVN: load "; Load->printAsOperand(dbgs()); dbgs() << " has unknown def " << *DepInst << '\n';); - return false; + return std::nullopt; } void GVNPass::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps, @@ -1284,10 +1280,9 @@ void GVNPass::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps, // where we have a value available in repl, also keep track of whether we see // dependencies that produce an unknown value for the load (such as a call // that could potentially clobber the load). - unsigned NumDeps = Deps.size(); - for (unsigned i = 0, e = NumDeps; i != e; ++i) { - BasicBlock *DepBB = Deps[i].getBB(); - MemDepResult DepInfo = Deps[i].getResult(); + for (const auto &Dep : Deps) { + BasicBlock *DepBB = Dep.getBB(); + MemDepResult DepInfo = Dep.getResult(); if (DeadBlocks.count(DepBB)) { // Dead dependent mem-op disguise as a load evaluating the same value @@ -1296,36 +1291,26 @@ void GVNPass::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps, continue; } - // The address being loaded in this non-local block may not be the same as - // the pointer operand of the load if PHI translation occurs. Make sure - // to consider the right address. - Value *Address = Deps[i].getAddress(); - - if (!DepInfo.isDef() && !DepInfo.isClobber()) { - if (auto R = tryToConvertLoadOfPtrSelect( - DepBB, DepBB->end(), Address, Load->getType(), getDominatorTree(), - getAliasAnalysis())) { - ValuesPerBlock.push_back( - AvailableValueInBlock::get(DepBB, std::move(*R))); - continue; - } + if (!DepInfo.isLocal()) { UnavailableBlocks.push_back(DepBB); continue; } - AvailableValue AV; - if (AnalyzeLoadAvailability(Load, DepInfo, Address, AV)) { + // The address being loaded in this non-local block may not be the same as + // the pointer operand of the load if PHI translation occurs. Make sure + // to consider the right address. + if (auto AV = AnalyzeLoadAvailability(Load, DepInfo, Dep.getAddress())) { // subtlety: because we know this was a non-local dependency, we know // it's safe to materialize anywhere between the instruction within // DepInfo and the end of it's block. - ValuesPerBlock.push_back(AvailableValueInBlock::get(DepBB, - std::move(AV))); + ValuesPerBlock.push_back( + AvailableValueInBlock::get(DepBB, std::move(*AV))); } else { UnavailableBlocks.push_back(DepBB); } } - assert(NumDeps == ValuesPerBlock.size() + UnavailableBlocks.size() && + assert(Deps.size() == ValuesPerBlock.size() + UnavailableBlocks.size() && "post condition violation"); } @@ -1534,10 +1519,11 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, // to speculatively execute the load at that points. if (MustEnsureSafetyOfSpeculativeExecution) { if (CriticalEdgePred.size()) - if (!isSafeToSpeculativelyExecute(Load, LoadBB->getFirstNonPHI(), DT)) + if (!isSafeToSpeculativelyExecute(Load, LoadBB->getFirstNonPHI(), AC, DT)) return false; for (auto &PL : PredLoads) - if (!isSafeToSpeculativelyExecute(Load, PL.first->getTerminator(), DT)) + if (!isSafeToSpeculativelyExecute(Load, PL.first->getTerminator(), AC, + DT)) return false; } @@ -1871,11 +1857,10 @@ static bool impliesEquivalanceIfFalse(CmpInst* Cmp) { static bool hasUsersIn(Value *V, BasicBlock *BB) { - for (User *U : V->users()) - if (isa<Instruction>(U) && - cast<Instruction>(U)->getParent() == BB) - return true; - return false; + return llvm::any_of(V->users(), [BB](User *U) { + auto *I = dyn_cast<Instruction>(U); + return I && I->getParent() == BB; + }); } bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { @@ -1900,7 +1885,7 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { // after the found access or before the terminator if no such access is // found. if (AL) { - for (auto &Acc : *AL) { + for (const auto &Acc : *AL) { if (auto *Current = dyn_cast<MemoryUseOrDef>(&Acc)) if (!Current->getMemoryInst()->comesBefore(NewS)) { FirstNonDom = Current; @@ -2042,9 +2027,8 @@ bool GVNPass::processLoad(LoadInst *L) { if (Dep.isNonLocal()) return processNonLocalLoad(L); - Value *Address = L->getPointerOperand(); // Only handle the local case below - if (!Dep.isDef() && !Dep.isClobber() && !isa<SelectInst>(Address)) { + if (!Dep.isLocal()) { // This might be a NonFuncLocal or an Unknown LLVM_DEBUG( // fast print dep, using operator<< on instruction is too slow. @@ -2053,25 +2037,24 @@ bool GVNPass::processLoad(LoadInst *L) { return false; } - AvailableValue AV; - if (AnalyzeLoadAvailability(L, Dep, Address, AV)) { - Value *AvailableValue = AV.MaterializeAdjustedValue(L, L, *this); + auto AV = AnalyzeLoadAvailability(L, Dep, L->getPointerOperand()); + if (!AV) + return false; - // Replace the load! - patchAndReplaceAllUsesWith(L, AvailableValue); - markInstructionForDeletion(L); - if (MSSAU) - MSSAU->removeMemoryAccess(L); - ++NumGVNLoad; - reportLoadElim(L, AvailableValue, ORE); - // Tell MDA to reexamine the reused pointer since we might have more - // information after forwarding it. - if (MD && AvailableValue->getType()->isPtrOrPtrVectorTy()) - MD->invalidateCachedPointerInfo(AvailableValue); - return true; - } + Value *AvailableValue = AV->MaterializeAdjustedValue(L, L, *this); - return false; + // Replace the load! + patchAndReplaceAllUsesWith(L, AvailableValue); + markInstructionForDeletion(L); + if (MSSAU) + MSSAU->removeMemoryAccess(L); + ++NumGVNLoad; + reportLoadElim(L, AvailableValue, ORE); + // Tell MDA to reexamine the reused pointer since we might have more + // information after forwarding it. + if (MD && AvailableValue->getType()->isPtrOrPtrVectorTy()) + MD->invalidateCachedPointerInfo(AvailableValue); + return true; } /// Return a pair the first field showing the value number of \p Exp and the @@ -2812,17 +2795,10 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) { NumWithout = 2; break; } - // It is not safe to do PRE when P->CurrentBlock is a loop backedge, and - // when CurInst has operand defined in CurrentBlock (so it may be defined - // by phi in the loop header). + // It is not safe to do PRE when P->CurrentBlock is a loop backedge. assert(BlockRPONumber.count(P) && BlockRPONumber.count(CurrentBlock) && "Invalid BlockRPONumber map."); - if (BlockRPONumber[P] >= BlockRPONumber[CurrentBlock] && - llvm::any_of(CurInst->operands(), [&](const Use &U) { - if (auto *Inst = dyn_cast<Instruction>(U.get())) - return Inst->getParent() == CurrentBlock; - return false; - })) { + if (BlockRPONumber[P] >= BlockRPONumber[CurrentBlock]) { NumWithout = 2; break; } diff --git a/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/llvm/lib/Transforms/Scalar/GVNHoist.cpp index 6cdc671ddb64..bbff497b7d92 100644 --- a/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -379,7 +379,7 @@ private: if (!Root) return; // Depth first walk on PDom tree to fill the CHIargs at each PDF. - for (auto Node : depth_first(Root)) { + for (auto *Node : depth_first(Root)) { BasicBlock *BB = Node->getBlock(); if (!BB) continue; @@ -435,7 +435,7 @@ private: continue; const VNType &VN = R; SmallPtrSet<BasicBlock *, 2> VNBlocks; - for (auto &I : V) { + for (const auto &I : V) { BasicBlock *BBI = I->getParent(); if (!hasEH(BBI)) VNBlocks.insert(BBI); @@ -563,7 +563,7 @@ bool GVNHoist::run(Function &F) { for (const BasicBlock *BB : depth_first(&F.getEntryBlock())) { DFSNumber[BB] = ++BBI; unsigned I = 0; - for (auto &Inst : *BB) + for (const auto &Inst : *BB) DFSNumber[&Inst] = ++I; } @@ -842,7 +842,7 @@ void GVNHoist::fillRenameStack(BasicBlock *BB, InValuesType &ValueBBs, void GVNHoist::fillChiArgs(BasicBlock *BB, OutValuesType &CHIBBs, GVNHoist::RenameStackType &RenameStack) { // For each *predecessor* (because Post-DOM) of BB check if it has a CHI - for (auto Pred : predecessors(BB)) { + for (auto *Pred : predecessors(BB)) { auto P = CHIBBs.find(Pred); if (P == CHIBBs.end()) { continue; diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp index 720b8e71fd56..5fb8a77051fb 100644 --- a/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -37,8 +37,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -288,7 +286,7 @@ public: ArrayRef<Value *> getValues() const { return Values; } bool areAllIncomingValuesSame() const { - return llvm::all_of(Values, [&](Value *V) { return V == Values[0]; }); + return llvm::all_equal(Values); } bool areAllIncomingValuesSameType() const { @@ -599,8 +597,8 @@ private: /// The main heuristic function. Analyze the set of instructions pointed to by /// LRI and return a candidate solution if these instructions can be sunk, or - /// None otherwise. - Optional<SinkingInstructionCandidate> analyzeInstructionForSinking( + /// std::nullopt otherwise. + std::optional<SinkingInstructionCandidate> analyzeInstructionForSinking( LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum, ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents); @@ -634,15 +632,18 @@ private: if (PN->getIncomingValue(0) != PN) PN->replaceAllUsesWith(PN->getIncomingValue(0)); else - PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->replaceAllUsesWith(PoisonValue::get(PN->getType())); PN->eraseFromParent(); } } }; -Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( - LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum, - ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents) { +std::optional<SinkingInstructionCandidate> +GVNSink::analyzeInstructionForSinking(LockstepReverseIterator &LRI, + unsigned &InstNum, + unsigned &MemoryInstNum, + ModelledPHISet &NeededPHIs, + SmallPtrSetImpl<Value *> &PHIContents) { auto Insts = *LRI; LLVM_DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I : Insts) { @@ -654,7 +655,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( uint32_t N = VN.lookupOrAdd(I); LLVM_DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n"); if (N == ~0U) - return None; + return std::nullopt; VNums[N]++; } unsigned VNumToSink = @@ -662,7 +663,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( if (VNums[VNumToSink] == 1) // Can't sink anything! - return None; + return std::nullopt; // Now restrict the number of incoming blocks down to only those with // VNumToSink. @@ -677,7 +678,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( } for (auto *I : NewInsts) if (shouldAvoidSinkingInstruction(I)) - return None; + return std::nullopt; // If we've restricted the incoming blocks, restrict all needed PHIs also // to that set. @@ -715,7 +716,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( // V exists in this PHI, but the whole PHI is different to NewPHI // (else it would have been removed earlier). We cannot continue // because this isn't representable. - return None; + return std::nullopt; // Which operands need PHIs? // FIXME: If any of these fail, we should partition up the candidates to @@ -728,7 +729,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( return I->getNumOperands() != I0->getNumOperands(); }; if (any_of(NewInsts, hasDifferentNumOperands)) - return None; + return std::nullopt; for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) { ModelledPHI PHI(NewInsts, OpNum, ActivePreds); @@ -736,15 +737,15 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( continue; if (!canReplaceOperandWithVariable(I0, OpNum)) // We can 't create a PHI from this instruction! - return None; + return std::nullopt; if (NeededPHIs.count(PHI)) continue; if (!PHI.areAllIncomingValuesSameType()) - return None; + return std::nullopt; // Don't create indirect calls! The called value is the final operand. if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OpNum == E - 1 && PHI.areAnyIncomingValuesConstant()) - return None; + return std::nullopt; NeededPHIs.reserve(NeededPHIs.size()); NeededPHIs.insert(PHI); diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp index af6062d142f0..abe0babc3f12 100644 --- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -42,6 +42,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -93,7 +94,7 @@ static Value *getCondition(Instruction *I) { } // Set the condition for \p I to \p NewCond. \p I can either be a guard or a -// conditional branch. +// conditional branch. static void setCondition(Instruction *I, Value *NewCond) { if (IntrinsicInst *GI = dyn_cast<IntrinsicInst>(I)) { assert(GI->getIntrinsicID() == Intrinsic::experimental_guard && @@ -116,6 +117,7 @@ class GuardWideningImpl { DominatorTree &DT; PostDominatorTree *PDT; LoopInfo &LI; + AssumptionCache &AC; MemorySSAUpdater *MSSAU; /// Together, these describe the region of interest. This might be all of @@ -261,7 +263,7 @@ class GuardWideningImpl { void widenGuard(Instruction *ToWiden, Value *NewCondition, bool InvertCondition) { Value *Result; - + widenCondCommon(getCondition(ToWiden), NewCondition, ToWiden, Result, InvertCondition); if (isGuardAsWidenableBranch(ToWiden)) { @@ -273,10 +275,10 @@ class GuardWideningImpl { public: explicit GuardWideningImpl(DominatorTree &DT, PostDominatorTree *PDT, - LoopInfo &LI, MemorySSAUpdater *MSSAU, - DomTreeNode *Root, - std::function<bool(BasicBlock*)> BlockFilter) - : DT(DT), PDT(PDT), LI(LI), MSSAU(MSSAU), Root(Root), + LoopInfo &LI, AssumptionCache &AC, + MemorySSAUpdater *MSSAU, DomTreeNode *Root, + std::function<bool(BasicBlock *)> BlockFilter) + : DT(DT), PDT(PDT), LI(LI), AC(AC), MSSAU(MSSAU), Root(Root), BlockFilter(BlockFilter) {} /// The entry point for this pass. @@ -468,7 +470,7 @@ bool GuardWideningImpl::isAvailableAt( if (!Inst || DT.dominates(Inst, Loc) || Visited.count(Inst)) return true; - if (!isSafeToSpeculativelyExecute(Inst, Loc, &DT) || + if (!isSafeToSpeculativelyExecute(Inst, Loc, &AC, &DT) || Inst->mayReadFromMemory()) return false; @@ -488,7 +490,7 @@ void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const { if (!Inst || DT.dominates(Inst, Loc)) return; - assert(isSafeToSpeculativelyExecute(Inst, Loc, &DT) && + assert(isSafeToSpeculativelyExecute(Inst, Loc, &AC, &DT) && !Inst->mayReadFromMemory() && "Should've checked with isAvailableAt!"); for (Value *Op : Inst->operands()) @@ -522,7 +524,8 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, // Given what we're doing here and the semantics of guards, it would // be correct to use a subset intersection, but that may be too // aggressive in cases we care about. - if (Optional<ConstantRange> Intersect = CR0.exactIntersectWith(CR1)) { + if (std::optional<ConstantRange> Intersect = + CR0.exactIntersectWith(CR1)) { APInt NewRHSAP; CmpInst::Predicate Pred; if (Intersect->getEquivalentICmp(Pred, NewRHSAP)) { @@ -764,11 +767,12 @@ PreservedAnalyses GuardWideningPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); auto *MSSAA = AM.getCachedResult<MemorySSAAnalysis>(F); std::unique_ptr<MemorySSAUpdater> MSSAU; if (MSSAA) MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAA->getMSSA()); - if (!GuardWideningImpl(DT, &PDT, LI, MSSAU ? MSSAU.get() : nullptr, + if (!GuardWideningImpl(DT, &PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr, DT.getRootNode(), [](BasicBlock *) { return true; }) .run()) return PreservedAnalyses::all(); @@ -791,8 +795,10 @@ PreservedAnalyses GuardWideningPass::run(Loop &L, LoopAnalysisManager &AM, std::unique_ptr<MemorySSAUpdater> MSSAU; if (AR.MSSA) MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA); - if (!GuardWideningImpl(AR.DT, nullptr, AR.LI, MSSAU ? MSSAU.get() : nullptr, - AR.DT.getNode(RootBB), BlockFilter).run()) + if (!GuardWideningImpl(AR.DT, nullptr, AR.LI, AR.AC, + MSSAU ? MSSAU.get() : nullptr, AR.DT.getNode(RootBB), + BlockFilter) + .run()) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); @@ -814,12 +820,13 @@ struct GuardWideningLegacyPass : public FunctionPass { return false; auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); std::unique_ptr<MemorySSAUpdater> MSSAU; if (MSSAWP) MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA()); - return GuardWideningImpl(DT, &PDT, LI, MSSAU ? MSSAU.get() : nullptr, + return GuardWideningImpl(DT, &PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr, DT.getRootNode(), [](BasicBlock *) { return true; }) .run(); @@ -848,6 +855,8 @@ struct LoopGuardWideningLegacyPass : public LoopPass { return false; auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); @@ -861,8 +870,9 @@ struct LoopGuardWideningLegacyPass : public LoopPass { auto BlockFilter = [&](BasicBlock *BB) { return BB == RootBB || L->contains(BB); }; - return GuardWideningImpl(DT, PDT, LI, MSSAU ? MSSAU.get() : nullptr, - DT.getNode(RootBB), BlockFilter).run(); + return GuardWideningImpl(DT, PDT, LI, AC, MSSAU ? MSSAU.get() : nullptr, + DT.getNode(RootBB), BlockFilter) + .run(); } void getAnalysisUsage(AnalysisUsage &AU) const override { diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 9698ed97379e..c834e51b5f29 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -26,7 +26,6 @@ #include "llvm/Transforms/Scalar/IndVarSimplify.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" @@ -196,7 +195,7 @@ static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) { bool isExact = false; // See if we can convert this to an int64_t uint64_t UIntVal; - if (APF.convertToInteger(makeMutableArrayRef(UIntVal), 64, true, + if (APF.convertToInteger(MutableArrayRef(UIntVal), 64, true, APFloat::rmTowardZero, &isExact) != APFloat::opOK || !isExact) return false; @@ -675,7 +674,7 @@ static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L) { // An IV counter must preserve its type. if (IncI->getNumOperands() == 2) break; - LLVM_FALLTHROUGH; + [[fallthrough]]; default: return nullptr; } @@ -789,7 +788,9 @@ static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, // If we can't analyze propagation through this instruction, just skip it // and transitive users. Safe as false is a conservative result. - if (!propagatesPoison(cast<Operator>(I)) && I != Root) + if (I != Root && !any_of(I->operands(), [&KnownPoison](const Use &U) { + return KnownPoison.contains(U) && propagatesPoison(U); + })) continue; if (KnownPoison.insert(I).second) @@ -1281,6 +1282,7 @@ bool IndVarSimplify::sinkUnusedInvariants(Loop *L) { MadeAnyChanges = true; ToMove->moveBefore(*ExitBlock, InsertPt); + SE->forgetValue(ToMove); if (Done) break; InsertPt = ToMove->getIterator(); } @@ -1291,23 +1293,32 @@ bool IndVarSimplify::sinkUnusedInvariants(Loop *L) { static void replaceExitCond(BranchInst *BI, Value *NewCond, SmallVectorImpl<WeakTrackingVH> &DeadInsts) { auto *OldCond = BI->getCondition(); + LLVM_DEBUG(dbgs() << "Replacing condition of loop-exiting branch " << *BI + << " with " << *NewCond << "\n"); BI->setCondition(NewCond); if (OldCond->use_empty()) DeadInsts.emplace_back(OldCond); } -static void foldExit(const Loop *L, BasicBlock *ExitingBB, bool IsTaken, - SmallVectorImpl<WeakTrackingVH> &DeadInsts) { +static Constant *createFoldedExitCond(const Loop *L, BasicBlock *ExitingBB, + bool IsTaken) { BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); auto *OldCond = BI->getCondition(); - auto *NewCond = - ConstantInt::get(OldCond->getType(), IsTaken ? ExitIfTrue : !ExitIfTrue); + return ConstantInt::get(OldCond->getType(), + IsTaken ? ExitIfTrue : !ExitIfTrue); +} + +static void foldExit(const Loop *L, BasicBlock *ExitingBB, bool IsTaken, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) { + BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); + auto *NewCond = createFoldedExitCond(L, ExitingBB, IsTaken); replaceExitCond(BI, NewCond, DeadInsts); } static void replaceLoopPHINodesWithPreheaderValues( - LoopInfo *LI, Loop *L, SmallVectorImpl<WeakTrackingVH> &DeadInsts) { + LoopInfo *LI, Loop *L, SmallVectorImpl<WeakTrackingVH> &DeadInsts, + ScalarEvolution &SE) { assert(L->isLoopSimplifyForm() && "Should only do it in simplify form!"); auto *LoopPreheader = L->getLoopPreheader(); auto *LoopHeader = L->getHeader(); @@ -1316,6 +1327,7 @@ static void replaceLoopPHINodesWithPreheaderValues( auto *PreheaderIncoming = PN.getIncomingValueForBlock(LoopPreheader); for (User *U : PN.users()) Worklist.push_back(cast<Instruction>(U)); + SE.forgetValue(&PN); PN.replaceAllUsesWith(PreheaderIncoming); DeadInsts.emplace_back(&PN); } @@ -1342,56 +1354,41 @@ static void replaceLoopPHINodesWithPreheaderValues( } } -static void replaceWithInvariantCond( - const Loop *L, BasicBlock *ExitingBB, ICmpInst::Predicate InvariantPred, - const SCEV *InvariantLHS, const SCEV *InvariantRHS, SCEVExpander &Rewriter, - SmallVectorImpl<WeakTrackingVH> &DeadInsts) { +static Value * +createInvariantCond(const Loop *L, BasicBlock *ExitingBB, + const ScalarEvolution::LoopInvariantPredicate &LIP, + SCEVExpander &Rewriter) { + ICmpInst::Predicate InvariantPred = LIP.Pred; BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); Rewriter.setInsertPoint(BI); - auto *LHSV = Rewriter.expandCodeFor(InvariantLHS); - auto *RHSV = Rewriter.expandCodeFor(InvariantRHS); + auto *LHSV = Rewriter.expandCodeFor(LIP.LHS); + auto *RHSV = Rewriter.expandCodeFor(LIP.RHS); bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); if (ExitIfTrue) InvariantPred = ICmpInst::getInversePredicate(InvariantPred); IRBuilder<> Builder(BI); - auto *NewCond = Builder.CreateICmp(InvariantPred, LHSV, RHSV, - BI->getCondition()->getName()); - replaceExitCond(BI, NewCond, DeadInsts); + return Builder.CreateICmp(InvariantPred, LHSV, RHSV, + BI->getCondition()->getName()); } -static bool optimizeLoopExitWithUnknownExitCount( - const Loop *L, BranchInst *BI, BasicBlock *ExitingBB, - const SCEV *MaxIter, bool Inverted, bool SkipLastIter, - ScalarEvolution *SE, SCEVExpander &Rewriter, - SmallVectorImpl<WeakTrackingVH> &DeadInsts) { - ICmpInst::Predicate Pred; - Value *LHS, *RHS; - BasicBlock *TrueSucc, *FalseSucc; - if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), - m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) - return false; - - assert((L->contains(TrueSucc) != L->contains(FalseSucc)) && - "Not a loop exit!"); +static std::optional<Value *> +createReplacement(ICmpInst *ICmp, const Loop *L, BasicBlock *ExitingBB, + const SCEV *MaxIter, bool Inverted, bool SkipLastIter, + ScalarEvolution *SE, SCEVExpander &Rewriter) { + ICmpInst::Predicate Pred = ICmp->getPredicate(); + Value *LHS = ICmp->getOperand(0); + Value *RHS = ICmp->getOperand(1); // 'LHS pred RHS' should now mean that we stay in loop. - if (L->contains(FalseSucc)) - Pred = CmpInst::getInversePredicate(Pred); - - // If we are proving loop exit, invert the predicate. + auto *BI = cast<BranchInst>(ExitingBB->getTerminator()); if (Inverted) Pred = CmpInst::getInversePredicate(Pred); const SCEV *LHSS = SE->getSCEVAtScope(LHS, L); const SCEV *RHSS = SE->getSCEVAtScope(RHS, L); - // Can we prove it to be trivially true? - if (SE->isKnownPredicateAt(Pred, LHSS, RHSS, BI)) { - foldExit(L, ExitingBB, Inverted, DeadInsts); - return true; - } - // Further logic works for non-inverted condition only. - if (Inverted) - return false; + // Can we prove it to be trivially true or false? + if (auto EV = SE->evaluatePredicateAt(Pred, LHSS, RHSS, BI)) + return createFoldedExitCond(L, ExitingBB, /*IsTaken*/ !*EV); auto *ARTy = LHSS->getType(); auto *MaxIterTy = MaxIter->getType(); @@ -1406,24 +1403,135 @@ static bool optimizeLoopExitWithUnknownExitCount( } if (SkipLastIter) { - const SCEV *One = SE->getOne(MaxIter->getType()); - MaxIter = SE->getMinusSCEV(MaxIter, One); + // Semantically skip last iter is "subtract 1, do not bother about unsigned + // wrap". getLoopInvariantExitCondDuringFirstIterations knows how to deal + // with umin in a smart way, but umin(a, b) - 1 will likely not simplify. + // So we manually construct umin(a - 1, b - 1). + SmallVector<const SCEV *, 4> Elements; + if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter)) { + for (auto *Op : UMin->operands()) + Elements.push_back(SE->getMinusSCEV(Op, SE->getOne(Op->getType()))); + MaxIter = SE->getUMinFromMismatchedTypes(Elements); + } else + MaxIter = SE->getMinusSCEV(MaxIter, SE->getOne(MaxIter->getType())); } // Check if there is a loop-invariant predicate equivalent to our check. auto LIP = SE->getLoopInvariantExitCondDuringFirstIterations(Pred, LHSS, RHSS, L, BI, MaxIter); if (!LIP) - return false; + return std::nullopt; // Can we prove it to be trivially true? if (SE->isKnownPredicateAt(LIP->Pred, LIP->LHS, LIP->RHS, BI)) - foldExit(L, ExitingBB, Inverted, DeadInsts); + return createFoldedExitCond(L, ExitingBB, /*IsTaken*/ false); else - replaceWithInvariantCond(L, ExitingBB, LIP->Pred, LIP->LHS, LIP->RHS, - Rewriter, DeadInsts); + return createInvariantCond(L, ExitingBB, *LIP, Rewriter); +} - return true; +static bool optimizeLoopExitWithUnknownExitCount( + const Loop *L, BranchInst *BI, BasicBlock *ExitingBB, const SCEV *MaxIter, + bool SkipLastIter, ScalarEvolution *SE, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) { + assert( + (L->contains(BI->getSuccessor(0)) != L->contains(BI->getSuccessor(1))) && + "Not a loop exit!"); + + // For branch that stays in loop by TRUE condition, go through AND. For branch + // that stays in loop by FALSE condition, go through OR. Both gives the + // similar logic: "stay in loop iff all conditions are true(false)". + bool Inverted = L->contains(BI->getSuccessor(1)); + SmallVector<ICmpInst *, 4> LeafConditions; + SmallVector<Value *, 4> Worklist; + SmallPtrSet<Value *, 4> Visited; + Value *OldCond = BI->getCondition(); + Visited.insert(OldCond); + Worklist.push_back(OldCond); + + auto GoThrough = [&](Value *V) { + Value *LHS = nullptr, *RHS = nullptr; + if (Inverted) { + if (!match(V, m_LogicalOr(m_Value(LHS), m_Value(RHS)))) + return false; + } else { + if (!match(V, m_LogicalAnd(m_Value(LHS), m_Value(RHS)))) + return false; + } + if (Visited.insert(LHS).second) + Worklist.push_back(LHS); + if (Visited.insert(RHS).second) + Worklist.push_back(RHS); + return true; + }; + + do { + Value *Curr = Worklist.pop_back_val(); + // Go through AND/OR conditions. Collect leaf ICMPs. We only care about + // those with one use, to avoid instruction duplication. + if (Curr->hasOneUse()) + if (!GoThrough(Curr)) + if (auto *ICmp = dyn_cast<ICmpInst>(Curr)) + LeafConditions.push_back(ICmp); + } while (!Worklist.empty()); + + // If the current basic block has the same exit count as the whole loop, and + // it consists of multiple icmp's, try to collect all icmp's that give exact + // same exit count. For all other icmp's, we could use one less iteration, + // because their value on the last iteration doesn't really matter. + SmallPtrSet<ICmpInst *, 4> ICmpsFailingOnLastIter; + if (!SkipLastIter && LeafConditions.size() > 1 && + SE->getExitCount(L, ExitingBB, + ScalarEvolution::ExitCountKind::SymbolicMaximum) == + MaxIter) + for (auto *ICmp : LeafConditions) { + auto EL = SE->computeExitLimitFromCond(L, ICmp, Inverted, + /*ControlsExit*/ false); + auto *ExitMax = EL.SymbolicMaxNotTaken; + if (isa<SCEVCouldNotCompute>(ExitMax)) + continue; + // They could be of different types (specifically this happens after + // IV widening). + auto *WiderType = + SE->getWiderType(ExitMax->getType(), MaxIter->getType()); + auto *WideExitMax = SE->getNoopOrZeroExtend(ExitMax, WiderType); + auto *WideMaxIter = SE->getNoopOrZeroExtend(MaxIter, WiderType); + if (WideExitMax == WideMaxIter) + ICmpsFailingOnLastIter.insert(ICmp); + } + + bool Changed = false; + for (auto *OldCond : LeafConditions) { + // Skip last iteration for this icmp under one of two conditions: + // - We do it for all conditions; + // - There is another ICmp that would fail on last iter, so this one doesn't + // really matter. + bool OptimisticSkipLastIter = SkipLastIter; + if (!OptimisticSkipLastIter) { + if (ICmpsFailingOnLastIter.size() > 1) + OptimisticSkipLastIter = true; + else if (ICmpsFailingOnLastIter.size() == 1) + OptimisticSkipLastIter = !ICmpsFailingOnLastIter.count(OldCond); + } + if (auto Replaced = + createReplacement(OldCond, L, ExitingBB, MaxIter, Inverted, + OptimisticSkipLastIter, SE, Rewriter)) { + Changed = true; + auto *NewCond = *Replaced; + if (auto *NCI = dyn_cast<Instruction>(NewCond)) { + NCI->setName(OldCond->getName() + ".first_iter"); + NCI->moveBefore(cast<Instruction>(OldCond)); + } + LLVM_DEBUG(dbgs() << "Unknown exit count: Replacing " << *OldCond + << " with " << *NewCond << "\n"); + assert(OldCond->hasOneUse() && "Must be!"); + OldCond->replaceAllUsesWith(NewCond); + DeadInsts.push_back(OldCond); + // Make sure we no longer consider this condition as failing on last + // iteration. + ICmpsFailingOnLastIter.erase(OldCond); + } + } + return Changed; } bool IndVarSimplify::canonicalizeExitCondition(Loop *L) { @@ -1587,7 +1695,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { // unconditional exit, we can still replace header phis with their // preheader value. if (!L->contains(BI->getSuccessor(CI->isNullValue()))) - replaceLoopPHINodesWithPreheaderValues(LI, L, DeadInsts); + replaceLoopPHINodesWithPreheaderValues(LI, L, DeadInsts, *SE); return true; } @@ -1598,8 +1706,8 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { return false; // Get a symbolic upper bound on the loop backedge taken count. - const SCEV *MaxExitCount = SE->getSymbolicMaxBackedgeTakenCount(L); - if (isa<SCEVCouldNotCompute>(MaxExitCount)) + const SCEV *MaxBECount = SE->getSymbolicMaxBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(MaxBECount)) return false; // Visit our exit blocks in order of dominance. We know from the fact that @@ -1625,22 +1733,37 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { bool Changed = false; bool SkipLastIter = false; - SmallSet<const SCEV*, 8> DominatingExitCounts; + const SCEV *CurrMaxExit = SE->getCouldNotCompute(); + auto UpdateSkipLastIter = [&](const SCEV *MaxExitCount) { + if (SkipLastIter || isa<SCEVCouldNotCompute>(MaxExitCount)) + return; + if (isa<SCEVCouldNotCompute>(CurrMaxExit)) + CurrMaxExit = MaxExitCount; + else + CurrMaxExit = SE->getUMinFromMismatchedTypes(CurrMaxExit, MaxExitCount); + // If the loop has more than 1 iteration, all further checks will be + // executed 1 iteration less. + if (CurrMaxExit == MaxBECount) + SkipLastIter = true; + }; + SmallSet<const SCEV *, 8> DominatingExactExitCounts; for (BasicBlock *ExitingBB : ExitingBlocks) { - const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); - if (isa<SCEVCouldNotCompute>(ExitCount)) { + const SCEV *ExactExitCount = SE->getExitCount(L, ExitingBB); + const SCEV *MaxExitCount = SE->getExitCount( + L, ExitingBB, ScalarEvolution::ExitCountKind::SymbolicMaximum); + if (isa<SCEVCouldNotCompute>(ExactExitCount)) { // Okay, we do not know the exit count here. Can we at least prove that it // will remain the same within iteration space? auto *BI = cast<BranchInst>(ExitingBB->getTerminator()); - auto OptimizeCond = [&](bool Inverted, bool SkipLastIter) { - return optimizeLoopExitWithUnknownExitCount( - L, BI, ExitingBB, MaxExitCount, Inverted, SkipLastIter, SE, - Rewriter, DeadInsts); + auto OptimizeCond = [&](bool SkipLastIter) { + return optimizeLoopExitWithUnknownExitCount(L, BI, ExitingBB, + MaxBECount, SkipLastIter, + SE, Rewriter, DeadInsts); }; // TODO: We might have proved that we can skip the last iteration for // this check. In this case, we only want to check the condition on the - // pre-last iteration (MaxExitCount - 1). However, there is a nasty + // pre-last iteration (MaxBECount - 1). However, there is a nasty // corner case: // // for (i = len; i != 0; i--) { ... check (i ult X) ... } @@ -1652,47 +1775,44 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { // // As a temporary solution, we query both last and pre-last iterations in // hope that we will be able to prove triviality for at least one of - // them. We can stop querying MaxExitCount for this case once SCEV - // understands that (MaxExitCount - 1) will not overflow here. - if (OptimizeCond(false, false) || OptimizeCond(true, false)) + // them. We can stop querying MaxBECount for this case once SCEV + // understands that (MaxBECount - 1) will not overflow here. + if (OptimizeCond(false)) + Changed = true; + else if (SkipLastIter && OptimizeCond(true)) Changed = true; - else if (SkipLastIter) - if (OptimizeCond(false, true) || OptimizeCond(true, true)) - Changed = true; + UpdateSkipLastIter(MaxExitCount); continue; } - if (MaxExitCount == ExitCount) - // If the loop has more than 1 iteration, all further checks will be - // executed 1 iteration less. - SkipLastIter = true; + UpdateSkipLastIter(ExactExitCount); // If we know we'd exit on the first iteration, rewrite the exit to // reflect this. This does not imply the loop must exit through this // exit; there may be an earlier one taken on the first iteration. // We know that the backedge can't be taken, so we replace all // the header PHIs with values coming from the preheader. - if (ExitCount->isZero()) { + if (ExactExitCount->isZero()) { foldExit(L, ExitingBB, true, DeadInsts); - replaceLoopPHINodesWithPreheaderValues(LI, L, DeadInsts); + replaceLoopPHINodesWithPreheaderValues(LI, L, DeadInsts, *SE); Changed = true; continue; } - assert(ExitCount->getType()->isIntegerTy() && - MaxExitCount->getType()->isIntegerTy() && + assert(ExactExitCount->getType()->isIntegerTy() && + MaxBECount->getType()->isIntegerTy() && "Exit counts must be integers"); Type *WiderType = - SE->getWiderType(MaxExitCount->getType(), ExitCount->getType()); - ExitCount = SE->getNoopOrZeroExtend(ExitCount, WiderType); - MaxExitCount = SE->getNoopOrZeroExtend(MaxExitCount, WiderType); - assert(MaxExitCount->getType() == ExitCount->getType()); + SE->getWiderType(MaxBECount->getType(), ExactExitCount->getType()); + ExactExitCount = SE->getNoopOrZeroExtend(ExactExitCount, WiderType); + MaxBECount = SE->getNoopOrZeroExtend(MaxBECount, WiderType); + assert(MaxBECount->getType() == ExactExitCount->getType()); // Can we prove that some other exit must be taken strictly before this // one? - if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT, - MaxExitCount, ExitCount)) { + if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT, MaxBECount, + ExactExitCount)) { foldExit(L, ExitingBB, false, DeadInsts); Changed = true; continue; @@ -1702,7 +1822,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { // find a duplicate, we've found an exit which would have exited on the // exiting iteration, but (from the visit order) strictly follows another // which does the same and is thus dead. - if (!DominatingExitCounts.insert(ExitCount).second) { + if (!DominatingExactExitCounts.insert(ExactExitCount).second) { foldExit(L, ExitingBB, false, DeadInsts); Changed = true; continue; diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 328615011ceb..52a4bc8a9f24 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -45,8 +45,6 @@ #include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -93,6 +91,7 @@ #include <cassert> #include <iterator> #include <limits> +#include <optional> #include <utility> #include <vector> @@ -211,9 +210,9 @@ public: /// Computes a range for the induction variable (IndVar) in which the range /// check is redundant and can be constant-folded away. The induction /// variable is not required to be the canonical {0,+,1} induction variable. - Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, - const SCEVAddRecExpr *IndVar, - bool IsLatchSigned) const; + std::optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, + const SCEVAddRecExpr *IndVar, + bool IsLatchSigned) const; /// Parse out a set of inductive range checks from \p BI and append them to \p /// Checks. @@ -235,7 +234,7 @@ class InductiveRangeCheckElimination { LoopInfo &LI; using GetBFIFunc = - llvm::Optional<llvm::function_ref<llvm::BlockFrequencyInfo &()> >; + std::optional<llvm::function_ref<llvm::BlockFrequencyInfo &()>>; GetBFIFunc GetBFI; // Returns true if it is profitable to do a transform basing on estimation of @@ -245,7 +244,7 @@ class InductiveRangeCheckElimination { public: InductiveRangeCheckElimination(ScalarEvolution &SE, BranchProbabilityInfo *BPI, DominatorTree &DT, - LoopInfo &LI, GetBFIFunc GetBFI = None) + LoopInfo &LI, GetBFIFunc GetBFI = std::nullopt) : SE(SE), BPI(BPI), DT(DT), LI(LI), GetBFI(GetBFI) {} bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop); @@ -307,7 +306,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, case ICmpInst::ICMP_SLE: std::swap(LHS, RHS); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_SGE: IsSigned = true; if (match(RHS, m_ConstantInt<0>())) { @@ -318,7 +317,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, case ICmpInst::ICMP_SLT: std::swap(LHS, RHS); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_SGT: IsSigned = true; if (match(RHS, m_ConstantInt<-1>())) { @@ -335,7 +334,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, case ICmpInst::ICMP_ULT: std::swap(LHS, RHS); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_UGT: IsSigned = false; if (IsLoopInvariant(LHS)) { @@ -503,8 +502,8 @@ struct LoopStructure { return Result; } - static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &, Loop &, - const char *&); + static std::optional<LoopStructure> parseLoopStructure(ScalarEvolution &, + Loop &, const char *&); }; /// This class is used to constrain loops to run within a given iteration space. @@ -541,20 +540,20 @@ class LoopConstrainer { // Calculated subranges we restrict the iteration space of the main loop to. // See the implementation of `calculateSubRanges' for more details on how - // these fields are computed. `LowLimit` is None if there is no restriction - // on low end of the restricted iteration space of the main loop. `HighLimit` - // is None if there is no restriction on high end of the restricted iteration - // space of the main loop. + // these fields are computed. `LowLimit` is std::nullopt if there is no + // restriction on low end of the restricted iteration space of the main loop. + // `HighLimit` is std::nullopt if there is no restriction on high end of the + // restricted iteration space of the main loop. struct SubRanges { - Optional<const SCEV *> LowLimit; - Optional<const SCEV *> HighLimit; + std::optional<const SCEV *> LowLimit; + std::optional<const SCEV *> HighLimit; }; // Compute a safe set of limits for the main loop to run in -- effectively the // intersection of `Range' and the iteration space of the original loop. - // Return None if unable to compute the set of subranges. - Optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const; + // Return std::nullopt if unable to compute the set of subranges. + std::optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const; // Clone `OriginalLoop' and return the result in CLResult. The IR after // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- @@ -747,12 +746,12 @@ static bool isSafeIncreasingBound(const SCEV *Start, SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); } -Optional<LoopStructure> +std::optional<LoopStructure> LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, const char *&FailureReason) { if (!L.isLoopSimplifyForm()) { FailureReason = "loop not in LoopSimplify form"; - return None; + return std::nullopt; } BasicBlock *Latch = L.getLoopLatch(); @@ -760,25 +759,25 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { FailureReason = "loop has already been cloned"; - return None; + return std::nullopt; } if (!L.isLoopExiting(Latch)) { FailureReason = "no loop latch"; - return None; + return std::nullopt; } BasicBlock *Header = L.getHeader(); BasicBlock *Preheader = L.getLoopPreheader(); if (!Preheader) { FailureReason = "no preheader"; - return None; + return std::nullopt; } BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); if (!LatchBr || LatchBr->isUnconditional()) { FailureReason = "latch terminator not conditional branch"; - return None; + return std::nullopt; } unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; @@ -786,13 +785,13 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { FailureReason = "latch terminator branch not conditional on integral icmp"; - return None; + return std::nullopt; } const SCEV *LatchCount = SE.getExitCount(&L, Latch); if (isa<SCEVCouldNotCompute>(LatchCount)) { FailureReason = "could not compute latch count"; - return None; + return std::nullopt; } ICmpInst::Predicate Pred = ICI->getPredicate(); @@ -811,7 +810,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, Pred = ICmpInst::getSwappedPredicate(Pred); } else { FailureReason = "no add recurrences in the icmp"; - return None; + return std::nullopt; } } @@ -845,20 +844,24 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, // induction variable satisfies some constraint. const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); + if (IndVarBase->getLoop() != &L) { + FailureReason = "LHS in cmp is not an AddRec for this loop"; + return std::nullopt; + } if (!IndVarBase->isAffine()) { FailureReason = "LHS in icmp not induction variable"; - return None; + return std::nullopt; } const SCEV* StepRec = IndVarBase->getStepRecurrence(SE); if (!isa<SCEVConstant>(StepRec)) { FailureReason = "LHS in icmp not induction variable"; - return None; + return std::nullopt; } ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { FailureReason = "LHS in icmp needs nsw for equality predicates"; - return None; + return std::nullopt; } assert(!StepCI->isZero() && "Zero step?"); @@ -921,19 +924,19 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, if (!FoundExpectedPred) { FailureReason = "expected icmp slt semantically, found something else"; - return None; + return std::nullopt; } IsSignedPredicate = ICmpInst::isSigned(Pred); if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { FailureReason = "unsigned latch conditions are explicitly prohibited"; - return None; + return std::nullopt; } if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, LatchBrExitIdx, &L, SE)) { FailureReason = "Unsafe loop bounds"; - return None; + return std::nullopt; } if (LatchBrExitIdx == 0) { // We need to increase the right value unless we have already decreased @@ -984,7 +987,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, if (!FoundExpectedPred) { FailureReason = "expected icmp sgt semantically, found something else"; - return None; + return std::nullopt; } IsSignedPredicate = @@ -992,13 +995,13 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { FailureReason = "unsigned latch conditions are explicitly prohibited"; - return None; + return std::nullopt; } if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, LatchBrExitIdx, &L, SE)) { FailureReason = "Unsafe bounds"; - return None; + return std::nullopt; } if (LatchBrExitIdx == 0) { @@ -1057,7 +1060,7 @@ static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty); } -Optional<LoopConstrainer::SubRanges> +std::optional<LoopConstrainer::SubRanges> LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); @@ -1065,9 +1068,9 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { // We only support wide range checks and narrow latches. if (!AllowNarrowLatchCondition && RTy != Ty) - return None; + return std::nullopt; if (RTy->getBitWidth() < Ty->getBitWidth()) - return None; + return std::nullopt; LoopConstrainer::SubRanges Result; @@ -1184,6 +1187,7 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, for (PHINode &PN : SBB->phis()) { Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); + SE.forgetValue(&PN); } } } @@ -1408,7 +1412,7 @@ bool LoopConstrainer::run() { MainLoopPreheader = Preheader; bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; - Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); + std::optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); if (!MaybeSR) { LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n"); return false; @@ -1423,7 +1427,7 @@ bool LoopConstrainer::run() { Instruction *InsertPt = OriginalPreheader->getTerminator(); // It would have been better to make `PreLoop' and `PostLoop' - // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy + // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy // constructor. ClonedLoop PreLoop, PostLoop; bool NeedsPreLoop = @@ -1534,7 +1538,7 @@ bool LoopConstrainer::run() { auto NewBlocksEnd = std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); - addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd)); + addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd)); DT.recalculate(F); @@ -1575,17 +1579,20 @@ bool LoopConstrainer::run() { /// Computes and returns a range of values for the induction variable (IndVar) /// in which the range check can be safely elided. If it cannot compute such a -/// range, returns None. -Optional<InductiveRangeCheck::Range> -InductiveRangeCheck::computeSafeIterationSpace( - ScalarEvolution &SE, const SCEVAddRecExpr *IndVar, - bool IsLatchSigned) const { +/// range, returns std::nullopt. +std::optional<InductiveRangeCheck::Range> +InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, + const SCEVAddRecExpr *IndVar, + bool IsLatchSigned) const { // We can deal when types of latch check and range checks don't match in case // if latch check is more narrow. - auto *IVType = cast<IntegerType>(IndVar->getType()); - auto *RCType = cast<IntegerType>(getBegin()->getType()); + auto *IVType = dyn_cast<IntegerType>(IndVar->getType()); + auto *RCType = dyn_cast<IntegerType>(getBegin()->getType()); + // Do not work with pointer types. + if (!IVType || !RCType) + return std::nullopt; if (IVType->getBitWidth() > RCType->getBitWidth()) - return None; + return std::nullopt; // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is @@ -1607,19 +1614,19 @@ InductiveRangeCheck::computeSafeIterationSpace( // to deal with overflown values. if (!IndVar->isAffine()) - return None; + return std::nullopt; const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned); const SCEVConstant *B = dyn_cast<SCEVConstant>( NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned)); if (!B) - return None; + return std::nullopt; assert(!B->isZero() && "Recurrence with zero step?"); const SCEV *C = getBegin(); const SCEVConstant *D = dyn_cast<SCEVConstant>(getStep()); if (D != B) - return None; + return std::nullopt; assert(!D->getValue()->isZero() && "Recurrence with zero step?"); unsigned BitWidth = RCType->getBitWidth(); @@ -1702,15 +1709,15 @@ InductiveRangeCheck::computeSafeIterationSpace( return InductiveRangeCheck::Range(Begin, End); } -static Optional<InductiveRangeCheck::Range> +static std::optional<InductiveRangeCheck::Range> IntersectSignedRange(ScalarEvolution &SE, - const Optional<InductiveRangeCheck::Range> &R1, + const std::optional<InductiveRangeCheck::Range> &R1, const InductiveRangeCheck::Range &R2) { if (R2.isEmpty(SE, /* IsSigned */ true)) - return None; + return std::nullopt; if (!R1) return R2; - auto &R1Value = R1.value(); + auto &R1Value = *R1; // We never return empty ranges from this function, and R1 is supposed to be // a result of intersection. Thus, R1 is never empty. assert(!R1Value.isEmpty(SE, /* IsSigned */ true) && @@ -1719,27 +1726,27 @@ IntersectSignedRange(ScalarEvolution &SE, // TODO: we could widen the smaller range and have this work; but for now we // bail out to keep things simple. if (R1Value.getType() != R2.getType()) - return None; + return std::nullopt; const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin()); const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd()); - // If the resulting range is empty, just return None. + // If the resulting range is empty, just return std::nullopt. auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); if (Ret.isEmpty(SE, /* IsSigned */ true)) - return None; + return std::nullopt; return Ret; } -static Optional<InductiveRangeCheck::Range> +static std::optional<InductiveRangeCheck::Range> IntersectUnsignedRange(ScalarEvolution &SE, - const Optional<InductiveRangeCheck::Range> &R1, + const std::optional<InductiveRangeCheck::Range> &R1, const InductiveRangeCheck::Range &R2) { if (R2.isEmpty(SE, /* IsSigned */ false)) - return None; + return std::nullopt; if (!R1) return R2; - auto &R1Value = R1.value(); + auto &R1Value = *R1; // We never return empty ranges from this function, and R1 is supposed to be // a result of intersection. Thus, R1 is never empty. assert(!R1Value.isEmpty(SE, /* IsSigned */ false) && @@ -1748,15 +1755,15 @@ IntersectUnsignedRange(ScalarEvolution &SE, // TODO: we could widen the smaller range and have this work; but for now we // bail out to keep things simple. if (R1Value.getType() != R2.getType()) - return None; + return std::nullopt; const SCEV *NewBegin = SE.getUMaxExpr(R1Value.getBegin(), R2.getBegin()); const SCEV *NewEnd = SE.getUMinExpr(R1Value.getEnd(), R2.getEnd()); - // If the resulting range is empty, just return None. + // If the resulting range is empty, just return std::nullopt. auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); if (Ret.isEmpty(SE, /* IsSigned */ false)) - return None; + return std::nullopt; return Ret; } @@ -1898,7 +1905,7 @@ bool InductiveRangeCheckElimination::run( LLVMContext &Context = Preheader->getContext(); SmallVector<InductiveRangeCheck, 16> RangeChecks; - for (auto BBI : L->getBlocks()) + for (auto *BBI : L->getBlocks()) if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI, RangeChecks); @@ -1920,7 +1927,7 @@ bool InductiveRangeCheckElimination::run( PrintRecognizedRangeChecks(errs()); const char *FailureReason = nullptr; - Optional<LoopStructure> MaybeLoopStructure = + std::optional<LoopStructure> MaybeLoopStructure = LoopStructure::parseLoopStructure(SE, *L, FailureReason); if (!MaybeLoopStructure) { LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " @@ -1933,7 +1940,7 @@ bool InductiveRangeCheckElimination::run( const SCEVAddRecExpr *IndVar = cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep))); - Optional<InductiveRangeCheck::Range> SafeIterRange; + std::optional<InductiveRangeCheck::Range> SafeIterRange; Instruction *ExprInsertPt = Preheader->getTerminator(); SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; @@ -1949,13 +1956,12 @@ bool InductiveRangeCheckElimination::run( auto Result = IRC.computeSafeIterationSpace(SE, IndVar, LS.IsSignedPredicate); if (Result) { - auto MaybeSafeIterRange = - IntersectRange(SE, SafeIterRange, Result.value()); + auto MaybeSafeIterRange = IntersectRange(SE, SafeIterRange, *Result); if (MaybeSafeIterRange) { - assert(!MaybeSafeIterRange.value().isEmpty(SE, LS.IsSignedPredicate) && + assert(!MaybeSafeIterRange->isEmpty(SE, LS.IsSignedPredicate) && "We should never return empty ranges!"); RangeChecksToEliminate.push_back(IRC); - SafeIterRange = MaybeSafeIterRange.value(); + SafeIterRange = *MaybeSafeIterRange; } } } @@ -1963,7 +1969,7 @@ bool InductiveRangeCheckElimination::run( if (!SafeIterRange) return false; - LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, SafeIterRange.value()); + LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, *SafeIterRange); bool Changed = LC.run(); if (Changed) { diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index 5eefde2e37a1..114738a35fd1 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -547,6 +547,7 @@ static Value *operandWithNewAddressSpaceOrCreateUndef( cast<PointerType>(Operand->getType()), NewAS); auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy); NewI->insertBefore(Inst); + NewI->setDebugLoc(Inst->getDebugLoc()); return NewI; } @@ -774,6 +775,7 @@ Value *InferAddressSpacesImpl::cloneValueWithNewAddressSpace( if (NewI->getParent() == nullptr) { NewI->insertBefore(I); NewI->takeName(I); + NewI->setDebugLoc(I->getDebugLoc()); } } return NewV; diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index b31eab50c5ec..f41eaed2e3e7 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -14,7 +14,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -54,6 +53,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" @@ -99,6 +99,11 @@ ImplicationSearchThreshold( "condition to use to thread over a weaker condition"), cl::init(3), cl::Hidden); +static cl::opt<unsigned> PhiDuplicateThreshold( + "jump-threading-phi-threshold", + cl::desc("Max PHIs in BB to duplicate for jump threading"), cl::init(76), + cl::Hidden); + static cl::opt<bool> PrintLVIAfterJumpThreading( "print-lvi-after-jump-threading", cl::desc("Print the LazyValueInfo cache after JumpThreading"), cl::init(false), @@ -216,7 +221,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { return; uint64_t TrueWeight, FalseWeight; - if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight)) + if (!extractBranchWeights(*CondBr, TrueWeight, FalseWeight)) return; if (TrueWeight + FalseWeight == 0) @@ -279,7 +284,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { // With PGO, this can be used to refine even existing profile data with // context information. This needs to be done after more performance // testing. - if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight)) + if (extractBranchWeights(*PredBr, PredTrueWeight, PredFalseWeight)) continue; // We can not infer anything useful when BP >= 50%, because BP is the @@ -346,7 +351,7 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, std::unique_ptr<BlockFrequencyInfo> BFI; std::unique_ptr<BranchProbabilityInfo> BPI; if (F.hasProfileData()) { - LoopInfo LI{DominatorTree(F)}; + LoopInfo LI{DT}; BPI.reset(new BranchProbabilityInfo(F, LI, &TLI)); BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } @@ -517,8 +522,23 @@ static unsigned getJumpThreadDuplicationCost(const TargetTransformInfo *TTI, Instruction *StopAt, unsigned Threshold) { assert(StopAt->getParent() == BB && "Not an instruction from proper BB?"); + + // Do not duplicate the BB if it has a lot of PHI nodes. + // If a threadable chain is too long then the number of PHI nodes can add up, + // leading to a substantial increase in compile time when rewriting the SSA. + unsigned PhiCount = 0; + Instruction *FirstNonPHI = nullptr; + for (Instruction &I : *BB) { + if (!isa<PHINode>(&I)) { + FirstNonPHI = &I; + break; + } + if (++PhiCount > PhiDuplicateThreshold) + return ~0U; + } + /// Ignore PHI nodes, these will be flattened when duplication happens. - BasicBlock::const_iterator I(BB->getFirstNonPHI()); + BasicBlock::const_iterator I(FirstNonPHI); // FIXME: THREADING will delete values that are just used to compute the // branch, so they shouldn't count against the duplication cost. @@ -560,8 +580,8 @@ static unsigned getJumpThreadDuplicationCost(const TargetTransformInfo *TTI, if (CI->cannotDuplicate() || CI->isConvergent()) return ~0U; - if (TTI->getUserCost(&*I, TargetTransformInfo::TCK_SizeAndLatency) - == TargetTransformInfo::TCC_Free) + if (TTI->getInstructionCost(&*I, TargetTransformInfo::TCK_SizeAndLatency) == + TargetTransformInfo::TCC_Free) continue; // All other instructions count for at least one unit. @@ -653,22 +673,25 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( Instruction *I = dyn_cast<Instruction>(V); if (!I || I->getParent() != BB) { - // Okay, if this is a live-in value, see if it has a known value at the end - // of any of our predecessors. - // - // FIXME: This should be an edge property, not a block end property. - /// TODO: Per PR2563, we could infer value range information about a - /// predecessor based on its terminator. - // - // FIXME: change this to use the more-rich 'getPredicateOnEdge' method if - // "I" is a non-local compare-with-a-constant instruction. This would be - // able to handle value inequalities better, for example if the compare is - // "X < 4" and "X < 3" is known true but "X < 4" itself is not available. - // Perhaps getConstantOnEdge should be smart enough to do this? + // Okay, if this is a live-in value, see if it has a known value at the any + // edge from our predecessors. for (BasicBlock *P : predecessors(BB)) { + using namespace PatternMatch; // If the value is known by LazyValueInfo to be a constant in a // predecessor, use that information to try to thread this block. Constant *PredCst = LVI->getConstantOnEdge(V, P, BB, CxtI); + // If I is a non-local compare-with-constant instruction, use more-rich + // 'getPredicateOnEdge' method. This would be able to handle value + // inequalities better, for example if the compare is "X < 4" and "X < 3" + // is known true but "X < 4" itself is not available. + CmpInst::Predicate Pred; + Value *Val; + Constant *Cst; + if (!PredCst && match(V, m_Cmp(Pred, m_Value(Val), m_Constant(Cst)))) { + auto Res = LVI->getPredicateOnEdge(Pred, Val, Cst, P, BB, CxtI); + if (Res != LazyValueInfo::Unknown) + PredCst = ConstantInt::getBool(V->getContext(), Res); + } if (Constant *KC = getKnownConstant(PredCst, Preference)) Result.emplace_back(KC, P); } @@ -1250,7 +1273,7 @@ bool JumpThreadingPass::processImpliedCondition(BasicBlock *BB) { return false; bool CondIsTrue = PBI->getSuccessor(0) == CurrentBB; - Optional<bool> Implication = + std::optional<bool> Implication = isImpliedCondition(PBI->getCondition(), Cond, DL, CondIsTrue); // If the branch condition of BB (which is Cond) and CurrentPred are @@ -1908,7 +1931,7 @@ bool JumpThreadingPass::processBranchOnXOR(BinaryOperator *BO) { // If all preds provide undef, just nuke the xor, because it is undef too. BO->replaceAllUsesWith(UndefValue::get(BO->getType())); BO->eraseFromParent(); - } else if (SplitVal->isZero()) { + } else if (SplitVal->isZero() && BO != BO->getOperand(isLHS)) { // If all preds provide 0, replace the xor with the other input. BO->replaceAllUsesWith(BO->getOperand(isLHS)); BO->eraseFromParent(); @@ -2060,6 +2083,30 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, // block, evaluate them to account for entry from PredBB. DenseMap<Instruction *, Value *> ValueMapping; + // Retargets llvm.dbg.value to any renamed variables. + auto RetargetDbgValueIfPossible = [&](Instruction *NewInst) -> bool { + auto DbgInstruction = dyn_cast<DbgValueInst>(NewInst); + if (!DbgInstruction) + return false; + + SmallSet<std::pair<Value *, Value *>, 16> OperandsToRemap; + for (auto DbgOperand : DbgInstruction->location_ops()) { + auto DbgOperandInstruction = dyn_cast<Instruction>(DbgOperand); + if (!DbgOperandInstruction) + continue; + + auto I = ValueMapping.find(DbgOperandInstruction); + if (I != ValueMapping.end()) { + OperandsToRemap.insert( + std::pair<Value *, Value *>(DbgOperand, I->second)); + } + } + + for (auto &[OldOp, MappedOp] : OperandsToRemap) + DbgInstruction->replaceVariableLocationOp(OldOp, MappedOp); + return true; + }; + // Clone the phi nodes of the source basic block into NewBB. The resulting // phi nodes are trivial since NewBB only has one predecessor, but SSAUpdater // might need to rewrite the operand of the cloned phi. @@ -2084,10 +2131,13 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, for (; BI != BE; ++BI) { Instruction *New = BI->clone(); New->setName(BI->getName()); - NewBB->getInstList().push_back(New); + New->insertInto(NewBB, NewBB->end()); ValueMapping[&*BI] = New; adaptNoAliasScopes(New, ClonedScopes, Context); + if (RetargetDbgValueIfPossible(New)) + continue; + // Remap operands to patch up intra-block references. for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) { @@ -2437,7 +2487,7 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB, // update the edge weight of the result of splitting predecessors. DenseMap<BasicBlock *, BlockFrequency> FreqMap; if (HasProfileData) - for (auto Pred : Preds) + for (auto *Pred : Preds) FreqMap.insert(std::make_pair( Pred, BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB))); @@ -2452,10 +2502,10 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB, std::vector<DominatorTree::UpdateType> Updates; Updates.reserve((2 * Preds.size()) + NewBBs.size()); - for (auto NewBB : NewBBs) { + for (auto *NewBB : NewBBs) { BlockFrequency NewBBFreq(0); Updates.push_back({DominatorTree::Insert, NewBB, BB}); - for (auto Pred : predecessors(NewBB)) { + for (auto *Pred : predecessors(NewBB)) { Updates.push_back({DominatorTree::Delete, Pred, BB}); Updates.push_back({DominatorTree::Insert, Pred, NewBB}); if (HasProfileData) // Update frequencies between Pred -> NewBB. @@ -2472,18 +2522,7 @@ BasicBlock *JumpThreadingPass::splitBlockPreds(BasicBlock *BB, bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) { const Instruction *TI = BB->getTerminator(); assert(TI->getNumSuccessors() > 1 && "not a split"); - - MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); - if (!WeightsNode) - return false; - - MDString *MDName = cast<MDString>(WeightsNode->getOperand(0)); - if (MDName->getString() != "branch_weights") - return false; - - // Ensure there are weights for all of the successors. Note that the first - // operand to the metadata node is a name, not a weight. - return WeightsNode->getNumOperands() == TI->getNumSuccessors() + 1; + return hasValidBranchWeightMD(*TI); } /// Update the block frequency of BB and branch weight and the metadata on the @@ -2677,7 +2716,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( if (New) { // Otherwise, insert the new instruction into the block. New->setName(BI->getName()); - PredBB->getInstList().insert(OldPredBranch->getIterator(), New); + New->insertInto(PredBB, OldPredBranch->getIterator()); // Update Dominance from simplified New instruction operands. for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) if (BasicBlock *SuccBB = dyn_cast<BasicBlock>(New->getOperand(i))) @@ -2731,12 +2770,30 @@ void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB, BB->getParent(), BB); // Move the unconditional branch to NewBB. PredTerm->removeFromParent(); - NewBB->getInstList().insert(NewBB->end(), PredTerm); + PredTerm->insertInto(NewBB, NewBB->end()); // Create a conditional branch and update PHI nodes. auto *BI = BranchInst::Create(NewBB, BB, SI->getCondition(), Pred); BI->applyMergedLocation(PredTerm->getDebugLoc(), SI->getDebugLoc()); + BI->copyMetadata(*SI, {LLVMContext::MD_prof}); SIUse->setIncomingValue(Idx, SI->getFalseValue()); SIUse->addIncoming(SI->getTrueValue(), NewBB); + // Set the block frequency of NewBB. + if (HasProfileData) { + uint64_t TrueWeight, FalseWeight; + if (extractBranchWeights(*SI, TrueWeight, FalseWeight) && + (TrueWeight + FalseWeight) != 0) { + SmallVector<BranchProbability, 2> BP; + BP.emplace_back(BranchProbability::getBranchProbability( + TrueWeight, TrueWeight + FalseWeight)); + BP.emplace_back(BranchProbability::getBranchProbability( + FalseWeight, TrueWeight + FalseWeight)); + BPI->setEdgeProbability(Pred, BP); + } + + auto NewBBFreq = + BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, NewBB); + BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + } // The select is now dead. SI->eraseFromParent(); diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index f54264b1dca6..2865dece8723 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -42,6 +42,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GuardUtils.h" @@ -75,6 +76,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -86,7 +88,6 @@ using namespace llvm; namespace llvm { -class BlockFrequencyInfo; class LPMUpdater; } // namespace llvm @@ -98,7 +99,9 @@ STATISTIC(NumSunk, "Number of instructions sunk out of loop"); STATISTIC(NumHoisted, "Number of instructions hoisted out of loop"); STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk"); STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk"); -STATISTIC(NumPromoted, "Number of memory locations promoted to registers"); +STATISTIC(NumPromotionCandidates, "Number of promotion candidates"); +STATISTIC(NumLoadPromoted, "Number of load-only promotions"); +STATISTIC(NumLoadStorePromoted, "Number of load and store promotions"); /// Memory promotion is enabled by default. static cl::opt<bool> @@ -109,6 +112,10 @@ static cl::opt<bool> ControlFlowHoisting( "licm-control-flow-hoisting", cl::Hidden, cl::init(false), cl::desc("Enable control flow (and PHI) hoisting in LICM")); +static cl::opt<bool> + SingleThread("licm-force-thread-model-single", cl::Hidden, cl::init(false), + cl::desc("Force thread model single in LICM pass")); + static cl::opt<uint32_t> MaxNumUsesTraversed( "licm-max-num-uses-traversed", cl::Hidden, cl::init(8), cl::desc("Max num uses visited for identifying load " @@ -147,14 +154,13 @@ static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, MemorySSAUpdater &MSSAU, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE); static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, - BlockFrequencyInfo *BFI, const Loop *CurLoop, - ICFLoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU, - OptimizationRemarkEmitter *ORE); + const Loop *CurLoop, ICFLoopSafetyInfo *SafetyInfo, + MemorySSAUpdater &MSSAU, OptimizationRemarkEmitter *ORE); static bool isSafeToExecuteUnconditionally( Instruction &Inst, const DominatorTree *DT, const TargetLibraryInfo *TLI, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, const Instruction *CtxI, - bool AllowSpeculation); + AssumptionCache *AC, bool AllowSpeculation); static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU, Loop *CurLoop, Instruction &I, SinkAndHoistLICMFlags &Flags); @@ -173,13 +179,15 @@ static void moveInstructionBefore(Instruction &I, Instruction &Dest, static void foreachMemoryAccess(MemorySSA *MSSA, Loop *L, function_ref<void(Instruction *)> Fn); -static SmallVector<SmallSetVector<Value *, 8>, 0> +using PointersAndHasReadsOutsideSet = + std::pair<SmallSetVector<Value *, 8>, bool>; +static SmallVector<PointersAndHasReadsOutsideSet, 0> collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L); namespace { struct LoopInvariantCodeMotion { bool runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, DominatorTree *DT, - BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, + AssumptionCache *AC, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, ScalarEvolution *SE, MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, bool LoopNestMode = false); @@ -214,12 +222,10 @@ struct LegacyLICMPass : public LoopPass { LLVM_DEBUG(dbgs() << "Perform LICM on Loop with header at block " << L->getHeader()->getNameOrAsOperand() << "\n"); + Function *F = L->getHeader()->getParent(); + auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - bool hasProfileData = L->getHeader()->getParent()->hasProfileData(); - BlockFrequencyInfo *BFI = - hasProfileData ? &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() - : nullptr; // For the old PM, we can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations // but ORE cannot be preserved (see comment before the pass definition). @@ -227,11 +233,10 @@ struct LegacyLICMPass : public LoopPass { return LICM.runOnLoop( L, &getAnalysis<AAResultsWrapperPass>().getAAResults(), &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), - &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), BFI, - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( - *L->getHeader()->getParent()), - &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *L->getHeader()->getParent()), + &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(*F), + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(*F), + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*F), SE ? &SE->getSE() : nullptr, MSSA, &ORE); } @@ -245,6 +250,7 @@ struct LegacyLICMPass : public LoopPass { AU.addRequired<MemorySSAWrapperPass>(); AU.addPreserved<MemorySSAWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); getLoopAnalysisUsage(AU); LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); AU.addPreserved<LazyBlockFrequencyInfoPass>(); @@ -259,7 +265,8 @@ private: PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { if (!AR.MSSA) - report_fatal_error("LICM requires MemorySSA (loop-mssa)"); + report_fatal_error("LICM requires MemorySSA (loop-mssa)", + /*GenCrashDiag*/false); // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations @@ -268,7 +275,7 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, LoopInvariantCodeMotion LICM(Opts.MssaOptCap, Opts.MssaNoAccForPromotionCap, Opts.AllowSpeculation); - if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, AR.BFI, &AR.TLI, &AR.TTI, + if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.AC, &AR.TLI, &AR.TTI, &AR.SE, AR.MSSA, &ORE)) return PreservedAnalyses::all(); @@ -295,7 +302,8 @@ PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { if (!AR.MSSA) - report_fatal_error("LNICM requires MemorySSA (loop-mssa)"); + report_fatal_error("LNICM requires MemorySSA (loop-mssa)", + /*GenCrashDiag*/false); // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations @@ -306,7 +314,7 @@ PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM, Opts.AllowSpeculation); Loop &OutermostLoop = LN.getOutermostLoop(); - bool Changed = LICM.runOnLoop(&OutermostLoop, &AR.AA, &AR.LI, &AR.DT, AR.BFI, + bool Changed = LICM.runOnLoop(&OutermostLoop, &AR.AA, &AR.LI, &AR.DT, &AR.AC, &AR.TLI, &AR.TTI, &AR.SE, AR.MSSA, &ORE, true); if (!Changed) @@ -382,11 +390,13 @@ llvm::SinkAndHoistLICMFlags::SinkAndHoistLICMFlags( /// Hoist expressions out of the specified loop. Note, alias info for inner /// loop is not preserved so it is not a good idea to run LICM multiple /// times on one loop. -bool LoopInvariantCodeMotion::runOnLoop( - Loop *L, AAResults *AA, LoopInfo *LI, DominatorTree *DT, - BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, - ScalarEvolution *SE, MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, - bool LoopNestMode) { +bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, + DominatorTree *DT, AssumptionCache *AC, + TargetLibraryInfo *TLI, + TargetTransformInfo *TTI, + ScalarEvolution *SE, MemorySSA *MSSA, + OptimizationRemarkEmitter *ORE, + bool LoopNestMode) { bool Changed = false; assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); @@ -435,15 +445,15 @@ bool LoopInvariantCodeMotion::runOnLoop( // us to sink instructions in one pass, without iteration. After sinking // instructions, we perform another pass to hoist them out of the loop. if (L->hasDedicatedExits()) - Changed |= LoopNestMode - ? sinkRegionForLoopNest(DT->getNode(L->getHeader()), AA, LI, - DT, BFI, TLI, TTI, L, MSSAU, - &SafetyInfo, Flags, ORE) - : sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, - TLI, TTI, L, MSSAU, &SafetyInfo, Flags, ORE); + Changed |= + LoopNestMode + ? sinkRegionForLoopNest(DT->getNode(L->getHeader()), AA, LI, DT, + TLI, TTI, L, MSSAU, &SafetyInfo, Flags, ORE) + : sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, TTI, L, + MSSAU, &SafetyInfo, Flags, ORE); Flags.setIsSink(false); if (Preheader) - Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, TLI, L, + Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, AC, TLI, L, MSSAU, SE, &SafetyInfo, Flags, ORE, LoopNestMode, LicmAllowSpeculation); @@ -483,11 +493,12 @@ bool LoopInvariantCodeMotion::runOnLoop( bool LocalPromoted; do { LocalPromoted = false; - for (const SmallSetVector<Value *, 8> &PointerMustAliases : + for (auto [PointerMustAliases, HasReadsOutsideSet] : collectPromotionCandidates(MSSA, AA, L)) { LocalPromoted |= promoteLoopAccessesToScalars( PointerMustAliases, ExitBlocks, InsertPts, MSSAInsertPts, PIC, LI, - DT, TLI, L, MSSAU, &SafetyInfo, ORE, LicmAllowSpeculation); + DT, AC, TLI, TTI, L, MSSAU, &SafetyInfo, ORE, + LicmAllowSpeculation, HasReadsOutsideSet); } Promoted |= LocalPromoted; } while (LocalPromoted); @@ -516,7 +527,7 @@ bool LoopInvariantCodeMotion::runOnLoop( MSSA->verifyMemorySSA(); if (Changed && SE) - SE->forgetLoopDispositions(L); + SE->forgetLoopDispositions(); return Changed; } @@ -526,10 +537,9 @@ bool LoopInvariantCodeMotion::runOnLoop( /// definitions, allowing us to sink a loop body in one pass without iteration. /// bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, - DominatorTree *DT, BlockFrequencyInfo *BFI, - TargetLibraryInfo *TLI, TargetTransformInfo *TTI, - Loop *CurLoop, MemorySSAUpdater &MSSAU, - ICFLoopSafetyInfo *SafetyInfo, + DominatorTree *DT, TargetLibraryInfo *TLI, + TargetTransformInfo *TTI, Loop *CurLoop, + MemorySSAUpdater &MSSAU, ICFLoopSafetyInfo *SafetyInfo, SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE, Loop *OutermostLoop) { @@ -577,7 +587,7 @@ bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, isNotUsedOrFreeInLoop(I, LoopNestMode ? OutermostLoop : CurLoop, SafetyInfo, TTI, FreeInLoop, LoopNestMode) && canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE)) { - if (sink(I, LI, DT, BFI, CurLoop, SafetyInfo, MSSAU, ORE)) { + if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE)) { if (!FreeInLoop) { ++II; salvageDebugInfo(I); @@ -593,11 +603,13 @@ bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, return Changed; } -bool llvm::sinkRegionForLoopNest( - DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, - BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, - Loop *CurLoop, MemorySSAUpdater &MSSAU, ICFLoopSafetyInfo *SafetyInfo, - SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE) { +bool llvm::sinkRegionForLoopNest(DomTreeNode *N, AAResults *AA, LoopInfo *LI, + DominatorTree *DT, TargetLibraryInfo *TLI, + TargetTransformInfo *TTI, Loop *CurLoop, + MemorySSAUpdater &MSSAU, + ICFLoopSafetyInfo *SafetyInfo, + SinkAndHoistLICMFlags &Flags, + OptimizationRemarkEmitter *ORE) { bool Changed = false; SmallPriorityWorklist<Loop *, 4> Worklist; @@ -605,8 +617,8 @@ bool llvm::sinkRegionForLoopNest( appendLoopsToWorklist(*CurLoop, Worklist); while (!Worklist.empty()) { Loop *L = Worklist.pop_back_val(); - Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, TLI, - TTI, L, MSSAU, SafetyInfo, Flags, ORE, CurLoop); + Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, TTI, L, + MSSAU, SafetyInfo, Flags, ORE, CurLoop); } return Changed; } @@ -845,7 +857,7 @@ public: /// uses, allowing us to hoist a loop body in one pass without iteration. /// bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, - DominatorTree *DT, BlockFrequencyInfo *BFI, + DominatorTree *DT, AssumptionCache *AC, TargetLibraryInfo *TLI, Loop *CurLoop, MemorySSAUpdater &MSSAU, ScalarEvolution *SE, ICFLoopSafetyInfo *SafetyInfo, @@ -902,7 +914,8 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE) && isSafeToExecuteUnconditionally( I, DT, TLI, CurLoop, SafetyInfo, ORE, - CurLoop->getLoopPreheader()->getTerminator(), AllowSpeculation)) { + CurLoop->getLoopPreheader()->getTerminator(), AC, + AllowSpeculation)) { hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, SE, ORE); HoistedInstructions.push_back(&I); @@ -1086,7 +1099,7 @@ static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, // in bits. Also, the invariant.start should dominate the load, and we // should not hoist the load out of a loop that contains this dominating // invariant.start. - if (LocSizeInBits.getFixedSize() <= InvariantSizeInBits && + if (LocSizeInBits.getFixedValue() <= InvariantSizeInBits && DT->properlyDominates(II->getParent(), CurLoop->getHeader())) return true; } @@ -1151,7 +1164,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // Loads from constant memory are always safe to move, even if they end up // in the same alias set as something that ends up being modified. - if (AA->pointsToConstantMemory(LI->getOperand(0))) + if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0)))) return true; if (LI->hasMetadata(LLVMContext::MD_invariant_load)) return true; @@ -1202,14 +1215,14 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, return true; // Handle simple cases by querying alias analysis. - FunctionModRefBehavior Behavior = AA->getModRefBehavior(CI); - if (Behavior == FMRB_DoesNotAccessMemory) + MemoryEffects Behavior = AA->getMemoryEffects(CI); + if (Behavior.doesNotAccessMemory()) return true; - if (AAResults::onlyReadsMemory(Behavior)) { + if (Behavior.onlyReadsMemory()) { // A readonly argmemonly function only reads from memory pointed to by // it's arguments with arbitrary offsets. If we can prove there are no // writes to this memory in the loop, we can hoist or sink. - if (AAResults::onlyAccessesArgPointees(Behavior)) { + if (Behavior.onlyAccessesArgPointees()) { // TODO: expand to writeable arguments for (Value *Op : CI->args()) if (Op->getType()->isPointerTy() && @@ -1316,13 +1329,14 @@ static bool isTriviallyReplaceablePHI(const PHINode &PN, const Instruction &I) { /// Return true if the instruction is free in the loop. static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, const TargetTransformInfo *TTI) { + InstructionCost CostI = + TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); - if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&I)) { - if (TTI->getUserCost(GEP, TargetTransformInfo::TCK_SizeAndLatency) != - TargetTransformInfo::TCC_Free) + if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { + if (CostI != TargetTransformInfo::TCC_Free) return false; - // For a GEP, we cannot simply use getUserCost because currently it - // optimistically assumes that a GEP will fold into addressing mode + // For a GEP, we cannot simply use getInstructionCost because currently + // it optimistically assumes that a GEP will fold into addressing mode // regardless of its users. const BasicBlock *BB = GEP->getParent(); for (const User *U : GEP->users()) { @@ -1333,9 +1347,9 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, return false; } return true; - } else - return TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency) == - TargetTransformInfo::TCC_Free; + } + + return CostI == TargetTransformInfo::TCC_Free; } /// Return true if the only users of this instruction are outside of @@ -1420,7 +1434,7 @@ static Instruction *cloneInstructionInExitBlock( New = I.clone(); } - ExitBlock.getInstList().insert(ExitBlock.getFirstInsertionPt(), New); + New->insertInto(&ExitBlock, ExitBlock.getFirstInsertionPt()); if (!I.getName().empty()) New->setName(I.getName() + ".le"); @@ -1587,9 +1601,8 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, /// position, and may either delete it or move it to outside of the loop. /// static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, - BlockFrequencyInfo *BFI, const Loop *CurLoop, - ICFLoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU, - OptimizationRemarkEmitter *ORE) { + const Loop *CurLoop, ICFLoopSafetyInfo *SafetyInfo, + MemorySSAUpdater &MSSAU, OptimizationRemarkEmitter *ORE) { bool Changed = false; LLVM_DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); @@ -1741,8 +1754,9 @@ static bool isSafeToExecuteUnconditionally( Instruction &Inst, const DominatorTree *DT, const TargetLibraryInfo *TLI, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, const Instruction *CtxI, - bool AllowSpeculation) { - if (AllowSpeculation && isSafeToSpeculativelyExecute(&Inst, CtxI, DT, TLI)) + AssumptionCache *AC, bool AllowSpeculation) { + if (AllowSpeculation && + isSafeToSpeculativelyExecute(&Inst, CtxI, AC, DT, TLI)) return true; bool GuaranteedToExecute = @@ -1765,7 +1779,6 @@ static bool isSafeToExecuteUnconditionally( namespace { class LoopPromoter : public LoadAndStorePromoter { Value *SomePtr; // Designated pointer to store to. - const SmallSetVector<Value *, 8> &PointerMustAliases; SmallVectorImpl<BasicBlock *> &LoopExitBlocks; SmallVectorImpl<Instruction *> &LoopInsertPts; SmallVectorImpl<MemoryAccess *> &MSSAInsertPts; @@ -1778,6 +1791,7 @@ class LoopPromoter : public LoadAndStorePromoter { AAMDNodes AATags; ICFLoopSafetyInfo &SafetyInfo; bool CanInsertStoresInExitBlocks; + ArrayRef<const Instruction *> Uses; // We're about to add a use of V in a loop exit block. Insert an LCSSA phi // (if legal) if doing so would add an out-of-loop use to an instruction @@ -1798,35 +1812,25 @@ class LoopPromoter : public LoadAndStorePromoter { public: LoopPromoter(Value *SP, ArrayRef<const Instruction *> Insts, SSAUpdater &S, - const SmallSetVector<Value *, 8> &PMA, SmallVectorImpl<BasicBlock *> &LEB, SmallVectorImpl<Instruction *> &LIP, SmallVectorImpl<MemoryAccess *> &MSSAIP, PredIteratorCache &PIC, MemorySSAUpdater &MSSAU, LoopInfo &li, DebugLoc dl, Align Alignment, bool UnorderedAtomic, const AAMDNodes &AATags, ICFLoopSafetyInfo &SafetyInfo, bool CanInsertStoresInExitBlocks) - : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), - LoopExitBlocks(LEB), LoopInsertPts(LIP), MSSAInsertPts(MSSAIP), - PredCache(PIC), MSSAU(MSSAU), LI(li), DL(std::move(dl)), - Alignment(Alignment), UnorderedAtomic(UnorderedAtomic), AATags(AATags), + : LoadAndStorePromoter(Insts, S), SomePtr(SP), LoopExitBlocks(LEB), + LoopInsertPts(LIP), MSSAInsertPts(MSSAIP), PredCache(PIC), MSSAU(MSSAU), + LI(li), DL(std::move(dl)), Alignment(Alignment), + UnorderedAtomic(UnorderedAtomic), AATags(AATags), SafetyInfo(SafetyInfo), - CanInsertStoresInExitBlocks(CanInsertStoresInExitBlocks) {} - - bool isInstInList(Instruction *I, - const SmallVectorImpl<Instruction *> &) const override { - Value *Ptr; - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - Ptr = LI->getOperand(0); - else - Ptr = cast<StoreInst>(I)->getPointerOperand(); - return PointerMustAliases.count(Ptr); - } + CanInsertStoresInExitBlocks(CanInsertStoresInExitBlocks), Uses(Insts) {} void insertStoresInLoopExitBlocks() { // Insert stores after in the loop exit blocks. Each exit block gets a // store of the live-out values that feed them. Since we've already told // the SSA updater about the defs in the loop and the preheader // definition, it is all set and we can start using it. + DIAssignID *NewID = nullptr; for (unsigned i = 0, e = LoopExitBlocks.size(); i != e; ++i) { BasicBlock *ExitBlock = LoopExitBlocks[i]; Value *LiveInValue = SSA.GetValueInMiddleOfBlock(ExitBlock); @@ -1838,6 +1842,21 @@ public: NewSI->setOrdering(AtomicOrdering::Unordered); NewSI->setAlignment(Alignment); NewSI->setDebugLoc(DL); + // Attach DIAssignID metadata to the new store, generating it on the + // first loop iteration. + if (i == 0) { + // NewSI will have its DIAssignID set here if there are any stores in + // Uses with a DIAssignID attachment. This merged ID will then be + // attached to the other inserted stores (in the branch below). + NewSI->mergeDIAssignID(Uses); + NewID = cast_or_null<DIAssignID>( + NewSI->getMetadata(LLVMContext::MD_DIAssignID)); + } else { + // Attach the DIAssignID (or nullptr) merged from Uses in the branch + // above. + NewSI->setMetadata(LLVMContext::MD_DIAssignID, NewID); + } + if (AATags) NewSI->setAAMetadata(AATags); @@ -1896,6 +1915,33 @@ bool isNotVisibleOnUnwindInLoop(const Value *Object, const Loop *L, isNotCapturedBeforeOrInLoop(Object, L, DT); } +bool isWritableObject(const Value *Object) { + // TODO: Alloca might not be writable after its lifetime ends. + // See https://github.com/llvm/llvm-project/issues/51838. + if (isa<AllocaInst>(Object)) + return true; + + // TODO: Also handle sret. + if (auto *A = dyn_cast<Argument>(Object)) + return A->hasByValAttr(); + + if (auto *G = dyn_cast<GlobalVariable>(Object)) + return !G->isConstant(); + + // TODO: Noalias has nothing to do with writability, this should check for + // an allocator function. + return isNoAliasCall(Object); +} + +bool isThreadLocalObject(const Value *Object, const Loop *L, DominatorTree *DT, + TargetTransformInfo *TTI) { + // The object must be function-local to start with, and then not captured + // before/in the loop. + return (isIdentifiedFunctionLocal(Object) && + isNotCapturedBeforeOrInLoop(Object, L, DT)) || + (TTI->isSingleThreaded() || SingleThread); +} + } // namespace /// Try to promote memory values to scalars by sinking stores out of the @@ -1908,14 +1954,23 @@ bool llvm::promoteLoopAccessesToScalars( SmallVectorImpl<BasicBlock *> &ExitBlocks, SmallVectorImpl<Instruction *> &InsertPts, SmallVectorImpl<MemoryAccess *> &MSSAInsertPts, PredIteratorCache &PIC, - LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - Loop *CurLoop, MemorySSAUpdater &MSSAU, ICFLoopSafetyInfo *SafetyInfo, - OptimizationRemarkEmitter *ORE, bool AllowSpeculation) { + LoopInfo *LI, DominatorTree *DT, AssumptionCache *AC, + const TargetLibraryInfo *TLI, TargetTransformInfo *TTI, Loop *CurLoop, + MemorySSAUpdater &MSSAU, ICFLoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, bool AllowSpeculation, + bool HasReadsOutsideSet) { // Verify inputs. assert(LI != nullptr && DT != nullptr && CurLoop != nullptr && SafetyInfo != nullptr && "Unexpected Input to promoteLoopAccessesToScalars"); + LLVM_DEBUG({ + dbgs() << "Trying to promote set of must-aliased pointers:\n"; + for (Value *Ptr : PointerMustAliases) + dbgs() << " " << *Ptr << "\n"; + }); + ++NumPromotionCandidates; + Value *SomePtr = *PointerMustAliases.begin(); BasicBlock *Preheader = CurLoop->getLoopPreheader(); @@ -1957,9 +2012,14 @@ bool llvm::promoteLoopAccessesToScalars( // store is never executed, but the exit blocks are not executed either. bool DereferenceableInPH = false; - bool SafeToInsertStore = false; bool StoreIsGuanteedToExecute = false; bool FoundLoadToPromote = false; + // Goes from Unknown to either Safe or Unsafe, but can't switch between them. + enum { + StoreSafe, + StoreUnsafe, + StoreSafetyUnknown, + } StoreSafety = StoreSafetyUnknown; SmallVector<Instruction *, 64> LoopUses; @@ -1973,8 +2033,12 @@ bool llvm::promoteLoopAccessesToScalars( const DataLayout &MDL = Preheader->getModule()->getDataLayout(); - bool IsKnownThreadLocalObject = false; - if (SafetyInfo->anyBlockMayThrow()) { + // If there are reads outside the promoted set, then promoting stores is + // definitely not safe. + if (HasReadsOutsideSet) + StoreSafety = StoreUnsafe; + + if (StoreSafety == StoreSafetyUnknown && SafetyInfo->anyBlockMayThrow()) { // If a loop can throw, we have to insert a store along each unwind edge. // That said, we can't actually make the unwind edge explicit. Therefore, // we have to prove that the store is dead along the unwind edge. We do @@ -1982,13 +2046,10 @@ bool llvm::promoteLoopAccessesToScalars( // after return and thus can't possibly load from the object. Value *Object = getUnderlyingObject(SomePtr); if (!isNotVisibleOnUnwindInLoop(Object, CurLoop, DT)) - return false; - // Subtlety: Alloca's aren't visible to callers, but *are* potentially - // visible to other threads if captured and used during their lifetimes. - IsKnownThreadLocalObject = !isa<AllocaInst>(Object); + StoreSafety = StoreUnsafe; } - // Check that all accesses to pointers in the aliass set use the same type. + // Check that all accesses to pointers in the alias set use the same type. // We cannot (yet) promote a memory location that is loaded and stored in // different sizes. While we are at it, collect alignment and AA info. Type *AccessTy = nullptr; @@ -2018,7 +2079,7 @@ bool llvm::promoteLoopAccessesToScalars( if (!DereferenceableInPH || (InstAlignment > Alignment)) if (isSafeToExecuteUnconditionally( *Load, DT, TLI, CurLoop, SafetyInfo, ORE, - Preheader->getTerminator(), AllowSpeculation)) { + Preheader->getTerminator(), AC, AllowSpeculation)) { DereferenceableInPH = true; Alignment = std::max(Alignment, InstAlignment); } @@ -2042,13 +2103,11 @@ bool llvm::promoteLoopAccessesToScalars( bool GuaranteedToExecute = SafetyInfo->isGuaranteedToExecute(*UI, DT, CurLoop); StoreIsGuanteedToExecute |= GuaranteedToExecute; - if (!DereferenceableInPH || !SafeToInsertStore || - (InstAlignment > Alignment)) { - if (GuaranteedToExecute) { - DereferenceableInPH = true; - SafeToInsertStore = true; - Alignment = std::max(Alignment, InstAlignment); - } + if (GuaranteedToExecute) { + DereferenceableInPH = true; + if (StoreSafety == StoreSafetyUnknown) + StoreSafety = StoreSafe; + Alignment = std::max(Alignment, InstAlignment); } // If a store dominates all exit blocks, it is safe to sink. @@ -2057,20 +2116,21 @@ bool llvm::promoteLoopAccessesToScalars( // introducing stores on paths that did not have them. // Note that this only looks at explicit exit blocks. If we ever // start sinking stores into unwind edges (see above), this will break. - if (!SafeToInsertStore) - SafeToInsertStore = llvm::all_of(ExitBlocks, [&](BasicBlock *Exit) { - return DT->dominates(Store->getParent(), Exit); - }); + if (StoreSafety == StoreSafetyUnknown && + llvm::all_of(ExitBlocks, [&](BasicBlock *Exit) { + return DT->dominates(Store->getParent(), Exit); + })) + StoreSafety = StoreSafe; // If the store is not guaranteed to execute, we may still get // deref info through it. if (!DereferenceableInPH) { DereferenceableInPH = isDereferenceableAndAlignedPointer( Store->getPointerOperand(), Store->getValueOperand()->getType(), - Store->getAlign(), MDL, Preheader->getTerminator(), DT, TLI); + Store->getAlign(), MDL, Preheader->getTerminator(), AC, DT, TLI); } } else - return false; // Not a load or store. + continue; // Not a load or store. if (!AccessTy) AccessTy = getLoadStoreType(UI); @@ -2103,58 +2163,58 @@ bool llvm::promoteLoopAccessesToScalars( return false; // If we couldn't prove we can hoist the load, bail. - if (!DereferenceableInPH) + if (!DereferenceableInPH) { + LLVM_DEBUG(dbgs() << "Not promoting: Not dereferenceable in preheader\n"); return false; + } // We know we can hoist the load, but don't have a guaranteed store. - // Check whether the location is thread-local. If it is, then we can insert - // stores along paths which originally didn't have them without violating the - // memory model. - if (!SafeToInsertStore) { - if (IsKnownThreadLocalObject) - SafeToInsertStore = true; - else { - Value *Object = getUnderlyingObject(SomePtr); - SafeToInsertStore = - (isNoAliasCall(Object) || isa<AllocaInst>(Object)) && - isNotCapturedBeforeOrInLoop(Object, CurLoop, DT); - } + // Check whether the location is writable and thread-local. If it is, then we + // can insert stores along paths which originally didn't have them without + // violating the memory model. + if (StoreSafety == StoreSafetyUnknown) { + Value *Object = getUnderlyingObject(SomePtr); + if (isWritableObject(Object) && + isThreadLocalObject(Object, CurLoop, DT, TTI)) + StoreSafety = StoreSafe; } // If we've still failed to prove we can sink the store, hoist the load // only, if possible. - if (!SafeToInsertStore && !FoundLoadToPromote) + if (StoreSafety != StoreSafe && !FoundLoadToPromote) // If we cannot hoist the load either, give up. return false; // Lets do the promotion! - if (SafeToInsertStore) + if (StoreSafety == StoreSafe) { LLVM_DEBUG(dbgs() << "LICM: Promoting load/store of the value: " << *SomePtr << '\n'); - else + ++NumLoadStorePromoted; + } else { LLVM_DEBUG(dbgs() << "LICM: Promoting load of the value: " << *SomePtr << '\n'); + ++NumLoadPromoted; + } ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "PromoteLoopAccessesToScalar", LoopUses[0]) << "Moving accesses to memory location out of the loop"; }); - ++NumPromoted; // Look at all the loop uses, and try to merge their locations. std::vector<const DILocation *> LoopUsesLocs; - for (auto U : LoopUses) + for (auto *U : LoopUses) LoopUsesLocs.push_back(U->getDebugLoc().get()); auto DL = DebugLoc(DILocation::getMergedLocations(LoopUsesLocs)); // We use the SSAUpdater interface to insert phi nodes as required. SmallVector<PHINode *, 16> NewPHIs; SSAUpdater SSA(&NewPHIs); - LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, - InsertPts, MSSAInsertPts, PIC, MSSAU, *LI, DL, - Alignment, SawUnorderedAtomic, AATags, *SafetyInfo, - SafeToInsertStore); + LoopPromoter Promoter(SomePtr, LoopUses, SSA, ExitBlocks, InsertPts, + MSSAInsertPts, PIC, MSSAU, *LI, DL, Alignment, + SawUnorderedAtomic, AATags, *SafetyInfo, + StoreSafety == StoreSafe); // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. @@ -2203,9 +2263,12 @@ static void foreachMemoryAccess(MemorySSA *MSSA, Loop *L, Fn(MUD->getMemoryInst()); } -static SmallVector<SmallSetVector<Value *, 8>, 0> +// The bool indicates whether there might be reads outside the set, in which +// case only loads may be promoted. +static SmallVector<PointersAndHasReadsOutsideSet, 0> collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) { - AliasSetTracker AST(*AA); + BatchAAResults BatchAA(*AA); + AliasSetTracker AST(BatchAA); auto IsPotentiallyPromotable = [L](const Instruction *I) { if (const auto *SI = dyn_cast<StoreInst>(I)) @@ -2225,10 +2288,10 @@ collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) { }); // We're only interested in must-alias sets that contain a mod. - SmallVector<const AliasSet *, 8> Sets; + SmallVector<PointerIntPair<const AliasSet *, 1, bool>, 8> Sets; for (AliasSet &AS : AST) if (!AS.isForwardingAliasSet() && AS.isMod() && AS.isMustAlias()) - Sets.push_back(&AS); + Sets.push_back({&AS, false}); if (Sets.empty()) return {}; // Nothing to promote... @@ -2238,17 +2301,28 @@ collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) { if (AttemptingPromotion.contains(I)) return; - llvm::erase_if(Sets, [&](const AliasSet *AS) { - return AS->aliasesUnknownInst(I, *AA); + llvm::erase_if(Sets, [&](PointerIntPair<const AliasSet *, 1, bool> &Pair) { + ModRefInfo MR = Pair.getPointer()->aliasesUnknownInst(I, BatchAA); + // Cannot promote if there are writes outside the set. + if (isModSet(MR)) + return true; + if (isRefSet(MR)) { + // Remember reads outside the set. + Pair.setInt(true); + // If this is a mod-only set and there are reads outside the set, + // we will not be able to promote, so bail out early. + return !Pair.getPointer()->isRef(); + } + return false; }); }); - SmallVector<SmallSetVector<Value *, 8>, 0> Result; - for (const AliasSet *Set : Sets) { + SmallVector<std::pair<SmallSetVector<Value *, 8>, bool>, 0> Result; + for (auto [Set, HasReadsOutsideSet] : Sets) { SmallSetVector<Value *, 8> PointerMustAliases; for (const auto &ASI : *Set) PointerMustAliases.insert(ASI.getValue()); - Result.push_back(std::move(PointerMustAliases)); + Result.emplace_back(std::move(PointerMustAliases), HasReadsOutsideSet); } return Result; diff --git a/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp b/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp index c063c0d3c88a..9ae55b9018da 100644 --- a/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp +++ b/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp @@ -7,19 +7,27 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopAccessAnalysisPrinter.h" +#include "llvm/ADT/PriorityWorklist.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + using namespace llvm; #define DEBUG_TYPE "loop-accesses" -PreservedAnalyses -LoopAccessInfoPrinterPass::run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, LPMUpdater &) { - Function &F = *L.getHeader()->getParent(); - auto &LAI = AM.getResult<LoopAccessAnalysis>(L, AR); +PreservedAnalyses LoopAccessInfoPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &LAIs = AM.getResult<LoopAccessAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); OS << "Loop access info in function '" << F.getName() << "':\n"; - OS.indent(2) << L.getHeader()->getName() << ":\n"; - LAI.print(OS, 4); + + SmallPriorityWorklist<Loop *, 4> Worklist; + appendLoopsToWorklist(LI, Worklist); + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + OS.indent(2) << L->getHeader()->getName() << ":\n"; + LAIs.getInfo(*L).print(OS, 4); + } return PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index 013a119c5096..7c2770979a90 100644 --- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -338,7 +338,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { } else continue; unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); - if (PtrAddrSpace) + if (!TTI->shouldPrefetchAddressSpace(PtrAddrSpace)) continue; NumMemAccesses++; if (L->isLoopInvariant(PtrValue)) @@ -398,7 +398,8 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { if (!SCEVE.isSafeToExpand(NextLSCEV)) continue; - Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/); + unsigned PtrAddrSpace = NextLSCEV->getType()->getPointerAddressSpace(); + Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt); IRBuilder<> Builder(P.InsertPt); diff --git a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index 93f3cd704196..7e4dbace043a 100644 --- a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -82,31 +82,30 @@ static bool isLoopDead(Loop *L, ScalarEvolution &SE, // blocks, then it is impossible to statically determine which value // should be used. AllOutgoingValuesSame = - all_of(makeArrayRef(ExitingBlocks).slice(1), [&](BasicBlock *BB) { + all_of(ArrayRef(ExitingBlocks).slice(1), [&](BasicBlock *BB) { return incoming == P.getIncomingValueForBlock(BB); }); if (!AllOutgoingValuesSame) break; - if (Instruction *I = dyn_cast<Instruction>(incoming)) - if (!L->makeLoopInvariant(I, Changed, Preheader->getTerminator())) { + if (Instruction *I = dyn_cast<Instruction>(incoming)) { + if (!L->makeLoopInvariant(I, Changed, Preheader->getTerminator(), + /*MSSAU=*/nullptr, &SE)) { AllEntriesInvariant = false; break; } + } } } - if (Changed) - SE.forgetLoopDispositions(L); - if (!AllEntriesInvariant || !AllOutgoingValuesSame) return false; // Make sure that no instructions in the block have potential side-effects. // This includes instructions that could write to memory, and loads that are // marked volatile. - for (auto &I : L->blocks()) + for (const auto &I : L->blocks()) if (any_of(*I, [](Instruction &I) { return I.mayHaveSideEffects() && !I.isDroppable(); })) @@ -456,7 +455,7 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, BasicBlock *ExitBlock = L->getUniqueExitBlock(); if (ExitBlock && isLoopNeverExecuted(L)) { - LLVM_DEBUG(dbgs() << "Loop is proven to never execute, delete it!"); + LLVM_DEBUG(dbgs() << "Loop is proven to never execute, delete it!\n"); // We need to forget the loop before setting the incoming values of the exit // phis to poison, so we properly invalidate the SCEV expressions for those // phis. @@ -497,7 +496,7 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, : LoopDeletionResult::Unmodified; } - LLVM_DEBUG(dbgs() << "Loop is invariant, delete it!"); + LLVM_DEBUG(dbgs() << "Loop is invariant, delete it!\n"); ORE.emit([&]() { return OptimizationRemark(DEBUG_TYPE, "Invariant", L->getStartLoc(), L->getHeader()) diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index b178bcae3b0e..7b52b7dca85f 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -25,7 +25,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/EquivalenceClasses.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -397,7 +396,7 @@ public: continue; auto PartI = I->getData(); - for (auto PartJ : make_range(std::next(ToBeMerged.member_begin(I)), + for (auto *PartJ : make_range(std::next(ToBeMerged.member_begin(I)), ToBeMerged.member_end())) { PartJ->moveTo(*PartI); } @@ -461,16 +460,14 @@ public: // update PH to point to the newly added preheader. BasicBlock *TopPH = OrigPH; unsigned Index = getSize() - 1; - for (auto I = std::next(PartitionContainer.rbegin()), - E = PartitionContainer.rend(); - I != E; ++I, --Index, TopPH = NewLoop->getLoopPreheader()) { - auto *Part = &*I; - - NewLoop = Part->cloneLoopWithPreheader(TopPH, Pred, Index, LI, DT); - - Part->getVMap()[ExitBlock] = TopPH; - Part->remapInstructions(); - setNewLoopID(OrigLoopID, Part); + for (auto &Part : llvm::drop_begin(llvm::reverse(PartitionContainer))) { + NewLoop = Part.cloneLoopWithPreheader(TopPH, Pred, Index, LI, DT); + + Part.getVMap()[ExitBlock] = TopPH; + Part.remapInstructions(); + setNewLoopID(OrigLoopID, &Part); + --Index; + TopPH = NewLoop->getLoopPreheader(); } Pred->getTerminator()->replaceUsesOfWith(OrigPH, TopPH); @@ -595,14 +592,14 @@ private: /// Assign new LoopIDs for the partition's cloned loop. void setNewLoopID(MDNode *OrigLoopID, InstPartition *Part) { - Optional<MDNode *> PartitionID = makeFollowupLoopID( + std::optional<MDNode *> PartitionID = makeFollowupLoopID( OrigLoopID, {LLVMLoopDistributeFollowupAll, Part->hasDepCycle() ? LLVMLoopDistributeFollowupSequential : LLVMLoopDistributeFollowupCoincident}); if (PartitionID) { Loop *NewLoop = Part->getDistributedLoop(); - NewLoop->setLoopID(PartitionID.value()); + NewLoop->setLoopID(*PartitionID); } } }; @@ -635,7 +632,7 @@ public: Accesses.append(Instructions.begin(), Instructions.end()); LLVM_DEBUG(dbgs() << "Backward dependences:\n"); - for (auto &Dep : Dependences) + for (const auto &Dep : Dependences) if (Dep.isPossiblyBackward()) { // Note that the designations source and destination follow the program // order, i.e. source is always first. (The direction is given by the @@ -655,13 +652,14 @@ private: class LoopDistributeForLoop { public: LoopDistributeForLoop(Loop *L, Function *F, LoopInfo *LI, DominatorTree *DT, - ScalarEvolution *SE, OptimizationRemarkEmitter *ORE) - : L(L), F(F), LI(LI), DT(DT), SE(SE), ORE(ORE) { + ScalarEvolution *SE, LoopAccessInfoManager &LAIs, + OptimizationRemarkEmitter *ORE) + : L(L), F(F), LI(LI), DT(DT), SE(SE), LAIs(LAIs), ORE(ORE) { setForced(); } /// Try to distribute an inner-most loop. - bool processLoop(std::function<const LoopAccessInfo &(Loop &)> &GetLAA) { + bool processLoop() { assert(L->isInnermost() && "Only process inner loops."); LLVM_DEBUG(dbgs() << "\nLDist: In \"" @@ -679,7 +677,7 @@ public: BasicBlock *PH = L->getLoopPreheader(); - LAI = &GetLAA(*L); + LAI = &LAIs.getInfo(*L); // Currently, we only distribute to isolate the part of the loop with // dependence cycles to enable partial vectorization. @@ -717,7 +715,7 @@ public: *Dependences); int NumUnsafeDependencesActive = 0; - for (auto &InstDep : MID) { + for (const auto &InstDep : MID) { Instruction *I = InstDep.Inst; // We update NumUnsafeDependencesActive post-instruction, catch the // start of a dependence directly via NumUnsafeDependencesStartOrEnd. @@ -821,12 +819,10 @@ public: // The unversioned loop will not be changed, so we inherit all attributes // from the original loop, but remove the loop distribution metadata to // avoid to distribute it again. - MDNode *UnversionedLoopID = - makeFollowupLoopID(OrigLoopID, - {LLVMLoopDistributeFollowupAll, - LLVMLoopDistributeFollowupFallback}, - "llvm.loop.distribute.", true) - .value(); + MDNode *UnversionedLoopID = *makeFollowupLoopID( + OrigLoopID, + {LLVMLoopDistributeFollowupAll, LLVMLoopDistributeFollowupFallback}, + "llvm.loop.distribute.", true); LVer.getNonVersionedLoop()->setLoopID(UnversionedLoopID); } @@ -893,7 +889,7 @@ public: /// If the optional has a value, it indicates whether distribution was forced /// to be enabled (true) or disabled (false). If the optional has no value /// distribution was not forced either way. - const Optional<bool> &isForced() const { return IsForced; } + const std::optional<bool> &isForced() const { return IsForced; } private: /// Filter out checks between pointers from the same partition. @@ -937,7 +933,7 @@ private: /// Check whether the loop metadata is forcing distribution to be /// enabled/disabled. void setForced() { - Optional<const MDOperand *> Value = + std::optional<const MDOperand *> Value = findStringMetadataForLoop(L, "llvm.loop.distribute.enable"); if (!Value) return; @@ -955,6 +951,7 @@ private: const LoopAccessInfo *LAI = nullptr; DominatorTree *DT; ScalarEvolution *SE; + LoopAccessInfoManager &LAIs; OptimizationRemarkEmitter *ORE; /// Indicates whether distribution is forced to be enabled/disabled for @@ -963,7 +960,7 @@ private: /// If the optional has a value, it indicates whether distribution was forced /// to be enabled (true) or disabled (false). If the optional has no value /// distribution was not forced either way. - Optional<bool> IsForced; + std::optional<bool> IsForced; }; } // end anonymous namespace @@ -971,7 +968,7 @@ private: /// Shared implementation between new and old PMs. static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, - std::function<const LoopAccessInfo &(Loop &)> &GetLAA) { + LoopAccessInfoManager &LAIs) { // Build up a worklist of inner-loops to vectorize. This is necessary as the // act of distributing a loop creates new loops and can invalidate iterators // across the loops. @@ -986,12 +983,12 @@ static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, // Now walk the identified inner loops. bool Changed = false; for (Loop *L : Worklist) { - LoopDistributeForLoop LDL(L, &F, LI, DT, SE, ORE); + LoopDistributeForLoop LDL(L, &F, LI, DT, SE, LAIs, ORE); // If distribution was forced for the specific loop to be // enabled/disabled, follow that. Otherwise use the global flag. if (LDL.isForced().value_or(EnableLoopDistribute)) - Changed |= LDL.processLoop(GetLAA); + Changed |= LDL.processLoop(); } // Process each loop nest in the function. @@ -1015,14 +1012,12 @@ public: return false; auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - std::function<const LoopAccessInfo &(Loop &)> GetLAA = - [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; + auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); - return runImpl(F, LI, DT, SE, ORE, GetLAA); + return runImpl(F, LI, DT, SE, ORE, LAIs); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -1046,22 +1041,8 @@ PreservedAnalyses LoopDistributePass::run(Function &F, auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - // We don't directly need these analyses but they're required for loop - // analyses so provide them below. - auto &AA = AM.getResult<AAManager>(F); - auto &AC = AM.getResult<AssumptionAnalysis>(F); - auto &TTI = AM.getResult<TargetIRAnalysis>(F); - auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - - auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); - std::function<const LoopAccessInfo &(Loop &)> GetLAA = - [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, - TLI, TTI, nullptr, nullptr, nullptr}; - return LAM.getResult<LoopAccessAnalysis>(L, AR); - }; - - bool Changed = runImpl(F, &LI, &DT, &SE, &ORE, GetLAA); + LoopAccessInfoManager &LAIs = AM.getResult<LoopAccessAnalysis>(F); + bool Changed = runImpl(F, &LI, &DT, &SE, &ORE, LAIs); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index f36193fc468e..7d9ce8d35e0b 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -75,6 +75,7 @@ #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -99,6 +100,7 @@ static cl::opt<bool> cl::desc("Widen the loop induction variables, if possible, so " "overflow checks won't reject flattening")); +namespace { // We require all uses of both induction variables to match this pattern: // // (OuterPHI * InnerTripCount) + InnerPHI @@ -139,7 +141,7 @@ struct FlattenInfo { PHINode *NarrowInnerInductionPHI = nullptr; // Holds the old/narrow induction PHINode *NarrowOuterInductionPHI = nullptr; // phis, i.e. the Phis before IV - // has been apllied. Used to skip + // has been applied. Used to skip // checks on phi nodes. FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){}; @@ -191,7 +193,7 @@ struct FlattenInfo { bool matchLinearIVUser(User *U, Value *InnerTripCount, SmallPtrSet<Value *, 4> &ValidOuterPHIUses) { - LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); + LLVM_DEBUG(dbgs() << "Checking linear i*M+j expression for: "; U->dump()); Value *MatchedMul = nullptr; Value *MatchedItCount = nullptr; @@ -211,6 +213,18 @@ struct FlattenInfo { if (!MatchedItCount) return false; + LLVM_DEBUG(dbgs() << "Matched multiplication: "; MatchedMul->dump()); + LLVM_DEBUG(dbgs() << "Matched iteration count: "; MatchedItCount->dump()); + + // The mul should not have any other uses. Widening may leave trivially dead + // uses, which can be ignored. + if (count_if(MatchedMul->users(), [](User *U) { + return !isInstructionTriviallyDead(cast<Instruction>(U)); + }) > 1) { + LLVM_DEBUG(dbgs() << "Multiply has more than one use\n"); + return false; + } + // Look through extends if the IV has been widened. Don't look through // extends if we already looked through a trunc. if (Widened && IsAdd && @@ -222,8 +236,11 @@ struct FlattenInfo { : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0); } + LLVM_DEBUG(dbgs() << "Looking for inner trip count: "; + InnerTripCount->dump()); + if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { - LLVM_DEBUG(dbgs() << "Use is optimisable\n"); + LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); LinearIVUses.insert(U); return true; @@ -240,8 +257,11 @@ struct FlattenInfo { SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0); for (User *U : InnerInductionPHI->users()) { - if (isInnerLoopIncrement(U)) + LLVM_DEBUG(dbgs() << "Checking User: "; U->dump()); + if (isInnerLoopIncrement(U)) { + LLVM_DEBUG(dbgs() << "Use is inner loop increment, continuing\n"); continue; + } // After widening the IVs, a trunc instruction might have been introduced, // so look through truncs. @@ -255,15 +275,21 @@ struct FlattenInfo { // branch) then the compare has been altered by another transformation e.g // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, where tripcount is // a constant. Ignore this use as the compare gets removed later anyway. - if (isInnerLoopTest(U)) + if (isInnerLoopTest(U)) { + LLVM_DEBUG(dbgs() << "Use is the inner loop test, continuing\n"); continue; + } - if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) + if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) { + LLVM_DEBUG(dbgs() << "Not a linear IV user\n"); return false; + } + LLVM_DEBUG(dbgs() << "Linear IV users found!\n"); } return true; } }; +} // namespace static bool setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, @@ -413,7 +439,8 @@ static bool findLoopComponents( // increment variable. Increment = cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch)); - if (Increment->hasNUsesOrMore(3)) { + if ((Compare->getOperand(0) != Increment || !Increment->hasNUses(2)) && + !Increment->hasNUses(1)) { LLVM_DEBUG(dbgs() << "Could not find valid increment\n"); return false; } @@ -540,7 +567,7 @@ checkOuterLoopInsts(FlattenInfo &FI, // they make a net difference of zero. if (IterationInstructions.count(&I)) continue; - // The uncoditional branch to the inner loop's header will turn into + // The unconditional branch to the inner loop's header will turn into // a fall-through, so adds no cost. BranchInst *Br = dyn_cast<BranchInst>(&I); if (Br && Br->isUnconditional() && @@ -552,7 +579,7 @@ checkOuterLoopInsts(FlattenInfo &FI, m_Specific(FI.InnerTripCount)))) continue; InstructionCost Cost = - TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + TTI->getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump()); RepeatedInstrCost += Cost; } @@ -759,9 +786,9 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, } // Tell LoopInfo, SCEV and the pass manager that the inner loop has been - // deleted, and any information that have about the outer loop invalidated. + // deleted, and invalidate any outer loop information. SE->forgetLoop(FI.OuterLoop); - SE->forgetLoop(FI.InnerLoop); + SE->forgetBlockAndLoopDispositions(); if (U) U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName()); LI->erase(FI.InnerLoop); @@ -911,7 +938,7 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, bool Changed = false; - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (AR.MSSA) { MSSAU = MemorySSAUpdater(AR.MSSA); if (VerifyMemorySSA) @@ -923,7 +950,7 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, - MSSAU ? MSSAU.getPointer() : nullptr); + MSSAU ? &*MSSAU : nullptr); if (!Changed) return PreservedAnalyses::all(); @@ -981,15 +1008,15 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) { auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (MSSA) MSSAU = MemorySSAUpdater(&MSSA->getMSSA()); bool Changed = false; for (Loop *L : *LI) { auto LN = LoopNest::getLoopNest(*L, *SE); - Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, - MSSAU ? MSSAU.getPointer() : nullptr); + Changed |= + Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, MSSAU ? &*MSSAU : nullptr); } return Changed; } diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp index d94b767c7b63..0eecec373736 100644 --- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -67,6 +67,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CodeMoverUtils.h" #include "llvm/Transforms/Utils/LoopPeel.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" using namespace llvm; @@ -101,6 +102,8 @@ STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block with " STATISTIC(NotRotated, "Candidate is not rotated"); STATISTIC(OnlySecondCandidateIsGuarded, "The second candidate is guarded while the first one is not"); +STATISTIC(NumHoistedInsts, "Number of hoisted preheader instructions."); +STATISTIC(NumSunkInsts, "Number of hoisted preheader instructions."); enum FusionDependenceAnalysisChoice { FUSION_DEPENDENCE_ANALYSIS_SCEV, @@ -183,9 +186,8 @@ struct FusionCandidate { OptimizationRemarkEmitter &ORE; - FusionCandidate(Loop *L, DominatorTree &DT, - const PostDominatorTree *PDT, OptimizationRemarkEmitter &ORE, - TTI::PeelingPreferences PP) + FusionCandidate(Loop *L, DominatorTree &DT, const PostDominatorTree *PDT, + OptimizationRemarkEmitter &ORE, TTI::PeelingPreferences PP) : Preheader(L->getLoopPreheader()), Header(L->getHeader()), ExitingBlock(L->getExitingBlock()), ExitBlock(L->getExitBlock()), Latch(L->getLoopLatch()), L(L), Valid(true), @@ -387,7 +389,13 @@ struct FusionCandidateCompare { /// Comparison functor to sort two Control Flow Equivalent fusion candidates /// into dominance order. /// If LHS dominates RHS and RHS post-dominates LHS, return true; - /// IF RHS dominates LHS and LHS post-dominates RHS, return false; + /// If RHS dominates LHS and LHS post-dominates RHS, return false; + /// If both LHS and RHS are not dominating each other then, non-strictly + /// post dominate check will decide the order of candidates. If RHS + /// non-strictly post dominates LHS then, return true. If LHS non-strictly + /// post dominates RHS then, return false. If both are non-strictly post + /// dominate each other then, level in the post dominator tree will decide + /// the order of candidates. bool operator()(const FusionCandidate &LHS, const FusionCandidate &RHS) const { const DominatorTree *DT = &(LHS.DT); @@ -413,9 +421,29 @@ struct FusionCandidateCompare { return true; } - // If LHS does not dominate RHS and RHS does not dominate LHS then there is - // no dominance relationship between the two FusionCandidates. Thus, they - // should not be in the same set together. + // If two FusionCandidates are in the same level of dominator tree, + // they will not dominate each other, but may still be control flow + // equivalent. To sort those FusionCandidates, nonStrictlyPostDominate() + // function is needed. + bool WrongOrder = + nonStrictlyPostDominate(LHSEntryBlock, RHSEntryBlock, DT, LHS.PDT); + bool RightOrder = + nonStrictlyPostDominate(RHSEntryBlock, LHSEntryBlock, DT, LHS.PDT); + if (WrongOrder && RightOrder) { + // If common predecessor of LHS and RHS post dominates both + // FusionCandidates then, Order of FusionCandidate can be + // identified by its level in post dominator tree. + DomTreeNode *LNode = LHS.PDT->getNode(LHSEntryBlock); + DomTreeNode *RNode = LHS.PDT->getNode(RHSEntryBlock); + return LNode->getLevel() > RNode->getLevel(); + } else if (WrongOrder) + return false; + else if (RightOrder) + return true; + + // If LHS does not non-strict Postdominate RHS and RHS does not non-strict + // Postdominate LHS then, there is no dominance relationship between the + // two FusionCandidates. Thus, they should not be in the same set together. llvm_unreachable( "No dominance relationship between these fusion candidates!"); } @@ -427,7 +455,7 @@ using LoopVector = SmallVector<Loop *, 4>; // order. Thus, if FC0 comes *before* FC1 in a FusionCandidateSet, then FC0 // dominates FC1 and FC1 post-dominates FC0. // std::set was chosen because we want a sorted data structure with stable -// iterators. A subsequent patch to loop fusion will enable fusing non-ajdacent +// iterators. A subsequent patch to loop fusion will enable fusing non-adjacent // loops by moving intervening code around. When this intervening code contains // loops, those loops will be moved also. The corresponding FusionCandidates // will also need to be moved accordingly. As this is done, having stable @@ -528,7 +556,7 @@ private: #ifndef NDEBUG static void printLoopVector(const LoopVector &LV) { dbgs() << "****************************\n"; - for (auto L : LV) + for (auto *L : LV) printLoop(*L, dbgs()); dbgs() << "****************************\n"; } @@ -549,7 +577,6 @@ private: PostDominatorTree &PDT; OptimizationRemarkEmitter &ORE; AssumptionCache &AC; - const TargetTransformInfo &TTI; public: @@ -644,7 +671,7 @@ private: void collectFusionCandidates(const LoopVector &LV) { for (Loop *L : LV) { TTI::PeelingPreferences PP = - gatherPeelingPreferences(L, SE, TTI, None, None); + gatherPeelingPreferences(L, SE, TTI, std::nullopt, std::nullopt); FusionCandidate CurrCand(L, DT, &PDT, ORE, PP); if (!CurrCand.isEligibleForFusion(SE)) continue; @@ -699,23 +726,22 @@ private: /// stating whether or not the two candidates are known at compile time to /// have the same TripCount. The second is the difference in the two /// TripCounts. This information can be used later to determine whether or not - /// peeling can be performed on either one of the candiates. - std::pair<bool, Optional<unsigned>> + /// peeling can be performed on either one of the candidates. + std::pair<bool, std::optional<unsigned>> haveIdenticalTripCounts(const FusionCandidate &FC0, const FusionCandidate &FC1) const { - const SCEV *TripCount0 = SE.getBackedgeTakenCount(FC0.L); if (isa<SCEVCouldNotCompute>(TripCount0)) { UncomputableTripCount++; LLVM_DEBUG(dbgs() << "Trip count of first loop could not be computed!"); - return {false, None}; + return {false, std::nullopt}; } const SCEV *TripCount1 = SE.getBackedgeTakenCount(FC1.L); if (isa<SCEVCouldNotCompute>(TripCount1)) { UncomputableTripCount++; LLVM_DEBUG(dbgs() << "Trip count of second loop could not be computed!"); - return {false, None}; + return {false, std::nullopt}; } LLVM_DEBUG(dbgs() << "\tTrip counts: " << *TripCount0 << " & " @@ -740,10 +766,10 @@ private: LLVM_DEBUG(dbgs() << "Loop(s) do not have a single exit point or do not " "have a constant number of iterations. Peeling " "is not benefical\n"); - return {false, None}; + return {false, std::nullopt}; } - Optional<unsigned> Difference = None; + std::optional<unsigned> Difference; int Diff = TC0 - TC1; if (Diff > 0) @@ -767,7 +793,8 @@ private: LLVM_DEBUG(dbgs() << "Attempting to peel first " << PeelCount << " iterations of the first loop. \n"); - FC0.Peeled = peelLoop(FC0.L, PeelCount, &LI, &SE, DT, &AC, true); + ValueToValueMapTy VMap; + FC0.Peeled = peelLoop(FC0.L, PeelCount, &LI, &SE, DT, &AC, true, VMap); if (FC0.Peeled) { LLVM_DEBUG(dbgs() << "Done Peeling\n"); @@ -807,7 +834,7 @@ private: } // Cannot modify the predecessors inside the above loop as it will cause // the iterators to be nullptrs, causing memory errors. - for (Instruction *CurrentBranch: WorkList) { + for (Instruction *CurrentBranch : WorkList) { BasicBlock *Succ = CurrentBranch->getSuccessor(0); if (Succ == BB) Succ = CurrentBranch->getSuccessor(1); @@ -858,12 +885,12 @@ private: // Check if the candidates have identical tripcounts (first value of // pair), and if not check the difference in the tripcounts between // the loops (second value of pair). The difference is not equal to - // None iff the loops iterate a constant number of times, and have a - // single exit. - std::pair<bool, Optional<unsigned>> IdenticalTripCountRes = + // std::nullopt iff the loops iterate a constant number of times, and + // have a single exit. + std::pair<bool, std::optional<unsigned>> IdenticalTripCountRes = haveIdenticalTripCounts(*FC0, *FC1); bool SameTripCount = IdenticalTripCountRes.first; - Optional<unsigned> TCDifference = IdenticalTripCountRes.second; + std::optional<unsigned> TCDifference = IdenticalTripCountRes.second; // Here we are checking that FC0 (the first loop) can be peeled, and // both loops have different tripcounts. @@ -895,9 +922,10 @@ private: continue; } - if (!FC0->GuardBranch && FC1->GuardBranch) { - LLVM_DEBUG(dbgs() << "The second candidate is guarded while the " - "first one is not. Not fusing.\n"); + if ((!FC0->GuardBranch && FC1->GuardBranch) || + (FC0->GuardBranch && !FC1->GuardBranch)) { + LLVM_DEBUG(dbgs() << "The one of candidate is guarded while the " + "another one is not. Not fusing.\n"); reportLoopFusion<OptimizationRemarkMissed>( *FC0, *FC1, OnlySecondCandidateIsGuarded); continue; @@ -914,16 +942,6 @@ private: continue; } - if (!isSafeToMoveBefore(*FC1->Preheader, - *FC0->Preheader->getTerminator(), DT, &PDT, - &DI)) { - LLVM_DEBUG(dbgs() << "Fusion candidate contains unsafe " - "instructions in preheader. Not fusing.\n"); - reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, - NonEmptyPreheader); - continue; - } - if (FC0->GuardBranch) { assert(FC1->GuardBranch && "Expecting valid FC1 guard branch"); @@ -959,6 +977,31 @@ private: continue; } + // If the second loop has instructions in the pre-header, attempt to + // hoist them up to the first loop's pre-header or sink them into the + // body of the second loop. + SmallVector<Instruction *, 4> SafeToHoist; + SmallVector<Instruction *, 4> SafeToSink; + // At this point, this is the last remaining legality check. + // Which means if we can make this pre-header empty, we can fuse + // these loops + if (!isEmptyPreheader(*FC1)) { + LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty " + "preheader.\n"); + + // If it is not safe to hoist/sink all instructions in the + // pre-header, we cannot fuse these loops. + if (!collectMovablePreheaderInsts(*FC0, *FC1, SafeToHoist, + SafeToSink)) { + LLVM_DEBUG(dbgs() << "Could not hoist/sink all instructions in " + "Fusion Candidate Pre-header.\n" + << "Not Fusing.\n"); + reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, + NonEmptyPreheader); + continue; + } + } + bool BeneficialToFuse = isBeneficialFusion(*FC0, *FC1); LLVM_DEBUG(dbgs() << "\tFusion appears to be " @@ -972,6 +1015,9 @@ private: // and profitable. At this point, start transforming the code and // perform fusion. + // Execute the hoist/sink operations on preheader instructions + movePreheaderInsts(*FC0, *FC1, SafeToHoist, SafeToSink); + LLVM_DEBUG(dbgs() << "\tFusion is performed: " << *FC0 << " and " << *FC1 << "\n"); @@ -1022,6 +1068,170 @@ private: return Fused; } + // Returns true if the instruction \p I can be hoisted to the end of the + // preheader of \p FC0. \p SafeToHoist contains the instructions that are + // known to be safe to hoist. The instructions encountered that cannot be + // hoisted are in \p NotHoisting. + // TODO: Move functionality into CodeMoverUtils + bool canHoistInst(Instruction &I, + const SmallVector<Instruction *, 4> &SafeToHoist, + const SmallVector<Instruction *, 4> &NotHoisting, + const FusionCandidate &FC0) const { + const BasicBlock *FC0PreheaderTarget = FC0.Preheader->getSingleSuccessor(); + assert(FC0PreheaderTarget && + "Expected single successor for loop preheader."); + + for (Use &Op : I.operands()) { + if (auto *OpInst = dyn_cast<Instruction>(Op)) { + bool OpHoisted = is_contained(SafeToHoist, OpInst); + // Check if we have already decided to hoist this operand. In this + // case, it does not dominate FC0 *yet*, but will after we hoist it. + if (!(OpHoisted || DT.dominates(OpInst, FC0PreheaderTarget))) { + return false; + } + } + } + + // PHIs in FC1's header only have FC0 blocks as predecessors. PHIs + // cannot be hoisted and should be sunk to the exit of the fused loop. + if (isa<PHINode>(I)) + return false; + + // If this isn't a memory inst, hoisting is safe + if (!I.mayReadOrWriteMemory()) + return true; + + LLVM_DEBUG(dbgs() << "Checking if this mem inst can be hoisted.\n"); + for (Instruction *NotHoistedInst : NotHoisting) { + if (auto D = DI.depends(&I, NotHoistedInst, true)) { + // Dependency is not read-before-write, write-before-read or + // write-before-write + if (D->isFlow() || D->isAnti() || D->isOutput()) { + LLVM_DEBUG(dbgs() << "Inst depends on an instruction in FC1's " + "preheader that is not being hoisted.\n"); + return false; + } + } + } + + for (Instruction *ReadInst : FC0.MemReads) { + if (auto D = DI.depends(ReadInst, &I, true)) { + // Dependency is not read-before-write + if (D->isAnti()) { + LLVM_DEBUG(dbgs() << "Inst depends on a read instruction in FC0.\n"); + return false; + } + } + } + + for (Instruction *WriteInst : FC0.MemWrites) { + if (auto D = DI.depends(WriteInst, &I, true)) { + // Dependency is not write-before-read or write-before-write + if (D->isFlow() || D->isOutput()) { + LLVM_DEBUG(dbgs() << "Inst depends on a write instruction in FC0.\n"); + return false; + } + } + } + return true; + } + + // Returns true if the instruction \p I can be sunk to the top of the exit + // block of \p FC1. + // TODO: Move functionality into CodeMoverUtils + bool canSinkInst(Instruction &I, const FusionCandidate &FC1) const { + for (User *U : I.users()) { + if (auto *UI{dyn_cast<Instruction>(U)}) { + // Cannot sink if user in loop + // If FC1 has phi users of this value, we cannot sink it into FC1. + if (FC1.L->contains(UI)) { + // Cannot hoist or sink this instruction. No hoisting/sinking + // should take place, loops should not fuse + return false; + } + } + } + + // If this isn't a memory inst, sinking is safe + if (!I.mayReadOrWriteMemory()) + return true; + + for (Instruction *ReadInst : FC1.MemReads) { + if (auto D = DI.depends(&I, ReadInst, true)) { + // Dependency is not write-before-read + if (D->isFlow()) { + LLVM_DEBUG(dbgs() << "Inst depends on a read instruction in FC1.\n"); + return false; + } + } + } + + for (Instruction *WriteInst : FC1.MemWrites) { + if (auto D = DI.depends(&I, WriteInst, true)) { + // Dependency is not write-before-write or read-before-write + if (D->isOutput() || D->isAnti()) { + LLVM_DEBUG(dbgs() << "Inst depends on a write instruction in FC1.\n"); + return false; + } + } + } + + return true; + } + + /// Collect instructions in the \p FC1 Preheader that can be hoisted + /// to the \p FC0 Preheader or sunk into the \p FC1 Body + bool collectMovablePreheaderInsts( + const FusionCandidate &FC0, const FusionCandidate &FC1, + SmallVector<Instruction *, 4> &SafeToHoist, + SmallVector<Instruction *, 4> &SafeToSink) const { + BasicBlock *FC1Preheader = FC1.Preheader; + // Save the instructions that are not being hoisted, so we know not to hoist + // mem insts that they dominate. + SmallVector<Instruction *, 4> NotHoisting; + + for (Instruction &I : *FC1Preheader) { + // Can't move a branch + if (&I == FC1Preheader->getTerminator()) + continue; + // If the instruction has side-effects, give up. + // TODO: The case of mayReadFromMemory we can handle but requires + // additional work with a dependence analysis so for now we give + // up on memory reads. + if (I.mayThrow() || !I.willReturn()) { + LLVM_DEBUG(dbgs() << "Inst: " << I << " may throw or won't return.\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "Checking Inst: " << I << "\n"); + + if (I.isAtomic() || I.isVolatile()) { + LLVM_DEBUG( + dbgs() << "\tInstruction is volatile or atomic. Cannot move it.\n"); + return false; + } + + if (canHoistInst(I, SafeToHoist, NotHoisting, FC0)) { + SafeToHoist.push_back(&I); + LLVM_DEBUG(dbgs() << "\tSafe to hoist.\n"); + } else { + LLVM_DEBUG(dbgs() << "\tCould not hoist. Trying to sink...\n"); + NotHoisting.push_back(&I); + + if (canSinkInst(I, FC1)) { + SafeToSink.push_back(&I); + LLVM_DEBUG(dbgs() << "\tSafe to sink.\n"); + } else { + LLVM_DEBUG(dbgs() << "\tCould not sink.\n"); + return false; + } + } + } + LLVM_DEBUG( + dbgs() << "All preheader instructions could be sunk or hoisted!\n"); + return true; + } + /// Rewrite all additive recurrences in a SCEV to use a new loop. class AddRecLoopReplacer : public SCEVRewriteVisitor<AddRecLoopReplacer> { public: @@ -1034,7 +1244,7 @@ private: const Loop *ExprL = Expr->getLoop(); SmallVector<const SCEV *, 2> Operands; if (ExprL == &OldL) { - Operands.append(Expr->op_begin(), Expr->op_end()); + append_range(Operands, Expr->operands()); return SE.getAddRecExpr(Operands, &NewL, Expr->getNoWrapFlags()); } @@ -1235,6 +1445,46 @@ private: return FC0.ExitBlock == FC1.getEntryBlock(); } + bool isEmptyPreheader(const FusionCandidate &FC) const { + return FC.Preheader->size() == 1; + } + + /// Hoist \p FC1 Preheader instructions to \p FC0 Preheader + /// and sink others into the body of \p FC1. + void movePreheaderInsts(const FusionCandidate &FC0, + const FusionCandidate &FC1, + SmallVector<Instruction *, 4> &HoistInsts, + SmallVector<Instruction *, 4> &SinkInsts) const { + // All preheader instructions except the branch must be hoisted or sunk + assert(HoistInsts.size() + SinkInsts.size() == FC1.Preheader->size() - 1 && + "Attempting to sink and hoist preheader instructions, but not all " + "the preheader instructions are accounted for."); + + NumHoistedInsts += HoistInsts.size(); + NumSunkInsts += SinkInsts.size(); + + LLVM_DEBUG(if (VerboseFusionDebugging) { + if (!HoistInsts.empty()) + dbgs() << "Hoisting: \n"; + for (Instruction *I : HoistInsts) + dbgs() << *I << "\n"; + if (!SinkInsts.empty()) + dbgs() << "Sinking: \n"; + for (Instruction *I : SinkInsts) + dbgs() << *I << "\n"; + }); + + for (Instruction *I : HoistInsts) { + assert(I->getParent() == FC1.Preheader); + I->moveBefore(FC0.Preheader->getTerminator()); + } + // insert instructions in reverse order to maintain dominance relationship + for (Instruction *I : reverse(SinkInsts)) { + assert(I->getParent() == FC1.Preheader); + I->moveBefore(&*FC1.ExitBlock->getFirstInsertionPt()); + } + } + /// Determine if two fusion candidates have identical guards /// /// This method will determine if two fusion candidates have the same guards. @@ -1480,6 +1730,7 @@ private: // mergeLatch may remove the only block in FC1. SE.forgetLoop(FC1.L); SE.forgetLoop(FC0.L); + SE.forgetLoopDispositions(); // Move instructions from FC0.Latch to FC1.Latch. // Note: mergeLatch requires an updated DT. @@ -1772,6 +2023,7 @@ private: // mergeLatch may remove the only block in FC1. SE.forgetLoop(FC1.L); SE.forgetLoop(FC0.L); + SE.forgetLoopDispositions(); // Move instructions from FC0.Latch to FC1.Latch. // Note: mergeLatch requires an updated DT. @@ -1838,6 +2090,7 @@ struct LoopFuseLegacy : public FunctionPass { bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &DI = getAnalysis<DependenceAnalysisWrapperPass>().getDI(); @@ -1866,8 +2119,19 @@ PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) { const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); const DataLayout &DL = F.getParent()->getDataLayout(); + // Ensure loops are in simplifed form which is a pre-requisite for loop fusion + // pass. Added only for new PM since the legacy PM has already added + // LoopSimplify pass as a dependency. + bool Changed = false; + for (auto &L : LI) { + Changed |= + simplifyLoop(L, &DT, &LI, &SE, &AC, nullptr, false /* PreserveLCSSA */); + } + if (Changed) + PDT.recalculate(F); + LoopFuser LF(LI, DT, DI, SE, PDT, ORE, DL, AC, TTI); - bool Changed = LF.fuseLoops(F); + Changed |= LF.fuseLoops(F); if (!Changed) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 3ed022f65d9a..035cbdf595a8 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -441,7 +441,7 @@ static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) { // array. We could theoretically do a store to an alloca or something, but // that doesn't seem worthwhile. Constant *C = dyn_cast<Constant>(V); - if (!C) + if (!C || isa<ConstantExpr>(C)) return nullptr; // Only handle simple values that are a power of two bytes in size. @@ -496,8 +496,8 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) { // When storing out scalable vectors we bail out for now, since the code // below currently only works for constant strides. TypeSize SizeInBits = DL->getTypeSizeInBits(StoredVal->getType()); - if (SizeInBits.isScalable() || (SizeInBits.getFixedSize() & 7) || - (SizeInBits.getFixedSize() >> 32) != 0) + if (SizeInBits.isScalable() || (SizeInBits.getFixedValue() & 7) || + (SizeInBits.getFixedValue() >> 32) != 0) return LegalStoreKind::None; // See if the pointer expression is an AddRec like {base,+,1} on the current @@ -1028,8 +1028,7 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, for (BasicBlock *B : L->blocks()) for (Instruction &I : *B) if (!IgnoredInsts.contains(&I) && - isModOrRefSet( - intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) + isModOrRefSet(AA.getModRefInfo(&I, StoreLoc) & Access)) return true; return false; } @@ -1273,6 +1272,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, StoreEv, LoadEv, BECount); } +namespace { class MemmoveVerifier { public: explicit MemmoveVerifier(const Value &LoadBasePtr, const Value &StoreBasePtr, @@ -1296,7 +1296,7 @@ public: // Ensure that LoadBasePtr is after StoreBasePtr or before StoreBasePtr // for negative stride. LoadBasePtr shouldn't overlap with StoreBasePtr. int64_t LoadSize = - DL.getTypeSizeInBits(TheLoad.getType()).getFixedSize() / 8; + DL.getTypeSizeInBits(TheLoad.getType()).getFixedValue() / 8; if (BP1 != BP2 || LoadSize != int64_t(StoreSize)) return false; if ((!IsNegStride && LoadOff < StoreOff + int64_t(StoreSize)) || @@ -1316,6 +1316,7 @@ private: public: const bool IsSameObject; }; +} // namespace bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( Value *DestPtr, Value *SourcePtr, const SCEV *StoreSizeSCEV, @@ -1483,7 +1484,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // anything where the alignment isn't at least the element size. assert((StoreAlign && LoadAlign) && "Expect unordered load/store to have align."); - if (StoreAlign.value() < StoreSize || LoadAlign.value() < StoreSize) + if (*StoreAlign < StoreSize || *LoadAlign < StoreSize) return Changed; // If the element.atomic memcpy is not lowered into explicit @@ -1497,9 +1498,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( // Note that unordered atomic loads/stores are *required* by the spec to // have an alignment but non-atomic loads/stores may not. NewCall = Builder.CreateElementUnorderedAtomicMemCpy( - StoreBasePtr, StoreAlign.value(), LoadBasePtr, LoadAlign.value(), - NumBytes, StoreSize, AATags.TBAA, AATags.TBAAStruct, AATags.Scope, - AATags.NoAlias); + StoreBasePtr, *StoreAlign, LoadBasePtr, *LoadAlign, NumBytes, StoreSize, + AATags.TBAA, AATags.TBAAStruct, AATags.Scope, AATags.NoAlias); } NewCall->setDebugLoc(TheStore->getDebugLoc()); diff --git a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index 4249512ea0f8..c9798a80978d 100644 --- a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -35,6 +35,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include <optional> #include <utility> using namespace llvm; @@ -214,14 +215,14 @@ public: PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (AR.MSSA) { MSSAU = MemorySSAUpdater(AR.MSSA); if (VerifyMemorySSA) AR.MSSA->verifyMemorySSA(); } if (!simplifyLoopInst(L, AR.DT, AR.LI, AR.AC, AR.TLI, - MSSAU ? MSSAU.getPointer() : nullptr)) + MSSAU ? &*MSSAU : nullptr)) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 18daa4295224..0a7c62113c7f 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -44,6 +44,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <cassert> @@ -86,7 +87,8 @@ static void printDepMatrix(CharMatrix &DepMatrix) { #endif static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, - Loop *L, DependenceInfo *DI) { + Loop *L, DependenceInfo *DI, + ScalarEvolution *SE) { using ValueVector = SmallVector<Value *, 16>; ValueVector MemInstr; @@ -125,6 +127,10 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, // Track Output, Flow, and Anti dependencies. if (auto D = DI->depends(Src, Dst, true)) { assert(D->isOrdered() && "Expected an output, flow or anti dep."); + // If the direction vector is negative, normalize it to + // make it non-negative. + if (D->normalize(SE)) + LLVM_DEBUG(dbgs() << "Negative dependence vector normalized.\n"); LLVM_DEBUG(StringRef DepType = D->isFlow() ? "flow" : D->isAnti() ? "anti" : "output"; dbgs() << "Found " << DepType @@ -133,19 +139,7 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, unsigned Levels = D->getLevels(); char Direction; for (unsigned II = 1; II <= Levels; ++II) { - const SCEV *Distance = D->getDistance(II); - const SCEVConstant *SCEVConst = - dyn_cast_or_null<SCEVConstant>(Distance); - if (SCEVConst) { - const ConstantInt *CI = SCEVConst->getValue(); - if (CI->isNegative()) - Direction = '<'; - else if (CI->isZero()) - Direction = '='; - else - Direction = '>'; - Dep.push_back(Direction); - } else if (D->isScalar(II)) { + if (D->isScalar(II)) { Direction = 'S'; Dep.push_back(Direction); } else { @@ -188,80 +182,36 @@ static void interChangeDependencies(CharMatrix &DepMatrix, unsigned FromIndx, std::swap(DepMatrix[I][ToIndx], DepMatrix[I][FromIndx]); } -// Checks if outermost non '=','S'or'I' dependence in the dependence matrix is -// '>' -static bool isOuterMostDepPositive(CharMatrix &DepMatrix, unsigned Row, - unsigned Column) { - for (unsigned i = 0; i <= Column; ++i) { - if (DepMatrix[Row][i] == '<') - return false; - if (DepMatrix[Row][i] == '>') +// After interchanging, check if the direction vector is valid. +// [Theorem] A permutation of the loops in a perfect nest is legal if and only +// if the direction matrix, after the same permutation is applied to its +// columns, has no ">" direction as the leftmost non-"=" direction in any row. +static bool isLexicographicallyPositive(std::vector<char> &DV) { + for (unsigned Level = 0; Level < DV.size(); ++Level) { + unsigned char Direction = DV[Level]; + if (Direction == '<') return true; - } - // All dependencies were '=','S' or 'I' - return false; -} - -// Checks if no dependence exist in the dependency matrix in Row before Column. -static bool containsNoDependence(CharMatrix &DepMatrix, unsigned Row, - unsigned Column) { - for (unsigned i = 0; i < Column; ++i) { - if (DepMatrix[Row][i] != '=' && DepMatrix[Row][i] != 'S' && - DepMatrix[Row][i] != 'I') + if (Direction == '>' || Direction == '*') return false; } return true; } -static bool validDepInterchange(CharMatrix &DepMatrix, unsigned Row, - unsigned OuterLoopId, char InnerDep, - char OuterDep) { - if (isOuterMostDepPositive(DepMatrix, Row, OuterLoopId)) - return false; - - if (InnerDep == OuterDep) - return true; - - // It is legal to interchange if and only if after interchange no row has a - // '>' direction as the leftmost non-'='. - - if (InnerDep == '=' || InnerDep == 'S' || InnerDep == 'I') - return true; - - if (InnerDep == '<') - return true; - - if (InnerDep == '>') { - // If OuterLoopId represents outermost loop then interchanging will make the - // 1st dependency as '>' - if (OuterLoopId == 0) - return false; - - // If all dependencies before OuterloopId are '=','S'or 'I'. Then - // interchanging will result in this row having an outermost non '=' - // dependency of '>' - if (!containsNoDependence(DepMatrix, Row, OuterLoopId)) - return true; - } - - return false; -} - // Checks if it is legal to interchange 2 loops. -// [Theorem] A permutation of the loops in a perfect nest is legal if and only -// if the direction matrix, after the same permutation is applied to its -// columns, has no ">" direction as the leftmost non-"=" direction in any row. static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, unsigned InnerLoopId, unsigned OuterLoopId) { unsigned NumRows = DepMatrix.size(); + std::vector<char> Cur; // For each row check if it is valid to interchange. for (unsigned Row = 0; Row < NumRows; ++Row) { - char InnerDep = DepMatrix[Row][InnerLoopId]; - char OuterDep = DepMatrix[Row][OuterLoopId]; - if (InnerDep == '*' || OuterDep == '*') + // Create temporary DepVector check its lexicographical order + // before and after swapping OuterLoop vs InnerLoop + Cur = DepMatrix[Row]; + if (!isLexicographicallyPositive(Cur)) return false; - if (!validDepInterchange(DepMatrix, Row, OuterLoopId, InnerDep, OuterDep)) + std::swap(Cur[InnerLoopId], Cur[OuterLoopId]); + if (!isLexicographicallyPositive(Cur)) return false; } return true; @@ -361,11 +311,18 @@ public: bool isProfitable(const Loop *InnerLoop, const Loop *OuterLoop, unsigned InnerLoopId, unsigned OuterLoopId, CharMatrix &DepMatrix, - const DenseMap<const Loop *, unsigned> &CostMap); + const DenseMap<const Loop *, unsigned> &CostMap, + std::unique_ptr<CacheCost> &CC); private: int getInstrOrderCost(); - + std::optional<bool> isProfitablePerLoopCacheAnalysis( + const DenseMap<const Loop *, unsigned> &CostMap, + std::unique_ptr<CacheCost> &CC); + std::optional<bool> isProfitablePerInstrOrderCost(); + std::optional<bool> isProfitableForVectorization(unsigned InnerLoopId, + unsigned OuterLoopId, + CharMatrix &DepMatrix); Loop *OuterLoop; Loop *InnerLoop; @@ -486,7 +443,7 @@ struct LoopInterchange { CharMatrix DependencyMatrix; Loop *OuterMostLoop = *(LoopList.begin()); if (!populateDependencyMatrix(DependencyMatrix, LoopNestDepth, - OuterMostLoop, DI)) { + OuterMostLoop, DI, SE)) { LLVM_DEBUG(dbgs() << "Populating dependency matrix failed\n"); return false; } @@ -562,7 +519,7 @@ struct LoopInterchange { LLVM_DEBUG(dbgs() << "Loops are legal to interchange\n"); LoopInterchangeProfitability LIP(OuterLoop, InnerLoop, SE, ORE); if (!LIP.isProfitable(InnerLoop, OuterLoop, InnerLoopId, OuterLoopId, - DependencyMatrix, CostMap)) { + DependencyMatrix, CostMap, CC)) { LLVM_DEBUG(dbgs() << "Interchanging loops not profitable.\n"); return false; } @@ -579,11 +536,7 @@ struct LoopInterchange { LLVM_DEBUG(dbgs() << "Loops interchanged.\n"); LoopsInterchanged++; - assert(InnerLoop->isLCSSAForm(*DT) && - "Inner loop not left in LCSSA form after loop interchange!"); - assert(OuterLoop->isLCSSAForm(*DT) && - "Outer loop not left in LCSSA form after loop interchange!"); - + llvm::formLCSSARecursively(*OuterLoop, *DT, LI, SE); return true; } }; @@ -858,18 +811,26 @@ bool LoopInterchangeLegality::currentLimitations() { } Inductions.clear(); - if (!findInductionAndReductions(InnerLoop, Inductions, nullptr)) { - LLVM_DEBUG( - dbgs() << "Only inner loops with induction or reduction PHI nodes " - << "are supported currently.\n"); - ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Only inner loops with induction or reduction PHI nodes can be" - " interchange currently."; - }); - return true; + // For multi-level loop nests, make sure that all phi nodes for inner loops + // at all levels can be recognized as a induction or reduction phi. Bail out + // if a phi node at a certain nesting level cannot be properly recognized. + Loop *CurLevelLoop = OuterLoop; + while (!CurLevelLoop->getSubLoops().empty()) { + // We already made sure that the loop nest is tightly nested. + CurLevelLoop = CurLevelLoop->getSubLoops().front(); + if (!findInductionAndReductions(CurLevelLoop, Inductions, nullptr)) { + LLVM_DEBUG( + dbgs() << "Only inner loops with induction or reduction PHI nodes " + << "are supported currently.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIInner", + CurLevelLoop->getStartLoc(), + CurLevelLoop->getHeader()) + << "Only inner loops with induction or reduction PHI nodes can be" + " interchange currently."; + }); + return true; + } } // TODO: Triangular loops are not handled for now. @@ -1137,31 +1098,10 @@ int LoopInterchangeProfitability::getInstrOrderCost() { return GoodOrder - BadOrder; } -static bool isProfitableForVectorization(unsigned InnerLoopId, - unsigned OuterLoopId, - CharMatrix &DepMatrix) { - // TODO: Improve this heuristic to catch more cases. - // If the inner loop is loop independent or doesn't carry any dependency it is - // profitable to move this to outer position. - for (auto &Row : DepMatrix) { - if (Row[InnerLoopId] != 'S' && Row[InnerLoopId] != 'I') - return false; - // TODO: We need to improve this heuristic. - if (Row[OuterLoopId] != '=') - return false; - } - // If outer loop has dependence and inner loop is loop independent then it is - // profitable to interchange to enable parallelism. - // If there are no dependences, interchanging will not improve anything. - return !DepMatrix.empty(); -} - -bool LoopInterchangeProfitability::isProfitable( - const Loop *InnerLoop, const Loop *OuterLoop, unsigned InnerLoopId, - unsigned OuterLoopId, CharMatrix &DepMatrix, - const DenseMap<const Loop *, unsigned> &CostMap) { - // TODO: Remove the legacy cost model. - +std::optional<bool> +LoopInterchangeProfitability::isProfitablePerLoopCacheAnalysis( + const DenseMap<const Loop *, unsigned> &CostMap, + std::unique_ptr<CacheCost> &CC) { // This is the new cost model returned from loop cache analysis. // A smaller index means the loop should be placed an outer loop, and vice // versa. @@ -1173,30 +1113,91 @@ bool LoopInterchangeProfitability::isProfitable( LLVM_DEBUG(dbgs() << "InnerIndex = " << InnerIndex << ", OuterIndex = " << OuterIndex << "\n"); if (InnerIndex < OuterIndex) - return true; - } else { - // Legacy cost model: this is rough cost estimation algorithm. It counts the - // good and bad order of induction variables in the instruction and allows - // reordering if number of bad orders is more than good. - int Cost = getInstrOrderCost(); - LLVM_DEBUG(dbgs() << "Cost = " << Cost << "\n"); - if (Cost < -LoopInterchangeCostThreshold) - return true; + return std::optional<bool>(true); + assert(InnerIndex != OuterIndex && "CostMap should assign unique " + "numbers to each loop"); + if (CC->getLoopCost(*OuterLoop) == CC->getLoopCost(*InnerLoop)) + return std::nullopt; + return std::optional<bool>(false); } + return std::nullopt; +} - // It is not profitable as per current cache profitability model. But check if - // we can move this loop outside to improve parallelism. - if (isProfitableForVectorization(InnerLoopId, OuterLoopId, DepMatrix)) - return true; +std::optional<bool> +LoopInterchangeProfitability::isProfitablePerInstrOrderCost() { + // Legacy cost model: this is rough cost estimation algorithm. It counts the + // good and bad order of induction variables in the instruction and allows + // reordering if number of bad orders is more than good. + int Cost = getInstrOrderCost(); + LLVM_DEBUG(dbgs() << "Cost = " << Cost << "\n"); + if (Cost < 0 && Cost < LoopInterchangeCostThreshold) + return std::optional<bool>(true); + + return std::nullopt; +} - ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "InterchangeNotProfitable", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Interchanging loops is too costly and it does not improve " - "parallelism."; - }); - return false; +std::optional<bool> LoopInterchangeProfitability::isProfitableForVectorization( + unsigned InnerLoopId, unsigned OuterLoopId, CharMatrix &DepMatrix) { + for (auto &Row : DepMatrix) { + // If the inner loop is loop independent or doesn't carry any dependency + // it is not profitable to move this to outer position, since we are + // likely able to do inner loop vectorization already. + if (Row[InnerLoopId] == 'I' || Row[InnerLoopId] == '=') + return std::optional<bool>(false); + + // If the outer loop is not loop independent it is not profitable to move + // this to inner position, since doing so would not enable inner loop + // parallelism. + if (Row[OuterLoopId] != 'I' && Row[OuterLoopId] != '=') + return std::optional<bool>(false); + } + // If inner loop has dependence and outer loop is loop independent then it + // is/ profitable to interchange to enable inner loop parallelism. + // If there are no dependences, interchanging will not improve anything. + return std::optional<bool>(!DepMatrix.empty()); +} + +bool LoopInterchangeProfitability::isProfitable( + const Loop *InnerLoop, const Loop *OuterLoop, unsigned InnerLoopId, + unsigned OuterLoopId, CharMatrix &DepMatrix, + const DenseMap<const Loop *, unsigned> &CostMap, + std::unique_ptr<CacheCost> &CC) { + // isProfitable() is structured to avoid endless loop interchange. + // If loop cache analysis could decide the profitability then, + // profitability check will stop and return the analysis result. + // If cache analysis failed to analyze the loopnest (e.g., + // due to delinearization issues) then only check whether it is + // profitable for InstrOrderCost. Likewise, if InstrOrderCost failed to + // analysis the profitability then only, isProfitableForVectorization + // will decide. + std::optional<bool> shouldInterchange = + isProfitablePerLoopCacheAnalysis(CostMap, CC); + if (!shouldInterchange.has_value()) { + shouldInterchange = isProfitablePerInstrOrderCost(); + if (!shouldInterchange.has_value()) + shouldInterchange = + isProfitableForVectorization(InnerLoopId, OuterLoopId, DepMatrix); + } + if (!shouldInterchange.has_value()) { + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "InterchangeNotProfitable", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Insufficient information to calculate the cost of loop for " + "interchange."; + }); + return false; + } else if (!shouldInterchange.value()) { + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "InterchangeNotProfitable", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Interchanging loops is not considered to improve cache " + "locality nor vectorization."; + }); + return false; + } + return true; } void LoopInterchangeTransform::removeChildLoop(Loop *OuterLoop, @@ -1286,7 +1287,6 @@ void LoopInterchangeTransform::restructureLoops( // Tell SE that we move the loops around. SE->forgetLoop(NewOuter); - SE->forgetLoop(NewInner); } bool LoopInterchangeTransform::transform() { @@ -1360,9 +1360,11 @@ bool LoopInterchangeTransform::transform() { for (Instruction *InnerIndexVar : InnerIndexVarList) WorkList.insert(cast<Instruction>(InnerIndexVar)); MoveInstructions(); + } - // Splits the inner loops phi nodes out into a separate basic block. - BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); + // Ensure the inner loop phi nodes have a separate basic block. + BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); + if (InnerLoopHeader->getFirstNonPHI() != InnerLoopHeader->getTerminator()) { SplitBlock(InnerLoopHeader, InnerLoopHeader->getFirstNonPHI(), DT, LI); LLVM_DEBUG(dbgs() << "splitting InnerLoopHeader done\n"); } @@ -1394,11 +1396,10 @@ bool LoopInterchangeTransform::transform() { /// \brief Move all instructions except the terminator from FromBB right before /// InsertBefore static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { - auto &ToList = InsertBefore->getParent()->getInstList(); - auto &FromList = FromBB->getInstList(); + BasicBlock *ToBB = InsertBefore->getParent(); - ToList.splice(InsertBefore->getIterator(), FromList, FromList.begin(), - FromBB->getTerminator()->getIterator()); + ToBB->splice(InsertBefore->getIterator(), FromBB, FromBB->begin(), + FromBB->getTerminator()->getIterator()); } /// Swap instructions between \p BB1 and \p BB2 but keep terminators intact. @@ -1773,5 +1774,6 @@ PreservedAnalyses LoopInterchangePass::run(LoopNest &LN, OptimizationRemarkEmitter ORE(&F); if (!LoopInterchange(&AR.SE, &AR.LI, &DI, &AR.DT, CC, &ORE).run(LN)) return PreservedAnalyses::all(); + U.markLoopNestChanged(true); return getLoopPassPreservedAnalyses(); } diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 1877ac1dfd08..b615a0a0a9c0 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -98,20 +98,21 @@ struct StoreToLoadForwardingCandidate { Value *LoadPtr = Load->getPointerOperand(); Value *StorePtr = Store->getPointerOperand(); Type *LoadType = getLoadStoreType(Load); + auto &DL = Load->getParent()->getModule()->getDataLayout(); assert(LoadPtr->getType()->getPointerAddressSpace() == StorePtr->getType()->getPointerAddressSpace() && - LoadType == getLoadStoreType(Store) && + DL.getTypeSizeInBits(LoadType) == + DL.getTypeSizeInBits(getLoadStoreType(Store)) && "Should be a known dependence"); // Currently we only support accesses with unit stride. FIXME: we should be // able to handle non unit stirde as well as long as the stride is equal to // the dependence distance. - if (getPtrStride(PSE, LoadType, LoadPtr, L) != 1 || - getPtrStride(PSE, LoadType, StorePtr, L) != 1) + if (getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0) != 1 || + getPtrStride(PSE, LoadType, StorePtr, L).value_or(0) != 1) return false; - auto &DL = Load->getParent()->getModule()->getDataLayout(); unsigned TypeByteSize = DL.getTypeAllocSize(const_cast<Type *>(LoadType)); auto *LoadPtrSCEV = cast<SCEVAddRecExpr>(PSE.getSCEV(LoadPtr)); @@ -211,9 +212,10 @@ public: if (!Load) continue; - // Only progagate the value if they are of the same type. - if (Store->getPointerOperandType() != Load->getPointerOperandType() || - getLoadStoreType(Store) != getLoadStoreType(Load)) + // Only propagate if the stored values are bit/pointer castable. + if (!CastInst::isBitOrNoopPointerCastable( + getLoadStoreType(Store), getLoadStoreType(Load), + Store->getParent()->getModule()->getDataLayout())) continue; Candidates.emplace_front(Load, Store); @@ -438,7 +440,21 @@ public: PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded", &L->getHeader()->front()); PHI->addIncoming(Initial, PH); - PHI->addIncoming(Cand.Store->getOperand(0), L->getLoopLatch()); + + Type *LoadType = Initial->getType(); + Type *StoreType = Cand.Store->getValueOperand()->getType(); + auto &DL = Cand.Load->getParent()->getModule()->getDataLayout(); + (void)DL; + + assert(DL.getTypeSizeInBits(LoadType) == DL.getTypeSizeInBits(StoreType) && + "The type sizes should match!"); + + Value *StoreValue = Cand.Store->getValueOperand(); + if (LoadType != StoreType) + StoreValue = CastInst::CreateBitOrPointerCast( + StoreValue, LoadType, "store_forward_cast", Cand.Store); + + PHI->addIncoming(StoreValue, L->getLoopLatch()); Cand.Load->replaceAllUsesWith(PHI); } @@ -605,11 +621,12 @@ private: } // end anonymous namespace -static bool -eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, - BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, - ScalarEvolution *SE, AssumptionCache *AC, - function_ref<const LoopAccessInfo &(Loop &)> GetLAI) { +static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, + DominatorTree &DT, + BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, + ScalarEvolution *SE, AssumptionCache *AC, + LoopAccessInfoManager &LAIs) { // Build up a worklist of inner-loops to transform to avoid iterator // invalidation. // FIXME: This logic comes from other passes that actually change the loop @@ -633,8 +650,10 @@ eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, if (!L->isRotatedForm() || !L->getExitingBlock()) continue; // The actual work is performed by LoadEliminationForLoop. - LoadEliminationForLoop LEL(L, &LI, GetLAI(*L), &DT, BFI, PSI); + LoadEliminationForLoop LEL(L, &LI, LAIs.getInfo(*L), &DT, BFI, PSI); Changed |= LEL.processLoop(); + if (Changed) + LAIs.clear(); } return Changed; } @@ -656,7 +675,7 @@ public: return false; auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto &LAA = getAnalysis<LoopAccessLegacyAnalysis>(); + auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); auto *BFI = (PSI && PSI->hasProfileSummary()) ? @@ -665,9 +684,8 @@ public: auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); // Process each loop nest in the function. - return eliminateLoadsAcrossLoops( - F, LI, DT, BFI, PSI, SE, /*AC*/ nullptr, - [&LAA](Loop &L) -> const LoopAccessInfo & { return LAA.getInfo(&L); }); + return eliminateLoadsAcrossLoops(F, LI, DT, BFI, PSI, SE, /*AC*/ nullptr, + LAIs); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -712,23 +730,15 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F, if (LI.empty()) return PreservedAnalyses::all(); auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); - auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - auto &AA = AM.getResult<AAManager>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); auto *BFI = (PSI && PSI->hasProfileSummary()) ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; + LoopAccessInfoManager &LAIs = AM.getResult<LoopAccessAnalysis>(F); - auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); - bool Changed = eliminateLoadsAcrossLoops( - F, LI, DT, BFI, PSI, &SE, &AC, [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, - TLI, TTI, nullptr, nullptr, nullptr}; - return LAM.getResult<LoopAccessAnalysis>(L, AR); - }); + bool Changed = eliminateLoadsAcrossLoops(F, LI, DT, BFI, PSI, &SE, &AC, LAIs); if (!Changed) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index d20d275ea60c..c98b94b56e48 100644 --- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -84,9 +84,10 @@ LoopPassManager::runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM, // invalid when encountering a loop-nest pass. std::unique_ptr<LoopNest> LoopNestPtr; bool IsLoopNestPtrValid = false; + Loop *OuterMostLoop = &L; for (size_t I = 0, E = IsLoopNestPass.size(); I != E; ++I) { - Optional<PreservedAnalyses> PassPA; + std::optional<PreservedAnalyses> PassPA; if (!IsLoopNestPass[I]) { // The `I`-th pass is a loop pass. auto &Pass = LoopPasses[LoopPassIndex++]; @@ -97,10 +98,18 @@ LoopPassManager::runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM, // If the loop-nest object calculated before is no longer valid, // re-calculate it here before running the loop-nest pass. - if (!IsLoopNestPtrValid) { - LoopNestPtr = LoopNest::getLoopNest(L, AR.SE); + // + // FIXME: PreservedAnalysis should not be abused to tell if the + // status of loopnest has been changed. We should use and only + // use LPMUpdater for this purpose. + if (!IsLoopNestPtrValid || U.isLoopNestChanged()) { + while (auto *ParentLoop = OuterMostLoop->getParentLoop()) + OuterMostLoop = ParentLoop; + LoopNestPtr = LoopNest::getLoopNest(*OuterMostLoop, AR.SE); IsLoopNestPtrValid = true; + U.markLoopNestChanged(false); } + PassPA = runSinglePass(*LoopNestPtr, Pass, AM, AR, U, PI); } @@ -118,7 +127,7 @@ LoopPassManager::runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM, // Update the analysis manager as each pass runs and potentially // invalidates analyses. - AM.invalidate(L, *PassPA); + AM.invalidate(IsLoopNestPass[I] ? *OuterMostLoop : L, *PassPA); // Finally, we intersect the final preserved analyses to compute the // aggregate preserved set for this pass manager. @@ -130,7 +139,7 @@ LoopPassManager::runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM, // After running the loop pass, the parent loop might change and we need to // notify the updater, otherwise U.ParentL might gets outdated and triggers // assertion failures in addSiblingLoops and addChildLoops. - U.setParentLoop(L.getParentLoop()); + U.setParentLoop((IsLoopNestPass[I] ? *OuterMostLoop : L).getParentLoop()); } return PA; } @@ -148,7 +157,8 @@ LoopPassManager::runWithoutLoopNestPasses(Loop &L, LoopAnalysisManager &AM, // instrumenting callbacks for the passes later. PassInstrumentation PI = AM.getResult<PassInstrumentationAnalysis>(L, AR); for (auto &Pass : LoopPasses) { - Optional<PreservedAnalyses> PassPA = runSinglePass(L, Pass, AM, AR, U, PI); + std::optional<PreservedAnalyses> PassPA = + runSinglePass(L, Pass, AM, AR, U, PI); // `PassPA` is `None` means that the before-pass callbacks in // `PassInstrumentation` return false. The pass does not run in this case, @@ -259,10 +269,11 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, PI.pushBeforeNonSkippedPassCallback([&LAR, &LI](StringRef PassID, Any IR) { if (isSpecialPass(PassID, {"PassManager"})) return; - assert(any_isa<const Loop *>(IR) || any_isa<const LoopNest *>(IR)); - const Loop *L = any_isa<const Loop *>(IR) - ? any_cast<const Loop *>(IR) - : &any_cast<const LoopNest *>(IR)->getOutermostLoop(); + assert(any_cast<const Loop *>(&IR) || any_cast<const LoopNest *>(&IR)); + const Loop **LPtr = any_cast<const Loop *>(&IR); + const Loop *L = LPtr ? *LPtr : nullptr; + if (!L) + L = &any_cast<const LoopNest *>(IR)->getOutermostLoop(); assert(L && "Loop should be valid for printing"); // Verify the loop structure and LCSSA form before visiting the loop. @@ -291,11 +302,7 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, if (!PI.runBeforePass<Loop>(*Pass, *L)) continue; - PreservedAnalyses PassPA; - { - TimeTraceScope TimeScope(Pass->name()); - PassPA = Pass->run(*L, LAM, LAR, Updater); - } + PreservedAnalyses PassPA = Pass->run(*L, LAM, LAR, Updater); // Do not pass deleted Loop into the instrumentation. if (Updater.skipCurrentLoop()) diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp index b327d38d2a84..49c0fff84d81 100644 --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -191,6 +191,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -200,6 +201,7 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" +#include <optional> #define DEBUG_TYPE "loop-predication" @@ -233,6 +235,13 @@ static cl::opt<bool> PredicateWidenableBranchGuards( "expressed as widenable branches to deoptimize blocks"), cl::init(true)); +static cl::opt<bool> InsertAssumesOfPredicatedGuardsConditions( + "loop-predication-insert-assumes-of-predicated-guards-conditions", + cl::Hidden, + cl::desc("Whether or not we should insert assumes of conditions of " + "predicated guards"), + cl::init(true)); + namespace { /// Represents an induction variable check: /// icmp Pred, <induction variable>, <loop invariant limit> @@ -263,8 +272,8 @@ class LoopPredication { LoopICmp LatchCheck; bool isSupportedStep(const SCEV* Step); - Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); - Optional<LoopICmp> parseLoopLatchICmp(); + std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); + std::optional<LoopICmp> parseLoopLatchICmp(); /// Return an insertion point suitable for inserting a safe to speculate /// instruction whose only user will be 'User' which has operands 'Ops'. A @@ -287,16 +296,17 @@ class LoopPredication { ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); - Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, - Instruction *Guard); - Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, - LoopICmp RangeCheck, - SCEVExpander &Expander, - Instruction *Guard); - Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, - LoopICmp RangeCheck, - SCEVExpander &Expander, - Instruction *Guard); + std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, + SCEVExpander &Expander, + Instruction *Guard); + std::optional<Value *> + widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, + SCEVExpander &Expander, + Instruction *Guard); + std::optional<Value *> + widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, + SCEVExpander &Expander, + Instruction *Guard); unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition, SCEVExpander &Expander, Instruction *Guard); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); @@ -376,18 +386,17 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, return PA; } -Optional<LoopICmp> -LoopPredication::parseLoopICmp(ICmpInst *ICI) { +std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) { auto Pred = ICI->getPredicate(); auto *LHS = ICI->getOperand(0); auto *RHS = ICI->getOperand(1); const SCEV *LHSS = SE->getSCEV(LHS); if (isa<SCEVCouldNotCompute>(LHSS)) - return None; + return std::nullopt; const SCEV *RHSS = SE->getSCEV(RHS); if (isa<SCEVCouldNotCompute>(RHSS)) - return None; + return std::nullopt; // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV if (SE->isLoopInvariant(LHSS, L)) { @@ -398,7 +407,7 @@ LoopPredication::parseLoopICmp(ICmpInst *ICI) { const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS); if (!AR || AR->getLoop() != L) - return None; + return std::nullopt; return LoopICmp(Pred, AR, RHSS); } @@ -446,8 +455,8 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL, Type *RangeCheckType) { if (!EnableIVTruncation) return false; - assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedSize() > - DL.getTypeSizeInBits(RangeCheckType).getFixedSize() && + assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() > + DL.getTypeSizeInBits(RangeCheckType).getFixedValue() && "Expected latch check IV type to be larger than range check operand " "type!"); // The start and end values of the IV should be known. This is to guarantee @@ -467,7 +476,7 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL, // guarantees that truncating the latch check to RangeCheckType is a safe // operation. auto RangeCheckTypeBitSize = - DL.getTypeSizeInBits(RangeCheckType).getFixedSize(); + DL.getTypeSizeInBits(RangeCheckType).getFixedValue(); return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; } @@ -475,20 +484,20 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL, // Return an LoopICmp describing a latch check equivlent to LatchCheck but with // the requested type if safe to do so. May involve the use of a new IV. -static Optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, - ScalarEvolution &SE, - const LoopICmp LatchCheck, - Type *RangeCheckType) { +static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, + ScalarEvolution &SE, + const LoopICmp LatchCheck, + Type *RangeCheckType) { auto *LatchType = LatchCheck.IV->getType(); if (RangeCheckType == LatchType) return LatchCheck; // For now, bail out if latch type is narrower than range type. - if (DL.getTypeSizeInBits(LatchType).getFixedSize() < - DL.getTypeSizeInBits(RangeCheckType).getFixedSize()) - return None; + if (DL.getTypeSizeInBits(LatchType).getFixedValue() < + DL.getTypeSizeInBits(RangeCheckType).getFixedValue()) + return std::nullopt; if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType)) - return None; + return std::nullopt; // We can now safely identify the truncated version of the IV and limit for // RangeCheckType. LoopICmp NewLatchCheck; @@ -496,7 +505,7 @@ static Optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( SE.getTruncateExpr(LatchCheck.IV, RangeCheckType)); if (!NewLatchCheck.IV) - return None; + return std::nullopt; NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType); LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType << "can be represented as range check type:" @@ -562,15 +571,15 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) { if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) if (const auto *LI = dyn_cast<LoadInst>(U->getValue())) if (LI->isUnordered() && L->hasLoopInvariantOperands(LI)) - if (AA->pointsToConstantMemory(LI->getOperand(0)) || + if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) || LI->hasMetadata(LLVMContext::MD_invariant_load)) return true; return false; } -Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( - LoopICmp LatchCheck, LoopICmp RangeCheck, - SCEVExpander &Expander, Instruction *Guard) { +std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( + LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, + Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); // Generate the widened condition for the forward loop: // guardStart u< guardLimit && @@ -590,12 +599,12 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( !isLoopInvariantValue(LatchStart) || !isLoopInvariantValue(LatchLimit)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; + return std::nullopt; } if (!Expander.isSafeToExpandAt(LatchStart, Guard) || !Expander.isSafeToExpandAt(LatchLimit, Guard)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; + return std::nullopt; } // guardLimit - guardStart + latchStart - 1 @@ -617,9 +626,9 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } -Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( - LoopICmp LatchCheck, LoopICmp RangeCheck, - SCEVExpander &Expander, Instruction *Guard) { +std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( + LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, + Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); const SCEV *GuardStart = RangeCheck.IV->getStart(); const SCEV *GuardLimit = RangeCheck.Limit; @@ -633,12 +642,12 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( !isLoopInvariantValue(LatchStart) || !isLoopInvariantValue(LatchLimit)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; + return std::nullopt; } if (!Expander.isSafeToExpandAt(LatchStart, Guard) || !Expander.isSafeToExpandAt(LatchLimit, Guard)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; + return std::nullopt; } // The decrement of the latch check IV should be the same as the // rangeCheckIV. @@ -647,7 +656,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " << *PostDecLatchCheckIV << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); - return None; + return std::nullopt; } // Generate the widened condition for CountDownLoop: @@ -676,13 +685,12 @@ static void normalizePredicate(ScalarEvolution *SE, Loop *L, ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; } - /// If ICI can be widened to a loop invariant condition emits the loop /// invariant condition in the loop preheader and return it, otherwise -/// returns None. -Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, - SCEVExpander &Expander, - Instruction *Guard) { +/// returns std::nullopt. +std::optional<Value *> +LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, + Instruction *Guard) { LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); LLVM_DEBUG(ICI->dump()); @@ -693,26 +701,26 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, auto RangeCheck = parseLoopICmp(ICI); if (!RangeCheck) { LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); - return None; + return std::nullopt; } LLVM_DEBUG(dbgs() << "Guard check:\n"); LLVM_DEBUG(RangeCheck->dump()); if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { LLVM_DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred << ")!\n"); - return None; + return std::nullopt; } auto *RangeCheckIV = RangeCheck->IV; if (!RangeCheckIV->isAffine()) { LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n"); - return None; + return std::nullopt; } auto *Step = RangeCheckIV->getStepRecurrence(*SE); // We cannot just compare with latch IV step because the latch and range IVs // may have different types. if (!isSupportedStep(Step)) { LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); - return None; + return std::nullopt; } auto *Ty = RangeCheckIV->getType(); auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty); @@ -720,7 +728,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " "corresponding to range type: " << *Ty << "\n"); - return None; + return std::nullopt; } LoopICmp CurrLatchCheck = *CurrLatchCheckOpt; @@ -731,7 +739,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, "Range and latch steps should be of same type!"); if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) { LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n"); - return None; + return std::nullopt; } if (Step->isOne()) @@ -756,17 +764,17 @@ unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks, // resulting list of subconditions in Checks vector. SmallVector<Value *, 4> Worklist(1, Condition); SmallPtrSet<Value *, 4> Visited; + Visited.insert(Condition); Value *WideableCond = nullptr; do { Value *Condition = Worklist.pop_back_val(); - if (!Visited.insert(Condition).second) - continue; - Value *LHS, *RHS; using namespace llvm::PatternMatch; if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) { - Worklist.push_back(LHS); - Worklist.push_back(RHS); + if (Visited.insert(LHS).second) + Worklist.push_back(LHS); + if (Visited.insert(RHS).second) + Worklist.push_back(RHS); continue; } @@ -817,6 +825,10 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, Value *AllChecks = Builder.CreateAnd(Checks); auto *OldCond = Guard->getOperand(0); Guard->setOperand(0, AllChecks); + if (InsertAssumesOfPredicatedGuardsConditions) { + Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard)); + Builder.CreateAssumption(OldCond); + } RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); @@ -829,6 +841,12 @@ bool LoopPredication::widenWidenableBranchGuardConditions( LLVM_DEBUG(dbgs() << "Processing guard:\n"); LLVM_DEBUG(BI->dump()); + Value *Cond, *WC; + BasicBlock *IfTrueBB, *IfFalseBB; + bool Parsed = parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB); + assert(Parsed && "Must be able to parse widenable branch"); + (void)Parsed; + TotalConsidered++; SmallVector<Value *, 4> Checks; unsigned NumWidened = collectChecks(Checks, BI->getCondition(), @@ -843,6 +861,10 @@ bool LoopPredication::widenWidenableBranchGuardConditions( Value *AllChecks = Builder.CreateAnd(Checks); auto *OldCond = BI->getCondition(); BI->setCondition(AllChecks); + if (InsertAssumesOfPredicatedGuardsConditions) { + Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt()); + Builder.CreateAssumption(Cond); + } RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); assert(isGuardAsWidenableBranch(BI) && "Stopped being a guard after transform?"); @@ -851,19 +873,19 @@ bool LoopPredication::widenWidenableBranchGuardConditions( return true; } -Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { +std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { using namespace PatternMatch; BasicBlock *LoopLatch = L->getLoopLatch(); if (!LoopLatch) { LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); - return None; + return std::nullopt; } auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); if (!BI || !BI->isConditional()) { LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n"); - return None; + return std::nullopt; } BasicBlock *TrueDest = BI->getSuccessor(0); assert( @@ -873,12 +895,12 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { auto *ICI = dyn_cast<ICmpInst>(BI->getCondition()); if (!ICI) { LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n"); - return None; + return std::nullopt; } auto Result = parseLoopICmp(ICI); if (!Result) { LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); - return None; + return std::nullopt; } if (TrueDest != L->getHeader()) @@ -888,13 +910,13 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { // recurrence. if (!Result->IV->isAffine()) { LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n"); - return None; + return std::nullopt; } auto *Step = Result->IV->getStepRecurrence(*SE); if (!isSupportedStep(Step)) { LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); - return None; + return std::nullopt; } auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) { @@ -912,13 +934,12 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { if (IsUnsupportedPredicate(Step, Result->Pred)) { LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred << ")!\n"); - return None; + return std::nullopt; } return Result; } - bool LoopPredication::isLoopProfitableToPredicate() { if (SkipProfitabilityChecks) return true; @@ -954,37 +975,24 @@ bool LoopPredication::isLoopProfitableToPredicate() { LatchExitBlock->getTerminatingDeoptimizeCall()) return false; - auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) { - if (!ProfileData || !ProfileData->getOperand(0)) - return false; - if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0))) - if (!MDS->getString().equals("branch_weights")) - return false; - if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors()) - return false; - return true; - }; - MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof); // Latch terminator has no valid profile data, so nothing to check // profitability on. - if (!IsValidProfileData(LatchProfileData, LatchTerm)) + if (!hasValidBranchWeightMD(*LatchTerm)) return true; auto ComputeBranchProbability = [&](const BasicBlock *ExitingBlock, const BasicBlock *ExitBlock) -> BranchProbability { auto *Term = ExitingBlock->getTerminator(); - MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof); unsigned NumSucc = Term->getNumSuccessors(); - if (IsValidProfileData(ProfileData, Term)) { - uint64_t Numerator = 0, Denominator = 0, ProfVal = 0; - for (unsigned i = 0; i < NumSucc; i++) { - ConstantInt *CI = - mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1)); - ProfVal = CI->getValue().getZExtValue(); + if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) { + SmallVector<uint32_t> Weights; + extractBranchWeights(ProfileData, Weights); + uint64_t Numerator = 0, Denominator = 0; + for (auto [i, Weight] : llvm::enumerate(Weights)) { if (Term->getSuccessor(i) == ExitBlock) - Numerator += ProfVal; - Denominator += ProfVal; + Numerator += Weight; + Denominator += Weight; } return BranchProbability::getBranchProbability(Numerator, Denominator); } else { diff --git a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp index f4ef22562341..a0b3189c7e09 100644 --- a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -191,13 +191,14 @@ namespace { using SmallInstructionVector = SmallVector<Instruction *, 16>; using SmallInstructionSet = SmallPtrSet<Instruction *, 16>; + using TinyInstructionVector = SmallVector<Instruction *, 1>; // Map between induction variable and its increment DenseMap<Instruction *, int64_t> IVToIncMap; - // For loop with multiple induction variable, remember the one used only to + // For loop with multiple induction variables, remember the ones used only to // control the loop. - Instruction *LoopControlIV; + TinyInstructionVector LoopControlIVs; // A chain of isomorphic instructions, identified by a single-use PHI // representing a reduction. Only the last value may be used outside the @@ -386,10 +387,10 @@ namespace { TargetLibraryInfo *TLI, DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA, DenseMap<Instruction *, int64_t> &IncrMap, - Instruction *LoopCtrlIV) + TinyInstructionVector LoopCtrlIVs) : Parent(Parent), L(L), SE(SE), AA(AA), TLI(TLI), DT(DT), LI(LI), PreserveLCSSA(PreserveLCSSA), IV(IV), IVToIncMap(IncrMap), - LoopControlIV(LoopCtrlIV) {} + LoopControlIVs(LoopCtrlIVs) {} /// Stage 1: Find all the DAG roots for the induction variable. bool findRoots(); @@ -468,7 +469,7 @@ namespace { // Map between induction variable and its increment DenseMap<Instruction *, int64_t> &IVToIncMap; - Instruction *LoopControlIV; + TinyInstructionVector LoopControlIVs; }; // Check if it is a compare-like instruction whose user is a branch @@ -577,33 +578,28 @@ bool LoopReroll::isLoopControlIV(Loop *L, Instruction *IV) { // be possible to reroll the loop. void LoopReroll::collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs) { - BasicBlock *Header = L->getHeader(); - for (BasicBlock::iterator I = Header->begin(), - IE = Header->getFirstInsertionPt(); I != IE; ++I) { - if (!isa<PHINode>(I)) - continue; - if (!I->getType()->isIntegerTy() && !I->getType()->isPointerTy()) + for (Instruction &IV : L->getHeader()->phis()) { + if (!IV.getType()->isIntegerTy() && !IV.getType()->isPointerTy()) continue; if (const SCEVAddRecExpr *PHISCEV = - dyn_cast<SCEVAddRecExpr>(SE->getSCEV(&*I))) { + dyn_cast<SCEVAddRecExpr>(SE->getSCEV(&IV))) { if (PHISCEV->getLoop() != L) continue; if (!PHISCEV->isAffine()) continue; - auto IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE)); + const auto *IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE)); if (IncSCEV) { - IVToIncMap[&*I] = IncSCEV->getValue()->getSExtValue(); - LLVM_DEBUG(dbgs() << "LRR: Possible IV: " << *I << " = " << *PHISCEV + IVToIncMap[&IV] = IncSCEV->getValue()->getSExtValue(); + LLVM_DEBUG(dbgs() << "LRR: Possible IV: " << IV << " = " << *PHISCEV << "\n"); - if (isLoopControlIV(L, &*I)) { - assert(!LoopControlIV && "Found two loop control only IV"); - LoopControlIV = &(*I); - LLVM_DEBUG(dbgs() << "LRR: Possible loop control only IV: " << *I + if (isLoopControlIV(L, &IV)) { + LoopControlIVs.push_back(&IV); + LLVM_DEBUG(dbgs() << "LRR: Loop control only IV: " << IV << " = " << *PHISCEV << "\n"); } else - PossibleIVs.push_back(&*I); + PossibleIVs.push_back(&IV); } } } @@ -1184,7 +1180,7 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { // Make sure we mark loop-control-only PHIs as used in all iterations. See // comment above LoopReroll::isLoopControlIV for more information. BasicBlock *Header = L->getHeader(); - if (LoopControlIV && LoopControlIV != IV) { + for (Instruction *LoopControlIV : LoopControlIVs) { for (auto *U : LoopControlIV->users()) { Instruction *IVUser = dyn_cast<Instruction>(U); // IVUser could be loop increment or compare @@ -1224,13 +1220,14 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { dbgs() << "LRR: " << KV.second.find_first() << "\t" << *KV.first << "\n"; }); + BatchAAResults BatchAA(*AA); for (unsigned Iter = 1; Iter < Scale; ++Iter) { // In addition to regular aliasing information, we need to look for // instructions from later (future) iterations that have side effects // preventing us from reordering them past other instructions with side // effects. bool FutureSideEffects = false; - AliasSetTracker AST(*AA); + AliasSetTracker AST(BatchAA); // The map between instructions in f(%iv.(i+1)) and f(%iv). DenseMap<Value *, Value *> BaseMap; @@ -1326,15 +1323,16 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { // Make sure that we don't alias with any instruction in the alias set // tracker. If we do, then we depend on a future iteration, and we // can't reroll. - if (RootInst->mayReadFromMemory()) + if (RootInst->mayReadFromMemory()) { for (auto &K : AST) { - if (K.aliasesUnknownInst(RootInst, *AA)) { + if (isModOrRefSet(K.aliasesUnknownInst(RootInst, BatchAA))) { LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << " vs. " << *RootInst << " (depends on future store)\n"); return false; } } + } // If we've past an instruction from a future iteration that may have // side effects, and this instruction might also, then we can't reorder @@ -1631,7 +1629,7 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, const SCEV *BackedgeTakenCount, ReductionTracker &Reductions) { DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI, DT, LI, PreserveLCSSA, - IVToIncMap, LoopControlIV); + IVToIncMap, LoopControlIVs); if (!DAGRoots.findRoots()) return false; @@ -1674,7 +1672,7 @@ bool LoopReroll::runOnLoop(Loop *L) { // reroll (there may be several possible options). SmallInstructionVector PossibleIVs; IVToIncMap.clear(); - LoopControlIV = nullptr; + LoopControlIVs.clear(); collectPossibleIVs(L, PossibleIVs); if (PossibleIVs.empty()) { diff --git a/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/llvm/lib/Transforms/Scalar/LoopRotation.cpp index d9c33b5f335a..ba735adc5b27 100644 --- a/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -25,6 +25,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/LoopRotationUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include <optional> using namespace llvm; #define DEBUG_TYPE "loop-rotate" @@ -55,13 +56,12 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); const SimplifyQuery SQ = getBestSimplifyQuery(AR, DL); - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (AR.MSSA) MSSAU = MemorySSAUpdater(AR.MSSA); - bool Changed = - LoopRotation(&L, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, - MSSAU ? MSSAU.getPointer() : nullptr, SQ, false, Threshold, - false, PrepareForLTO || PrepareForLTOOption); + bool Changed = LoopRotation(&L, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, + MSSAU ? &*MSSAU : nullptr, SQ, false, Threshold, + false, PrepareForLTO || PrepareForLTOOption); if (!Changed) return PreservedAnalyses::all(); @@ -117,7 +117,7 @@ public: auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; // Not requiring MemorySSA and getting it only if available will split // the loop pass pipeline when LoopRotate is being run first. auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); @@ -130,9 +130,9 @@ public: ? DefaultRotationThreshold : MaxHeaderSize; - return LoopRotation(L, LI, TTI, AC, &DT, &SE, - MSSAU ? MSSAU.getPointer() : nullptr, SQ, false, - Threshold, false, PrepareForLTO || PrepareForLTOOption); + return LoopRotation(L, LI, TTI, AC, &DT, &SE, MSSAU ? &*MSSAU : nullptr, SQ, + false, Threshold, false, + PrepareForLTO || PrepareForLTOOption); } }; } // end namespace diff --git a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 083f87436acd..8d59fdff9236 100644 --- a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -32,6 +32,7 @@ #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include <optional> using namespace llvm; #define DEBUG_TYPE "loop-simplifycfg" @@ -371,6 +372,7 @@ private: DeadInstructions.emplace_back(LandingPad); for (Instruction *I : DeadInstructions) { + SE.forgetBlockAndLoopDispositions(I); I->replaceAllUsesWith(PoisonValue::get(I->getType())); I->eraseFromParent(); } @@ -416,6 +418,7 @@ private: DTU.applyUpdates(DTUpdates); DTUpdates.clear(); formLCSSARecursively(*FixLCSSALoop, DT, &LI, &SE); + SE.forgetBlockAndLoopDispositions(); } } @@ -474,7 +477,7 @@ private: NumLoopBlocksDeleted += DeadLoopBlocks.size(); } - /// Constant-fold terminators of blocks acculumated in FoldCandidates into the + /// Constant-fold terminators of blocks accumulated in FoldCandidates into the /// unconditional branches. void foldTerminators() { for (BasicBlock *BB : FoldCandidates) { @@ -595,6 +598,9 @@ public: LLVM_DEBUG(dbgs() << "Constant-folding " << FoldCandidates.size() << " terminators in loop " << Header->getName() << "\n"); + if (!DeadLoopBlocks.empty()) + SE.forgetBlockAndLoopDispositions(); + // Make the actual transforms. handleDeadExits(); foldTerminators(); @@ -655,7 +661,8 @@ static bool constantFoldTerminators(Loop &L, DominatorTree &DT, LoopInfo &LI, } static bool mergeBlocksIntoPredecessors(Loop &L, DominatorTree &DT, - LoopInfo &LI, MemorySSAUpdater *MSSAU) { + LoopInfo &LI, MemorySSAUpdater *MSSAU, + ScalarEvolution &SE) { bool Changed = false; DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); // Copy blocks into a temporary array to avoid iterator invalidation issues @@ -682,6 +689,9 @@ static bool mergeBlocksIntoPredecessors(Loop &L, DominatorTree &DT, Changed = true; } + if (Changed) + SE.forgetBlockAndLoopDispositions(); + return Changed; } @@ -697,7 +707,7 @@ static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, return true; // Eliminate unconditional branches by merging blocks into their predecessors. - Changed |= mergeBlocksIntoPredecessors(L, DT, LI, MSSAU); + Changed |= mergeBlocksIntoPredecessors(L, DT, LI, MSSAU, SE); if (Changed) SE.forgetTopmostLoop(&L); @@ -708,12 +718,12 @@ static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &LPMU) { - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (AR.MSSA) MSSAU = MemorySSAUpdater(AR.MSSA); bool DeleteCurrentLoop = false; - if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE, - MSSAU ? MSSAU.getPointer() : nullptr, DeleteCurrentLoop)) + if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE, MSSAU ? &*MSSAU : nullptr, + DeleteCurrentLoop)) return PreservedAnalyses::all(); if (DeleteCurrentLoop) @@ -741,15 +751,14 @@ public: LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (MSSAA) MSSAU = MemorySSAUpdater(&MSSAA->getMSSA()); if (MSSAA && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); bool DeleteCurrentLoop = false; - bool Changed = - simplifyLoopCFG(*L, DT, LI, SE, MSSAU ? MSSAU.getPointer() : nullptr, - DeleteCurrentLoop); + bool Changed = simplifyLoopCFG(*L, DT, LI, SE, MSSAU ? &*MSSAU : nullptr, + DeleteCurrentLoop); if (DeleteCurrentLoop) LPM.markLoopAsDeleted(*L); return Changed; diff --git a/llvm/lib/Transforms/Scalar/LoopSink.cpp b/llvm/lib/Transforms/Scalar/LoopSink.cpp index dce1af475fb1..21025b0bdb33 100644 --- a/llvm/lib/Transforms/Scalar/LoopSink.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -215,7 +215,7 @@ static bool sinkInstruction( BasicBlock *MoveBB = *SortedBBsToSinkInto.begin(); // FIXME: Optimize the efficiency for cloned value replacement. The current // implementation is O(SortedBBsToSinkInto.size() * I.num_uses()). - for (BasicBlock *N : makeArrayRef(SortedBBsToSinkInto).drop_front(1)) { + for (BasicBlock *N : ArrayRef(SortedBBsToSinkInto).drop_front(1)) { assert(LoopBlockNumber.find(N)->second > LoopBlockNumber.find(MoveBB)->second && "BBs not sorted!"); @@ -300,8 +300,8 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, return BFI.getBlockFreq(A) < BFI.getBlockFreq(B); }); - // Traverse preheader's instructions in reverse order becaue if A depends - // on B (A appears after B), A needs to be sinked first before B can be + // Traverse preheader's instructions in reverse order because if A depends + // on B (A appears after B), A needs to be sunk first before B can be // sinked. for (Instruction &I : llvm::make_early_inc_range(llvm::reverse(*Preheader))) { if (isa<PHINode>(&I)) @@ -312,12 +312,13 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, if (!canSinkOrHoistInst(I, &AA, &DT, &L, MSSAU, false, LICMFlags)) continue; if (sinkInstruction(L, I, ColdLoopBBs, LoopBlockNumber, LI, DT, BFI, - &MSSAU)) + &MSSAU)) { Changed = true; + if (SE) + SE->forgetBlockAndLoopDispositions(&I); + } } - if (Changed && SE) - SE->forgetLoopDispositions(&L); return Changed; } diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index a3434f8bc46d..4c89f947d7fc 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -64,6 +64,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/IVUsers.h" @@ -123,6 +124,7 @@ #include <limits> #include <map> #include <numeric> +#include <optional> #include <utility> using namespace llvm; @@ -146,7 +148,7 @@ static cl::opt<bool> EnablePhiElim( "enable-lsr-phielim", cl::Hidden, cl::init(true), cl::desc("Enable LSR phi elimination")); -// The flag adds instruction count to solutions cost comparision. +// The flag adds instruction count to solutions cost comparison. static cl::opt<bool> InsnsCost( "lsr-insns-cost", cl::Hidden, cl::init(true), cl::desc("Add instruction count to a LSR cost model")); @@ -186,6 +188,17 @@ static cl::opt<unsigned> SetupCostDepthLimit( "lsr-setupcost-depth-limit", cl::Hidden, cl::init(7), cl::desc("The limit on recursion depth for LSRs setup cost")); +static cl::opt<bool> AllowTerminatingConditionFoldingAfterLSR( + "lsr-term-fold", cl::Hidden, cl::init(false), + cl::desc("Attempt to replace primary IV with other IV.")); + +static cl::opt<bool> AllowDropSolutionIfLessProfitable( + "lsr-drop-solution", cl::Hidden, cl::init(false), + cl::desc("Attempt to drop solution if it is less profitable")); + +STATISTIC(NumTermFold, + "Number of terminating condition fold recognized and performed"); + #ifndef NDEBUG // Stress test IV chain generation. static cl::opt<bool> StressIVChain( @@ -1067,7 +1080,7 @@ public: C.ScaleCost = 0; } - bool isLess(const Cost &Other); + bool isLess(const Cost &Other) const; void Lose(); @@ -1255,7 +1268,7 @@ static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) { if (auto S = dyn_cast<SCEVIntegralCastExpr>(Reg)) return getSetupCost(S->getOperand(), Depth - 1); if (auto S = dyn_cast<SCEVNAryExpr>(Reg)) - return std::accumulate(S->op_begin(), S->op_end(), 0, + return std::accumulate(S->operands().begin(), S->operands().end(), 0, [&](unsigned i, const SCEV *Reg) { return i + getSetupCost(Reg, Depth - 1); }); @@ -1466,7 +1479,7 @@ void Cost::Lose() { } /// Choose the lower cost. -bool Cost::isLess(const Cost &Other) { +bool Cost::isLess(const Cost &Other) const { if (InsnsCost.getNumOccurrences() > 0 && InsnsCost && C.Insns != Other.C.Insns) return C.Insns < Other.C.Insns; @@ -1967,6 +1980,10 @@ class LSRInstance { /// SmallDenseSet. SetVector<int64_t, SmallVector<int64_t, 8>, SmallSet<int64_t, 8>> Factors; + /// The cost of the current SCEV, the best solution by LSR will be dropped if + /// the solution is not profitable. + Cost BaselineCost; + /// Interesting use types, to facilitate truncation reuse. SmallSetVector<Type *, 4> Types; @@ -2413,9 +2430,7 @@ LSRInstance::OptimizeLoopTermCond() { BasicBlock *LatchBlock = L->getLoopLatch(); SmallVector<BasicBlock*, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - if (llvm::all_of(ExitingBlocks, [&LatchBlock](const BasicBlock *BB) { - return LatchBlock != BB; - })) { + if (!llvm::is_contained(ExitingBlocks, LatchBlock)) { // The backedge doesn't exit the loop; treat this as a head-tested loop. IVIncInsertPos = LatchBlock->getTerminator(); return; @@ -2520,7 +2535,7 @@ LSRInstance::OptimizeLoopTermCond() { ICmpInst *OldCond = Cond; Cond = cast<ICmpInst>(Cond->clone()); Cond->setName(L->getHeader()->getName() + ".termcond"); - ExitingBlock->getInstList().insert(TermBr->getIterator(), Cond); + Cond->insertInto(ExitingBlock, TermBr->getIterator()); // Clone the IVUse, as the old use still exists! CondUse = &IU.AddUser(Cond, CondUse->getOperandValToReplace()); @@ -2542,15 +2557,8 @@ LSRInstance::OptimizeLoopTermCond() { // must dominate all the post-inc comparisons we just set up, and it must // dominate the loop latch edge. IVIncInsertPos = L->getLoopLatch()->getTerminator(); - for (Instruction *Inst : PostIncs) { - BasicBlock *BB = - DT.findNearestCommonDominator(IVIncInsertPos->getParent(), - Inst->getParent()); - if (BB == Inst->getParent()) - IVIncInsertPos = Inst; - else if (BB != IVIncInsertPos->getParent()) - IVIncInsertPos = BB->getTerminator(); - } + for (Instruction *Inst : PostIncs) + IVIncInsertPos = DT.findNearestCommonDominator(IVIncInsertPos, Inst); } /// Determine if the given use can accommodate a fixup at the given offset and @@ -2708,7 +2716,7 @@ void LSRInstance::CollectInterestingTypesAndFactors() { Strides.insert(AR->getStepRecurrence(SE)); Worklist.push_back(AR->getStart()); } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { - Worklist.append(Add->op_begin(), Add->op_end()); + append_range(Worklist, Add->operands()); } } while (!Worklist.empty()); } @@ -3288,6 +3296,11 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { BranchInst *ExitBranch = nullptr; bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &TLI); + // For calculating baseline cost + SmallPtrSet<const SCEV *, 16> Regs; + DenseSet<const SCEV *> VisitedRegs; + DenseSet<size_t> VisitedLSRUse; + for (const IVStrideUse &U : IU) { Instruction *UserInst = U.getUser(); // Skip IV users that are part of profitable IV Chains. @@ -3381,6 +3394,14 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { LF.Offset = Offset; LU.AllFixupsOutsideLoop &= LF.isUseFullyOutsideLoop(L); + // Create SCEV as Formula for calculating baseline cost + if (!VisitedLSRUse.count(LUIdx) && !LF.isUseFullyOutsideLoop(L)) { + Formula F; + F.initialMatch(S, L, SE); + BaselineCost.RateFormula(F, Regs, VisitedRegs, LU); + VisitedLSRUse.insert(LUIdx); + } + if (!LU.WidestFixupType || SE.getTypeSizeInBits(LU.WidestFixupType) < SE.getTypeSizeInBits(LF.OperandValToReplace->getType())) @@ -3462,7 +3483,7 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() { continue; if (const SCEVNAryExpr *N = dyn_cast<SCEVNAryExpr>(S)) - Worklist.append(N->op_begin(), N->op_end()); + append_range(Worklist, N->operands()); else if (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(S)) Worklist.push_back(C->getOperand()); else if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { @@ -4267,8 +4288,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { ImmMapTy::const_iterator OtherImms[] = { Imms.begin(), std::prev(Imms.end()), Imms.lower_bound(Avg)}; - for (size_t i = 0, e = array_lengthof(OtherImms); i != e; ++i) { - ImmMapTy::const_iterator M = OtherImms[i]; + for (const auto &M : OtherImms) { if (M == J || M == JE) continue; // Compute the difference between the two. @@ -5157,6 +5177,20 @@ void LSRInstance::Solve(SmallVectorImpl<const Formula *> &Solution) const { }); assert(Solution.size() == Uses.size() && "Malformed solution!"); + + if (BaselineCost.isLess(SolutionCost)) { + LLVM_DEBUG(dbgs() << "The baseline solution requires "; + BaselineCost.print(dbgs()); dbgs() << "\n"); + if (!AllowDropSolutionIfLessProfitable) + LLVM_DEBUG( + dbgs() << "Baseline is more profitable than chosen solution, " + "add option 'lsr-drop-solution' to drop LSR solution.\n"); + else { + LLVM_DEBUG(dbgs() << "Baseline is more profitable than chosen " + "solution, dropping LSR solution.\n";); + Solution.clear(); + } + } } /// Helper for AdjustInsertPositionForExpand. Climb up the dominator tree far as @@ -5701,7 +5735,8 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, MSSAU(MSSAU), AMK(PreferredAddresingMode.getNumOccurrences() > 0 ? PreferredAddresingMode : TTI.getPreferredAddressingMode(L, &SE)), - Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr", false) { + Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr", false), + BaselineCost(L, SE, TTI, AMK) { // If LoopSimplify form is not available, stay out of trouble. if (!L->isLoopSimplifyForm()) return; @@ -5942,7 +5977,7 @@ struct SCEVDbgValueBuilder { /// in the set of values referenced by the expression. void pushLocation(llvm::Value *V) { Expr.push_back(llvm::dwarf::DW_OP_LLVM_arg); - auto *It = std::find(LocationOps.begin(), LocationOps.end(), V); + auto *It = llvm::find(LocationOps, V); unsigned ArgIndex = 0; if (It != LocationOps.end()) { ArgIndex = std::distance(LocationOps.begin(), It); @@ -5980,7 +6015,7 @@ struct SCEVDbgValueBuilder { "Expected arithmetic SCEV type"); bool Success = true; unsigned EmitOperator = 0; - for (auto &Op : CommExpr->operands()) { + for (const auto &Op : CommExpr->operands()) { Success &= pushSCEV(Op); if (EmitOperator >= 1) @@ -6347,7 +6382,7 @@ static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE, llvm::PHINode *LSRInductionVar, DVIRecoveryRec &DVIRec, const SCEV *SCEVInductionVar, SCEVDbgValueBuilder IterCountExpr) { - if (!DVIRec.DVI->isUndef()) + if (!DVIRec.DVI->isKillLocation()) return false; // LSR may have caused several changes to the dbg.value in the failed salvage @@ -6394,11 +6429,10 @@ static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE, // Create an offset-based salvage expression if possible, as it requires // less DWARF ops than an iteration count-based expression. - if (Optional<APInt> Offset = + if (std::optional<APInt> Offset = SE.computeConstantDifference(DVIRec.SCEVs[i], SCEVInductionVar)) { - if (Offset.value().getMinSignedBits() <= 64) - SalvageExpr->createOffsetExpr(Offset.value().getSExtValue(), - LSRInductionVar); + if (Offset->getMinSignedBits() <= 64) + SalvageExpr->createOffsetExpr(Offset->getSExtValue(), LSRInductionVar); } else if (!SalvageExpr->createIterCountExpr(DVIRec.SCEVs[i], IterCountExpr, SE)) return false; @@ -6490,14 +6524,14 @@ static void DbgGatherSalvagableDVI( Loop *L, ScalarEvolution &SE, SmallVector<std::unique_ptr<DVIRecoveryRec>, 2> &SalvageableDVISCEVs, SmallSet<AssertingVH<DbgValueInst>, 2> &DVIHandles) { - for (auto &B : L->getBlocks()) { + for (const auto &B : L->getBlocks()) { for (auto &I : *B) { auto DVI = dyn_cast<DbgValueInst>(&I); if (!DVI) continue; // Ensure that if any location op is undef that the dbg.vlue is not // cached. - if (DVI->isUndef()) + if (DVI->isKillLocation()) continue; // Check that the location op SCEVs are suitable for translation to @@ -6573,6 +6607,159 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE, return nullptr; } +static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *>> +canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, + const LoopInfo &LI) { + if (!L->isInnermost()) { + LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); + return std::nullopt; + } + // Only inspect on simple loop structure + if (!L->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n"); + return std::nullopt; + } + + if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { + LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); + return std::nullopt; + } + + BasicBlock *LoopLatch = L->getLoopLatch(); + + // TODO: Can we do something for greater than and less than? + // Terminating condition is foldable when it is an eq/ne icmp + BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); + if (BI->isUnconditional()) + return std::nullopt; + Value *TermCond = BI->getCondition(); + if (!isa<ICmpInst>(TermCond) || !cast<ICmpInst>(TermCond)->isEquality()) { + LLVM_DEBUG(dbgs() << "Cannot fold on branching condition that is not an " + "ICmpInst::eq / ICmpInst::ne\n"); + return std::nullopt; + } + if (!TermCond->hasOneUse()) { + LLVM_DEBUG( + dbgs() + << "Cannot replace terminating condition with more than one use\n"); + return std::nullopt; + } + + // For `IsToFold`, a primary IV can be replaced by other affine AddRec when it + // is only used by the terminating condition. To check for this, we may need + // to traverse through a chain of use-def until we can examine the final + // usage. + // *----------------------* + // *---->| LoopHeader: | + // | | PrimaryIV = phi ... | + // | *----------------------* + // | | + // | | + // | chain of + // | single use + // used by | + // phi | + // | Value + // | / \ + // | chain of chain of + // | single use single use + // | / \ + // | / \ + // *- Value Value --> used by terminating condition + auto IsToFold = [&](PHINode &PN) -> bool { + Value *V = &PN; + + while (V->getNumUses() == 1) + V = *V->user_begin(); + + if (V->getNumUses() != 2) + return false; + + Value *VToPN = nullptr; + Value *VToTermCond = nullptr; + for (User *U : V->users()) { + while (U->getNumUses() == 1) { + if (isa<PHINode>(U)) + VToPN = U; + if (U == TermCond) + VToTermCond = U; + U = *U->user_begin(); + } + } + return VToPN && VToTermCond; + }; + + // If this is an IV which we could replace the terminating condition, return + // the final value of the alternative IV on the last iteration. + auto getAlternateIVEnd = [&](PHINode &PN) -> const SCEV * { + // FIXME: This does not properly account for overflow. + const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const SCEV *TermValueS = SE.getAddExpr( + AddRec->getOperand(0), + SE.getTruncateOrZeroExtend( + SE.getMulExpr( + AddRec->getOperand(1), + SE.getTruncateOrZeroExtend( + SE.getAddExpr(BECount, SE.getOne(BECount->getType())), + AddRec->getOperand(1)->getType())), + AddRec->getOperand(0)->getType())); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + if (!Expander.isSafeToExpand(TermValueS)) { + LLVM_DEBUG( + dbgs() << "Is not safe to expand terminating value for phi node" << PN + << "\n"); + return nullptr; + } + return TermValueS; + }; + + PHINode *ToFold = nullptr; + PHINode *ToHelpFold = nullptr; + const SCEV *TermValueS = nullptr; + + for (PHINode &PN : L->getHeader()->phis()) { + if (!SE.isSCEVable(PN.getType())) { + LLVM_DEBUG(dbgs() << "IV of phi '" << PN + << "' is not SCEV-able, not qualified for the " + "terminating condition folding.\n"); + continue; + } + const SCEV *S = SE.getSCEV(&PN); + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S); + // Only speculate on affine AddRec + if (!AddRec || !AddRec->isAffine()) { + LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN + << "' is not an affine add recursion, not qualified " + "for the terminating condition folding.\n"); + continue; + } + + if (IsToFold(PN)) + ToFold = &PN; + else if (auto P = getAlternateIVEnd(PN)) { + ToHelpFold = &PN; + TermValueS = P; + } + } + + LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() + << "Cannot find other AddRec IV to help folding\n";); + + LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() + << "\nFound loop that can fold terminating condition\n" + << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" + << " TermCond: " << *TermCond << "\n" + << " BrandInst: " << *BI << "\n" + << " ToFold: " << *ToFold << "\n" + << " ToHelpFold: " << *ToHelpFold << "\n"); + + if (!ToFold || !ToHelpFold) + return std::nullopt; + return std::make_tuple(ToFold, ToHelpFold, TermValueS); +} + static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, @@ -6620,7 +6807,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, if (L->isRecursivelyLCSSAForm(DT, LI) && L->getExitBlock()) { SmallVector<WeakTrackingVH, 16> DeadInsts; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - SCEVExpander Rewriter(SE, DL, "lsr", false); + SCEVExpander Rewriter(SE, DL, "lsr", true); int Rewrites = rewriteLoopExitValues(L, &LI, &TLI, &SE, &TTI, Rewriter, &DT, UnusedIndVarInLoop, DeadInsts); if (Rewrites) { @@ -6631,13 +6818,73 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, } } + if (AllowTerminatingConditionFoldingAfterLSR) { + if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI)) { + auto [ToFold, ToHelpFold, TermValueS] = *Opt; + + Changed = true; + NumTermFold++; + + BasicBlock *LoopPreheader = L->getLoopPreheader(); + BasicBlock *LoopLatch = L->getLoopLatch(); + + (void)ToFold; + LLVM_DEBUG(dbgs() << "To fold phi-node:\n" + << *ToFold << "\n" + << "New term-cond phi-node:\n" + << *ToHelpFold << "\n"); + + Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader); + (void)StartValue; + Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); + + // SCEVExpander for both use in preheader and latch + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + SCEVExpanderCleaner ExpCleaner(Expander); + + assert(Expander.isSafeToExpand(TermValueS) && + "Terminating value was checked safe in canFoldTerminatingCondition"); + + // Create new terminating value at loop header + Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(), + LoopPreheader->getTerminator()); + + LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" + << *StartValue << "\n" + << "Terminating value of new term-cond phi-node:\n" + << *TermValue << "\n"); + + // Create new terminating condition at loop latch + BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); + ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition()); + IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); + // FIXME: We are adding a use of an IV here without account for poison safety. + // This is incorrect. + Value *NewTermCond = LatchBuilder.CreateICmp( + OldTermCond->getPredicate(), LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + + LLVM_DEBUG(dbgs() << "Old term-cond:\n" + << *OldTermCond << "\n" + << "New term-cond:\b" << *NewTermCond << "\n"); + + BI->setCondition(NewTermCond); + + OldTermCond->eraseFromParent(); + DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); + + ExpCleaner.markResultUsed(); + } + } + if (SalvageableDVIRecords.empty()) return Changed; // Obtain relevant IVs and attempt to rewrite the salvageable DVIs with // expressions composed using the derived iteration count. // TODO: Allow for multiple IV references for nested AddRecSCEVs - for (auto &L : LI) { + for (const auto &L : LI) { if (llvm::PHINode *IV = GetInductionVariable(*L, SE, Reducer)) DbgRewriteSalvageableDVIs(L, SE, IV, SalvageableDVIRecords); else { diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 64fcdfa15aa9..0ae26b494c5a 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -12,8 +12,6 @@ #include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringRef.h" @@ -156,7 +154,7 @@ getUnrollAndJammedLoopSize(unsigned LoopSize, // unroll count was set explicitly. static bool computeUnrollAndJamCount( Loop *L, Loop *SubLoop, const TargetTransformInfo &TTI, DominatorTree &DT, - LoopInfo *LI, ScalarEvolution &SE, + LoopInfo *LI, AssumptionCache *AC, ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, OptimizationRemarkEmitter *ORE, unsigned OuterTripCount, unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount, @@ -170,7 +168,7 @@ static bool computeUnrollAndJamCount( unsigned MaxTripCount = 0; bool UseUpperBound = false; bool ExplicitUnroll = computeUnrollCount( - L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount, + L, TTI, DT, LI, AC, SE, EphValues, ORE, OuterTripCount, MaxTripCount, /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, PP, UseUpperBound); if (ExplicitUnroll || UseUpperBound) { @@ -284,11 +282,11 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, const TargetTransformInfo &TTI, AssumptionCache &AC, DependenceInfo &DI, OptimizationRemarkEmitter &ORE, int OptLevel) { - TargetTransformInfo::UnrollingPreferences UP = - gatherUnrollingPreferences(L, SE, TTI, nullptr, nullptr, ORE, OptLevel, - None, None, None, None, None, None); + TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( + L, SE, TTI, nullptr, nullptr, ORE, OptLevel, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt); TargetTransformInfo::PeelingPreferences PP = - gatherPeelingPreferences(L, SE, TTI, None, None); + gatherPeelingPreferences(L, SE, TTI, std::nullopt, std::nullopt); TransformationMode EnableMode = hasUnrollAndJamTransformation(L); if (EnableMode & TM_Disable) @@ -369,11 +367,11 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // To assign the loop id of the epilogue, assign it before unrolling it so it // is applied to every inner loop of the epilogue. We later apply the loop ID // for the jammed inner loop. - Optional<MDNode *> NewInnerEpilogueLoopID = makeFollowupLoopID( + std::optional<MDNode *> NewInnerEpilogueLoopID = makeFollowupLoopID( OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupRemainderInner}); if (NewInnerEpilogueLoopID) - SubLoop->setLoopID(NewInnerEpilogueLoopID.value()); + SubLoop->setLoopID(*NewInnerEpilogueLoopID); // Find trip count and trip multiple BasicBlock *Latch = L->getLoopLatch(); @@ -384,7 +382,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // Decide if, and by how much, to unroll bool IsCountSetExplicitly = computeUnrollAndJamCount( - L, SubLoop, TTI, DT, LI, SE, EphValues, &ORE, OuterTripCount, + L, SubLoop, TTI, DT, LI, &AC, SE, EphValues, &ORE, OuterTripCount, OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP, PP); if (UP.Count <= 1) return LoopUnrollResult::Unmodified; @@ -399,27 +397,27 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // Assign new loop attributes. if (EpilogueOuterLoop) { - Optional<MDNode *> NewOuterEpilogueLoopID = makeFollowupLoopID( + std::optional<MDNode *> NewOuterEpilogueLoopID = makeFollowupLoopID( OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupRemainderOuter}); if (NewOuterEpilogueLoopID) - EpilogueOuterLoop->setLoopID(NewOuterEpilogueLoopID.value()); + EpilogueOuterLoop->setLoopID(*NewOuterEpilogueLoopID); } - Optional<MDNode *> NewInnerLoopID = + std::optional<MDNode *> NewInnerLoopID = makeFollowupLoopID(OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupInner}); if (NewInnerLoopID) - SubLoop->setLoopID(NewInnerLoopID.value()); + SubLoop->setLoopID(*NewInnerLoopID); else SubLoop->setLoopID(OrigSubLoopID); if (UnrollResult == LoopUnrollResult::PartiallyUnrolled) { - Optional<MDNode *> NewOuterLoopID = makeFollowupLoopID( + std::optional<MDNode *> NewOuterLoopID = makeFollowupLoopID( OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupOuter}); if (NewOuterLoopID) { - L->setLoopID(NewOuterLoopID.value()); + L->setLoopID(*NewOuterLoopID); // Do not setLoopAlreadyUnrolled if a followup was given. return UnrollResult; diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index de5833f60adc..1a6065cb3f1a 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -15,8 +15,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -64,6 +62,7 @@ #include <cassert> #include <cstdint> #include <limits> +#include <optional> #include <string> #include <tuple> #include <utility> @@ -185,9 +184,10 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, OptimizationRemarkEmitter &ORE, int OptLevel, - Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, - Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, - Optional<bool> UserUpperBound, Optional<unsigned> UserFullUnrollMaxCount) { + std::optional<unsigned> UserThreshold, std::optional<unsigned> UserCount, + std::optional<bool> UserAllowPartial, std::optional<bool> UserRuntime, + std::optional<bool> UserUpperBound, + std::optional<unsigned> UserFullUnrollMaxCount) { TargetTransformInfo::UnrollingPreferences UP; // Set up the defaults @@ -342,8 +342,8 @@ struct PragmaInfo { /// cost of the 'false'-block). /// \returns Optional value, holding the RolledDynamicCost and UnrolledCost. If /// the analysis failed (no benefits expected from the unrolling, or the loop is -/// too big to analyze), the returned value is None. -static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( +/// too big to analyze), the returned value is std::nullopt. +static std::optional<EstimatedUnrollCost> analyzeLoopUnrollCost( const Loop *L, unsigned TripCount, DominatorTree &DT, ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, const TargetTransformInfo &TTI, unsigned MaxUnrolledLoopSize, @@ -358,11 +358,11 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( // Only analyze inner loops. We can't properly estimate cost of nested loops // and we won't visit inner loops again anyway. if (!L->isInnermost()) - return None; + return std::nullopt; // Don't simulate loops with a big or unknown tripcount if (!TripCount || TripCount > MaxIterationsCountToAnalyze) - return None; + return std::nullopt; SmallSetVector<BasicBlock *, 16> BBWorklist; SmallSetVector<std::pair<BasicBlock *, BasicBlock *>, 4> ExitWorklist; @@ -443,7 +443,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( // First accumulate the cost of this instruction. if (!Cost.IsFree) { - UnrolledCost += TTI.getUserCost(I, CostKind); + UnrolledCost += TTI.getInstructionCost(I, CostKind); LLVM_DEBUG(dbgs() << "Adding cost of instruction (iteration " << Iteration << "): "); LLVM_DEBUG(I->dump()); @@ -537,7 +537,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( // Track this instruction's expected baseline cost when executing the // rolled loop form. - RolledDynamicCost += TTI.getUserCost(&I, CostKind); + RolledDynamicCost += TTI.getInstructionCost(&I, CostKind); // Visit the instruction to analyze its loop cost after unrolling, // and if the visitor returns true, mark the instruction as free after @@ -558,7 +558,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( const Function *Callee = CI->getCalledFunction(); if (!Callee || TTI.isLoweredToCall(Callee)) { LLVM_DEBUG(dbgs() << "Can't analyze cost of loop with call\n"); - return None; + return std::nullopt; } } @@ -573,7 +573,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( << " UnrolledCost: " << UnrolledCost << ", MaxUnrolledLoopSize: " << MaxUnrolledLoopSize << "\n"); - return None; + return std::nullopt; } } @@ -631,7 +631,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( if (UnrolledCost == RolledDynamicCost) { LLVM_DEBUG(dbgs() << " No opportunities found.. exiting.\n" << " UnrolledCost: " << UnrolledCost << "\n"); - return None; + return std::nullopt; } } @@ -682,7 +682,7 @@ InstructionCost llvm::ApproximateLoopSize( // that each loop has at least three instructions (likely a conditional // branch, a comparison feeding that branch, and some kind of loop increment // feeding that comparison instruction). - if (LoopSize.isValid() && *LoopSize.getValue() < BEInsns + 1) + if (LoopSize.isValid() && LoopSize < BEInsns + 1) // This is an open coded max() on InstructionCost LoopSize = BEInsns + 1; @@ -772,7 +772,7 @@ public: } }; -static Optional<unsigned> +static std::optional<unsigned> shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo, const unsigned TripMultiple, const unsigned TripCount, const UnrollCostEstimator UCE, @@ -797,10 +797,10 @@ shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo, return TripCount; // if didn't return until here, should continue to other priorties - return None; + return std::nullopt; } -static Optional<unsigned> shouldFullUnroll( +static std::optional<unsigned> shouldFullUnroll( Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, const unsigned FullUnrollTripCount, const UnrollCostEstimator UCE, @@ -808,7 +808,7 @@ static Optional<unsigned> shouldFullUnroll( assert(FullUnrollTripCount && "should be non-zero!"); if (FullUnrollTripCount > UP.FullUnrollMaxCount) - return None; + return std::nullopt; // When computing the unrolled size, note that BEInsns are not replicated // like the rest of the loop body. @@ -818,7 +818,7 @@ static Optional<unsigned> shouldFullUnroll( // The loop isn't that small, but we still can fully unroll it if that // helps to remove a significant number of instructions. // To check that, run additional analysis on the loop. - if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( + if (std::optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( L, FullUnrollTripCount, DT, SE, EphValues, TTI, UP.Threshold * UP.MaxPercentThresholdBoost / 100, UP.MaxIterationsCountToAnalyze)) { @@ -827,16 +827,16 @@ static Optional<unsigned> shouldFullUnroll( if (Cost->UnrolledCost < UP.Threshold * Boost / 100) return FullUnrollTripCount; } - return None; + return std::nullopt; } -static Optional<unsigned> +static std::optional<unsigned> shouldPartialUnroll(const unsigned LoopSize, const unsigned TripCount, const UnrollCostEstimator UCE, const TargetTransformInfo::UnrollingPreferences &UP) { if (!TripCount) - return None; + return std::nullopt; if (!UP.Partial) { LLVM_DEBUG(dbgs() << " will not try to unroll partially because " @@ -888,6 +888,7 @@ shouldPartialUnroll(const unsigned LoopSize, const unsigned TripCount, // refactored into it own function. bool llvm::computeUnrollCount( Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI, + AssumptionCache *AC, ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, OptimizationRemarkEmitter *ORE, unsigned TripCount, unsigned MaxTripCount, bool MaxOrZero, unsigned TripMultiple, unsigned LoopSize, @@ -978,7 +979,7 @@ bool llvm::computeUnrollCount( } // 5th priority is loop peeling. - computePeelCount(L, LoopSize, PP, TripCount, DT, SE, UP.Threshold); + computePeelCount(L, LoopSize, PP, TripCount, DT, SE, AC, UP.Threshold); if (PP.PeelCount) { UP.Runtime = false; UP.Count = 1; @@ -1118,17 +1119,20 @@ bool llvm::computeUnrollCount( return ExplicitUnroll; } -static LoopUnrollResult tryToUnrollLoop( - Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, - const TargetTransformInfo &TTI, AssumptionCache &AC, - OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, - ProfileSummaryInfo *PSI, bool PreserveLCSSA, int OptLevel, - bool OnlyWhenForced, bool ForgetAllSCEV, Optional<unsigned> ProvidedCount, - Optional<unsigned> ProvidedThreshold, Optional<bool> ProvidedAllowPartial, - Optional<bool> ProvidedRuntime, Optional<bool> ProvidedUpperBound, - Optional<bool> ProvidedAllowPeeling, - Optional<bool> ProvidedAllowProfileBasedPeeling, - Optional<unsigned> ProvidedFullUnrollMaxCount) { +static LoopUnrollResult +tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, + const TargetTransformInfo &TTI, AssumptionCache &AC, + OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, bool PreserveLCSSA, int OptLevel, + bool OnlyWhenForced, bool ForgetAllSCEV, + std::optional<unsigned> ProvidedCount, + std::optional<unsigned> ProvidedThreshold, + std::optional<bool> ProvidedAllowPartial, + std::optional<bool> ProvidedRuntime, + std::optional<bool> ProvidedUpperBound, + std::optional<bool> ProvidedAllowPeeling, + std::optional<bool> ProvidedAllowProfileBasedPeeling, + std::optional<unsigned> ProvidedFullUnrollMaxCount) { LLVM_DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() << "] Loop %" << L->getHeader()->getName() << "\n"); @@ -1222,7 +1226,7 @@ static LoopUnrollResult tryToUnrollLoop( // Find the smallest exact trip count for any exit. This is an upper bound // on the loop trip count, but an exit at an earlier iteration is still // possible. An unroll by the smallest exact trip count guarantees that all - // brnaches relating to at least one exit can be eliminated. This is unlike + // branches relating to at least one exit can be eliminated. This is unlike // the max trip count, which only guarantees that the backedge can be broken. unsigned TripCount = 0; unsigned TripMultiple = 1; @@ -1272,7 +1276,7 @@ static LoopUnrollResult tryToUnrollLoop( // fully unroll the loop. bool UseUpperBound = false; bool IsCountSetExplicitly = computeUnrollCount( - L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero, + L, TTI, DT, LI, &AC, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero, TripMultiple, LoopSize, UP, PP, UseUpperBound); if (!UP.Count) return LoopUnrollResult::Unmodified; @@ -1288,7 +1292,8 @@ static LoopUnrollResult tryToUnrollLoop( << " iterations"; }); - if (peelLoop(L, PP.PeelCount, LI, &SE, DT, &AC, PreserveLCSSA)) { + ValueToValueMapTy VMap; + if (peelLoop(L, PP.PeelCount, LI, &SE, DT, &AC, PreserveLCSSA, VMap)) { simplifyLoopAfterUnroll(L, true, LI, &SE, &DT, &AC, &TTI); // If the loop was peeled, we already "used up" the profile information // we had, so we don't want to unroll or peel again. @@ -1320,19 +1325,19 @@ static LoopUnrollResult tryToUnrollLoop( return LoopUnrollResult::Unmodified; if (RemainderLoop) { - Optional<MDNode *> RemainderLoopID = + std::optional<MDNode *> RemainderLoopID = makeFollowupLoopID(OrigLoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); if (RemainderLoopID) - RemainderLoop->setLoopID(RemainderLoopID.value()); + RemainderLoop->setLoopID(*RemainderLoopID); } if (UnrollResult != LoopUnrollResult::FullyUnrolled) { - Optional<MDNode *> NewLoopID = + std::optional<MDNode *> NewLoopID = makeFollowupLoopID(OrigLoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupUnrolled}); if (NewLoopID) { - L->setLoopID(NewLoopID.value()); + L->setLoopID(*NewLoopID); // Do not setLoopAlreadyUnrolled if loop attributes have been specified // explicitly. @@ -1366,23 +1371,25 @@ public: /// Otherwise, forgetAllLoops and rebuild when needed next. bool ForgetAllSCEV; - Optional<unsigned> ProvidedCount; - Optional<unsigned> ProvidedThreshold; - Optional<bool> ProvidedAllowPartial; - Optional<bool> ProvidedRuntime; - Optional<bool> ProvidedUpperBound; - Optional<bool> ProvidedAllowPeeling; - Optional<bool> ProvidedAllowProfileBasedPeeling; - Optional<unsigned> ProvidedFullUnrollMaxCount; + std::optional<unsigned> ProvidedCount; + std::optional<unsigned> ProvidedThreshold; + std::optional<bool> ProvidedAllowPartial; + std::optional<bool> ProvidedRuntime; + std::optional<bool> ProvidedUpperBound; + std::optional<bool> ProvidedAllowPeeling; + std::optional<bool> ProvidedAllowProfileBasedPeeling; + std::optional<unsigned> ProvidedFullUnrollMaxCount; LoopUnroll(int OptLevel = 2, bool OnlyWhenForced = false, - bool ForgetAllSCEV = false, Optional<unsigned> Threshold = None, - Optional<unsigned> Count = None, - Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, - Optional<bool> UpperBound = None, - Optional<bool> AllowPeeling = None, - Optional<bool> AllowProfileBasedPeeling = None, - Optional<unsigned> ProvidedFullUnrollMaxCount = None) + bool ForgetAllSCEV = false, + std::optional<unsigned> Threshold = std::nullopt, + std::optional<unsigned> Count = std::nullopt, + std::optional<bool> AllowPartial = std::nullopt, + std::optional<bool> Runtime = std::nullopt, + std::optional<bool> UpperBound = std::nullopt, + std::optional<bool> AllowPeeling = std::nullopt, + std::optional<bool> AllowProfileBasedPeeling = std::nullopt, + std::optional<unsigned> ProvidedFullUnrollMaxCount = std::nullopt) : LoopPass(ID), OptLevel(OptLevel), OnlyWhenForced(OnlyWhenForced), ForgetAllSCEV(ForgetAllSCEV), ProvidedCount(std::move(Count)), ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), @@ -1454,12 +1461,12 @@ Pass *llvm::createLoopUnrollPass(int OptLevel, bool OnlyWhenForced, // callers. return new LoopUnroll( OptLevel, OnlyWhenForced, ForgetAllSCEV, - Threshold == -1 ? None : Optional<unsigned>(Threshold), - Count == -1 ? None : Optional<unsigned>(Count), - AllowPartial == -1 ? None : Optional<bool>(AllowPartial), - Runtime == -1 ? None : Optional<bool>(Runtime), - UpperBound == -1 ? None : Optional<bool>(UpperBound), - AllowPeeling == -1 ? None : Optional<bool>(AllowPeeling)); + Threshold == -1 ? std::nullopt : std::optional<unsigned>(Threshold), + Count == -1 ? std::nullopt : std::optional<unsigned>(Count), + AllowPartial == -1 ? std::nullopt : std::optional<bool>(AllowPartial), + Runtime == -1 ? std::nullopt : std::optional<bool>(Runtime), + UpperBound == -1 ? std::nullopt : std::optional<bool>(UpperBound), + AllowPeeling == -1 ? std::nullopt : std::optional<bool>(AllowPeeling)); } Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced, @@ -1487,16 +1494,17 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, std::string LoopName = std::string(L.getName()); - bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, ORE, - /*BFI*/ nullptr, /*PSI*/ nullptr, - /*PreserveLCSSA*/ true, OptLevel, - OnlyWhenForced, ForgetSCEV, /*Count*/ None, - /*Threshold*/ None, /*AllowPartial*/ false, - /*Runtime*/ false, /*UpperBound*/ false, - /*AllowPeeling*/ true, - /*AllowProfileBasedPeeling*/ false, - /*FullUnrollMaxCount*/ None) != - LoopUnrollResult::Unmodified; + bool Changed = + tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, ORE, + /*BFI*/ nullptr, /*PSI*/ nullptr, + /*PreserveLCSSA*/ true, OptLevel, OnlyWhenForced, + ForgetSCEV, /*Count*/ std::nullopt, + /*Threshold*/ std::nullopt, /*AllowPartial*/ false, + /*Runtime*/ false, /*UpperBound*/ false, + /*AllowPeeling*/ true, + /*AllowProfileBasedPeeling*/ false, + /*FullUnrollMaxCount*/ std::nullopt) != + LoopUnrollResult::Unmodified; if (!Changed) return PreservedAnalyses::all(); @@ -1583,7 +1591,7 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, // legality and profitability checks. This means running the loop unroller // will simplify all loops, regardless of whether anything end up being // unrolled. - for (auto &L : LI) { + for (const auto &L : LI) { Changed |= simplifyLoop(L, &DT, &LI, &SE, &AC, nullptr, false /* PreserveLCSSA */); Changed |= formLCSSARecursively(*L, DT, &LI, &SE); @@ -1607,7 +1615,7 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, // Check if the profile summary indicates that the profiled application // has a huge working set size, in which case we disable peeling to avoid // bloating it further. - Optional<bool> LocalAllowPeeling = UnrollOpts.AllowPeeling; + std::optional<bool> LocalAllowPeeling = UnrollOpts.AllowPeeling; if (PSI && PSI->hasHugeWorkingSetSize()) LocalAllowPeeling = false; std::string LoopName = std::string(L.getName()); @@ -1616,9 +1624,9 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, LoopUnrollResult Result = tryToUnrollLoop( &L, DT, &LI, SE, TTI, AC, ORE, BFI, PSI, /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, UnrollOpts.OnlyWhenForced, - UnrollOpts.ForgetSCEV, /*Count*/ None, - /*Threshold*/ None, UnrollOpts.AllowPartial, UnrollOpts.AllowRuntime, - UnrollOpts.AllowUpperBound, LocalAllowPeeling, + UnrollOpts.ForgetSCEV, /*Count*/ std::nullopt, + /*Threshold*/ std::nullopt, UnrollOpts.AllowPartial, + UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound, LocalAllowPeeling, UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount); Changed |= Result != LoopUnrollResult::Unmodified; @@ -1644,18 +1652,18 @@ void LoopUnrollPass::printPipeline( static_cast<PassInfoMixin<LoopUnrollPass> *>(this)->printPipeline( OS, MapClassName2PassName); OS << "<"; - if (UnrollOpts.AllowPartial != None) - OS << (UnrollOpts.AllowPartial.value() ? "" : "no-") << "partial;"; - if (UnrollOpts.AllowPeeling != None) - OS << (UnrollOpts.AllowPeeling.value() ? "" : "no-") << "peeling;"; - if (UnrollOpts.AllowRuntime != None) - OS << (UnrollOpts.AllowRuntime.value() ? "" : "no-") << "runtime;"; - if (UnrollOpts.AllowUpperBound != None) - OS << (UnrollOpts.AllowUpperBound.value() ? "" : "no-") << "upperbound;"; - if (UnrollOpts.AllowProfileBasedPeeling != None) - OS << (UnrollOpts.AllowProfileBasedPeeling.value() ? "" : "no-") + if (UnrollOpts.AllowPartial != std::nullopt) + OS << (*UnrollOpts.AllowPartial ? "" : "no-") << "partial;"; + if (UnrollOpts.AllowPeeling != std::nullopt) + OS << (*UnrollOpts.AllowPeeling ? "" : "no-") << "peeling;"; + if (UnrollOpts.AllowRuntime != std::nullopt) + OS << (*UnrollOpts.AllowRuntime ? "" : "no-") << "runtime;"; + if (UnrollOpts.AllowUpperBound != std::nullopt) + OS << (*UnrollOpts.AllowUpperBound ? "" : "no-") << "upperbound;"; + if (UnrollOpts.AllowProfileBasedPeeling != std::nullopt) + OS << (*UnrollOpts.AllowProfileBasedPeeling ? "" : "no-") << "profile-peeling;"; - if (UnrollOpts.FullUnrollMaxCount != None) + if (UnrollOpts.FullUnrollMaxCount != std::nullopt) OS << "full-unroll-max=" << UnrollOpts.FullUnrollMaxCount << ";"; OS << "O" << UnrollOpts.OptLevel; OS << ">"; diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index c733aa4701ed..848be25a2fe0 100644 --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -147,51 +147,31 @@ struct LoopVersioningLICM { // LoopAccessInfo will take place only when it's necessary. LoopVersioningLICM(AliasAnalysis *AA, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, - function_ref<const LoopAccessInfo &(Loop *)> GetLAI) - : AA(AA), SE(SE), GetLAI(GetLAI), + LoopAccessInfoManager &LAIs, LoopInfo &LI, + Loop *CurLoop) + : AA(AA), SE(SE), LAIs(LAIs), LI(LI), CurLoop(CurLoop), LoopDepthThreshold(LVLoopDepthThreshold), InvariantThreshold(LVInvarThreshold), ORE(ORE) {} - bool runOnLoop(Loop *L, LoopInfo *LI, DominatorTree *DT); - - void reset() { - AA = nullptr; - SE = nullptr; - CurLoop = nullptr; - LoadAndStoreCounter = 0; - InvariantCounter = 0; - IsReadOnlyLoop = true; - ORE = nullptr; - CurAST.reset(); - } - - class AutoResetter { - public: - AutoResetter(LoopVersioningLICM &LVLICM) : LVLICM(LVLICM) {} - ~AutoResetter() { LVLICM.reset(); } - - private: - LoopVersioningLICM &LVLICM; - }; + bool run(DominatorTree *DT); private: // Current AliasAnalysis information - AliasAnalysis *AA = nullptr; + AliasAnalysis *AA; // Current ScalarEvolution - ScalarEvolution *SE = nullptr; + ScalarEvolution *SE; // Current Loop's LoopAccessInfo const LoopAccessInfo *LAI = nullptr; // Proxy for retrieving LoopAccessInfo. - function_ref<const LoopAccessInfo &(Loop *)> GetLAI; + LoopAccessInfoManager &LAIs; - // The current loop we are working on. - Loop *CurLoop = nullptr; + LoopInfo &LI; - // AliasSet information for the current loop. - std::unique_ptr<AliasSetTracker> CurAST; + // The current loop we are working on. + Loop *CurLoop; // Maximum loop nest threshold unsigned LoopDepthThreshold; @@ -275,9 +255,15 @@ bool LoopVersioningLICM::legalLoopStructure() { /// Check memory accesses in loop and confirms it's good for /// LoopVersioningLICM. bool LoopVersioningLICM::legalLoopMemoryAccesses() { - bool HasMayAlias = false; - bool TypeSafety = false; - bool HasMod = false; + // Loop over the body of this loop, construct AST. + BatchAAResults BAA(*AA); + AliasSetTracker AST(BAA); + for (auto *Block : CurLoop->getBlocks()) { + // Ignore blocks in subloops. + if (LI.getLoopFor(Block) == CurLoop) + AST.add(*Block); + } + // Memory check: // Transform phase will generate a versioned loop and also a runtime check to // ensure the pointers are independent and they don’t alias. @@ -290,7 +276,10 @@ bool LoopVersioningLICM::legalLoopMemoryAccesses() { // // Iterate over alias tracker sets, and confirm AliasSets doesn't have any // must alias set. - for (const auto &I : *CurAST) { + bool HasMayAlias = false; + bool TypeSafety = false; + bool HasMod = false; + for (const auto &I : AST) { const AliasSet &AS = I; // Skip Forward Alias Sets, as this should be ignored as part of // the AliasSetTracker object. @@ -413,7 +402,7 @@ bool LoopVersioningLICM::legalLoopInstructions() { } } // Get LoopAccessInfo from current loop via the proxy. - LAI = &GetLAI(CurLoop); + LAI = &LAIs.getInfo(*CurLoop); // Check LoopAccessInfo for need of runtime check. if (LAI->getRuntimePointerChecking()->getChecks().empty()) { LLVM_DEBUG(dbgs() << " LAA: Runtime check not found !!\n"); @@ -582,35 +571,18 @@ bool LoopVersioningLICMLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); OptimizationRemarkEmitter *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); - auto GetLAI = [&](Loop *L) -> const LoopAccessInfo & { - return getAnalysis<LoopAccessLegacyAnalysis>().getInfo(L); - }; - - return LoopVersioningLICM(AA, SE, ORE, GetLAI).runOnLoop(L, LI, DT); + return LoopVersioningLICM(AA, SE, ORE, LAIs, LI, L).run(DT); } -bool LoopVersioningLICM::runOnLoop(Loop *L, LoopInfo *LI, DominatorTree *DT) { - // This will automatically release all resources hold by the current - // LoopVersioningLICM object. - AutoResetter Resetter(*this); - +bool LoopVersioningLICM::run(DominatorTree *DT) { // Do not do the transformation if disabled by metadata. - if (hasLICMVersioningTransformation(L) & TM_Disable) + if (hasLICMVersioningTransformation(CurLoop) & TM_Disable) return false; - // Set Current Loop - CurLoop = L; - CurAST.reset(new AliasSetTracker(*AA)); - - // Loop over the body of this loop, construct AST. - for (auto *Block : L->getBlocks()) { - if (LI->getLoopFor(Block) == L) // Ignore blocks in subloop. - CurAST->add(*Block); // Incorporate the specified basic block - } - bool Changed = false; // Check feasiblity of LoopVersioningLICM. @@ -621,7 +593,7 @@ bool LoopVersioningLICM::runOnLoop(Loop *L, LoopInfo *LI, DominatorTree *DT) { // Create memcheck for memory accessed inside loop. // Clone original loop, and set blocks properly. LoopVersioning LVer(*LAI, LAI->getRuntimePointerChecking()->getChecks(), - CurLoop, LI, DT, SE); + CurLoop, &LI, DT, SE); LVer.versionLoop(); // Set Loop Versioning metaData for original loop. addStringMetadataToLoop(LVer.getNonVersionedLoop(), LICMVersioningMetaData); @@ -667,15 +639,11 @@ PreservedAnalyses LoopVersioningLICMPass::run(Loop &L, LoopAnalysisManager &AM, AliasAnalysis *AA = &LAR.AA; ScalarEvolution *SE = &LAR.SE; DominatorTree *DT = &LAR.DT; - LoopInfo *LI = &LAR.LI; const Function *F = L.getHeader()->getParent(); OptimizationRemarkEmitter ORE(F); - auto GetLAI = [&](Loop *L) -> const LoopAccessInfo & { - return AM.getResult<LoopAccessAnalysis>(*L, LAR); - }; - - if (!LoopVersioningLICM(AA, SE, &ORE, GetLAI).runOnLoop(&L, LI, DT)) + LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr); + if (!LoopVersioningLICM(AA, SE, &ORE, LAIs, LAR.LI, &L).run(DT)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); } diff --git a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp index 47493b54a527..ef22b0401b1b 100644 --- a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp @@ -31,6 +31,7 @@ #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -96,7 +97,7 @@ static bool replaceConditionalBranchesOnConstant(Instruction *II, static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo &TLI, DominatorTree *DT) { - Optional<DomTreeUpdater> DTU; + std::optional<DomTreeUpdater> DTU; if (DT) DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); @@ -143,10 +144,10 @@ static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo &TLI, break; } HasDeadBlocks |= replaceConditionalBranchesOnConstant( - II, NewValue, DTU ? DTU.getPointer() : nullptr); + II, NewValue, DTU ? &*DTU : nullptr); } if (HasDeadBlocks) - removeUnreachableBlocks(F, DTU ? DTU.getPointer() : nullptr); + removeUnreachableBlocks(F, DTU ? &*DTU : nullptr); return !Worklist.empty(); } diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 88fad9896c59..454aa56be531 100644 --- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -27,6 +27,8 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/MisExpect.h" +#include <cmath> + using namespace llvm; #define DEBUG_TYPE "lower-expect-intrinsic" @@ -123,6 +125,17 @@ static void handlePhiDef(CallInst *Expect) { if (!ExpectedValue) return; const APInt &ExpectedPhiValue = ExpectedValue->getValue(); + bool ExpectedValueIsLikely = true; + Function *Fn = Expect->getCalledFunction(); + // If the function is expect_with_probability, then we need to take the + // probability into consideration. For example, in + // expect.with.probability.i64(i64 %a, i64 1, double 0.0), the + // "ExpectedValue" 1 is unlikely. This affects probability propagation later. + if (Fn->getIntrinsicID() == Intrinsic::expect_with_probability) { + auto *Confidence = cast<ConstantFP>(Expect->getArgOperand(2)); + double TrueProb = Confidence->getValueAPF().convertToDouble(); + ExpectedValueIsLikely = (TrueProb > 0.5); + } // Walk up in backward a list of instructions that // have 'copy' semantics by 'stripping' the copies @@ -164,7 +177,7 @@ static void handlePhiDef(CallInst *Expect) { // Executes the recorded operations on input 'Value'. auto ApplyOperations = [&](const APInt &Value) { APInt Result = Value; - for (auto Op : llvm::reverse(Operations)) { + for (auto *Op : llvm::reverse(Operations)) { switch (Op->getOpcode()) { case Instruction::Xor: Result ^= cast<ConstantInt>(Op->getOperand(1))->getValue(); @@ -211,9 +224,12 @@ static void handlePhiDef(CallInst *Expect) { continue; // Not an interesting case when IsUnlikely is false -- we can not infer - // anything useful when the operand value matches the expected phi - // output. - if (ExpectedPhiValue == ApplyOperations(CI->getValue())) + // anything useful when: + // (1) We expect some phi output and the operand value matches it, or + // (2) We don't expect some phi output (i.e. the "ExpectedValue" has low + // probability) and the operand value doesn't match that. + const APInt &CurrentPhiValue = ApplyOperations(CI->getValue()); + if (ExpectedValueIsLikely == (ExpectedPhiValue == CurrentPhiValue)) continue; BranchInst *BI = GetDomConditional(i); @@ -246,6 +262,8 @@ static void handlePhiDef(CallInst *Expect) { uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight( Expect->getCalledFunction()->getIntrinsicID(), Expect, 2); + if (!ExpectedValueIsLikely) + std::swap(LikelyBranchWeightVal, UnlikelyBranchWeightVal); if (IsOpndComingFromSuccessor(BI->getSuccessor(1))) BI->setMetadata(LLVMContext::MD_prof, diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index f1e1359255bd..17594b98c5bc 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -46,6 +46,8 @@ #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/MatrixUtils.h" +#include <cmath> + using namespace llvm; using namespace PatternMatch; @@ -80,6 +82,9 @@ static cl::opt<MatrixLayoutTy> MatrixLayout( clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout"))); +static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", + cl::init(false)); + /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { @@ -88,6 +93,39 @@ static DISubprogram *getSubprogram(DIScope *Scope) { return cast<DILocalScope>(Scope)->getSubprogram(); } +/// Erase \p V from \p BB and move \II forward to avoid invalidating +/// iterators. +static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II, + BasicBlock &BB) { + auto *Inst = cast<Instruction>(V); + // Still used, don't erase. + if (!Inst->use_empty()) + return; + if (II != BB.rend() && Inst == &*II) + ++II; + Inst->eraseFromParent(); +} + +/// Return true if V is a splat of a value (which is used when multiplying a +/// matrix with a scalar). +static bool isSplat(Value *V) { + if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) + return SV->isZeroEltSplat(); + return false; +} + +/// Match any mul operation (fp or integer). +template <typename LTy, typename RTy> +auto m_AnyMul(const LTy &L, const RTy &R) { + return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); +} + +/// Match any add operation (fp or integer). +template <typename LTy, typename RTy> +auto m_AnyAdd(const LTy &L, const RTy &R) { + return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); +} + namespace { // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute @@ -384,6 +422,9 @@ class LowerMatrixIntrinsics { return NumColumns; return NumRows; } + + /// Returns the transposed shape. + ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } }; /// Maps instructions to their shape information. The shape information @@ -437,10 +478,10 @@ public: /// Return the estimated number of vector ops required for an operation on /// \p VT * N. unsigned getNumOps(Type *ST, unsigned N) { - return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / + return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedValue() / double(TTI.getRegisterBitWidth( TargetTransformInfo::RGK_FixedWidthVector) - .getFixedSize())); + .getFixedValue())); } /// Return the set of vectors that a matrix value is lowered to. @@ -684,115 +725,198 @@ public: return NewWorkList; } - /// Try moving transposes in order to fold them away or into multiplies. - void optimizeTransposes() { - auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) { - // We need to remove Old from the ShapeMap otherwise RAUW will replace it - // with New. We should only add New it it supportsShapeInfo so we insert - // it conditionally instead. - auto S = ShapeMap.find(&Old); - if (S != ShapeMap.end()) { - ShapeMap.erase(S); - if (supportsShapeInfo(New)) - ShapeMap.insert({New, S->second}); - } - Old.replaceAllUsesWith(New); + /// (Op0 op Op1)^T -> Op0^T op Op1^T + /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use + /// them on both sides of \p Operation. + Instruction *distributeTransposes( + Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1, + MatrixBuilder &Builder, + function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)> + Operation) { + Value *T0 = Builder.CreateMatrixTranspose( + Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t"); + // We are being run after shape prop, add shape for newly created + // instructions so that we lower them later. + setShapeInfo(T0, Shape0.t()); + Value *T1 = Builder.CreateMatrixTranspose( + Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t"); + setShapeInfo(T1, Shape1.t()); + return Operation(T0, Shape0.t(), T1, Shape1.t()); + } + + void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) { + // We need to remove Old from the ShapeMap otherwise RAUW will replace it + // with New. We should only add New it it supportsShapeInfo so we insert + // it conditionally instead. + auto S = ShapeMap.find(&Old); + if (S != ShapeMap.end()) { + ShapeMap.erase(S); + if (supportsShapeInfo(New)) + ShapeMap.insert({New, S->second}); + } + Old.replaceAllUsesWith(New); + } + + /// Sink a top-level transpose inside matmuls and adds. + /// This creates and erases instructions as needed, and returns the newly + /// created instruction while updating the iterator to avoid invalidation. If + /// this returns nullptr, no new instruction was created. + Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) { + BasicBlock &BB = *I.getParent(); + IRBuilder<> IB(&I); + MatrixBuilder Builder(IB); + + Value *TA, *TAMA, *TAMB; + ConstantInt *R, *K, *C; + if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) + return nullptr; + + // Transpose of a transpose is a nop + Value *TATA; + if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { + updateShapeAndReplaceAllUsesWith(I, TATA); + eraseFromParentAndMove(&I, II, BB); + eraseFromParentAndMove(TA, II, BB); + return nullptr; + } + + // k^T -> k + if (isSplat(TA)) { + updateShapeAndReplaceAllUsesWith(I, TA); + eraseFromParentAndMove(&I, II, BB); + return nullptr; + } + + // (A * B)^t -> B^t * A^t + // RxK KxC CxK KxR + if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( + m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), + m_ConstantInt(K), m_ConstantInt(C)))) { + auto NewInst = distributeTransposes( + TAMB, {K, C}, TAMA, {R, K}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + return Builder.CreateMatrixMultiply(T0, T1, Shape0.NumRows, + Shape0.NumColumns, + Shape1.NumColumns, "mmul"); + }); + updateShapeAndReplaceAllUsesWith(I, NewInst); + eraseFromParentAndMove(&I, II, BB); + eraseFromParentAndMove(TA, II, BB); + return NewInst; + } + + // Same as above, but with a mul, which occurs when multiplied + // with a scalar. + // (A * k)^t -> A^t * k + // R x C RxC + if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && + (isSplat(TAMA) || isSplat(TAMB))) { + IRBuilder<> LocalBuilder(&I); + // We know that the transposed operand is of shape RxC. + // An when multiplied with a scalar, the shape is preserved. + auto NewInst = distributeTransposes( + TAMA, {R, C}, TAMB, {R, C}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + bool IsFP = I.getType()->isFPOrFPVectorTy(); + auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") + : LocalBuilder.CreateMul(T0, T1, "mmul"); + auto *Result = cast<Instruction>(Mul); + setShapeInfo(Result, Shape0); + return Result; + }); + updateShapeAndReplaceAllUsesWith(I, NewInst); + eraseFromParentAndMove(&I, II, BB); + eraseFromParentAndMove(TA, II, BB); + return NewInst; + } + + // (A + B)^t -> A^t + B^t + // RxC RxC CxR CxR + if (match(TA, m_AnyAdd(m_Value(TAMA), m_Value(TAMB)))) { + IRBuilder<> LocalBuilder(&I); + auto NewInst = distributeTransposes( + TAMA, {R, C}, TAMB, {R, C}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + auto *FAdd = + cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd")); + setShapeInfo(FAdd, Shape0); + return FAdd; + }); + updateShapeAndReplaceAllUsesWith(I, NewInst); + eraseFromParentAndMove(&I, II, BB); + eraseFromParentAndMove(TA, II, BB); + return NewInst; + } + + return nullptr; + } + + void liftTranspose(Instruction &I) { + // Erase dead Instructions after lifting transposes from binops. + auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) { + if (T.use_empty()) + T.eraseFromParent(); + if (A->use_empty()) + cast<Instruction>(A)->eraseFromParent(); + if (A != B && B->use_empty()) + cast<Instruction>(B)->eraseFromParent(); }; - // First sink all transposes inside matmuls, hoping that we end up with NN, - // NT or TN variants. + Value *A, *B, *AT, *BT; + ConstantInt *R, *K, *C; + // A^t * B ^t -> (B * A)^t + if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( + m_Value(A), m_Value(B), m_ConstantInt(R), + m_ConstantInt(K), m_ConstantInt(C))) && + match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && + match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { + IRBuilder<> IB(&I); + MatrixBuilder Builder(IB); + Value *M = Builder.CreateMatrixMultiply( + BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); + setShapeInfo(M, {C, R}); + Instruction *NewInst = Builder.CreateMatrixTranspose(M, C->getZExtValue(), + R->getZExtValue()); + updateShapeAndReplaceAllUsesWith(I, NewInst); + CleanupBinOp(I, A, B); + } + // A^t + B ^t -> (A + B)^t + else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && + match(A, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && + match(B, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) { + IRBuilder<> Builder(&I); + Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); + setShapeInfo(Add, {C, R}); + MatrixBuilder MBuilder(Builder); + Instruction *NewInst = MBuilder.CreateMatrixTranspose( + Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t"); + updateShapeAndReplaceAllUsesWith(I, NewInst); + CleanupBinOp(I, A, B); + } + } + + /// Try moving transposes in order to fold them away or into multiplies. + void optimizeTransposes() { + // First sink all transposes inside matmuls and adds, hoping that we end up + // with NN, NT or TN variants. for (BasicBlock &BB : reverse(Func)) { for (auto II = BB.rbegin(); II != BB.rend();) { Instruction &I = *II; // We may remove II. By default continue on the next/prev instruction. ++II; - // If we were to erase II, move again. - auto EraseFromParent = [&II, &BB](Value *V) { - auto *Inst = cast<Instruction>(V); - if (Inst->use_empty()) { - if (II != BB.rend() && Inst == &*II) { - ++II; - } - Inst->eraseFromParent(); - } - }; - - // If we're creating a new instruction, continue from there. - Instruction *NewInst = nullptr; - - IRBuilder<> IB(&I); - MatrixBuilder Builder(IB); - - Value *TA, *TAMA, *TAMB; - ConstantInt *R, *K, *C; - if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) { - - // Transpose of a transpose is a nop - Value *TATA; - if (match(TA, - m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) { - ReplaceAllUsesWith(I, TATA); - EraseFromParent(&I); - EraseFromParent(TA); - } - - // (A * B)^t -> B^t * A^t - // RxK KxC CxK KxR - else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>( - m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), - m_ConstantInt(K), m_ConstantInt(C)))) { - Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), - C->getZExtValue(), - TAMB->getName() + "_t"); - // We are being run after shape prop, add shape for newly created - // instructions so that we lower them later. - setShapeInfo(T0, {C, K}); - Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), - K->getZExtValue(), - TAMA->getName() + "_t"); - setShapeInfo(T1, {K, R}); - NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), - K->getZExtValue(), - R->getZExtValue(), "mmul"); - ReplaceAllUsesWith(I, NewInst); - EraseFromParent(&I); - EraseFromParent(TA); - } - } - - // If we replaced I with a new instruction, continue from there. - if (NewInst) + if (Instruction *NewInst = sinkTranspose(I, II)) II = std::next(BasicBlock::reverse_iterator(NewInst)); } } - // If we have a TT matmul, lift the transpose. We may be able to fold into - // consuming multiply. + // If we have a TT matmul or a TT add, lift the transpose. We may be able + // to fold into consuming multiply or add. for (BasicBlock &BB : Func) { for (Instruction &I : llvm::make_early_inc_range(BB)) { - Value *A, *B, *AT, *BT; - ConstantInt *R, *K, *C; - // A^t * B ^t -> (B * A)^t - if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( - m_Value(A), m_Value(B), m_ConstantInt(R), - m_ConstantInt(K), m_ConstantInt(C))) && - match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && - match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { - IRBuilder<> IB(&I); - MatrixBuilder Builder(IB); - Value *M = Builder.CreateMatrixMultiply( - BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); - setShapeInfo(M, {C, R}); - Instruction *NewInst = Builder.CreateMatrixTranspose( - M, C->getZExtValue(), R->getZExtValue()); - ReplaceAllUsesWith(I, NewInst); - if (I.use_empty()) - I.eraseFromParent(); - if (A->use_empty()) - cast<Instruction>(A)->eraseFromParent(); - if (A != B && B->use_empty()) - cast<Instruction>(B)->eraseFromParent(); - } + liftTranspose(I); } } } @@ -832,10 +956,10 @@ public: if (!isMinimal()) { optimizeTransposes(); - LLVM_DEBUG({ + if (PrintAfterTransposeOpt) { dbgs() << "Dump after matrix transpose optimization:\n"; - Func.dump(); - }); + Func.print(dbgs()); + } } bool Changed = false; @@ -1199,8 +1323,8 @@ public: bool IsScalarMatrixTransposed, FastMathFlags FMF) { const unsigned VF = std::max<unsigned>( TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) - .getFixedSize() / - Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), + .getFixedValue() / + Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(), 1U); unsigned R = Result.getNumRows(); unsigned C = Result.getNumColumns(); @@ -1378,8 +1502,8 @@ public: const unsigned VF = std::max<unsigned>( TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) - .getFixedSize() / - EltType->getPrimitiveSizeInBits().getFixedSize(), + .getFixedValue() / + EltType->getPrimitiveSizeInBits().getFixedValue(), 1U); // Cost model for tiling @@ -2160,7 +2284,7 @@ public: // the inlinedAt chain. If the function does not have a DISubprogram, we // only map them to the containing function. MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; - for (auto &KV : Inst2Matrix) { + for (const auto &KV : Inst2Matrix) { if (Func.getSubprogram()) { auto *I = cast<Instruction>(KV.first); DILocation *Context = I->getDebugLoc(); diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 1f5bc69acecd..64846484f936 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -13,7 +13,6 @@ #include "llvm/Transforms/Scalar/MemCpyOptimizer.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -58,6 +57,7 @@ #include <algorithm> #include <cassert> #include <cstdint> +#include <optional> using namespace llvm; @@ -176,8 +176,8 @@ public: void addStore(int64_t OffsetFromFirst, StoreInst *SI) { TypeSize StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType()); assert(!StoreSize.isScalable() && "Can't track scalable-typed stores"); - addRange(OffsetFromFirst, StoreSize.getFixedSize(), SI->getPointerOperand(), - SI->getAlign(), SI); + addRange(OffsetFromFirst, StoreSize.getFixedValue(), + SI->getPointerOperand(), SI->getAlign(), SI); } void addMemSet(int64_t OffsetFromFirst, MemSetInst *MSI) { @@ -331,23 +331,34 @@ void MemCpyOptPass::eraseInstruction(Instruction *I) { } // Check for mod or ref of Loc between Start and End, excluding both boundaries. -// Start and End must be in the same block -static bool accessedBetween(AliasAnalysis &AA, MemoryLocation Loc, +// Start and End must be in the same block. +// If SkippedLifetimeStart is provided, skip over one clobbering lifetime.start +// intrinsic and store it inside SkippedLifetimeStart. +static bool accessedBetween(BatchAAResults &AA, MemoryLocation Loc, const MemoryUseOrDef *Start, - const MemoryUseOrDef *End) { + const MemoryUseOrDef *End, + Instruction **SkippedLifetimeStart = nullptr) { assert(Start->getBlock() == End->getBlock() && "Only local supported"); for (const MemoryAccess &MA : make_range(++Start->getIterator(), End->getIterator())) { - if (isModOrRefSet(AA.getModRefInfo(cast<MemoryUseOrDef>(MA).getMemoryInst(), - Loc))) + Instruction *I = cast<MemoryUseOrDef>(MA).getMemoryInst(); + if (isModOrRefSet(AA.getModRefInfo(I, Loc))) { + auto *II = dyn_cast<IntrinsicInst>(I); + if (II && II->getIntrinsicID() == Intrinsic::lifetime_start && + SkippedLifetimeStart && !*SkippedLifetimeStart) { + *SkippedLifetimeStart = I; + continue; + } + return true; + } } return false; } // Check for mod of Loc between Start and End, excluding both boundaries. // Start and End can be in different blocks. -static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA, +static bool writtenBetween(MemorySSA *MSSA, BatchAAResults &AA, MemoryLocation Loc, const MemoryUseOrDef *Start, const MemoryUseOrDef *End) { if (isa<MemoryUse>(End)) { @@ -368,7 +379,7 @@ static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA, // TODO: Only walk until we hit Start. MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( - End->getDefiningAccess(), Loc); + End->getDefiningAccess(), Loc, AA); return !MSSA->dominates(Clobber, Start); } @@ -451,7 +462,7 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, break; // Check to see if this store is to a constant offset from the start ptr. - Optional<int64_t> Offset = + std::optional<int64_t> Offset = isPointerOffset(StartPtr, NextStore->getPointerOperand(), DL); if (!Offset) break; @@ -465,7 +476,8 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, break; // Check to see if this store is to a constant offset from the start ptr. - Optional<int64_t> Offset = isPointerOffset(StartPtr, MSI->getDest(), DL); + std::optional<int64_t> Offset = + isPointerOffset(StartPtr, MSI->getDest(), DL); if (!Offset) break; @@ -504,6 +516,8 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End - Range.Start, Range.Alignment); + AMemSet->mergeDIAssignID(Range.TheStores); + LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI : Range.TheStores) dbgs() << *SI << '\n'; @@ -546,9 +560,17 @@ bool MemCpyOptPass::moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI) { // Keep track of the arguments of all instruction we plan to lift // so we can make sure to lift them as well if appropriate. DenseSet<Instruction*> Args; - if (auto *Ptr = dyn_cast<Instruction>(SI->getPointerOperand())) - if (Ptr->getParent() == SI->getParent()) - Args.insert(Ptr); + auto AddArg = [&](Value *Arg) { + auto *I = dyn_cast<Instruction>(Arg); + if (I && I->getParent() == SI->getParent()) { + // Cannot hoist user of P above P + if (I == P) return false; + Args.insert(I); + } + return true; + }; + if (!AddArg(SI->getPointerOperand())) + return false; // Instruction to lift before P. SmallVector<Instruction *, 8> ToLift{SI}; @@ -569,7 +591,7 @@ bool MemCpyOptPass::moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI) { if (!isGuaranteedToTransferExecutionToSuccessor(C)) return false; - bool MayAlias = isModOrRefSet(AA->getModRefInfo(C, None)); + bool MayAlias = isModOrRefSet(AA->getModRefInfo(C, std::nullopt)); bool NeedLift = false; if (Args.erase(C)) @@ -612,14 +634,9 @@ bool MemCpyOptPass::moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI) { } ToLift.push_back(C); - for (unsigned k = 0, e = C->getNumOperands(); k != e; ++k) - if (auto *A = dyn_cast<Instruction>(C->getOperand(k))) { - if (A->getParent() == SI->getParent()) { - // Cannot hoist user of P above P - if(A == P) return false; - Args.insert(A); - } - } + for (Value *Op : C->operands()) + if (!AddArg(Op)) + return false; } // Find MSSA insertion point. Normally P will always have a corresponding @@ -657,6 +674,116 @@ bool MemCpyOptPass::moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI) { return true; } +bool MemCpyOptPass::processStoreOfLoad(StoreInst *SI, LoadInst *LI, + const DataLayout &DL, + BasicBlock::iterator &BBI) { + if (!LI->isSimple() || !LI->hasOneUse() || + LI->getParent() != SI->getParent()) + return false; + + auto *T = LI->getType(); + // Don't introduce calls to memcpy/memmove intrinsics out of thin air if + // the corresponding libcalls are not available. + // TODO: We should really distinguish between libcall availability and + // our ability to introduce intrinsics. + if (T->isAggregateType() && + (EnableMemCpyOptWithoutLibcalls || + (TLI->has(LibFunc_memcpy) && TLI->has(LibFunc_memmove)))) { + MemoryLocation LoadLoc = MemoryLocation::get(LI); + + // We use alias analysis to check if an instruction may store to + // the memory we load from in between the load and the store. If + // such an instruction is found, we try to promote there instead + // of at the store position. + // TODO: Can use MSSA for this. + Instruction *P = SI; + for (auto &I : make_range(++LI->getIterator(), SI->getIterator())) { + if (isModSet(AA->getModRefInfo(&I, LoadLoc))) { + P = &I; + break; + } + } + + // We found an instruction that may write to the loaded memory. + // We can try to promote at this position instead of the store + // position if nothing aliases the store memory after this and the store + // destination is not in the range. + if (P && P != SI) { + if (!moveUp(SI, P, LI)) + P = nullptr; + } + + // If a valid insertion position is found, then we can promote + // the load/store pair to a memcpy. + if (P) { + // If we load from memory that may alias the memory we store to, + // memmove must be used to preserve semantic. If not, memcpy can + // be used. Also, if we load from constant memory, memcpy can be used + // as the constant memory won't be modified. + bool UseMemMove = false; + if (isModSet(AA->getModRefInfo(SI, LoadLoc))) + UseMemMove = true; + + uint64_t Size = DL.getTypeStoreSize(T); + + IRBuilder<> Builder(P); + Instruction *M; + if (UseMemMove) + M = Builder.CreateMemMove( + SI->getPointerOperand(), SI->getAlign(), + LI->getPointerOperand(), LI->getAlign(), Size); + else + M = Builder.CreateMemCpy( + SI->getPointerOperand(), SI->getAlign(), + LI->getPointerOperand(), LI->getAlign(), Size); + M->copyMetadata(*SI, LLVMContext::MD_DIAssignID); + + LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " + << *M << "\n"); + + auto *LastDef = + cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI)); + auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); + MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); + + eraseInstruction(SI); + eraseInstruction(LI); + ++NumMemCpyInstr; + + // Make sure we do not invalidate the iterator. + BBI = M->getIterator(); + return true; + } + } + + // Detect cases where we're performing call slot forwarding, but + // happen to be using a load-store pair to implement it, rather than + // a memcpy. + BatchAAResults BAA(*AA); + auto GetCall = [&]() -> CallInst * { + // We defer this expensive clobber walk until the cheap checks + // have been done on the source inside performCallSlotOptzn. + if (auto *LoadClobber = dyn_cast<MemoryUseOrDef>( + MSSA->getWalker()->getClobberingMemoryAccess(LI, BAA))) + return dyn_cast_or_null<CallInst>(LoadClobber->getMemoryInst()); + return nullptr; + }; + + bool Changed = performCallSlotOptzn( + LI, SI, SI->getPointerOperand()->stripPointerCasts(), + LI->getPointerOperand()->stripPointerCasts(), + DL.getTypeStoreSize(SI->getOperand(0)->getType()), + std::min(SI->getAlign(), LI->getAlign()), BAA, GetCall); + if (Changed) { + eraseInstruction(SI); + eraseInstruction(LI); + ++NumMemCpyInstr; + return true; + } + + return false; +} + bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (!SI->isSimple()) return false; @@ -679,109 +806,8 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { return false; // Load to store forwarding can be interpreted as memcpy. - if (auto *LI = dyn_cast<LoadInst>(StoredVal)) { - if (LI->isSimple() && LI->hasOneUse() && - LI->getParent() == SI->getParent()) { - - auto *T = LI->getType(); - // Don't introduce calls to memcpy/memmove intrinsics out of thin air if - // the corresponding libcalls are not available. - // TODO: We should really distinguish between libcall availability and - // our ability to introduce intrinsics. - if (T->isAggregateType() && - (EnableMemCpyOptWithoutLibcalls || - (TLI->has(LibFunc_memcpy) && TLI->has(LibFunc_memmove)))) { - MemoryLocation LoadLoc = MemoryLocation::get(LI); - - // We use alias analysis to check if an instruction may store to - // the memory we load from in between the load and the store. If - // such an instruction is found, we try to promote there instead - // of at the store position. - // TODO: Can use MSSA for this. - Instruction *P = SI; - for (auto &I : make_range(++LI->getIterator(), SI->getIterator())) { - if (isModSet(AA->getModRefInfo(&I, LoadLoc))) { - P = &I; - break; - } - } - - // We found an instruction that may write to the loaded memory. - // We can try to promote at this position instead of the store - // position if nothing aliases the store memory after this and the store - // destination is not in the range. - if (P && P != SI) { - if (!moveUp(SI, P, LI)) - P = nullptr; - } - - // If a valid insertion position is found, then we can promote - // the load/store pair to a memcpy. - if (P) { - // If we load from memory that may alias the memory we store to, - // memmove must be used to preserve semantic. If not, memcpy can - // be used. Also, if we load from constant memory, memcpy can be used - // as the constant memory won't be modified. - bool UseMemMove = false; - if (isModSet(AA->getModRefInfo(SI, LoadLoc))) - UseMemMove = true; - - uint64_t Size = DL.getTypeStoreSize(T); - - IRBuilder<> Builder(P); - Instruction *M; - if (UseMemMove) - M = Builder.CreateMemMove( - SI->getPointerOperand(), SI->getAlign(), - LI->getPointerOperand(), LI->getAlign(), Size); - else - M = Builder.CreateMemCpy( - SI->getPointerOperand(), SI->getAlign(), - LI->getPointerOperand(), LI->getAlign(), Size); - - LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " - << *M << "\n"); - - auto *LastDef = - cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI)); - auto *NewAccess = MSSAU->createMemoryAccessAfter(M, LastDef, LastDef); - MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); - - eraseInstruction(SI); - eraseInstruction(LI); - ++NumMemCpyInstr; - - // Make sure we do not invalidate the iterator. - BBI = M->getIterator(); - return true; - } - } - - // Detect cases where we're performing call slot forwarding, but - // happen to be using a load-store pair to implement it, rather than - // a memcpy. - auto GetCall = [&]() -> CallInst * { - // We defer this expensive clobber walk until the cheap checks - // have been done on the source inside performCallSlotOptzn. - if (auto *LoadClobber = dyn_cast<MemoryUseOrDef>( - MSSA->getWalker()->getClobberingMemoryAccess(LI))) - return dyn_cast_or_null<CallInst>(LoadClobber->getMemoryInst()); - return nullptr; - }; - - bool changed = performCallSlotOptzn( - LI, SI, SI->getPointerOperand()->stripPointerCasts(), - LI->getPointerOperand()->stripPointerCasts(), - DL.getTypeStoreSize(SI->getOperand(0)->getType()), - std::min(SI->getAlign(), LI->getAlign()), GetCall); - if (changed) { - eraseInstruction(SI); - eraseInstruction(LI); - ++NumMemCpyInstr; - return true; - } - } - } + if (auto *LI = dyn_cast<LoadInst>(StoredVal)) + return processStoreOfLoad(SI, LI, DL, BBI); // The following code creates memset intrinsics out of thin air. Don't do // this if the corresponding libfunc is not available. @@ -813,6 +839,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { IRBuilder<> Builder(SI); auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size, SI->getAlign()); + M->copyMetadata(*SI, LLVMContext::MD_DIAssignID); LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); @@ -853,7 +880,7 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, Instruction *cpyStore, Value *cpyDest, Value *cpySrc, TypeSize cpySize, - Align cpyAlign, + Align cpyDestAlign, BatchAAResults &BAA, std::function<CallInst *()> GetC) { // The general transformation to keep in mind is // @@ -910,22 +937,33 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, // Check that nothing touches the dest of the copy between // the call and the store/memcpy. - if (accessedBetween(*AA, DestLoc, MSSA->getMemoryAccess(C), - MSSA->getMemoryAccess(cpyStore))) { + Instruction *SkippedLifetimeStart = nullptr; + if (accessedBetween(BAA, DestLoc, MSSA->getMemoryAccess(C), + MSSA->getMemoryAccess(cpyStore), &SkippedLifetimeStart)) { LLVM_DEBUG(dbgs() << "Call Slot: Dest pointer modified after call\n"); return false; } + // If we need to move a lifetime.start above the call, make sure that we can + // actually do so. If the argument is bitcasted for example, we would have to + // move the bitcast as well, which we don't handle. + if (SkippedLifetimeStart) { + auto *LifetimeArg = + dyn_cast<Instruction>(SkippedLifetimeStart->getOperand(1)); + if (LifetimeArg && LifetimeArg->getParent() == C->getParent() && + C->comesBefore(LifetimeArg)) + return false; + } + // Check that accessing the first srcSize bytes of dest will not cause a // trap. Otherwise the transform is invalid since it might cause a trap // to occur earlier than it otherwise would. if (!isDereferenceableAndAlignedPointer(cpyDest, Align(1), APInt(64, cpySize), - DL, C, DT)) { + DL, C, AC, DT)) { LLVM_DEBUG(dbgs() << "Call Slot: Dest pointer not dereferenceable\n"); return false; } - // Make sure that nothing can observe cpyDest being written early. There are // a number of cases to consider: // 1. cpyDest cannot be accessed between C and cpyStore as a precondition of @@ -941,17 +979,19 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, // renders accesses from other threads undefined. // TODO: This is currently not checked. if (mayBeVisibleThroughUnwinding(cpyDest, C, cpyStore)) { - LLVM_DEBUG(dbgs() << "Call Slot: Dest may be visible through unwinding"); + LLVM_DEBUG(dbgs() << "Call Slot: Dest may be visible through unwinding\n"); return false; } // Check that dest points to memory that is at least as aligned as src. Align srcAlign = srcAlloca->getAlign(); - bool isDestSufficientlyAligned = srcAlign <= cpyAlign; + bool isDestSufficientlyAligned = srcAlign <= cpyDestAlign; // If dest is not aligned enough and we can't increase its alignment then // bail out. - if (!isDestSufficientlyAligned && !isa<AllocaInst>(cpyDest)) + if (!isDestSufficientlyAligned && !isa<AllocaInst>(cpyDest)) { + LLVM_DEBUG(dbgs() << "Call Slot: Dest not sufficiently aligned\n"); return false; + } // Check that src is not accessed except via the call and the memcpy. This // guarantees that it holds only undefined values when passed in (so the final @@ -1026,7 +1066,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, // pointer (we have already any direct mod/refs in the loop above). // Also bail if we hit a terminator, as we don't want to scan into other // blocks. - if (isModOrRefSet(AA->getModRefInfo(&I, SrcLoc)) || I.isTerminator()) + if (isModOrRefSet(BAA.getModRefInfo(&I, SrcLoc)) || I.isTerminator()) return false; } } @@ -1047,10 +1087,11 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, // unexpected manner, for example via a global, which we deduce from // the use analysis, we also need to know that it does not sneakily // access dest. We rely on AA to figure this out for us. - ModRefInfo MR = AA->getModRefInfo(C, cpyDest, LocationSize::precise(srcSize)); + MemoryLocation DestWithSrcSize(cpyDest, LocationSize::precise(srcSize)); + ModRefInfo MR = BAA.getModRefInfo(C, DestWithSrcSize); // If necessary, perform additional analysis. if (isModOrRefSet(MR)) - MR = AA->callCapturesBefore(C, cpyDest, LocationSize::precise(srcSize), DT); + MR = BAA.callCapturesBefore(C, DestWithSrcSize, DT); if (isModOrRefSet(MR)) return false; @@ -1090,6 +1131,12 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, cast<AllocaInst>(cpyDest)->setAlignment(srcAlign); } + if (SkippedLifetimeStart) { + SkippedLifetimeStart->moveBefore(C); + MSSAU->moveBefore(MSSA->getMemoryAccess(SkippedLifetimeStart), + MSSA->getMemoryAccess(C)); + } + // Update AA metadata // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be // handled here, but combineMetadata doesn't support them yet @@ -1108,7 +1155,8 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, /// We've found that the (upward scanning) memory dependence of memcpy 'M' is /// the memcpy 'MDep'. Try to simplify M to copy from MDep's input if we can. bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, - MemCpyInst *MDep) { + MemCpyInst *MDep, + BatchAAResults &BAA) { // We can only transforms memcpy's where the dest of one is the source of the // other. if (M->getSource() != MDep->getDest() || MDep->isVolatile()) @@ -1142,7 +1190,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, // then we could still perform the xform by moving M up to the first memcpy. // TODO: It would be sufficient to check the MDep source up to the memcpy // size of M, rather than MDep. - if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), + if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep), MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M))) return false; @@ -1152,7 +1200,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, // still want to eliminate the intermediate value, but we have to generate a // memmove instead of memcpy. bool UseMemMove = false; - if (isModSet(AA->getModRefInfo(M, MemoryLocation::getForSource(MDep)))) + if (isModSet(BAA.getModRefInfo(M, MemoryLocation::getForSource(MDep)))) UseMemMove = true; // If all checks passed, then we can transform M. @@ -1178,6 +1226,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, NewM = Builder.CreateMemCpy(M->getRawDest(), M->getDestAlign(), MDep->getRawSource(), MDep->getSourceAlign(), M->getLength(), M->isVolatile()); + NewM->copyMetadata(*M, LLVMContext::MD_DIAssignID); assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M))); auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)); @@ -1205,20 +1254,21 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, /// memset(dst + src_size, c, dst_size <= src_size ? 0 : dst_size - src_size); /// \endcode bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, - MemSetInst *MemSet) { + MemSetInst *MemSet, + BatchAAResults &BAA) { // We can only transform memset/memcpy with the same destination. - if (!AA->isMustAlias(MemSet->getDest(), MemCpy->getDest())) + if (!BAA.isMustAlias(MemSet->getDest(), MemCpy->getDest())) return false; // Check that src and dst of the memcpy aren't the same. While memcpy // operands cannot partially overlap, exact equality is allowed. - if (isModSet(AA->getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy)))) + if (isModSet(BAA.getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy)))) return false; // We know that dst up to src_size is not written. We now need to make sure // that dst up to dst_size is not accessed. (If we did not move the memset, // checking for reads would be sufficient.) - if (accessedBetween(*AA, MemoryLocation::getForDest(MemSet), + if (accessedBetween(BAA, MemoryLocation::getForDest(MemSet), MSSA->getMemoryAccess(MemSet), MSSA->getMemoryAccess(MemCpy))) return false; @@ -1288,7 +1338,7 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, /// Determine whether the instruction has undefined content for the given Size, /// either because it was freshly alloca'd or started its lifetime. -static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V, +static bool hasUndefContents(MemorySSA *MSSA, BatchAAResults &AA, Value *V, MemoryDef *Def, Value *Size) { if (MSSA->isLiveOnEntryDef(Def)) return isa<AllocaInst>(getUnderlyingObject(V)); @@ -1298,7 +1348,7 @@ static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V, auto *LTSize = cast<ConstantInt>(II->getArgOperand(0)); if (auto *CSize = dyn_cast<ConstantInt>(Size)) { - if (AA->isMustAlias(V, II->getArgOperand(1)) && + if (AA.isMustAlias(V, II->getArgOperand(1)) && LTSize->getZExtValue() >= CSize->getZExtValue()) return true; } @@ -1310,9 +1360,9 @@ static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V, if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(V))) { if (getUnderlyingObject(II->getArgOperand(1)) == Alloca) { const DataLayout &DL = Alloca->getModule()->getDataLayout(); - if (Optional<TypeSize> AllocaSize = - Alloca->getAllocationSizeInBits(DL)) - if (*AllocaSize == LTSize->getValue() * 8) + if (std::optional<TypeSize> AllocaSize = + Alloca->getAllocationSize(DL)) + if (*AllocaSize == LTSize->getValue()) return true; } } @@ -1335,10 +1385,11 @@ static bool hasUndefContents(MemorySSA *MSSA, AliasAnalysis *AA, Value *V, /// \endcode /// When dst2_size <= dst1_size. bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, - MemSetInst *MemSet) { + MemSetInst *MemSet, + BatchAAResults &BAA) { // Make sure that memcpy(..., memset(...), ...), that is we are memsetting and // memcpying from the same address. Otherwise it is hard to reason about. - if (!AA->isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource())) + if (!BAA.isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource())) return false; Value *MemSetSize = MemSet->getLength(); @@ -1366,9 +1417,9 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, bool CanReduceSize = false; MemoryUseOrDef *MemSetAccess = MSSA->getMemoryAccess(MemSet); MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( - MemSetAccess->getDefiningAccess(), MemCpyLoc); + MemSetAccess->getDefiningAccess(), MemCpyLoc, BAA); if (auto *MD = dyn_cast<MemoryDef>(Clobber)) - if (hasUndefContents(MSSA, AA, MemCpy->getSource(), MD, CopySize)) + if (hasUndefContents(MSSA, BAA, MemCpy->getSource(), MD, CopySize)) CanReduceSize = true; if (!CanReduceSize) @@ -1380,7 +1431,7 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, IRBuilder<> Builder(MemCpy); Instruction *NewM = Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), - CopySize, MaybeAlign(MemCpy->getDestAlignment())); + CopySize, MemCpy->getDestAlign()); auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); auto *NewAccess = MSSAU->createMemoryAccessAfter(NewM, LastDef, LastDef); @@ -1411,9 +1462,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { if (Value *ByteVal = isBytewiseValue(GV->getInitializer(), M->getModule()->getDataLayout())) { IRBuilder<> Builder(M); - Instruction *NewM = - Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), - MaybeAlign(M->getDestAlignment()), false); + Instruction *NewM = Builder.CreateMemSet( + M->getRawDest(), ByteVal, M->getLength(), M->getDestAlign(), false); auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M)); auto *NewAccess = @@ -1425,12 +1475,13 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { return true; } + BatchAAResults BAA(*AA); MemoryUseOrDef *MA = MSSA->getMemoryAccess(M); // FIXME: Not using getClobberingMemoryAccess() here due to PR54682. MemoryAccess *AnyClobber = MA->getDefiningAccess(); MemoryLocation DestLoc = MemoryLocation::getForDest(M); const MemoryAccess *DestClobber = - MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc); + MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc, BAA); // Try to turn a partially redundant memset + memcpy into // memcpy + smaller memset. We don't need the memcpy size for this. @@ -1439,11 +1490,11 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { if (auto *MD = dyn_cast<MemoryDef>(DestClobber)) if (auto *MDep = dyn_cast_or_null<MemSetInst>(MD->getMemoryInst())) if (DestClobber->getBlock() == M->getParent()) - if (processMemSetMemCpyDependence(M, MDep)) + if (processMemSetMemCpyDependence(M, MDep, BAA)) return true; MemoryAccess *SrcClobber = MSSA->getWalker()->getClobberingMemoryAccess( - AnyClobber, MemoryLocation::getForSource(M)); + AnyClobber, MemoryLocation::getForSource(M), BAA); // There are four possible optimizations we can do for memcpy: // a) memcpy-memcpy xform which exposes redundance for DSE. @@ -1456,14 +1507,10 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { if (Instruction *MI = MD->getMemoryInst()) { if (auto *CopySize = dyn_cast<ConstantInt>(M->getLength())) { if (auto *C = dyn_cast<CallInst>(MI)) { - // FIXME: Can we pass in either of dest/src alignment here instead - // of conservatively taking the minimum? - Align Alignment = std::min(M->getDestAlign().valueOrOne(), - M->getSourceAlign().valueOrOne()); - if (performCallSlotOptzn( - M, M, M->getDest(), M->getSource(), - TypeSize::getFixed(CopySize->getZExtValue()), Alignment, - [C]() -> CallInst * { return C; })) { + if (performCallSlotOptzn(M, M, M->getDest(), M->getSource(), + TypeSize::getFixed(CopySize->getZExtValue()), + M->getDestAlign().valueOrOne(), BAA, + [C]() -> CallInst * { return C; })) { LLVM_DEBUG(dbgs() << "Performed call slot optimization:\n" << " call: " << *C << "\n" << " memcpy: " << *M << "\n"); @@ -1474,9 +1521,9 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { } } if (auto *MDep = dyn_cast<MemCpyInst>(MI)) - return processMemCpyMemCpyDependence(M, MDep); + return processMemCpyMemCpyDependence(M, MDep, BAA); if (auto *MDep = dyn_cast<MemSetInst>(MI)) { - if (performMemCpyToMemSetOptzn(M, MDep)) { + if (performMemCpyToMemSetOptzn(M, MDep, BAA)) { LLVM_DEBUG(dbgs() << "Converted memcpy to memset\n"); eraseInstruction(M); ++NumCpyToSet; @@ -1485,7 +1532,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { } } - if (hasUndefContents(MSSA, AA, M->getSource(), MD, M->getLength())) { + if (hasUndefContents(MSSA, BAA, M->getSource(), MD, M->getLength())) { LLVM_DEBUG(dbgs() << "Removed memcpy from undef\n"); eraseInstruction(M); ++NumMemCpyInstr; @@ -1532,8 +1579,9 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { if (!CallAccess) return false; MemCpyInst *MDep = nullptr; + BatchAAResults BAA(*AA); MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( - CallAccess->getDefiningAccess(), Loc); + CallAccess->getDefiningAccess(), Loc, BAA); if (auto *MD = dyn_cast<MemoryDef>(Clobber)) MDep = dyn_cast_or_null<MemCpyInst>(MD->getMemoryInst()); @@ -1574,7 +1622,7 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { // *b = 42; // foo(*a) // It would be invalid to transform the second memcpy into foo(*b). - if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), + if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep), MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB))) return false; diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp index ce01ae5b2692..bcedb05890af 100644 --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -153,7 +153,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { if (!isDereferenceablePointer(Addr, LoadI->getType(), DL)) { LLVM_DEBUG(dbgs() << "not dereferenceable\n"); // We need to make sure that we can do comparison in any order, so we - // require memory to be unconditionnally dereferencable. + // require memory to be unconditionally dereferenceable. return {}; } @@ -300,9 +300,9 @@ bool BCECmpBlock::doesOtherWork() const { // Visit the given comparison. If this is a comparison between two valid // BCE atoms, returns the comparison. -Optional<BCECmp> visitICmp(const ICmpInst *const CmpI, - const ICmpInst::Predicate ExpectedPredicate, - BaseIdentifier &BaseId) { +std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI, + const ICmpInst::Predicate ExpectedPredicate, + BaseIdentifier &BaseId) { // The comparison can only be used once: // - For intermediate blocks, as a branch condition. // - For the final block, as an incoming value for the Phi. @@ -310,19 +310,19 @@ Optional<BCECmp> visitICmp(const ICmpInst *const CmpI, // other comparisons as we would create an orphan use of the value. if (!CmpI->hasOneUse()) { LLVM_DEBUG(dbgs() << "cmp has several uses\n"); - return None; + return std::nullopt; } if (CmpI->getPredicate() != ExpectedPredicate) - return None; + return std::nullopt; LLVM_DEBUG(dbgs() << "cmp " << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") << "\n"); auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId); if (!Lhs.BaseId) - return None; + return std::nullopt; auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId); if (!Rhs.BaseId) - return None; + return std::nullopt; const auto &DL = CmpI->getModule()->getDataLayout(); return BCECmp(std::move(Lhs), std::move(Rhs), DL.getTypeSizeInBits(CmpI->getOperand(0)->getType()), CmpI); @@ -330,12 +330,15 @@ Optional<BCECmp> visitICmp(const ICmpInst *const CmpI, // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. -Optional<BCECmpBlock> visitCmpBlock(Value *const Val, BasicBlock *const Block, - const BasicBlock *const PhiBlock, - BaseIdentifier &BaseId) { - if (Block->empty()) return None; +std::optional<BCECmpBlock> visitCmpBlock(Value *const Val, + BasicBlock *const Block, + const BasicBlock *const PhiBlock, + BaseIdentifier &BaseId) { + if (Block->empty()) + return std::nullopt; auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); - if (!BranchI) return None; + if (!BranchI) + return std::nullopt; LLVM_DEBUG(dbgs() << "branch\n"); Value *Cond; ICmpInst::Predicate ExpectedPredicate; @@ -351,7 +354,8 @@ Optional<BCECmpBlock> visitCmpBlock(Value *const Val, BasicBlock *const Block, // chained). const auto *const Const = cast<ConstantInt>(Val); LLVM_DEBUG(dbgs() << "const\n"); - if (!Const->isZero()) return None; + if (!Const->isZero()) + return std::nullopt; LLVM_DEBUG(dbgs() << "false\n"); assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); BasicBlock *const FalseBlock = BranchI->getSuccessor(1); @@ -361,12 +365,13 @@ Optional<BCECmpBlock> visitCmpBlock(Value *const Val, BasicBlock *const Block, } auto *CmpI = dyn_cast<ICmpInst>(Cond); - if (!CmpI) return None; + if (!CmpI) + return std::nullopt; LLVM_DEBUG(dbgs() << "icmp\n"); - Optional<BCECmp> Result = visitICmp(CmpI, ExpectedPredicate, BaseId); + std::optional<BCECmp> Result = visitICmp(CmpI, ExpectedPredicate, BaseId); if (!Result) - return None; + return std::nullopt; BCECmpBlock::InstructionSet BlockInsts( {Result->Lhs.LoadI, Result->Rhs.LoadI, Result->CmpI, BranchI}); @@ -472,7 +477,7 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, BaseIdentifier BaseId; for (BasicBlock *const Block : Blocks) { assert(Block && "invalid block"); - Optional<BCECmpBlock> Comparison = visitCmpBlock( + std::optional<BCECmpBlock> Comparison = visitCmpBlock( Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId); if (!Comparison) { LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); @@ -645,14 +650,18 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, Comparisons.begin(), Comparisons.end(), 0u, [](int Size, const BCECmpBlock &C) { return Size + C.SizeBits(); }); + // memcmp expects a 'size_t' argument and returns 'int'. + unsigned SizeTBits = TLI.getSizeTSize(*Phi.getModule()); + unsigned IntBits = TLI.getIntSize(); + // Create memcmp() == 0. const auto &DL = Phi.getModule()->getDataLayout(); Value *const MemCmpCall = emitMemCmp( Lhs, Rhs, - ConstantInt::get(DL.getIntPtrType(Context), TotalSizeBits / 8), Builder, - DL, &TLI); + ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8), + Builder, DL, &TLI); IsEqual = Builder.CreateICmpEQ( - MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); + MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0)); } BasicBlock *const PhiBB = Phi.getParent(); diff --git a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 6383d6ea838b..62e75d98448c 100644 --- a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -220,27 +220,29 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, } /// -/// Check if 2 stores can be sunk together with corresponding GEPs +/// Check if 2 stores can be sunk, optionally together with corresponding GEPs. /// bool MergedLoadStoreMotion::canSinkStoresAndGEPs(StoreInst *S0, StoreInst *S1) const { - auto *A0 = dyn_cast<Instruction>(S0->getPointerOperand()); - auto *A1 = dyn_cast<Instruction>(S1->getPointerOperand()); - return A0 && A1 && A0->isIdenticalTo(A1) && A0->hasOneUse() && - (A0->getParent() == S0->getParent()) && A1->hasOneUse() && - (A1->getParent() == S1->getParent()) && isa<GetElementPtrInst>(A0); + if (S0->getPointerOperand() == S1->getPointerOperand()) + return true; + auto *GEP0 = dyn_cast<GetElementPtrInst>(S0->getPointerOperand()); + auto *GEP1 = dyn_cast<GetElementPtrInst>(S1->getPointerOperand()); + return GEP0 && GEP1 && GEP0->isIdenticalTo(GEP1) && GEP0->hasOneUse() && + (GEP0->getParent() == S0->getParent()) && GEP1->hasOneUse() && + (GEP1->getParent() == S1->getParent()); } /// /// Merge two stores to same address and sink into \p BB /// -/// Also sinks GEP instruction computing the store address +/// Optionally also sinks GEP instruction computing the store address /// void MergedLoadStoreMotion::sinkStoresAndGEPs(BasicBlock *BB, StoreInst *S0, StoreInst *S1) { + Value *Ptr0 = S0->getPointerOperand(); + Value *Ptr1 = S1->getPointerOperand(); // Only one definition? - auto *A0 = dyn_cast<Instruction>(S0->getPointerOperand()); - auto *A1 = dyn_cast<Instruction>(S1->getPointerOperand()); LLVM_DEBUG(dbgs() << "Sink Instruction into BB \n"; BB->dump(); dbgs() << "Instruction Left\n"; S0->dump(); dbgs() << "\n"; dbgs() << "Instruction Right\n"; S1->dump(); dbgs() << "\n"); @@ -249,25 +251,30 @@ void MergedLoadStoreMotion::sinkStoresAndGEPs(BasicBlock *BB, StoreInst *S0, // Intersect optional metadata. S0->andIRFlags(S1); S0->dropUnknownNonDebugMetadata(); + S0->applyMergedLocation(S0->getDebugLoc(), S1->getDebugLoc()); + S0->mergeDIAssignID(S1); // Create the new store to be inserted at the join point. StoreInst *SNew = cast<StoreInst>(S0->clone()); - Instruction *ANew = A0->clone(); SNew->insertBefore(&*InsertPt); - ANew->insertBefore(SNew); - - assert(S0->getParent() == A0->getParent()); - assert(S1->getParent() == A1->getParent()); - // New PHI operand? Use it. if (PHINode *NewPN = getPHIOperand(BB, S0, S1)) SNew->setOperand(0, NewPN); S0->eraseFromParent(); S1->eraseFromParent(); - A0->replaceAllUsesWith(ANew); - A0->eraseFromParent(); - A1->replaceAllUsesWith(ANew); - A1->eraseFromParent(); + + if (Ptr0 != Ptr1) { + auto *GEP0 = cast<GetElementPtrInst>(Ptr0); + auto *GEP1 = cast<GetElementPtrInst>(Ptr1); + Instruction *GEPNew = GEP0->clone(); + GEPNew->insertBefore(SNew); + GEPNew->applyMergedLocation(GEP0->getDebugLoc(), GEP1->getDebugLoc()); + SNew->setOperand(1, GEPNew); + GEP0->replaceAllUsesWith(GEPNew); + GEP0->eraseFromParent(); + GEP1->replaceAllUsesWith(GEPNew); + GEP1->eraseFromParent(); + } } /// diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index 6dca30d9876e..19bee4fa3879 100644 --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -403,8 +403,9 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, // Replace the I-th index with LHS. IndexExprs[I] = SE->getSCEV(LHS); if (isKnownNonNegative(LHS, *DL, 0, AC, GEP, DT) && - DL->getTypeSizeInBits(LHS->getType()).getFixedSize() < - DL->getTypeSizeInBits(GEP->getOperand(I)->getType()).getFixedSize()) { + DL->getTypeSizeInBits(LHS->getType()).getFixedValue() < + DL->getTypeSizeInBits(GEP->getOperand(I)->getType()) + .getFixedValue()) { // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to // zext if the source operand is proved non-negative. We should do that // consistently so that CandidateExpr more likely appears before. See @@ -576,13 +577,13 @@ NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr, } template <typename MaxMinT> static SCEVTypes convertToSCEVype(MaxMinT &MM) { - if (std::is_same<smax_pred_ty, typename MaxMinT::PredType>::value) + if (std::is_same_v<smax_pred_ty, typename MaxMinT::PredType>) return scSMaxExpr; - else if (std::is_same<umax_pred_ty, typename MaxMinT::PredType>::value) + else if (std::is_same_v<umax_pred_ty, typename MaxMinT::PredType>) return scUMaxExpr; - else if (std::is_same<smin_pred_ty, typename MaxMinT::PredType>::value) + else if (std::is_same_v<smin_pred_ty, typename MaxMinT::PredType>) return scSMinExpr; - else if (std::is_same<umin_pred_ty, typename MaxMinT::PredType>::value) + else if (std::is_same_v<umin_pred_ty, typename MaxMinT::PredType>) return scUMinExpr; llvm_unreachable("Can't convert MinMax pattern to SCEV type"); diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp index 876ef3c427a6..d3dba0c5f1d5 100644 --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -207,7 +207,7 @@ private: Root[I] = ++DFSNum; // Store the DFS Number we had before it possibly gets incremented. unsigned int OurDFS = DFSNum; - for (auto &Op : I->operands()) { + for (const auto &Op : I->operands()) { if (auto *InstOp = dyn_cast<Instruction>(Op)) { if (Root.lookup(Op) == 0) FindSCC(InstOp); @@ -766,9 +766,6 @@ private: SmallPtrSetImpl<Value *> &Visited, MemoryAccess *MemAccess, Instruction *OrigInst, BasicBlock *PredBB); - bool OpIsSafeForPHIOfOpsHelper(Value *V, const BasicBlock *PHIBlock, - SmallPtrSetImpl<const Value *> &Visited, - SmallVectorImpl<Instruction *> &Worklist); bool OpIsSafeForPHIOfOps(Value *Op, const BasicBlock *PHIBlock, SmallPtrSetImpl<const Value *> &); void addPhiOfOps(PHINode *Op, BasicBlock *BB, Instruction *ExistingValue); @@ -1203,10 +1200,9 @@ NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const { if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } else if (auto *GEPI = dyn_cast<GetElementPtrInst>(I)) { - Value *V = - simplifyGEPInst(GEPI->getSourceElementType(), *E->op_begin(), - makeArrayRef(std::next(E->op_begin()), E->op_end()), - GEPI->isInBounds(), Q); + Value *V = simplifyGEPInst(GEPI->getSourceElementType(), *E->op_begin(), + ArrayRef(std::next(E->op_begin()), E->op_end()), + GEPI->isInBounds(), Q); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } else if (AllConstant) { @@ -1566,7 +1562,7 @@ NewGVN::performSymbolicPredicateInfoEvaluation(IntrinsicInst *I) const { LLVM_DEBUG(dbgs() << "Found predicate info from instruction !\n"); - const Optional<PredicateConstraint> &Constraint = PI->getConstraint(); + const std::optional<PredicateConstraint> &Constraint = PI->getConstraint(); if (!Constraint) return ExprResult::none(); @@ -1610,6 +1606,17 @@ NewGVN::ExprResult NewGVN::performSymbolicCallEvaluation(Instruction *I) const { return ExprResult::some(createVariableOrConstant(ReturnedValue)); } } + + // FIXME: Currently the calls which may access the thread id may + // be considered as not accessing the memory. But this is + // problematic for coroutines, since coroutines may resume in a + // different thread. So we disable the optimization here for the + // correctness. However, it may block many other correct + // optimizations. Revert this one when we detect the memory + // accessing kind more precisely. + if (CI->getFunction()->isPresplitCoroutine()) + return ExprResult::none(); + if (AA->doesNotAccessMemory(CI)) { return ExprResult::some( createCallExpression(CI, TOPClass->getMemoryLeader())); @@ -1699,7 +1706,7 @@ bool NewGVN::isCycleFree(const Instruction *I) const { return isa<PHINode>(V) || isCopyOfAPHI(V); }); ICS = AllPhis ? ICS_CycleFree : ICS_Cycle; - for (auto *Member : SCC) + for (const auto *Member : SCC) if (auto *MemberPhi = dyn_cast<PHINode>(Member)) InstCycleState.insert({MemberPhi, ICS}); } @@ -2090,7 +2097,7 @@ void NewGVN::markMemoryDefTouched(const MemoryAccess *MA) { void NewGVN::markMemoryUsersTouched(const MemoryAccess *MA) { if (isa<MemoryUse>(MA)) return; - for (auto U : MA->users()) + for (const auto *U : MA->users()) TouchedInstructions.set(MemoryToDFSNum(U)); touchAndErase(MemoryToUsers, MA); } @@ -2102,14 +2109,14 @@ void NewGVN::markPredicateUsersTouched(Instruction *I) { // Mark users affected by a memory leader change. void NewGVN::markMemoryLeaderChangeTouched(CongruenceClass *CC) { - for (auto M : CC->memory()) + for (const auto *M : CC->memory()) markMemoryDefTouched(M); } // Touch the instructions that need to be updated after a congruence class has a // leader change, and mark changed values. void NewGVN::markValueLeaderChangeTouched(CongruenceClass *CC) { - for (auto M : *CC) { + for (auto *M : *CC) { if (auto *I = dyn_cast<Instruction>(M)) TouchedInstructions.set(InstrToDFSNum(I)); LeaderChanges.insert(M); @@ -2574,58 +2581,6 @@ static bool okayForPHIOfOps(const Instruction *I) { isa<LoadInst>(I); } -bool NewGVN::OpIsSafeForPHIOfOpsHelper( - Value *V, const BasicBlock *PHIBlock, - SmallPtrSetImpl<const Value *> &Visited, - SmallVectorImpl<Instruction *> &Worklist) { - - if (!isa<Instruction>(V)) - return true; - auto OISIt = OpSafeForPHIOfOps.find(V); - if (OISIt != OpSafeForPHIOfOps.end()) - return OISIt->second; - - // Keep walking until we either dominate the phi block, or hit a phi, or run - // out of things to check. - if (DT->properlyDominates(getBlockForValue(V), PHIBlock)) { - OpSafeForPHIOfOps.insert({V, true}); - return true; - } - // PHI in the same block. - if (isa<PHINode>(V) && getBlockForValue(V) == PHIBlock) { - OpSafeForPHIOfOps.insert({V, false}); - return false; - } - - auto *OrigI = cast<Instruction>(V); - // When we hit an instruction that reads memory (load, call, etc), we must - // consider any store that may happen in the loop. For now, we assume the - // worst: there is a store in the loop that alias with this read. - // The case where the load is outside the loop is already covered by the - // dominator check above. - // TODO: relax this condition - if (OrigI->mayReadFromMemory()) - return false; - - for (auto *Op : OrigI->operand_values()) { - if (!isa<Instruction>(Op)) - continue; - // Stop now if we find an unsafe operand. - auto OISIt = OpSafeForPHIOfOps.find(OrigI); - if (OISIt != OpSafeForPHIOfOps.end()) { - if (!OISIt->second) { - OpSafeForPHIOfOps.insert({V, false}); - return false; - } - continue; - } - if (!Visited.insert(Op).second) - continue; - Worklist.push_back(cast<Instruction>(Op)); - } - return true; -} - // Return true if this operand will be safe to use for phi of ops. // // The reason some operands are unsafe is that we are not trying to recursively @@ -2635,13 +2590,56 @@ bool NewGVN::OpIsSafeForPHIOfOpsHelper( // be determined to be constant. bool NewGVN::OpIsSafeForPHIOfOps(Value *V, const BasicBlock *PHIBlock, SmallPtrSetImpl<const Value *> &Visited) { - SmallVector<Instruction *, 4> Worklist; - if (!OpIsSafeForPHIOfOpsHelper(V, PHIBlock, Visited, Worklist)) - return false; + SmallVector<Value *, 4> Worklist; + Worklist.push_back(V); while (!Worklist.empty()) { auto *I = Worklist.pop_back_val(); - if (!OpIsSafeForPHIOfOpsHelper(I, PHIBlock, Visited, Worklist)) + if (!isa<Instruction>(I)) + continue; + + auto OISIt = OpSafeForPHIOfOps.find(I); + if (OISIt != OpSafeForPHIOfOps.end()) + return OISIt->second; + + // Keep walking until we either dominate the phi block, or hit a phi, or run + // out of things to check. + if (DT->properlyDominates(getBlockForValue(I), PHIBlock)) { + OpSafeForPHIOfOps.insert({I, true}); + continue; + } + // PHI in the same block. + if (isa<PHINode>(I) && getBlockForValue(I) == PHIBlock) { + OpSafeForPHIOfOps.insert({I, false}); + return false; + } + + auto *OrigI = cast<Instruction>(I); + // When we hit an instruction that reads memory (load, call, etc), we must + // consider any store that may happen in the loop. For now, we assume the + // worst: there is a store in the loop that alias with this read. + // The case where the load is outside the loop is already covered by the + // dominator check above. + // TODO: relax this condition + if (OrigI->mayReadFromMemory()) return false; + + // Check the operands of the current instruction. + for (auto *Op : OrigI->operand_values()) { + if (!isa<Instruction>(Op)) + continue; + // Stop now if we find an unsafe operand. + auto OISIt = OpSafeForPHIOfOps.find(OrigI); + if (OISIt != OpSafeForPHIOfOps.end()) { + if (!OISIt->second) { + OpSafeForPHIOfOps.insert({I, false}); + return false; + } + continue; + } + if (!Visited.insert(Op).second) + continue; + Worklist.push_back(cast<Instruction>(Op)); + } } OpSafeForPHIOfOps.insert({V, true}); return true; @@ -2798,7 +2796,7 @@ NewGVN::makePossiblePHIOfOps(Instruction *I, // We failed to find a leader for the current ValueOp, but this might // change in case of the translated operands change. if (SafeForPHIOfOps) - for (auto Dep : CurrentDeps) + for (auto *Dep : CurrentDeps) addAdditionalUsers(Dep, I); return nullptr; @@ -2816,7 +2814,7 @@ NewGVN::makePossiblePHIOfOps(Instruction *I, LLVM_DEBUG(dbgs() << "Found phi of ops operand " << *FoundVal << " in " << getBlockName(PredBB) << "\n"); } - for (auto Dep : Deps) + for (auto *Dep : Deps) addAdditionalUsers(Dep, I); sortPHIOps(PHIOps); auto *E = performSymbolicPHIEvaluation(PHIOps, I, PHIBlock); @@ -2883,7 +2881,7 @@ void NewGVN::initializeCongruenceClasses(Function &F) { MemoryAccessToClass[MSSA->getLiveOnEntryDef()] = createMemoryClass(MSSA->getLiveOnEntryDef()); - for (auto DTN : nodes(DT)) { + for (auto *DTN : nodes(DT)) { BasicBlock *BB = DTN->getBlock(); // All MemoryAccesses are equivalent to live on entry to start. They must // be initialized to something so that initial changes are noticed. For @@ -2929,14 +2927,13 @@ void NewGVN::initializeCongruenceClasses(Function &F) { } void NewGVN::cleanupTables() { - for (unsigned i = 0, e = CongruenceClasses.size(); i != e; ++i) { - LLVM_DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->getID() - << " has " << CongruenceClasses[i]->size() - << " members\n"); + for (CongruenceClass *&CC : CongruenceClasses) { + LLVM_DEBUG(dbgs() << "Congruence class " << CC->getID() << " has " + << CC->size() << " members\n"); // Make sure we delete the congruence class (probably worth switching to // a unique_ptr at some point. - delete CongruenceClasses[i]; - CongruenceClasses[i] = nullptr; + delete CC; + CC = nullptr; } // Destroy the value expressions @@ -3151,7 +3148,7 @@ bool NewGVN::singleReachablePHIPath( return true; const auto *EndDef = First; - for (auto *ChainDef : optimized_def_chain(First)) { + for (const auto *ChainDef : optimized_def_chain(First)) { if (ChainDef == Second) return true; if (MSSA->isLiveOnEntryDef(ChainDef)) @@ -3166,7 +3163,7 @@ bool NewGVN::singleReachablePHIPath( make_filter_range(MP->operands(), ReachableOperandPred); SmallVector<const Value *, 32> OperandList; llvm::copy(FilteredPhiArgs, std::back_inserter(OperandList)); - bool Okay = is_splat(OperandList); + bool Okay = all_equal(OperandList); if (Okay) return singleReachablePHIPath(Visited, cast<MemoryAccess>(OperandList[0]), Second); @@ -3196,7 +3193,7 @@ void NewGVN::verifyMemoryCongruency() const { assert(MemoryAccessToClass.lookup(CC->getMemoryLeader()) == CC && "Representative MemoryAccess does not appear to be reverse " "mapped properly"); - for (auto M : CC->memory()) + for (const auto *M : CC->memory()) assert(MemoryAccessToClass.lookup(M) == CC && "Memory member does not appear to be reverse mapped properly"); } @@ -3218,7 +3215,7 @@ void NewGVN::verifyMemoryCongruency() const { // We could have phi nodes which operands are all trivially dead, // so we don't process them. if (auto *MemPHI = dyn_cast<MemoryPhi>(Pair.first)) { - for (auto &U : MemPHI->incoming_values()) { + for (const auto &U : MemPHI->incoming_values()) { if (auto *I = dyn_cast<Instruction>(&*U)) { if (!isInstructionTriviallyDead(I)) return true; @@ -3261,7 +3258,7 @@ void NewGVN::verifyMemoryCongruency() const { const MemoryDef *MD = cast<MemoryDef>(U); return ValueToClass.lookup(MD->getMemoryInst()); }); - assert(is_splat(PhiOpClasses) && + assert(all_equal(PhiOpClasses) && "All MemoryPhi arguments should be in the same class"); } } @@ -3293,6 +3290,7 @@ void NewGVN::verifyIterationSettled(Function &F) { TouchedInstructions.set(); TouchedInstructions.reset(0); + OpSafeForPHIOfOps.clear(); iterateTouchedInstructions(); DenseSet<std::pair<const CongruenceClass *, const CongruenceClass *>> EqualClasses; @@ -3455,7 +3453,7 @@ bool NewGVN::runGVN() { } // Now a standard depth first ordering of the domtree is equivalent to RPO. - for (auto DTN : depth_first(DT->getRootNode())) { + for (auto *DTN : depth_first(DT->getRootNode())) { BasicBlock *B = DTN->getBlock(); const auto &BlockRange = assignDFSNumbers(B, ICount); BlockInstRange.insert({B, BlockRange}); @@ -3575,7 +3573,7 @@ void NewGVN::convertClassToDFSOrdered( const CongruenceClass &Dense, SmallVectorImpl<ValueDFS> &DFSOrderedSet, DenseMap<const Value *, unsigned int> &UseCounts, SmallPtrSetImpl<Instruction *> &ProbablyDead) const { - for (auto D : Dense) { + for (auto *D : Dense) { // First add the value. BasicBlock *BB = getBlockForValue(D); // Constants are handled prior to ever calling this function, so @@ -3665,7 +3663,7 @@ void NewGVN::convertClassToDFSOrdered( void NewGVN::convertClassToLoadsAndStores( const CongruenceClass &Dense, SmallVectorImpl<ValueDFS> &LoadsAndStores) const { - for (auto D : Dense) { + for (auto *D : Dense) { if (!isa<LoadInst>(D) && !isa<StoreInst>(D)) continue; @@ -3803,7 +3801,7 @@ Value *NewGVN::findPHIOfOpsLeader(const Expression *E, if (alwaysAvailable(CC->getLeader())) return CC->getLeader(); - for (auto Member : *CC) { + for (auto *Member : *CC) { auto *MemberInst = dyn_cast<Instruction>(Member); if (MemberInst == OrigInst) continue; @@ -3896,7 +3894,7 @@ bool NewGVN::eliminateInstructions(Function &F) { continue; // Everything still in the TOP class is unreachable or dead. if (CC == TOPClass) { - for (auto M : *CC) { + for (auto *M : *CC) { auto *VTE = ValueToExpression.lookup(M); if (VTE && isa<DeadExpression>(VTE)) markInstructionForDeletion(cast<Instruction>(M)); @@ -3917,7 +3915,7 @@ bool NewGVN::eliminateInstructions(Function &F) { CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader(); if (alwaysAvailable(Leader)) { CongruenceClass::MemberSet MembersLeft; - for (auto M : *CC) { + for (auto *M : *CC) { Value *Member = M; // Void things have no uses we can replace. if (Member == Leader || !isa<Instruction>(Member) || diff --git a/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 689a2a286cb9..3a699df1cde4 100644 --- a/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -22,6 +22,7 @@ #include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <optional> using namespace llvm; @@ -80,10 +81,9 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, Instruction *LibCall = Call->clone(); Builder.Insert(LibCall); - // Add attribute "readnone" so that backend can use a native sqrt instruction - // for this call. - Call->removeFnAttr(Attribute::WriteOnly); - Call->addFnAttr(Attribute::ReadNone); + // Add memory(none) attribute, so that the backend can use a native sqrt + // instruction for this call. + Call->setDoesNotAccessMemory(); // Insert a FP compare instruction and use it as the CurrBB branch condition. Builder.SetInsertPoint(CurrBBTerm); @@ -104,7 +104,7 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, DominatorTree *DT) { - Optional<DomTreeUpdater> DTU; + std::optional<DomTreeUpdater> DTU; if (DT) DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); @@ -140,7 +140,7 @@ static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, case LibFunc_sqrt: if (TTI->haveFastSqrt(Call->getType()) && optimizeSQRT(Call, CalledFunc, *CurrBB, BB, TTI, - DTU ? DTU.getPointer() : nullptr)) + DTU ? &*DTU : nullptr)) break; continue; default: diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index cd2ce8ce336e..21628b61edd6 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -833,9 +833,14 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, /// additional opportunities have been exposed. static Value *NegateValue(Value *V, Instruction *BI, ReassociatePass::OrderedSet &ToRedo) { - if (auto *C = dyn_cast<Constant>(V)) - return C->getType()->isFPOrFPVectorTy() ? ConstantExpr::getFNeg(C) : - ConstantExpr::getNeg(C); + if (auto *C = dyn_cast<Constant>(V)) { + const DataLayout &DL = BI->getModule()->getDataLayout(); + Constant *Res = C->getType()->isFPOrFPVectorTy() + ? ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL) + : ConstantExpr::getNeg(C); + if (Res) + return Res; + } // We are trying to expose opportunity for reassociation. One of the things // that we want to do to achieve this is to push a negation as deep into an @@ -880,46 +885,29 @@ static Value *NegateValue(Value *V, Instruction *BI, // this use. We do this by moving it to the entry block (if it is a // non-instruction value) or right after the definition. These negates will // be zapped by reassociate later, so we don't need much finesse here. - Instruction *TheNeg = cast<Instruction>(U); + Instruction *TheNeg = dyn_cast<Instruction>(U); - // Verify that the negate is in this function, V might be a constant expr. - if (TheNeg->getParent()->getParent() != BI->getParent()->getParent()) + // We can't safely propagate a vector zero constant with poison/undef lanes. + Constant *C; + if (match(TheNeg, m_BinOp(m_Constant(C), m_Value())) && + C->containsUndefOrPoisonElement()) continue; - bool FoundCatchSwitch = false; + // Verify that the negate is in this function, V might be a constant expr. + if (!TheNeg || + TheNeg->getParent()->getParent() != BI->getParent()->getParent()) + continue; - BasicBlock::iterator InsertPt; + Instruction *InsertPt; if (Instruction *InstInput = dyn_cast<Instruction>(V)) { - if (InvokeInst *II = dyn_cast<InvokeInst>(InstInput)) { - InsertPt = II->getNormalDest()->begin(); - } else { - InsertPt = ++InstInput->getIterator(); - } - - const BasicBlock *BB = InsertPt->getParent(); - - // Make sure we don't move anything before PHIs or exception - // handling pads. - while (InsertPt != BB->end() && (isa<PHINode>(InsertPt) || - InsertPt->isEHPad())) { - if (isa<CatchSwitchInst>(InsertPt)) - // A catchswitch cannot have anything in the block except - // itself and PHIs. We'll bail out below. - FoundCatchSwitch = true; - ++InsertPt; - } + InsertPt = InstInput->getInsertionPointAfterDef(); + if (!InsertPt) + continue; } else { - InsertPt = TheNeg->getParent()->getParent()->getEntryBlock().begin(); + InsertPt = &*TheNeg->getFunction()->getEntryBlock().begin(); } - // We found a catchswitch in the block where we want to move the - // neg. We cannot move anything into that block. Bail and just - // create the neg before BI, as if we hadn't found an existing - // neg. - if (FoundCatchSwitch) - break; - - TheNeg->moveBefore(&*InsertPt); + TheNeg->moveBefore(InsertPt); if (TheNeg->getOpcode() == Instruction::Sub) { TheNeg->setHasNoUnsignedWrap(false); TheNeg->setHasNoSignedWrap(false); @@ -1898,10 +1886,10 @@ ReassociatePass::buildMinimalMultiplyDAG(IRBuilderBase &Builder, // Iteratively collect the base of each factor with an add power into the // outer product, and halve each power in preparation for squaring the // expression. - for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) { - if (Factors[Idx].Power & 1) - OuterProduct.push_back(Factors[Idx].Base); - Factors[Idx].Power >>= 1; + for (Factor &F : Factors) { + if (F.Power & 1) + OuterProduct.push_back(F.Base); + F.Power >>= 1; } if (Factors[0].Power) { Value *SquareRoot = buildMinimalMultiplyDAG(Builder, Factors); @@ -2027,7 +2015,7 @@ void ReassociatePass::RecursivelyEraseDeadInsts(Instruction *I, RedoInsts.remove(I); llvm::salvageDebugInfo(*I); I->eraseFromParent(); - for (auto Op : Ops) + for (auto *Op : Ops) if (Instruction *OpInst = dyn_cast<Instruction>(Op)) if (OpInst->use_empty()) Insts.insert(OpInst); diff --git a/llvm/lib/Transforms/Scalar/Reg2Mem.cpp b/llvm/lib/Transforms/Scalar/Reg2Mem.cpp index 9dc64493a9ee..db7a1f24660c 100644 --- a/llvm/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/llvm/lib/Transforms/Scalar/Reg2Mem.cpp @@ -40,6 +40,9 @@ STATISTIC(NumRegsDemoted, "Number of registers demoted"); STATISTIC(NumPhisDemoted, "Number of phi-nodes demoted"); static bool valueEscapes(const Instruction &Inst) { + if (!Inst.getType()->isSized()) + return false; + const BasicBlock *BB = Inst.getParent(); for (const User *U : Inst.users()) { const Instruction *UI = cast<Instruction>(U); diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index baf407c5037b..bcb012b79c2e 100644 --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -17,8 +17,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" @@ -71,6 +69,7 @@ #include <cstddef> #include <cstdint> #include <iterator> +#include <optional> #include <set> #include <string> #include <utility> @@ -110,6 +109,9 @@ static cl::opt<bool> AllowStatepointWithNoDeoptInfo("rs4gc-allow-statepoint-with-no-deopt-info", cl::Hidden, cl::init(true)); +static cl::opt<bool> RematDerivedAtUses("rs4gc-remat-derived-at-uses", + cl::Hidden, cl::init(true)); + /// The IR fed into RewriteStatepointsForGC may have had attributes and /// metadata implying dereferenceability that are no longer valid/correct after /// RewriteStatepointsForGC has run. This is because semantically, after @@ -295,13 +297,13 @@ using RematCandTy = MapVector<Value *, RematerizlizationCandidateRecord>; } // end anonymous namespace static ArrayRef<Use> GetDeoptBundleOperands(const CallBase *Call) { - Optional<OperandBundleUse> DeoptBundle = + std::optional<OperandBundleUse> DeoptBundle = Call->getOperandBundle(LLVMContext::OB_deopt); if (!DeoptBundle) { assert(AllowStatepointWithNoDeoptInfo && "Found non-leaf call without deopt info!"); - return None; + return std::nullopt; } return DeoptBundle->Inputs; @@ -317,7 +319,7 @@ static void findLiveSetAtInst(Instruction *inst, GCPtrLivenessData &Data, StatepointLiveSetTy &out); // TODO: Once we can get to the GCStrategy, this becomes -// Optional<bool> isGCManagedPointer(const Type *Ty) const override { +// std::optional<bool> isGCManagedPointer(const Type *Ty) const override { static bool isGCPointerType(Type *T) { if (auto *PT = dyn_cast<PointerType>(T)) @@ -1400,6 +1402,61 @@ static void recomputeLiveInValues( } } +// Utility function which clones all instructions from "ChainToBase" +// and inserts them before "InsertBefore". Returns rematerialized value +// which should be used after statepoint. +static Instruction *rematerializeChain(ArrayRef<Instruction *> ChainToBase, + Instruction *InsertBefore, + Value *RootOfChain, + Value *AlternateLiveBase) { + Instruction *LastClonedValue = nullptr; + Instruction *LastValue = nullptr; + // Walk backwards to visit top-most instructions first. + for (Instruction *Instr : + make_range(ChainToBase.rbegin(), ChainToBase.rend())) { + // Only GEP's and casts are supported as we need to be careful to not + // introduce any new uses of pointers not in the liveset. + // Note that it's fine to introduce new uses of pointers which were + // otherwise not used after this statepoint. + assert(isa<GetElementPtrInst>(Instr) || isa<CastInst>(Instr)); + + Instruction *ClonedValue = Instr->clone(); + ClonedValue->insertBefore(InsertBefore); + ClonedValue->setName(Instr->getName() + ".remat"); + + // If it is not first instruction in the chain then it uses previously + // cloned value. We should update it to use cloned value. + if (LastClonedValue) { + assert(LastValue); + ClonedValue->replaceUsesOfWith(LastValue, LastClonedValue); +#ifndef NDEBUG + for (auto *OpValue : ClonedValue->operand_values()) { + // Assert that cloned instruction does not use any instructions from + // this chain other than LastClonedValue + assert(!is_contained(ChainToBase, OpValue) && + "incorrect use in rematerialization chain"); + // Assert that the cloned instruction does not use the RootOfChain + // or the AlternateLiveBase. + assert(OpValue != RootOfChain && OpValue != AlternateLiveBase); + } +#endif + } else { + // For the first instruction, replace the use of unrelocated base i.e. + // RootOfChain/OrigRootPhi, with the corresponding PHI present in the + // live set. They have been proved to be the same PHI nodes. Note + // that the *only* use of the RootOfChain in the ChainToBase list is + // the first Value in the list. + if (RootOfChain != AlternateLiveBase) + ClonedValue->replaceUsesOfWith(RootOfChain, AlternateLiveBase); + } + + LastClonedValue = ClonedValue; + LastValue = Instr; + } + assert(LastClonedValue); + return LastClonedValue; +} + // When inserting gc.relocate and gc.result calls, we need to ensure there are // no uses of the original value / return value between the gc.statepoint and // the gc.relocate / gc.result call. One case which can arise is a phi node @@ -1430,10 +1487,7 @@ normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent, // machine model for purposes of optimization. We have to strip these on // both function declarations and call sites. static constexpr Attribute::AttrKind FnAttrsToStrip[] = - {Attribute::ReadNone, Attribute::ReadOnly, Attribute::WriteOnly, - Attribute::ArgMemOnly, Attribute::InaccessibleMemOnly, - Attribute::InaccessibleMemOrArgMemOnly, - Attribute::NoSync, Attribute::NoFree}; + {Attribute::Memory, Attribute::NoSync, Attribute::NoFree}; // Create new attribute set containing only attributes which can be transferred // from original call to the safepoint. @@ -1629,10 +1683,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ uint32_t Flags = uint32_t(StatepointFlags::None); SmallVector<Value *, 8> CallArgs(Call->args()); - Optional<ArrayRef<Use>> DeoptArgs; + std::optional<ArrayRef<Use>> DeoptArgs; if (auto Bundle = Call->getOperandBundle(LLVMContext::OB_deopt)) DeoptArgs = Bundle->Inputs; - Optional<ArrayRef<Use>> TransitionArgs; + std::optional<ArrayRef<Use>> TransitionArgs; if (auto Bundle = Call->getOperandBundle(LLVMContext::OB_gc_transition)) { TransitionArgs = Bundle->Inputs; // TODO: This flag no longer serves a purpose and can be removed later @@ -2082,8 +2136,12 @@ static void relocationViaAlloca( auto InsertClobbersAt = [&](Instruction *IP) { for (auto *AI : ToClobber) { - auto PT = cast<PointerType>(AI->getAllocatedType()); - Constant *CPN = ConstantPointerNull::get(PT); + auto AT = AI->getAllocatedType(); + Constant *CPN; + if (AT->isVectorTy()) + CPN = ConstantAggregateZero::get(AT); + else + CPN = ConstantPointerNull::get(cast<PointerType>(AT)); new StoreInst(CPN, AI, IP); } }; @@ -2379,6 +2437,126 @@ findRematerializationCandidates(PointerToBaseTy PointerToBase, } } +// Try to rematerialize derived pointers immediately before their uses +// (instead of rematerializing after every statepoint it is live through). +// This can be beneficial when derived pointer is live across many +// statepoints, but uses are rare. +static void rematerializeLiveValuesAtUses( + RematCandTy &RematerizationCandidates, + MutableArrayRef<PartiallyConstructedSafepointRecord> Records, + PointerToBaseTy &PointerToBase) { + if (!RematDerivedAtUses) + return; + + SmallVector<Instruction *, 32> LiveValuesToBeDeleted; + + LLVM_DEBUG(dbgs() << "Rematerialize derived pointers at uses, " + << "Num statepoints: " << Records.size() << '\n'); + + for (auto &It : RematerizationCandidates) { + Instruction *Cand = cast<Instruction>(It.first); + auto &Record = It.second; + + if (Record.Cost >= RematerializationThreshold) + continue; + + if (Cand->user_empty()) + continue; + + if (Cand->hasOneUse()) + if (auto *U = dyn_cast<Instruction>(Cand->getUniqueUndroppableUser())) + if (U->getParent() == Cand->getParent()) + continue; + + // Rematerialization before PHI nodes is not implemented. + if (llvm::any_of(Cand->users(), + [](const auto *U) { return isa<PHINode>(U); })) + continue; + + LLVM_DEBUG(dbgs() << "Trying cand " << *Cand << " ... "); + + // Count of rematerialization instructions we introduce is equal to number + // of candidate uses. + // Count of rematerialization instructions we eliminate is equal to number + // of statepoints it is live through. + // Consider transformation profitable if latter is greater than former + // (in other words, we create less than eliminate). + unsigned NumLiveStatepoints = llvm::count_if( + Records, [Cand](const auto &R) { return R.LiveSet.contains(Cand); }); + unsigned NumUses = Cand->getNumUses(); + + LLVM_DEBUG(dbgs() << "Num uses: " << NumUses << " Num live statepoints: " + << NumLiveStatepoints << " "); + + if (NumLiveStatepoints < NumUses) { + LLVM_DEBUG(dbgs() << "not profitable\n"); + continue; + } + + // If rematerialization is 'free', then favor rematerialization at + // uses as it generally shortens live ranges. + // TODO: Short (size ==1) chains only? + if (NumLiveStatepoints == NumUses && Record.Cost > 0) { + LLVM_DEBUG(dbgs() << "not profitable\n"); + continue; + } + + LLVM_DEBUG(dbgs() << "looks profitable\n"); + + // ChainToBase may contain another remat candidate (as a sub chain) which + // has been rewritten by now. Need to recollect chain to have up to date + // value. + // TODO: sort records in findRematerializationCandidates() in + // decreasing chain size order? + if (Record.ChainToBase.size() > 1) { + Record.ChainToBase.clear(); + findRematerializableChainToBasePointer(Record.ChainToBase, Cand); + } + + // Current rematerialization algorithm is very simple: we rematerialize + // immediately before EVERY use, even if there are several uses in same + // block or if use is local to Cand Def. The reason is that this allows + // us to avoid recomputing liveness without complicated analysis: + // - If we did not eliminate all uses of original Candidate, we do not + // know exaclty in what BBs it is still live. + // - If we rematerialize once per BB, we need to find proper insertion + // place (first use in block, but after Def) and analyze if there is + // statepoint between uses in the block. + while (!Cand->user_empty()) { + Instruction *UserI = cast<Instruction>(*Cand->user_begin()); + Instruction *RematChain = rematerializeChain( + Record.ChainToBase, UserI, Record.RootOfChain, PointerToBase[Cand]); + UserI->replaceUsesOfWith(Cand, RematChain); + PointerToBase[RematChain] = PointerToBase[Cand]; + } + LiveValuesToBeDeleted.push_back(Cand); + } + + LLVM_DEBUG(dbgs() << "Rematerialized " << LiveValuesToBeDeleted.size() + << " derived pointers\n"); + for (auto *Cand : LiveValuesToBeDeleted) { + assert(Cand->use_empty() && "Unexpected user remain"); + RematerizationCandidates.erase(Cand); + for (auto &R : Records) { + assert(!R.LiveSet.contains(Cand) || + R.LiveSet.contains(PointerToBase[Cand])); + R.LiveSet.remove(Cand); + } + } + + // Recollect not rematerialized chains - we might have rewritten + // their sub-chains. + if (!LiveValuesToBeDeleted.empty()) { + for (auto &P : RematerizationCandidates) { + auto &R = P.second; + if (R.ChainToBase.size() > 1) { + R.ChainToBase.clear(); + findRematerializableChainToBasePointer(R.ChainToBase, P.first); + } + } + } +} + // From the statepoint live set pick values that are cheaper to recompute then // to relocate. Remove this values from the live set, rematerialize them after // statepoint and record them in "Info" structure. Note that similar to @@ -2414,69 +2592,14 @@ static void rematerializeLiveValues(CallBase *Call, // Clone instructions and record them inside "Info" structure. - // For each live pointer find get its defining chain. - SmallVector<Instruction *, 3> ChainToBase = Record.ChainToBase; - // Walk backwards to visit top-most instructions first. - std::reverse(ChainToBase.begin(), ChainToBase.end()); - - // Utility function which clones all instructions from "ChainToBase" - // and inserts them before "InsertBefore". Returns rematerialized value - // which should be used after statepoint. - auto rematerializeChain = [&ChainToBase]( - Instruction *InsertBefore, Value *RootOfChain, Value *AlternateLiveBase) { - Instruction *LastClonedValue = nullptr; - Instruction *LastValue = nullptr; - for (Instruction *Instr: ChainToBase) { - // Only GEP's and casts are supported as we need to be careful to not - // introduce any new uses of pointers not in the liveset. - // Note that it's fine to introduce new uses of pointers which were - // otherwise not used after this statepoint. - assert(isa<GetElementPtrInst>(Instr) || isa<CastInst>(Instr)); - - Instruction *ClonedValue = Instr->clone(); - ClonedValue->insertBefore(InsertBefore); - ClonedValue->setName(Instr->getName() + ".remat"); - - // If it is not first instruction in the chain then it uses previously - // cloned value. We should update it to use cloned value. - if (LastClonedValue) { - assert(LastValue); - ClonedValue->replaceUsesOfWith(LastValue, LastClonedValue); -#ifndef NDEBUG - for (auto OpValue : ClonedValue->operand_values()) { - // Assert that cloned instruction does not use any instructions from - // this chain other than LastClonedValue - assert(!is_contained(ChainToBase, OpValue) && - "incorrect use in rematerialization chain"); - // Assert that the cloned instruction does not use the RootOfChain - // or the AlternateLiveBase. - assert(OpValue != RootOfChain && OpValue != AlternateLiveBase); - } -#endif - } else { - // For the first instruction, replace the use of unrelocated base i.e. - // RootOfChain/OrigRootPhi, with the corresponding PHI present in the - // live set. They have been proved to be the same PHI nodes. Note - // that the *only* use of the RootOfChain in the ChainToBase list is - // the first Value in the list. - if (RootOfChain != AlternateLiveBase) - ClonedValue->replaceUsesOfWith(RootOfChain, AlternateLiveBase); - } - - LastClonedValue = ClonedValue; - LastValue = Instr; - } - assert(LastClonedValue); - return LastClonedValue; - }; - // Different cases for calls and invokes. For invokes we need to clone // instructions both on normal and unwind path. if (isa<CallInst>(Call)) { Instruction *InsertBefore = Call->getNextNode(); assert(InsertBefore); - Instruction *RematerializedValue = rematerializeChain( - InsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); + Instruction *RematerializedValue = + rematerializeChain(Record.ChainToBase, InsertBefore, + Record.RootOfChain, PointerToBase[LiveValue]); Info.RematerializedValues[RematerializedValue] = LiveValue; } else { auto *Invoke = cast<InvokeInst>(Call); @@ -2486,18 +2609,20 @@ static void rematerializeLiveValues(CallBase *Call, Instruction *UnwindInsertBefore = &*Invoke->getUnwindDest()->getFirstInsertionPt(); - Instruction *NormalRematerializedValue = rematerializeChain( - NormalInsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); - Instruction *UnwindRematerializedValue = rematerializeChain( - UnwindInsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); + Instruction *NormalRematerializedValue = + rematerializeChain(Record.ChainToBase, NormalInsertBefore, + Record.RootOfChain, PointerToBase[LiveValue]); + Instruction *UnwindRematerializedValue = + rematerializeChain(Record.ChainToBase, UnwindInsertBefore, + Record.RootOfChain, PointerToBase[LiveValue]); Info.RematerializedValues[NormalRematerializedValue] = LiveValue; Info.RematerializedValues[UnwindRematerializedValue] = LiveValue; } } - // Remove rematerializaed values from the live set - for (auto LiveValue: LiveValuesToBeDeleted) { + // Remove rematerialized values from the live set. + for (auto *LiveValue: LiveValuesToBeDeleted) { Info.LiveSet.remove(LiveValue); } } @@ -2697,6 +2822,9 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // In order to reduce live set of statepoint we might choose to rematerialize // some values instead of relocating them. This is purely an optimization and // does not influence correctness. + // First try rematerialization at uses, then after statepoints. + rematerializeLiveValuesAtUses(RematerizationCandidates, Records, + PointerToBase); for (size_t i = 0; i < Records.size(); i++) rematerializeLiveValues(ToUpdate[i], Records[i], PointerToBase, RematerizationCandidates, TTI); @@ -3266,7 +3394,7 @@ static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, // We may have base pointers which are now live that weren't before. We need // to update the PointerToBase structure to reflect this. - for (auto V : Updated) + for (auto *V : Updated) PointerToBase.insert({ V, V }); Info.LiveSet = Updated; diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp index 2282ef636076..7b396c6ee074 100644 --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -27,19 +27,15 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueLattice.h" -#include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -67,133 +63,6 @@ STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); STATISTIC(NumInstReplaced, "Number of instructions replaced with (simpler) instruction"); -STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP"); -STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP"); -STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); -STATISTIC( - IPNumInstReplaced, - "Number of instructions replaced with (simpler) instruction by IPSCCP"); - -// Helper to check if \p LV is either a constant or a constant -// range with a single element. This should cover exactly the same cases as the -// old ValueLatticeElement::isConstant() and is intended to be used in the -// transition to ValueLatticeElement. -static bool isConstant(const ValueLatticeElement &LV) { - return LV.isConstant() || - (LV.isConstantRange() && LV.getConstantRange().isSingleElement()); -} - -// Helper to check if \p LV is either overdefined or a constant range with more -// than a single element. This should cover exactly the same cases as the old -// ValueLatticeElement::isOverdefined() and is intended to be used in the -// transition to ValueLatticeElement. -static bool isOverdefined(const ValueLatticeElement &LV) { - return !LV.isUnknownOrUndef() && !isConstant(LV); -} - -static bool canRemoveInstruction(Instruction *I) { - if (wouldInstructionBeTriviallyDead(I)) - return true; - - // Some instructions can be handled but are rejected above. Catch - // those cases by falling through to here. - // TODO: Mark globals as being constant earlier, so - // TODO: wouldInstructionBeTriviallyDead() knows that atomic loads - // TODO: are safe to remove. - return isa<LoadInst>(I); -} - -static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { - Constant *Const = nullptr; - if (V->getType()->isStructTy()) { - std::vector<ValueLatticeElement> IVs = Solver.getStructLatticeValueFor(V); - if (llvm::any_of(IVs, isOverdefined)) - return false; - std::vector<Constant *> ConstVals; - auto *ST = cast<StructType>(V->getType()); - for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { - ValueLatticeElement V = IVs[i]; - ConstVals.push_back(isConstant(V) - ? Solver.getConstant(V) - : UndefValue::get(ST->getElementType(i))); - } - Const = ConstantStruct::get(ST, ConstVals); - } else { - const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); - if (isOverdefined(IV)) - return false; - - Const = - isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType()); - } - assert(Const && "Constant is nullptr here!"); - - // Replacing `musttail` instructions with constant breaks `musttail` invariant - // unless the call itself can be removed. - // Calls with "clang.arc.attachedcall" implicitly use the return value and - // those uses cannot be updated with a constant. - CallBase *CB = dyn_cast<CallBase>(V); - if (CB && ((CB->isMustTailCall() && - !canRemoveInstruction(CB)) || - CB->getOperandBundle(LLVMContext::OB_clang_arc_attachedcall))) { - Function *F = CB->getCalledFunction(); - - // Don't zap returns of the callee - if (F) - Solver.addToMustPreserveReturnsInFunctions(F); - - LLVM_DEBUG(dbgs() << " Can\'t treat the result of call " << *CB - << " as a constant\n"); - return false; - } - - LLVM_DEBUG(dbgs() << " Constant: " << *Const << " = " << *V << '\n'); - - // Replaces all of the uses of a variable with uses of the constant. - V->replaceAllUsesWith(Const); - return true; -} - -static bool simplifyInstsInBlock(SCCPSolver &Solver, BasicBlock &BB, - SmallPtrSetImpl<Value *> &InsertedValues, - Statistic &InstRemovedStat, - Statistic &InstReplacedStat) { - bool MadeChanges = false; - for (Instruction &Inst : make_early_inc_range(BB)) { - if (Inst.getType()->isVoidTy()) - continue; - if (tryToReplaceWithConstant(Solver, &Inst)) { - if (canRemoveInstruction(&Inst)) - Inst.eraseFromParent(); - - MadeChanges = true; - ++InstRemovedStat; - } else if (isa<SExtInst>(&Inst)) { - Value *ExtOp = Inst.getOperand(0); - if (isa<Constant>(ExtOp) || InsertedValues.count(ExtOp)) - continue; - const ValueLatticeElement &IV = Solver.getLatticeValueFor(ExtOp); - if (!IV.isConstantRange(/*UndefAllowed=*/false)) - continue; - if (IV.getConstantRange().isAllNonNegative()) { - auto *ZExt = new ZExtInst(ExtOp, Inst.getType(), "", &Inst); - ZExt->takeName(&Inst); - InsertedValues.insert(ZExt); - Inst.replaceAllUsesWith(ZExt); - Solver.removeLatticeValueFor(&Inst); - Inst.eraseFromParent(); - InstReplacedStat++; - MadeChanges = true; - } - } - } - return MadeChanges; -} - -static bool removeNonFeasibleEdges(const SCCPSolver &Solver, BasicBlock *BB, - DomTreeUpdater &DTU, - BasicBlock *&NewUnreachableBB); - // runSCCP() - Run the Sparse Conditional Constant Propagation algorithm, // and return true if the function was modified. static bool runSCCP(Function &F, const DataLayout &DL, @@ -235,8 +104,8 @@ static bool runSCCP(Function &F, const DataLayout &DL, continue; } - MadeChanges |= simplifyInstsInBlock(Solver, BB, InsertedValues, - NumInstRemoved, NumInstReplaced); + MadeChanges |= Solver.simplifyInstsInBlock(BB, InsertedValues, + NumInstRemoved, NumInstReplaced); } // Remove unreachable blocks and non-feasible edges. @@ -246,7 +115,7 @@ static bool runSCCP(Function &F, const DataLayout &DL, BasicBlock *NewUnreachableBB = nullptr; for (BasicBlock &BB : F) - MadeChanges |= removeNonFeasibleEdges(Solver, &BB, DTU, NewUnreachableBB); + MadeChanges |= Solver.removeNonFeasibleEdges(&BB, DTU, NewUnreachableBB); for (BasicBlock *DeadBB : BlocksToErase) if (!DeadBB->hasAddressTaken()) @@ -318,407 +187,3 @@ INITIALIZE_PASS_END(SCCPLegacyPass, "sccp", // createSCCPPass - This is the public interface to this file. FunctionPass *llvm::createSCCPPass() { return new SCCPLegacyPass(); } -static void findReturnsToZap(Function &F, - SmallVector<ReturnInst *, 8> &ReturnsToZap, - SCCPSolver &Solver) { - // We can only do this if we know that nothing else can call the function. - if (!Solver.isArgumentTrackedFunction(&F)) - return; - - if (Solver.mustPreserveReturn(&F)) { - LLVM_DEBUG( - dbgs() - << "Can't zap returns of the function : " << F.getName() - << " due to present musttail or \"clang.arc.attachedcall\" call of " - "it\n"); - return; - } - - assert( - all_of(F.users(), - [&Solver](User *U) { - if (isa<Instruction>(U) && - !Solver.isBlockExecutable(cast<Instruction>(U)->getParent())) - return true; - // Non-callsite uses are not impacted by zapping. Also, constant - // uses (like blockaddresses) could stuck around, without being - // used in the underlying IR, meaning we do not have lattice - // values for them. - if (!isa<CallBase>(U)) - return true; - if (U->getType()->isStructTy()) { - return all_of(Solver.getStructLatticeValueFor(U), - [](const ValueLatticeElement &LV) { - return !isOverdefined(LV); - }); - } - return !isOverdefined(Solver.getLatticeValueFor(U)); - }) && - "We can only zap functions where all live users have a concrete value"); - - for (BasicBlock &BB : F) { - if (CallInst *CI = BB.getTerminatingMustTailCall()) { - LLVM_DEBUG(dbgs() << "Can't zap return of the block due to present " - << "musttail call : " << *CI << "\n"); - (void)CI; - return; - } - - if (auto *RI = dyn_cast<ReturnInst>(BB.getTerminator())) - if (!isa<UndefValue>(RI->getOperand(0))) - ReturnsToZap.push_back(RI); - } -} - -static bool removeNonFeasibleEdges(const SCCPSolver &Solver, BasicBlock *BB, - DomTreeUpdater &DTU, - BasicBlock *&NewUnreachableBB) { - SmallPtrSet<BasicBlock *, 8> FeasibleSuccessors; - bool HasNonFeasibleEdges = false; - for (BasicBlock *Succ : successors(BB)) { - if (Solver.isEdgeFeasible(BB, Succ)) - FeasibleSuccessors.insert(Succ); - else - HasNonFeasibleEdges = true; - } - - // All edges feasible, nothing to do. - if (!HasNonFeasibleEdges) - return false; - - // SCCP can only determine non-feasible edges for br, switch and indirectbr. - Instruction *TI = BB->getTerminator(); - assert((isa<BranchInst>(TI) || isa<SwitchInst>(TI) || - isa<IndirectBrInst>(TI)) && - "Terminator must be a br, switch or indirectbr"); - - if (FeasibleSuccessors.size() == 0) { - // Branch on undef/poison, replace with unreachable. - SmallPtrSet<BasicBlock *, 8> SeenSuccs; - SmallVector<DominatorTree::UpdateType, 8> Updates; - for (BasicBlock *Succ : successors(BB)) { - Succ->removePredecessor(BB); - if (SeenSuccs.insert(Succ).second) - Updates.push_back({DominatorTree::Delete, BB, Succ}); - } - TI->eraseFromParent(); - new UnreachableInst(BB->getContext(), BB); - DTU.applyUpdatesPermissive(Updates); - } else if (FeasibleSuccessors.size() == 1) { - // Replace with an unconditional branch to the only feasible successor. - BasicBlock *OnlyFeasibleSuccessor = *FeasibleSuccessors.begin(); - SmallVector<DominatorTree::UpdateType, 8> Updates; - bool HaveSeenOnlyFeasibleSuccessor = false; - for (BasicBlock *Succ : successors(BB)) { - if (Succ == OnlyFeasibleSuccessor && !HaveSeenOnlyFeasibleSuccessor) { - // Don't remove the edge to the only feasible successor the first time - // we see it. We still do need to remove any multi-edges to it though. - HaveSeenOnlyFeasibleSuccessor = true; - continue; - } - - Succ->removePredecessor(BB); - Updates.push_back({DominatorTree::Delete, BB, Succ}); - } - - BranchInst::Create(OnlyFeasibleSuccessor, BB); - TI->eraseFromParent(); - DTU.applyUpdatesPermissive(Updates); - } else if (FeasibleSuccessors.size() > 1) { - SwitchInstProfUpdateWrapper SI(*cast<SwitchInst>(TI)); - SmallVector<DominatorTree::UpdateType, 8> Updates; - - // If the default destination is unfeasible it will never be taken. Replace - // it with a new block with a single Unreachable instruction. - BasicBlock *DefaultDest = SI->getDefaultDest(); - if (!FeasibleSuccessors.contains(DefaultDest)) { - if (!NewUnreachableBB) { - NewUnreachableBB = - BasicBlock::Create(DefaultDest->getContext(), "default.unreachable", - DefaultDest->getParent(), DefaultDest); - new UnreachableInst(DefaultDest->getContext(), NewUnreachableBB); - } - - SI->setDefaultDest(NewUnreachableBB); - Updates.push_back({DominatorTree::Delete, BB, DefaultDest}); - Updates.push_back({DominatorTree::Insert, BB, NewUnreachableBB}); - } - - for (auto CI = SI->case_begin(); CI != SI->case_end();) { - if (FeasibleSuccessors.contains(CI->getCaseSuccessor())) { - ++CI; - continue; - } - - BasicBlock *Succ = CI->getCaseSuccessor(); - Succ->removePredecessor(BB); - Updates.push_back({DominatorTree::Delete, BB, Succ}); - SI.removeCase(CI); - // Don't increment CI, as we removed a case. - } - - DTU.applyUpdatesPermissive(Updates); - } else { - llvm_unreachable("Must have at least one feasible successor"); - } - return true; -} - -bool llvm::runIPSCCP( - Module &M, const DataLayout &DL, - std::function<const TargetLibraryInfo &(Function &)> GetTLI, - function_ref<AnalysisResultsForFn(Function &)> getAnalysis) { - SCCPSolver Solver(DL, GetTLI, M.getContext()); - - // Loop over all functions, marking arguments to those with their addresses - // taken or that are external as overdefined. - for (Function &F : M) { - if (F.isDeclaration()) - continue; - - Solver.addAnalysis(F, getAnalysis(F)); - - // Determine if we can track the function's return values. If so, add the - // function to the solver's set of return-tracked functions. - if (canTrackReturnsInterprocedurally(&F)) - Solver.addTrackedFunction(&F); - - // Determine if we can track the function's arguments. If so, add the - // function to the solver's set of argument-tracked functions. - if (canTrackArgumentsInterprocedurally(&F)) { - Solver.addArgumentTrackedFunction(&F); - continue; - } - - // Assume the function is called. - Solver.markBlockExecutable(&F.front()); - - // Assume nothing about the incoming arguments. - for (Argument &AI : F.args()) - Solver.markOverdefined(&AI); - } - - // Determine if we can track any of the module's global variables. If so, add - // the global variables we can track to the solver's set of tracked global - // variables. - for (GlobalVariable &G : M.globals()) { - G.removeDeadConstantUsers(); - if (canTrackGlobalVariableInterprocedurally(&G)) - Solver.trackValueOfGlobalVariable(&G); - } - - // Solve for constants. - bool ResolvedUndefs = true; - Solver.solve(); - while (ResolvedUndefs) { - LLVM_DEBUG(dbgs() << "RESOLVING UNDEFS\n"); - ResolvedUndefs = false; - for (Function &F : M) { - if (Solver.resolvedUndefsIn(F)) - ResolvedUndefs = true; - } - if (ResolvedUndefs) - Solver.solve(); - } - - bool MadeChanges = false; - - // Iterate over all of the instructions in the module, replacing them with - // constants if we have found them to be of constant values. - - for (Function &F : M) { - if (F.isDeclaration()) - continue; - - SmallVector<BasicBlock *, 512> BlocksToErase; - - if (Solver.isBlockExecutable(&F.front())) { - bool ReplacedPointerArg = false; - for (Argument &Arg : F.args()) { - if (!Arg.use_empty() && tryToReplaceWithConstant(Solver, &Arg)) { - ReplacedPointerArg |= Arg.getType()->isPointerTy(); - ++IPNumArgsElimed; - } - } - - // If we replaced an argument, the argmemonly and - // inaccessiblemem_or_argmemonly attributes do not hold any longer. Remove - // them from both the function and callsites. - if (ReplacedPointerArg) { - AttributeMask AttributesToRemove; - AttributesToRemove.addAttribute(Attribute::ArgMemOnly); - AttributesToRemove.addAttribute(Attribute::InaccessibleMemOrArgMemOnly); - F.removeFnAttrs(AttributesToRemove); - - for (User *U : F.users()) { - auto *CB = dyn_cast<CallBase>(U); - if (!CB || CB->getCalledFunction() != &F) - continue; - - CB->removeFnAttrs(AttributesToRemove); - } - } - MadeChanges |= ReplacedPointerArg; - } - - SmallPtrSet<Value *, 32> InsertedValues; - for (BasicBlock &BB : F) { - if (!Solver.isBlockExecutable(&BB)) { - LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB); - ++NumDeadBlocks; - - MadeChanges = true; - - if (&BB != &F.front()) - BlocksToErase.push_back(&BB); - continue; - } - - MadeChanges |= simplifyInstsInBlock(Solver, BB, InsertedValues, - IPNumInstRemoved, IPNumInstReplaced); - } - - DomTreeUpdater DTU = Solver.getDTU(F); - // Change dead blocks to unreachable. We do it after replacing constants - // in all executable blocks, because changeToUnreachable may remove PHI - // nodes in executable blocks we found values for. The function's entry - // block is not part of BlocksToErase, so we have to handle it separately. - for (BasicBlock *BB : BlocksToErase) { - NumInstRemoved += changeToUnreachable(BB->getFirstNonPHI(), - /*PreserveLCSSA=*/false, &DTU); - } - if (!Solver.isBlockExecutable(&F.front())) - NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHI(), - /*PreserveLCSSA=*/false, &DTU); - - BasicBlock *NewUnreachableBB = nullptr; - for (BasicBlock &BB : F) - MadeChanges |= removeNonFeasibleEdges(Solver, &BB, DTU, NewUnreachableBB); - - for (BasicBlock *DeadBB : BlocksToErase) - if (!DeadBB->hasAddressTaken()) - DTU.deleteBB(DeadBB); - - for (BasicBlock &BB : F) { - for (Instruction &Inst : llvm::make_early_inc_range(BB)) { - if (Solver.getPredicateInfoFor(&Inst)) { - if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { - if (II->getIntrinsicID() == Intrinsic::ssa_copy) { - Value *Op = II->getOperand(0); - Inst.replaceAllUsesWith(Op); - Inst.eraseFromParent(); - } - } - } - } - } - } - - // If we inferred constant or undef return values for a function, we replaced - // all call uses with the inferred value. This means we don't need to bother - // actually returning anything from the function. Replace all return - // instructions with return undef. - // - // Do this in two stages: first identify the functions we should process, then - // actually zap their returns. This is important because we can only do this - // if the address of the function isn't taken. In cases where a return is the - // last use of a function, the order of processing functions would affect - // whether other functions are optimizable. - SmallVector<ReturnInst*, 8> ReturnsToZap; - - for (const auto &I : Solver.getTrackedRetVals()) { - Function *F = I.first; - const ValueLatticeElement &ReturnValue = I.second; - - // If there is a known constant range for the return value, add !range - // metadata to the function's call sites. - if (ReturnValue.isConstantRange() && - !ReturnValue.getConstantRange().isSingleElement()) { - // Do not add range metadata if the return value may include undef. - if (ReturnValue.isConstantRangeIncludingUndef()) - continue; - - auto &CR = ReturnValue.getConstantRange(); - for (User *User : F->users()) { - auto *CB = dyn_cast<CallBase>(User); - if (!CB || CB->getCalledFunction() != F) - continue; - - // Limit to cases where the return value is guaranteed to be neither - // poison nor undef. Poison will be outside any range and currently - // values outside of the specified range cause immediate undefined - // behavior. - if (!isGuaranteedNotToBeUndefOrPoison(CB, nullptr, CB)) - continue; - - // Do not touch existing metadata for now. - // TODO: We should be able to take the intersection of the existing - // metadata and the inferred range. - if (CB->getMetadata(LLVMContext::MD_range)) - continue; - - LLVMContext &Context = CB->getParent()->getContext(); - Metadata *RangeMD[] = { - ConstantAsMetadata::get(ConstantInt::get(Context, CR.getLower())), - ConstantAsMetadata::get(ConstantInt::get(Context, CR.getUpper()))}; - CB->setMetadata(LLVMContext::MD_range, MDNode::get(Context, RangeMD)); - } - continue; - } - if (F->getReturnType()->isVoidTy()) - continue; - if (isConstant(ReturnValue) || ReturnValue.isUnknownOrUndef()) - findReturnsToZap(*F, ReturnsToZap, Solver); - } - - for (auto F : Solver.getMRVFunctionsTracked()) { - assert(F->getReturnType()->isStructTy() && - "The return type should be a struct"); - StructType *STy = cast<StructType>(F->getReturnType()); - if (Solver.isStructLatticeConstant(F, STy)) - findReturnsToZap(*F, ReturnsToZap, Solver); - } - - // Zap all returns which we've identified as zap to change. - SmallSetVector<Function *, 8> FuncZappedReturn; - for (unsigned i = 0, e = ReturnsToZap.size(); i != e; ++i) { - Function *F = ReturnsToZap[i]->getParent()->getParent(); - ReturnsToZap[i]->setOperand(0, UndefValue::get(F->getReturnType())); - // Record all functions that are zapped. - FuncZappedReturn.insert(F); - } - - // Remove the returned attribute for zapped functions and the - // corresponding call sites. - for (Function *F : FuncZappedReturn) { - for (Argument &A : F->args()) - F->removeParamAttr(A.getArgNo(), Attribute::Returned); - for (Use &U : F->uses()) { - // Skip over blockaddr users. - if (isa<BlockAddress>(U.getUser())) - continue; - CallBase *CB = cast<CallBase>(U.getUser()); - for (Use &Arg : CB->args()) - CB->removeParamAttr(CB->getArgOperandNo(&Arg), Attribute::Returned); - } - } - - // If we inferred constant or undef values for globals variables, we can - // delete the global and any stores that remain to it. - for (auto &I : make_early_inc_range(Solver.getTrackedGlobals())) { - GlobalVariable *GV = I.first; - if (isOverdefined(I.second)) - continue; - LLVM_DEBUG(dbgs() << "Found that GV '" << GV->getName() - << "' is constant!\n"); - while (!GV->use_empty()) { - StoreInst *SI = cast<StoreInst>(GV->user_back()); - SI->eraseFromParent(); - MadeChanges = true; - } - M.getGlobalList().erase(GV); - ++IPNumGlobalConst; - } - - return MadeChanges; -} diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 143a035749c7..8339981e1bdc 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -38,6 +38,7 @@ #include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/PtrUseVisitor.h" @@ -78,6 +79,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> @@ -104,6 +106,11 @@ STATISTIC(MaxUsesPerAllocaPartition, "Maximum number of uses of a partition"); STATISTIC(NumNewAllocas, "Number of new, smaller allocas introduced"); STATISTIC(NumPromoted, "Number of allocas promoted to SSA values"); STATISTIC(NumLoadsSpeculated, "Number of loads speculated to allow promotion"); +STATISTIC(NumLoadsPredicated, + "Number of loads rewritten into predicated loads to allow promotion"); +STATISTIC( + NumStoresPredicated, + "Number of stores rewritten into predicated loads to allow promotion"); STATISTIC(NumDeleted, "Number of instructions deleted"); STATISTIC(NumVectorized, "Number of vectorized aggregates"); @@ -111,8 +118,111 @@ STATISTIC(NumVectorized, "Number of vectorized aggregates"); /// GEPs. static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), cl::Hidden); - namespace { +/// Find linked dbg.assign and generate a new one with the correct +/// FragmentInfo. Link Inst to the new dbg.assign. If Value is nullptr the +/// value component is copied from the old dbg.assign to the new. +/// \param OldAlloca Alloca for the variable before splitting. +/// \param RelativeOffsetInBits Offset into \p OldAlloca relative to the +/// offset prior to splitting (change in offset). +/// \param SliceSizeInBits New number of bits being written to. +/// \param OldInst Instruction that is being split. +/// \param Inst New instruction performing this part of the +/// split store. +/// \param Dest Store destination. +/// \param Value Stored value. +/// \param DL Datalayout. +static void migrateDebugInfo(AllocaInst *OldAlloca, + uint64_t RelativeOffsetInBits, + uint64_t SliceSizeInBits, Instruction *OldInst, + Instruction *Inst, Value *Dest, Value *Value, + const DataLayout &DL) { + auto MarkerRange = at::getAssignmentMarkers(OldInst); + // Nothing to do if OldInst has no linked dbg.assign intrinsics. + if (MarkerRange.empty()) + return; + + LLVM_DEBUG(dbgs() << " migrateDebugInfo\n"); + LLVM_DEBUG(dbgs() << " OldAlloca: " << *OldAlloca << "\n"); + LLVM_DEBUG(dbgs() << " RelativeOffset: " << RelativeOffsetInBits << "\n"); + LLVM_DEBUG(dbgs() << " SliceSizeInBits: " << SliceSizeInBits << "\n"); + LLVM_DEBUG(dbgs() << " OldInst: " << *OldInst << "\n"); + LLVM_DEBUG(dbgs() << " Inst: " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << " Dest: " << *Dest << "\n"); + if (Value) + LLVM_DEBUG(dbgs() << " Value: " << *Value << "\n"); + + // The new inst needs a DIAssignID unique metadata tag (if OldInst has + // one). It shouldn't already have one: assert this assumption. + assert(!Inst->getMetadata(LLVMContext::MD_DIAssignID)); + DIAssignID *NewID = nullptr; + auto &Ctx = Inst->getContext(); + DIBuilder DIB(*OldInst->getModule(), /*AllowUnresolved*/ false); + uint64_t AllocaSizeInBits = *OldAlloca->getAllocationSizeInBits(DL); + assert(OldAlloca->isStaticAlloca()); + + for (DbgAssignIntrinsic *DbgAssign : MarkerRange) { + LLVM_DEBUG(dbgs() << " existing dbg.assign is: " << *DbgAssign + << "\n"); + auto *Expr = DbgAssign->getExpression(); + + // Check if the dbg.assign already describes a fragment. + auto GetCurrentFragSize = [AllocaSizeInBits, DbgAssign, + Expr]() -> uint64_t { + if (auto FI = Expr->getFragmentInfo()) + return FI->SizeInBits; + if (auto VarSize = DbgAssign->getVariable()->getSizeInBits()) + return *VarSize; + // The variable type has an unspecified size. This can happen in the + // case of DW_TAG_unspecified_type types, e.g. std::nullptr_t. Because + // there is no fragment and we do not know the size of the variable type, + // we'll guess by looking at the alloca. + return AllocaSizeInBits; + }; + uint64_t CurrentFragSize = GetCurrentFragSize(); + bool MakeNewFragment = CurrentFragSize != SliceSizeInBits; + assert(MakeNewFragment || RelativeOffsetInBits == 0); + + assert(SliceSizeInBits <= AllocaSizeInBits); + if (MakeNewFragment) { + assert(RelativeOffsetInBits + SliceSizeInBits <= CurrentFragSize); + auto E = DIExpression::createFragmentExpression( + Expr, RelativeOffsetInBits, SliceSizeInBits); + assert(E && "Failed to create fragment expr!"); + Expr = *E; + } + + // If we haven't created a DIAssignID ID do that now and attach it to Inst. + if (!NewID) { + NewID = DIAssignID::getDistinct(Ctx); + Inst->setMetadata(LLVMContext::MD_DIAssignID, NewID); + } + + Value = Value ? Value : DbgAssign->getValue(); + auto *NewAssign = DIB.insertDbgAssign( + Inst, Value, DbgAssign->getVariable(), Expr, Dest, + DIExpression::get(Ctx, std::nullopt), DbgAssign->getDebugLoc()); + + // We could use more precision here at the cost of some additional (code) + // complexity - if the original dbg.assign was adjacent to its store, we + // could position this new dbg.assign adjacent to its store rather than the + // old dbg.assgn. That would result in interleaved dbg.assigns rather than + // what we get now: + // split store !1 + // split store !2 + // dbg.assign !1 + // dbg.assign !2 + // This (current behaviour) results results in debug assignments being + // noted as slightly offset (in code) from the store. In practice this + // should have little effect on the debugging experience due to the fact + // that all the split stores should get the same line number. + NewAssign->moveBefore(DbgAssign); + + NewAssign->setDebugLoc(DbgAssign->getDebugLoc()); + LLVM_DEBUG(dbgs() << "Created new assign intrinsic: " << *NewAssign + << "\n"); + } +} /// A custom IRBuilder inserter which prefixes all names, but only in /// Assert builds. @@ -653,7 +763,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> { public: SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS) : PtrUseVisitor<SliceBuilder>(DL), - AllocSize(DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize()), + AllocSize(DL.getTypeAllocSize(AI.getAllocatedType()).getFixedValue()), AS(AS) {} private: @@ -746,7 +856,7 @@ private: GEPOffset += Index * APInt(Offset.getBitWidth(), - DL.getTypeAllocSize(GTI.getIndexedType()).getFixedSize()); + DL.getTypeAllocSize(GTI.getIndexedType()).getFixedValue()); } // If this index has computed an intermediate pointer which is not @@ -778,14 +888,10 @@ private: if (!IsOffsetKnown) return PI.setAborted(&LI); - if (LI.isVolatile() && - LI.getPointerAddressSpace() != DL.getAllocaAddrSpace()) - return PI.setAborted(&LI); - if (isa<ScalableVectorType>(LI.getType())) return PI.setAborted(&LI); - uint64_t Size = DL.getTypeStoreSize(LI.getType()).getFixedSize(); + uint64_t Size = DL.getTypeStoreSize(LI.getType()).getFixedValue(); return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile()); } @@ -796,14 +902,10 @@ private: if (!IsOffsetKnown) return PI.setAborted(&SI); - if (SI.isVolatile() && - SI.getPointerAddressSpace() != DL.getAllocaAddrSpace()) - return PI.setAborted(&SI); - if (isa<ScalableVectorType>(ValOp->getType())) return PI.setAborted(&SI); - uint64_t Size = DL.getTypeStoreSize(ValOp->getType()).getFixedSize(); + uint64_t Size = DL.getTypeStoreSize(ValOp->getType()).getFixedValue(); // If this memory access can be shown to *statically* extend outside the // bounds of the allocation, it's behavior is undefined, so simply @@ -837,11 +939,6 @@ private: if (!IsOffsetKnown) return PI.setAborted(&II); - // Don't replace this with a store with a different address space. TODO: - // Use a store with the casted new alloca? - if (II.isVolatile() && II.getDestAddressSpace() != DL.getAllocaAddrSpace()) - return PI.setAborted(&II); - insertUse(II, Offset, Length ? Length->getLimitedValue() : AllocSize - Offset.getLimitedValue(), (bool)Length); @@ -861,13 +958,6 @@ private: if (!IsOffsetKnown) return PI.setAborted(&II); - // Don't replace this with a load/store with a different address space. - // TODO: Use a store with the casted new alloca? - if (II.isVolatile() && - (II.getDestAddressSpace() != DL.getAllocaAddrSpace() || - II.getSourceAddressSpace() != DL.getAllocaAddrSpace())) - return PI.setAborted(&II); - // This side of the transfer is completely out-of-bounds, and so we can // nuke the entire transfer. However, we also need to nuke the other side // if already added to our partitions. @@ -971,16 +1061,16 @@ private: std::tie(UsedI, I) = Uses.pop_back_val(); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - Size = std::max(Size, - DL.getTypeStoreSize(LI->getType()).getFixedSize()); + Size = + std::max(Size, DL.getTypeStoreSize(LI->getType()).getFixedValue()); continue; } if (StoreInst *SI = dyn_cast<StoreInst>(I)) { Value *Op = SI->getOperand(0); if (Op == UsedI) return SI; - Size = std::max(Size, - DL.getTypeStoreSize(Op->getType()).getFixedSize()); + Size = + std::max(Size, DL.getTypeStoreSize(Op->getType()).getFixedValue()); continue; } @@ -1210,8 +1300,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) { BasicBlock *BB = PN.getParent(); Align MaxAlign; uint64_t APWidth = DL.getIndexTypeSizeInBits(PN.getType()); - APInt MaxSize(APWidth, 0); - bool HaveLoad = false; + Type *LoadType = nullptr; for (User *U : PN.users()) { LoadInst *LI = dyn_cast<LoadInst>(U); if (!LI || !LI->isSimple()) @@ -1223,21 +1312,28 @@ static bool isSafePHIToSpeculate(PHINode &PN) { if (LI->getParent() != BB) return false; + if (LoadType) { + if (LoadType != LI->getType()) + return false; + } else { + LoadType = LI->getType(); + } + // Ensure that there are no instructions between the PHI and the load that // could store. for (BasicBlock::iterator BBI(PN); &*BBI != LI; ++BBI) if (BBI->mayWriteToMemory()) return false; - uint64_t Size = DL.getTypeStoreSize(LI->getType()).getFixedSize(); MaxAlign = std::max(MaxAlign, LI->getAlign()); - MaxSize = MaxSize.ult(Size) ? APInt(APWidth, Size) : MaxSize; - HaveLoad = true; } - if (!HaveLoad) + if (!LoadType) return false; + APInt LoadSize = + APInt(APWidth, DL.getTypeStoreSize(LoadType).getFixedValue()); + // We can only transform this if it is safe to push the loads into the // predecessor blocks. The only thing to watch out for is that we can't put // a possibly trapping load in the predecessor if it is a critical edge. @@ -1259,7 +1355,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) { // If this pointer is always safe to load, or if we can prove that there // is already a load in the block, then we can move the load to the pred // block. - if (isSafeToLoadUnconditionally(InVal, MaxAlign, MaxSize, DL, TI)) + if (isSafeToLoadUnconditionally(InVal, MaxAlign, LoadSize, DL, TI)) continue; return false; @@ -1321,102 +1417,241 @@ static void speculatePHINodeLoads(IRBuilderTy &IRB, PHINode &PN) { PN.eraseFromParent(); } -/// Select instructions that use an alloca and are subsequently loaded can be -/// rewritten to load both input pointers and then select between the result, -/// allowing the load of the alloca to be promoted. -/// From this: -/// %P2 = select i1 %cond, i32* %Alloca, i32* %Other -/// %V = load i32* %P2 -/// to: -/// %V1 = load i32* %Alloca -> will be mem2reg'd -/// %V2 = load i32* %Other -/// %V = select i1 %cond, i32 %V1, i32 %V2 -/// -/// We can do this to a select if its only uses are loads and if the operand -/// to the select can be loaded unconditionally. If found an intervening bitcast -/// with a single use of the load, allow the promotion. -static bool isSafeSelectToSpeculate(SelectInst &SI) { - Value *TValue = SI.getTrueValue(); - Value *FValue = SI.getFalseValue(); +sroa::SelectHandSpeculativity & +sroa::SelectHandSpeculativity::setAsSpeculatable(bool isTrueVal) { + if (isTrueVal) + Bitfield::set<sroa::SelectHandSpeculativity::TrueVal>(Storage, true); + else + Bitfield::set<sroa::SelectHandSpeculativity::FalseVal>(Storage, true); + return *this; +} + +bool sroa::SelectHandSpeculativity::isSpeculatable(bool isTrueVal) const { + return isTrueVal + ? Bitfield::get<sroa::SelectHandSpeculativity::TrueVal>(Storage) + : Bitfield::get<sroa::SelectHandSpeculativity::FalseVal>(Storage); +} + +bool sroa::SelectHandSpeculativity::areAllSpeculatable() const { + return isSpeculatable(/*isTrueVal=*/true) && + isSpeculatable(/*isTrueVal=*/false); +} + +bool sroa::SelectHandSpeculativity::areAnySpeculatable() const { + return isSpeculatable(/*isTrueVal=*/true) || + isSpeculatable(/*isTrueVal=*/false); +} +bool sroa::SelectHandSpeculativity::areNoneSpeculatable() const { + return !areAnySpeculatable(); +} + +static sroa::SelectHandSpeculativity +isSafeLoadOfSelectToSpeculate(LoadInst &LI, SelectInst &SI, bool PreserveCFG) { + assert(LI.isSimple() && "Only for simple loads"); + sroa::SelectHandSpeculativity Spec; + const DataLayout &DL = SI.getModule()->getDataLayout(); + for (Value *Value : {SI.getTrueValue(), SI.getFalseValue()}) + if (isSafeToLoadUnconditionally(Value, LI.getType(), LI.getAlign(), DL, + &LI)) + Spec.setAsSpeculatable(/*isTrueVal=*/Value == SI.getTrueValue()); + else if (PreserveCFG) + return Spec; + + return Spec; +} + +std::optional<sroa::RewriteableMemOps> +SROAPass::isSafeSelectToSpeculate(SelectInst &SI, bool PreserveCFG) { + RewriteableMemOps Ops; for (User *U : SI.users()) { - LoadInst *LI; - BitCastInst *BC = dyn_cast<BitCastInst>(U); - if (BC && BC->hasOneUse()) - LI = dyn_cast<LoadInst>(*BC->user_begin()); - else - LI = dyn_cast<LoadInst>(U); + if (auto *BC = dyn_cast<BitCastInst>(U); BC && BC->hasOneUse()) + U = *BC->user_begin(); + + if (auto *Store = dyn_cast<StoreInst>(U)) { + // Note that atomic stores can be transformed; atomic semantics do not + // have any meaning for a local alloca. Stores are not speculatable, + // however, so if we can't turn it into a predicated store, we are done. + if (Store->isVolatile() || PreserveCFG) + return {}; // Give up on this `select`. + Ops.emplace_back(Store); + continue; + } - if (!LI || !LI->isSimple()) - return false; + auto *LI = dyn_cast<LoadInst>(U); - // Both operands to the select need to be dereferenceable, either - // absolutely (e.g. allocas) or at this point because we can see other - // accesses to it. - if (!isSafeToLoadUnconditionally(TValue, LI->getType(), - LI->getAlign(), DL, LI)) - return false; - if (!isSafeToLoadUnconditionally(FValue, LI->getType(), - LI->getAlign(), DL, LI)) - return false; + // Note that atomic loads can be transformed; + // atomic semantics do not have any meaning for a local alloca. + if (!LI || LI->isVolatile()) + return {}; // Give up on this `select`. + + PossiblySpeculatableLoad Load(LI); + if (!LI->isSimple()) { + // If the `load` is not simple, we can't speculatively execute it, + // but we could handle this via a CFG modification. But can we? + if (PreserveCFG) + return {}; // Give up on this `select`. + Ops.emplace_back(Load); + continue; + } + + sroa::SelectHandSpeculativity Spec = + isSafeLoadOfSelectToSpeculate(*LI, SI, PreserveCFG); + if (PreserveCFG && !Spec.areAllSpeculatable()) + return {}; // Give up on this `select`. + + Load.setInt(Spec); + Ops.emplace_back(Load); } - return true; + return Ops; } -static void speculateSelectInstLoads(IRBuilderTy &IRB, SelectInst &SI) { - LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); +static void speculateSelectInstLoads(SelectInst &SI, LoadInst &LI, + IRBuilderTy &IRB) { + LLVM_DEBUG(dbgs() << " original load: " << SI << "\n"); - IRB.SetInsertPoint(&SI); Value *TV = SI.getTrueValue(); Value *FV = SI.getFalseValue(); - // Replace the loads of the select with a select of two loads. - while (!SI.use_empty()) { - LoadInst *LI; - BitCastInst *BC = dyn_cast<BitCastInst>(SI.user_back()); - if (BC) { - assert(BC->hasOneUse() && "Bitcast should have a single use."); - LI = cast<LoadInst>(BC->user_back()); - } else { - LI = cast<LoadInst>(SI.user_back()); - } + // Replace the given load of the select with a select of two loads. - assert(LI->isSimple() && "We only speculate simple loads"); + assert(LI.isSimple() && "We only speculate simple loads"); - IRB.SetInsertPoint(LI); - Value *NewTV = - BC ? IRB.CreateBitCast(TV, BC->getType(), TV->getName() + ".sroa.cast") - : TV; - Value *NewFV = - BC ? IRB.CreateBitCast(FV, BC->getType(), FV->getName() + ".sroa.cast") - : FV; - LoadInst *TL = IRB.CreateLoad(LI->getType(), NewTV, - LI->getName() + ".sroa.speculate.load.true"); - LoadInst *FL = IRB.CreateLoad(LI->getType(), NewFV, - LI->getName() + ".sroa.speculate.load.false"); - NumLoadsSpeculated += 2; - - // Transfer alignment and AA info if present. - TL->setAlignment(LI->getAlign()); - FL->setAlignment(LI->getAlign()); - - AAMDNodes Tags = LI->getAAMetadata(); - if (Tags) { - TL->setAAMetadata(Tags); - FL->setAAMetadata(Tags); - } + IRB.SetInsertPoint(&LI); + + if (auto *TypedPtrTy = LI.getPointerOperandType(); + !TypedPtrTy->isOpaquePointerTy() && SI.getType() != TypedPtrTy) { + TV = IRB.CreateBitOrPointerCast(TV, TypedPtrTy, ""); + FV = IRB.CreateBitOrPointerCast(FV, TypedPtrTy, ""); + } - Value *V = IRB.CreateSelect(SI.getCondition(), TL, FL, - LI->getName() + ".sroa.speculated"); + LoadInst *TL = + IRB.CreateAlignedLoad(LI.getType(), TV, LI.getAlign(), + LI.getName() + ".sroa.speculate.load.true"); + LoadInst *FL = + IRB.CreateAlignedLoad(LI.getType(), FV, LI.getAlign(), + LI.getName() + ".sroa.speculate.load.false"); + NumLoadsSpeculated += 2; + + // Transfer alignment and AA info if present. + TL->setAlignment(LI.getAlign()); + FL->setAlignment(LI.getAlign()); + + AAMDNodes Tags = LI.getAAMetadata(); + if (Tags) { + TL->setAAMetadata(Tags); + FL->setAAMetadata(Tags); + } - LLVM_DEBUG(dbgs() << " speculated to: " << *V << "\n"); - LI->replaceAllUsesWith(V); - LI->eraseFromParent(); - if (BC) - BC->eraseFromParent(); + Value *V = IRB.CreateSelect(SI.getCondition(), TL, FL, + LI.getName() + ".sroa.speculated"); + + LLVM_DEBUG(dbgs() << " speculated to: " << *V << "\n"); + LI.replaceAllUsesWith(V); +} + +template <typename T> +static void rewriteMemOpOfSelect(SelectInst &SI, T &I, + sroa::SelectHandSpeculativity Spec, + DomTreeUpdater &DTU) { + assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Only for load and store!"); + LLVM_DEBUG(dbgs() << " original mem op: " << I << "\n"); + BasicBlock *Head = I.getParent(); + Instruction *ThenTerm = nullptr; + Instruction *ElseTerm = nullptr; + if (Spec.areNoneSpeculatable()) + SplitBlockAndInsertIfThenElse(SI.getCondition(), &I, &ThenTerm, &ElseTerm, + SI.getMetadata(LLVMContext::MD_prof), &DTU); + else { + SplitBlockAndInsertIfThen(SI.getCondition(), &I, /*Unreachable=*/false, + SI.getMetadata(LLVMContext::MD_prof), &DTU, + /*LI=*/nullptr, /*ThenBlock=*/nullptr); + if (Spec.isSpeculatable(/*isTrueVal=*/true)) + cast<BranchInst>(Head->getTerminator())->swapSuccessors(); + } + auto *HeadBI = cast<BranchInst>(Head->getTerminator()); + Spec = {}; // Do not use `Spec` beyond this point. + BasicBlock *Tail = I.getParent(); + Tail->setName(Head->getName() + ".cont"); + PHINode *PN; + if (isa<LoadInst>(I)) + PN = PHINode::Create(I.getType(), 2, "", &I); + for (BasicBlock *SuccBB : successors(Head)) { + bool IsThen = SuccBB == HeadBI->getSuccessor(0); + int SuccIdx = IsThen ? 0 : 1; + auto *NewMemOpBB = SuccBB == Tail ? Head : SuccBB; + if (NewMemOpBB != Head) { + NewMemOpBB->setName(Head->getName() + (IsThen ? ".then" : ".else")); + if (isa<LoadInst>(I)) + ++NumLoadsPredicated; + else + ++NumStoresPredicated; + } else + ++NumLoadsSpeculated; + auto &CondMemOp = cast<T>(*I.clone()); + CondMemOp.insertBefore(NewMemOpBB->getTerminator()); + Value *Ptr = SI.getOperand(1 + SuccIdx); + if (auto *PtrTy = Ptr->getType(); + !PtrTy->isOpaquePointerTy() && + PtrTy != CondMemOp.getPointerOperandType()) + Ptr = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( + Ptr, CondMemOp.getPointerOperandType(), "", &CondMemOp); + CondMemOp.setOperand(I.getPointerOperandIndex(), Ptr); + if (isa<LoadInst>(I)) { + CondMemOp.setName(I.getName() + (IsThen ? ".then" : ".else") + ".val"); + PN->addIncoming(&CondMemOp, NewMemOpBB); + } else + LLVM_DEBUG(dbgs() << " to: " << CondMemOp << "\n"); + } + if (isa<LoadInst>(I)) { + PN->takeName(&I); + LLVM_DEBUG(dbgs() << " to: " << *PN << "\n"); + I.replaceAllUsesWith(PN); + } +} + +static void rewriteMemOpOfSelect(SelectInst &SelInst, Instruction &I, + sroa::SelectHandSpeculativity Spec, + DomTreeUpdater &DTU) { + if (auto *LI = dyn_cast<LoadInst>(&I)) + rewriteMemOpOfSelect(SelInst, *LI, Spec, DTU); + else if (auto *SI = dyn_cast<StoreInst>(&I)) + rewriteMemOpOfSelect(SelInst, *SI, Spec, DTU); + else + llvm_unreachable_internal("Only for load and store."); +} + +static bool rewriteSelectInstMemOps(SelectInst &SI, + const sroa::RewriteableMemOps &Ops, + IRBuilderTy &IRB, DomTreeUpdater *DTU) { + bool CFGChanged = false; + LLVM_DEBUG(dbgs() << " original select: " << SI << "\n"); + + for (const RewriteableMemOp &Op : Ops) { + sroa::SelectHandSpeculativity Spec; + Instruction *I; + if (auto *const *US = std::get_if<UnspeculatableStore>(&Op)) { + I = *US; + } else { + auto PSL = std::get<PossiblySpeculatableLoad>(Op); + I = PSL.getPointer(); + Spec = PSL.getInt(); + } + if (Spec.areAllSpeculatable()) { + speculateSelectInstLoads(SI, cast<LoadInst>(*I), IRB); + } else { + assert(DTU && "Should not get here when not allowed to modify the CFG!"); + rewriteMemOpOfSelect(SI, *I, Spec, *DTU); + CFGChanged = true; + } + I->eraseFromParent(); } + + for (User *U : make_early_inc_range(SI.users())) + cast<BitCastInst>(U)->eraseFromParent(); SI.eraseFromParent(); + return CFGChanged; } /// Build a GEP out of a base pointer and indices. @@ -1678,8 +1913,8 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) { return false; } - if (DL.getTypeSizeInBits(NewTy).getFixedSize() != - DL.getTypeSizeInBits(OldTy).getFixedSize()) + if (DL.getTypeSizeInBits(NewTy).getFixedValue() != + DL.getTypeSizeInBits(OldTy).getFixedValue()) return false; if (!NewTy->isSingleValueType() || !OldTy->isSingleValueType()) return false; @@ -1714,6 +1949,9 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) { return false; } + if (OldTy->isTargetExtTy() || NewTy->isTargetExtTy()) + return false; + return true; } @@ -1847,6 +2085,34 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, return true; } +/// Test whether a vector type is viable for promotion. +/// +/// This implements the necessary checking for \c isVectorPromotionViable over +/// all slices of the alloca for the given VectorType. +static bool checkVectorTypeForPromotion(Partition &P, VectorType *VTy, + const DataLayout &DL) { + uint64_t ElementSize = + DL.getTypeSizeInBits(VTy->getElementType()).getFixedValue(); + + // While the definition of LLVM vectors is bitpacked, we don't support sizes + // that aren't byte sized. + if (ElementSize % 8) + return false; + assert((DL.getTypeSizeInBits(VTy).getFixedValue() % 8) == 0 && + "vector size not a multiple of element size?"); + ElementSize /= 8; + + for (const Slice &S : P) + if (!isVectorPromotionViableForSlice(P, S, VTy, ElementSize, DL)) + return false; + + for (const Slice *S : P.splitSliceTails()) + if (!isVectorPromotionViableForSlice(P, *S, VTy, ElementSize, DL)) + return false; + + return true; +} + /// Test whether the given alloca partitioning and range of slices can be /// promoted to a vector. /// @@ -1861,23 +2127,36 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // we have different element types. SmallVector<VectorType *, 4> CandidateTys; Type *CommonEltTy = nullptr; + VectorType *CommonVecPtrTy = nullptr; + bool HaveVecPtrTy = false; bool HaveCommonEltTy = true; + bool HaveCommonVecPtrTy = true; auto CheckCandidateType = [&](Type *Ty) { if (auto *VTy = dyn_cast<VectorType>(Ty)) { // Return if bitcast to vectors is different for total size in bits. if (!CandidateTys.empty()) { VectorType *V = CandidateTys[0]; - if (DL.getTypeSizeInBits(VTy).getFixedSize() != - DL.getTypeSizeInBits(V).getFixedSize()) { + if (DL.getTypeSizeInBits(VTy).getFixedValue() != + DL.getTypeSizeInBits(V).getFixedValue()) { CandidateTys.clear(); return; } } CandidateTys.push_back(VTy); + Type *EltTy = VTy->getElementType(); + if (!CommonEltTy) - CommonEltTy = VTy->getElementType(); - else if (CommonEltTy != VTy->getElementType()) + CommonEltTy = EltTy; + else if (CommonEltTy != EltTy) HaveCommonEltTy = false; + + if (EltTy->isPointerTy()) { + HaveVecPtrTy = true; + if (!CommonVecPtrTy) + CommonVecPtrTy = VTy; + else if (CommonVecPtrTy != VTy) + HaveCommonVecPtrTy = false; + } } }; // Consider any loads or stores that are the exact size of the slice. @@ -1894,25 +2173,32 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { if (CandidateTys.empty()) return nullptr; - // Remove non-integer vector types if we had multiple common element types. - // FIXME: It'd be nice to replace them with integer vector types, but we can't - // do that until all the backends are known to produce good code for all - // integer vector types. - if (!HaveCommonEltTy) { - llvm::erase_if(CandidateTys, [](VectorType *VTy) { - return !VTy->getElementType()->isIntegerTy(); - }); - - // If there were no integer vector types, give up. - if (CandidateTys.empty()) - return nullptr; + // Pointer-ness is sticky, if we had a vector-of-pointers candidate type, + // then we should choose it, not some other alternative. + // But, we can't perform a no-op pointer address space change via bitcast, + // so if we didn't have a common pointer element type, bail. + if (HaveVecPtrTy && !HaveCommonVecPtrTy) + return nullptr; + + // Try to pick the "best" element type out of the choices. + if (!HaveCommonEltTy && HaveVecPtrTy) { + // If there was a pointer element type, there's really only one choice. + CandidateTys.clear(); + CandidateTys.push_back(CommonVecPtrTy); + } else if (!HaveCommonEltTy && !HaveVecPtrTy) { + // Integer-ify vector types. + for (VectorType *&VTy : CandidateTys) { + if (!VTy->getElementType()->isIntegerTy()) + VTy = cast<VectorType>(VTy->getWithNewType(IntegerType::getIntNTy( + VTy->getContext(), VTy->getScalarSizeInBits()))); + } // Rank the remaining candidate vector types. This is easy because we know // they're all integer vectors. We sort by ascending number of elements. auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) { (void)DL; - assert(DL.getTypeSizeInBits(RHSTy).getFixedSize() == - DL.getTypeSizeInBits(LHSTy).getFixedSize() && + assert(DL.getTypeSizeInBits(RHSTy).getFixedValue() == + DL.getTypeSizeInBits(LHSTy).getFixedValue() && "Cannot have vector types of different sizes!"); assert(RHSTy->getElementType()->isIntegerTy() && "All non-integer types eliminated!"); @@ -1939,31 +2225,15 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { CandidateTys.resize(1); } - // Try each vector type, and return the one which works. - auto CheckVectorTypeForPromotion = [&](VectorType *VTy) { - uint64_t ElementSize = - DL.getTypeSizeInBits(VTy->getElementType()).getFixedSize(); - - // While the definition of LLVM vectors is bitpacked, we don't support sizes - // that aren't byte sized. - if (ElementSize % 8) - return false; - assert((DL.getTypeSizeInBits(VTy).getFixedSize() % 8) == 0 && - "vector size not a multiple of element size?"); - ElementSize /= 8; - - for (const Slice &S : P) - if (!isVectorPromotionViableForSlice(P, S, VTy, ElementSize, DL)) - return false; - - for (const Slice *S : P.splitSliceTails()) - if (!isVectorPromotionViableForSlice(P, *S, VTy, ElementSize, DL)) - return false; + // FIXME: hack. Do we have a named constant for this? + // SDAG SDNode can't have more than 65535 operands. + llvm::erase_if(CandidateTys, [](VectorType *VTy) { + return cast<FixedVectorType>(VTy)->getNumElements() > + std::numeric_limits<unsigned short>::max(); + }); - return true; - }; for (VectorType *VTy : CandidateTys) - if (CheckVectorTypeForPromotion(VTy)) + if (checkVectorTypeForPromotion(P, VTy, DL)) return VTy; return nullptr; @@ -1978,7 +2248,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, Type *AllocaTy, const DataLayout &DL, bool &WholeAllocaOp) { - uint64_t Size = DL.getTypeStoreSize(AllocaTy).getFixedSize(); + uint64_t Size = DL.getTypeStoreSize(AllocaTy).getFixedValue(); uint64_t RelBegin = S.beginOffset() - AllocBeginOffset; uint64_t RelEnd = S.endOffset() - AllocBeginOffset; @@ -2003,7 +2273,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (LI->isVolatile()) return false; // We can't handle loads that extend past the allocated memory. - if (DL.getTypeStoreSize(LI->getType()).getFixedSize() > Size) + if (DL.getTypeStoreSize(LI->getType()).getFixedValue() > Size) return false; // So far, AllocaSliceRewriter does not support widening split slice tails // in rewriteIntegerLoad. @@ -2015,7 +2285,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (!isa<VectorType>(LI->getType()) && RelBegin == 0 && RelEnd == Size) WholeAllocaOp = true; if (IntegerType *ITy = dyn_cast<IntegerType>(LI->getType())) { - if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize()) + if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedValue()) return false; } else if (RelBegin != 0 || RelEnd != Size || !canConvertValue(DL, AllocaTy, LI->getType())) { @@ -2028,7 +2298,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (SI->isVolatile()) return false; // We can't handle stores that extend past the allocated memory. - if (DL.getTypeStoreSize(ValueTy).getFixedSize() > Size) + if (DL.getTypeStoreSize(ValueTy).getFixedValue() > Size) return false; // So far, AllocaSliceRewriter does not support widening split slice tails // in rewriteIntegerStore. @@ -2040,7 +2310,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (!isa<VectorType>(ValueTy) && RelBegin == 0 && RelEnd == Size) WholeAllocaOp = true; if (IntegerType *ITy = dyn_cast<IntegerType>(ValueTy)) { - if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize()) + if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedValue()) return false; } else if (RelBegin != 0 || RelEnd != Size || !canConvertValue(DL, ValueTy, AllocaTy)) { @@ -2068,13 +2338,13 @@ static bool isIntegerWideningViableForSlice(const Slice &S, /// promote the resulting alloca. static bool isIntegerWideningViable(Partition &P, Type *AllocaTy, const DataLayout &DL) { - uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy).getFixedSize(); + uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy).getFixedValue(); // Don't create integer types larger than the maximum bitwidth. if (SizeInBits > IntegerType::MAX_INT_BITS) return false; // Don't try to handle allocas with bit-padding. - if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy).getFixedSize()) + if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy).getFixedValue()) return false; // We need to ensure that an integer type with the appropriate bitwidth can @@ -2112,13 +2382,13 @@ static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V, const Twine &Name) { LLVM_DEBUG(dbgs() << " start: " << *V << "\n"); IntegerType *IntTy = cast<IntegerType>(V->getType()); - assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <= - DL.getTypeStoreSize(IntTy).getFixedSize() && + assert(DL.getTypeStoreSize(Ty).getFixedValue() + Offset <= + DL.getTypeStoreSize(IntTy).getFixedValue() && "Element extends past full value"); uint64_t ShAmt = 8 * Offset; if (DL.isBigEndian()) - ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() - - DL.getTypeStoreSize(Ty).getFixedSize() - Offset); + ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedValue() - + DL.getTypeStoreSize(Ty).getFixedValue() - Offset); if (ShAmt) { V = IRB.CreateLShr(V, ShAmt, Name + ".shift"); LLVM_DEBUG(dbgs() << " shifted: " << *V << "\n"); @@ -2143,13 +2413,13 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old, V = IRB.CreateZExt(V, IntTy, Name + ".ext"); LLVM_DEBUG(dbgs() << " extended: " << *V << "\n"); } - assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <= - DL.getTypeStoreSize(IntTy).getFixedSize() && + assert(DL.getTypeStoreSize(Ty).getFixedValue() + Offset <= + DL.getTypeStoreSize(IntTy).getFixedValue() && "Element store outside of alloca store"); uint64_t ShAmt = 8 * Offset; if (DL.isBigEndian()) - ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() - - DL.getTypeStoreSize(Ty).getFixedSize() - Offset); + ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedValue() - + DL.getTypeStoreSize(Ty).getFixedValue() - Offset); if (ShAmt) { V = IRB.CreateShl(V, ShAmt, Name + ".shift"); LLVM_DEBUG(dbgs() << " shifted: " << *V << "\n"); @@ -2284,6 +2554,7 @@ class llvm::sroa::AllocaSliceRewriter // original alloca. uint64_t NewBeginOffset = 0, NewEndOffset = 0; + uint64_t RelativeOffset = 0; uint64_t SliceSize = 0; bool IsSplittable = false; bool IsSplit = false; @@ -2298,6 +2569,16 @@ class llvm::sroa::AllocaSliceRewriter // the insertion point is set to point to the user. IRBuilderTy IRB; + // Return the new alloca, addrspacecasted if required to avoid changing the + // addrspace of a volatile access. + Value *getPtrToNewAI(unsigned AddrSpace, bool IsVolatile) { + if (!IsVolatile || AddrSpace == NewAI.getType()->getPointerAddressSpace()) + return &NewAI; + + Type *AccessTy = NewAI.getAllocatedType()->getPointerTo(AddrSpace); + return IRB.CreateAddrSpaceCast(&NewAI, AccessTy); + } + public: AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROAPass &Pass, AllocaInst &OldAI, AllocaInst &NewAI, @@ -2314,16 +2595,16 @@ public: IsIntegerPromotable ? Type::getIntNTy(NewAI.getContext(), DL.getTypeSizeInBits(NewAI.getAllocatedType()) - .getFixedSize()) + .getFixedValue()) : nullptr), VecTy(PromotableVecTy), ElementTy(VecTy ? VecTy->getElementType() : nullptr), - ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8 + ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy).getFixedValue() / 8 : 0), PHIUsers(PHIUsers), SelectUsers(SelectUsers), IRB(NewAI.getContext(), ConstantFolder()) { if (VecTy) { - assert((DL.getTypeSizeInBits(ElementTy).getFixedSize() % 8) == 0 && + assert((DL.getTypeSizeInBits(ElementTy).getFixedValue() % 8) == 0 && "Only multiple-of-8 sized vector elements are viable"); ++NumVectorized; } @@ -2347,8 +2628,14 @@ public: NewBeginOffset = std::max(BeginOffset, NewAllocaBeginOffset); NewEndOffset = std::min(EndOffset, NewAllocaEndOffset); + RelativeOffset = NewBeginOffset - BeginOffset; SliceSize = NewEndOffset - NewBeginOffset; - + LLVM_DEBUG(dbgs() << " Begin:(" << BeginOffset << ", " << EndOffset + << ") NewBegin:(" << NewBeginOffset << ", " + << NewEndOffset << ") NewAllocaBegin:(" + << NewAllocaBeginOffset << ", " << NewAllocaEndOffset + << ")\n"); + assert(IsSplit || RelativeOffset == 0); OldUse = I->getUse(); OldPtr = cast<Instruction>(OldUse->get()); @@ -2486,7 +2773,7 @@ private: Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8) : LI.getType(); const bool IsLoadPastEnd = - DL.getTypeStoreSize(TargetTy).getFixedSize() > SliceSize; + DL.getTypeStoreSize(TargetTy).getFixedValue() > SliceSize; bool IsPtrAdjusted = false; Value *V; if (VecTy) { @@ -2498,28 +2785,24 @@ private: (canConvertValue(DL, NewAllocaTy, TargetTy) || (IsLoadPastEnd && NewAllocaTy->isIntegerTy() && TargetTy->isIntegerTy()))) { - LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + Value *NewPtr = + getPtrToNewAI(LI.getPointerAddressSpace(), LI.isVolatile()); + LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), NewPtr, NewAI.getAlign(), LI.isVolatile(), LI.getName()); - if (AATags) - NewLI->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); if (LI.isVolatile()) NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); if (NewLI->isAtomic()) NewLI->setAlignment(LI.getAlign()); - // Any !nonnull metadata or !range metadata on the old load is also valid - // on the new load. This is even true in some cases even when the loads - // are different types, for example by mapping !nonnull metadata to - // !range metadata by modeling the null pointer constant converted to the - // integer type. - // FIXME: Add support for range metadata here. Currently the utilities - // for this don't propagate range metadata in trivial cases from one - // integer load to another, don't handle non-addrspace-0 null pointers - // correctly, and don't have any support for mapping ranges as the - // integer type becomes winder or narrower. - if (MDNode *N = LI.getMetadata(LLVMContext::MD_nonnull)) - copyNonnullMetadata(LI, N, *NewLI); + // Copy any metadata that is valid for the new load. This may require + // conversion to a different kind of metadata, e.g. !nonnull might change + // to !range or vice versa. + copyMetadataForLoad(*NewLI, LI); + + // Do this after copyMetadataForLoad() to preserve the TBAA shift. + if (AATags) + NewLI->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); // Try to preserve nonnull metadata V = NewLI; @@ -2556,7 +2839,7 @@ private: assert(!LI.isVolatile()); assert(LI.getType()->isIntegerTy() && "Only integer type loads and stores are split"); - assert(SliceSize < DL.getTypeStoreSize(LI.getType()).getFixedSize() && + assert(SliceSize < DL.getTypeStoreSize(LI.getType()).getFixedValue() && "Split load isn't smaller than original load"); assert(DL.typeSizeEqualsStoreSize(LI.getType()) && "Non-byte-multiple bit width"); @@ -2586,6 +2869,9 @@ private: bool rewriteVectorizedStoreInst(Value *V, StoreInst &SI, Value *OldOp, AAMDNodes AATags) { + // Capture V for the purpose of debug-info accounting once it's converted + // to a vector store. + Value *OrigV = V; if (V->getType() != VecTy) { unsigned BeginIndex = getIndex(NewBeginOffset); unsigned EndIndex = getIndex(NewEndOffset); @@ -2611,6 +2897,9 @@ private: Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); Pass.DeadInsts.push_back(&SI); + // NOTE: Careful to use OrigV rather than V. + migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, Store, + Store->getPointerOperand(), OrigV, DL); LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return true; } @@ -2618,7 +2907,7 @@ private: bool rewriteIntegerStore(Value *V, StoreInst &SI, AAMDNodes AATags) { assert(IntTy && "We cannot extract an integer from the alloca"); assert(!SI.isVolatile()); - if (DL.getTypeSizeInBits(V->getType()).getFixedSize() != + if (DL.getTypeSizeInBits(V->getType()).getFixedValue() != IntTy->getBitWidth()) { Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, NewAI.getAlign(), "oldload"); @@ -2633,6 +2922,10 @@ private: LLVMContext::MD_access_group}); if (AATags) Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + + migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, Store, + Store->getPointerOperand(), Store->getValueOperand(), DL); + Pass.DeadInsts.push_back(&SI); LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return true; @@ -2652,7 +2945,7 @@ private: if (AllocaInst *AI = dyn_cast<AllocaInst>(V->stripInBoundsOffsets())) Pass.PostPromotionWorklist.insert(AI); - if (SliceSize < DL.getTypeStoreSize(V->getType()).getFixedSize()) { + if (SliceSize < DL.getTypeStoreSize(V->getType()).getFixedValue()) { assert(!SI.isVolatile()); assert(V->getType()->isIntegerTy() && "Only integer type loads and stores are split"); @@ -2669,7 +2962,7 @@ private: return rewriteIntegerStore(V, SI, AATags); const bool IsStorePastEnd = - DL.getTypeStoreSize(V->getType()).getFixedSize() > SliceSize; + DL.getTypeStoreSize(V->getType()).getFixedValue() > SliceSize; StoreInst *NewSI; if (NewBeginOffset == NewAllocaBeginOffset && NewEndOffset == NewAllocaEndOffset && @@ -2689,8 +2982,11 @@ private: } V = convertValue(DL, IRB, V, NewAllocaTy); + Value *NewPtr = + getPtrToNewAI(SI.getPointerAddressSpace(), SI.isVolatile()); + NewSI = - IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign(), SI.isVolatile()); + IRB.CreateAlignedStore(V, NewPtr, NewAI.getAlign(), SI.isVolatile()); } else { unsigned AS = SI.getPointerAddressSpace(); Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS)); @@ -2705,6 +3001,10 @@ private: NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); if (NewSI->isAtomic()) NewSI->setAlignment(SI.getAlign()); + + migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &SI, NewSI, + NewSI->getPointerOperand(), NewSI->getValueOperand(), DL); + Pass.DeadInsts.push_back(&SI); deleteIfTriviallyDead(OldOp); @@ -2760,7 +3060,11 @@ private: assert(NewBeginOffset == BeginOffset); II.setDest(getNewAllocaSlicePtr(IRB, OldPtr->getType())); II.setDestAlignment(getSliceAlign()); - + // In theory we should call migrateDebugInfo here. However, we do not + // emit dbg.assign intrinsics for mem intrinsics storing through non- + // constant geps, or storing a variable number of bytes. + assert(at::getAssignmentMarkers(&II).empty() && + "AT: Unexpected link to non-const GEP"); deleteIfTriviallyDead(OldPtr); return false; } @@ -2785,7 +3089,7 @@ private: auto *Int8Ty = IntegerType::getInt8Ty(NewAI.getContext()); auto *SrcTy = FixedVectorType::get(Int8Ty, Len); return canConvertValue(DL, SrcTy, AllocaTy) && - DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy).getFixedSize()); + DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy).getFixedValue()); }(); // If this doesn't map cleanly onto the alloca type, and that type isn't @@ -2793,11 +3097,15 @@ private: if (!CanContinue) { Type *SizeTy = II.getLength()->getType(); Constant *Size = ConstantInt::get(SizeTy, NewEndOffset - NewBeginOffset); - CallInst *New = IRB.CreateMemSet( + MemIntrinsic *New = cast<MemIntrinsic>(IRB.CreateMemSet( getNewAllocaSlicePtr(IRB, OldPtr->getType()), II.getValue(), Size, - MaybeAlign(getSliceAlign()), II.isVolatile()); + MaybeAlign(getSliceAlign()), II.isVolatile())); if (AATags) New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + + migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New, + New->getRawDest(), nullptr, DL); + LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return false; } @@ -2821,7 +3129,7 @@ private: "Too many elements!"); Value *Splat = getIntegerSplat( - II.getValue(), DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8); + II.getValue(), DL.getTypeSizeInBits(ElementTy).getFixedValue() / 8); Splat = convertValue(DL, IRB, Splat, ElementTy); if (NumElements > 1) Splat = getVectorSplat(Splat, NumElements); @@ -2855,7 +3163,7 @@ private: assert(NewEndOffset == NewAllocaEndOffset); V = getIntegerSplat(II.getValue(), - DL.getTypeSizeInBits(ScalarTy).getFixedSize() / 8); + DL.getTypeSizeInBits(ScalarTy).getFixedValue() / 8); if (VectorType *AllocaVecTy = dyn_cast<VectorType>(AllocaTy)) V = getVectorSplat( V, cast<FixedVectorType>(AllocaVecTy)->getNumElements()); @@ -2863,12 +3171,17 @@ private: V = convertValue(DL, IRB, V, AllocaTy); } + Value *NewPtr = getPtrToNewAI(II.getDestAddressSpace(), II.isVolatile()); StoreInst *New = - IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign(), II.isVolatile()); + IRB.CreateAlignedStore(V, NewPtr, NewAI.getAlign(), II.isVolatile()); New->copyMetadata(II, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); if (AATags) New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + + migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New, + New->getPointerOperand(), V, DL); + LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return !II.isVolatile(); } @@ -2886,7 +3199,6 @@ private: (!IsDest && II.getRawSource() == OldPtr)); Align SliceAlign = getSliceAlign(); - // For unsplit intrinsics, we simply modify the source and destination // pointers in place. This isn't just an optimization, it is a matter of // correctness. With unsplit intrinsics we may be dealing with transfers @@ -2897,10 +3209,16 @@ private: if (!IsSplittable) { Value *AdjustedPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); if (IsDest) { + // Update the address component of linked dbg.assigns. + for (auto *DAI : at::getAssignmentMarkers(&II)) { + if (any_of(DAI->location_ops(), + [&](Value *V) { return V == II.getDest(); }) || + DAI->getAddress() == II.getDest()) + DAI->replaceVariableLocationOp(II.getDest(), AdjustedPtr); + } II.setDest(AdjustedPtr); II.setDestAlignment(SliceAlign); - } - else { + } else { II.setSource(AdjustedPtr); II.setSourceAlignment(SliceAlign); } @@ -2921,7 +3239,7 @@ private: !VecTy && !IntTy && (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset || SliceSize != - DL.getTypeStoreSize(NewAI.getAllocatedType()).getFixedSize() || + DL.getTypeStoreSize(NewAI.getAllocatedType()).getFixedValue() || !NewAI.getAllocatedType()->isSingleValueType()); // If we're just going to emit a memcpy, the alloca hasn't changed, and the @@ -2989,6 +3307,9 @@ private: Size, II.isVolatile()); if (AATags) New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + + migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, New, + DestPtr, nullptr, DL); LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return false; } @@ -3017,14 +3338,22 @@ private: } OtherPtrTy = OtherTy->getPointerTo(OtherAS); - Value *SrcPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy, + Value *AdjPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy, OtherPtr->getName() + "."); MaybeAlign SrcAlign = OtherAlign; - Value *DstPtr = &NewAI; MaybeAlign DstAlign = SliceAlign; - if (!IsDest) { - std::swap(SrcPtr, DstPtr); + if (!IsDest) std::swap(SrcAlign, DstAlign); + + Value *SrcPtr; + Value *DstPtr; + + if (IsDest) { + DstPtr = getPtrToNewAI(II.getDestAddressSpace(), II.isVolatile()); + SrcPtr = AdjPtr; + } else { + DstPtr = AdjPtr; + SrcPtr = getPtrToNewAI(II.getSourceAddressSpace(), II.isVolatile()); } Value *Src; @@ -3067,6 +3396,9 @@ private: LLVMContext::MD_access_group}); if (AATags) Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + + migrateDebugInfo(&OldAI, RelativeOffset * 8, SliceSize * 8, &II, Store, + DstPtr, Src, DL); LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return !II.isVolatile(); } @@ -3404,12 +3736,13 @@ private: struct StoreOpSplitter : public OpSplitter<StoreOpSplitter> { StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy, - AAMDNodes AATags, Align BaseAlign, const DataLayout &DL, - IRBuilderTy &IRB) + AAMDNodes AATags, StoreInst *AggStore, Align BaseAlign, + const DataLayout &DL, IRBuilderTy &IRB) : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr, BaseTy, BaseAlign, DL, IRB), - AATags(AATags) {} + AATags(AATags), AggStore(AggStore) {} AAMDNodes AATags; + StoreInst *AggStore; /// Emit a leaf store of a single value. This is called at the leaves of the /// recursive emission to actually produce stores. void emitFunc(Type *Ty, Value *&Agg, Align Alignment, const Twine &Name) { @@ -3431,6 +3764,24 @@ private: GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset)) Store->setAAMetadata(AATags.shift(Offset.getZExtValue())); + // migrateDebugInfo requires the base Alloca. Walk to it from this gep. + // If we cannot (because there's an intervening non-const or unbounded + // gep) then we wouldn't expect to see dbg.assign intrinsics linked to + // this instruction. + APInt OffsetInBytes(DL.getTypeSizeInBits(Ptr->getType()), false); + Value *Base = InBoundsGEP->stripAndAccumulateInBoundsConstantOffsets( + DL, OffsetInBytes); + if (auto *OldAI = dyn_cast<AllocaInst>(Base)) { + uint64_t SizeInBits = + DL.getTypeSizeInBits(Store->getValueOperand()->getType()); + migrateDebugInfo(OldAI, OffsetInBytes.getZExtValue() * 8, SizeInBits, + AggStore, Store, Store->getPointerOperand(), + Store->getValueOperand(), DL); + } else { + assert(at::getAssignmentMarkers(Store).empty() && + "AT: unexpected debug.assign linked to store through " + "unbounded GEP"); + } LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); } }; @@ -3444,7 +3795,7 @@ private: // We have an aggregate being stored, split it apart. LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); - StoreOpSplitter Splitter(&SI, *U, V->getType(), SI.getAAMetadata(), + StoreOpSplitter Splitter(&SI, *U, V->getType(), SI.getAAMetadata(), &SI, getAdjustedAlignment(&SI, 0), DL, IRB); Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); Visited.erase(&SI); @@ -3593,8 +3944,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { if (Ty->isSingleValueType()) return Ty; - uint64_t AllocSize = DL.getTypeAllocSize(Ty).getFixedSize(); - uint64_t TypeSize = DL.getTypeSizeInBits(Ty).getFixedSize(); + uint64_t AllocSize = DL.getTypeAllocSize(Ty).getFixedValue(); + uint64_t TypeSize = DL.getTypeSizeInBits(Ty).getFixedValue(); Type *InnerTy; if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) { @@ -3607,8 +3958,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { return Ty; } - if (AllocSize > DL.getTypeAllocSize(InnerTy).getFixedSize() || - TypeSize > DL.getTypeSizeInBits(InnerTy).getFixedSize()) + if (AllocSize > DL.getTypeAllocSize(InnerTy).getFixedValue() || + TypeSize > DL.getTypeSizeInBits(InnerTy).getFixedValue()) return Ty; return stripAggregateTypeWrapping(DL, InnerTy); @@ -3629,10 +3980,10 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { /// return a type if necessary. static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, uint64_t Size) { - if (Offset == 0 && DL.getTypeAllocSize(Ty).getFixedSize() == Size) + if (Offset == 0 && DL.getTypeAllocSize(Ty).getFixedValue() == Size) return stripAggregateTypeWrapping(DL, Ty); - if (Offset > DL.getTypeAllocSize(Ty).getFixedSize() || - (DL.getTypeAllocSize(Ty).getFixedSize() - Offset) < Size) + if (Offset > DL.getTypeAllocSize(Ty).getFixedValue() || + (DL.getTypeAllocSize(Ty).getFixedValue() - Offset) < Size) return nullptr; if (isa<ArrayType>(Ty) || isa<VectorType>(Ty)) { @@ -3648,7 +3999,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, ElementTy = VT->getElementType(); TyNumElements = VT->getNumElements(); } - uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize(); + uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedValue(); uint64_t NumSkippedElements = Offset / ElementSize; if (NumSkippedElements >= TyNumElements) return nullptr; @@ -3688,7 +4039,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, Offset -= SL->getElementOffset(Index); Type *ElementTy = STy->getElementType(Index); - uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize(); + uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedValue(); if (Offset >= ElementSize) return nullptr; // The offset points into alignment padding. @@ -3723,7 +4074,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, // Try to build up a sub-structure. StructType *SubTy = - StructType::get(STy->getContext(), makeArrayRef(EI, EE), STy->isPacked()); + StructType::get(STy->getContext(), ArrayRef(EI, EE), STy->isPacked()); const StructLayout *SubSL = DL.getStructLayout(SubTy); if (Size != SubSL->getSizeInBytes()) return nullptr; // The sub-struct doesn't have quite the size needed. @@ -3741,20 +4092,15 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, /// the following: /// /// %a = alloca [12 x i8] -/// %gep1 = getelementptr [12 x i8]* %a, i32 0, i32 0 -/// %gep2 = getelementptr [12 x i8]* %a, i32 0, i32 4 -/// %gep3 = getelementptr [12 x i8]* %a, i32 0, i32 8 -/// %iptr1 = bitcast i8* %gep1 to i64* -/// %iptr2 = bitcast i8* %gep2 to i64* -/// %fptr1 = bitcast i8* %gep1 to float* -/// %fptr2 = bitcast i8* %gep2 to float* -/// %fptr3 = bitcast i8* %gep3 to float* -/// store float 0.0, float* %fptr1 -/// store float 1.0, float* %fptr2 -/// %v = load i64* %iptr1 -/// store i64 %v, i64* %iptr2 -/// %f1 = load float* %fptr2 -/// %f2 = load float* %fptr3 +/// %gep1 = getelementptr i8, ptr %a, i32 0 +/// %gep2 = getelementptr i8, ptr %a, i32 4 +/// %gep3 = getelementptr i8, ptr %a, i32 8 +/// store float 0.0, ptr %gep1 +/// store float 1.0, ptr %gep2 +/// %v = load i64, ptr %gep1 +/// store i64 %v, ptr %gep2 +/// %f1 = load float, ptr %gep2 +/// %f2 = load float, ptr %gep3 /// /// Here we want to form 3 partitions of the alloca, each 4 bytes large, and /// promote everything so we recover the 2 SSA values that should have been @@ -4050,7 +4396,8 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { getAdjustedAlignment(SI, PartOffset), /*IsVolatile*/ false); PStore->copyMetadata(*SI, {LLVMContext::MD_mem_parallel_loop_access, - LLVMContext::MD_access_group}); + LLVMContext::MD_access_group, + LLVMContext::MD_DIAssignID}); LLVM_DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); } @@ -4246,29 +4593,48 @@ AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // won't always succeed, in which case we fall back to a legal integer type // or an i8 array of an appropriate size. Type *SliceTy = nullptr; + VectorType *SliceVecTy = nullptr; const DataLayout &DL = AI.getModule()->getDataLayout(); std::pair<Type *, IntegerType *> CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset()); // Do all uses operate on the same type? if (CommonUseTy.first) - if (DL.getTypeAllocSize(CommonUseTy.first).getFixedSize() >= P.size()) + if (DL.getTypeAllocSize(CommonUseTy.first).getFixedValue() >= P.size()) { SliceTy = CommonUseTy.first; + SliceVecTy = dyn_cast<VectorType>(SliceTy); + } // If not, can we find an appropriate subtype in the original allocated type? if (!SliceTy) if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(), P.beginOffset(), P.size())) SliceTy = TypePartitionTy; + // If still not, can we use the largest bitwidth integer type used? if (!SliceTy && CommonUseTy.second) - if (DL.getTypeAllocSize(CommonUseTy.second).getFixedSize() >= P.size()) + if (DL.getTypeAllocSize(CommonUseTy.second).getFixedValue() >= P.size()) { SliceTy = CommonUseTy.second; + SliceVecTy = dyn_cast<VectorType>(SliceTy); + } if ((!SliceTy || (SliceTy->isArrayTy() && SliceTy->getArrayElementType()->isIntegerTy())) && - DL.isLegalInteger(P.size() * 8)) + DL.isLegalInteger(P.size() * 8)) { SliceTy = Type::getIntNTy(*C, P.size() * 8); + } + + // If the common use types are not viable for promotion then attempt to find + // another type that is viable. + if (SliceVecTy && !checkVectorTypeForPromotion(P, SliceVecTy, DL)) + if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(), + P.beginOffset(), P.size())) { + VectorType *TypePartitionVecTy = dyn_cast<VectorType>(TypePartitionTy); + if (TypePartitionVecTy && + checkVectorTypeForPromotion(P, TypePartitionVecTy, DL)) + SliceTy = TypePartitionTy; + } + if (!SliceTy) SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size()); - assert(DL.getTypeAllocSize(SliceTy).getFixedSize() >= P.size()); + assert(DL.getTypeAllocSize(SliceTy).getFixedValue() >= P.size()); bool IsIntegerPromotable = isIntegerWideningViable(P, SliceTy, DL); @@ -4296,7 +4662,7 @@ AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // the alloca's alignment unconstrained. const bool IsUnconstrained = Alignment <= DL.getABITypeAlign(SliceTy); NewAI = new AllocaInst( - SliceTy, AI.getType()->getAddressSpace(), nullptr, + SliceTy, AI.getAddressSpace(), nullptr, IsUnconstrained ? DL.getPrefTypeAlign(SliceTy) : Alignment, AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI); // Copy the old AI debug location over to the new one. @@ -4342,13 +4708,21 @@ AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS, break; } - for (SelectInst *Sel : SelectUsers) - if (!isSafeSelectToSpeculate(*Sel)) { + SmallVector<std::pair<SelectInst *, RewriteableMemOps>, 2> + NewSelectsToRewrite; + NewSelectsToRewrite.reserve(SelectUsers.size()); + for (SelectInst *Sel : SelectUsers) { + std::optional<RewriteableMemOps> Ops = + isSafeSelectToSpeculate(*Sel, PreserveCFG); + if (!Ops) { Promotable = false; PHIUsers.clear(); SelectUsers.clear(); + NewSelectsToRewrite.clear(); break; } + NewSelectsToRewrite.emplace_back(std::make_pair(Sel, *Ops)); + } if (Promotable) { for (Use *U : AS.getDeadUsesIfPromotable()) { @@ -4367,8 +4741,12 @@ AllocaInst *SROAPass::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // next iteration. for (PHINode *PHIUser : PHIUsers) SpeculatablePHIs.insert(PHIUser); - for (SelectInst *SelectUser : SelectUsers) - SpeculatableSelects.insert(SelectUser); + SelectsToRewrite.reserve(SelectsToRewrite.size() + + NewSelectsToRewrite.size()); + for (auto &&KV : llvm::make_range( + std::make_move_iterator(NewSelectsToRewrite.begin()), + std::make_move_iterator(NewSelectsToRewrite.end()))) + SelectsToRewrite.insert(std::move(KV)); Worklist.insert(NewAI); } } else { @@ -4412,7 +4790,7 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { bool IsSorted = true; uint64_t AllocaSize = - DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize(); + DL.getTypeAllocSize(AI.getAllocatedType()).getFixedValue(); const uint64_t MaxBitVectorSize = 1024; if (AllocaSize <= MaxBitVectorSize) { // If a byte boundary is included in any load or store, a slice starting or @@ -4477,7 +4855,7 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { if (NewAI != &AI) { uint64_t SizeOfByte = 8; uint64_t AllocaSize = - DL.getTypeSizeInBits(NewAI->getAllocatedType()).getFixedSize(); + DL.getTypeSizeInBits(NewAI->getAllocatedType()).getFixedValue(); // Don't include any padding. uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte); Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size)); @@ -4492,11 +4870,13 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { // Migrate debug information from the old alloca to the new alloca(s) // and the individual partitions. TinyPtrVector<DbgVariableIntrinsic *> DbgDeclares = FindDbgAddrUses(&AI); + for (auto *DbgAssign : at::getAssignmentMarkers(&AI)) + DbgDeclares.push_back(DbgAssign); for (DbgVariableIntrinsic *DbgDeclare : DbgDeclares) { auto *Expr = DbgDeclare->getExpression(); DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); uint64_t AllocaSize = - DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedSize(); + DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedValue(); for (auto Fragment : Fragments) { // Create a fragment expression describing the new partition or reuse AI's // expression if there is only one partition. @@ -4511,9 +4891,10 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { if (ExprFragment) { uint64_t AbsEnd = ExprFragment->OffsetInBits + ExprFragment->SizeInBits; - if (Start >= AbsEnd) + if (Start >= AbsEnd) { // No need to describe a SROAed padding. continue; + } Size = std::min(Size, AbsEnd - Start); } // The new, smaller fragment is stenciled out from the old fragment. @@ -4555,8 +4936,23 @@ bool SROAPass::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { OldDII->eraseFromParent(); } - DIB.insertDeclare(Fragment.Alloca, DbgDeclare->getVariable(), FragmentExpr, - DbgDeclare->getDebugLoc(), &AI); + if (auto *DbgAssign = dyn_cast<DbgAssignIntrinsic>(DbgDeclare)) { + if (!Fragment.Alloca->hasMetadata(LLVMContext::MD_DIAssignID)) { + Fragment.Alloca->setMetadata( + LLVMContext::MD_DIAssignID, + DIAssignID::getDistinct(AI.getContext())); + } + auto *NewAssign = DIB.insertDbgAssign( + Fragment.Alloca, DbgAssign->getValue(), DbgAssign->getVariable(), + FragmentExpr, Fragment.Alloca, DbgAssign->getAddressExpression(), + DbgAssign->getDebugLoc()); + NewAssign->setDebugLoc(DbgAssign->getDebugLoc()); + LLVM_DEBUG(dbgs() << "Created new assign intrinsic: " << *NewAssign + << "\n"); + } else { + DIB.insertDeclare(Fragment.Alloca, DbgDeclare->getVariable(), + FragmentExpr, DbgDeclare->getDebugLoc(), &AI); + } } } return Changed; @@ -4582,24 +4978,27 @@ void SROAPass::clobberUse(Use &U) { /// This analyzes the alloca to ensure we can reason about it, builds /// the slices of the alloca, and then hands it off to be split and /// rewritten as needed. -bool SROAPass::runOnAlloca(AllocaInst &AI) { +std::pair<bool /*Changed*/, bool /*CFGChanged*/> +SROAPass::runOnAlloca(AllocaInst &AI) { + bool Changed = false; + bool CFGChanged = false; + LLVM_DEBUG(dbgs() << "SROA alloca: " << AI << "\n"); ++NumAllocasAnalyzed; // Special case dead allocas, as they're trivial. if (AI.use_empty()) { AI.eraseFromParent(); - return true; + Changed = true; + return {Changed, CFGChanged}; } const DataLayout &DL = AI.getModule()->getDataLayout(); // Skip alloca forms that this analysis can't handle. auto *AT = AI.getAllocatedType(); if (AI.isArrayAllocation() || !AT->isSized() || isa<ScalableVectorType>(AT) || - DL.getTypeAllocSize(AT).getFixedSize() == 0) - return false; - - bool Changed = false; + DL.getTypeAllocSize(AT).getFixedValue() == 0) + return {Changed, CFGChanged}; // First, split any FCA loads and stores touching this alloca to promote // better splitting and promotion opportunities. @@ -4611,7 +5010,7 @@ bool SROAPass::runOnAlloca(AllocaInst &AI) { AllocaSlices AS(DL, AI); LLVM_DEBUG(AS.print(dbgs())); if (AS.isEscaped()) - return Changed; + return {Changed, CFGChanged}; // Delete all the dead users of this alloca before splitting and rewriting it. for (Instruction *DeadUser : AS.getDeadUsers()) { @@ -4633,7 +5032,7 @@ bool SROAPass::runOnAlloca(AllocaInst &AI) { // No slices to split. Leave the dead alloca for a later pass to clean up. if (AS.begin() == AS.end()) - return Changed; + return {Changed, CFGChanged}; Changed |= splitAlloca(AI, AS); @@ -4641,11 +5040,15 @@ bool SROAPass::runOnAlloca(AllocaInst &AI) { while (!SpeculatablePHIs.empty()) speculatePHINodeLoads(IRB, *SpeculatablePHIs.pop_back_val()); - LLVM_DEBUG(dbgs() << " Speculating Selects\n"); - while (!SpeculatableSelects.empty()) - speculateSelectInstLoads(IRB, *SpeculatableSelects.pop_back_val()); + LLVM_DEBUG(dbgs() << " Rewriting Selects\n"); + auto RemainingSelectsToRewrite = SelectsToRewrite.takeVector(); + while (!RemainingSelectsToRewrite.empty()) { + const auto [K, V] = RemainingSelectsToRewrite.pop_back_val(); + CFGChanged |= + rewriteSelectInstMemOps(*K, V, IRB, PreserveCFG ? nullptr : DTU); + } - return Changed; + return {Changed, CFGChanged}; } /// Delete the dead instructions accumulated in this run. @@ -4662,7 +5065,8 @@ bool SROAPass::deleteDeadInstructions( bool Changed = false; while (!DeadInsts.empty()) { Instruction *I = dyn_cast_or_null<Instruction>(DeadInsts.pop_back_val()); - if (!I) continue; + if (!I) + continue; LLVM_DEBUG(dbgs() << "Deleting dead instruction: " << *I << "\n"); // If the instruction is an alloca, find the possible dbg.declare connected @@ -4674,6 +5078,7 @@ bool SROAPass::deleteDeadInstructions( OldDII->eraseFromParent(); } + at::deleteAssignmentMarkers(I); I->replaceAllUsesWith(UndefValue::get(I->getType())); for (Use &Operand : I->operands()) @@ -4703,16 +5108,16 @@ bool SROAPass::promoteAllocas(Function &F) { NumPromoted += PromotableAllocas.size(); LLVM_DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); - PromoteMemToReg(PromotableAllocas, *DT, AC); + PromoteMemToReg(PromotableAllocas, DTU->getDomTree(), AC); PromotableAllocas.clear(); return true; } -PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT, +PreservedAnalyses SROAPass::runImpl(Function &F, DomTreeUpdater &RunDTU, AssumptionCache &RunAC) { LLVM_DEBUG(dbgs() << "SROA function: " << F.getName() << "\n"); C = &F.getContext(); - DT = &RunDT; + DTU = &RunDTU; AC = &RunAC; BasicBlock &EntryBB = F.getEntryBlock(); @@ -4729,13 +5134,18 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT, } bool Changed = false; + bool CFGChanged = false; // A set of deleted alloca instruction pointers which should be removed from // the list of promotable allocas. SmallPtrSet<AllocaInst *, 4> DeletedAllocas; do { while (!Worklist.empty()) { - Changed |= runOnAlloca(*Worklist.pop_back_val()); + auto [IterationChanged, IterationCFGChanged] = + runOnAlloca(*Worklist.pop_back_val()); + Changed |= IterationChanged; + CFGChanged |= IterationCFGChanged; + Changed |= deleteDeadInstructions(DeletedAllocas); // Remove the deleted allocas from various lists so that we don't try to @@ -4755,19 +5165,41 @@ PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT, PostPromotionWorklist.clear(); } while (!Worklist.empty()); + assert((!CFGChanged || Changed) && "Can not only modify the CFG."); + assert((!CFGChanged || !PreserveCFG) && + "Should not have modified the CFG when told to preserve it."); + if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserveSet<CFGAnalyses>(); + if (!CFGChanged) + PA.preserveSet<CFGAnalyses>(); + PA.preserve<DominatorTreeAnalysis>(); return PA; } +PreservedAnalyses SROAPass::runImpl(Function &F, DominatorTree &RunDT, + AssumptionCache &RunAC) { + DomTreeUpdater DTU(RunDT, DomTreeUpdater::UpdateStrategy::Lazy); + return runImpl(F, DTU, RunAC); +} + PreservedAnalyses SROAPass::run(Function &F, FunctionAnalysisManager &AM) { return runImpl(F, AM.getResult<DominatorTreeAnalysis>(F), AM.getResult<AssumptionAnalysis>(F)); } +void SROAPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<SROAPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << (PreserveCFG ? "<preserve-cfg>" : "<modify-cfg>"); +} + +SROAPass::SROAPass(SROAOptions PreserveCFG_) + : PreserveCFG(PreserveCFG_ == SROAOptions::PreserveCFG) {} + /// A legacy pass for the legacy pass manager that wraps the \c SROA pass. /// /// This is in the llvm namespace purely to allow it to be a friend of the \c @@ -4779,7 +5211,8 @@ class llvm::sroa::SROALegacyPass : public FunctionPass { public: static char ID; - SROALegacyPass() : FunctionPass(ID) { + SROALegacyPass(SROAOptions PreserveCFG = SROAOptions::PreserveCFG) + : FunctionPass(ID), Impl(PreserveCFG) { initializeSROALegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -4797,7 +5230,7 @@ public: AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); - AU.setPreservesCFG(); + AU.addPreserved<DominatorTreeWrapperPass>(); } StringRef getPassName() const override { return "SROA"; } @@ -4805,7 +5238,10 @@ public: char SROALegacyPass::ID = 0; -FunctionPass *llvm::createSROAPass() { return new SROALegacyPass(); } +FunctionPass *llvm::createSROAPass(bool PreserveCFG) { + return new SROALegacyPass(PreserveCFG ? SROAOptions::PreserveCFG + : SROAOptions::ModifyCFG); +} INITIALIZE_PASS_BEGIN(SROALegacyPass, "sroa", "Scalar Replacement Of Aggregates", false, false) diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp index 5ab9e25577d8..8aee8d140a29 100644 --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -31,12 +31,10 @@ using namespace llvm; /// ScalarOpts library. void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeADCELegacyPassPass(Registry); - initializeAnnotationRemarksLegacyPass(Registry); initializeBDCELegacyPassPass(Registry); initializeAlignmentFromAssumptionsPass(Registry); initializeCallSiteSplittingLegacyPassPass(Registry); initializeConstantHoistingLegacyPassPass(Registry); - initializeConstraintEliminationPass(Registry); initializeCorrelatedValuePropagationPass(Registry); initializeDCELegacyPassPass(Registry); initializeDivRemPairsLegacyPassPass(Registry); diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index e2976ace3a4a..1c8e4e3512dc 100644 --- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -35,6 +35,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <cassert> +#include <optional> using namespace llvm; @@ -656,7 +657,7 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, // shuffle blend with the pass through value. if (isConstantIntVector(Mask)) { unsigned MemIndex = 0; - VResult = UndefValue::get(VecType); + VResult = PoisonValue::get(VecType); SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { Value *InsertElt; @@ -861,7 +862,7 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT) { - Optional<DomTreeUpdater> DTU; + std::optional<DomTreeUpdater> DTU; if (DT) DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); @@ -873,7 +874,7 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI, for (BasicBlock &BB : llvm::make_early_inc_range(F)) { bool ModifiedDTOnIteration = false; MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL, - DTU ? DTU.getPointer() : nullptr); + DTU ? &*DTU : nullptr); // Restart BB iteration if the dominator tree of the Function was changed if (ModifiedDTOnIteration) diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 08f4b2173da2..4aab88b74f10 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -76,10 +76,13 @@ BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) { // Used to store the scattered form of a vector. using ValueVector = SmallVector<Value *, 8>; -// Used to map a vector Value to its scattered form. We use std::map -// because we want iterators to persist across insertion and because the -// values are relatively large. -using ScatterMap = std::map<Value *, ValueVector>; +// Used to map a vector Value and associated type to its scattered form. +// The associated type is only non-null for pointer values that are "scattered" +// when used as pointer operands to load or store. +// +// We use std::map because we want iterators to persist across insertion and +// because the values are relatively large. +using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>; // Lists Instructions that have been replaced with scalar implementations, // along with a pointer to their scattered forms. @@ -113,7 +116,7 @@ private: unsigned Size; }; -// FCmpSpliiter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp +// FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp // called Name that compares X and Y in the same way as FCI. struct FCmpSplitter { FCmpSplitter(FCmpInst &fci) : FCI(fci) {} @@ -126,7 +129,7 @@ struct FCmpSplitter { FCmpInst &FCI; }; -// ICmpSpliiter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp +// ICmpSplitter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp // called Name that compares X and Y in the same way as ICI. struct ICmpSplitter { ICmpSplitter(ICmpInst &ici) : ICI(ici) {} @@ -139,7 +142,7 @@ struct ICmpSplitter { ICmpInst &ICI; }; -// UnarySpliiter(UO)(Builder, X, Name) uses Builder to create +// UnarySplitter(UO)(Builder, X, Name) uses Builder to create // a unary operator like UO called Name with operand X. struct UnarySplitter { UnarySplitter(UnaryOperator &uo) : UO(uo) {} @@ -151,7 +154,7 @@ struct UnarySplitter { UnaryOperator &UO; }; -// BinarySpliiter(BO)(Builder, X, Y, Name) uses Builder to create +// BinarySplitter(BO)(Builder, X, Y, Name) uses Builder to create // a binary operator like BO called Name with operands X and Y. struct BinarySplitter { BinarySplitter(BinaryOperator &bo) : BO(bo) {} @@ -174,7 +177,7 @@ struct VectorLayout { } // The type of the vector. - VectorType *VecTy = nullptr; + FixedVectorType *VecTy = nullptr; // The type of each element. Type *ElemTy = nullptr; @@ -188,7 +191,7 @@ struct VectorLayout { template <typename T> T getWithDefaultOverride(const cl::opt<T> &ClOption, - const llvm::Optional<T> &DefaultOverride) { + const std::optional<T> &DefaultOverride) { return ClOption.getNumOccurrences() ? ClOption : DefaultOverride.value_or(ClOption); } @@ -232,8 +235,8 @@ private: void replaceUses(Instruction *Op, Value *CV); bool canTransferMetadata(unsigned Kind); void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV); - Optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment, - const DataLayout &DL); + std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment, + const DataLayout &DL); bool finish(); template<typename T> bool splitUnary(Instruction &, const T &); @@ -389,7 +392,7 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V, // so that it can be used everywhere. Function *F = VArg->getParent(); BasicBlock *BB = &F->getEntryBlock(); - return Scatterer(BB, BB->begin(), V, PtrElemTy, &Scattered[V]); + return Scatterer(BB, BB->begin(), V, PtrElemTy, &Scattered[{V, PtrElemTy}]); } if (Instruction *VOp = dyn_cast<Instruction>(V)) { // When scalarizing PHI nodes we might try to examine/rewrite InsertElement @@ -406,7 +409,7 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V, BasicBlock *BB = VOp->getParent(); return Scatterer( BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, - PtrElemTy, &Scattered[V]); + PtrElemTy, &Scattered[{V, PtrElemTy}]); } // In the fallback case, just put the scattered before Point and // keep the result local to Point. @@ -422,7 +425,7 @@ void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { // If we already have a scattered form of Op (created from ExtractElements // of Op itself), replace them with the new form. - ValueVector &SV = Scattered[Op]; + ValueVector &SV = Scattered[{Op, nullptr}]; if (!SV.empty()) { for (unsigned I = 0, E = SV.size(); I != E; ++I) { Value *V = SV[I]; @@ -481,19 +484,20 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op, } // Try to fill in Layout from Ty, returning true on success. Alignment is -// the alignment of the vector, or None if the ABI default should be used. -Optional<VectorLayout> +// the alignment of the vector, or std::nullopt if the ABI default should be +// used. +std::optional<VectorLayout> ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment, const DataLayout &DL) { VectorLayout Layout; // Make sure we're dealing with a vector. - Layout.VecTy = dyn_cast<VectorType>(Ty); + Layout.VecTy = dyn_cast<FixedVectorType>(Ty); if (!Layout.VecTy) - return None; + return std::nullopt; // Check that we're dealing with full-byte elements. Layout.ElemTy = Layout.VecTy->getElementType(); if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy)) - return None; + return std::nullopt; Layout.VecAlign = Alignment; Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy); return Layout; @@ -503,11 +507,11 @@ ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment, // to create an instruction like I with operand X and name Name. template<typename Splitter> bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) { - VectorType *VT = dyn_cast<VectorType>(I.getType()); + auto *VT = dyn_cast<FixedVectorType>(I.getType()); if (!VT) return false; - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumElems = VT->getNumElements(); IRBuilder<> Builder(&I); Scatterer Op = scatter(&I, I.getOperand(0)); assert(Op.size() == NumElems && "Mismatched unary operation"); @@ -523,11 +527,11 @@ bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) { // to create an instruction like I with operands X and Y and name Name. template<typename Splitter> bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) { - VectorType *VT = dyn_cast<VectorType>(I.getType()); + auto *VT = dyn_cast<FixedVectorType>(I.getType()); if (!VT) return false; - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumElems = VT->getNumElements(); IRBuilder<> Builder(&I); Scatterer VOp0 = scatter(&I, I.getOperand(0)); Scatterer VOp1 = scatter(&I, I.getOperand(1)); @@ -558,7 +562,7 @@ static Function *getScalarIntrinsicDeclaration(Module *M, /// If a call to a vector typed intrinsic function, split into a scalar call per /// element if possible for the intrinsic. bool ScalarizerVisitor::splitCall(CallInst &CI) { - VectorType *VT = dyn_cast<VectorType>(CI.getType()); + auto *VT = dyn_cast<FixedVectorType>(CI.getType()); if (!VT) return false; @@ -570,7 +574,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID)) return false; - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumElems = VT->getNumElements(); unsigned NumArgs = CI.arg_size(); ValueVector ScalarOperands(NumArgs); @@ -623,11 +627,11 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { } bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) { - VectorType *VT = dyn_cast<VectorType>(SI.getType()); + auto *VT = dyn_cast<FixedVectorType>(SI.getType()); if (!VT) return false; - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumElems = VT->getNumElements(); IRBuilder<> Builder(&SI); Scatterer VOp1 = scatter(&SI, SI.getOperand(1)); Scatterer VOp2 = scatter(&SI, SI.getOperand(2)); @@ -676,12 +680,12 @@ bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) { } bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { - VectorType *VT = dyn_cast<VectorType>(GEPI.getType()); + auto *VT = dyn_cast<FixedVectorType>(GEPI.getType()); if (!VT) return false; IRBuilder<> Builder(&GEPI); - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumElems = VT->getNumElements(); unsigned NumIndices = GEPI.getNumIndices(); // The base pointer might be scalar even if it's a vector GEP. In those cases, @@ -722,11 +726,11 @@ bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { } bool ScalarizerVisitor::visitCastInst(CastInst &CI) { - VectorType *VT = dyn_cast<VectorType>(CI.getDestTy()); + auto *VT = dyn_cast<FixedVectorType>(CI.getDestTy()); if (!VT) return false; - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumElems = VT->getNumElements(); IRBuilder<> Builder(&CI); Scatterer Op0 = scatter(&CI, CI.getOperand(0)); assert(Op0.size() == NumElems && "Mismatched cast"); @@ -740,13 +744,13 @@ bool ScalarizerVisitor::visitCastInst(CastInst &CI) { } bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { - VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy()); - VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy()); + auto *DstVT = dyn_cast<FixedVectorType>(BCI.getDestTy()); + auto *SrcVT = dyn_cast<FixedVectorType>(BCI.getSrcTy()); if (!DstVT || !SrcVT) return false; - unsigned DstNumElems = cast<FixedVectorType>(DstVT)->getNumElements(); - unsigned SrcNumElems = cast<FixedVectorType>(SrcVT)->getNumElements(); + unsigned DstNumElems = DstVT->getNumElements(); + unsigned SrcNumElems = SrcVT->getNumElements(); IRBuilder<> Builder(&BCI); Scatterer Op0 = scatter(&BCI, BCI.getOperand(0)); ValueVector Res; @@ -795,11 +799,11 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { } bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { - VectorType *VT = dyn_cast<VectorType>(IEI.getType()); + auto *VT = dyn_cast<FixedVectorType>(IEI.getType()); if (!VT) return false; - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumElems = VT->getNumElements(); IRBuilder<> Builder(&IEI); Scatterer Op0 = scatter(&IEI, IEI.getOperand(0)); Value *NewElt = IEI.getOperand(1); @@ -830,11 +834,11 @@ bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { } bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { - VectorType *VT = dyn_cast<VectorType>(EEI.getOperand(0)->getType()); + auto *VT = dyn_cast<FixedVectorType>(EEI.getOperand(0)->getType()); if (!VT) return false; - unsigned NumSrcElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumSrcElems = VT->getNumElements(); IRBuilder<> Builder(&EEI); Scatterer Op0 = scatter(&EEI, EEI.getOperand(0)); Value *ExtIdx = EEI.getOperand(1); @@ -848,7 +852,7 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { if (!ScalarizeVariableInsertExtract) return false; - Value *Res = UndefValue::get(VT->getElementType()); + Value *Res = PoisonValue::get(VT->getElementType()); for (unsigned I = 0; I < NumSrcElems; ++I) { Value *ShouldExtract = Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I), @@ -862,11 +866,11 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { } bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) { - VectorType *VT = dyn_cast<VectorType>(SVI.getType()); + auto *VT = dyn_cast<FixedVectorType>(SVI.getType()); if (!VT) return false; - unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); + unsigned NumElems = VT->getNumElements(); Scatterer Op0 = scatter(&SVI, SVI.getOperand(0)); Scatterer Op1 = scatter(&SVI, SVI.getOperand(1)); ValueVector Res; @@ -886,7 +890,7 @@ bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) { } bool ScalarizerVisitor::visitPHINode(PHINode &PHI) { - VectorType *VT = dyn_cast<VectorType>(PHI.getType()); + auto *VT = dyn_cast<FixedVectorType>(PHI.getType()); if (!VT) return false; @@ -916,7 +920,7 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { if (!LI.isSimple()) return false; - Optional<VectorLayout> Layout = getVectorLayout( + std::optional<VectorLayout> Layout = getVectorLayout( LI.getType(), LI.getAlign(), LI.getModule()->getDataLayout()); if (!Layout) return false; @@ -942,7 +946,7 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { return false; Value *FullValue = SI.getValueOperand(); - Optional<VectorLayout> Layout = getVectorLayout( + std::optional<VectorLayout> Layout = getVectorLayout( FullValue->getType(), SI.getAlign(), SI.getModule()->getDataLayout()); if (!Layout) return false; @@ -981,9 +985,9 @@ bool ScalarizerVisitor::finish() { // The value is still needed, so recreate it using a series of // InsertElements. Value *Res = PoisonValue::get(Op->getType()); - if (auto *Ty = dyn_cast<VectorType>(Op->getType())) { + if (auto *Ty = dyn_cast<FixedVectorType>(Op->getType())) { BasicBlock *BB = Op->getParent(); - unsigned Count = cast<FixedVectorType>(Ty)->getNumElements(); + unsigned Count = Ty->getNumElements(); IRBuilder<> Builder(Op); if (isa<PHINode>(Op)) Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 7da5a78772ad..4fb90bcea4f0 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -817,6 +817,10 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, gep_type_iterator GTI = gep_type_begin(*GEP); for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { if (GTI.isSequential()) { + // Constant offsets of scalable types are not really constant. + if (isa<ScalableVectorType>(GTI.getIndexedType())) + continue; + // Tries to extract a constant offset from this GEP index. int64_t ConstantOffset = ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT); @@ -1006,6 +1010,10 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { gep_type_iterator GTI = gep_type_begin(*GEP); for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { if (GTI.isSequential()) { + // Constant offsets of scalable types are not really constant. + if (isa<ScalableVectorType>(GTI.getIndexedType())) + continue; + // Splits this GEP index into a variadic part and a constant offset, and // uses the variadic part as the new index. Value *OldIdx = GEP->getOperand(I); @@ -1122,18 +1130,17 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // sizeof(int64). // // Emit an uglygep in this case. - Type *I8PtrTy = Type::getInt8PtrTy(GEP->getContext(), - GEP->getPointerAddressSpace()); - NewGEP = new BitCastInst(NewGEP, I8PtrTy, "", GEP); - NewGEP = GetElementPtrInst::Create( - Type::getInt8Ty(GEP->getContext()), NewGEP, - ConstantInt::get(IntPtrTy, AccumulativeByteOffset, true), "uglygep", - GEP); + IRBuilder<> Builder(GEP); + Type *I8PtrTy = + Builder.getInt8Ty()->getPointerTo(GEP->getPointerAddressSpace()); + + NewGEP = cast<Instruction>(Builder.CreateGEP( + Builder.getInt8Ty(), Builder.CreateBitCast(NewGEP, I8PtrTy), + {ConstantInt::get(IntPtrTy, AccumulativeByteOffset, true)}, "uglygep", + GEPWasInBounds)); + NewGEP->copyMetadata(*GEP); - // Inherit the inbounds attribute of the original GEP. - cast<GetElementPtrInst>(NewGEP)->setIsInBounds(GEPWasInBounds); - if (GEP->getType() != I8PtrTy) - NewGEP = new BitCastInst(NewGEP, GEP->getType(), GEP->getName(), GEP); + NewGEP = cast<Instruction>(Builder.CreateBitCast(NewGEP, GEP->getType())); } GEP->replaceAllUsesWith(NewGEP); diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 0535608244cc..7e08120f923d 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/GuardUtils.h" @@ -26,6 +27,7 @@ #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -61,6 +63,7 @@ #include <cassert> #include <iterator> #include <numeric> +#include <optional> #include <utility> #define DEBUG_TYPE "simple-loop-unswitch" @@ -115,6 +118,18 @@ static cl::opt<bool> FreezeLoopUnswitchCond( cl::desc("If enabled, the freeze instruction will be added to condition " "of loop unswitch to prevent miscompilation.")); +namespace { +struct NonTrivialUnswitchCandidate { + Instruction *TI = nullptr; + TinyPtrVector<Value *> Invariants; + std::optional<InstructionCost> Cost; + NonTrivialUnswitchCandidate( + Instruction *TI, ArrayRef<Value *> Invariants, + std::optional<InstructionCost> Cost = std::nullopt) + : TI(TI), Invariants(Invariants), Cost(Cost){}; +}; +} // end anonymous namespace. + // Helper to skip (select x, true, false), which matches both a logical AND and // OR and can confuse code that tries to determine if \p Cond is either a // logical AND or OR but not both. @@ -133,8 +148,8 @@ static Value *skipTrivialSelect(Value *Cond) { /// inputs which are loop invariant. For some operations these can be /// re-associated and unswitched out of the loop entirely. static TinyPtrVector<Value *> -collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, - LoopInfo &LI) { +collectHomogenousInstGraphLoopInvariants(const Loop &L, Instruction &Root, + const LoopInfo &LI) { assert(!L.isLoopInvariant(&Root) && "Only need to walk the graph if root itself is not invariant."); TinyPtrVector<Value *> Invariants; @@ -175,7 +190,7 @@ collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, return Invariants; } -static void replaceLoopInvariantUses(Loop &L, Value *Invariant, +static void replaceLoopInvariantUses(const Loop &L, Value *Invariant, Constant &Replacement) { assert(!isa<Constant>(Invariant) && "Why are we unswitching on a constant?"); @@ -192,9 +207,10 @@ static void replaceLoopInvariantUses(Loop &L, Value *Invariant, /// Check that all the LCSSA PHI nodes in the loop exit block have trivial /// incoming values along this edge. -static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, - BasicBlock &ExitBB) { - for (Instruction &I : ExitBB) { +static bool areLoopExitPHIsLoopInvariant(const Loop &L, + const BasicBlock &ExitingBB, + const BasicBlock &ExitBB) { + for (const Instruction &I : ExitBB) { auto *PN = dyn_cast<PHINode>(&I); if (!PN) // No more PHIs to check. @@ -214,7 +230,7 @@ static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, static void buildPartialUnswitchConditionalBranch( BasicBlock &BB, ArrayRef<Value *> Invariants, bool Direction, BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc, bool InsertFreeze, - Instruction *I, AssumptionCache *AC, DominatorTree &DT) { + const Instruction *I, AssumptionCache *AC, const DominatorTree &DT) { IRBuilder<> IRB(&BB); SmallVector<Value *> FrozenInvariants; @@ -239,7 +255,7 @@ static void buildPartialInvariantUnswitchConditionalBranch( for (auto *Val : reverse(ToDuplicate)) { Instruction *Inst = cast<Instruction>(Val); Instruction *NewInst = Inst->clone(); - BB.getInstList().insert(BB.end(), NewInst); + NewInst->insertInto(&BB, BB.end()); RemapInstruction(NewInst, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); VMap[Val] = NewInst; @@ -418,9 +434,10 @@ static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader, // Return the top-most loop containing ExitBB and having ExitBB as exiting block // or the loop containing ExitBB, if there is no parent loop containing ExitBB // as exiting block. -static Loop *getTopMostExitingLoop(BasicBlock *ExitBB, LoopInfo &LI) { - Loop *TopMost = LI.getLoopFor(ExitBB); - Loop *Current = TopMost; +static const Loop *getTopMostExitingLoop(const BasicBlock *ExitBB, + const LoopInfo &LI) { + const Loop *TopMost = LI.getLoopFor(ExitBB); + const Loop *Current = TopMost; while (Current) { if (Current->isLoopExiting(ExitBB)) TopMost = Current; @@ -521,11 +538,12 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // loop, the loop containing the exit block and the topmost parent loop // exiting via LoopExitBB. if (SE) { - if (Loop *ExitL = getTopMostExitingLoop(LoopExitBB, LI)) + if (const Loop *ExitL = getTopMostExitingLoop(LoopExitBB, LI)) SE->forgetLoop(ExitL); else // Forget the entire nest as this exits the entire nest. SE->forgetTopmostLoop(&L); + SE->forgetBlockAndLoopDispositions(); } if (MSSAU && VerifyMemorySSA) @@ -562,13 +580,12 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // If fully unswitching, we can use the existing branch instruction. // Splice it into the old PH to gate reaching the new preheader and re-point // its successors. - OldPH->getInstList().splice(OldPH->end(), BI.getParent()->getInstList(), - BI); + OldPH->splice(OldPH->end(), BI.getParent(), BI.getIterator()); BI.setCondition(Cond); if (MSSAU) { // Temporarily clone the terminator, to make MSSA update cheaper by // separating "insert edge" updates from "remove edge" ones. - ParentBB->getInstList().push_back(BI.clone()); + BI.clone()->insertInto(ParentBB, ParentBB->end()); } else { // Create a new unconditional branch that will continue the loop as a new // terminator. @@ -1098,7 +1115,8 @@ static BasicBlock *buildClonedLoopBlocks( const SmallDenseMap<BasicBlock *, BasicBlock *, 16> &DominatingSucc, ValueToValueMapTy &VMap, SmallVectorImpl<DominatorTree::UpdateType> &DTUpdates, AssumptionCache &AC, - DominatorTree &DT, LoopInfo &LI, MemorySSAUpdater *MSSAU) { + DominatorTree &DT, LoopInfo &LI, MemorySSAUpdater *MSSAU, + ScalarEvolution *SE) { SmallVector<BasicBlock *, 4> NewBlocks; NewBlocks.reserve(L.getNumBlocks() + ExitBlocks.size()); @@ -1174,6 +1192,10 @@ static BasicBlock *buildClonedLoopBlocks( // We should have a value map between the instruction and its clone. assert(VMap.lookup(&I) == &ClonedI && "Mismatch in the value map!"); + // Forget SCEVs based on exit phis in case SCEV looked through the phi. + if (SE && isa<PHINode>(I)) + SE->forgetValue(&I); + auto *MergePN = PHINode::Create(I.getType(), /*NumReservedValues*/ 2, ".us-phi", &*MergeBB->getFirstInsertionPt()); @@ -1550,7 +1572,7 @@ static void buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, // We need a stable insertion order. We use the order of the original loop // order and map into the correct parent loop. for (auto *BB : llvm::concat<BasicBlock *const>( - makeArrayRef(ClonedPH), ClonedLoopBlocks, ClonedExitsInLoops)) + ArrayRef(ClonedPH), ClonedLoopBlocks, ClonedExitsInLoops)) if (Loop *OuterL = ExitLoopMap.lookup(BB)) OuterL->addBasicBlockToLoop(BB, LI); @@ -1590,7 +1612,7 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, // Find all the dead clones, and remove them from their successors. SmallVector<BasicBlock *, 16> DeadBlocks; for (BasicBlock *BB : llvm::concat<BasicBlock *const>(L.blocks(), ExitBlocks)) - for (auto &VMap : VMaps) + for (const auto &VMap : VMaps) if (BasicBlock *ClonedBB = cast_or_null<BasicBlock>(VMap->lookup(BB))) if (!DT.isReachableFromEntry(ClonedBB)) { for (BasicBlock *SuccBB : successors(ClonedBB)) @@ -1618,6 +1640,7 @@ deleteDeadBlocksFromLoop(Loop &L, SmallVectorImpl<BasicBlock *> &ExitBlocks, DominatorTree &DT, LoopInfo &LI, MemorySSAUpdater *MSSAU, + ScalarEvolution *SE, function_ref<void(Loop &, StringRef)> DestroyLoopCB) { // Find all the dead blocks tied to this loop, and remove them from their // successors. @@ -1669,6 +1692,8 @@ deleteDeadBlocksFromLoop(Loop &L, "If the child loop header is dead all blocks in the child loop must " "be dead as well!"); DestroyLoopCB(*ChildL, ChildL->getName()); + if (SE) + SE->forgetBlockAndLoopDispositions(); LI.destroy(ChildL); return true; }); @@ -1818,7 +1843,8 @@ static SmallPtrSet<const BasicBlock *, 16> recomputeLoopBlockSet(Loop &L, /// referenced). static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, LoopInfo &LI, - SmallVectorImpl<Loop *> &HoistedLoops) { + SmallVectorImpl<Loop *> &HoistedLoops, + ScalarEvolution *SE) { auto *PH = L.getLoopPreheader(); // Compute the actual parent loop from the exit blocks. Because we may have @@ -2011,6 +2037,8 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, LI.removeLoop(llvm::find(LI, &L)); // markLoopAsDeleted for L should be triggered by the caller (it is typically // done by using the UnswitchCB callback). + if (SE) + SE->forgetBlockAndLoopDispositions(); LI.destroy(&L); return false; } @@ -2047,8 +2075,8 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { static void unswitchNontrivialInvariants( Loop &L, Instruction &TI, ArrayRef<Value *> Invariants, - SmallVectorImpl<BasicBlock *> &ExitBlocks, IVConditionInfo &PartialIVInfo, - DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, function_ref<void(Loop &, StringRef)> DestroyLoopCB) { @@ -2129,6 +2157,8 @@ static void unswitchNontrivialInvariants( // furthest up our loopnest which can be mutated, which we will use below to // update things. Loop *OuterExitL = &L; + SmallVector<BasicBlock *, 4> ExitBlocks; + L.getUniqueExitBlocks(ExitBlocks); for (auto *ExitBB : ExitBlocks) { Loop *NewOuterExitL = LI.getLoopFor(ExitBB); if (!NewOuterExitL) { @@ -2148,6 +2178,7 @@ static void unswitchNontrivialInvariants( SE->forgetLoop(OuterExitL); else SE->forgetTopmostLoop(&L); + SE->forgetBlockAndLoopDispositions(); } bool InsertFreeze = false; @@ -2157,14 +2188,26 @@ static void unswitchNontrivialInvariants( InsertFreeze = !SafetyInfo.isGuaranteedToExecute(TI, &DT, &L); } + // Perform the isGuaranteedNotToBeUndefOrPoison() query before the transform, + // otherwise the branch instruction will have been moved outside the loop + // already, and may imply that a poison condition is always UB. + Value *FullUnswitchCond = nullptr; + if (FullUnswitch) { + FullUnswitchCond = + BI ? skipTrivialSelect(BI->getCondition()) : SI->getCondition(); + if (InsertFreeze) + InsertFreeze = !isGuaranteedNotToBeUndefOrPoison( + FullUnswitchCond, &AC, L.getLoopPreheader()->getTerminator(), &DT); + } + // If the edge from this terminator to a successor dominates that successor, // store a map from each block in its dominator subtree to it. This lets us // tell when cloning for a particular successor if a block is dominated by // some *other* successor with a single data structure. We use this to // significantly reduce cloning. SmallDenseMap<BasicBlock *, BasicBlock *, 16> DominatingSucc; - for (auto *SuccBB : llvm::concat<BasicBlock *const>( - makeArrayRef(RetainedSuccBB), UnswitchedSuccBBs)) + for (auto *SuccBB : llvm::concat<BasicBlock *const>(ArrayRef(RetainedSuccBB), + UnswitchedSuccBBs)) if (SuccBB->getUniquePredecessor() || llvm::all_of(predecessors(SuccBB), [&](BasicBlock *PredBB) { return PredBB == ParentBB || DT.dominates(SuccBB, PredBB); @@ -2193,7 +2236,7 @@ static void unswitchNontrivialInvariants( VMaps.emplace_back(new ValueToValueMapTy()); ClonedPHs[SuccBB] = buildClonedLoopBlocks( L, LoopPH, SplitBB, ExitBlocks, ParentBB, SuccBB, RetainedSuccBB, - DominatingSucc, *VMaps.back(), DTUpdates, AC, DT, LI, MSSAU); + DominatingSucc, *VMaps.back(), DTUpdates, AC, DT, LI, MSSAU, SE); } // Drop metadata if we may break its semantics by moving this instr into the @@ -2220,23 +2263,21 @@ static void unswitchNontrivialInvariants( if (FullUnswitch) { // Splice the terminator from the original loop and rewrite its // successors. - SplitBB->getInstList().splice(SplitBB->end(), ParentBB->getInstList(), TI); + SplitBB->splice(SplitBB->end(), ParentBB, TI.getIterator()); // Keep a clone of the terminator for MSSA updates. Instruction *NewTI = TI.clone(); - ParentBB->getInstList().push_back(NewTI); + NewTI->insertInto(ParentBB, ParentBB->end()); // First wire up the moved terminator to the preheaders. if (BI) { BasicBlock *ClonedPH = ClonedPHs.begin()->second; BI->setSuccessor(ClonedSucc, ClonedPH); BI->setSuccessor(1 - ClonedSucc, LoopPH); - Value *Cond = skipTrivialSelect(BI->getCondition()); - if (InsertFreeze) { - if (!isGuaranteedNotToBeUndefOrPoison(Cond, &AC, BI, &DT)) - Cond = new FreezeInst(Cond, Cond->getName() + ".fr", BI); - } - BI->setCondition(Cond); + if (InsertFreeze) + FullUnswitchCond = new FreezeInst( + FullUnswitchCond, FullUnswitchCond->getName() + ".fr", BI); + BI->setCondition(FullUnswitchCond); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); } else { assert(SI && "Must either be a branch or switch!"); @@ -2245,17 +2286,16 @@ static void unswitchNontrivialInvariants( assert(SI->getDefaultDest() == RetainedSuccBB && "Not retaining default successor!"); SI->setDefaultDest(LoopPH); - for (auto &Case : SI->cases()) + for (const auto &Case : SI->cases()) if (Case.getCaseSuccessor() == RetainedSuccBB) Case.setSuccessor(LoopPH); else Case.setSuccessor(ClonedPHs.find(Case.getCaseSuccessor())->second); - if (InsertFreeze) { - auto Cond = SI->getCondition(); - if (!isGuaranteedNotToBeUndefOrPoison(Cond, &AC, SI, &DT)) - SI->setCondition(new FreezeInst(Cond, Cond->getName() + ".fr", SI)); - } + if (InsertFreeze) + SI->setCondition(new FreezeInst( + FullUnswitchCond, FullUnswitchCond->getName() + ".fr", SI)); + // We need to use the set to populate domtree updates as even when there // are multiple cases pointing at the same successor we only want to // remove and insert one edge in the domtree. @@ -2306,7 +2346,7 @@ static void unswitchNontrivialInvariants( SwitchInst *NewSI = cast<SwitchInst>(NewTI); assert(NewSI->getDefaultDest() == RetainedSuccBB && "Not retaining default successor!"); - for (auto &Case : NewSI->cases()) + for (const auto &Case : NewSI->cases()) Case.getCaseSuccessor()->removePredecessor( ParentBB, /*KeepOneInputPHIs*/ true); @@ -2372,13 +2412,14 @@ static void unswitchNontrivialInvariants( // Now that our cloned loops have been built, we can update the original loop. // First we delete the dead blocks from it and then we rebuild the loop // structure taking these deletions into account. - deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, DestroyLoopCB); + deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE,DestroyLoopCB); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); SmallVector<Loop *, 4> HoistedLoops; - bool IsStillLoop = rebuildLoopAfterUnswitch(L, ExitBlocks, LI, HoistedLoops); + bool IsStillLoop = + rebuildLoopAfterUnswitch(L, ExitBlocks, LI, HoistedLoops, SE); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -2573,10 +2614,9 @@ static InstructionCost computeDomSubtreeCost( /// /// It also makes all relevant DT and LI updates, so that all structures are in /// valid state after this transform. -static BranchInst * -turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, - SmallVectorImpl<BasicBlock *> &ExitBlocks, - DominatorTree &DT, LoopInfo &LI, MemorySSAUpdater *MSSAU) { +static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, + DominatorTree &DT, LoopInfo &LI, + MemorySSAUpdater *MSSAU) { SmallVector<DominatorTree::UpdateType, 4> DTUpdates; LLVM_DEBUG(dbgs() << "Turning " << *GI << " into a branch.\n"); BasicBlock *CheckBB = GI->getParent(); @@ -2603,9 +2643,6 @@ turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, CheckBI->getSuccessor(1)->setName("deopt"); BasicBlock *DeoptBlock = CheckBI->getSuccessor(1); - // We now have a new exit block. - ExitBlocks.push_back(CheckBI->getSuccessor(1)); - if (MSSAU) MSSAU->moveAllAfterSpliceBlocks(CheckBB, GuardedBlock, GI); @@ -2651,19 +2688,19 @@ turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, /// That requires knowing not just the number of "remaining" candidates but /// also costs of unswitching for each of these candidates. static int CalculateUnswitchCostMultiplier( - Instruction &TI, Loop &L, LoopInfo &LI, DominatorTree &DT, - ArrayRef<std::pair<Instruction *, TinyPtrVector<Value *>>> - UnswitchCandidates) { + const Instruction &TI, const Loop &L, const LoopInfo &LI, + const DominatorTree &DT, + ArrayRef<NonTrivialUnswitchCandidate> UnswitchCandidates) { // Guards and other exiting conditions do not contribute to exponential // explosion as soon as they dominate the latch (otherwise there might be // another path to the latch remaining that does not allow to eliminate the // loop copy on unswitch). - BasicBlock *Latch = L.getLoopLatch(); - BasicBlock *CondBlock = TI.getParent(); + const BasicBlock *Latch = L.getLoopLatch(); + const BasicBlock *CondBlock = TI.getParent(); if (DT.dominates(CondBlock, Latch) && (isGuard(&TI) || - llvm::count_if(successors(&TI), [&L](BasicBlock *SuccBB) { + llvm::count_if(successors(&TI), [&L](const BasicBlock *SuccBB) { return L.contains(SuccBB); }) <= 1)) { NumCostMultiplierSkipped++; @@ -2677,16 +2714,17 @@ static int CalculateUnswitchCostMultiplier( // unswitching. Branch/guard counts as 1, switch counts as log2 of its cases. int UnswitchedClones = 0; for (auto Candidate : UnswitchCandidates) { - Instruction *CI = Candidate.first; - BasicBlock *CondBlock = CI->getParent(); + const Instruction *CI = Candidate.TI; + const BasicBlock *CondBlock = CI->getParent(); bool SkipExitingSuccessors = DT.dominates(CondBlock, Latch); if (isGuard(CI)) { if (!SkipExitingSuccessors) UnswitchedClones++; continue; } - int NonExitingSuccessors = llvm::count_if( - successors(CondBlock), [SkipExitingSuccessors, &L](BasicBlock *SuccBB) { + int NonExitingSuccessors = + llvm::count_if(successors(CondBlock), + [SkipExitingSuccessors, &L](const BasicBlock *SuccBB) { return !SkipExitingSuccessors || L.contains(SuccBB); }); UnswitchedClones += Log2_32(NonExitingSuccessors); @@ -2722,17 +2760,12 @@ static int CalculateUnswitchCostMultiplier( return CostMultiplier; } -static bool unswitchBestCondition( - Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - AAResults &AA, TargetTransformInfo &TTI, - function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - function_ref<void(Loop &, StringRef)> DestroyLoopCB) { - // Collect all invariant conditions within this loop (as opposed to an inner - // loop which would be handled when visiting that inner loop). - SmallVector<std::pair<Instruction *, TinyPtrVector<Value *>>, 4> - UnswitchCandidates; - +static bool collectUnswitchCandidates( + SmallVectorImpl<NonTrivialUnswitchCandidate> &UnswitchCandidates, + IVConditionInfo &PartialIVInfo, Instruction *&PartialIVCondBranch, + const Loop &L, const LoopInfo &LI, AAResults &AA, + const MemorySSAUpdater *MSSAU) { + assert(UnswitchCandidates.empty() && "Should be!"); // Whether or not we should also collect guards in the loop. bool CollectGuards = false; if (UnswitchGuards) { @@ -2742,7 +2775,6 @@ static bool unswitchBestCondition( CollectGuards = true; } - IVConditionInfo PartialIVInfo; for (auto *BB : L.blocks()) { if (LI.getLoopFor(BB) != &L) continue; @@ -2750,7 +2782,8 @@ static bool unswitchBestCondition( if (CollectGuards) for (auto &I : *BB) if (isGuard(&I)) { - auto *Cond = cast<IntrinsicInst>(&I)->getArgOperand(0); + auto *Cond = + skipTrivialSelect(cast<IntrinsicInst>(&I)->getArgOperand(0)); // TODO: Support AND, OR conditions and partial unswitching. if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond)) UnswitchCandidates.push_back({&I, {Cond}}); @@ -2791,11 +2824,10 @@ static bool unswitchBestCondition( } } - Instruction *PartialIVCondBranch = nullptr; if (MSSAU && !findOptionMDForLoop(&L, "llvm.loop.unswitch.partial.disable") && !any_of(UnswitchCandidates, [&L](auto &TerminatorAndInvariants) { - return TerminatorAndInvariants.first == L.getHeader()->getTerminator(); - })) { + return TerminatorAndInvariants.TI == L.getHeader()->getTerminator(); + })) { MemorySSA *MSSA = MSSAU->getMemorySSA(); if (auto Info = hasPartialIVCondition(L, MSSAThreshold, *MSSA, AA)) { LLVM_DEBUG( @@ -2809,10 +2841,22 @@ static bool unswitchBestCondition( {L.getHeader()->getTerminator(), std::move(ValsToDuplicate)}); } } + return !UnswitchCandidates.empty(); +} - // If we didn't find any candidates, we're done. - if (UnswitchCandidates.empty()) +static bool isSafeForNoNTrivialUnswitching(Loop &L, LoopInfo &LI) { + if (!L.isSafeToClone()) return false; + for (auto *BB : L.blocks()) + for (auto &I : *BB) { + if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) + return false; + if (auto *CB = dyn_cast<CallBase>(&I)) { + assert(!CB->cannotDuplicate() && "Checked by L.isSafeToClone()."); + if (CB->isConvergent()) + return false; + } + } // Check if there are irreducible CFG cycles in this loop. If so, we cannot // easily unswitch non-trivial edges out of the loop. Doing so might turn the @@ -2827,7 +2871,6 @@ static bool unswitchBestCondition( SmallVector<BasicBlock *, 4> ExitBlocks; L.getUniqueExitBlocks(ExitBlocks); - // We cannot unswitch if exit blocks contain a cleanuppad/catchswitch // instruction as we don't know how to split those exit blocks. // FIXME: We should teach SplitBlock to handle this and remove this @@ -2841,10 +2884,13 @@ static bool unswitchBestCondition( } } - LLVM_DEBUG( - dbgs() << "Considering " << UnswitchCandidates.size() - << " non-trivial loop invariant conditions for unswitching.\n"); + return true; +} +static NonTrivialUnswitchCandidate findBestNonTrivialUnswitchCandidate( + ArrayRef<NonTrivialUnswitchCandidate> UnswitchCandidates, const Loop &L, + const DominatorTree &DT, const LoopInfo &LI, AssumptionCache &AC, + const TargetTransformInfo &TTI, const IVConditionInfo &PartialIVInfo) { // Given that unswitching these terminators will require duplicating parts of // the loop, so we need to be able to model that cost. Compute the ephemeral // values and set up a data structure to hold per-BB costs. We cache each @@ -2869,14 +2915,7 @@ static bool unswitchBestCondition( for (auto &I : *BB) { if (EphValues.count(&I)) continue; - - if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) - return false; - if (auto *CB = dyn_cast<CallBase>(&I)) - if (CB->isConvergent() || CB->cannotDuplicate()) - return false; - - Cost += TTI.getUserCost(&I, CostKind); + Cost += TTI.getInstructionCost(&I, CostKind); } assert(Cost >= 0 && "Must not have negative costs!"); LoopCost += Cost; @@ -2958,12 +2997,11 @@ static bool unswitchBestCondition( "Cannot unswitch a condition without multiple distinct successors!"); return (LoopCost - Cost) * (SuccessorsCount - 1); }; - Instruction *BestUnswitchTI = nullptr; - InstructionCost BestUnswitchCost = 0; - ArrayRef<Value *> BestUnswitchInvariants; - for (auto &TerminatorAndInvariants : UnswitchCandidates) { - Instruction &TI = *TerminatorAndInvariants.first; - ArrayRef<Value *> Invariants = TerminatorAndInvariants.second; + + std::optional<NonTrivialUnswitchCandidate> Best; + for (auto &Candidate : UnswitchCandidates) { + Instruction &TI = *Candidate.TI; + ArrayRef<Value *> Invariants = Candidate.Invariants; BranchInst *BI = dyn_cast<BranchInst>(&TI); InstructionCost CandidateCost = ComputeUnswitchedCost( TI, /*FullUnswitch*/ !BI || @@ -2986,34 +3024,59 @@ static bool unswitchBestCondition( << " for unswitch candidate: " << TI << "\n"); } - if (!BestUnswitchTI || CandidateCost < BestUnswitchCost) { - BestUnswitchTI = &TI; - BestUnswitchCost = CandidateCost; - BestUnswitchInvariants = Invariants; + if (!Best || CandidateCost < Best->Cost) { + Best = Candidate; + Best->Cost = CandidateCost; } } - assert(BestUnswitchTI && "Failed to find loop unswitch candidate"); + assert(Best && "Must be!"); + return *Best; +} - if (BestUnswitchCost >= UnswitchThreshold) { - LLVM_DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " - << BestUnswitchCost << "\n"); +static bool unswitchBestCondition( + Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + AAResults &AA, TargetTransformInfo &TTI, + function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + function_ref<void(Loop &, StringRef)> DestroyLoopCB) { + // Collect all invariant conditions within this loop (as opposed to an inner + // loop which would be handled when visiting that inner loop). + SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates; + IVConditionInfo PartialIVInfo; + Instruction *PartialIVCondBranch = nullptr; + // If we didn't find any candidates, we're done. + if (!collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo, + PartialIVCondBranch, L, LI, AA, MSSAU)) + return false; + + LLVM_DEBUG( + dbgs() << "Considering " << UnswitchCandidates.size() + << " non-trivial loop invariant conditions for unswitching.\n"); + + NonTrivialUnswitchCandidate Best = findBestNonTrivialUnswitchCandidate( + UnswitchCandidates, L, DT, LI, AC, TTI, PartialIVInfo); + + assert(Best.TI && "Failed to find loop unswitch candidate"); + assert(Best.Cost && "Failed to compute cost"); + + if (*Best.Cost >= UnswitchThreshold) { + LLVM_DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " << *Best.Cost + << "\n"); return false; } - if (BestUnswitchTI != PartialIVCondBranch) + if (Best.TI != PartialIVCondBranch) PartialIVInfo.InstToDuplicate.clear(); // If the best candidate is a guard, turn it into a branch. - if (isGuard(BestUnswitchTI)) - BestUnswitchTI = turnGuardIntoBranch(cast<IntrinsicInst>(BestUnswitchTI), L, - ExitBlocks, DT, LI, MSSAU); - - LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " - << BestUnswitchCost << ") terminator: " << *BestUnswitchTI - << "\n"); - unswitchNontrivialInvariants(L, *BestUnswitchTI, BestUnswitchInvariants, - ExitBlocks, PartialIVInfo, DT, LI, AC, - UnswitchCB, SE, MSSAU, DestroyLoopCB); + if (isGuard(Best.TI)) + Best.TI = + turnGuardIntoBranch(cast<IntrinsicInst>(Best.TI), L, DT, LI, MSSAU); + + LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << Best.Cost + << ") terminator: " << *Best.TI << "\n"); + unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT, + LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB); return true; } @@ -3044,6 +3107,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, bool NonTrivial, function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, function_ref<void(Loop &, StringRef)> DestroyLoopCB) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -3080,8 +3144,16 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, if (L.getHeader()->getParent()->hasOptSize()) return false; - // Skip non-trivial unswitching for loops that cannot be cloned. - if (!L.isSafeToClone()) + // Skip cold loops, as unswitching them brings little benefit + // but increases the code size + if (PSI && PSI->hasProfileSummary() && BFI && + PSI->isFunctionColdInCallGraph(L.getHeader()->getParent(), *BFI)) { + LLVM_DEBUG(dbgs() << " Skip cold loop: " << L << "\n"); + return false; + } + + // Perform legality checks. + if (!isSafeForNoNTrivialUnswitching(L, LI)) return false; // For non-trivial unswitching, because it often creates new loops, we rely on @@ -3105,7 +3177,11 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, LPMUpdater &U) { Function &F = *L.getHeader()->getParent(); (void)F; - + ProfileSummaryInfo *PSI = nullptr; + if (auto OuterProxy = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR) + .getCachedResult<ModuleAnalysisManagerFunctionProxy>(F)) + PSI = OuterProxy->getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); @@ -3144,14 +3220,14 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, U.markLoopAsDeleted(L, Name); }; - Optional<MemorySSAUpdater> MSSAU; + std::optional<MemorySSAUpdater> MSSAU; if (AR.MSSA) { MSSAU = MemorySSAUpdater(AR.MSSA); if (VerifyMemorySSA) AR.MSSA->verifyMemorySSA(); } if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, - UnswitchCB, &AR.SE, MSSAU ? MSSAU.getPointer() : nullptr, + UnswitchCB, &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, DestroyLoopCB)) return PreservedAnalyses::all(); @@ -3214,7 +3290,6 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L << "\n"); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); @@ -3251,9 +3326,9 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { if (VerifyMemorySSA) MSSA->verifyMemorySSA(); - - bool Changed = unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, - UnswitchCB, SE, &MSSAU, DestroyLoopCB); + bool Changed = + unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, UnswitchCB, SE, + &MSSAU, nullptr, nullptr, DestroyLoopCB); if (VerifyMemorySSA) MSSA->verifyMemorySSA(); diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index fb2d812a186d..e014f5d1eb04 100644 --- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -108,12 +108,12 @@ performBlockTailMerging(Function &F, ArrayRef<BasicBlock *> BBs, std::get<1>(I) = PHINode::Create(std::get<0>(I)->getType(), /*NumReservedValues=*/BBs.size(), CanonicalBB->getName() + ".op"); - CanonicalBB->getInstList().push_back(std::get<1>(I)); + std::get<1>(I)->insertInto(CanonicalBB, CanonicalBB->end()); } // Make it so that this canonical block actually has the right // terminator. CanonicalTerm = Term->clone(); - CanonicalBB->getInstList().push_back(CanonicalTerm); + CanonicalTerm->insertInto(CanonicalBB, CanonicalBB->end()); // If the canonical terminator has operands, rewrite it to take PHI's. for (auto I : zip(NewOps, CanonicalTerm->operands())) std::get<1>(I) = std::get<0>(I); diff --git a/llvm/lib/Transforms/Scalar/Sink.cpp b/llvm/lib/Transforms/Scalar/Sink.cpp index e8fde53005f0..8b99f73b850b 100644 --- a/llvm/lib/Transforms/Scalar/Sink.cpp +++ b/llvm/lib/Transforms/Scalar/Sink.cpp @@ -79,7 +79,8 @@ static bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo, if (SuccToSinkTo->getUniquePredecessor() != Inst->getParent()) { // We cannot sink a load across a critical edge - there may be stores in // other code paths. - if (Inst->mayReadFromMemory()) + if (Inst->mayReadFromMemory() && + !Inst->hasMetadata(LLVMContext::MD_invariant_load)) return false; // We don't want to sink across a critical edge if we don't dominate the @@ -173,9 +174,6 @@ static bool SinkInstruction(Instruction *Inst, static bool ProcessBlock(BasicBlock &BB, DominatorTree &DT, LoopInfo &LI, AAResults &AA) { - // Can't sink anything out of a block that has less than two successors. - if (BB.getTerminator()->getNumSuccessors() <= 1) return false; - // Don't bother sinking code out of unreachable blocks. In addition to being // unprofitable, it can also lead to infinite looping, because in an // unreachable loop there may be nowhere to stop. diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index 9ac4608134c2..65f8d760ede3 100644 --- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -252,7 +252,7 @@ static InstructionCost ComputeSpeculationCost(const Instruction *I, case Instruction::ShuffleVector: case Instruction::ExtractValue: case Instruction::InsertValue: - return TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency); + return TTI.getInstructionCost(I, TargetTransformInfo::TCK_SizeAndLatency); default: return InstructionCost::getInvalid(); // Disallow anything not explicitly diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index 0b797abefe20..81d151c2904e 100644 --- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LegacyDivergenceAnalysis.h" @@ -87,6 +88,8 @@ using BBPredicates = DenseMap<BasicBlock *, Value *>; using PredMap = DenseMap<BasicBlock *, BBPredicates>; using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>; +using BranchDebugLocMap = DenseMap<BasicBlock *, DebugLoc>; + // A traits type that is intended to be used in graph algorithms. The graph // traits starts at an entry node, and traverses the RegionNodes that are in // the Nodes set. @@ -246,6 +249,7 @@ class StructurizeCFG { SmallVector<RegionNode *, 8> Order; BBSet Visited; + BBSet FlowSet; SmallVector<WeakVH, 8> AffectedPhis; BBPhiMap DeletedPhis; @@ -258,6 +262,8 @@ class StructurizeCFG { PredMap LoopPreds; BranchVector LoopConds; + BranchDebugLocMap TermDL; + RegionNode *PrevNode; void orderNodes(); @@ -278,6 +284,9 @@ class StructurizeCFG { void addPhiValues(BasicBlock *From, BasicBlock *To); + void findUndefBlocks(BasicBlock *PHIBlock, + const SmallSet<BasicBlock *, 8> &Incomings, + SmallVector<BasicBlock *> &UndefBlks) const; void setPhiValues(); void simplifyAffectedPhis(); @@ -395,7 +404,7 @@ void StructurizeCFG::orderNodes() { WorkList.emplace_back(I, I + Size); // Add the SCC nodes to the Order array. - for (auto &N : SCC) { + for (const auto &N : SCC) { assert(I < E && "SCC size mismatch!"); Order[I++] = N.first; } @@ -536,6 +545,14 @@ void StructurizeCFG::collectInfos() { // Find the last back edges analyzeLoops(RN); } + + // Reset the collected term debug locations + TermDL.clear(); + + for (BasicBlock &BB : *Func) { + if (const DebugLoc &DL = BB.getTerminator()->getDebugLoc()) + TermDL[&BB] = DL; + } } /// Insert the missing branch conditions @@ -632,6 +649,67 @@ void StructurizeCFG::addPhiValues(BasicBlock *From, BasicBlock *To) { AddedPhis[To].push_back(From); } +/// When we are reconstructing a PHI inside \p PHIBlock with incoming values +/// from predecessors \p Incomings, we have a chance to mark the available value +/// from some blocks as undefined. The function will find out all such blocks +/// and return in \p UndefBlks. +void StructurizeCFG::findUndefBlocks( + BasicBlock *PHIBlock, const SmallSet<BasicBlock *, 8> &Incomings, + SmallVector<BasicBlock *> &UndefBlks) const { + // We may get a post-structured CFG like below: + // + // | P1 + // |/ + // F1 + // |\ + // | N + // |/ + // F2 + // |\ + // | P2 + // |/ + // F3 + // |\ + // B + // + // B is the block that has a PHI being reconstructed. P1/P2 are predecessors + // of B before structurization. F1/F2/F3 are flow blocks inserted during + // structurization process. Block N is not a predecessor of B before + // structurization, but are placed between the predecessors(P1/P2) of B after + // structurization. This usually means that threads went to N never take the + // path N->F2->F3->B. For example, the threads take the branch F1->N may + // always take the branch F2->P2. So, when we are reconstructing a PHI + // originally in B, we can safely say the incoming value from N is undefined. + SmallSet<BasicBlock *, 8> VisitedBlock; + SmallVector<BasicBlock *, 8> Stack; + if (PHIBlock == ParentRegion->getExit()) { + for (auto P : predecessors(PHIBlock)) { + if (ParentRegion->contains(P)) + Stack.push_back(P); + } + } else { + append_range(Stack, predecessors(PHIBlock)); + } + + // Do a backward traversal over the CFG, and stop further searching if + // the block is not a Flow. If a block is neither flow block nor the + // incoming predecessor, then the incoming value from the block is + // undefined value for the PHI being reconstructed. + while (!Stack.empty()) { + BasicBlock *Current = Stack.pop_back_val(); + if (VisitedBlock.contains(Current)) + continue; + + VisitedBlock.insert(Current); + if (FlowSet.contains(Current)) { + for (auto P : predecessors(Current)) + Stack.push_back(P); + } else if (!Incomings.contains(Current)) { + UndefBlks.push_back(Current); + } + } +} + /// Add the real PHI value as soon as everything is set up void StructurizeCFG::setPhiValues() { SmallVector<PHINode *, 8> InsertedPhis; @@ -643,6 +721,8 @@ void StructurizeCFG::setPhiValues() { if (!DeletedPhis.count(To)) continue; + SmallVector<BasicBlock *> UndefBlks; + bool CachedUndefs = false; PhiMap &Map = DeletedPhis[To]; for (const auto &PI : Map) { PHINode *Phi = PI.first; @@ -651,15 +731,30 @@ void StructurizeCFG::setPhiValues() { Updater.AddAvailableValue(&Func->getEntryBlock(), Undef); Updater.AddAvailableValue(To, Undef); - NearestCommonDominator Dominator(DT); - Dominator.addBlock(To); + SmallSet<BasicBlock *, 8> Incomings; + SmallVector<BasicBlock *> ConstantPreds; for (const auto &VI : PI.second) { + Incomings.insert(VI.first); Updater.AddAvailableValue(VI.first, VI.second); - Dominator.addAndRememberBlock(VI.first); + if (isa<Constant>(VI.second)) + ConstantPreds.push_back(VI.first); } - if (!Dominator.resultIsRememberedBlock()) - Updater.AddAvailableValue(Dominator.result(), Undef); + if (!CachedUndefs) { + findUndefBlocks(To, Incomings, UndefBlks); + CachedUndefs = true; + } + + for (auto UB : UndefBlks) { + // If this undef block is dominated by any predecessor(before + // structurization) of reconstructed PHI with constant incoming value, + // don't mark the available value as undefined. Setting undef to such + // block will stop us from getting optimal phi insertion. + if (any_of(ConstantPreds, + [&](BasicBlock *CP) { return DT->dominates(CP, UB); })) + continue; + Updater.AddAvailableValue(UB, Undef); + } for (BasicBlock *FI : From) Phi->setIncomingValueForBlock(FI, Updater.GetValueAtEndOfBlock(FI)); @@ -679,6 +774,9 @@ void StructurizeCFG::simplifyAffectedPhis() { Changed = false; SimplifyQuery Q(Func->getParent()->getDataLayout()); Q.DT = DT; + // Setting CanUseUndef to true might extend value liveness, set it to false + // to achieve better register pressure. + Q.CanUseUndef = false; for (WeakVH VH : AffectedPhis) { if (auto Phi = dyn_cast_or_null<PHINode>(VH)) { if (auto NewValue = simplifyInstruction(Phi, Q)) { @@ -742,7 +840,8 @@ void StructurizeCFG::changeExit(RegionNode *Node, BasicBlock *NewExit, } else { BasicBlock *BB = Node->getNodeAs<BasicBlock>(); killTerminator(BB); - BranchInst::Create(NewExit, BB); + BranchInst *Br = BranchInst::Create(NewExit, BB); + Br->setDebugLoc(TermDL[BB]); addPhiValues(BB, NewExit); if (IncludeDominator) DT->changeImmediateDominator(NewExit, BB); @@ -756,6 +855,13 @@ BasicBlock *StructurizeCFG::getNextFlow(BasicBlock *Dominator) { Order.back()->getEntry(); BasicBlock *Flow = BasicBlock::Create(Context, FlowBlockName, Func, Insert); + FlowSet.insert(Flow); + + // use a temporary variable to avoid a use-after-free if the map's storage is + // reallocated + DebugLoc DL = TermDL[Dominator]; + TermDL[Flow] = std::move(DL); + DT->addNewBlock(Flow, Dominator); ParentRegion->getRegionInfo()->setRegionFor(Flow, ParentRegion); return Flow; @@ -851,7 +957,9 @@ void StructurizeCFG::wireFlow(bool ExitUseAllowed, BasicBlock *Next = needPostfix(Flow, ExitUseAllowed); // let it point to entry and next block - Conditions.push_back(BranchInst::Create(Entry, Next, BoolUndef, Flow)); + BranchInst *Br = BranchInst::Create(Entry, Next, BoolUndef, Flow); + Br->setDebugLoc(TermDL[Flow]); + Conditions.push_back(Br); addPhiValues(Flow, Entry); DT->changeImmediateDominator(Entry, Flow); @@ -885,26 +993,14 @@ void StructurizeCFG::handleLoops(bool ExitUseAllowed, handleLoops(false, LoopEnd); } - // If the start of the loop is the entry block, we can't branch to it so - // insert a new dummy entry block. - Function *LoopFunc = LoopStart->getParent(); - if (LoopStart == &LoopFunc->getEntryBlock()) { - LoopStart->setName("entry.orig"); - - BasicBlock *NewEntry = - BasicBlock::Create(LoopStart->getContext(), - "entry", - LoopFunc, - LoopStart); - BranchInst::Create(LoopStart, NewEntry); - DT->setNewRoot(NewEntry); - } + assert(LoopStart != &LoopStart->getParent()->getEntryBlock()); // Create an extra loop end node LoopEnd = needPrefix(false); BasicBlock *Next = needPostfix(LoopEnd, ExitUseAllowed); - LoopConds.push_back(BranchInst::Create(Next, LoopStart, - BoolUndef, LoopEnd)); + BranchInst *Br = BranchInst::Create(Next, LoopStart, BoolUndef, LoopEnd); + Br->setDebugLoc(TermDL[LoopEnd]); + LoopConds.push_back(Br); addPhiValues(LoopEnd, LoopStart); setPrevNode(Next); } @@ -974,7 +1070,7 @@ static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, // Count of how many direct children are conditional. unsigned ConditionalDirectChildren = 0; - for (auto E : R->elements()) { + for (auto *E : R->elements()) { if (!E->isSubRegion()) { auto Br = dyn_cast<BranchInst>(E->getEntry()->getTerminator()); if (!Br || !Br->isConditional()) @@ -998,7 +1094,7 @@ static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, // their direct child basic blocks' terminators, regardless of whether // subregions are uniform or not. However, this requires a very careful // look at SIAnnotateControlFlow to make sure nothing breaks there. - for (auto BB : E->getNodeAs<Region>()->blocks()) { + for (auto *BB : E->getNodeAs<Region>()->blocks()) { auto Br = dyn_cast<BranchInst>(BB->getTerminator()); if (!Br || !Br->isConditional()) continue; @@ -1100,6 +1196,8 @@ bool StructurizeCFG::run(Region *R, DominatorTree *DT) { Loops.clear(); LoopPreds.clear(); LoopConds.clear(); + FlowSet.clear(); + TermDL.clear(); return true; } diff --git a/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp b/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp index 16b3483f9687..4ec7181ad859 100644 --- a/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp +++ b/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp @@ -187,19 +187,7 @@ Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1, Instruction *I2) { if (!I1) return I2; - if (DT->dominates(I1, I2)) - return I1; - if (DT->dominates(I2, I1)) - return I2; - - // If there is no dominance relation, use common dominator. - BasicBlock *DomBB = - DT->findNearestCommonDominator(I1->getParent(), I2->getParent()); - - Instruction *Dom = DomBB->getTerminator(); - assert(Dom && "Common dominator not found!"); - - return Dom; + return DT->findNearestCommonDominator(I1, I2); } BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn, @@ -234,7 +222,7 @@ Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn, BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB); Type *Ty = GV->getType(); auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast"); - PosBB->getInstList().insert(Iter, CastInst); + CastInst->insertInto(PosBB, Iter); return CastInst; } diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index 27c04177e894..4f1350e4ebb9 100644 --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -243,10 +243,12 @@ static bool markTails(Function &F, OptimizationRemarkEmitter *ORE) { isa<PseudoProbeInst>(&I)) continue; - // Special-case operand bundles "clang.arc.attachedcall" and "ptrauth". - bool IsNoTail = - CI->isNoTailCall() || CI->hasOperandBundlesOtherThan( - {LLVMContext::OB_clang_arc_attachedcall, LLVMContext::OB_ptrauth}); + // Special-case operand bundles "clang.arc.attachedcall", "ptrauth", and + // "kcfi". + bool IsNoTail = CI->isNoTailCall() || + CI->hasOperandBundlesOtherThan( + {LLVMContext::OB_clang_arc_attachedcall, + LLVMContext::OB_ptrauth, LLVMContext::OB_kcfi}); if (!IsNoTail && CI->doesNotAccessMemory()) { // A call to a readnone function whose arguments are all things computed @@ -714,8 +716,8 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) { BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret); NewBI->setDebugLoc(CI->getDebugLoc()); - BB->getInstList().erase(Ret); // Remove return. - BB->getInstList().erase(CI); // Remove call. + Ret->eraseFromParent(); // Remove return. + CI->eraseFromParent(); // Remove call. DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}}); ++NumEliminated; return true; diff --git a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp index 8367e61c1a47..9e08954ef643 100644 --- a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp +++ b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp @@ -48,9 +48,9 @@ static void warnAboutLeftoverTransformations(Loop *L, if (hasVectorizeTransformation(L) == TM_ForcedByUser) { LLVM_DEBUG(dbgs() << "Leftover vectorization transformation\n"); - Optional<ElementCount> VectorizeWidth = + std::optional<ElementCount> VectorizeWidth = getOptionalElementCountLoopAttribute(L); - Optional<int> InterleaveCount = + std::optional<int> InterleaveCount = getOptionalIntLoopAttribute(L, "llvm.loop.interleave.count"); if (!VectorizeWidth || VectorizeWidth->isVector()) diff --git a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp index e6372fc5ab86..56acdcc0bc3c 100644 --- a/llvm/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/llvm/lib/Transforms/Utils/AddDiscriminators.cpp @@ -193,7 +193,7 @@ static bool addDiscriminators(Function &F) { // of the instruction appears in other basic block, assign a new // discriminator for this instruction. for (BasicBlock &B : F) { - for (auto &I : B.getInstList()) { + for (auto &I : B) { // Not all intrinsic calls should have a discriminator. // We want to avoid a non-deterministic assignment of discriminators at // different debug levels. We still allow discriminators on memory @@ -237,7 +237,7 @@ static bool addDiscriminators(Function &F) { // a same source line for correct profile annotation. for (BasicBlock &B : F) { LocationSet CallLocations; - for (auto &I : B.getInstList()) { + for (auto &I : B) { // We bypass intrinsic calls for the following two reasons: // 1) We want to avoid a non-deterministic assignment of // discriminators. diff --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp index 02ea17825c2f..d17c399ba798 100644 --- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp +++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp @@ -254,7 +254,7 @@ struct AssumeBuilderState { unsigned DerefSize = MemInst->getModule() ->getDataLayout() .getTypeStoreSize(AccType) - .getKnownMinSize(); + .getKnownMinValue(); if (DerefSize != 0) { addKnowledge({Attribute::Dereferenceable, DerefSize, Pointer}); if (!NullPointerIsDefined(MemInst->getFunction(), diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index e3cb5f359e34..58a226fc601c 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -81,10 +82,10 @@ void llvm::detachDeadBlocks( // eventually be removed (they are themselves dead). if (!I.use_empty()) I.replaceAllUsesWith(PoisonValue::get(I.getType())); - BB->getInstList().pop_back(); + BB->back().eraseFromParent(); } new UnreachableInst(BB->getContext(), BB); - assert(BB->getInstList().size() == 1 && + assert(BB->size() == 1 && isa<UnreachableInst>(BB->getTerminator()) && "The successor list of BB isn't empty before " "applying corresponding DTU updates."); @@ -149,7 +150,7 @@ bool llvm::FoldSingleEntryPHINodes(BasicBlock *BB, if (PN->getIncomingValue(0) != PN) PN->replaceAllUsesWith(PN->getIncomingValue(0)); else - PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->replaceAllUsesWith(PoisonValue::get(PN->getType())); if (MemDep) MemDep->removeInstruction(PN); // Memdep updates AA itself. @@ -178,7 +179,8 @@ bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI, bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, LoopInfo *LI, MemorySSAUpdater *MSSAU, MemoryDependenceResults *MemDep, - bool PredecessorWithTwoSuccessors) { + bool PredecessorWithTwoSuccessors, + DominatorTree *DT) { if (BB->hasAddressTaken()) return false; @@ -231,10 +233,21 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, FoldSingleEntryPHINodes(BB, MemDep); } + if (DT) { + assert(!DTU && "cannot use both DT and DTU for updates"); + DomTreeNode *PredNode = DT->getNode(PredBB); + DomTreeNode *BBNode = DT->getNode(BB); + if (PredNode) { + assert(BBNode && "PredNode unreachable but BBNode reachable?"); + for (DomTreeNode *C : to_vector(BBNode->children())) + C->setIDom(PredNode); + } + } // DTU update: Collect all the edges that exit BB. // These dominator edges will be redirected from Pred. std::vector<DominatorTree::UpdateType> Updates; if (DTU) { + assert(!DT && "cannot use both DT and DTU for updates"); // To avoid processing the same predecessor more than once. SmallPtrSet<BasicBlock *, 8> SeenSuccs; SmallPtrSet<BasicBlock *, 2> SuccsOfPredBB(succ_begin(PredBB), @@ -266,8 +279,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, Start = PTI; // Move all definitions in the successor to the predecessor... - PredBB->getInstList().splice(PTI->getIterator(), BB->getInstList(), - BB->begin(), STI->getIterator()); + PredBB->splice(PTI->getIterator(), BB, BB->begin(), STI->getIterator()); if (MSSAU) MSSAU->moveAllAfterMergeBlocks(BB, PredBB, Start); @@ -278,16 +290,16 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, if (PredecessorWithTwoSuccessors) { // Delete the unconditional branch from BB. - BB->getInstList().pop_back(); + BB->back().eraseFromParent(); // Update branch in the predecessor. PredBB_BI->setSuccessor(FallThruPath, NewSucc); } else { // Delete the unconditional branch from the predecessor. - PredBB->getInstList().pop_back(); + PredBB->back().eraseFromParent(); // Move terminator instruction. - PredBB->getInstList().splice(PredBB->end(), BB->getInstList()); + PredBB->splice(PredBB->end(), BB); // Terminator may be a memory accessing instruction too. if (MSSAU) @@ -311,6 +323,12 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, if (DTU) DTU->applyUpdates(Updates); + if (DT) { + assert(succ_empty(BB) && + "successors should have been transferred to PredBB"); + DT->eraseNode(BB); + } + // Finally, erase the old block and update dominator info. DeleteDeadBlock(BB, DTU); @@ -372,11 +390,22 @@ static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { DVI->getExpression(), DVI->getDebugLoc()->getInlinedAt()); auto R = VariableSet.insert(Key); + // If the variable fragment hasn't been seen before then we don't want + // to remove this dbg intrinsic. + if (R.second) + continue; + + if (auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI)) { + // Don't delete dbg.assign intrinsics that are linked to instructions. + if (!at::getAssignmentInsts(DAI).empty()) + continue; + // Unlinked dbg.assign intrinsics can be treated like dbg.values. + } + // If the same variable fragment is described more than once it is enough // to keep the last one (i.e. the first found since we for reverse // iteration). - if (!R.second) - ToBeRemoved.push_back(DVI); + ToBeRemoved.push_back(DVI); continue; } // Sequence with consecutive dbg.value instrs ended. Clear the map to @@ -416,19 +445,32 @@ static bool removeRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { VariableMap; for (auto &I : *BB) { if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(&I)) { - DebugVariable Key(DVI->getVariable(), - NoneType(), + DebugVariable Key(DVI->getVariable(), std::nullopt, DVI->getDebugLoc()->getInlinedAt()); auto VMI = VariableMap.find(Key); + auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI); + // A dbg.assign with no linked instructions can be treated like a + // dbg.value (i.e. can be deleted). + bool IsDbgValueKind = (!DAI || at::getAssignmentInsts(DAI).empty()); + // Update the map if we found a new value/expression describing the // variable, or if the variable wasn't mapped already. SmallVector<Value *, 4> Values(DVI->getValues()); if (VMI == VariableMap.end() || VMI->second.first != Values || VMI->second.second != DVI->getExpression()) { - VariableMap[Key] = {Values, DVI->getExpression()}; + // Use a sentinal value (nullptr) for the DIExpression when we see a + // linked dbg.assign so that the next debug intrinsic will never match + // it (i.e. always treat linked dbg.assigns as if they're unique). + if (IsDbgValueKind) + VariableMap[Key] = {Values, DVI->getExpression()}; + else + VariableMap[Key] = {Values, nullptr}; continue; } - // Found an identical mapping. Remember the instruction for later removal. + + // Don't delete dbg.assign intrinsics that are linked to instructions. + if (!IsDbgValueKind) + continue; ToBeRemoved.push_back(DVI); } } @@ -439,6 +481,60 @@ static bool removeRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { return !ToBeRemoved.empty(); } +/// Remove redundant undef dbg.assign intrinsic from an entry block using a +/// forward scan. +/// Strategy: +/// --------------------- +/// Scanning forward, delete dbg.assign intrinsics iff they are undef, not +/// linked to an intrinsic, and don't share an aggregate variable with a debug +/// intrinsic that didn't meet the criteria. In other words, undef dbg.assigns +/// that come before non-undef debug intrinsics for the variable are +/// deleted. Given: +/// +/// dbg.assign undef, "x", FragmentX1 (*) +/// <block of instructions, none being "dbg.value ..., "x", ..."> +/// dbg.value %V, "x", FragmentX2 +/// <block of instructions, none being "dbg.value ..., "x", ..."> +/// dbg.assign undef, "x", FragmentX1 +/// +/// then (only) the instruction marked with (*) can be removed. +/// Possible improvements: +/// - Keep track of non-overlapping fragments. +static bool remomveUndefDbgAssignsFromEntryBlock(BasicBlock *BB) { + assert(BB->isEntryBlock() && "expected entry block"); + SmallVector<DbgAssignIntrinsic *, 8> ToBeRemoved; + DenseSet<DebugVariable> SeenDefForAggregate; + // Returns the DebugVariable for DVI with no fragment info. + auto GetAggregateVariable = [](DbgValueInst *DVI) { + return DebugVariable(DVI->getVariable(), std::nullopt, + DVI->getDebugLoc()->getInlinedAt()); + }; + + // Remove undef dbg.assign intrinsics that are encountered before + // any non-undef intrinsics from the entry block. + for (auto &I : *BB) { + DbgValueInst *DVI = dyn_cast<DbgValueInst>(&I); + if (!DVI) + continue; + auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI); + bool IsDbgValueKind = (!DAI || at::getAssignmentInsts(DAI).empty()); + DebugVariable Aggregate = GetAggregateVariable(DVI); + if (!SeenDefForAggregate.contains(Aggregate)) { + bool IsKill = DVI->isKillLocation() && IsDbgValueKind; + if (!IsKill) { + SeenDefForAggregate.insert(Aggregate); + } else if (DAI) { + ToBeRemoved.push_back(DAI); + } + } + } + + for (DbgAssignIntrinsic *DAI : ToBeRemoved) + DAI->eraseFromParent(); + + return !ToBeRemoved.empty(); +} + bool llvm::RemoveRedundantDbgInstrs(BasicBlock *BB) { bool MadeChanges = false; // By using the "backward scan" strategy before the "forward scan" strategy we @@ -453,6 +549,9 @@ bool llvm::RemoveRedundantDbgInstrs(BasicBlock *BB) { // getting (2) out of the way, the foward scan will remove (3) since "x" // already is described as having the value V1 at (1). MadeChanges |= removeRedundantDbgInstrsUsingBackwardScan(BB); + if (BB->isEntryBlock() && + isAssignmentTrackingEnabled(*BB->getParent()->getParent())) + MadeChanges |= remomveUndefDbgAssignsFromEntryBlock(BB); MadeChanges |= removeRedundantDbgInstrsUsingForwardScan(BB); if (MadeChanges) @@ -461,8 +560,7 @@ bool llvm::RemoveRedundantDbgInstrs(BasicBlock *BB) { return MadeChanges; } -void llvm::ReplaceInstWithValue(BasicBlock::InstListType &BIL, - BasicBlock::iterator &BI, Value *V) { +void llvm::ReplaceInstWithValue(BasicBlock::iterator &BI, Value *V) { Instruction &I = *BI; // Replaces all of the uses of the instruction with uses of the value I.replaceAllUsesWith(V); @@ -472,11 +570,11 @@ void llvm::ReplaceInstWithValue(BasicBlock::InstListType &BIL, V->takeName(&I); // Delete the unnecessary instruction now... - BI = BIL.erase(BI); + BI = BI->eraseFromParent(); } -void llvm::ReplaceInstWithInst(BasicBlock::InstListType &BIL, - BasicBlock::iterator &BI, Instruction *I) { +void llvm::ReplaceInstWithInst(BasicBlock *BB, BasicBlock::iterator &BI, + Instruction *I) { assert(I->getParent() == nullptr && "ReplaceInstWithInst: Instruction already inserted into basic block!"); @@ -486,10 +584,10 @@ void llvm::ReplaceInstWithInst(BasicBlock::InstListType &BIL, I->setDebugLoc(BI->getDebugLoc()); // Insert the new instruction into the basic block... - BasicBlock::iterator New = BIL.insert(BI, I); + BasicBlock::iterator New = I->insertInto(BB, BI); // Replace all uses of the old instruction, and delete it. - ReplaceInstWithValue(BIL, BI, I); + ReplaceInstWithValue(BI, I); // Move BI back to point to the newly inserted instruction BI = New; @@ -511,7 +609,7 @@ bool llvm::IsBlockFollowedByDeoptOrUnreachable(const BasicBlock *BB) { void llvm::ReplaceInstWithInst(Instruction *From, Instruction *To) { BasicBlock::iterator BI(From); - ReplaceInstWithInst(From->getParent()->getInstList(), BI, To); + ReplaceInstWithInst(From->getParent(), BI, To); } BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, DominatorTree *DT, @@ -1126,13 +1224,13 @@ SplitBlockPredecessorsImpl(BasicBlock *BB, ArrayRef<BasicBlock *> Preds, BI->setDebugLoc(BB->getFirstNonPHIOrDbg()->getDebugLoc()); // Move the edges from Preds to point to NewBB instead of BB. - for (unsigned i = 0, e = Preds.size(); i != e; ++i) { + for (BasicBlock *Pred : Preds) { // This is slightly more strict than necessary; the minimum requirement // is that there be no more than one indirectbr branching to BB. And // all BlockAddress uses would need to be updated. - assert(!isa<IndirectBrInst>(Preds[i]->getTerminator()) && + assert(!isa<IndirectBrInst>(Pred->getTerminator()) && "Cannot split an edge from an IndirectBrInst"); - Preds[i]->getTerminator()->replaceSuccessorWith(BB, NewBB); + Pred->getTerminator()->replaceSuccessorWith(BB, NewBB); } // Insert a new PHI node into NewBB for every PHI node in BB and that new PHI @@ -1208,13 +1306,13 @@ static void SplitLandingPadPredecessorsImpl( BI1->setDebugLoc(OrigBB->getFirstNonPHI()->getDebugLoc()); // Move the edges from Preds to point to NewBB1 instead of OrigBB. - for (unsigned i = 0, e = Preds.size(); i != e; ++i) { + for (BasicBlock *Pred : Preds) { // This is slightly more strict than necessary; the minimum requirement // is that there be no more than one indirectbr branching to BB. And // all BlockAddress uses would need to be updated. - assert(!isa<IndirectBrInst>(Preds[i]->getTerminator()) && + assert(!isa<IndirectBrInst>(Pred->getTerminator()) && "Cannot split an edge from an IndirectBrInst"); - Preds[i]->getTerminator()->replaceUsesOfWith(OrigBB, NewBB1); + Pred->getTerminator()->replaceUsesOfWith(OrigBB, NewBB1); } bool HasLoopExit = false; @@ -1264,12 +1362,12 @@ static void SplitLandingPadPredecessorsImpl( LandingPadInst *LPad = OrigBB->getLandingPadInst(); Instruction *Clone1 = LPad->clone(); Clone1->setName(Twine("lpad") + Suffix1); - NewBB1->getInstList().insert(NewBB1->getFirstInsertionPt(), Clone1); + Clone1->insertInto(NewBB1, NewBB1->getFirstInsertionPt()); if (NewBB2) { Instruction *Clone2 = LPad->clone(); Clone2->setName(Twine("lpad") + Suffix2); - NewBB2->getInstList().insert(NewBB2->getFirstInsertionPt(), Clone2); + Clone2->insertInto(NewBB2, NewBB2->getFirstInsertionPt()); // Create a PHI node for the two cloned landingpad instructions only // if the original landingpad instruction has some uses. @@ -1320,7 +1418,7 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, Instruction *UncondBranch = Pred->getTerminator(); // Clone the return and add it to the end of the predecessor. Instruction *NewRet = RI->clone(); - Pred->getInstList().push_back(NewRet); + NewRet->insertInto(Pred, Pred->end()); // If the return instruction returns a value, and if the value was a // PHI node in "BB", propagate the right value into the return. @@ -1332,7 +1430,7 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, // return instruction. V = BCI->getOperand(0); NewBC = BCI->clone(); - Pred->getInstList().insert(NewRet->getIterator(), NewBC); + NewBC->insertInto(Pred, NewRet->getIterator()); Op = NewBC; } @@ -1342,9 +1440,9 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, NewEV = EVI->clone(); if (NewBC) { NewBC->setOperand(0, NewEV); - Pred->getInstList().insert(NewBC->getIterator(), NewEV); + NewEV->insertInto(Pred, NewBC->getIterator()); } else { - Pred->getInstList().insert(NewRet->getIterator(), NewEV); + NewEV->insertInto(Pred, NewRet->getIterator()); Op = NewEV; } } @@ -1465,8 +1563,14 @@ Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond, void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, Instruction **ThenTerm, Instruction **ElseTerm, - MDNode *BranchWeights) { + MDNode *BranchWeights, + DomTreeUpdater *DTU) { BasicBlock *Head = SplitBefore->getParent(); + + SmallPtrSet<BasicBlock *, 8> UniqueOrigSuccessors; + if (DTU) + UniqueOrigSuccessors.insert(succ_begin(Head), succ_end(Head)); + BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); Instruction *HeadOldTerm = Head->getTerminator(); LLVMContext &C = Head->getContext(); @@ -1480,6 +1584,19 @@ void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, BranchInst::Create(/*ifTrue*/ThenBlock, /*ifFalse*/ElseBlock, Cond); HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights); ReplaceInstWithInst(HeadOldTerm, HeadNewTerm); + if (DTU) { + SmallVector<DominatorTree::UpdateType, 8> Updates; + Updates.reserve(4 + 2 * UniqueOrigSuccessors.size()); + for (BasicBlock *Succ : successors(Head)) { + Updates.push_back({DominatorTree::Insert, Head, Succ}); + Updates.push_back({DominatorTree::Insert, Succ, Tail}); + } + for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors) + Updates.push_back({DominatorTree::Insert, Tail, UniqueOrigSuccessor}); + for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors) + Updates.push_back({DominatorTree::Delete, Head, UniqueOrigSuccessor}); + DTU->applyUpdates(Updates); + } } BranchInst *llvm::GetIfCondition(BasicBlock *BB, BasicBlock *&IfTrue, @@ -1591,8 +1708,8 @@ static void reconnectPhis(BasicBlock *Out, BasicBlock *GuardBlock, auto Phi = cast<PHINode>(I); auto NewPhi = PHINode::Create(Phi->getType(), Incoming.size(), - Phi->getName() + ".moved", &FirstGuardBlock->back()); - for (auto In : Incoming) { + Phi->getName() + ".moved", &FirstGuardBlock->front()); + for (auto *In : Incoming) { Value *V = UndefValue::get(Phi->getType()); if (In == Out) { V = NewPhi; @@ -1612,7 +1729,7 @@ static void reconnectPhis(BasicBlock *Out, BasicBlock *GuardBlock, } } -using BBPredicates = DenseMap<BasicBlock *, PHINode *>; +using BBPredicates = DenseMap<BasicBlock *, Instruction *>; using BBSetVector = SetVector<BasicBlock *>; // Redirects the terminator of the incoming block to the first guard @@ -1628,6 +1745,8 @@ using BBSetVector = SetVector<BasicBlock *>; static std::tuple<Value *, BasicBlock *, BasicBlock *> redirectToHub(BasicBlock *BB, BasicBlock *FirstGuardBlock, const BBSetVector &Outgoing) { + assert(isa<BranchInst>(BB->getTerminator()) && + "Only support branch terminator."); auto Branch = cast<BranchInst>(BB->getTerminator()); auto Condition = Branch->isConditional() ? Branch->getCondition() : nullptr; @@ -1655,38 +1774,101 @@ redirectToHub(BasicBlock *BB, BasicBlock *FirstGuardBlock, assert(Succ0 || Succ1); return std::make_tuple(Condition, Succ0, Succ1); } - -// Capture the existing control flow as guard predicates, and redirect -// control flow from every incoming block to the first guard block in -// the hub. +// Setup the branch instructions for guard blocks. // -// There is one guard predicate for each outgoing block OutBB. The -// predicate is a PHINode with one input for each InBB which -// represents whether the hub should transfer control flow to OutBB if -// it arrived from InBB. These predicates are NOT ORTHOGONAL. The Hub -// evaluates them in the same order as the Outgoing set-vector, and -// control branches to the first outgoing block whose predicate -// evaluates to true. -static void convertToGuardPredicates( - BasicBlock *FirstGuardBlock, BBPredicates &GuardPredicates, - SmallVectorImpl<WeakVH> &DeletionCandidates, const BBSetVector &Incoming, - const BBSetVector &Outgoing) { +// Each guard block terminates in a conditional branch that transfers +// control to the corresponding outgoing block or the next guard +// block. The last guard block has two outgoing blocks as successors +// since the condition for the final outgoing block is trivially +// true. So we create one less block (including the first guard block) +// than the number of outgoing blocks. +static void setupBranchForGuard(SmallVectorImpl<BasicBlock *> &GuardBlocks, + const BBSetVector &Outgoing, + BBPredicates &GuardPredicates) { + // To help keep the loop simple, temporarily append the last + // outgoing block to the list of guard blocks. + GuardBlocks.push_back(Outgoing.back()); + + for (int i = 0, e = GuardBlocks.size() - 1; i != e; ++i) { + auto Out = Outgoing[i]; + assert(GuardPredicates.count(Out)); + BranchInst::Create(Out, GuardBlocks[i + 1], GuardPredicates[Out], + GuardBlocks[i]); + } + + // Remove the last block from the guard list. + GuardBlocks.pop_back(); +} + +/// We are using one integer to represent the block we are branching to. Then at +/// each guard block, the predicate was calcuated using a simple `icmp eq`. +static void calcPredicateUsingInteger( + const BBSetVector &Incoming, const BBSetVector &Outgoing, + SmallVectorImpl<BasicBlock *> &GuardBlocks, BBPredicates &GuardPredicates) { + auto &Context = Incoming.front()->getContext(); + auto FirstGuardBlock = GuardBlocks.front(); + + auto Phi = PHINode::Create(Type::getInt32Ty(Context), Incoming.size(), + "merged.bb.idx", FirstGuardBlock); + + for (auto In : Incoming) { + Value *Condition; + BasicBlock *Succ0; + BasicBlock *Succ1; + std::tie(Condition, Succ0, Succ1) = + redirectToHub(In, FirstGuardBlock, Outgoing); + Value *IncomingId = nullptr; + if (Succ0 && Succ1) { + // target_bb_index = Condition ? index_of_succ0 : index_of_succ1. + auto Succ0Iter = find(Outgoing, Succ0); + auto Succ1Iter = find(Outgoing, Succ1); + Value *Id0 = ConstantInt::get(Type::getInt32Ty(Context), + std::distance(Outgoing.begin(), Succ0Iter)); + Value *Id1 = ConstantInt::get(Type::getInt32Ty(Context), + std::distance(Outgoing.begin(), Succ1Iter)); + IncomingId = SelectInst::Create(Condition, Id0, Id1, "target.bb.idx", + In->getTerminator()); + } else { + // Get the index of the non-null successor. + auto SuccIter = Succ0 ? find(Outgoing, Succ0) : find(Outgoing, Succ1); + IncomingId = ConstantInt::get(Type::getInt32Ty(Context), + std::distance(Outgoing.begin(), SuccIter)); + } + Phi->addIncoming(IncomingId, In); + } + + for (int i = 0, e = Outgoing.size() - 1; i != e; ++i) { + auto Out = Outgoing[i]; + auto Cmp = ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, Phi, + ConstantInt::get(Type::getInt32Ty(Context), i), + Out->getName() + ".predicate", GuardBlocks[i]); + GuardPredicates[Out] = Cmp; + } +} + +/// We record the predicate of each outgoing block using a phi of boolean. +static void calcPredicateUsingBooleans( + const BBSetVector &Incoming, const BBSetVector &Outgoing, + SmallVectorImpl<BasicBlock *> &GuardBlocks, BBPredicates &GuardPredicates, + SmallVectorImpl<WeakVH> &DeletionCandidates) { auto &Context = Incoming.front()->getContext(); auto BoolTrue = ConstantInt::getTrue(Context); auto BoolFalse = ConstantInt::getFalse(Context); + auto FirstGuardBlock = GuardBlocks.front(); // The predicate for the last outgoing is trivially true, and so we // process only the first N-1 successors. for (int i = 0, e = Outgoing.size() - 1; i != e; ++i) { auto Out = Outgoing[i]; LLVM_DEBUG(dbgs() << "Creating guard for " << Out->getName() << "\n"); + auto Phi = PHINode::Create(Type::getInt1Ty(Context), Incoming.size(), StringRef("Guard.") + Out->getName(), FirstGuardBlock); GuardPredicates[Out] = Phi; } - for (auto In : Incoming) { + for (auto *In : Incoming) { Value *Condition; BasicBlock *Succ0; BasicBlock *Succ1; @@ -1698,105 +1880,103 @@ static void convertToGuardPredicates( // for Succ0 and Succ1 complement each other. If Succ0 is visited // first in the loop below, control will branch to Succ0 using the // corresponding predicate. But if that branch is not taken, then - // control must reach Succ1, which means that the predicate for - // Succ1 is always true. + // control must reach Succ1, which means that the incoming value of + // the predicate from `In` is true for Succ1. bool OneSuccessorDone = false; for (int i = 0, e = Outgoing.size() - 1; i != e; ++i) { auto Out = Outgoing[i]; - auto Phi = GuardPredicates[Out]; + PHINode *Phi = cast<PHINode>(GuardPredicates[Out]); if (Out != Succ0 && Out != Succ1) { Phi->addIncoming(BoolFalse, In); - continue; - } - // Optimization: When only one successor is an outgoing block, - // the predicate is always true. - if (!Succ0 || !Succ1 || OneSuccessorDone) { + } else if (!Succ0 || !Succ1 || OneSuccessorDone) { + // Optimization: When only one successor is an outgoing block, + // the incoming predicate from `In` is always true. Phi->addIncoming(BoolTrue, In); - continue; - } - assert(Succ0 && Succ1); - OneSuccessorDone = true; - if (Out == Succ0) { - Phi->addIncoming(Condition, In); - continue; + } else { + assert(Succ0 && Succ1); + if (Out == Succ0) { + Phi->addIncoming(Condition, In); + } else { + auto Inverted = invertCondition(Condition); + DeletionCandidates.push_back(Condition); + Phi->addIncoming(Inverted, In); + } + OneSuccessorDone = true; } - auto Inverted = invertCondition(Condition); - DeletionCandidates.push_back(Condition); - Phi->addIncoming(Inverted, In); } } } -// For each outgoing block OutBB, create a guard block in the Hub. The -// first guard block was already created outside, and available as the -// first element in the vector of guard blocks. +// Capture the existing control flow as guard predicates, and redirect +// control flow from \p Incoming block through the \p GuardBlocks to the +// \p Outgoing blocks. // -// Each guard block terminates in a conditional branch that transfers -// control to the corresponding outgoing block or the next guard -// block. The last guard block has two outgoing blocks as successors -// since the condition for the final outgoing block is trivially -// true. So we create one less block (including the first guard block) -// than the number of outgoing blocks. -static void createGuardBlocks(SmallVectorImpl<BasicBlock *> &GuardBlocks, - Function *F, const BBSetVector &Outgoing, - BBPredicates &GuardPredicates, StringRef Prefix) { - for (int i = 0, e = Outgoing.size() - 2; i != e; ++i) { +// There is one guard predicate for each outgoing block OutBB. The +// predicate represents whether the hub should transfer control flow +// to OutBB. These predicates are NOT ORTHOGONAL. The Hub evaluates +// them in the same order as the Outgoing set-vector, and control +// branches to the first outgoing block whose predicate evaluates to true. +static void +convertToGuardPredicates(SmallVectorImpl<BasicBlock *> &GuardBlocks, + SmallVectorImpl<WeakVH> &DeletionCandidates, + const BBSetVector &Incoming, + const BBSetVector &Outgoing, const StringRef Prefix, + std::optional<unsigned> MaxControlFlowBooleans) { + BBPredicates GuardPredicates; + auto F = Incoming.front()->getParent(); + + for (int i = 0, e = Outgoing.size() - 1; i != e; ++i) GuardBlocks.push_back( BasicBlock::Create(F->getContext(), Prefix + ".guard", F)); - } - assert(GuardBlocks.size() == GuardPredicates.size()); - - // To help keep the loop simple, temporarily append the last - // outgoing block to the list of guard blocks. - GuardBlocks.push_back(Outgoing.back()); - for (int i = 0, e = GuardBlocks.size() - 1; i != e; ++i) { - auto Out = Outgoing[i]; - assert(GuardPredicates.count(Out)); - BranchInst::Create(Out, GuardBlocks[i + 1], GuardPredicates[Out], - GuardBlocks[i]); - } + // When we are using an integer to record which target block to jump to, we + // are creating less live values, actually we are using one single integer to + // store the index of the target block. When we are using booleans to store + // the branching information, we need (N-1) boolean values, where N is the + // number of outgoing block. + if (!MaxControlFlowBooleans || Outgoing.size() <= *MaxControlFlowBooleans) + calcPredicateUsingBooleans(Incoming, Outgoing, GuardBlocks, GuardPredicates, + DeletionCandidates); + else + calcPredicateUsingInteger(Incoming, Outgoing, GuardBlocks, GuardPredicates); - // Remove the last block from the guard list. - GuardBlocks.pop_back(); + setupBranchForGuard(GuardBlocks, Outgoing, GuardPredicates); } BasicBlock *llvm::CreateControlFlowHub( DomTreeUpdater *DTU, SmallVectorImpl<BasicBlock *> &GuardBlocks, const BBSetVector &Incoming, const BBSetVector &Outgoing, - const StringRef Prefix) { - auto F = Incoming.front()->getParent(); - auto FirstGuardBlock = - BasicBlock::Create(F->getContext(), Prefix + ".guard", F); + const StringRef Prefix, std::optional<unsigned> MaxControlFlowBooleans) { + if (Outgoing.size() < 2) + return Outgoing.front(); SmallVector<DominatorTree::UpdateType, 16> Updates; if (DTU) { - for (auto In : Incoming) { - Updates.push_back({DominatorTree::Insert, In, FirstGuardBlock}); - for (auto Succ : successors(In)) { + for (auto *In : Incoming) { + for (auto Succ : successors(In)) if (Outgoing.count(Succ)) Updates.push_back({DominatorTree::Delete, In, Succ}); - } } } - BBPredicates GuardPredicates; SmallVector<WeakVH, 8> DeletionCandidates; - convertToGuardPredicates(FirstGuardBlock, GuardPredicates, DeletionCandidates, - Incoming, Outgoing); - - GuardBlocks.push_back(FirstGuardBlock); - createGuardBlocks(GuardBlocks, F, Outgoing, GuardPredicates, Prefix); - + convertToGuardPredicates(GuardBlocks, DeletionCandidates, Incoming, Outgoing, + Prefix, MaxControlFlowBooleans); + auto FirstGuardBlock = GuardBlocks.front(); + // Update the PHINodes in each outgoing block to match the new control flow. - for (int i = 0, e = GuardBlocks.size(); i != e; ++i) { + for (int i = 0, e = GuardBlocks.size(); i != e; ++i) reconnectPhis(Outgoing[i], GuardBlocks[i], Incoming, FirstGuardBlock); - } + reconnectPhis(Outgoing.back(), GuardBlocks.back(), Incoming, FirstGuardBlock); if (DTU) { int NumGuards = GuardBlocks.size(); assert((int)Outgoing.size() == NumGuards + 1); + + for (auto In : Incoming) + Updates.push_back({DominatorTree::Insert, In, FirstGuardBlock}); + for (int i = 0; i != NumGuards - 1; ++i) { Updates.push_back({DominatorTree::Insert, GuardBlocks[i], Outgoing[i]}); Updates.push_back( diff --git a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp index 9c595401ce29..ddb35756030f 100644 --- a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -179,7 +179,7 @@ llvm::SplitKnownCriticalEdge(Instruction *TI, unsigned SuccNum, // Insert the block into the function... right after the block TI lives in. Function &F = *TIBB->getParent(); Function::iterator FBBI = TIBB->getIterator(); - F.getBasicBlockList().insert(++FBBI, NewBB); + F.insert(++FBBI, NewBB); // Branch to the new block, breaking the edge. TI->setSuccessor(SuccNum, NewBB); diff --git a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index e25ec74a0572..1e21a2f85446 100644 --- a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/TypeSize.h" +#include <optional> using namespace llvm; @@ -75,11 +76,6 @@ static bool setOnlyReadsMemory(Function &F) { static bool setOnlyWritesMemory(Function &F) { if (F.onlyWritesMemory()) // writeonly or readnone return false; - // Turn readonly and writeonly into readnone. - if (F.hasFnAttribute(Attribute::ReadOnly)) { - F.removeFnAttr(Attribute::ReadOnly); - return setDoesNotAccessMemory(F); - } ++NumWriteOnly; F.setOnlyWritesMemory(); return true; @@ -231,7 +227,7 @@ static bool setAllocatedPointerParam(Function &F, unsigned ArgNo) { } static bool setAllocSize(Function &F, unsigned ElemSizeArg, - Optional<unsigned> NumElemsArg) { + std::optional<unsigned> NumElemsArg) { if (F.hasFnAttribute(Attribute::AllocSize)) return false; F.addFnAttr(Attribute::getWithAllocSizeArgs(F.getContext(), ElemSizeArg, @@ -316,7 +312,7 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_strcpy: case LibFunc_strncpy: Changed |= setReturnedArg(F, 0); - LLVM_FALLTHROUGH; + [[fallthrough]]; case LibFunc_stpcpy: case LibFunc_stpncpy: Changed |= setOnlyAccessesArgMemory(F); @@ -386,7 +382,7 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, break; case LibFunc_strndup: Changed |= setArgNoUndef(F, 1); - LLVM_FALLTHROUGH; + [[fallthrough]]; case LibFunc_strdup: Changed |= setAllocFamily(F, "malloc"); Changed |= setOnlyAccessesInaccessibleMemOrArgMem(F); @@ -446,16 +442,16 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, break; case LibFunc_aligned_alloc: Changed |= setAlignedAllocParam(F, 0); - Changed |= setAllocSize(F, 1, None); + Changed |= setAllocSize(F, 1, std::nullopt); Changed |= setAllocKind(F, AllocFnKind::Alloc | AllocFnKind::Uninitialized | AllocFnKind::Aligned); - LLVM_FALLTHROUGH; + [[fallthrough]]; case LibFunc_valloc: case LibFunc_malloc: case LibFunc_vec_malloc: Changed |= setAllocFamily(F, TheLibFunc == LibFunc_vec_malloc ? "vec_malloc" : "malloc"); Changed |= setAllocKind(F, AllocFnKind::Alloc | AllocFnKind::Uninitialized); - Changed |= setAllocSize(F, 0, None); + Changed |= setAllocSize(F, 0, std::nullopt); Changed |= setOnlyAccessesInaccessibleMemory(F); Changed |= setRetAndArgsNoUndef(F); Changed |= setDoesNotThrow(F); @@ -507,7 +503,7 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_mempcpy: case LibFunc_memccpy: Changed |= setWillReturn(F); - LLVM_FALLTHROUGH; + [[fallthrough]]; case LibFunc_memcpy_chk: Changed |= setDoesNotThrow(F); Changed |= setOnlyAccessesArgMemory(F); @@ -521,7 +517,7 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, Changed |= setAllocFamily(F, "malloc"); Changed |= setAllocKind(F, AllocFnKind::Alloc | AllocFnKind::Aligned | AllocFnKind::Uninitialized); - Changed |= setAllocSize(F, 1, None); + Changed |= setAllocSize(F, 1, std::nullopt); Changed |= setAlignedAllocParam(F, 0); Changed |= setOnlyAccessesInaccessibleMemory(F); Changed |= setRetNoUndef(F); @@ -548,7 +544,7 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, F, TheLibFunc == LibFunc_vec_realloc ? "vec_malloc" : "malloc"); Changed |= setAllocKind(F, AllocFnKind::Realloc); Changed |= setAllocatedPointerParam(F, 0); - Changed |= setAllocSize(F, 1, None); + Changed |= setAllocSize(F, 1, std::nullopt); Changed |= setOnlyAccessesInaccessibleMemOrArgMem(F); Changed |= setRetNoUndef(F); Changed |= setDoesNotThrow(F); @@ -985,7 +981,7 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, break; case LibFunc_dunder_strndup: Changed |= setArgNoUndef(F, 1); - LLVM_FALLTHROUGH; + [[fallthrough]]; case LibFunc_dunder_strdup: Changed |= setDoesNotThrow(F); Changed |= setRetDoesNotAlias(F); @@ -1078,10 +1074,10 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); Changed |= setOnlyReadsMemory(F, 1); - LLVM_FALLTHROUGH; + [[fallthrough]]; case LibFunc_memset: Changed |= setWillReturn(F); - LLVM_FALLTHROUGH; + [[fallthrough]]; case LibFunc_memset_chk: Changed |= setOnlyAccessesArgMemory(F); Changed |= setOnlyWritesMemory(F, 0); @@ -1232,7 +1228,7 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, } // We have to do this step after AllocKind has been inferred on functions so // we can reliably identify free-like and realloc-like functions. - if (!isLibFreeFunction(&F, TheLibFunc) && !isReallocLikeFn(&F, &TLI)) + if (!isLibFreeFunction(&F, TheLibFunc) && !isReallocLikeFn(&F)) Changed |= setDoesNotFreeMemory(F); return Changed; } @@ -1244,6 +1240,13 @@ static void setArgExtAttr(Function &F, unsigned ArgNo, F.addParamAttr(ArgNo, ExtAttr); } +static void setRetExtAttr(Function &F, + const TargetLibraryInfo &TLI, bool Signed = true) { + Attribute::AttrKind ExtAttr = TLI.getExtAttrForI32Return(Signed); + if (ExtAttr != Attribute::None && !F.hasRetAttribute(ExtAttr)) + F.addRetAttr(ExtAttr); +} + // Modeled after X86TargetLowering::markLibCallAttributes. static void markRegisterParameterAttributes(Function *F) { if (!F->arg_size() || F->isVarArg()) @@ -1319,6 +1322,8 @@ FunctionCallee llvm::getOrInsertLibFunc(Module *M, const TargetLibraryInfo &TLI, // on any target: A size_t argument (which may be an i32 on some targets) // should not trigger the assert below. case LibFunc_bcmp: + setRetExtAttr(*F, TLI); + break; case LibFunc_calloc: case LibFunc_fwrite: case LibFunc_malloc: @@ -1421,6 +1426,15 @@ Value *llvm::castToCStr(Value *V, IRBuilderBase &B) { return B.CreateBitCast(V, B.getInt8PtrTy(AS), "cstr"); } +static IntegerType *getIntTy(IRBuilderBase &B, const TargetLibraryInfo *TLI) { + return B.getIntNTy(TLI->getIntSize()); +} + +static IntegerType *getSizeTTy(IRBuilderBase &B, const TargetLibraryInfo *TLI) { + const Module *M = B.GetInsertBlock()->getModule(); + return B.getIntNTy(TLI->getSizeTSize(*M)); +} + static Value *emitLibCall(LibFunc TheLibFunc, Type *ReturnType, ArrayRef<Type *> ParamTypes, ArrayRef<Value *> Operands, IRBuilderBase &B, @@ -1443,8 +1457,8 @@ static Value *emitLibCall(LibFunc TheLibFunc, Type *ReturnType, Value *llvm::emitStrLen(Value *Ptr, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - LLVMContext &Context = B.GetInsertBlock()->getContext(); - return emitLibCall(LibFunc_strlen, DL.getIntPtrType(Context), + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_strlen, SizeTTy, B.getInt8PtrTy(), castToCStr(Ptr, B), B, TLI); } @@ -1457,17 +1471,18 @@ Value *llvm::emitStrDup(Value *Ptr, IRBuilderBase &B, Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilderBase &B, const TargetLibraryInfo *TLI) { Type *I8Ptr = B.getInt8PtrTy(); - Type *I32Ty = B.getInt32Ty(); - return emitLibCall(LibFunc_strchr, I8Ptr, {I8Ptr, I32Ty}, - {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, B, TLI); + Type *IntTy = getIntTy(B, TLI); + return emitLibCall(LibFunc_strchr, I8Ptr, {I8Ptr, IntTy}, + {castToCStr(Ptr, B), ConstantInt::get(IntTy, C)}, B, TLI); } Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - LLVMContext &Context = B.GetInsertBlock()->getContext(); + Type *IntTy = getIntTy(B, TLI); + Type *SizeTTy = getSizeTTy(B, TLI); return emitLibCall( - LibFunc_strncmp, B.getInt32Ty(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)}, + LibFunc_strncmp, IntTy, + {B.getInt8PtrTy(), B.getInt8PtrTy(), SizeTTy}, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); } @@ -1488,14 +1503,16 @@ Value *llvm::emitStpCpy(Value *Dst, Value *Src, IRBuilderBase &B, Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B, const TargetLibraryInfo *TLI) { Type *I8Ptr = B.getInt8PtrTy(); - return emitLibCall(LibFunc_strncpy, I8Ptr, {I8Ptr, I8Ptr, Len->getType()}, + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_strncpy, I8Ptr, {I8Ptr, I8Ptr, SizeTTy}, {castToCStr(Dst, B), castToCStr(Src, B), Len}, B, TLI); } Value *llvm::emitStpNCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B, const TargetLibraryInfo *TLI) { Type *I8Ptr = B.getInt8PtrTy(); - return emitLibCall(LibFunc_stpncpy, I8Ptr, {I8Ptr, I8Ptr, Len->getType()}, + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_stpncpy, I8Ptr, {I8Ptr, I8Ptr, SizeTTy}, {castToCStr(Dst, B), castToCStr(Src, B), Len}, B, TLI); } @@ -1509,11 +1526,11 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, AttributeList AS; AS = AttributeList::get(M->getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind); - LLVMContext &Context = B.GetInsertBlock()->getContext(); + Type *I8Ptr = B.getInt8PtrTy(); + Type *SizeTTy = getSizeTTy(B, TLI); FunctionCallee MemCpy = getOrInsertLibFunc(M, *TLI, LibFunc_memcpy_chk, - AttributeList::get(M->getContext(), AS), B.getInt8PtrTy(), - B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context), - DL.getIntPtrType(Context)); + AttributeList::get(M->getContext(), AS), I8Ptr, + I8Ptr, I8Ptr, SizeTTy, SizeTTy); Dst = castToCStr(Dst, B); Src = castToCStr(Src, B); CallInst *CI = B.CreateCall(MemCpy, {Dst, Src, Len, ObjSize}); @@ -1525,74 +1542,85 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, Value *llvm::emitMemPCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - LLVMContext &Context = B.GetInsertBlock()->getContext(); - return emitLibCall( - LibFunc_mempcpy, B.getInt8PtrTy(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)}, - {Dst, Src, Len}, B, TLI); + Type *I8Ptr = B.getInt8PtrTy(); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_mempcpy, I8Ptr, + {I8Ptr, I8Ptr, SizeTTy}, + {Dst, Src, Len}, B, TLI); } Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - LLVMContext &Context = B.GetInsertBlock()->getContext(); - return emitLibCall( - LibFunc_memchr, B.getInt8PtrTy(), - {B.getInt8PtrTy(), B.getInt32Ty(), DL.getIntPtrType(Context)}, - {castToCStr(Ptr, B), Val, Len}, B, TLI); + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_memchr, I8Ptr, + {I8Ptr, IntTy, SizeTTy}, + {castToCStr(Ptr, B), Val, Len}, B, TLI); } Value *llvm::emitMemRChr(Value *Ptr, Value *Val, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - LLVMContext &Context = B.GetInsertBlock()->getContext(); - return emitLibCall( - LibFunc_memrchr, B.getInt8PtrTy(), - {B.getInt8PtrTy(), B.getInt32Ty(), DL.getIntPtrType(Context)}, - {castToCStr(Ptr, B), Val, Len}, B, TLI); + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_memrchr, I8Ptr, + {I8Ptr, IntTy, SizeTTy}, + {castToCStr(Ptr, B), Val, Len}, B, TLI); } Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - LLVMContext &Context = B.GetInsertBlock()->getContext(); - return emitLibCall( - LibFunc_memcmp, B.getInt32Ty(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)}, - {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_memcmp, IntTy, + {I8Ptr, I8Ptr, SizeTTy}, + {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); } Value *llvm::emitBCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - LLVMContext &Context = B.GetInsertBlock()->getContext(); - return emitLibCall( - LibFunc_bcmp, B.getInt32Ty(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)}, - {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_bcmp, IntTy, + {I8Ptr, I8Ptr, SizeTTy}, + {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); } Value *llvm::emitMemCCpy(Value *Ptr1, Value *Ptr2, Value *Val, Value *Len, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - return emitLibCall( - LibFunc_memccpy, B.getInt8PtrTy(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt32Ty(), Len->getType()}, - {Ptr1, Ptr2, Val, Len}, B, TLI); + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_memccpy, I8Ptr, + {I8Ptr, I8Ptr, IntTy, SizeTTy}, + {Ptr1, Ptr2, Val, Len}, B, TLI); } Value *llvm::emitSNPrintf(Value *Dest, Value *Size, Value *Fmt, ArrayRef<Value *> VariadicArgs, IRBuilderBase &B, const TargetLibraryInfo *TLI) { + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); + Type *SizeTTy = getSizeTTy(B, TLI); SmallVector<Value *, 8> Args{castToCStr(Dest, B), Size, castToCStr(Fmt, B)}; llvm::append_range(Args, VariadicArgs); - return emitLibCall(LibFunc_snprintf, B.getInt32Ty(), - {B.getInt8PtrTy(), Size->getType(), B.getInt8PtrTy()}, + return emitLibCall(LibFunc_snprintf, IntTy, + {I8Ptr, SizeTTy, I8Ptr}, Args, B, TLI, /*IsVaArgs=*/true); } Value *llvm::emitSPrintf(Value *Dest, Value *Fmt, ArrayRef<Value *> VariadicArgs, IRBuilderBase &B, const TargetLibraryInfo *TLI) { + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); SmallVector<Value *, 8> Args{castToCStr(Dest, B), castToCStr(Fmt, B)}; llvm::append_range(Args, VariadicArgs); - return emitLibCall(LibFunc_sprintf, B.getInt32Ty(), - {B.getInt8PtrTy(), B.getInt8PtrTy()}, Args, B, TLI, + return emitLibCall(LibFunc_sprintf, IntTy, + {I8Ptr, I8Ptr}, Args, B, TLI, /*IsVaArgs=*/true); } @@ -1605,37 +1633,48 @@ Value *llvm::emitStrCat(Value *Dest, Value *Src, IRBuilderBase &B, Value *llvm::emitStrLCpy(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - return emitLibCall(LibFunc_strlcpy, Size->getType(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), Size->getType()}, + Type *I8Ptr = B.getInt8PtrTy(); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_strlcpy, SizeTTy, + {I8Ptr, I8Ptr, SizeTTy}, {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); } Value *llvm::emitStrLCat(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - return emitLibCall(LibFunc_strlcat, Size->getType(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), Size->getType()}, + Type *I8Ptr = B.getInt8PtrTy(); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_strlcat, SizeTTy, + {I8Ptr, I8Ptr, SizeTTy}, {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); } Value *llvm::emitStrNCat(Value *Dest, Value *Src, Value *Size, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - return emitLibCall(LibFunc_strncat, B.getInt8PtrTy(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), Size->getType()}, + Type *I8Ptr = B.getInt8PtrTy(); + Type *SizeTTy = getSizeTTy(B, TLI); + return emitLibCall(LibFunc_strncat, I8Ptr, + {I8Ptr, I8Ptr, SizeTTy}, {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); } Value *llvm::emitVSNPrintf(Value *Dest, Value *Size, Value *Fmt, Value *VAList, IRBuilderBase &B, const TargetLibraryInfo *TLI) { + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); + Type *SizeTTy = getSizeTTy(B, TLI); return emitLibCall( - LibFunc_vsnprintf, B.getInt32Ty(), - {B.getInt8PtrTy(), Size->getType(), B.getInt8PtrTy(), VAList->getType()}, + LibFunc_vsnprintf, IntTy, + {I8Ptr, SizeTTy, I8Ptr, VAList->getType()}, {castToCStr(Dest, B), Size, castToCStr(Fmt, B), VAList}, B, TLI); } Value *llvm::emitVSPrintf(Value *Dest, Value *Fmt, Value *VAList, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - return emitLibCall(LibFunc_vsprintf, B.getInt32Ty(), - {B.getInt8PtrTy(), B.getInt8PtrTy(), VAList->getType()}, + Type *I8Ptr = B.getInt8PtrTy(); + Type *IntTy = getIntTy(B, TLI); + return emitLibCall(LibFunc_vsprintf, IntTy, + {I8Ptr, I8Ptr, VAList->getType()}, {castToCStr(Dest, B), castToCStr(Fmt, B), VAList}, B, TLI); } @@ -1756,22 +1795,20 @@ Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, return emitBinaryFloatFnCallHelper(Op1, Op2, TheLibFunc, Name, B, Attrs, TLI); } +// Emit a call to putchar(int) with Char as the argument. Char must have +// the same precision as int, which need not be 32 bits. Value *llvm::emitPutChar(Value *Char, IRBuilderBase &B, const TargetLibraryInfo *TLI) { Module *M = B.GetInsertBlock()->getModule(); if (!isLibFuncEmittable(M, TLI, LibFunc_putchar)) return nullptr; + Type *IntTy = getIntTy(B, TLI); StringRef PutCharName = TLI->getName(LibFunc_putchar); FunctionCallee PutChar = getOrInsertLibFunc(M, *TLI, LibFunc_putchar, - B.getInt32Ty(), B.getInt32Ty()); + IntTy, IntTy); inferNonMandatoryLibFuncAttrs(M, PutCharName, *TLI); - CallInst *CI = B.CreateCall(PutChar, - B.CreateIntCast(Char, - B.getInt32Ty(), - /*isSigned*/true, - "chari"), - PutCharName); + CallInst *CI = B.CreateCall(PutChar, Char, PutCharName); if (const Function *F = dyn_cast<Function>(PutChar.getCallee()->stripPointerCasts())) @@ -1785,8 +1822,9 @@ Value *llvm::emitPutS(Value *Str, IRBuilderBase &B, if (!isLibFuncEmittable(M, TLI, LibFunc_puts)) return nullptr; + Type *IntTy = getIntTy(B, TLI); StringRef PutsName = TLI->getName(LibFunc_puts); - FunctionCallee PutS = getOrInsertLibFunc(M, *TLI, LibFunc_puts, B.getInt32Ty(), + FunctionCallee PutS = getOrInsertLibFunc(M, *TLI, LibFunc_puts, IntTy, B.getInt8PtrTy()); inferNonMandatoryLibFuncAttrs(M, PutsName, *TLI); CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), PutsName); @@ -1802,13 +1840,12 @@ Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilderBase &B, if (!isLibFuncEmittable(M, TLI, LibFunc_fputc)) return nullptr; + Type *IntTy = getIntTy(B, TLI); StringRef FPutcName = TLI->getName(LibFunc_fputc); - FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fputc, B.getInt32Ty(), - B.getInt32Ty(), File->getType()); + FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fputc, IntTy, + IntTy, File->getType()); if (File->getType()->isPointerTy()) inferNonMandatoryLibFuncAttrs(M, FPutcName, *TLI); - Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, - "chari"); CallInst *CI = B.CreateCall(F, {Char, File}, FPutcName); if (const Function *Fn = @@ -1823,8 +1860,9 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilderBase &B, if (!isLibFuncEmittable(M, TLI, LibFunc_fputs)) return nullptr; + Type *IntTy = getIntTy(B, TLI); StringRef FPutsName = TLI->getName(LibFunc_fputs); - FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fputs, B.getInt32Ty(), + FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fputs, IntTy, B.getInt8PtrTy(), File->getType()); if (File->getType()->isPointerTy()) inferNonMandatoryLibFuncAttrs(M, FPutsName, *TLI); @@ -1842,17 +1880,17 @@ Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilderBase &B, if (!isLibFuncEmittable(M, TLI, LibFunc_fwrite)) return nullptr; - LLVMContext &Context = B.GetInsertBlock()->getContext(); + Type *SizeTTy = getSizeTTy(B, TLI); StringRef FWriteName = TLI->getName(LibFunc_fwrite); FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fwrite, - DL.getIntPtrType(Context), B.getInt8PtrTy(), DL.getIntPtrType(Context), - DL.getIntPtrType(Context), File->getType()); + SizeTTy, B.getInt8PtrTy(), SizeTTy, + SizeTTy, File->getType()); if (File->getType()->isPointerTy()) inferNonMandatoryLibFuncAttrs(M, FWriteName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, - ConstantInt::get(DL.getIntPtrType(Context), 1), File}); + ConstantInt::get(SizeTTy, 1), File}); if (const Function *Fn = dyn_cast<Function>(F.getCallee()->stripPointerCasts())) @@ -1867,9 +1905,9 @@ Value *llvm::emitMalloc(Value *Num, IRBuilderBase &B, const DataLayout &DL, return nullptr; StringRef MallocName = TLI->getName(LibFunc_malloc); - LLVMContext &Context = B.GetInsertBlock()->getContext(); + Type *SizeTTy = getSizeTTy(B, TLI); FunctionCallee Malloc = getOrInsertLibFunc(M, *TLI, LibFunc_malloc, - B.getInt8PtrTy(), DL.getIntPtrType(Context)); + B.getInt8PtrTy(), SizeTTy); inferNonMandatoryLibFuncAttrs(M, MallocName, *TLI); CallInst *CI = B.CreateCall(Malloc, Num, MallocName); @@ -1887,10 +1925,9 @@ Value *llvm::emitCalloc(Value *Num, Value *Size, IRBuilderBase &B, return nullptr; StringRef CallocName = TLI.getName(LibFunc_calloc); - const DataLayout &DL = M->getDataLayout(); - IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); + Type *SizeTTy = getSizeTTy(B, &TLI); FunctionCallee Calloc = getOrInsertLibFunc(M, TLI, LibFunc_calloc, - B.getInt8PtrTy(), PtrType, PtrType); + B.getInt8PtrTy(), SizeTTy, SizeTTy); inferNonMandatoryLibFuncAttrs(M, CallocName, TLI); CallInst *CI = B.CreateCall(Calloc, {Num, Size}, CallocName); diff --git a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp index 833d04210629..930a0bcbfac5 100644 --- a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -16,8 +16,6 @@ #include "llvm/Transforms/Utils/BypassSlowDivision.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Transforms/Utils/Local.h" @@ -87,7 +85,7 @@ class FastDivInsertionTask { QuotRemPair createDivRemPhiNodes(QuotRemWithBB &LHS, QuotRemWithBB &RHS, BasicBlock *PhiBB); Value *insertOperandRuntimeCheck(Value *Op1, Value *Op2); - Optional<QuotRemPair> insertFastDivAndRem(); + std::optional<QuotRemPair> insertFastDivAndRem(); bool isSignedOp() { return SlowDivOrRem->getOpcode() == Instruction::SDiv || @@ -161,7 +159,7 @@ Value *FastDivInsertionTask::getReplacement(DivCacheTy &Cache) { if (CacheI == Cache.end()) { // If previous instance does not exist, try to insert fast div. - Optional<QuotRemPair> OptResult = insertFastDivAndRem(); + std::optional<QuotRemPair> OptResult = insertFastDivAndRem(); // Bail out if insertFastDivAndRem has failed. if (!OptResult) return nullptr; @@ -350,19 +348,19 @@ Value *FastDivInsertionTask::insertOperandRuntimeCheck(Value *Op1, Value *Op2) { /// Substitutes the div/rem instruction with code that checks the value of the /// operands and uses a shorter-faster div/rem instruction when possible. -Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { +std::optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { Value *Dividend = SlowDivOrRem->getOperand(0); Value *Divisor = SlowDivOrRem->getOperand(1); VisitedSetTy SetL; ValueRange DividendRange = getValueRange(Dividend, SetL); if (DividendRange == VALRNG_LIKELY_LONG) - return None; + return std::nullopt; VisitedSetTy SetR; ValueRange DivisorRange = getValueRange(Divisor, SetR); if (DivisorRange == VALRNG_LIKELY_LONG) - return None; + return std::nullopt; bool DividendShort = (DividendRange == VALRNG_KNOWN_SHORT); bool DivisorShort = (DivisorRange == VALRNG_KNOWN_SHORT); @@ -387,7 +385,7 @@ Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { // If the divisor is not a constant, DAGCombiner will convert it to a // multiplication by a magic constant. It isn't clear if it is worth // introducing control flow to get a narrower multiply. - return None; + return std::nullopt; } // After Constant Hoisting pass, long constants may be represented as @@ -397,7 +395,7 @@ Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { if (auto *BCI = dyn_cast<BitCastInst>(Divisor)) if (BCI->getParent() == SlowDivOrRem->getParent() && isa<ConstantInt>(BCI->getOperand(0))) - return None; + return std::nullopt; IRBuilder<> Builder(MainBB, MainBB->end()); Builder.SetCurrentDebugLocation(SlowDivOrRem->getDebugLoc()); @@ -417,7 +415,7 @@ Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { // Split the basic block before the div/rem. BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem); // Remove the unconditional branch from MainBB to SuccessorBB. - MainBB->getInstList().back().eraseFromParent(); + MainBB->back().eraseFromParent(); QuotRemWithBB Long; Long.BB = MainBB; Long.Quotient = ConstantInt::get(getSlowType(), 0); @@ -434,7 +432,7 @@ Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { // Split the basic block before the div/rem. BasicBlock *SuccessorBB = MainBB->splitBasicBlock(SlowDivOrRem); // Remove the unconditional branch from MainBB to SuccessorBB. - MainBB->getInstList().back().eraseFromParent(); + MainBB->back().eraseFromParent(); QuotRemWithBB Fast = createFastBB(SuccessorBB); QuotRemWithBB Slow = createSlowBB(SuccessorBB); QuotRemPair Result = createDivRemPhiNodes(Fast, Slow, SuccessorBB); diff --git a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp index 1840f26add2d..d0b89ba2606e 100644 --- a/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp +++ b/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp @@ -36,7 +36,7 @@ bool CallGraphUpdater::finalize() { CallGraphNode *DeadCGN = (*CG)[DeadFn]; DeadCGN->removeAllCalledFunctions(); CG->getExternalCallingNode()->removeAnyCallEdgeTo(DeadCGN); - DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType())); + DeadFn->replaceAllUsesWith(PoisonValue::get(DeadFn->getType())); } // Then remove the node and function from the module. @@ -51,7 +51,7 @@ bool CallGraphUpdater::finalize() { // no call graph was provided. for (Function *DeadFn : DeadFunctions) { DeadFn->removeDeadConstantUsers(); - DeadFn->replaceAllUsesWith(UndefValue::get(DeadFn->getType())); + DeadFn->replaceAllUsesWith(PoisonValue::get(DeadFn->getType())); if (LCG && !ReplacedFunctions.count(DeadFn)) { // Taken mostly from the inliner: diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index e530afc277db..4a82f9606d3f 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -415,18 +415,8 @@ bool llvm::isLegalToPromote(const CallBase &CB, Function *Callee, // site. unsigned I = 0; for (; I < NumParams; ++I) { - Type *FormalTy = Callee->getFunctionType()->getFunctionParamType(I); - Type *ActualTy = CB.getArgOperand(I)->getType(); - if (FormalTy == ActualTy) - continue; - if (!CastInst::isBitOrNoopPointerCastable(ActualTy, FormalTy, DL)) { - if (FailureReason) - *FailureReason = "Argument type mismatch"; - return false; - } // Make sure that the callee and call agree on byval/inalloca. The types do // not have to match. - if (Callee->hasParamAttribute(I, Attribute::ByVal) != CB.getAttributes().hasParamAttr(I, Attribute::ByVal)) { if (FailureReason) @@ -439,6 +429,28 @@ bool llvm::isLegalToPromote(const CallBase &CB, Function *Callee, *FailureReason = "inalloca mismatch"; return false; } + + Type *FormalTy = Callee->getFunctionType()->getFunctionParamType(I); + Type *ActualTy = CB.getArgOperand(I)->getType(); + if (FormalTy == ActualTy) + continue; + if (!CastInst::isBitOrNoopPointerCastable(ActualTy, FormalTy, DL)) { + if (FailureReason) + *FailureReason = "Argument type mismatch"; + return false; + } + + // MustTail call needs stricter type match. See + // Verifier::verifyMustTailCall(). + if (CB.isMustTailCall()) { + PointerType *PF = dyn_cast<PointerType>(FormalTy); + PointerType *PA = dyn_cast<PointerType>(ActualTy); + if (!PF || !PA || PF->getAddressSpace() != PA->getAddressSpace()) { + if (FailureReason) + *FailureReason = "Musttail call Argument type mismatch"; + return false; + } + } } for (; I < NumArgs; I++) { // Vararg functions can have more arguments than parameters. diff --git a/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp b/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp index 9101a1e41f7b..4d622679dbdb 100644 --- a/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp +++ b/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp @@ -16,7 +16,7 @@ // @a = alias i8, i8 *@g <-- @a is now an alias to base object @g // @b = alias i8, i8 *@g // -// Eventually this file will implement full alias canonicalation, so that +// Eventually this file will implement full alias canonicalization, so that // all aliasees are private anonymous values. E.g. // @a = alias i8, i8 *@g // @g = global i8 0 diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp index 1d348213bfdb..87822ee85c2b 100644 --- a/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -33,6 +33,7 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <map> +#include <optional> using namespace llvm; #define DEBUG_TYPE "clone-function" @@ -46,7 +47,7 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, if (BB->hasName()) NewBB->setName(BB->getName() + NameSuffix); - bool hasCalls = false, hasDynamicAllocas = false; + bool hasCalls = false, hasDynamicAllocas = false, hasMemProfMetadata = false; Module *TheModule = F ? F->getParent() : nullptr; // Loop over all instructions, and copy them over. @@ -57,10 +58,13 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, Instruction *NewInst = I.clone(); if (I.hasName()) NewInst->setName(I.getName() + NameSuffix); - NewBB->getInstList().push_back(NewInst); + NewInst->insertInto(NewBB, NewBB->end()); VMap[&I] = NewInst; // Add instruction map to value. - hasCalls |= (isa<CallInst>(I) && !I.isDebugOrPseudoInst()); + if (isa<CallInst>(I) && !I.isDebugOrPseudoInst()) { + hasCalls = true; + hasMemProfMetadata |= I.hasMetadata(LLVMContext::MD_memprof); + } if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { if (!AI->isStaticAlloca()) { hasDynamicAllocas = true; @@ -70,6 +74,7 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, if (CodeInfo) { CodeInfo->ContainsCalls |= hasCalls; + CodeInfo->ContainsMemProfMetadata |= hasMemProfMetadata; CodeInfo->ContainsDynamicAllocas |= hasDynamicAllocas; } return NewBB; @@ -100,12 +105,26 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, NewFunc->copyAttributesFrom(OldFunc); NewFunc->setAttributes(NewAttrs); + const RemapFlags FuncGlobalRefFlags = + ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges; + // Fix up the personality function that got copied over. if (OldFunc->hasPersonalityFn()) - NewFunc->setPersonalityFn( - MapValue(OldFunc->getPersonalityFn(), VMap, - ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges, - TypeMapper, Materializer)); + NewFunc->setPersonalityFn(MapValue(OldFunc->getPersonalityFn(), VMap, + FuncGlobalRefFlags, TypeMapper, + Materializer)); + + if (OldFunc->hasPrefixData()) { + NewFunc->setPrefixData(MapValue(OldFunc->getPrefixData(), VMap, + FuncGlobalRefFlags, TypeMapper, + Materializer)); + } + + if (OldFunc->hasPrologueData()) { + NewFunc->setPrologueData(MapValue(OldFunc->getPrologueData(), VMap, + FuncGlobalRefFlags, TypeMapper, + Materializer)); + } SmallVector<AttributeSet, 4> NewArgAttrs(NewFunc->arg_size()); AttributeList OldAttrs = OldFunc->getAttributes(); @@ -132,7 +151,7 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, // duplicate instructions and then freeze them in the MD map. We also record // information about dbg.value and dbg.declare to avoid duplicating the // types. - Optional<DebugInfoFinder> DIFinder; + std::optional<DebugInfoFinder> DIFinder; // Track the subprogram attachment that needs to be cloned to fine-tune the // mapping within the same module. @@ -471,6 +490,7 @@ void PruningFunctionCloner::CloneBlock( } bool hasCalls = false, hasDynamicAllocas = false, hasStaticAllocas = false; + bool hasMemProfMetadata = false; // Loop over all instructions, and copy them over, DCE'ing as we go. This // loop doesn't include the terminator. @@ -487,8 +507,9 @@ void PruningFunctionCloner::CloneBlock( } // Eagerly remap operands to the newly cloned instruction, except for PHI - // nodes for which we defer processing until we update the CFG. - if (!isa<PHINode>(NewInst)) { + // nodes for which we defer processing until we update the CFG. Also defer + // debug intrinsic processing because they may contain use-before-defs. + if (!isa<PHINode>(NewInst) && !isa<DbgVariableIntrinsic>(NewInst)) { RemapInstruction(NewInst, VMap, ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges); @@ -514,8 +535,11 @@ void PruningFunctionCloner::CloneBlock( if (II->hasName()) NewInst->setName(II->getName() + NameSuffix); VMap[&*II] = NewInst; // Add instruction map to value. - NewBB->getInstList().push_back(NewInst); - hasCalls |= (isa<CallInst>(II) && !II->isDebugOrPseudoInst()); + NewInst->insertInto(NewBB, NewBB->end()); + if (isa<CallInst>(II) && !II->isDebugOrPseudoInst()) { + hasCalls = true; + hasMemProfMetadata |= II->hasMetadata(LLVMContext::MD_memprof); + } if (CodeInfo) { CodeInfo->OrigVMap[&*II] = NewInst; @@ -573,7 +597,7 @@ void PruningFunctionCloner::CloneBlock( Instruction *NewInst = OldTI->clone(); if (OldTI->hasName()) NewInst->setName(OldTI->getName() + NameSuffix); - NewBB->getInstList().push_back(NewInst); + NewInst->insertInto(NewBB, NewBB->end()); VMap[OldTI] = NewInst; // Add instruction map to value. if (CodeInfo) { @@ -589,6 +613,7 @@ void PruningFunctionCloner::CloneBlock( if (CodeInfo) { CodeInfo->ContainsCalls |= hasCalls; + CodeInfo->ContainsMemProfMetadata |= hasMemProfMetadata; CodeInfo->ContainsDynamicAllocas |= hasDynamicAllocas; CodeInfo->ContainsDynamicAllocas |= hasStaticAllocas && BB != &BB->getParent()->front(); @@ -628,6 +653,15 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, StartingInst = &StartingBB->front(); } + // Collect debug intrinsics for remapping later. + SmallVector<const DbgVariableIntrinsic *, 8> DbgIntrinsics; + for (const auto &BB : *OldFunc) { + for (const auto &I : BB) { + if (const auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) + DbgIntrinsics.push_back(DVI); + } + } + // Clone the entry block, and anything recursively reachable from it. std::vector<const BasicBlock *> CloneWorklist; PFC.CloneBlock(StartingBB, StartingInst->getIterator(), CloneWorklist); @@ -650,7 +684,7 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, continue; // Dead block. // Add the new block to the new function. - NewFunc->getBasicBlockList().push_back(NewBB); + NewFunc->insert(NewFunc->end(), NewBB); // Handle PHI nodes specially, as we have to remove references to dead // blocks. @@ -799,6 +833,19 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, VMap[OrigV] = I; } + // Remap debug intrinsic operands now that all values have been mapped. + // Doing this now (late) preserves use-before-defs in debug intrinsics. If + // we didn't do this, ValueAsMetadata(use-before-def) operands would be + // replaced by empty metadata. This would signal later cleanup passes to + // remove the debug intrinsics, potentially causing incorrect locations. + for (const auto *DVI : DbgIntrinsics) { + if (DbgVariableIntrinsic *NewDVI = + cast_or_null<DbgVariableIntrinsic>(VMap.lookup(DVI))) + RemapInstruction(NewDVI, VMap, + ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges, + TypeMapper, Materializer); + } + // Simplify conditional branches and switches with a constant operand. We try // to prune these out when cloning, but if the simplification required // looking through PHI nodes, those are only available after forming the full @@ -856,7 +903,7 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, Dest->replaceAllUsesWith(&*I); // Move all the instructions in the succ to the pred. - I->getInstList().splice(I->end(), Dest->getInstList()); + I->splice(I->end(), Dest); // Remove the dest block. Dest->eraseFromParent(); @@ -980,10 +1027,9 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, } // Move them physically from the end of the block list. - F->getBasicBlockList().splice(Before->getIterator(), F->getBasicBlockList(), - NewPH); - F->getBasicBlockList().splice(Before->getIterator(), F->getBasicBlockList(), - NewLoop->getHeader()->getIterator(), F->end()); + F->splice(Before->getIterator(), F, NewPH->getIterator()); + F->splice(Before->getIterator(), F, NewLoop->getHeader()->getIterator(), + F->end()); return NewLoop; } @@ -1041,7 +1087,7 @@ void llvm::cloneNoAliasScopes(ArrayRef<MDNode *> NoAliasDeclScopes, MDBuilder MDB(Context); for (auto *ScopeList : NoAliasDeclScopes) { - for (auto &MDOperand : ScopeList->operands()) { + for (const auto &MDOperand : ScopeList->operands()) { if (MDNode *MD = dyn_cast<MDNode>(MDOperand)) { AliasScopeNode SNANode(MD); @@ -1066,7 +1112,7 @@ void llvm::adaptNoAliasScopes(Instruction *I, auto CloneScopeList = [&](const MDNode *ScopeList) -> MDNode * { bool NeedsReplacement = false; SmallVector<Metadata *, 8> NewScopeList; - for (auto &MDOp : ScopeList->operands()) { + for (const auto &MDOp : ScopeList->operands()) { if (MDNode *MD = dyn_cast<MDNode>(MDOp)) { if (auto *NewMD = ClonedScopes.lookup(MD)) { NewScopeList.push_back(NewMD); diff --git a/llvm/lib/Transforms/Utils/CloneModule.cpp b/llvm/lib/Transforms/Utils/CloneModule.cpp index 55cda0f11e47..55e051298a9a 100644 --- a/llvm/lib/Transforms/Utils/CloneModule.cpp +++ b/llvm/lib/Transforms/Utils/CloneModule.cpp @@ -109,6 +109,15 @@ std::unique_ptr<Module> llvm::CloneModule( VMap[&I] = GA; } + for (const GlobalIFunc &I : M.ifuncs()) { + // Defer setting the resolver function until after functions are cloned. + auto *GI = + GlobalIFunc::create(I.getValueType(), I.getAddressSpace(), + I.getLinkage(), I.getName(), nullptr, New.get()); + GI->copyAttributesFrom(&I); + VMap[&I] = GI; + } + // Now that all of the things that global variable initializer can refer to // have been created, loop through and copy the global variable referrers // over... We also set the attributes on the global now. @@ -184,6 +193,12 @@ std::unique_ptr<Module> llvm::CloneModule( GA->setAliasee(MapValue(C, VMap)); } + for (const GlobalIFunc &I : M.ifuncs()) { + GlobalIFunc *GI = cast<GlobalIFunc>(VMap[&I]); + if (const Constant *Resolver = I.getResolver()) + GI->setResolver(MapValue(Resolver, VMap)); + } + // And named metadata.... for (const NamedMDNode &NMD : M.named_metadata()) { NamedMDNode *NewNMD = New->getOrInsertNamedMetadata(NMD.getName()); diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 421f1f329f07..c1fe10504e45 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -15,7 +15,6 @@ #include "llvm/Transforms/Utils/CodeExtractor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -138,7 +137,7 @@ static bool isBlockValidForExtraction(const BasicBlock &BB, if (auto *UBB = CSI->getUnwindDest()) if (!Result.count(UBB)) return false; - for (auto *HBB : CSI->handlers()) + for (const auto *HBB : CSI->handlers()) if (!Result.count(const_cast<BasicBlock*>(HBB))) return false; continue; @@ -831,6 +830,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, std::vector<Type *> ParamTy; std::vector<Type *> AggParamTy; ValueSet StructValues; + const DataLayout &DL = M->getDataLayout(); // Add the types of the input values to the function's argument list for (Value *value : inputs) { @@ -849,7 +849,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, AggParamTy.push_back(output->getType()); StructValues.insert(output); } else - ParamTy.push_back(PointerType::getUnqual(output->getType())); + ParamTy.push_back( + PointerType::get(output->getType(), DL.getAllocaAddrSpace())); } assert( @@ -864,7 +865,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, StructType *StructTy = nullptr; if (AggregateArgs && !AggParamTy.empty()) { StructTy = StructType::get(M->getContext(), AggParamTy); - ParamTy.push_back(PointerType::getUnqual(StructTy)); + ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace())); } LLVM_DEBUG({ @@ -902,26 +903,21 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // Those attributes cannot be propagated safely. Explicitly list them // here so we get a warning if new attributes are added. case Attribute::AllocSize: - case Attribute::ArgMemOnly: case Attribute::Builtin: case Attribute::Convergent: - case Attribute::InaccessibleMemOnly: - case Attribute::InaccessibleMemOrArgMemOnly: case Attribute::JumpTable: case Attribute::Naked: case Attribute::NoBuiltin: case Attribute::NoMerge: case Attribute::NoReturn: case Attribute::NoSync: - case Attribute::ReadNone: - case Attribute::ReadOnly: case Attribute::ReturnsTwice: case Attribute::Speculatable: case Attribute::StackAlignment: case Attribute::WillReturn: - case Attribute::WriteOnly: case Attribute::AllocKind: case Attribute::PresplitCoroutine: + case Attribute::Memory: continue; // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: @@ -963,6 +959,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::NoCfCheck: case Attribute::MustProgress: case Attribute::NoProfile: + case Attribute::SkipProfile: break; // These attributes cannot be applied to functions. case Attribute::Alignment: @@ -980,6 +977,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::NoUndef: case Attribute::NonNull: case Attribute::Preallocated: + case Attribute::ReadNone: + case Attribute::ReadOnly: case Attribute::Returned: case Attribute::SExt: case Attribute::StructRet: @@ -989,6 +988,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::ZExt: case Attribute::ImmArg: case Attribute::ByRef: + case Attribute::WriteOnly: // These are not really attributes. case Attribute::None: case Attribute::EndAttrKinds: @@ -999,7 +999,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, newFunction->addFnAttr(Attr); } - newFunction->getBasicBlockList().push_back(newRootNode); + newFunction->insert(newFunction->end(), newRootNode); // Create scalar and aggregate iterators to name all of the arguments we // inserted. @@ -1208,7 +1208,7 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName()); - codeReplacer->getInstList().push_back(GEP); + GEP->insertInto(codeReplacer, codeReplacer->end()); new StoreInst(StructValues[i], GEP, codeReplacer); NumAggregatedInputs++; } @@ -1226,7 +1226,7 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc()) call->setDebugLoc(DL); } - codeReplacer->getInstList().push_back(call); + call->insertInto(codeReplacer, codeReplacer->end()); // Set swifterror parameter attributes. for (unsigned SwiftErrArgNo : SwiftErrorArgs) { @@ -1246,7 +1246,7 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName()); - codeReplacer->getInstList().push_back(GEP); + GEP->insertInto(codeReplacer, codeReplacer->end()); Output = GEP; ++aggIdx; } else { @@ -1258,8 +1258,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, codeReplacer); Reloads.push_back(load); std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end()); - for (unsigned u = 0, e = Users.size(); u != e; ++u) { - Instruction *inst = cast<Instruction>(Users[u]); + for (User *U : Users) { + Instruction *inst = cast<Instruction>(U); if (!Blocks.count(inst->getParent())) inst->replaceUsesOfWith(outputs[i], load); } @@ -1435,21 +1435,17 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } void CodeExtractor::moveCodeToFunction(Function *newFunction) { - Function *oldFunc = (*Blocks.begin())->getParent(); - Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); - Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); - auto newFuncIt = newFunction->front().getIterator(); for (BasicBlock *Block : Blocks) { // Delete the basic block from the old function, and the list of blocks - oldBlocks.remove(Block); + Block->removeFromParent(); // Insert this basic block into the new function // Insert the original blocks after the entry block created // for the new function. The entry block may be followed // by a set of exit blocks at this point, but these exit // blocks better be placed at the end of the new function. - newFuncIt = newBlocks.insertAfter(newFuncIt, Block); + newFuncIt = newFunction->insert(std::next(newFuncIt), Block); } } @@ -1538,7 +1534,8 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, assert(OldSP->getUnit() && "Missing compile unit for subprogram"); DIBuilder DIB(*OldFunc.getParent(), /*AllowUnresolved=*/false, OldSP->getUnit()); - auto SPType = DIB.createSubroutineType(DIB.getOrCreateTypeArray(None)); + auto SPType = + DIB.createSubroutineType(DIB.getOrCreateTypeArray(std::nullopt)); DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition | DISubprogram::SPFlagOptimized | DISubprogram::SPFlagLocalToUnit; @@ -1555,18 +1552,25 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, // point to a variable in the wrong scope. SmallDenseMap<DINode *, DINode *> RemappedMetadata; SmallVector<Instruction *, 4> DebugIntrinsicsToDelete; + DenseMap<const MDNode *, MDNode *> Cache; for (Instruction &I : instructions(NewFunc)) { auto *DII = dyn_cast<DbgInfoIntrinsic>(&I); if (!DII) continue; - // Point the intrinsic to a fresh label within the new function. + // Point the intrinsic to a fresh label within the new function if the + // intrinsic was not inlined from some other function. if (auto *DLI = dyn_cast<DbgLabelInst>(&I)) { + if (DLI->getDebugLoc().getInlinedAt()) + continue; DILabel *OldLabel = DLI->getLabel(); DINode *&NewLabel = RemappedMetadata[OldLabel]; - if (!NewLabel) - NewLabel = DILabel::get(Ctx, NewSP, OldLabel->getName(), + if (!NewLabel) { + DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram( + *OldLabel->getScope(), *NewSP, Ctx, Cache); + NewLabel = DILabel::get(Ctx, NewScope, OldLabel->getName(), OldLabel->getFile(), OldLabel->getLine()); + } DLI->setArgOperand(0, MetadataAsValue::get(Ctx, NewLabel)); continue; } @@ -1587,17 +1591,23 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, DebugIntrinsicsToDelete.push_back(DVI); continue; } - - // Point the intrinsic to a fresh variable within the new function. - DILocalVariable *OldVar = DVI->getVariable(); - DINode *&NewVar = RemappedMetadata[OldVar]; - if (!NewVar) - NewVar = DIB.createAutoVariable( - NewSP, OldVar->getName(), OldVar->getFile(), OldVar->getLine(), - OldVar->getType(), /*AlwaysPreserve=*/false, DINode::FlagZero, - OldVar->getAlignInBits()); - DVI->setVariable(cast<DILocalVariable>(NewVar)); + // If the variable was in the scope of the old function, i.e. it was not + // inlined, point the intrinsic to a fresh variable within the new function. + if (!DVI->getDebugLoc().getInlinedAt()) { + DILocalVariable *OldVar = DVI->getVariable(); + DINode *&NewVar = RemappedMetadata[OldVar]; + if (!NewVar) { + DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram( + *OldVar->getScope(), *NewSP, Ctx, Cache); + NewVar = DIB.createAutoVariable( + NewScope, OldVar->getName(), OldVar->getFile(), OldVar->getLine(), + OldVar->getType(), /*AlwaysPreserve=*/false, DINode::FlagZero, + OldVar->getAlignInBits()); + } + DVI->setVariable(cast<DILocalVariable>(NewVar)); + } } + for (auto *DII : DebugIntrinsicsToDelete) DII->eraseFromParent(); DIB.finalizeSubprogram(NewSP); @@ -1606,13 +1616,13 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, // function. for (Instruction &I : instructions(NewFunc)) { if (const DebugLoc &DL = I.getDebugLoc()) - I.setDebugLoc(DILocation::get(Ctx, DL.getLine(), DL.getCol(), NewSP)); + I.setDebugLoc( + DebugLoc::replaceInlinedAtSubprogram(DL, *NewSP, Ctx, Cache)); // Loop info metadata may contain line locations. Fix them up. - auto updateLoopInfoLoc = [&Ctx, NewSP](Metadata *MD) -> Metadata * { + auto updateLoopInfoLoc = [&Ctx, &Cache, NewSP](Metadata *MD) -> Metadata * { if (auto *Loc = dyn_cast_or_null<DILocation>(MD)) - return DILocation::get(Ctx, Loc->getLine(), Loc->getColumn(), NewSP, - nullptr); + return DebugLoc::replaceInlinedAtSubprogram(Loc, *NewSP, Ctx, Cache); return MD; }; updateLoopMetadataDebugLocations(I, updateLoopInfoLoc); @@ -1653,14 +1663,14 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, } } - // Remove @llvm.assume calls that will be moved to the new function from the - // old function's assumption cache. + // Remove CondGuardInsts that will be moved to the new function from the old + // function's assumption cache. for (BasicBlock *Block : Blocks) { for (Instruction &I : llvm::make_early_inc_range(*Block)) { - if (auto *AI = dyn_cast<AssumeInst>(&I)) { + if (auto *CI = dyn_cast<CondGuardInst>(&I)) { if (AC) - AC->unregisterAssumption(AI); - AI->eraseFromParent(); + AC->unregisterAssumption(CI); + CI->eraseFromParent(); } } } @@ -1725,7 +1735,7 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, }); }); } - newFuncRoot->getInstList().push_back(BranchI); + BranchI->insertInto(newFuncRoot, newFuncRoot->end()); ValueSet SinkingCands, HoistingCands; BasicBlock *CommonExit = nullptr; @@ -1778,7 +1788,7 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); if (Count) newFunction->setEntryCount( - ProfileCount(Count.value(), Function::PCT_Real)); // FIXME + ProfileCount(*Count, Function::PCT_Real)); // FIXME BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); } @@ -1854,7 +1864,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc, const Function &NewFunc, AssumptionCache *AC) { for (auto AssumeVH : AC->assumptions()) { - auto *I = dyn_cast_or_null<CallInst>(AssumeVH); + auto *I = dyn_cast_or_null<CondGuardInst>(AssumeVH); if (!I) continue; @@ -1866,7 +1876,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc, // that were previously in the old function, but that have now been moved // to the new function. for (auto AffectedValVH : AC->assumptionsFor(I->getOperand(0))) { - auto *AffectedCI = dyn_cast_or_null<CallInst>(AffectedValVH); + auto *AffectedCI = dyn_cast_or_null<CondGuardInst>(AffectedValVH); if (!AffectedCI) continue; if (AffectedCI->getFunction() != &OldFunc) diff --git a/llvm/lib/Transforms/Utils/CodeLayout.cpp b/llvm/lib/Transforms/Utils/CodeLayout.cpp index 1ff0f148b3a9..9eb3aff3ffe8 100644 --- a/llvm/lib/Transforms/Utils/CodeLayout.cpp +++ b/llvm/lib/Transforms/Utils/CodeLayout.cpp @@ -35,12 +35,15 @@ // Reference: // * A. Newell and S. Pupyrev, Improved Basic Block Reordering, // IEEE Transactions on Computers, 2020 +// https://arxiv.org/abs/1809.04676 // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CodeLayout.h" #include "llvm/Support/CommandLine.h" +#include <cmath> + using namespace llvm; #define DEBUG_TYPE "code-layout" @@ -54,40 +57,56 @@ cl::opt<bool> ApplyExtTspWithoutProfile( cl::desc("Whether to apply ext-tsp placement for instances w/o profile"), cl::init(true), cl::Hidden); -// Algorithm-specific constants. The values are tuned for the best performance +// Algorithm-specific params. The values are tuned for the best performance // of large-scale front-end bound binaries. -static cl::opt<double> - ForwardWeight("ext-tsp-forward-weight", cl::Hidden, cl::init(0.1), - cl::desc("The weight of forward jumps for ExtTSP value")); +static cl::opt<double> ForwardWeightCond( + "ext-tsp-forward-weight-cond", cl::ReallyHidden, cl::init(0.1), + cl::desc("The weight of conditional forward jumps for ExtTSP value")); + +static cl::opt<double> ForwardWeightUncond( + "ext-tsp-forward-weight-uncond", cl::ReallyHidden, cl::init(0.1), + cl::desc("The weight of unconditional forward jumps for ExtTSP value")); + +static cl::opt<double> BackwardWeightCond( + "ext-tsp-backward-weight-cond", cl::ReallyHidden, cl::init(0.1), + cl::desc("The weight of conditonal backward jumps for ExtTSP value")); + +static cl::opt<double> BackwardWeightUncond( + "ext-tsp-backward-weight-uncond", cl::ReallyHidden, cl::init(0.1), + cl::desc("The weight of unconditonal backward jumps for ExtTSP value")); + +static cl::opt<double> FallthroughWeightCond( + "ext-tsp-fallthrough-weight-cond", cl::ReallyHidden, cl::init(1.0), + cl::desc("The weight of conditional fallthrough jumps for ExtTSP value")); -static cl::opt<double> - BackwardWeight("ext-tsp-backward-weight", cl::Hidden, cl::init(0.1), - cl::desc("The weight of backward jumps for ExtTSP value")); +static cl::opt<double> FallthroughWeightUncond( + "ext-tsp-fallthrough-weight-uncond", cl::ReallyHidden, cl::init(1.05), + cl::desc("The weight of unconditional fallthrough jumps for ExtTSP value")); static cl::opt<unsigned> ForwardDistance( - "ext-tsp-forward-distance", cl::Hidden, cl::init(1024), + "ext-tsp-forward-distance", cl::ReallyHidden, cl::init(1024), cl::desc("The maximum distance (in bytes) of a forward jump for ExtTSP")); static cl::opt<unsigned> BackwardDistance( - "ext-tsp-backward-distance", cl::Hidden, cl::init(640), + "ext-tsp-backward-distance", cl::ReallyHidden, cl::init(640), cl::desc("The maximum distance (in bytes) of a backward jump for ExtTSP")); // The maximum size of a chain created by the algorithm. The size is bounded // so that the algorithm can efficiently process extremely large instance. static cl::opt<unsigned> - MaxChainSize("ext-tsp-max-chain-size", cl::Hidden, cl::init(4096), + MaxChainSize("ext-tsp-max-chain-size", cl::ReallyHidden, cl::init(4096), cl::desc("The maximum size of a chain to create.")); // The maximum size of a chain for splitting. Larger values of the threshold // may yield better quality at the cost of worsen run-time. static cl::opt<unsigned> ChainSplitThreshold( - "ext-tsp-chain-split-threshold", cl::Hidden, cl::init(128), + "ext-tsp-chain-split-threshold", cl::ReallyHidden, cl::init(128), cl::desc("The maximum size of a chain to apply splitting")); // The option enables splitting (large) chains along in-coming and out-going // jumps. This typically results in a better quality. static cl::opt<bool> EnableChainSplitAlongJumps( - "ext-tsp-enable-chain-split-along-jumps", cl::Hidden, cl::init(true), + "ext-tsp-enable-chain-split-along-jumps", cl::ReallyHidden, cl::init(true), cl::desc("The maximum size of a chain to apply splitting")); namespace { @@ -95,31 +114,37 @@ namespace { // Epsilon for comparison of doubles. constexpr double EPS = 1e-8; +// Compute the Ext-TSP score for a given jump. +double jumpExtTSPScore(uint64_t JumpDist, uint64_t JumpMaxDist, uint64_t Count, + double Weight) { + if (JumpDist > JumpMaxDist) + return 0; + double Prob = 1.0 - static_cast<double>(JumpDist) / JumpMaxDist; + return Weight * Prob * Count; +} + // Compute the Ext-TSP score for a jump between a given pair of blocks, // using their sizes, (estimated) addresses and the jump execution count. double extTSPScore(uint64_t SrcAddr, uint64_t SrcSize, uint64_t DstAddr, - uint64_t Count) { + uint64_t Count, bool IsConditional) { // Fallthrough if (SrcAddr + SrcSize == DstAddr) { - // Assume that FallthroughWeight = 1.0 after normalization - return static_cast<double>(Count); + return jumpExtTSPScore(0, 1, Count, + IsConditional ? FallthroughWeightCond + : FallthroughWeightUncond); } // Forward if (SrcAddr + SrcSize < DstAddr) { - const auto Dist = DstAddr - (SrcAddr + SrcSize); - if (Dist <= ForwardDistance) { - double Prob = 1.0 - static_cast<double>(Dist) / ForwardDistance; - return ForwardWeight * Prob * Count; - } - return 0; + const uint64_t Dist = DstAddr - (SrcAddr + SrcSize); + return jumpExtTSPScore(Dist, ForwardDistance, Count, + IsConditional ? ForwardWeightCond + : ForwardWeightUncond); } // Backward - const auto Dist = SrcAddr + SrcSize - DstAddr; - if (Dist <= BackwardDistance) { - double Prob = 1.0 - static_cast<double>(Dist) / BackwardDistance; - return BackwardWeight * Prob * Count; - } - return 0; + const uint64_t Dist = SrcAddr + SrcSize - DstAddr; + return jumpExtTSPScore(Dist, BackwardDistance, Count, + IsConditional ? BackwardWeightCond + : BackwardWeightUncond); } /// A type of merging two chains, X and Y. The former chain is split into @@ -191,8 +216,8 @@ public: std::vector<Jump *> InJumps; public: - explicit Block(size_t Index, uint64_t Size_, uint64_t EC) - : Index(Index), Size(Size_), ExecutionCount(EC) {} + explicit Block(size_t Index, uint64_t Size, uint64_t EC) + : Index(Index), Size(Size), ExecutionCount(EC) {} bool isEntry() const { return Index == 0; } }; @@ -210,6 +235,8 @@ public: Block *Target; // Execution count of the arc in the profile data. uint64_t ExecutionCount{0}; + // Whether the jump corresponds to a conditional branch. + bool IsConditional{false}; public: explicit Jump(Block *Source, Block *Target, uint64_t ExecutionCount) @@ -231,6 +258,14 @@ public: bool isEntry() const { return Blocks[0]->Index == 0; } + bool isCold() const { + for (auto *Block : Blocks) { + if (Block->ExecutionCount > 0) + return false; + } + return true; + } + double score() const { return Score; } void setScore(double NewScore) { Score = NewScore; } @@ -371,10 +406,10 @@ void Chain::mergeEdges(Chain *Other) { // Update edges adjacent to chain Other for (auto EdgeIt : Other->Edges) { - const auto DstChain = EdgeIt.first; - const auto DstEdge = EdgeIt.second; - const auto TargetChain = DstChain == Other ? this : DstChain; - auto CurEdge = getEdge(TargetChain); + Chain *DstChain = EdgeIt.first; + ChainEdge *DstEdge = EdgeIt.second; + Chain *TargetChain = DstChain == Other ? this : DstChain; + ChainEdge *CurEdge = getEdge(TargetChain); if (CurEdge == nullptr) { DstEdge->changeEndpoint(Other, this); this->addEdge(TargetChain, DstEdge); @@ -436,7 +471,7 @@ private: /// The implementation of the ExtTSP algorithm. class ExtTSPImpl { using EdgeT = std::pair<uint64_t, uint64_t>; - using EdgeCountMap = DenseMap<EdgeT, uint64_t>; + using EdgeCountMap = std::vector<std::pair<EdgeT, uint64_t>>; public: ExtTSPImpl(size_t NumNodes, const std::vector<uint64_t> &NodeSizes, @@ -478,12 +513,14 @@ private: } // Initialize jumps between blocks - SuccNodes = std::vector<std::vector<uint64_t>>(NumNodes); - PredNodes = std::vector<std::vector<uint64_t>>(NumNodes); + SuccNodes.resize(NumNodes); + PredNodes.resize(NumNodes); + std::vector<uint64_t> OutDegree(NumNodes, 0); AllJumps.reserve(EdgeCounts.size()); for (auto It : EdgeCounts) { auto Pred = It.first.first; auto Succ = It.first.second; + OutDegree[Pred]++; // Ignore self-edges if (Pred == Succ) continue; @@ -499,11 +536,15 @@ private: Block.OutJumps.push_back(&AllJumps.back()); } } + for (auto &Jump : AllJumps) { + assert(OutDegree[Jump.Source->Index] > 0); + Jump.IsConditional = OutDegree[Jump.Source->Index] > 1; + } // Initialize chains AllChains.reserve(NumNodes); HotChains.reserve(NumNodes); - for (auto &Block : AllBlocks) { + for (Block &Block : AllBlocks) { AllChains.emplace_back(Block.Index, &Block); Block.CurChain = &AllChains.back(); if (Block.ExecutionCount > 0) { @@ -513,10 +554,10 @@ private: // Initialize chain edges AllEdges.reserve(AllJumps.size()); - for (auto &Block : AllBlocks) { + for (Block &Block : AllBlocks) { for (auto &Jump : Block.OutJumps) { auto SuccBlock = Jump->Target; - auto CurEdge = Block.CurChain->getEdge(SuccBlock->CurChain); + ChainEdge *CurEdge = Block.CurChain->getEdge(SuccBlock->CurChain); // this edge is already present in the graph if (CurEdge != nullptr) { assert(SuccBlock->CurChain->getEdge(Block.CurChain) != nullptr); @@ -596,11 +637,11 @@ private: Chain *BestChainSucc = nullptr; auto BestGain = MergeGainTy(); // Iterate over all pairs of chains - for (auto ChainPred : HotChains) { + for (Chain *ChainPred : HotChains) { // Get candidates for merging with the current chain for (auto EdgeIter : ChainPred->edges()) { - auto ChainSucc = EdgeIter.first; - auto ChainEdge = EdgeIter.second; + Chain *ChainSucc = EdgeIter.first; + class ChainEdge *ChainEdge = EdgeIter.second; // Ignore loop edges if (ChainPred == ChainSucc) continue; @@ -610,7 +651,8 @@ private: continue; // Compute the gain of merging the two chains - auto CurGain = getBestMergeGain(ChainPred, ChainSucc, ChainEdge); + MergeGainTy CurGain = + getBestMergeGain(ChainPred, ChainSucc, ChainEdge); if (CurGain.score() <= EPS) continue; @@ -635,11 +677,13 @@ private: } } - /// Merge cold blocks to reduce code size. + /// Merge remaining blocks into chains w/o taking jump counts into + /// consideration. This allows to maintain the original block order in the + /// absense of profile data void mergeColdChains() { for (size_t SrcBB = 0; SrcBB < NumNodes; SrcBB++) { - // Iterating over neighbors in the reverse order to make sure original - // fallthrough jumps are merged first + // Iterating in reverse order to make sure original fallthrough jumps are + // merged first; this might be beneficial for code size. size_t NumSuccs = SuccNodes[SrcBB].size(); for (size_t Idx = 0; Idx < NumSuccs; Idx++) { auto DstBB = SuccNodes[SrcBB][NumSuccs - Idx - 1]; @@ -647,7 +691,8 @@ private: auto DstChain = AllBlocks[DstBB].CurChain; if (SrcChain != DstChain && !DstChain->isEntry() && SrcChain->blocks().back()->Index == SrcBB && - DstChain->blocks().front()->Index == DstBB) { + DstChain->blocks().front()->Index == DstBB && + SrcChain->isCold() == DstChain->isCold()) { mergeChains(SrcChain, DstChain, 0, MergeTypeTy::X_Y); } } @@ -667,10 +712,11 @@ private: double Score = 0; for (auto &Jump : Jumps) { - const auto SrcBlock = Jump->Source; - const auto DstBlock = Jump->Target; + const Block *SrcBlock = Jump->Source; + const Block *DstBlock = Jump->Target; Score += ::extTSPScore(SrcBlock->EstimatedAddr, SrcBlock->Size, - DstBlock->EstimatedAddr, Jump->ExecutionCount); + DstBlock->EstimatedAddr, Jump->ExecutionCount, + Jump->IsConditional); } return Score; } @@ -689,7 +735,7 @@ private: // Precompute jumps between ChainPred and ChainSucc auto Jumps = Edge->jumps(); - auto EdgePP = ChainPred->getEdge(ChainPred); + ChainEdge *EdgePP = ChainPred->getEdge(ChainPred); if (EdgePP != nullptr) { Jumps.insert(Jumps.end(), EdgePP->jumps().begin(), EdgePP->jumps().end()); } @@ -711,7 +757,7 @@ private: return; // Apply the merge, compute the corresponding gain, and update the best // value, if the merge is beneficial - for (auto &MergeType : MergeTypes) { + for (const auto &MergeType : MergeTypes) { Gain.updateIfLessThan( computeMergeGain(ChainPred, ChainSucc, Jumps, Offset, MergeType)); } @@ -778,7 +824,7 @@ private: /// Merge two chains of blocks respecting a given merge 'type' and 'offset'. /// - /// If MergeType == 0, then the result is a concatentation of two chains. + /// If MergeType == 0, then the result is a concatenation of two chains. /// Otherwise, the first chain is cut into two sub-chains at the offset, /// and merged using all possible ways of concatenating three chains. MergedChain mergeBlocks(const std::vector<Block *> &X, @@ -813,22 +859,21 @@ private: assert(Into != From && "a chain cannot be merged with itself"); // Merge the blocks - auto MergedBlocks = + MergedChain MergedBlocks = mergeBlocks(Into->blocks(), From->blocks(), MergeOffset, MergeType); Into->merge(From, MergedBlocks.getBlocks()); Into->mergeEdges(From); From->clear(); // Update cached ext-tsp score for the new chain - auto SelfEdge = Into->getEdge(Into); + ChainEdge *SelfEdge = Into->getEdge(Into); if (SelfEdge != nullptr) { MergedBlocks = MergedChain(Into->blocks().begin(), Into->blocks().end()); Into->setScore(extTSPScore(MergedBlocks, SelfEdge->jumps())); } // Remove chain From from the list of active chains - auto Iter = std::remove(HotChains.begin(), HotChains.end(), From); - HotChains.erase(Iter, HotChains.end()); + llvm::erase_value(HotChains, From); // Invalidate caches for (auto EdgeIter : Into->edges()) { @@ -847,7 +892,7 @@ private: // Using doubles to avoid overflow of ExecutionCount double Size = 0; double ExecutionCount = 0; - for (auto Block : Chain.blocks()) { + for (auto *Block : Chain.blocks()) { Size += static_cast<double>(Block->Size); ExecutionCount += static_cast<double>(Block->ExecutionCount); } @@ -859,7 +904,7 @@ private: // Sorting chains by density in the decreasing order std::stable_sort(SortedChains.begin(), SortedChains.end(), [&](const Chain *C1, const Chain *C2) { - // Makre sure the original entry block is at the + // Make sure the original entry block is at the // beginning of the order if (C1->isEntry() != C2->isEntry()) { return C1->isEntry(); @@ -873,8 +918,8 @@ private: // Collect the blocks in the order specified by their chains Order.reserve(NumNodes); - for (auto Chain : SortedChains) { - for (auto Block : Chain->blocks()) { + for (Chain *Chain : SortedChains) { + for (Block *Block : Chain->blocks()) { Order.push_back(Block->Index); } } @@ -911,7 +956,7 @@ private: std::vector<uint64_t> llvm::applyExtTspLayout( const std::vector<uint64_t> &NodeSizes, const std::vector<uint64_t> &NodeCounts, - const DenseMap<std::pair<uint64_t, uint64_t>, uint64_t> &EdgeCounts) { + const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) { size_t NumNodes = NodeSizes.size(); // Verify correctness of the input data. @@ -932,12 +977,17 @@ std::vector<uint64_t> llvm::applyExtTspLayout( double llvm::calcExtTspScore( const std::vector<uint64_t> &Order, const std::vector<uint64_t> &NodeSizes, const std::vector<uint64_t> &NodeCounts, - const DenseMap<std::pair<uint64_t, uint64_t>, uint64_t> &EdgeCounts) { + const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) { // Estimate addresses of the blocks in memory - auto Addr = std::vector<uint64_t>(NodeSizes.size(), 0); + std::vector<uint64_t> Addr(NodeSizes.size(), 0); for (size_t Idx = 1; Idx < Order.size(); Idx++) { Addr[Order[Idx]] = Addr[Order[Idx - 1]] + NodeSizes[Order[Idx - 1]]; } + std::vector<uint64_t> OutDegree(NodeSizes.size(), 0); + for (auto It : EdgeCounts) { + auto Pred = It.first.first; + OutDegree[Pred]++; + } // Increase the score for each jump double Score = 0; @@ -945,7 +995,9 @@ double llvm::calcExtTspScore( auto Pred = It.first.first; auto Succ = It.first.second; uint64_t Count = It.second; - Score += extTSPScore(Addr[Pred], NodeSizes[Pred], Addr[Succ], Count); + bool IsConditional = OutDegree[Pred] > 1; + Score += ::extTSPScore(Addr[Pred], NodeSizes[Pred], Addr[Succ], Count, + IsConditional); } return Score; } @@ -953,8 +1005,8 @@ double llvm::calcExtTspScore( double llvm::calcExtTspScore( const std::vector<uint64_t> &NodeSizes, const std::vector<uint64_t> &NodeCounts, - const DenseMap<std::pair<uint64_t, uint64_t>, uint64_t> &EdgeCounts) { - auto Order = std::vector<uint64_t>(NodeSizes.size()); + const std::vector<std::pair<EdgeT, uint64_t>> &EdgeCounts) { + std::vector<uint64_t> Order(NodeSizes.size()); for (size_t Idx = 0; Idx < NodeSizes.size(); Idx++) { Order[Idx] = Idx; } diff --git a/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp b/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp index 648f4e64a4d2..4a6719741719 100644 --- a/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp +++ b/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CodeMoverUtils.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/PostDominators.h" @@ -58,9 +57,9 @@ class ControlConditions { public: /// Return a ControlConditions which stores all conditions required to execute /// \p BB from \p Dominator. If \p MaxLookup is non-zero, it limits the - /// number of conditions to collect. Return None if not all conditions are - /// collected successfully, or we hit the limit. - static const Optional<ControlConditions> + /// number of conditions to collect. Return std::nullopt if not all conditions + /// are collected successfully, or we hit the limit. + static const std::optional<ControlConditions> collectControlConditions(const BasicBlock &BB, const BasicBlock &Dominator, const DominatorTree &DT, const PostDominatorTree &PDT, @@ -105,9 +104,12 @@ static bool domTreeLevelBefore(DominatorTree *DT, const Instruction *InstA, return DA->getLevel() < DB->getLevel(); } -const Optional<ControlConditions> ControlConditions::collectControlConditions( - const BasicBlock &BB, const BasicBlock &Dominator, const DominatorTree &DT, - const PostDominatorTree &PDT, unsigned MaxLookup) { +const std::optional<ControlConditions> +ControlConditions::collectControlConditions(const BasicBlock &BB, + const BasicBlock &Dominator, + const DominatorTree &DT, + const PostDominatorTree &PDT, + unsigned MaxLookup) { assert(DT.dominates(&Dominator, &BB) && "Expecting Dominator to dominate BB"); ControlConditions Conditions; @@ -129,7 +131,7 @@ const Optional<ControlConditions> ControlConditions::collectControlConditions( // Limitation: can only handle branch instruction currently. const BranchInst *BI = dyn_cast<BranchInst>(IDom->getTerminator()); if (!BI) - return None; + return std::nullopt; bool Inserted = false; if (PDT.dominates(CurBlock, IDom)) { @@ -149,13 +151,13 @@ const Optional<ControlConditions> ControlConditions::collectControlConditions( Inserted = Conditions.addControlCondition( ControlCondition(BI->getCondition(), false)); } else - return None; + return std::nullopt; if (Inserted) ++NumConditions; if (MaxLookup != 0 && NumConditions > MaxLookup) - return None; + return std::nullopt; CurBlock = IDom; } while (CurBlock != &Dominator); @@ -249,16 +251,16 @@ bool llvm::isControlFlowEquivalent(const BasicBlock &BB0, const BasicBlock &BB1, << " and " << BB1.getName() << " is " << CommonDominator->getName() << "\n"); - const Optional<ControlConditions> BB0Conditions = + const std::optional<ControlConditions> BB0Conditions = ControlConditions::collectControlConditions(BB0, *CommonDominator, DT, PDT); - if (BB0Conditions == None) + if (BB0Conditions == std::nullopt) return false; - const Optional<ControlConditions> BB1Conditions = + const std::optional<ControlConditions> BB1Conditions = ControlConditions::collectControlConditions(BB1, *CommonDominator, DT, PDT); - if (BB1Conditions == None) + if (BB1Conditions == std::nullopt) return false; return BB0Conditions->isEquivalent(*BB1Conditions); @@ -455,7 +457,7 @@ bool llvm::nonStrictlyPostDominate(const BasicBlock *ThisBlock, if (PDT->dominates(CurBlock, OtherBlock)) return true; - for (auto *Pred : predecessors(CurBlock)) { + for (const auto *Pred : predecessors(CurBlock)) { if (Pred == CommonDominator || Visited.count(Pred)) continue; WorkList.push_back(Pred); diff --git a/llvm/lib/Transforms/Utils/Debugify.cpp b/llvm/lib/Transforms/Utils/Debugify.cpp index 24126b5ab67b..989473693a0b 100644 --- a/llvm/lib/Transforms/Utils/Debugify.cpp +++ b/llvm/lib/Transforms/Utils/Debugify.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" +#include <optional> #define DEBUG_TYPE "debugify" @@ -114,7 +115,8 @@ bool llvm::applyDebugifyMetadata( continue; bool InsertedDbgVal = false; - auto SPType = DIB.createSubroutineType(DIB.getOrCreateTypeArray(None)); + auto SPType = + DIB.createSubroutineType(DIB.getOrCreateTypeArray(std::nullopt)); DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition | DISubprogram::SPFlagOptimized; if (F.hasPrivateLinkage() || F.hasInternalLinkage()) @@ -243,13 +245,18 @@ applyDebugify(Module &M, bool llvm::stripDebugifyMetadata(Module &M) { bool Changed = false; - // Remove the llvm.debugify module-level named metadata. + // Remove the llvm.debugify and llvm.mir.debugify module-level named metadata. NamedMDNode *DebugifyMD = M.getNamedMetadata("llvm.debugify"); if (DebugifyMD) { M.eraseNamedMetadata(DebugifyMD); Changed = true; } + if (auto *MIRDebugifyMD = M.getNamedMetadata("llvm.mir.debugify")) { + M.eraseNamedMetadata(MIRDebugifyMD); + Changed = true; + } + // Strip out all debug intrinsics and supporting metadata (subprograms, types, // variables, etc). Changed |= StripDebugInfo(M); @@ -338,7 +345,7 @@ bool llvm::collectDebugInfoMetadata(Module &M, if (I.getDebugLoc().getInlinedAt()) continue; // Skip undef values. - if (DVI->isUndef()) + if (DVI->isKillLocation()) continue; auto *Var = DVI->getVariable(); @@ -513,15 +520,19 @@ static void writeJSON(StringRef OrigDIVerifyBugsReportFilePath, return; } - OS_FILE << "{\"file\":\"" << FileNameFromCU << "\", "; + if (auto L = OS_FILE.lock()) { + OS_FILE << "{\"file\":\"" << FileNameFromCU << "\", "; - StringRef PassName = NameOfWrappedPass != "" ? NameOfWrappedPass : "no-name"; - OS_FILE << "\"pass\":\"" << PassName << "\", "; + StringRef PassName = + NameOfWrappedPass != "" ? NameOfWrappedPass : "no-name"; + OS_FILE << "\"pass\":\"" << PassName << "\", "; - llvm::json::Value BugsToPrint{std::move(Bugs)}; - OS_FILE << "\"bugs\": " << BugsToPrint; + llvm::json::Value BugsToPrint{std::move(Bugs)}; + OS_FILE << "\"bugs\": " << BugsToPrint; - OS_FILE << "}\n"; + OS_FILE << "}\n"; + } + OS_FILE.close(); } bool llvm::checkDebugInfoMetadata(Module &M, @@ -577,7 +588,7 @@ bool llvm::checkDebugInfoMetadata(Module &M, if (I.getDebugLoc().getInlinedAt()) continue; // Skip undef values. - if (DVI->isUndef()) + if (DVI->isKillLocation()) continue; auto *Var = DVI->getVariable(); @@ -670,7 +681,7 @@ bool diagnoseMisSizedDbgValue(Module &M, DbgValueInst *DVI) { Type *Ty = V->getType(); uint64_t ValueOperandSize = getAllocSizeInBits(M, Ty); - Optional<uint64_t> DbgVarSize = DVI->getFragmentSizeInBits(); + std::optional<uint64_t> DbgVarSize = DVI->getFragmentSizeInBits(); if (!ValueOperandSize || !DbgVarSize) return false; @@ -1020,19 +1031,19 @@ void DebugifyEachInstrumentation::registerCallbacks( PIC.registerBeforeNonSkippedPassCallback([this](StringRef P, Any IR) { if (isIgnoredPass(P)) return; - if (any_isa<const Function *>(IR)) - applyDebugify(*const_cast<Function *>(any_cast<const Function *>(IR)), + if (const auto **F = any_cast<const Function *>(&IR)) + applyDebugify(*const_cast<Function *>(*F), Mode, DebugInfoBeforePass, P); - else if (any_isa<const Module *>(IR)) - applyDebugify(*const_cast<Module *>(any_cast<const Module *>(IR)), + else if (const auto **M = any_cast<const Module *>(&IR)) + applyDebugify(*const_cast<Module *>(*M), Mode, DebugInfoBeforePass, P); }); PIC.registerAfterPassCallback([this](StringRef P, Any IR, const PreservedAnalyses &PassPA) { if (isIgnoredPass(P)) return; - if (any_isa<const Function *>(IR)) { - auto &F = *const_cast<Function *>(any_cast<const Function *>(IR)); + if (const auto **CF = any_cast<const Function *>(&IR)) { + auto &F = *const_cast<Function *>(*CF); Module &M = *F.getParent(); auto It = F.getIterator(); if (Mode == DebugifyMode::SyntheticDebugInfo) @@ -1043,8 +1054,8 @@ void DebugifyEachInstrumentation::registerCallbacks( M, make_range(It, std::next(It)), *DebugInfoBeforePass, "CheckModuleDebugify (original debuginfo)", P, OrigDIVerifyBugsReportFilePath); - } else if (any_isa<const Module *>(IR)) { - auto &M = *const_cast<Module *>(any_cast<const Module *>(IR)); + } else if (const auto **CM = any_cast<const Module *>(&IR)) { + auto &M = *const_cast<Module *>(*CM); if (Mode == DebugifyMode::SyntheticDebugInfo) checkDebugifyMetadata(M, M.functions(), P, "CheckModuleDebugify", /*Strip=*/true, DIStatsMap); diff --git a/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp b/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp index f6f80540ad95..086ea088dc5e 100644 --- a/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -92,8 +92,15 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, BasicBlock::iterator InsertPt; if (!I.isTerminator()) { InsertPt = ++I.getIterator(); + // Don't insert before PHI nodes or landingpad instrs. for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt) - /* empty */; // Don't insert before PHI nodes or landingpad instrs. + if (isa<CatchSwitchInst>(InsertPt)) + break; + if (isa<CatchSwitchInst>(InsertPt)) { + for (BasicBlock *Handler : successors(&*InsertPt)) + new StoreInst(&I, Slot, &*Handler->getFirstInsertionPt()); + return Slot; + } } else { InvokeInst &II = cast<InvokeInst>(I); InsertPt = II.getNormalDest()->getFirstInsertionPt(); @@ -138,14 +145,27 @@ AllocaInst *llvm::DemotePHIToStack(PHINode *P, Instruction *AllocaPoint) { // Insert a load in place of the PHI and replace all uses. BasicBlock::iterator InsertPt = P->getIterator(); - + // Don't insert before PHI nodes or landingpad instrs. for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt) - /* empty */; // Don't insert before PHI nodes or landingpad instrs. - - Value *V = - new LoadInst(P->getType(), Slot, P->getName() + ".reload", &*InsertPt); - P->replaceAllUsesWith(V); - + if (isa<CatchSwitchInst>(InsertPt)) + break; + if (isa<CatchSwitchInst>(InsertPt)) { + // We need a separate load before each actual use of the PHI + SmallVector<Instruction *, 4> Users; + for (User *U : P->users()) { + Instruction *User = cast<Instruction>(U); + Users.push_back(User); + } + for (Instruction *User : Users) { + Value *V = + new LoadInst(P->getType(), Slot, P->getName() + ".reload", User); + User->replaceUsesOfWith(P, V); + } + } else { + Value *V = + new LoadInst(P->getType(), Slot, P->getName() + ".reload", &*InsertPt); + P->replaceAllUsesWith(V); + } // Delete PHI. P->eraseFromParent(); return Slot; diff --git a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp index 60f910bceab8..53af1b1969c2 100644 --- a/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/EntryExitInstrumenter.h" +#include "llvm/ADT/Triple.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" @@ -34,9 +35,24 @@ static void insertCall(Function &CurFn, StringRef Func, Func == "__mcount" || Func == "_mcount" || Func == "__cyg_profile_func_enter_bare") { - FunctionCallee Fn = M.getOrInsertFunction(Func, Type::getVoidTy(C)); - CallInst *Call = CallInst::Create(Fn, "", InsertionPt); - Call->setDebugLoc(DL); + Triple TargetTriple(M.getTargetTriple()); + if (TargetTriple.isOSAIX() && Func == "__mcount") { + Type *SizeTy = M.getDataLayout().getIntPtrType(C); + Type *SizePtrTy = SizeTy->getPointerTo(); + GlobalVariable *GV = new GlobalVariable(M, SizeTy, /*isConstant=*/false, + GlobalValue::InternalLinkage, + ConstantInt::get(SizeTy, 0)); + CallInst *Call = CallInst::Create( + M.getOrInsertFunction(Func, + FunctionType::get(Type::getVoidTy(C), {SizePtrTy}, + /*isVarArg=*/false)), + {GV}, "", InsertionPt); + Call->setDebugLoc(DL); + } else { + FunctionCallee Fn = M.getOrInsertFunction(Func, Type::getVoidTy(C)); + CallInst *Call = CallInst::Create(Fn, "", InsertionPt); + Call->setDebugLoc(DL); + } return; } diff --git a/llvm/lib/Transforms/Utils/Evaluator.cpp b/llvm/lib/Transforms/Utils/Evaluator.cpp index 7509fde6df9d..dc58bebd724b 100644 --- a/llvm/lib/Transforms/Utils/Evaluator.cpp +++ b/llvm/lib/Transforms/Utils/Evaluator.cpp @@ -132,7 +132,7 @@ Constant *Evaluator::MutableValue::read(Type *Ty, APInt Offset, const MutableValue *V = this; while (const auto *Agg = V->Val.dyn_cast<MutableAggregate *>()) { Type *AggTy = Agg->Ty; - Optional<APInt> Index = DL.getGEPIndexForOffset(AggTy, Offset); + std::optional<APInt> Index = DL.getGEPIndexForOffset(AggTy, Offset); if (!Index || Index->uge(Agg->Elements.size()) || !TypeSize::isKnownLE(TySize, DL.getTypeStoreSize(AggTy))) return nullptr; @@ -176,7 +176,7 @@ bool Evaluator::MutableValue::write(Constant *V, APInt Offset, MutableAggregate *Agg = MV->Val.get<MutableAggregate *>(); Type *AggTy = Agg->Ty; - Optional<APInt> Index = DL.getGEPIndexForOffset(AggTy, Offset); + std::optional<APInt> Index = DL.getGEPIndexForOffset(AggTy, Offset); if (!Index || Index->uge(Agg->Elements.size()) || !TypeSize::isKnownLE(TySize, DL.getTypeStoreSize(AggTy))) return false; @@ -626,10 +626,8 @@ bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, CallStack.push_back(F); // Initialize arguments to the incoming values specified. - unsigned ArgNo = 0; - for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); AI != E; - ++AI, ++ArgNo) - setVal(&*AI, ActualArgs[ArgNo]); + for (const auto &[ArgNo, Arg] : llvm::enumerate(F->args())) + setVal(&Arg, ActualArgs[ArgNo]); // ExecutedBlocks - We only handle non-looping, non-recursive code. As such, // we can only evaluate any one basic block at most once. This set keeps diff --git a/llvm/lib/Transforms/Utils/FixIrreducible.cpp b/llvm/lib/Transforms/Utils/FixIrreducible.cpp index 24539bd231c6..dda236167363 100644 --- a/llvm/lib/Transforms/Utils/FixIrreducible.cpp +++ b/llvm/lib/Transforms/Utils/FixIrreducible.cpp @@ -137,7 +137,7 @@ static void reconnectChildLoops(LoopInfo &LI, Loop *ParentLoop, Loop *NewLoop, // SCC gets destroyed since its backedges are removed. That may // not be necessary if we can retain such backedges. if (Headers.count(Child->getHeader())) { - for (auto BB : Child->blocks()) { + for (auto *BB : Child->blocks()) { if (LI.getLoopFor(BB) != Child) continue; LI.changeLoopFor(BB, NewLoop); @@ -146,7 +146,7 @@ static void reconnectChildLoops(LoopInfo &LI, Loop *ParentLoop, Loop *NewLoop, } std::vector<Loop *> GrandChildLoops; std::swap(GrandChildLoops, Child->getSubLoopsVector()); - for (auto GrandChildLoop : GrandChildLoops) { + for (auto *GrandChildLoop : GrandChildLoops) { GrandChildLoop->setParentLoop(nullptr); NewLoop->addChildLoop(GrandChildLoop); } @@ -170,14 +170,14 @@ static void createNaturalLoopInternal(LoopInfo &LI, DominatorTree &DT, SetVector<BasicBlock *> &Headers) { #ifndef NDEBUG // All headers are part of the SCC - for (auto H : Headers) { + for (auto *H : Headers) { assert(Blocks.count(H)); } #endif SetVector<BasicBlock *> Predecessors; - for (auto H : Headers) { - for (auto P : predecessors(H)) { + for (auto *H : Headers) { + for (auto *P : predecessors(H)) { Predecessors.insert(P); } } @@ -214,13 +214,13 @@ static void createNaturalLoopInternal(LoopInfo &LI, DominatorTree &DT, // in the loop. This ensures that it is recognized as the // header. Since the new loop is already in LoopInfo, the new blocks // are also propagated up the chain of parent loops. - for (auto G : GuardBlocks) { + for (auto *G : GuardBlocks) { LLVM_DEBUG(dbgs() << "added guard block: " << G->getName() << "\n"); NewLoop->addBasicBlockToLoop(G, LI); } // Add the SCC blocks to the new loop. - for (auto BB : Blocks) { + for (auto *BB : Blocks) { NewLoop->addBlockEntry(BB); if (LI.getLoopFor(BB) == ParentLoop) { LLVM_DEBUG(dbgs() << "moved block from parent: " << BB->getName() @@ -288,7 +288,7 @@ static bool makeReducible(LoopInfo &LI, DominatorTree &DT, Graph &&G) { // match. So we discover the headers using the reverse of the block order. SetVector<BasicBlock *> Headers; LLVM_DEBUG(dbgs() << "Found headers:"); - for (auto BB : reverse(Blocks)) { + for (auto *BB : reverse(Blocks)) { for (const auto P : predecessors(BB)) { // Skip unreachable predecessors. if (!DT.isReachableFromEntry(P)) diff --git a/llvm/lib/Transforms/Utils/FlattenCFG.cpp b/llvm/lib/Transforms/Utils/FlattenCFG.cpp index ddd3f597ae01..2fb2ab82e41a 100644 --- a/llvm/lib/Transforms/Utils/FlattenCFG.cpp +++ b/llvm/lib/Transforms/Utils/FlattenCFG.cpp @@ -145,9 +145,7 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { // Check predecessors of \param BB. SmallPtrSet<BasicBlock *, 16> Preds(pred_begin(BB), pred_end(BB)); - for (SmallPtrSetIterator<BasicBlock *> PI = Preds.begin(), PE = Preds.end(); - PI != PE; ++PI) { - BasicBlock *Pred = *PI; + for (BasicBlock *Pred : Preds) { BranchInst *PBI = dyn_cast<BranchInst>(Pred->getTerminator()); // All predecessors should terminate with a branch. @@ -286,9 +284,8 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { do { CB = PBI->getSuccessor(1 - Idx); // Delete the conditional branch. - FirstCondBlock->getInstList().pop_back(); - FirstCondBlock->getInstList() - .splice(FirstCondBlock->end(), CB->getInstList()); + FirstCondBlock->back().eraseFromParent(); + FirstCondBlock->splice(FirstCondBlock->end(), CB); PBI = cast<BranchInst>(FirstCondBlock->getTerminator()); Value *CC = PBI->getCondition(); // Merge conditions. @@ -431,6 +428,9 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { return false; BasicBlock *FirstEntryBlock = CInst1->getParent(); + // Don't die trying to process degenerate/unreachable code. + if (FirstEntryBlock == SecondEntryBlock) + return false; // Either then-path or else-path should be empty. bool InvertCond2 = false; @@ -479,9 +479,8 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { } // Merge \param SecondEntryBlock into \param FirstEntryBlock. - FirstEntryBlock->getInstList().pop_back(); - FirstEntryBlock->getInstList() - .splice(FirstEntryBlock->end(), SecondEntryBlock->getInstList()); + FirstEntryBlock->back().eraseFromParent(); + FirstEntryBlock->splice(FirstEntryBlock->end(), SecondEntryBlock); BranchInst *PBI = cast<BranchInst>(FirstEntryBlock->getTerminator()); assert(PBI->getCondition() == CInst2); BasicBlock *SaveInsertBB = Builder.GetInsertBlock(); diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp index 06596f7b04e1..3fa61ec68cd3 100644 --- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -110,7 +110,7 @@ int FunctionComparator::cmpMem(StringRef L, StringRef R) const { // Compare strings lexicographically only when it is necessary: only when // strings are equal in size. - return L.compare(R); + return std::clamp(L.compare(R), -1, 1); } int FunctionComparator::cmpAttrs(const AttributeList L, @@ -241,9 +241,9 @@ int FunctionComparator::cmpConstants(const Constant *L, unsigned TyRWidth = 0; if (auto *VecTyL = dyn_cast<VectorType>(TyL)) - TyLWidth = VecTyL->getPrimitiveSizeInBits().getFixedSize(); + TyLWidth = VecTyL->getPrimitiveSizeInBits().getFixedValue(); if (auto *VecTyR = dyn_cast<VectorType>(TyR)) - TyRWidth = VecTyR->getPrimitiveSizeInBits().getFixedSize(); + TyRWidth = VecTyR->getPrimitiveSizeInBits().getFixedValue(); if (TyLWidth != TyRWidth) return cmpNumbers(TyLWidth, TyRWidth); @@ -381,7 +381,7 @@ int FunctionComparator::cmpConstants(const Constant *L, BasicBlock *RBB = RBA->getBasicBlock(); if (LBB == RBB) return 0; - for (BasicBlock &BB : F->getBasicBlockList()) { + for (BasicBlock &BB : *F) { if (&BB == LBB) { assert(&BB != RBB); return -1; @@ -402,6 +402,15 @@ int FunctionComparator::cmpConstants(const Constant *L, return cmpValues(LBA->getBasicBlock(), RBA->getBasicBlock()); } } + case Value::DSOLocalEquivalentVal: { + // dso_local_equivalent is functionally equivalent to whatever it points to. + // This means the behavior of the IR should be the exact same as if the + // function was referenced directly rather than through a + // dso_local_equivalent. + const auto *LEquiv = cast<DSOLocalEquivalent>(L); + const auto *REquiv = cast<DSOLocalEquivalent>(R); + return cmpGlobalValues(LEquiv->getGlobalValue(), REquiv->getGlobalValue()); + } default: // Unknown constant, abort. LLVM_DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n"); llvm_unreachable("Constant ValueID not recognized."); @@ -968,7 +977,7 @@ FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) { // This random value acts as a block header, as otherwise the partition of // opcodes into BBs wouldn't affect the hash, only the order of the opcodes H.add(45798); - for (auto &Inst : *BB) { + for (const auto &Inst : *BB) { H.add(Inst.getOpcode()); } const Instruction *Term = BB->getTerminator(); diff --git a/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp b/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp index 8e6d4626c9fd..87be6be01885 100644 --- a/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -12,8 +12,18 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/FunctionImportUtils.h" +#include "llvm/Support/CommandLine.h" using namespace llvm; +/// Uses the "source_filename" instead of a Module hash ID for the suffix of +/// promoted locals during LTO. NOTE: This requires that the source filename +/// has a unique name / path to avoid name collisions. +static cl::opt<bool> UseSourceFilenameForPromotedLocals( + "use-source-filename-for-promoted-locals", cl::Hidden, + cl::desc("Uses the source file name instead of the Module hash. " + "This requires that the source filename has a unique name / " + "path to avoid name collisions.")); + /// Checks if we should import SGV as a definition, otherwise import as a /// declaration. bool FunctionImportGlobalProcessing::doImportAsDefinition( @@ -94,9 +104,19 @@ bool FunctionImportGlobalProcessing::isNonRenamableLocal( std::string FunctionImportGlobalProcessing::getPromotedName(const GlobalValue *SGV) { assert(SGV->hasLocalLinkage()); + // For locals that must be promoted to global scope, ensure that // the promoted name uniquely identifies the copy in the original module, // using the ID assigned during combined index creation. + if (UseSourceFilenameForPromotedLocals && + !SGV->getParent()->getSourceFileName().empty()) { + SmallString<256> Suffix(SGV->getParent()->getSourceFileName()); + std::replace_if(std::begin(Suffix), std::end(Suffix), + [&](char ch) { return !isAlnum(ch); }, '_'); + return ModuleSummaryIndex::getGlobalNameForLocal( + SGV->getName(), Suffix); + } + return ModuleSummaryIndex::getGlobalNameForLocal( SGV->getName(), ImportIndex.getModuleHash(SGV->getParent()->getModuleIdentifier())); @@ -206,7 +226,7 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { if (VI && ImportIndex.hasSyntheticEntryCounts()) { if (Function *F = dyn_cast<Function>(&GV)) { if (!F->isDeclaration()) { - for (auto &S : VI.getSummaryList()) { + for (const auto &S : VI.getSummaryList()) { auto *FS = cast<FunctionSummary>(S->getBaseObject()); if (FS->modulePath() == M.getModuleIdentifier()) { F->setEntryCount(Function::ProfileCount(FS->entryCount(), diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp index 878f9477a29d..399c9a43793f 100644 --- a/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -12,8 +12,6 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -27,6 +25,7 @@ #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/MemoryProfileInfo.h" #include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/Analysis/ObjCARCUtil.h" #include "llvm/Analysis/ProfileSummaryInfo.h" @@ -70,11 +69,15 @@ #include <cstdint> #include <iterator> #include <limits> +#include <optional> #include <string> #include <utility> #include <vector> +#define DEBUG_TYPE "inline-function" + using namespace llvm; +using namespace llvm::memprof; using ProfileCount = Function::ProfileCount; static cl::opt<bool> @@ -547,13 +550,6 @@ static BasicBlock *HandleCallsInBlockInlinedThroughInvoke( if (!CI || CI->doesNotThrow()) continue; - if (CI->isInlineAsm()) { - InlineAsm *IA = cast<InlineAsm>(CI->getCalledOperand()); - if (!IA->canThrow()) { - continue; - } - } - // We do not need to (and in fact, cannot) convert possibly throwing calls // to @llvm.experimental_deoptimize (resp. @llvm.experimental.guard) into // invokes. The caller's "segment" of the deoptimization continuation @@ -782,6 +778,140 @@ static void HandleInlinedEHPad(InvokeInst *II, BasicBlock *FirstNewBlock, UnwindDest->removePredecessor(InvokeBB); } +static bool haveCommonPrefix(MDNode *MIBStackContext, + MDNode *CallsiteStackContext) { + assert(MIBStackContext->getNumOperands() > 0 && + CallsiteStackContext->getNumOperands() > 0); + // Because of the context trimming performed during matching, the callsite + // context could have more stack ids than the MIB. We match up to the end of + // the shortest stack context. + for (auto MIBStackIter = MIBStackContext->op_begin(), + CallsiteStackIter = CallsiteStackContext->op_begin(); + MIBStackIter != MIBStackContext->op_end() && + CallsiteStackIter != CallsiteStackContext->op_end(); + MIBStackIter++, CallsiteStackIter++) { + auto *Val1 = mdconst::dyn_extract<ConstantInt>(*MIBStackIter); + auto *Val2 = mdconst::dyn_extract<ConstantInt>(*CallsiteStackIter); + assert(Val1 && Val2); + if (Val1->getZExtValue() != Val2->getZExtValue()) + return false; + } + return true; +} + +static void removeMemProfMetadata(CallBase *Call) { + Call->setMetadata(LLVMContext::MD_memprof, nullptr); +} + +static void removeCallsiteMetadata(CallBase *Call) { + Call->setMetadata(LLVMContext::MD_callsite, nullptr); +} + +static void updateMemprofMetadata(CallBase *CI, + const std::vector<Metadata *> &MIBList) { + assert(!MIBList.empty()); + // Remove existing memprof, which will either be replaced or may not be needed + // if we are able to use a single allocation type function attribute. + removeMemProfMetadata(CI); + CallStackTrie CallStack; + for (Metadata *MIB : MIBList) + CallStack.addCallStack(cast<MDNode>(MIB)); + bool MemprofMDAttached = CallStack.buildAndAttachMIBMetadata(CI); + assert(MemprofMDAttached == CI->hasMetadata(LLVMContext::MD_memprof)); + if (!MemprofMDAttached) + // If we used a function attribute remove the callsite metadata as well. + removeCallsiteMetadata(CI); +} + +// Update the metadata on the inlined copy ClonedCall of a call OrigCall in the +// inlined callee body, based on the callsite metadata InlinedCallsiteMD from +// the call that was inlined. +static void propagateMemProfHelper(const CallBase *OrigCall, + CallBase *ClonedCall, + MDNode *InlinedCallsiteMD) { + MDNode *OrigCallsiteMD = ClonedCall->getMetadata(LLVMContext::MD_callsite); + MDNode *ClonedCallsiteMD = nullptr; + // Check if the call originally had callsite metadata, and update it for the + // new call in the inlined body. + if (OrigCallsiteMD) { + // The cloned call's context is now the concatenation of the original call's + // callsite metadata and the callsite metadata on the call where it was + // inlined. + ClonedCallsiteMD = MDNode::concatenate(OrigCallsiteMD, InlinedCallsiteMD); + ClonedCall->setMetadata(LLVMContext::MD_callsite, ClonedCallsiteMD); + } + + // Update any memprof metadata on the cloned call. + MDNode *OrigMemProfMD = ClonedCall->getMetadata(LLVMContext::MD_memprof); + if (!OrigMemProfMD) + return; + // We currently expect that allocations with memprof metadata also have + // callsite metadata for the allocation's part of the context. + assert(OrigCallsiteMD); + + // New call's MIB list. + std::vector<Metadata *> NewMIBList; + + // For each MIB metadata, check if its call stack context starts with the + // new clone's callsite metadata. If so, that MIB goes onto the cloned call in + // the inlined body. If not, it stays on the out-of-line original call. + for (auto &MIBOp : OrigMemProfMD->operands()) { + MDNode *MIB = dyn_cast<MDNode>(MIBOp); + // Stack is first operand of MIB. + MDNode *StackMD = getMIBStackNode(MIB); + assert(StackMD); + // See if the new cloned callsite context matches this profiled context. + if (haveCommonPrefix(StackMD, ClonedCallsiteMD)) + // Add it to the cloned call's MIB list. + NewMIBList.push_back(MIB); + } + if (NewMIBList.empty()) { + removeMemProfMetadata(ClonedCall); + removeCallsiteMetadata(ClonedCall); + return; + } + if (NewMIBList.size() < OrigMemProfMD->getNumOperands()) + updateMemprofMetadata(ClonedCall, NewMIBList); +} + +// Update memprof related metadata (!memprof and !callsite) based on the +// inlining of Callee into the callsite at CB. The updates include merging the +// inlined callee's callsite metadata with that of the inlined call, +// and moving the subset of any memprof contexts to the inlined callee +// allocations if they match the new inlined call stack. +// FIXME: Replace memprof metadata with function attribute if all MIB end up +// having the same behavior. Do other context trimming/merging optimizations +// too. +static void +propagateMemProfMetadata(Function *Callee, CallBase &CB, + bool ContainsMemProfMetadata, + const ValueMap<const Value *, WeakTrackingVH> &VMap) { + MDNode *CallsiteMD = CB.getMetadata(LLVMContext::MD_callsite); + // Only need to update if the inlined callsite had callsite metadata, or if + // there was any memprof metadata inlined. + if (!CallsiteMD && !ContainsMemProfMetadata) + return; + + // Propagate metadata onto the cloned calls in the inlined callee. + for (const auto &Entry : VMap) { + // See if this is a call that has been inlined and remapped, and not + // simplified away in the process. + auto *OrigCall = dyn_cast_or_null<CallBase>(Entry.first); + auto *ClonedCall = dyn_cast_or_null<CallBase>(Entry.second); + if (!OrigCall || !ClonedCall) + continue; + // If the inlined callsite did not have any callsite metadata, then it isn't + // involved in any profiled call contexts, and we can remove any memprof + // metadata on the cloned call. + if (!CallsiteMD) { + removeMemProfMetadata(ClonedCall); + removeCallsiteMetadata(ClonedCall); + continue; + } + propagateMemProfHelper(OrigCall, ClonedCall, CallsiteMD); + } +} + /// When inlining a call site that has !llvm.mem.parallel_loop_access, /// !llvm.access.group, !alias.scope or !noalias metadata, that metadata should /// be propagated to all memory-accessing cloned instructions. @@ -911,7 +1041,7 @@ void ScopedAliasMetadataDeepCloner::clone() { SmallVector<TempMDTuple, 16> DummyNodes; for (const MDNode *I : MD) { - DummyNodes.push_back(MDTuple::getTemporary(I->getContext(), None)); + DummyNodes.push_back(MDTuple::getTemporary(I->getContext(), std::nullopt)); MDMap[I].reset(DummyNodes.back().get()); } @@ -1061,13 +1191,13 @@ static void AddAliasScopeMetadata(CallBase &CB, ValueToValueMapTy &VMap, IsFuncCall = true; if (CalleeAAR) { - FunctionModRefBehavior MRB = CalleeAAR->getModRefBehavior(Call); + MemoryEffects ME = CalleeAAR->getMemoryEffects(Call); // We'll retain this knowledge without additional metadata. - if (AAResults::onlyAccessesInaccessibleMem(MRB)) + if (ME.onlyAccessesInaccessibleMem()) continue; - if (AAResults::onlyAccessesArgPointees(MRB)) + if (ME.onlyAccessesArgPointees()) IsArgMemOnlyCall = true; } @@ -1307,23 +1437,26 @@ static void AddAlignmentAssumptions(CallBase &CB, InlineFunctionInfo &IFI) { Function *CalledFunc = CB.getCalledFunction(); for (Argument &Arg : CalledFunc->args()) { - unsigned Align = Arg.getType()->isPointerTy() ? Arg.getParamAlignment() : 0; - if (Align && !Arg.hasPassPointeeByValueCopyAttr() && !Arg.hasNUses(0)) { - if (!DTCalculated) { - DT.recalculate(*CB.getCaller()); - DTCalculated = true; - } - - // If we can already prove the asserted alignment in the context of the - // caller, then don't bother inserting the assumption. - Value *ArgVal = CB.getArgOperand(Arg.getArgNo()); - if (getKnownAlignment(ArgVal, DL, &CB, AC, &DT) >= Align) - continue; + if (!Arg.getType()->isPointerTy() || Arg.hasPassPointeeByValueCopyAttr() || + Arg.hasNUses(0)) + continue; + MaybeAlign Alignment = Arg.getParamAlign(); + if (!Alignment) + continue; - CallInst *NewAsmp = - IRBuilder<>(&CB).CreateAlignmentAssumption(DL, ArgVal, Align); - AC->registerAssumption(cast<AssumeInst>(NewAsmp)); + if (!DTCalculated) { + DT.recalculate(*CB.getCaller()); + DTCalculated = true; } + // If we can already prove the asserted alignment in the context of the + // caller, then don't bother inserting the assumption. + Value *ArgVal = CB.getArgOperand(Arg.getArgNo()); + if (getKnownAlignment(ArgVal, DL, &CB, AC, &DT) >= *Alignment) + continue; + + CallInst *NewAsmp = IRBuilder<>(&CB).CreateAlignmentAssumption( + DL, ArgVal, Alignment->value()); + AC->registerAssumption(cast<AssumeInst>(NewAsmp)); } } @@ -1423,7 +1556,7 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, Instruction *TheCall, const Function *CalledFunc, InlineFunctionInfo &IFI, - unsigned ByValAlignment) { + MaybeAlign ByValAlignment) { assert(cast<PointerType>(Arg->getType()) ->isOpaqueOrPointeeTypeMatches(ByValType)); Function *Caller = TheCall->getFunction(); @@ -1436,7 +1569,7 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, // If the byval argument has a specified alignment that is greater than the // passed in pointer, then we either have to round up the input pointer or // give up on this transformation. - if (ByValAlignment <= 1) // 0 = unspecified, 1 = no particular alignment. + if (ByValAlignment.valueOrOne() == 1) return Arg; AssumptionCache *AC = @@ -1444,8 +1577,8 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, // If the pointer is already known to be sufficiently aligned, or if we can // round it up to a larger alignment, then we don't need a temporary. - if (getOrEnforceKnownAlignment(Arg, Align(ByValAlignment), DL, TheCall, - AC) >= ByValAlignment) + if (getOrEnforceKnownAlignment(Arg, *ByValAlignment, DL, TheCall, AC) >= + *ByValAlignment) return Arg; // Otherwise, we have to make a memcpy to get a safe alignment. This is bad @@ -1453,13 +1586,13 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, } // Create the alloca. If we have DataLayout, use nice alignment. - Align Alignment(DL.getPrefTypeAlignment(ByValType)); + Align Alignment = DL.getPrefTypeAlign(ByValType); // If the byval had an alignment specified, we *must* use at least that // alignment, as it is required by the byval argument (and uses of the // pointer inside the callee). - if (ByValAlignment > 0) - Alignment = std::max(Alignment, Align(ByValAlignment)); + if (ByValAlignment) + Alignment = std::max(Alignment, *ByValAlignment); Value *NewAlloca = new AllocaInst(ByValType, DL.getAllocaAddrSpace(), nullptr, Alignment, @@ -1595,6 +1728,94 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, } } +#undef DEBUG_TYPE +#define DEBUG_TYPE "assignment-tracking" +/// Find Alloca and linked DbgAssignIntrinsic for locals escaped by \p CB. +static at::StorageToVarsMap collectEscapedLocals(const DataLayout &DL, + const CallBase &CB) { + at::StorageToVarsMap EscapedLocals; + SmallPtrSet<const Value *, 4> SeenBases; + + LLVM_DEBUG( + errs() << "# Finding caller local variables escaped by callee\n"); + for (const Value *Arg : CB.args()) { + LLVM_DEBUG(errs() << "INSPECT: " << *Arg << "\n"); + if (!Arg->getType()->isPointerTy()) { + LLVM_DEBUG(errs() << " | SKIP: Not a pointer\n"); + continue; + } + + const Instruction *I = dyn_cast<Instruction>(Arg); + if (!I) { + LLVM_DEBUG(errs() << " | SKIP: Not result of instruction\n"); + continue; + } + + // Walk back to the base storage. + assert(Arg->getType()->isPtrOrPtrVectorTy()); + APInt TmpOffset(DL.getIndexTypeSizeInBits(Arg->getType()), 0, false); + const AllocaInst *Base = dyn_cast<AllocaInst>( + Arg->stripAndAccumulateConstantOffsets(DL, TmpOffset, true)); + if (!Base) { + LLVM_DEBUG(errs() << " | SKIP: Couldn't walk back to base storage\n"); + continue; + } + + assert(Base); + LLVM_DEBUG(errs() << " | BASE: " << *Base << "\n"); + // We only need to process each base address once - skip any duplicates. + if (!SeenBases.insert(Base).second) + continue; + + // Find all local variables associated with the backing storage. + for (auto *DAI : at::getAssignmentMarkers(Base)) { + // Skip variables from inlined functions - they are not local variables. + if (DAI->getDebugLoc().getInlinedAt()) + continue; + LLVM_DEBUG(errs() << " > DEF : " << *DAI << "\n"); + EscapedLocals[Base].insert(at::VarRecord(DAI)); + } + } + return EscapedLocals; +} + +static void trackInlinedStores(Function::iterator Start, Function::iterator End, + const CallBase &CB) { + LLVM_DEBUG(errs() << "trackInlinedStores into " + << Start->getParent()->getName() << " from " + << CB.getCalledFunction()->getName() << "\n"); + std::unique_ptr<DataLayout> DL = std::make_unique<DataLayout>(CB.getModule()); + at::trackAssignments(Start, End, collectEscapedLocals(*DL, CB), *DL); +} + +/// Update inlined instructions' DIAssignID metadata. We need to do this +/// otherwise a function inlined more than once into the same function +/// will cause DIAssignID to be shared by many instructions. +static void fixupAssignments(Function::iterator Start, Function::iterator End) { + // Map {Old, New} metadata. Not used directly - use GetNewID. + DenseMap<DIAssignID *, DIAssignID *> Map; + auto GetNewID = [&Map](Metadata *Old) { + DIAssignID *OldID = cast<DIAssignID>(Old); + if (DIAssignID *NewID = Map.lookup(OldID)) + return NewID; + DIAssignID *NewID = DIAssignID::getDistinct(OldID->getContext()); + Map[OldID] = NewID; + return NewID; + }; + // Loop over all the inlined instructions. If we find a DIAssignID + // attachment or use, replace it with a new version. + for (auto BBI = Start; BBI != End; ++BBI) { + for (Instruction &I : *BBI) { + if (auto *ID = I.getMetadata(LLVMContext::MD_DIAssignID)) + I.setMetadata(LLVMContext::MD_DIAssignID, GetNewID(ID)); + else if (auto *DAI = dyn_cast<DbgAssignIntrinsic>(&I)) + DAI->setAssignId(GetNewID(DAI->getAssignID())); + } + } +} +#undef DEBUG_TYPE +#define DEBUG_TYPE "inline-function" + /// Update the block frequencies of the caller after a callee has been inlined. /// /// Each block cloned into the caller has its block frequency scaled by the @@ -1636,7 +1857,8 @@ static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, BlockFrequencyInfo *CallerBFI) { if (CalleeEntryCount.isSynthetic() || CalleeEntryCount.getCount() < 1) return; - auto CallSiteCount = PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None; + auto CallSiteCount = + PSI ? PSI->getProfileCount(TheCall, CallerBFI) : std::nullopt; int64_t CallCount = std::min(CallSiteCount.value_or(0), CalleeEntryCount.getCount()); updateProfileCallee(Callee, -CallCount, &VMap); @@ -1784,6 +2006,7 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind, /// exists in the instruction stream. Similarly this will inline a recursive /// function by one level. llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, + bool MergeAttributes, AAResults *CalleeAAR, bool InsertLifetime, Function *ForwardVarArgsTo) { @@ -1814,6 +2037,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, continue; if (Tag == LLVMContext::OB_clang_arc_attachedcall) continue; + if (Tag == LLVMContext::OB_kcfi) + continue; return InlineResult::failure("unsupported operand bundle"); } @@ -1874,7 +2099,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if (CallerPersonality) { EHPersonality Personality = classifyEHPersonality(CallerPersonality); if (isScopedEHPersonality(Personality)) { - Optional<OperandBundleUse> ParentFunclet = + std::optional<OperandBundleUse> ParentFunclet = CB.getOperandBundle(LLVMContext::OB_funclet); if (ParentFunclet) CallSiteEHPad = cast<FuncletPadInst>(ParentFunclet->Inputs.front()); @@ -1963,7 +2188,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if (CB.isByValArgument(ArgNo)) { ActualArg = HandleByValArgument(CB.getParamByValType(ArgNo), ActualArg, &CB, CalledFunc, IFI, - CalledFunc->getParamAlignment(ArgNo)); + CalledFunc->getParamAlign(ArgNo)); if (ActualArg != *AI) ByValInits.push_back( {ActualArg, (Value *)*AI, CB.getParamByValType(ArgNo)}); @@ -2019,7 +2244,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, HandleByValArgumentInit(Init.Ty, Init.Dst, Init.Src, Caller->getParent(), &*FirstNewBlock, IFI); - Optional<OperandBundleUse> ParentDeopt = + std::optional<OperandBundleUse> ParentDeopt = CB.getOperandBundle(LLVMContext::OB_deopt); if (ParentDeopt) { SmallVector<OperandBundleDef, 2> OpDefs; @@ -2077,6 +2302,15 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, fixupLineNumbers(Caller, FirstNewBlock, &CB, CalledFunc->getSubprogram() != nullptr); + if (isAssignmentTrackingEnabled(*Caller->getParent())) { + // Interpret inlined stores to caller-local variables as assignments. + trackInlinedStores(FirstNewBlock, Caller->end(), CB); + + // Update DIAssignID metadata attachments and uses so that they are + // unique to this inlined instance. + fixupAssignments(FirstNewBlock, Caller->end()); + } + // Now clone the inlined noalias scope metadata. SAMetadataCloner.clone(); SAMetadataCloner.remap(FirstNewBlock, Caller->end()); @@ -2088,6 +2322,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // function which feed into its return value. AddReturnAttributes(CB, VMap); + propagateMemProfMetadata(CalledFunc, CB, + InlinedFunctionInfo.ContainsMemProfMetadata, VMap); + // Propagate metadata on the callsite if necessary. PropagateCallSiteMetadata(CB, FirstNewBlock, Caller->end()); @@ -2096,7 +2333,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, for (BasicBlock &NewBlock : make_range(FirstNewBlock->getIterator(), Caller->end())) for (Instruction &I : NewBlock) - if (auto *II = dyn_cast<AssumeInst>(&I)) + if (auto *II = dyn_cast<CondGuardInst>(&I)) IFI.GetAssumptionCache(*Caller).registerAssumption(II); } @@ -2136,8 +2373,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Transfer all of the allocas over in a block. Using splice means // that the instructions aren't removed from the symbol table, then // reinserted. - Caller->getEntryBlock().getInstList().splice( - InsertPoint, FirstNewBlock->getInstList(), AI->getIterator(), I); + Caller->getEntryBlock().splice(InsertPoint, &*FirstNewBlock, + AI->getIterator(), I); } } @@ -2270,7 +2507,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if (!AllocaTypeSize.isScalable() && AllocaArraySize != std::numeric_limits<uint64_t>::max() && std::numeric_limits<uint64_t>::max() / AllocaArraySize >= - AllocaTypeSize.getFixedSize()) { + AllocaTypeSize.getFixedValue()) { AllocaSize = ConstantInt::get(Type::getInt64Ty(AI->getContext()), AllocaArraySize * AllocaTypeSize); } @@ -2480,10 +2717,10 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // the calling basic block. if (Returns.size() == 1 && std::distance(FirstNewBlock, Caller->end()) == 1) { // Move all of the instructions right before the call. - OrigBB->getInstList().splice(CB.getIterator(), FirstNewBlock->getInstList(), - FirstNewBlock->begin(), FirstNewBlock->end()); + OrigBB->splice(CB.getIterator(), &*FirstNewBlock, FirstNewBlock->begin(), + FirstNewBlock->end()); // Remove the cloned basic block. - Caller->getBasicBlockList().pop_back(); + Caller->back().eraseFromParent(); // If the call site was an invoke instruction, add a branch to the normal // destination. @@ -2507,6 +2744,9 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Since we are now done with the return instruction, delete it also. Returns[0]->eraseFromParent(); + if (MergeAttributes) + AttributeFuncs::mergeAttributesForInlining(*Caller, *CalledFunc); + // We are now done with the inlining. return InlineResult::success(); } @@ -2556,9 +2796,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Now that the function is correct, make it a little bit nicer. In // particular, move the basic blocks inserted from the end of the function // into the space made by splitting the source basic block. - Caller->getBasicBlockList().splice(AfterCallBB->getIterator(), - Caller->getBasicBlockList(), FirstNewBlock, - Caller->end()); + Caller->splice(AfterCallBB->getIterator(), Caller, FirstNewBlock, + Caller->end()); // Handle all of the return instructions that we just cloned in, and eliminate // any users of the original call/invoke instruction. @@ -2618,8 +2857,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Splice the code from the return block into the block that it will return // to, which contains the code that was after the call. - AfterCallBB->getInstList().splice(AfterCallBB->begin(), - ReturnBB->getInstList()); + AfterCallBB->splice(AfterCallBB->begin(), ReturnBB); if (CreatedBranchToNormalDest) CreatedBranchToNormalDest->setDebugLoc(Returns[0]->getDebugLoc()); @@ -2649,13 +2887,13 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Splice the code entry block into calling block, right before the // unconditional branch. CalleeEntry->replaceAllUsesWith(OrigBB); // Update PHI nodes - OrigBB->getInstList().splice(Br->getIterator(), CalleeEntry->getInstList()); + OrigBB->splice(Br->getIterator(), CalleeEntry); // Remove the unconditional branch. - OrigBB->getInstList().erase(Br); + Br->eraseFromParent(); // Now we can remove the CalleeEntry block, which is now empty. - Caller->getBasicBlockList().erase(CalleeEntry); + CalleeEntry->eraseFromParent(); // If we inserted a phi node, check to see if it has a single value (e.g. all // the entries are the same or undef). If so, remove the PHI so it doesn't @@ -2670,5 +2908,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, } } + if (MergeAttributes) + AttributeFuncs::mergeAttributesForInlining(*Caller, *CalledFunc); + return InlineResult::success(); } diff --git a/llvm/lib/Transforms/Utils/IntegerDivision.cpp b/llvm/lib/Transforms/Utils/IntegerDivision.cpp index 47ab30f03d14..cea095408b0c 100644 --- a/llvm/lib/Transforms/Utils/IntegerDivision.cpp +++ b/llvm/lib/Transforms/Utils/IntegerDivision.cpp @@ -32,14 +32,7 @@ using namespace llvm; static Value *generateSignedRemainderCode(Value *Dividend, Value *Divisor, IRBuilder<> &Builder) { unsigned BitWidth = Dividend->getType()->getIntegerBitWidth(); - ConstantInt *Shift; - - if (BitWidth == 64) { - Shift = Builder.getInt64(63); - } else { - assert(BitWidth == 32 && "Unexpected bit width"); - Shift = Builder.getInt32(31); - } + ConstantInt *Shift = Builder.getIntN(BitWidth, BitWidth - 1); // Following instructions are generated for both i32 (shift 31) and // i64 (shift 63). @@ -53,6 +46,8 @@ static Value *generateSignedRemainderCode(Value *Dividend, Value *Divisor, // ; %urem = urem i32 %dividend, %divisor // ; %xored = xor i32 %urem, %dividend_sgn // ; %srem = sub i32 %xored, %dividend_sgn + Dividend = Builder.CreateFreeze(Dividend); + Divisor = Builder.CreateFreeze(Divisor); Value *DividendSign = Builder.CreateAShr(Dividend, Shift); Value *DivisorSign = Builder.CreateAShr(Divisor, Shift); Value *DvdXor = Builder.CreateXor(Dividend, DividendSign); @@ -84,6 +79,8 @@ static Value *generatedUnsignedRemainderCode(Value *Dividend, Value *Divisor, // ; %quotient = udiv i32 %dividend, %divisor // ; %product = mul i32 %divisor, %quotient // ; %remainder = sub i32 %dividend, %product + Dividend = Builder.CreateFreeze(Dividend); + Divisor = Builder.CreateFreeze(Divisor); Value *Quotient = Builder.CreateUDiv(Dividend, Divisor); Value *Product = Builder.CreateMul(Divisor, Quotient); Value *Remainder = Builder.CreateSub(Dividend, Product); @@ -104,14 +101,7 @@ static Value *generateSignedDivisionCode(Value *Dividend, Value *Divisor, // Implementation taken from compiler-rt's __divsi3 and __divdi3 unsigned BitWidth = Dividend->getType()->getIntegerBitWidth(); - ConstantInt *Shift; - - if (BitWidth == 64) { - Shift = Builder.getInt64(63); - } else { - assert(BitWidth == 32 && "Unexpected bit width"); - Shift = Builder.getInt32(31); - } + ConstantInt *Shift = Builder.getIntN(BitWidth, BitWidth - 1); // Following instructions are generated for both i32 (shift 31) and // i64 (shift 63). @@ -126,6 +116,8 @@ static Value *generateSignedDivisionCode(Value *Dividend, Value *Divisor, // ; %q_mag = udiv i32 %u_dvnd, %u_dvsr // ; %tmp4 = xor i32 %q_mag, %q_sgn // ; %q = sub i32 %tmp4, %q_sgn + Dividend = Builder.CreateFreeze(Dividend); + Divisor = Builder.CreateFreeze(Divisor); Value *Tmp = Builder.CreateAShr(Dividend, Shift); Value *Tmp1 = Builder.CreateAShr(Divisor, Shift); Value *Tmp2 = Builder.CreateXor(Tmp, Dividend); @@ -156,23 +148,10 @@ static Value *generateUnsignedDivisionCode(Value *Dividend, Value *Divisor, IntegerType *DivTy = cast<IntegerType>(Dividend->getType()); unsigned BitWidth = DivTy->getBitWidth(); - ConstantInt *Zero; - ConstantInt *One; - ConstantInt *NegOne; - ConstantInt *MSB; - - if (BitWidth == 64) { - Zero = Builder.getInt64(0); - One = Builder.getInt64(1); - NegOne = ConstantInt::getSigned(DivTy, -1); - MSB = Builder.getInt64(63); - } else { - assert(BitWidth == 32 && "Unexpected bit width"); - Zero = Builder.getInt32(0); - One = Builder.getInt32(1); - NegOne = ConstantInt::getSigned(DivTy, -1); - MSB = Builder.getInt32(31); - } + ConstantInt *Zero = ConstantInt::get(DivTy, 0); + ConstantInt *One = ConstantInt::get(DivTy, 1); + ConstantInt *NegOne = ConstantInt::getSigned(DivTy, -1); + ConstantInt *MSB = ConstantInt::get(DivTy, BitWidth - 1); ConstantInt *True = Builder.getTrue(); @@ -241,12 +220,14 @@ static Value *generateUnsignedDivisionCode(Value *Dividend, Value *Divisor, // ; %tmp1 = tail call i32 @llvm.ctlz.i32(i32 %dividend, i1 true) // ; %sr = sub nsw i32 %tmp0, %tmp1 // ; %ret0_4 = icmp ugt i32 %sr, 31 - // ; %ret0 = or i1 %ret0_3, %ret0_4 + // ; %ret0 = select i1 %ret0_3, i1 true, i1 %ret0_4 // ; %retDividend = icmp eq i32 %sr, 31 // ; %retVal = select i1 %ret0, i32 0, i32 %dividend - // ; %earlyRet = or i1 %ret0, %retDividend + // ; %earlyRet = select i1 %ret0, i1 true, %retDividend // ; br i1 %earlyRet, label %end, label %bb1 Builder.SetInsertPoint(SpecialCases); + Divisor = Builder.CreateFreeze(Divisor); + Dividend = Builder.CreateFreeze(Dividend); Value *Ret0_1 = Builder.CreateICmpEQ(Divisor, Zero); Value *Ret0_2 = Builder.CreateICmpEQ(Dividend, Zero); Value *Ret0_3 = Builder.CreateOr(Ret0_1, Ret0_2); @@ -254,10 +235,10 @@ static Value *generateUnsignedDivisionCode(Value *Dividend, Value *Divisor, Value *Tmp1 = Builder.CreateCall(CTLZ, {Dividend, True}); Value *SR = Builder.CreateSub(Tmp0, Tmp1); Value *Ret0_4 = Builder.CreateICmpUGT(SR, MSB); - Value *Ret0 = Builder.CreateOr(Ret0_3, Ret0_4); + Value *Ret0 = Builder.CreateLogicalOr(Ret0_3, Ret0_4); Value *RetDividend = Builder.CreateICmpEQ(SR, MSB); Value *RetVal = Builder.CreateSelect(Ret0, Zero, Dividend); - Value *EarlyRet = Builder.CreateOr(Ret0, RetDividend); + Value *EarlyRet = Builder.CreateLogicalOr(Ret0, RetDividend); Builder.CreateCondBr(EarlyRet, End, BB1); // ; bb1: ; preds = %special-cases @@ -367,8 +348,7 @@ static Value *generateUnsignedDivisionCode(Value *Dividend, Value *Divisor, /// Generate code to calculate the remainder of two integers, replacing Rem with /// the generated code. This currently generates code using the udiv expansion, /// but future work includes generating more specialized code, e.g. when more -/// information about the operands are known. Implements both 32bit and 64bit -/// scalar division. +/// information about the operands are known. /// /// Replace Rem with generated code. bool llvm::expandRemainder(BinaryOperator *Rem) { @@ -379,9 +359,6 @@ bool llvm::expandRemainder(BinaryOperator *Rem) { IRBuilder<> Builder(Rem); assert(!Rem->getType()->isVectorTy() && "Div over vectors not supported"); - assert((Rem->getType()->getIntegerBitWidth() == 32 || - Rem->getType()->getIntegerBitWidth() == 64) && - "Div of bitwidth other than 32 or 64 not supported"); // First prepare the sign if it's a signed remainder if (Rem->getOpcode() == Instruction::SRem) { @@ -421,12 +398,10 @@ bool llvm::expandRemainder(BinaryOperator *Rem) { return true; } - /// Generate code to divide two integers, replacing Div with the generated /// code. This currently generates code similarly to compiler-rt's /// implementations, but future work includes generating more specialized code -/// when more information about the operands are known. Implements both -/// 32bit and 64bit scalar division. +/// when more information about the operands are known. /// /// Replace Div with generated code. bool llvm::expandDivision(BinaryOperator *Div) { @@ -437,9 +412,6 @@ bool llvm::expandDivision(BinaryOperator *Div) { IRBuilder<> Builder(Div); assert(!Div->getType()->isVectorTy() && "Div over vectors not supported"); - assert((Div->getType()->getIntegerBitWidth() == 32 || - Div->getType()->getIntegerBitWidth() == 64) && - "Div of bitwidth other than 32 or 64 not supported"); // First prepare the sign if it's a signed division if (Div->getOpcode() == Instruction::SDiv) { @@ -540,9 +512,7 @@ bool llvm::expandRemainderUpTo64Bits(BinaryOperator *Rem) { unsigned RemTyBitWidth = RemTy->getIntegerBitWidth(); - assert(RemTyBitWidth <= 64 && "Div of bitwidth greater than 64 not supported"); - - if (RemTyBitWidth == 64) + if (RemTyBitWidth >= 64) return expandRemainder(Rem); // If bitwidth smaller than 64 extend inputs, extend output and proceed @@ -637,10 +607,7 @@ bool llvm::expandDivisionUpTo64Bits(BinaryOperator *Div) { unsigned DivTyBitWidth = DivTy->getIntegerBitWidth(); - assert(DivTyBitWidth <= 64 && - "Div of bitwidth greater than 64 not supported"); - - if (DivTyBitWidth == 64) + if (DivTyBitWidth >= 64) return expandDivision(Div); // If bitwidth smaller than 64 extend inputs, extend output and proceed diff --git a/llvm/lib/Transforms/Utils/LCSSA.cpp b/llvm/lib/Transforms/Utils/LCSSA.cpp index 84d377d835f3..af79dc456ea6 100644 --- a/llvm/lib/Transforms/Utils/LCSSA.cpp +++ b/llvm/lib/Transforms/Utils/LCSSA.cpp @@ -107,10 +107,16 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, if (ExitBlocks.empty()) continue; - for (Use &U : I->uses()) { + for (Use &U : make_early_inc_range(I->uses())) { Instruction *User = cast<Instruction>(U.getUser()); BasicBlock *UserBB = User->getParent(); + // Skip uses in unreachable blocks. + if (!DT.isReachableFromEntry(UserBB)) { + U.set(PoisonValue::get(I->getType())); + continue; + } + // For practical purposes, we consider that the use in a PHI // occurs in the respective predecessor block. For more info, // see the `phi` doc in LangRef and the LCSSA doc. @@ -235,7 +241,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, llvm::findDbgValues(DbgValues, I); // Update pre-existing debug value uses that reside outside the loop. - for (auto DVI : DbgValues) { + for (auto *DVI : DbgValues) { BasicBlock *UserBB = DVI->getParent(); if (InstBB == UserBB || L->contains(UserBB)) continue; @@ -417,7 +423,7 @@ bool llvm::formLCSSARecursively(Loop &L, const DominatorTree &DT, static bool formLCSSAOnAllLoops(const LoopInfo *LI, const DominatorTree &DT, ScalarEvolution *SE) { bool Changed = false; - for (auto &L : *LI) + for (const auto &L : *LI) Changed |= formLCSSARecursively(*L, DT, LI, SE); return Changed; } diff --git a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index 6e87da9fb168..5dd469c7af4b 100644 --- a/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -40,6 +40,9 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" + +#include <cmath> + using namespace llvm; #define DEBUG_TYPE "libcalls-shrinkwrap" @@ -493,7 +496,7 @@ void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) { assert(SuccBB && "The split block should have a single successor"); SuccBB->setName("cdce.end"); CI->removeFromParent(); - CallBB->getInstList().insert(CallBB->getFirstInsertionPt(), CI); + CI->insertInto(CallBB, CallBB->getFirstInsertionPt()); LLVM_DEBUG(dbgs() << "== Basic Block After =="); LLVM_DEBUG(dbgs() << *CallBB->getSinglePredecessor() << *CallBB << *CallBB->getSingleSuccessor() << "\n"); diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 2f1d0c2f9012..31cdd2ee56b9 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -17,8 +17,6 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -58,11 +56,13 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -80,6 +80,7 @@ #include <cstdint> #include <iterator> #include <map> +#include <optional> #include <utility> using namespace llvm; @@ -210,20 +211,18 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. if (i->getCaseSuccessor() == DefaultDest) { - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = getValidBranchWeightMDNode(*SI); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches // left, unless the metadata doesn't match the switch. - if (NCases > 1 && MD && MD->getNumOperands() == 2 + NCases) { + if (NCases > 1 && MD) { // Collect branch weights into a vector. SmallVector<uint32_t, 8> Weights; - for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; - ++MD_i) { - auto *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i)); - Weights.push_back(CI->getValue().getZExtValue()); - } + extractBranchWeights(MD, Weights); + // Merge weight of this case to the default weight. unsigned idx = i->getCaseIndex(); + // TODO: Add overflow check. Weights[0] += Weights[idx+1]; // Remove weight for this case. std::swap(Weights[idx+1], Weights.back()); @@ -237,6 +236,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, DefaultDest->removePredecessor(ParentBB); i = SI->removeCase(i); e = SI->case_end(); + + // Removing this case may have made the condition constant. In that + // case, update CI and restart iteration through the cases. + if (auto *NewCI = dyn_cast<ConstantInt>(SI->getCondition())) { + CI = NewCI; + i = SI->case_begin(); + } + Changed = true; continue; } @@ -305,18 +312,14 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, BranchInst *NewBr = Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(), SI->getDefaultDest()); - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); - if (MD && MD->getNumOperands() == 3) { - ConstantInt *SICase = - mdconst::dyn_extract<ConstantInt>(MD->getOperand(2)); - ConstantInt *SIDef = - mdconst::dyn_extract<ConstantInt>(MD->getOperand(1)); - assert(SICase && SIDef); + SmallVector<uint32_t> Weights; + if (extractBranchWeights(*SI, Weights) && Weights.size() == 2) { + uint32_t DefWeight = Weights[0]; + uint32_t CaseWeight = Weights[1]; // The TrueWeight should be the weight for the single case of SI. NewBr->setMetadata(LLVMContext::MD_prof, - MDBuilder(BB->getContext()). - createBranchWeights(SICase->getValue().getZExtValue(), - SIDef->getValue().getZExtValue())); + MDBuilder(BB->getContext()) + .createBranchWeights(CaseWeight, DefWeight)); } // Update make.implicit metadata to the newly-created conditional branch. @@ -443,8 +446,23 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, if (isRemovableAlloc(CB, TLI)) return true; - if (!I->willReturn()) - return false; + if (!I->willReturn()) { + auto *II = dyn_cast<IntrinsicInst>(I); + if (!II) + return false; + + // TODO: These intrinsics are not safe to remove, because this may remove + // a well-defined trap. + switch (II->getIntrinsicID()) { + case Intrinsic::wasm_trunc_signed: + case Intrinsic::wasm_trunc_unsigned: + case Intrinsic::ptrauth_auth: + case Intrinsic::ptrauth_resign: + return true; + default: + return false; + } + } if (!I->mayHaveSideEffects()) return true; @@ -488,7 +506,8 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, } if (auto *FPI = dyn_cast<ConstrainedFPIntrinsic>(I)) { - Optional<fp::ExceptionBehavior> ExBehavior = FPI->getExceptionBehavior(); + std::optional<fp::ExceptionBehavior> ExBehavior = + FPI->getExceptionBehavior(); return *ExBehavior != fp::ebStrict; } } @@ -595,10 +614,8 @@ void llvm::RecursivelyDeleteTriviallyDeadInstructions( bool llvm::replaceDbgUsesWithUndef(Instruction *I) { SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; findDbgUsers(DbgUsers, I); - for (auto *DII : DbgUsers) { - Value *Undef = UndefValue::get(I->getType()); - DII->replaceVariableLocationOp(I, Undef); - } + for (auto *DII : DbgUsers) + DII->setKillLocation(); return !DbgUsers.empty(); } @@ -798,7 +815,7 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, // Splice all the instructions from PredBB to DestBB. PredBB->getTerminator()->eraseFromParent(); - DestBB->getInstList().splice(DestBB->begin(), PredBB->getInstList()); + DestBB->splice(DestBB->begin(), PredBB); new UnreachableInst(PredBB->getContext(), PredBB); // If the PredBB is the entry block of the function, move DestBB up to @@ -807,7 +824,7 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DestBB->moveAfter(PredBB); if (DTU) { - assert(PredBB->getInstList().size() == 1 && + assert(PredBB->size() == 1 && isa<UnreachableInst>(PredBB->getTerminator()) && "The successor list of PredBB isn't empty before " "applying corresponding DTU updates."); @@ -1090,17 +1107,77 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, } } - // We cannot fold the block if it's a branch to an already present callbr - // successor because that creates duplicate successors. - for (BasicBlock *PredBB : predecessors(BB)) { - if (auto *CBI = dyn_cast<CallBrInst>(PredBB->getTerminator())) { - if (Succ == CBI->getDefaultDest()) - return false; - for (unsigned i = 0, e = CBI->getNumIndirectDests(); i != e; ++i) - if (Succ == CBI->getIndirectDest(i)) - return false; - } - } + // 'BB' and 'BB->Pred' are loop latches, bail out to presrve inner loop + // metadata. + // + // FIXME: This is a stop-gap solution to preserve inner-loop metadata given + // current status (that loop metadata is implemented as metadata attached to + // the branch instruction in the loop latch block). To quote from review + // comments, "the current representation of loop metadata (using a loop latch + // terminator attachment) is known to be fundamentally broken. Loop latches + // are not uniquely associated with loops (both in that a latch can be part of + // multiple loops and a loop may have multiple latches). Loop headers are. The + // solution to this problem is also known: Add support for basic block + // metadata, and attach loop metadata to the loop header." + // + // Why bail out: + // In this case, we expect 'BB' is the latch for outer-loop and 'BB->Pred' is + // the latch for inner-loop (see reason below), so bail out to prerserve + // inner-loop metadata rather than eliminating 'BB' and attaching its metadata + // to this inner-loop. + // - The reason we believe 'BB' and 'BB->Pred' have different inner-most + // loops: assuming 'BB' and 'BB->Pred' are from the same inner-most loop L, + // then 'BB' is the header and latch of 'L' and thereby 'L' must consist of + // one self-looping basic block, which is contradictory with the assumption. + // + // To illustrate how inner-loop metadata is dropped: + // + // CFG Before + // + // BB is while.cond.exit, attached with loop metdata md2. + // BB->Pred is for.body, attached with loop metadata md1. + // + // entry + // | + // v + // ---> while.cond -------------> while.end + // | | + // | v + // | while.body + // | | + // | v + // | for.body <---- (md1) + // | | |______| + // | v + // | while.cond.exit (md2) + // | | + // |_______| + // + // CFG After + // + // while.cond1 is the merge of while.cond.exit and while.cond above. + // for.body is attached with md2, and md1 is dropped. + // If LoopSimplify runs later (as a part of loop pass), it could create + // dedicated exits for inner-loop (essentially adding `while.cond.exit` + // back), but won't it won't see 'md1' nor restore it for the inner-loop. + // + // entry + // | + // v + // ---> while.cond1 -------------> while.end + // | | + // | v + // | while.body + // | | + // | v + // | for.body <---- (md2) + // |_______| |______| + if (Instruction *TI = BB->getTerminator()) + if (TI->hasMetadata(LLVMContext::MD_loop)) + for (BasicBlock *Pred : predecessors(BB)) + if (Instruction *PredTI = Pred->getTerminator()) + if (PredTI->hasMetadata(LLVMContext::MD_loop)) + return false; LLVM_DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB); @@ -1143,8 +1220,7 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, // Copy over any phi, debug or lifetime instruction. BB->getTerminator()->eraseFromParent(); - Succ->getInstList().splice(Succ->getFirstNonPHI()->getIterator(), - BB->getInstList()); + Succ->splice(Succ->getFirstNonPHI()->getIterator(), BB); } else { while (PHINode *PN = dyn_cast<PHINode>(&BB->front())) { // We explicitly check for such uses in CanPropagatePredecessorsForPHIs. @@ -1168,7 +1244,7 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, // Clear the successor list of BB to match updates applying to DTU later. if (BB->getTerminator()) - BB->getInstList().pop_back(); + BB->back().eraseFromParent(); new UnreachableInst(BB->getContext(), BB); assert(succ_empty(BB) && "The successor list of BB isn't empty before " "applying corresponding DTU updates."); @@ -1412,10 +1488,10 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { const DataLayout &DL = DII->getModule()->getDataLayout(); TypeSize ValueSize = DL.getTypeAllocSizeInBits(ValTy); - if (Optional<uint64_t> FragmentSize = DII->getFragmentSizeInBits()) { + if (std::optional<uint64_t> FragmentSize = DII->getFragmentSizeInBits()) { assert(!ValueSize.isScalable() && "Fragments don't work on scalable types."); - return ValueSize.getFixedSize() >= *FragmentSize; + return ValueSize.getFixedValue() >= *FragmentSize; } // We can't always calculate the size of the DI variable (e.g. if it is a // VLA). Try to use the size of the alloca that the dbg intrinsic describes @@ -1426,7 +1502,8 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { "address of variable must have exactly 1 location operand."); if (auto *AI = dyn_cast_or_null<AllocaInst>(DII->getVariableLocationOp(0))) { - if (Optional<TypeSize> FragmentSize = AI->getAllocationSizeInBits(DL)) { + if (std::optional<TypeSize> FragmentSize = + AI->getAllocationSizeInBits(DL)) { return TypeSize::isKnownGE(ValueSize, *FragmentSize); } } @@ -1435,30 +1512,17 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { return false; } -/// Produce a DebugLoc to use for each dbg.declare/inst pair that are promoted -/// to a dbg.value. Because no machine insts can come from debug intrinsics, -/// only the scope and inlinedAt is significant. Zero line numbers are used in -/// case this DebugLoc leaks into any adjacent instructions. -static DebugLoc getDebugValueLoc(DbgVariableIntrinsic *DII, Instruction *Src) { - // Original dbg.declare must have a location. - const DebugLoc &DeclareLoc = DII->getDebugLoc(); - MDNode *Scope = DeclareLoc.getScope(); - DILocation *InlinedAt = DeclareLoc.getInlinedAt(); - // Produce an unknown location with the correct scope / inlinedAt fields. - return DILocation::get(DII->getContext(), 0, 0, Scope, InlinedAt); -} - /// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value /// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, StoreInst *SI, DIBuilder &Builder) { - assert(DII->isAddressOfVariable()); + assert(DII->isAddressOfVariable() || isa<DbgAssignIntrinsic>(DII)); auto *DIVar = DII->getVariable(); assert(DIVar && "Missing variable"); auto *DIExpr = DII->getExpression(); Value *DV = SI->getValueOperand(); - DebugLoc NewLoc = getDebugValueLoc(DII, SI); + DebugLoc NewLoc = getDebugValueLoc(DII); if (!valueCoversEntireFragment(DV->getType(), DII)) { // FIXME: If storing to a part of the variable described by the dbg.declare, @@ -1493,7 +1557,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, return; } - DebugLoc NewLoc = getDebugValueLoc(DII, nullptr); + DebugLoc NewLoc = getDebugValueLoc(DII); // We are now tracking the loaded value instead of the address. In the // future if multi-location support is added to the IR, it might be @@ -1527,7 +1591,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, BasicBlock *BB = APN->getParent(); auto InsertionPt = BB->getFirstInsertionPt(); - DebugLoc NewLoc = getDebugValueLoc(DII, nullptr); + DebugLoc NewLoc = getDebugValueLoc(DII); // The block may be a catchswitch block, which does not have a valid // insertion point. @@ -1587,7 +1651,7 @@ bool llvm::LowerDbgDeclare(Function &F) { WorkList.push_back(AI); while (!WorkList.empty()) { const Value *V = WorkList.pop_back_val(); - for (auto &AIUse : V->uses()) { + for (const auto &AIUse : V->uses()) { User *U = AIUse.getUser(); if (StoreInst *SI = dyn_cast<StoreInst>(U)) { if (AIUse.getOperandNo() == 1) @@ -1599,7 +1663,7 @@ bool llvm::LowerDbgDeclare(Function &F) { // pointer to the variable. Insert a *value* intrinsic that describes // the variable by dereferencing the alloca. if (!CI->isLifetimeStartOrEnd()) { - DebugLoc NewLoc = getDebugValueLoc(DDI, nullptr); + DebugLoc NewLoc = getDebugValueLoc(DDI); auto *DerefExpr = DIExpression::append(DDI->getExpression(), dwarf::DW_OP_deref); DIB.insertDbgValueIntrinsic(AI, DDI->getVariable(), DerefExpr, @@ -1653,12 +1717,12 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB, // propagate the info through the new PHI. If we use more than one new PHI in // a single destination BB with the same old dbg.value, merge the updates so // that we get a single new dbg.value with all the new PHIs. - for (auto PHI : InsertedPHIs) { + for (auto *PHI : InsertedPHIs) { BasicBlock *Parent = PHI->getParent(); // Avoid inserting an intrinsic into an EH block. if (Parent->getFirstNonPHI()->isEHPad()) continue; - for (auto VI : PHI->operand_values()) { + for (auto *VI : PHI->operand_values()) { auto V = DbgValueMap.find(VI); if (V != DbgValueMap.end()) { auto *DbgII = cast<DbgVariableIntrinsic>(V->second); @@ -1735,14 +1799,48 @@ void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress, replaceOneDbgValueForAlloca(DVI, NewAllocaAddress, Builder, Offset); } -/// Where possible to salvage debug information for \p I do so -/// and return True. If not possible mark undef and return False. +/// Where possible to salvage debug information for \p I do so. +/// If not possible mark undef. void llvm::salvageDebugInfo(Instruction &I) { SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; findDbgUsers(DbgUsers, &I); salvageDebugInfoForDbgValues(I, DbgUsers); } +/// Salvage the address component of \p DAI. +static void salvageDbgAssignAddress(DbgAssignIntrinsic *DAI) { + Instruction *I = dyn_cast<Instruction>(DAI->getAddress()); + // Only instructions can be salvaged at the moment. + if (!I) + return; + + assert(!DAI->getAddressExpression()->getFragmentInfo().has_value() && + "address-expression shouldn't have fragment info"); + + // The address component of a dbg.assign cannot be variadic. + uint64_t CurrentLocOps = 0; + SmallVector<Value *, 4> AdditionalValues; + SmallVector<uint64_t, 16> Ops; + Value *NewV = salvageDebugInfoImpl(*I, CurrentLocOps, Ops, AdditionalValues); + + // Check if the salvage failed. + if (!NewV) + return; + + DIExpression *SalvagedExpr = DIExpression::appendOpsToArg( + DAI->getAddressExpression(), Ops, 0, /*StackValue=*/false); + assert(!SalvagedExpr->getFragmentInfo().has_value() && + "address-expression shouldn't have fragment info"); + + // Salvage succeeds if no additional values are required. + if (AdditionalValues.empty()) { + DAI->setAddress(NewV); + DAI->setAddressExpression(SalvagedExpr); + } else { + DAI->setKillAddress(); + } +} + void llvm::salvageDebugInfoForDbgValues( Instruction &I, ArrayRef<DbgVariableIntrinsic *> DbgUsers) { // These are arbitrary chosen limits on the maximum number of values and the @@ -1753,6 +1851,15 @@ void llvm::salvageDebugInfoForDbgValues( bool Salvaged = false; for (auto *DII : DbgUsers) { + if (auto *DAI = dyn_cast<DbgAssignIntrinsic>(DII)) { + if (DAI->getAddress() == &I) { + salvageDbgAssignAddress(DAI); + Salvaged = true; + } + if (DAI->getValue() != &I) + continue; + } + // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they // are implicitly pointing out the value as a DWARF memory location // description. @@ -1789,17 +1896,18 @@ void llvm::salvageDebugInfoForDbgValues( bool IsValidSalvageExpr = SalvagedExpr->getNumElements() <= MaxExpressionSize; if (AdditionalValues.empty() && IsValidSalvageExpr) { DII->setExpression(SalvagedExpr); - } else if (isa<DbgValueInst>(DII) && IsValidSalvageExpr && + } else if (isa<DbgValueInst>(DII) && !isa<DbgAssignIntrinsic>(DII) && + IsValidSalvageExpr && DII->getNumVariableLocationOps() + AdditionalValues.size() <= MaxDebugArgs) { DII->addVariableLocationOps(AdditionalValues, SalvagedExpr); } else { // Do not salvage using DIArgList for dbg.addr/dbg.declare, as it is - // currently only valid for stack value expressions. + // not currently supported in those instructions. Do not salvage using + // DIArgList for dbg.assign yet. FIXME: support this. // Also do not salvage if the resulting DIArgList would contain an // unreasonably large number of values. - Value *Undef = UndefValue::get(I.getOperand(0)->getType()); - DII->replaceVariableLocationOp(I.getOperand(0), Undef); + DII->setKillLocation(); } LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); Salvaged = true; @@ -1808,10 +1916,8 @@ void llvm::salvageDebugInfoForDbgValues( if (Salvaged) return; - for (auto *DII : DbgUsers) { - Value *Undef = UndefValue::get(I.getType()); - DII->replaceVariableLocationOp(&I, Undef); - } + for (auto *DII : DbgUsers) + DII->setKillLocation(); } Value *getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL, @@ -1956,7 +2062,7 @@ Value *llvm::salvageDebugInfoImpl(Instruction &I, uint64_t CurrentLocOps, } /// A replacement for a dbg.value expression. -using DbgValReplacement = Optional<DIExpression *>; +using DbgValReplacement = std::optional<DIExpression *>; /// Point debug users of \p From to \p To using exprs given by \p RewriteExpr, /// possibly moving/undefing users to prevent use-before-def. Returns true if @@ -2082,7 +2188,7 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, // Without knowing signedness, sign/zero extension isn't possible. auto Signedness = Var->getSignedness(); if (!Signedness) - return None; + return std::nullopt; bool Signed = *Signedness == DIBasicType::Signedness::Signed; return DIExpression::appendExt(DII.getExpression(), ToBits, FromBits, @@ -2146,7 +2252,7 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool PreserveLCSSA, while (BBI != BBE) { if (!BBI->use_empty()) BBI->replaceAllUsesWith(PoisonValue::get(BBI->getType())); - BB->getInstList().erase(BBI++); + BBI++->eraseFromParent(); ++NumInstrsRemoved; } if (DTU) { @@ -2216,7 +2322,7 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, CI->getName() + ".noexc"); // Delete the unconditional branch inserted by SplitBlock - BB->getInstList().pop_back(); + BB->back().eraseFromParent(); // Create the new invoke instruction. SmallVector<Value *, 8> InvokeArgs(CI->args()); @@ -2244,7 +2350,7 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, CI->replaceAllUsesWith(II); // Delete the original call - Split->getInstList().pop_front(); + Split->front().eraseFromParent(); return Split; } @@ -2297,7 +2403,9 @@ static bool markAliveBlocks(Function &F, } } } else if ((isa<ConstantPointerNull>(Callee) && - !NullPointerIsDefined(CI->getFunction())) || + !NullPointerIsDefined(CI->getFunction(), + cast<PointerType>(Callee->getType()) + ->getAddressSpace())) || isa<UndefValue>(Callee)) { changeToUnreachable(CI, false, DTU); Changed = true; @@ -2444,13 +2552,11 @@ static bool markAliveBlocks(Function &F, return Changed; } -void llvm::removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU) { +Instruction *llvm::removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU) { Instruction *TI = BB->getTerminator(); - if (auto *II = dyn_cast<InvokeInst>(TI)) { - changeToCall(II, DTU); - return; - } + if (auto *II = dyn_cast<InvokeInst>(TI)) + return changeToCall(II, DTU); Instruction *NewTI; BasicBlock *UnwindDest; @@ -2478,6 +2584,7 @@ void llvm::removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU) { TI->eraseFromParent(); if (DTU) DTU->applyUpdates({{DominatorTree::Delete, BB, UnwindDest}}); + return NewTI; } /// removeUnreachableBlocks - Remove blocks that are not reachable, even @@ -2536,6 +2643,9 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, break; case LLVMContext::MD_dbg: llvm_unreachable("getAllMetadataOtherThanDebugLoc returned a MD_dbg"); + case LLVMContext::MD_DIAssignID: + K->mergeDIAssignID(J); + break; case LLVMContext::MD_tbaa: K->setMetadata(Kind, MDNode::getMostGenericTBAA(JMD, KMD)); break; @@ -2642,6 +2752,7 @@ void llvm::copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source) { case LLVMContext::MD_nontemporal: case LLVMContext::MD_mem_parallel_loop_access: case LLVMContext::MD_access_group: + case LLVMContext::MD_noundef: // All of these directly apply. Dest.setMetadata(ID, N); break; @@ -2805,6 +2916,11 @@ void llvm::copyNonnullMetadata(const LoadInst &OldLI, MDNode *N, void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, MDNode *N, LoadInst &NewLI) { auto *NewTy = NewLI.getType(); + // Simply copy the metadata if the type did not change. + if (NewTy == OldLI.getType()) { + NewLI.setMetadata(LLVMContext::MD_range, N); + return; + } // Give up unless it is converted to a pointer where there is a single very // valuable mapping we can do reliably. @@ -2815,7 +2931,7 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, unsigned BitWidth = DL.getPointerTypeSizeInBits(NewTy); if (!getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) { - MDNode *NN = MDNode::get(OldLI.getContext(), None); + MDNode *NN = MDNode::get(OldLI.getContext(), std::nullopt); NewLI.setMetadata(LLVMContext::MD_nonnull, NN); } } @@ -2864,9 +2980,8 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, I->setDebugLoc(InsertPt->getDebugLoc()); ++II; } - DomBlock->getInstList().splice(InsertPt->getIterator(), BB->getInstList(), - BB->begin(), - BB->getTerminator()->getIterator()); + DomBlock->splice(InsertPt->getIterator(), BB, BB->begin(), + BB->getTerminator()->getIterator()); } namespace { @@ -2917,15 +3032,15 @@ struct BitPart { /// /// Because we pass around references into \c BPS, we must use a container that /// does not invalidate internal references (std::map instead of DenseMap). -static const Optional<BitPart> & +static const std::optional<BitPart> & collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, - std::map<Value *, Optional<BitPart>> &BPS, int Depth, + std::map<Value *, std::optional<BitPart>> &BPS, int Depth, bool &FoundRoot) { auto I = BPS.find(V); if (I != BPS.end()) return I->second; - auto &Result = BPS[V] = None; + auto &Result = BPS[V] = std::nullopt; auto BitWidth = V->getType()->getScalarSizeInBits(); // Can't do integer/elements > 128 bits. @@ -2961,7 +3076,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, if (A->Provenance[BitIdx] != BitPart::Unset && B->Provenance[BitIdx] != BitPart::Unset && A->Provenance[BitIdx] != B->Provenance[BitIdx]) - return Result = None; + return Result = std::nullopt; if (A->Provenance[BitIdx] == BitPart::Unset) Result->Provenance[BitIdx] = B->Provenance[BitIdx]; @@ -3169,7 +3284,7 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( // Try to find all the pieces corresponding to the bswap. bool FoundRoot = false; - std::map<Value *, Optional<BitPart>> BPS; + std::map<Value *, std::optional<BitPart>> BPS; const auto &Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0, FoundRoot); if (!Res) diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp index f093fea19c4d..2acbe9002309 100644 --- a/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -11,7 +11,6 @@ #include "llvm/Transforms/Utils/LoopPeel.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" @@ -29,6 +28,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -41,6 +41,7 @@ #include <algorithm> #include <cassert> #include <cstdint> +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -71,25 +72,20 @@ static cl::opt<unsigned> UnrollForcePeelCount( "unroll-force-peel-count", cl::init(0), cl::Hidden, cl::desc("Force a peel count regardless of profiling information.")); +static cl::opt<bool> DisableAdvancedPeeling( + "disable-advanced-peeling", cl::init(false), cl::Hidden, + cl::desc( + "Disable advance peeling. Issues for convergent targets (D134803).")); + static const char *PeeledCountMetaData = "llvm.loop.peeled.count"; // Check whether we are capable of peeling this loop. -bool llvm::canPeel(Loop *L) { +bool llvm::canPeel(const Loop *L) { // Make sure the loop is in simplified form if (!L->isLoopSimplifyForm()) return false; - - // Don't try to peel loops where the latch is not the exiting block. - // This can be an indication of two different things: - // 1) The loop is not rotated. - // 2) The loop contains irreducible control flow that involves the latch. - const BasicBlock *Latch = L->getLoopLatch(); - if (!L->isLoopExiting(Latch)) - return false; - - // Peeling is only supported if the latch is a branch. - if (!isa<BranchInst>(Latch->getTerminator())) - return false; + if (!DisableAdvancedPeeling) + return true; SmallVector<BasicBlock *, 4> Exits; L->getUniqueNonLatchExitBlocks(Exits); @@ -104,63 +100,182 @@ bool llvm::canPeel(Loop *L) { return llvm::all_of(Exits, IsBlockFollowedByDeoptOrUnreachable); } -// This function calculates the number of iterations after which the given Phi -// becomes an invariant. The pre-calculated values are memorized in the map. The -// function (shortcut is I) is calculated according to the following definition: +namespace { + +// As a loop is peeled, it may be the case that Phi nodes become +// loop-invariant (ie, known because there is only one choice). +// For example, consider the following function: +// void g(int); +// void binary() { +// int x = 0; +// int y = 0; +// int a = 0; +// for(int i = 0; i <100000; ++i) { +// g(x); +// x = y; +// g(a); +// y = a + 1; +// a = 5; +// } +// } +// Peeling 3 iterations is beneficial because the values for x, y and a +// become known. The IR for this loop looks something like the following: +// +// %i = phi i32 [ 0, %entry ], [ %inc, %if.end ] +// %a = phi i32 [ 0, %entry ], [ 5, %if.end ] +// %y = phi i32 [ 0, %entry ], [ %add, %if.end ] +// %x = phi i32 [ 0, %entry ], [ %y, %if.end ] +// ... +// tail call void @_Z1gi(i32 signext %x) +// tail call void @_Z1gi(i32 signext %a) +// %add = add nuw nsw i32 %a, 1 +// %inc = add nuw nsw i32 %i, 1 +// %exitcond = icmp eq i32 %inc, 100000 +// br i1 %exitcond, label %for.cond.cleanup, label %for.body +// +// The arguments for the calls to g will become known after 3 iterations +// of the loop, because the phi nodes values become known after 3 iterations +// of the loop (ie, they are known on the 4th iteration, so peel 3 iterations). +// The first iteration has g(0), g(0); the second has g(0), g(5); the +// third has g(1), g(5) and the fourth (and all subsequent) have g(6), g(5). +// Now consider the phi nodes: +// %a is a phi with constants so it is determined after iteration 1. +// %y is a phi based on a constant and %a so it is determined on +// the iteration after %a is determined, so iteration 2. +// %x is a phi based on a constant and %y so it is determined on +// the iteration after %y, so iteration 3. +// %i is based on itself (and is an induction variable) so it is +// never determined. +// This means that peeling off 3 iterations will result in being able to +// remove the phi nodes for %a, %y, and %x. The arguments for the +// corresponding calls to g are determined and the code for computing +// x, y, and a can be removed. +// +// The PhiAnalyzer class calculates how many times a loop should be +// peeled based on the above analysis of the phi nodes in the loop while +// respecting the maximum specified. +class PhiAnalyzer { +public: + PhiAnalyzer(const Loop &L, unsigned MaxIterations); + + // Calculate the sufficient minimum number of iterations of the loop to peel + // such that phi instructions become determined (subject to allowable limits) + std::optional<unsigned> calculateIterationsToPeel(); + +protected: + using PeelCounter = std::optional<unsigned>; + const PeelCounter Unknown = std::nullopt; + + // Add 1 respecting Unknown and return Unknown if result over MaxIterations + PeelCounter addOne(PeelCounter PC) const { + if (PC == Unknown) + return Unknown; + return (*PC + 1 <= MaxIterations) ? PeelCounter{*PC + 1} : Unknown; + } + + // Calculate the number of iterations after which the given value + // becomes an invariant. + PeelCounter calculate(const Value &); + + const Loop &L; + const unsigned MaxIterations; + + // Map of Values to number of iterations to invariance + SmallDenseMap<const Value *, PeelCounter> IterationsToInvariance; +}; + +PhiAnalyzer::PhiAnalyzer(const Loop &L, unsigned MaxIterations) + : L(L), MaxIterations(MaxIterations) { + assert(canPeel(&L) && "loop is not suitable for peeling"); + assert(MaxIterations > 0 && "no peeling is allowed?"); +} + +// This function calculates the number of iterations after which the value +// becomes an invariant. The pre-calculated values are memorized in a map. +// N.B. This number will be Unknown or <= MaxIterations. +// The function is calculated according to the following definition: // Given %x = phi <Inputs from above the loop>, ..., [%y, %back.edge]. -// If %y is a loop invariant, then I(%x) = 1. -// If %y is a Phi from the loop header, I(%x) = I(%y) + 1. -// Otherwise, I(%x) is infinite. -// TODO: Actually if %y is an expression that depends only on Phi %z and some -// loop invariants, we can estimate I(%x) = I(%z) + 1. The example -// looks like: -// %x = phi(0, %a), <-- becomes invariant starting from 3rd iteration. -// %y = phi(0, 5), -// %a = %y + 1. -static Optional<unsigned> calculateIterationsToInvariance( - PHINode *Phi, Loop *L, BasicBlock *BackEdge, - SmallDenseMap<PHINode *, Optional<unsigned> > &IterationsToInvariance) { - assert(Phi->getParent() == L->getHeader() && - "Non-loop Phi should not be checked for turning into invariant."); - assert(BackEdge == L->getLoopLatch() && "Wrong latch?"); +// F(%x) = G(%y) + 1 (N.B. [MaxIterations | Unknown] + 1 => Unknown) +// G(%y) = 0 if %y is a loop invariant +// G(%y) = G(%BackEdgeValue) if %y is a phi in the header block +// G(%y) = TODO: if %y is an expression based on phis and loop invariants +// The example looks like: +// %x = phi(0, %a) <-- becomes invariant starting from 3rd iteration. +// %y = phi(0, 5) +// %a = %y + 1 +// G(%y) = Unknown otherwise (including phi not in header block) +PhiAnalyzer::PeelCounter PhiAnalyzer::calculate(const Value &V) { // If we already know the answer, take it from the map. - auto I = IterationsToInvariance.find(Phi); + auto I = IterationsToInvariance.find(&V); if (I != IterationsToInvariance.end()) return I->second; - // Otherwise we need to analyze the input from the back edge. - Value *Input = Phi->getIncomingValueForBlock(BackEdge); - // Place infinity to map to avoid infinite recursion for cycled Phis. Such + // Place Unknown to map to avoid infinite recursion. Such // cycles can never stop on an invariant. - IterationsToInvariance[Phi] = None; - Optional<unsigned> ToInvariance = None; - - if (L->isLoopInvariant(Input)) - ToInvariance = 1u; - else if (PHINode *IncPhi = dyn_cast<PHINode>(Input)) { - // Only consider Phis in header block. - if (IncPhi->getParent() != L->getHeader()) - return None; - // If the input becomes an invariant after X iterations, then our Phi - // becomes an invariant after X + 1 iterations. - auto InputToInvariance = calculateIterationsToInvariance( - IncPhi, L, BackEdge, IterationsToInvariance); - if (InputToInvariance) - ToInvariance = *InputToInvariance + 1u; + IterationsToInvariance[&V] = Unknown; + + if (L.isLoopInvariant(&V)) + // Loop invariant so known at start. + return (IterationsToInvariance[&V] = 0); + if (const PHINode *Phi = dyn_cast<PHINode>(&V)) { + if (Phi->getParent() != L.getHeader()) { + // Phi is not in header block so Unknown. + assert(IterationsToInvariance[&V] == Unknown && "unexpected value saved"); + return Unknown; + } + // We need to analyze the input from the back edge and add 1. + Value *Input = Phi->getIncomingValueForBlock(L.getLoopLatch()); + PeelCounter Iterations = calculate(*Input); + assert(IterationsToInvariance[Input] == Iterations && + "unexpected value saved"); + return (IterationsToInvariance[Phi] = addOne(Iterations)); + } + if (const Instruction *I = dyn_cast<Instruction>(&V)) { + if (isa<CmpInst>(I) || I->isBinaryOp()) { + // Binary instructions get the max of the operands. + PeelCounter LHS = calculate(*I->getOperand(0)); + if (LHS == Unknown) + return Unknown; + PeelCounter RHS = calculate(*I->getOperand(1)); + if (RHS == Unknown) + return Unknown; + return (IterationsToInvariance[I] = {std::max(*LHS, *RHS)}); + } + if (I->isCast()) + // Cast instructions get the value of the operand. + return (IterationsToInvariance[I] = calculate(*I->getOperand(0))); } + // TODO: handle more expressions + + // Everything else is Unknown. + assert(IterationsToInvariance[&V] == Unknown && "unexpected value saved"); + return Unknown; +} - // If we found that this Phi lies in an invariant chain, update the map. - if (ToInvariance) - IterationsToInvariance[Phi] = ToInvariance; - return ToInvariance; +std::optional<unsigned> PhiAnalyzer::calculateIterationsToPeel() { + unsigned Iterations = 0; + for (auto &PHI : L.getHeader()->phis()) { + PeelCounter ToInvariance = calculate(PHI); + if (ToInvariance != Unknown) { + assert(*ToInvariance <= MaxIterations && "bad result in phi analysis"); + Iterations = std::max(Iterations, *ToInvariance); + if (Iterations == MaxIterations) + break; + } + } + assert((Iterations <= MaxIterations) && "bad result in phi analysis"); + return Iterations ? std::optional<unsigned>(Iterations) : std::nullopt; } +} // unnamed namespace + // Try to find any invariant memory reads that will become dereferenceable in // the remainder loop after peeling. The load must also be used (transitively) // by an exit condition. Returns the number of iterations to peel off (at the // moment either 0 or 1). static unsigned peelToTurnInvariantLoadsDerefencebale(Loop &L, - DominatorTree &DT) { + DominatorTree &DT, + AssumptionCache *AC) { // Skip loops with a single exiting block, because there should be no benefit // for the heuristic below. if (L.getExitingBlock()) @@ -201,7 +316,7 @@ static unsigned peelToTurnInvariantLoadsDerefencebale(Loop &L, if (auto *LI = dyn_cast<LoadInst>(&I)) { Value *Ptr = LI->getPointerOperand(); if (DT.dominates(BB, Latch) && L.isLoopInvariant(Ptr) && - !isDereferenceablePointer(Ptr, LI->getType(), DL, LI, &DT)) + !isDereferenceablePointer(Ptr, LI->getType(), DL, LI, AC, &DT)) for (Value *U : I.users()) LoadUsers.insert(U); } @@ -330,7 +445,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, /// This "heuristic" exactly matches implicit behavior which used to exist /// inside getLoopEstimatedTripCount. It was added here to keep an -/// improvement inside that API from causing peeling to become more agressive. +/// improvement inside that API from causing peeling to become more aggressive. /// This should probably be removed. static bool violatesLegacyMultiExitLoopCheck(Loop *L) { BasicBlock *Latch = L->getLoopLatch(); @@ -357,7 +472,8 @@ static bool violatesLegacyMultiExitLoopCheck(Loop *L) { void llvm::computePeelCount(Loop *L, unsigned LoopSize, TargetTransformInfo::PeelingPreferences &PP, unsigned TripCount, DominatorTree &DT, - ScalarEvolution &SE, unsigned Threshold) { + ScalarEvolution &SE, AssumptionCache *AC, + unsigned Threshold) { assert(LoopSize > 0 && "Zero loop size is not allowed!"); // Save the PP.PeelCount value set by the target in // TTI.getPeelingPreferences or by the flag -unroll-peel-count. @@ -397,38 +513,31 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (AlreadyPeeled >= UnrollPeelMaxCount) return; + // Pay respect to limitations implied by loop size and the max peel count. + unsigned MaxPeelCount = UnrollPeelMaxCount; + MaxPeelCount = std::min(MaxPeelCount, Threshold / LoopSize - 1); + + // Start the max computation with the PP.PeelCount value set by the target + // in TTI.getPeelingPreferences or by the flag -unroll-peel-count. + unsigned DesiredPeelCount = TargetPeelCount; + // Here we try to get rid of Phis which become invariants after 1, 2, ..., N // iterations of the loop. For this we compute the number for iterations after // which every Phi is guaranteed to become an invariant, and try to peel the // maximum number of iterations among these values, thus turning all those // Phis into invariants. - - // Store the pre-calculated values here. - SmallDenseMap<PHINode *, Optional<unsigned>> IterationsToInvariance; - // Now go through all Phis to calculate their the number of iterations they - // need to become invariants. - // Start the max computation with the PP.PeelCount value set by the target - // in TTI.getPeelingPreferences or by the flag -unroll-peel-count. - unsigned DesiredPeelCount = TargetPeelCount; - BasicBlock *BackEdge = L->getLoopLatch(); - assert(BackEdge && "Loop is not in simplified form?"); - for (auto BI = L->getHeader()->begin(); isa<PHINode>(&*BI); ++BI) { - PHINode *Phi = cast<PHINode>(&*BI); - auto ToInvariance = calculateIterationsToInvariance(Phi, L, BackEdge, - IterationsToInvariance); - if (ToInvariance) - DesiredPeelCount = std::max(DesiredPeelCount, *ToInvariance); + if (MaxPeelCount > DesiredPeelCount) { + // Check how many iterations are useful for resolving Phis + auto NumPeels = PhiAnalyzer(*L, MaxPeelCount).calculateIterationsToPeel(); + if (NumPeels) + DesiredPeelCount = std::max(DesiredPeelCount, *NumPeels); } - // Pay respect to limitations implied by loop size and the max peel count. - unsigned MaxPeelCount = UnrollPeelMaxCount; - MaxPeelCount = std::min(MaxPeelCount, Threshold / LoopSize - 1); - DesiredPeelCount = std::max(DesiredPeelCount, countToEliminateCompares(*L, MaxPeelCount, SE)); if (DesiredPeelCount == 0) - DesiredPeelCount = peelToTurnInvariantLoadsDerefencebale(*L, DT); + DesiredPeelCount = peelToTurnInvariantLoadsDerefencebale(*L, DT, AC); if (DesiredPeelCount > 0) { DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount); @@ -460,7 +569,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (L->getHeader()->getParent()->hasProfileData()) { if (violatesLegacyMultiExitLoopCheck(L)) return; - Optional<unsigned> EstimatedTripCount = getLoopEstimatedTripCount(L); + std::optional<unsigned> EstimatedTripCount = getLoopEstimatedTripCount(L); if (!EstimatedTripCount) return; @@ -484,82 +593,87 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, } } -/// Update the branch weights of the latch of a peeled-off loop +struct WeightInfo { + // Weights for current iteration. + SmallVector<uint32_t> Weights; + // Weights to subtract after each iteration. + const SmallVector<uint32_t> SubWeights; +}; + +/// Update the branch weights of an exiting block of a peeled-off loop /// iteration. -/// This sets the branch weights for the latch of the recently peeled off loop -/// iteration correctly. -/// Let F is a weight of the edge from latch to header. -/// Let E is a weight of the edge from latch to exit. +/// Let F is a weight of the edge to continue (fallthrough) into the loop. +/// Let E is a weight of the edge to an exit. /// F/(F+E) is a probability to go to loop and E/(F+E) is a probability to /// go to exit. -/// Then, Estimated TripCount = F / E. +/// Then, Estimated ExitCount = F / E. /// For I-th (counting from 0) peeled off iteration we set the the weights for -/// the peeled latch as (TC - I, 1). It gives us reasonable distribution, -/// The probability to go to exit 1/(TC-I) increases. At the same time -/// the estimated trip count of remaining loop reduces by I. +/// the peeled exit as (EC - I, 1). It gives us reasonable distribution, +/// The probability to go to exit 1/(EC-I) increases. At the same time +/// the estimated exit count in the remainder loop reduces by I. /// To avoid dealing with division rounding we can just multiple both part /// of weights to E and use weight as (F - I * E, E). -/// -/// \param Header The copy of the header block that belongs to next iteration. -/// \param LatchBR The copy of the latch branch that belongs to this iteration. -/// \param[in,out] FallThroughWeight The weight of the edge from latch to -/// header before peeling (in) and after peeled off one iteration (out). -static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - uint64_t ExitWeight, - uint64_t &FallThroughWeight) { - // FallThroughWeight is 0 means that there is no branch weights on original - // latch block or estimated trip count is zero. - if (!FallThroughWeight) - return; - - unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); - MDBuilder MDB(LatchBR->getContext()); - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight) - : MDB.createBranchWeights(FallThroughWeight, ExitWeight); - LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); - FallThroughWeight = - FallThroughWeight > ExitWeight ? FallThroughWeight - ExitWeight : 1; +static void updateBranchWeights(Instruction *Term, WeightInfo &Info) { + MDBuilder MDB(Term->getContext()); + Term->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(Info.Weights)); + for (auto [Idx, SubWeight] : enumerate(Info.SubWeights)) + if (SubWeight != 0) + Info.Weights[Idx] = Info.Weights[Idx] > SubWeight + ? Info.Weights[Idx] - SubWeight + : 1; } -/// Initialize the weights. -/// -/// \param Header The header block. -/// \param LatchBR The latch branch. -/// \param[out] ExitWeight The weight of the edge from Latch to Exit. -/// \param[out] FallThroughWeight The weight of the edge from Latch to Header. -static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - uint64_t &ExitWeight, - uint64_t &FallThroughWeight) { - uint64_t TrueWeight, FalseWeight; - if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) - return; - unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; - ExitWeight = HeaderIdx ? TrueWeight : FalseWeight; - FallThroughWeight = HeaderIdx ? FalseWeight : TrueWeight; -} +/// Initialize the weights for all exiting blocks. +static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos, + Loop *L) { + SmallVector<BasicBlock *> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (BasicBlock *ExitingBlock : ExitingBlocks) { + Instruction *Term = ExitingBlock->getTerminator(); + SmallVector<uint32_t> Weights; + if (!extractBranchWeights(*Term, Weights)) + continue; -/// Update the weights of original Latch block after peeling off all iterations. -/// -/// \param Header The header block. -/// \param LatchBR The latch branch. -/// \param ExitWeight The weight of the edge from Latch to Exit. -/// \param FallThroughWeight The weight of the edge from Latch to Header. -static void fixupBranchWeights(BasicBlock *Header, BranchInst *LatchBR, - uint64_t ExitWeight, - uint64_t FallThroughWeight) { - // FallThroughWeight is 0 means that there is no branch weights on original - // latch block or estimated trip count is zero. - if (!FallThroughWeight) - return; + // See the comment on updateBranchWeights() for an explanation of what we + // do here. + uint32_t FallThroughWeights = 0; + uint32_t ExitWeights = 0; + for (auto [Succ, Weight] : zip(successors(Term), Weights)) { + if (L->contains(Succ)) + FallThroughWeights += Weight; + else + ExitWeights += Weight; + } + + // Don't try to update weights for degenerate case. + if (FallThroughWeights == 0) + continue; + + SmallVector<uint32_t> SubWeights; + for (auto [Succ, Weight] : zip(successors(Term), Weights)) { + if (!L->contains(Succ)) { + // Exit weights stay the same. + SubWeights.push_back(0); + continue; + } + + // Subtract exit weights on each iteration, distributed across all + // fallthrough edges. + double W = (double)Weight / (double)FallThroughWeights; + SubWeights.push_back((uint32_t)(ExitWeights * W)); + } + + WeightInfos.insert({Term, {std::move(Weights), std::move(SubWeights)}}); + } +} - // Sets the branch weights on the loop exit. - MDBuilder MDB(LatchBR->getContext()); - unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight) - : MDB.createBranchWeights(FallThroughWeight, ExitWeight); - LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); +/// Update the weights of original exiting block after peeling off all +/// iterations. +static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) { + MDBuilder MDB(Term->getContext()); + Term->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(Info.Weights)); } /// Clones the body of the loop L, putting it between \p InsertTop and \p @@ -641,10 +755,10 @@ static void cloneLoopBlocks( // header (for the last peeled iteration) or the copied header of the next // iteration (for every other iteration) BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); - BranchInst *LatchBR = cast<BranchInst>(NewLatch->getTerminator()); - for (unsigned idx = 0, e = LatchBR->getNumSuccessors(); idx < e; ++idx) - if (LatchBR->getSuccessor(idx) == Header) { - LatchBR->setSuccessor(idx, InsertBot); + auto *LatchTerm = cast<Instruction>(NewLatch->getTerminator()); + for (unsigned idx = 0, e = LatchTerm->getNumSuccessors(); idx < e; ++idx) + if (LatchTerm->getSuccessor(idx) == Header) { + LatchTerm->setSuccessor(idx, InsertBot); break; } if (DT) @@ -670,7 +784,7 @@ static void cloneLoopBlocks( else VMap[&*I] = LatchVal; } - cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); + NewPHI->eraseFromParent(); } // Fix up the outgoing values - we need to add a value for the iteration @@ -693,10 +807,12 @@ static void cloneLoopBlocks( LVMap[KV.first] = KV.second; } -TargetTransformInfo::PeelingPreferences llvm::gatherPeelingPreferences( - Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, - Optional<bool> UserAllowPeeling, - Optional<bool> UserAllowProfileBasedPeeling, bool UnrollingSpecficValues) { +TargetTransformInfo::PeelingPreferences +llvm::gatherPeelingPreferences(Loop *L, ScalarEvolution &SE, + const TargetTransformInfo &TTI, + std::optional<bool> UserAllowPeeling, + std::optional<bool> UserAllowProfileBasedPeeling, + bool UnrollingSpecficValues) { TargetTransformInfo::PeelingPreferences PP; // Set the default values. @@ -738,7 +854,7 @@ TargetTransformInfo::PeelingPreferences llvm::gatherPeelingPreferences( /// optimizations. bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, ScalarEvolution *SE, DominatorTree &DT, AssumptionCache *AC, - bool PreserveLCSSA) { + bool PreserveLCSSA, ValueToValueMapTy &LVMap) { assert(PeelCount > 0 && "Attempt to peel out zero iterations?"); assert(canPeel(L) && "Attempt to peel a loop which is not peelable?"); @@ -830,14 +946,13 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, InsertBot->setName(Header->getName() + ".peel.next"); NewPreHeader->setName(PreHeader->getName() + ".peel.newph"); - ValueToValueMapTy LVMap; + Instruction *LatchTerm = + cast<Instruction>(cast<BasicBlock>(Latch)->getTerminator()); // If we have branch weight information, we'll want to update it for the // newly created branches. - BranchInst *LatchBR = - cast<BranchInst>(cast<BasicBlock>(Latch)->getTerminator()); - uint64_t ExitWeight = 0, FallThroughWeight = 0; - initBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight); + DenseMap<Instruction *, WeightInfo> Weights; + initBranchWeights(Weights, L); // Identify what noalias metadata is inside the loop: if it is inside the // loop, the associated metadata must be cloned for each iteration. @@ -866,19 +981,22 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, assert(DT.verify(DominatorTree::VerificationLevel::Fast)); #endif - auto *LatchBRCopy = cast<BranchInst>(VMap[LatchBR]); - updateBranchWeights(InsertBot, LatchBRCopy, ExitWeight, FallThroughWeight); + for (auto &[Term, Info] : Weights) { + auto *TermCopy = cast<Instruction>(VMap[Term]); + updateBranchWeights(TermCopy, Info); + } + // Remove Loop metadata from the latch branch instruction // because it is not the Loop's latch branch anymore. - LatchBRCopy->setMetadata(LLVMContext::MD_loop, nullptr); + auto *LatchTermCopy = cast<Instruction>(VMap[LatchTerm]); + LatchTermCopy->setMetadata(LLVMContext::MD_loop, nullptr); InsertTop = InsertBot; InsertBot = SplitBlock(InsertBot, InsertBot->getTerminator(), &DT, LI); InsertBot->setName(Header->getName() + ".peel.next"); - F->getBasicBlockList().splice(InsertTop->getIterator(), - F->getBasicBlockList(), - NewBlocks[0]->getIterator(), F->end()); + F->splice(InsertTop->getIterator(), F, NewBlocks[0]->getIterator(), + F->end()); } // Now adjust the phi nodes in the loop header to get their initial values @@ -893,7 +1011,8 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, PHI->setIncomingValueForBlock(NewPreHeader, NewVal); } - fixupBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight); + for (const auto &[Term, Info] : Weights) + fixupBranchWeights(Term, Info); // Update Metadata for count of peeled off iterations. unsigned AlreadyPeeled = 0; diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index 597c88ad13df..1a9eaf242190 100644 --- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -316,7 +316,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { L->dump()); return Rotated; } - if (*Metrics.NumInsts.getValue() > MaxHeaderSize) { + if (Metrics.NumInsts > MaxHeaderSize) { LLVM_DEBUG(dbgs() << "LoopRotation: NOT rotating - contains " << Metrics.NumInsts << " instructions, which is more than the threshold (" @@ -345,8 +345,14 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // all outer loops because insertion and deletion of blocks that happens // during the rotation may violate invariants related to backedge taken // infos in them. - if (SE) + if (SE) { SE->forgetTopmostLoop(L); + // We may hoist some instructions out of loop. In case if they were cached + // as "loop variant" or "loop computable", these caches must be dropped. + // We also may fold basic blocks, so cached block dispositions also need + // to be dropped. + SE->forgetBlockAndLoopDispositions(); + } LLVM_DEBUG(dbgs() << "LoopRotation: rotating "; L->dump()); if (MSSAU && VerifyMemorySSA) @@ -713,7 +719,7 @@ static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, if (!cast<GEPOperator>(I)->hasAllConstantIndices()) return false; // fall-thru to increment case - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::Add: case Instruction::Sub: case Instruction::And: @@ -789,6 +795,11 @@ bool LoopRotate::simplifyLoopLatch(Loop *L) { MergeBlockIntoPredecessor(Latch, &DTU, LI, MSSAU, nullptr, /*PredecessorWithTwoSuccessors=*/true); + if (SE) { + // Merging blocks may remove blocks reference in the block disposition cache. Clear the cache. + SE->forgetBlockAndLoopDispositions(); + } + if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 2ff8a3f7b228..87a0e54e2704 100644 --- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -231,7 +231,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, // a function call is present until a better alternative becomes // available. This is similar to the conservative treatment of // convergent function calls in GVNHoist and JumpThreading. - for (auto BB : L->blocks()) { + for (auto *BB : L->blocks()) { for (auto &II : *BB) { if (auto CI = dyn_cast<CallBase>(&II)) { if (CI->isConvergent()) { @@ -392,7 +392,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, // Move the new backedge block to right after the last backedge block. Function::iterator InsertPos = ++BackedgeBlocks.back()->getIterator(); - F->getBasicBlockList().splice(InsertPos, F->getBasicBlockList(), BEBlock); + F->splice(InsertPos, F, BEBlock->getIterator()); // Now that the block has been inserted into the function, create PHI nodes in // the backedge block which correspond to any PHI nodes in the header block. @@ -440,7 +440,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, // eliminate the PHI Node. if (HasUniqueIncomingValue) { NewPN->replaceAllUsesWith(UniqueValue); - BEBlock->getInstList().erase(NewPN); + NewPN->eraseFromParent(); } } @@ -450,8 +450,8 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, // it from the backedge and add it to BEBlock. unsigned LoopMDKind = BEBlock->getContext().getMDKindID("llvm.loop"); MDNode *LoopMD = nullptr; - for (unsigned i = 0, e = BackedgeBlocks.size(); i != e; ++i) { - Instruction *TI = BackedgeBlocks[i]->getTerminator(); + for (BasicBlock *BB : BackedgeBlocks) { + Instruction *TI = BB->getTerminator(); if (!LoopMD) LoopMD = TI->getMetadata(LoopMDKind); TI->setMetadata(LoopMDKind, nullptr); @@ -649,18 +649,13 @@ ReprocessLoop: continue; if (!L->makeLoopInvariant( Inst, AnyInvariant, - Preheader ? Preheader->getTerminator() : nullptr, MSSAU)) { + Preheader ? Preheader->getTerminator() : nullptr, MSSAU, SE)) { AllInvariant = false; break; } } - if (AnyInvariant) { + if (AnyInvariant) Changed = true; - // The loop disposition of all SCEV expressions that depend on any - // hoisted values have also changed. - if (SE) - SE->forgetLoopDispositions(L); - } if (!AllInvariant) continue; // The block has now been cleared of all instructions except for diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp index 1be1082002fc..e8f585b4a94d 100644 --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -17,7 +17,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" @@ -66,6 +65,7 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <assert.h> +#include <numeric> #include <type_traits> #include <vector> @@ -321,6 +321,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, unsigned TripMultiple; unsigned BreakoutTrip; bool ExitOnTrue; + BasicBlock *FirstExitingBlock = nullptr; SmallVector<BasicBlock *> ExitingBlocks; }; DenseMap<BasicBlock *, ExitInfo> ExitInfos; @@ -341,7 +342,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, Info.TripMultiple = 0; } else { Info.BreakoutTrip = Info.TripMultiple = - (unsigned)GreatestCommonDivisor64(ULO.Count, Info.TripMultiple); + (unsigned)std::gcd(ULO.Count, Info.TripMultiple); } Info.ExitOnTrue = !L->contains(BI->getSuccessor(0)); Info.ExitingBlocks.push_back(ExitingBlock); @@ -464,8 +465,10 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (SE) { if (ULO.ForgetAllSCEV) SE->forgetAllLoops(); - else + else { SE->forgetTopmostLoop(L); + SE->forgetBlockAndLoopDispositions(); + } } if (!LatchIsExiting) @@ -506,7 +509,8 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, // When a FSDiscriminator is enabled, we don't need to add the multiply // factors to the discriminators. - if (Header->getParent()->isDebugInfoForProfiling() && !EnableFSDiscriminator) + if (Header->getParent()->shouldEmitDebugInfoForProfiling() && + !EnableFSDiscriminator) for (BasicBlock *BB : L->getBlocks()) for (Instruction &I : *BB) if (!isa<DbgInfoIntrinsic>(&I)) @@ -537,7 +541,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { ValueToValueMapTy VMap; BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); - Header->getParent()->getBasicBlockList().insert(BlockInsertPt, New); + Header->getParent()->insert(BlockInsertPt, New); assert((*BB != Header || LI->getLoopFor(*BB) == L) && "Header should not be in a sub-loop"); @@ -556,7 +560,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (It > 1 && L->contains(InValI)) InVal = LastValueMap[InValI]; VMap[OrigPHI] = InVal; - New->getInstList().erase(NewPHI); + NewPHI->eraseFromParent(); } // Update our running map of newest clones @@ -575,6 +579,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (It != LastValueMap.end()) Incoming = It->second; PHI.addIncoming(Incoming, New); + SE->forgetValue(&PHI); } } // Keep track of new headers and latches as we create them, so that @@ -629,7 +634,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, for (PHINode *PN : OrigPHINode) { if (CompletelyUnroll) { PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader)); - Header->getInstList().erase(PN); + PN->eraseFromParent(); } else if (ULO.Count > 1) { Value *InVal = PN->removeIncomingValue(LatchBlock, false); // If this value was defined in the loop, take the value defined by the @@ -676,8 +681,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, assert(!UnrollVerifyDomtree || DT->verify(DominatorTree::VerificationLevel::Fast)); - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - + SmallVector<DominatorTree::UpdateType> DTUpdates; auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) { auto *Term = cast<BranchInst>(Src->getTerminator()); const unsigned Idx = ExitOnTrue ^ WillExit; @@ -691,15 +695,15 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, BranchInst::Create(Dest, Term); Term->eraseFromParent(); - DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}}); + DTUpdates.emplace_back(DominatorTree::Delete, Src, DeadSucc); }; auto WillExit = [&](const ExitInfo &Info, unsigned i, unsigned j, - bool IsLatch) -> Optional<bool> { + bool IsLatch) -> std::optional<bool> { if (CompletelyUnroll) { if (PreserveOnlyFirst) { if (i == 0) - return None; + return std::nullopt; return j == 0; } // Complete (but possibly inexact) unrolling @@ -707,7 +711,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, return true; if (Info.TripCount && j != Info.TripCount) return false; - return None; + return std::nullopt; } if (ULO.Runtime) { @@ -715,7 +719,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, // exits may be stale. if (IsLatch && j != 0) return false; - return None; + return std::nullopt; } if (j != Info.BreakoutTrip && @@ -724,36 +728,69 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, // unconditional branch for some iterations. return false; } - return None; + return std::nullopt; }; // Fold branches for iterations where we know that they will exit or not // exit. - for (const auto &Pair : ExitInfos) { - const ExitInfo &Info = Pair.second; + for (auto &Pair : ExitInfos) { + ExitInfo &Info = Pair.second; for (unsigned i = 0, e = Info.ExitingBlocks.size(); i != e; ++i) { // The branch destination. unsigned j = (i + 1) % e; bool IsLatch = Pair.first == LatchBlock; - Optional<bool> KnownWillExit = WillExit(Info, i, j, IsLatch); - if (!KnownWillExit) + std::optional<bool> KnownWillExit = WillExit(Info, i, j, IsLatch); + if (!KnownWillExit) { + if (!Info.FirstExitingBlock) + Info.FirstExitingBlock = Info.ExitingBlocks[i]; continue; + } // We don't fold known-exiting branches for non-latch exits here, // because this ensures that both all loop blocks and all exit blocks // remain reachable in the CFG. // TODO: We could fold these branches, but it would require much more // sophisticated updates to LoopInfo. - if (*KnownWillExit && !IsLatch) + if (*KnownWillExit && !IsLatch) { + if (!Info.FirstExitingBlock) + Info.FirstExitingBlock = Info.ExitingBlocks[i]; continue; + } SetDest(Info.ExitingBlocks[i], *KnownWillExit, Info.ExitOnTrue); } } + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + DomTreeUpdater *DTUToUse = &DTU; + if (ExitingBlocks.size() == 1 && ExitInfos.size() == 1) { + // Manually update the DT if there's a single exiting node. In that case + // there's a single exit node and it is sufficient to update the nodes + // immediately dominated by the original exiting block. They will become + // dominated by the first exiting block that leaves the loop after + // unrolling. Note that the CFG inside the loop does not change, so there's + // no need to update the DT inside the unrolled loop. + DTUToUse = nullptr; + auto &[OriginalExit, Info] = *ExitInfos.begin(); + if (!Info.FirstExitingBlock) + Info.FirstExitingBlock = Info.ExitingBlocks.back(); + for (auto *C : to_vector(DT->getNode(OriginalExit)->children())) { + if (L->contains(C->getBlock())) + continue; + C->setIDom(DT->getNode(Info.FirstExitingBlock)); + } + } else { + DTU.applyUpdates(DTUpdates); + } + // When completely unrolling, the last latch becomes unreachable. - if (!LatchIsExiting && CompletelyUnroll) - changeToUnreachable(Latches.back()->getTerminator(), PreserveLCSSA, &DTU); + if (!LatchIsExiting && CompletelyUnroll) { + // There is no need to update the DT here, because there must be a unique + // latch. Hence if the latch is not exiting it must directly branch back to + // the original loop header and does not dominate any nodes. + assert(LatchBlock->getSingleSuccessor() && "Loop with multiple latches?"); + changeToUnreachable(Latches.back()->getTerminator(), PreserveLCSSA); + } // Merge adjacent basic blocks, if possible. for (BasicBlock *Latch : Latches) { @@ -765,16 +802,21 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (Term && Term->isUnconditional()) { BasicBlock *Dest = Term->getSuccessor(0); BasicBlock *Fold = Dest->getUniquePredecessor(); - if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) { + if (MergeBlockIntoPredecessor(Dest, /*DTU=*/DTUToUse, LI, + /*MSSAU=*/nullptr, /*MemDep=*/nullptr, + /*PredecessorWithTwoSuccessors=*/false, + DTUToUse ? nullptr : DT)) { // Dest has been folded into Fold. Update our worklists accordingly. std::replace(Latches.begin(), Latches.end(), Dest, Fold); llvm::erase_value(UnrolledLoopBlocks, Dest); } } } - // Apply updates to the DomTree. - DT = &DTU.getDomTree(); + if (DTUToUse) { + // Apply updates to the DomTree. + DT = &DTU.getDomTree(); + } assert(!UnrollVerifyDomtree || DT->verify(DominatorTree::VerificationLevel::Fast)); diff --git a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp index 96485d15c75b..b125e952ec94 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -13,7 +13,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -138,25 +137,28 @@ static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop, template <typename T> static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch, BasicBlockSet &AftBlocks, T Visit) { - SmallVector<Instruction *, 8> Worklist; SmallPtrSet<Instruction *, 8> VisitedInstr; - for (auto &Phi : Header->phis()) { - Value *V = Phi.getIncomingValueForBlock(Latch); - if (Instruction *I = dyn_cast<Instruction>(V)) - Worklist.push_back(I); - } - while (!Worklist.empty()) { - Instruction *I = Worklist.pop_back_val(); - if (!Visit(I)) - return false; + std::function<bool(Instruction * I)> ProcessInstr = [&](Instruction *I) { + if (VisitedInstr.count(I)) + return true; + VisitedInstr.insert(I); if (AftBlocks.count(I->getParent())) for (auto &U : I->operands()) if (Instruction *II = dyn_cast<Instruction>(U)) - if (!VisitedInstr.count(II)) - Worklist.push_back(II); + if (!ProcessInstr(II)) + return false; + + return Visit(I); + }; + + for (auto &Phi : Header->phis()) { + Value *V = Phi.getIncomingValueForBlock(Latch); + if (Instruction *I = dyn_cast<Instruction>(V)) + if (!ProcessInstr(I)) + return false; } return true; @@ -169,20 +171,12 @@ static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header, BasicBlockSet &AftBlocks) { // We need to ensure we move the instructions in the correct order, // starting with the earliest required instruction and moving forward. - std::vector<Instruction *> Visited; processHeaderPhiOperands(Header, Latch, AftBlocks, - [&Visited, &AftBlocks](Instruction *I) { + [&AftBlocks, &InsertLoc](Instruction *I) { if (AftBlocks.count(I->getParent())) - Visited.push_back(I); + I->moveBefore(InsertLoc); return true; }); - - // Move all instructions in program order to before the InsertLoc - BasicBlock *InsertLocBB = InsertLoc->getParent(); - for (Instruction *I : reverse(Visited)) { - if (I->getParent() != InsertLocBB) - I->moveBefore(InsertLoc); - } } /* @@ -261,7 +255,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, // if not outright eliminated. if (SE) { SE->forgetLoop(L); - SE->forgetLoop(SubLoop); + SE->forgetBlockAndLoopDispositions(); } using namespace ore; @@ -349,7 +343,8 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, // When a FSDiscriminator is enabled, we don't need to add the multiply // factors to the discriminators. - if (Header->getParent()->isDebugInfoForProfiling() && !EnableFSDiscriminator) + if (Header->getParent()->shouldEmitDebugInfoForProfiling() && + !EnableFSDiscriminator) for (BasicBlock *BB : L->getBlocks()) for (Instruction &I : *BB) if (!isa<DbgInfoIntrinsic>(&I)) @@ -375,7 +370,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { ValueToValueMapTy VMap; BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); - Header->getParent()->getBasicBlockList().push_back(New); + Header->getParent()->insert(Header->getParent()->end(), New); // Tell LI about New. addClonedBlockToLoopInfo(*BB, New, LI, NewLoops); @@ -497,7 +492,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, if (CompletelyUnroll) { while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) { Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader)); - Phi->getParent()->getInstList().erase(Phi); + Phi->eraseFromParent(); } } else { // Update the PHI values to point to the last aft block diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 023a0afd329b..b19156bcb420 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -30,6 +30,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -217,7 +218,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, for (PHINode &PN : NewExit->phis()) { // PN should be used in another PHI located in Exit block as // Exit was split by SplitBlockPredecessors into Exit and NewExit - // Basicaly it should look like: + // Basically it should look like: // NewExit: // PN = PHI [I, Latch] // ... @@ -399,10 +400,10 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder, if (UnrollRemainder) return NewLoop; - Optional<MDNode *> NewLoopID = makeFollowupLoopID( + std::optional<MDNode *> NewLoopID = makeFollowupLoopID( LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); if (NewLoopID) { - NewLoop->setLoopID(NewLoopID.value()); + NewLoop->setLoopID(*NewLoopID); // Do not setLoopAlreadyUnrolled if loop attributes have been defined // explicitly. @@ -471,7 +472,7 @@ static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop, uint64_t TrueWeight, FalseWeight; BranchInst *LatchBR = cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator()); - if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) + if (!extractBranchWeights(*LatchBR, TrueWeight, FalseWeight)) return; uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader() ? FalseWeight @@ -811,10 +812,7 @@ bool llvm::UnrollRuntimeLoopRemainder( updateLatchBranchWeightsForRemainderLoop(L, remainderLoop, Count); // Insert the cloned blocks into the function. - F->getBasicBlockList().splice(InsertBot->getIterator(), - F->getBasicBlockList(), - NewBlocks[0]->getIterator(), - F->end()); + F->splice(InsertBot->getIterator(), F, NewBlocks[0]->getIterator(), F->end()); // Now the loop blocks are cloned and the other exiting blocks from the // remainder are connected to the original Loop's exit blocks. The remaining diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 349063dd5e89..7df8651ede15 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -12,7 +12,6 @@ #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" @@ -38,6 +37,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -246,27 +246,27 @@ void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD, TheLoop->setLoopID(NewLoopID); } -Optional<ElementCount> +std::optional<ElementCount> llvm::getOptionalElementCountLoopAttribute(const Loop *TheLoop) { - Optional<int> Width = + std::optional<int> Width = getOptionalIntLoopAttribute(TheLoop, "llvm.loop.vectorize.width"); if (Width) { - Optional<int> IsScalable = getOptionalIntLoopAttribute( + std::optional<int> IsScalable = getOptionalIntLoopAttribute( TheLoop, "llvm.loop.vectorize.scalable.enable"); return ElementCount::get(*Width, IsScalable.value_or(false)); } - return None; + return std::nullopt; } -Optional<MDNode *> llvm::makeFollowupLoopID( +std::optional<MDNode *> llvm::makeFollowupLoopID( MDNode *OrigLoopID, ArrayRef<StringRef> FollowupOptions, const char *InheritOptionsExceptPrefix, bool AlwaysNew) { if (!OrigLoopID) { if (AlwaysNew) return nullptr; - return None; + return std::nullopt; } assert(OrigLoopID->getOperand(0) == OrigLoopID); @@ -325,7 +325,7 @@ Optional<MDNode *> llvm::makeFollowupLoopID( // Attributes of the followup loop not specified explicity, so signal to the // transformation pass to add suitable attributes. if (!AlwaysNew && !HasAnyFollowup) - return None; + return std::nullopt; // If no attributes were added or remove, the previous loop Id can be reused. if (!AlwaysNew && !Changed) @@ -353,10 +353,10 @@ TransformationMode llvm::hasUnrollTransformation(const Loop *L) { if (getBooleanLoopAttribute(L, "llvm.loop.unroll.disable")) return TM_SuppressedByUser; - Optional<int> Count = + std::optional<int> Count = getOptionalIntLoopAttribute(L, "llvm.loop.unroll.count"); if (Count) - return Count.value() == 1 ? TM_SuppressedByUser : TM_ForcedByUser; + return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser; if (getBooleanLoopAttribute(L, "llvm.loop.unroll.enable")) return TM_ForcedByUser; @@ -374,10 +374,10 @@ TransformationMode llvm::hasUnrollAndJamTransformation(const Loop *L) { if (getBooleanLoopAttribute(L, "llvm.loop.unroll_and_jam.disable")) return TM_SuppressedByUser; - Optional<int> Count = + std::optional<int> Count = getOptionalIntLoopAttribute(L, "llvm.loop.unroll_and_jam.count"); if (Count) - return Count.value() == 1 ? TM_SuppressedByUser : TM_ForcedByUser; + return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser; if (getBooleanLoopAttribute(L, "llvm.loop.unroll_and_jam.enable")) return TM_ForcedByUser; @@ -389,15 +389,15 @@ TransformationMode llvm::hasUnrollAndJamTransformation(const Loop *L) { } TransformationMode llvm::hasVectorizeTransformation(const Loop *L) { - Optional<bool> Enable = + std::optional<bool> Enable = getOptionalBoolLoopAttribute(L, "llvm.loop.vectorize.enable"); if (Enable == false) return TM_SuppressedByUser; - Optional<ElementCount> VectorizeWidth = + std::optional<ElementCount> VectorizeWidth = getOptionalElementCountLoopAttribute(L); - Optional<int> InterleaveCount = + std::optional<int> InterleaveCount = getOptionalIntLoopAttribute(L, "llvm.loop.interleave.count"); // 'Forcing' vector width and interleave count to one effectively disables @@ -485,8 +485,10 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // Tell ScalarEvolution that the loop is deleted. Do this before // deleting the loop so that ScalarEvolution can look at the loop // to determine what it needs to clean up. - if (SE) + if (SE) { SE->forgetLoop(L); + SE->forgetBlockAndLoopDispositions(); + } Instruction *OldTerm = Preheader->getTerminator(); assert(!OldTerm->mayHaveSideEffects() && @@ -591,7 +593,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, } // Use a map to unique and a vector to guarantee deterministic ordering. - llvm::SmallDenseSet<std::pair<DIVariable *, DIExpression *>, 4> DeadDebugSet; + llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet; llvm::SmallVector<DbgVariableIntrinsic *, 4> DeadDebugInst; if (ExitBlock) { @@ -620,11 +622,8 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I); if (!DVI) continue; - auto Key = - DeadDebugSet.find({DVI->getVariable(), DVI->getExpression()}); - if (Key != DeadDebugSet.end()) + if (!DeadDebugSet.insert(DebugVariable(DVI)).second) continue; - DeadDebugSet.insert({DVI->getVariable(), DVI->getExpression()}); DeadDebugInst.push_back(DVI); } @@ -633,15 +632,14 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // Since debug values in the loop have been deleted, inserting an undef // dbg.value truncates the range of any dbg.value before the loop where the // loop used to be. This is particularly important for constant values. - DIBuilder DIB(*ExitBlock->getModule()); Instruction *InsertDbgValueBefore = ExitBlock->getFirstNonPHI(); assert(InsertDbgValueBefore && "There should be a non-PHI instruction in exit block, else these " "instructions will have no parent."); - for (auto *DVI : DeadDebugInst) - DIB.insertDbgValueIntrinsic(UndefValue::get(Builder.getInt32Ty()), - DVI->getVariable(), DVI->getExpression(), - DVI->getDebugLoc(), InsertDbgValueBefore); + for (auto *DVI : DeadDebugInst) { + DVI->setKillLocation(); + DVI->moveBefore(InsertDbgValueBefore); + } } // Remove the block from the reference counting scheme, so that we can @@ -693,6 +691,7 @@ void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE, Loop *OutermostLoop = L->getOutermostLoop(); SE.forgetLoop(L); + SE.forgetBlockAndLoopDispositions(); std::unique_ptr<MemorySSAUpdater> MSSAU; if (MSSA) @@ -782,22 +781,22 @@ static BranchInst *getExpectedExitLoopLatchBranch(Loop *L) { /// Return the estimated trip count for any exiting branch which dominates /// the loop latch. -static Optional<uint64_t> -getEstimatedTripCount(BranchInst *ExitingBranch, Loop *L, - uint64_t &OrigExitWeight) { +static std::optional<uint64_t> getEstimatedTripCount(BranchInst *ExitingBranch, + Loop *L, + uint64_t &OrigExitWeight) { // To estimate the number of times the loop body was executed, we want to // know the number of times the backedge was taken, vs. the number of times // we exited the loop. uint64_t LoopWeight, ExitWeight; - if (!ExitingBranch->extractProfMetadata(LoopWeight, ExitWeight)) - return None; + if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) + return std::nullopt; if (L->contains(ExitingBranch->getSuccessor(1))) std::swap(LoopWeight, ExitWeight); if (!ExitWeight) // Don't have a way to return predicated infinite - return None; + return std::nullopt; OrigExitWeight = ExitWeight; @@ -808,7 +807,7 @@ getEstimatedTripCount(BranchInst *ExitingBranch, Loop *L, return ExitCount + 1; } -Optional<unsigned> +std::optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L, unsigned *EstimatedLoopInvocationWeight) { // Currently we take the estimate exit count only from the loop latch, @@ -817,14 +816,14 @@ llvm::getLoopEstimatedTripCount(Loop *L, // TODO: incorporate information from other exits if (BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L)) { uint64_t ExitWeight; - if (Optional<uint64_t> EstTripCount = - getEstimatedTripCount(LatchBranch, L, ExitWeight)) { + if (std::optional<uint64_t> EstTripCount = + getEstimatedTripCount(LatchBranch, L, ExitWeight)) { if (EstimatedLoopInvocationWeight) *EstimatedLoopInvocationWeight = ExitWeight; return *EstTripCount; } } - return None; + return std::nullopt; } bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount, @@ -1165,7 +1164,7 @@ static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) { if (Curr->mayHaveSideEffects()) return true; // Otherwise, add all its users to worklist. - for (auto U : Curr->users()) { + for (const auto *U : Curr->users()) { auto *UI = cast<Instruction>(U); if (Visited.insert(UI).second) WorkList.push_back(UI); @@ -1394,7 +1393,10 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI, // and next SCEV may errneously get smaller cost. // Collect all the candidate PHINodes to be rewritten. - RewritePhiSet.emplace_back(PN, i, ExitValue, Inst, HighCost); + Instruction *InsertPt = + (isa<PHINode>(Inst) || isa<LandingPadInst>(Inst)) ? + &*Inst->getParent()->getFirstInsertionPt() : Inst; + RewritePhiSet.emplace_back(PN, i, ExitValue, InsertPt, HighCost); } } } @@ -1474,7 +1476,7 @@ void llvm::setProfileInfoAfterUnrolling(Loop *OrigLoop, Loop *UnrolledLoop, // Get number of iterations in the original scalar loop. unsigned OrigLoopInvocationWeight = 0; - Optional<unsigned> OrigAverageTripCount = + std::optional<unsigned> OrigAverageTripCount = getLoopEstimatedTripCount(OrigLoop, &OrigLoopInvocationWeight); if (!OrigAverageTripCount) return; @@ -1664,8 +1666,7 @@ Value *llvm::addRuntimeChecks( } Value *llvm::addDiffRuntimeChecks( - Instruction *Loc, Loop *TheLoop, ArrayRef<PointerDiffInfo> Checks, - SCEVExpander &Expander, + Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander, function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) { LLVMContext &Ctx = Loc->getContext(); @@ -1675,7 +1676,7 @@ Value *llvm::addDiffRuntimeChecks( // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; - for (auto &C : Checks) { + for (const auto &C : Checks) { Type *Ty = C.SinkStart->getType(); // Compute VF * IC * AccessSize. auto *VFTimesUFTimesSize = @@ -1702,10 +1703,9 @@ Value *llvm::addDiffRuntimeChecks( return MemoryRuntimeCheck; } -Optional<IVConditionInfo> llvm::hasPartialIVCondition(Loop &L, - unsigned MSSAThreshold, - MemorySSA &MSSA, - AAResults &AA) { +std::optional<IVConditionInfo> +llvm::hasPartialIVCondition(const Loop &L, unsigned MSSAThreshold, + const MemorySSA &MSSA, AAResults &AA) { auto *TI = dyn_cast<BranchInst>(L.getHeader()->getTerminator()); if (!TI || !TI->isConditional()) return {}; @@ -1762,7 +1762,7 @@ Optional<IVConditionInfo> llvm::hasPartialIVCondition(Loop &L, [&L, &AA, &AccessedLocs, &ExitingBlocks, &InstToDuplicate, MSSAThreshold](BasicBlock *Succ, BasicBlock *Header, SmallVector<MemoryAccess *, 4> AccessesToCheck) - -> Optional<IVConditionInfo> { + -> std::optional<IVConditionInfo> { IVConditionInfo Info; // First, collect all blocks in the loop that are on a patch from Succ // to the header. @@ -1840,7 +1840,7 @@ Optional<IVConditionInfo> llvm::hasPartialIVCondition(Loop &L, if (L.contains(Succ)) continue; - Info.PathIsNoop &= llvm::empty(Succ->phis()) && + Info.PathIsNoop &= Succ->phis().empty() && (!Info.ExitForPath || Info.ExitForPath == Succ); if (!Info.PathIsNoop) break; diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index 97f29527bb95..17e71cf5a6c4 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -137,8 +137,10 @@ void LoopVersioning::addPHINodes( // See if we have a single-operand PHI with the value defined by the // original loop. for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) { - if (PN->getIncomingValue(0) == Inst) + if (PN->getIncomingValue(0) == Inst) { + SE->forgetValue(PN); break; + } } // If not create it. if (!PN) { @@ -254,8 +256,8 @@ void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst, } namespace { -bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA, - DominatorTree *DT, ScalarEvolution *SE) { +bool runImpl(LoopInfo *LI, LoopAccessInfoManager &LAIs, DominatorTree *DT, + ScalarEvolution *SE) { // Build up a worklist of inner-loops to version. This is necessary as the // act of versioning a loop creates new loops and can invalidate iterators // across the loops. @@ -273,7 +275,7 @@ bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA, if (!L->isLoopSimplifyForm() || !L->isRotatedForm() || !L->getExitingBlock()) continue; - const LoopAccessInfo &LAI = GetLAA(*L); + const LoopAccessInfo &LAI = LAIs.getInfo(*L); if (!LAI.hasConvergentOp() && (LAI.getNumRuntimePointerChecks() || !LAI.getPSE().getPredicate().isAlwaysTrue())) { @@ -282,6 +284,7 @@ bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA, LVer.versionLoop(); LVer.annotateLoopWithNoAlias(); Changed = true; + LAIs.clear(); } } @@ -299,14 +302,11 @@ public: bool runOnFunction(Function &F) override { auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - return getAnalysis<LoopAccessLegacyAnalysis>().getInfo(&L); - }; - + auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - return runImpl(LI, GetLAA, DT, SE); + return runImpl(LI, LAIs, DT, SE); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -346,20 +346,10 @@ PreservedAnalyses LoopVersioningPass::run(Function &F, FunctionAnalysisManager &AM) { auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); - auto &TTI = AM.getResult<TargetIRAnalysis>(F); + LoopAccessInfoManager &LAIs = AM.getResult<LoopAccessAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - auto &AA = AM.getResult<AAManager>(F); - auto &AC = AM.getResult<AssumptionAnalysis>(F); - - auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); - auto GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, - TLI, TTI, nullptr, nullptr, nullptr}; - return LAM.getResult<LoopAccessAnalysis>(L, AR); - }; - - if (runImpl(&LI, GetLAA, &DT, &SE)) + + if (runImpl(&LI, LAIs, &DT, &SE)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } diff --git a/llvm/lib/Transforms/Utils/LowerAtomic.cpp b/llvm/lib/Transforms/Utils/LowerAtomic.cpp index 2247b8107739..b6f40de0daa6 100644 --- a/llvm/lib/Transforms/Utils/LowerAtomic.cpp +++ b/llvm/lib/Transforms/Utils/LowerAtomic.cpp @@ -41,43 +41,60 @@ bool llvm::lowerAtomicCmpXchgInst(AtomicCmpXchgInst *CXI) { Value *llvm::buildAtomicRMWValue(AtomicRMWInst::BinOp Op, IRBuilderBase &Builder, Value *Loaded, - Value *Inc) { + Value *Val) { Value *NewVal; switch (Op) { case AtomicRMWInst::Xchg: - return Inc; + return Val; case AtomicRMWInst::Add: - return Builder.CreateAdd(Loaded, Inc, "new"); + return Builder.CreateAdd(Loaded, Val, "new"); case AtomicRMWInst::Sub: - return Builder.CreateSub(Loaded, Inc, "new"); + return Builder.CreateSub(Loaded, Val, "new"); case AtomicRMWInst::And: - return Builder.CreateAnd(Loaded, Inc, "new"); + return Builder.CreateAnd(Loaded, Val, "new"); case AtomicRMWInst::Nand: - return Builder.CreateNot(Builder.CreateAnd(Loaded, Inc), "new"); + return Builder.CreateNot(Builder.CreateAnd(Loaded, Val), "new"); case AtomicRMWInst::Or: - return Builder.CreateOr(Loaded, Inc, "new"); + return Builder.CreateOr(Loaded, Val, "new"); case AtomicRMWInst::Xor: - return Builder.CreateXor(Loaded, Inc, "new"); + return Builder.CreateXor(Loaded, Val, "new"); case AtomicRMWInst::Max: - NewVal = Builder.CreateICmpSGT(Loaded, Inc); - return Builder.CreateSelect(NewVal, Loaded, Inc, "new"); + NewVal = Builder.CreateICmpSGT(Loaded, Val); + return Builder.CreateSelect(NewVal, Loaded, Val, "new"); case AtomicRMWInst::Min: - NewVal = Builder.CreateICmpSLE(Loaded, Inc); - return Builder.CreateSelect(NewVal, Loaded, Inc, "new"); + NewVal = Builder.CreateICmpSLE(Loaded, Val); + return Builder.CreateSelect(NewVal, Loaded, Val, "new"); case AtomicRMWInst::UMax: - NewVal = Builder.CreateICmpUGT(Loaded, Inc); - return Builder.CreateSelect(NewVal, Loaded, Inc, "new"); + NewVal = Builder.CreateICmpUGT(Loaded, Val); + return Builder.CreateSelect(NewVal, Loaded, Val, "new"); case AtomicRMWInst::UMin: - NewVal = Builder.CreateICmpULE(Loaded, Inc); - return Builder.CreateSelect(NewVal, Loaded, Inc, "new"); + NewVal = Builder.CreateICmpULE(Loaded, Val); + return Builder.CreateSelect(NewVal, Loaded, Val, "new"); case AtomicRMWInst::FAdd: - return Builder.CreateFAdd(Loaded, Inc, "new"); + return Builder.CreateFAdd(Loaded, Val, "new"); case AtomicRMWInst::FSub: - return Builder.CreateFSub(Loaded, Inc, "new"); + return Builder.CreateFSub(Loaded, Val, "new"); case AtomicRMWInst::FMax: - return Builder.CreateMaxNum(Loaded, Inc); + return Builder.CreateMaxNum(Loaded, Val); case AtomicRMWInst::FMin: - return Builder.CreateMinNum(Loaded, Inc); + return Builder.CreateMinNum(Loaded, Val); + case AtomicRMWInst::UIncWrap: { + Constant *One = ConstantInt::get(Loaded->getType(), 1); + Value *Inc = Builder.CreateAdd(Loaded, One); + Value *Cmp = Builder.CreateICmpUGE(Loaded, Val); + Constant *Zero = ConstantInt::get(Loaded->getType(), 0); + return Builder.CreateSelect(Cmp, Zero, Inc, "new"); + } + case AtomicRMWInst::UDecWrap: { + Constant *Zero = ConstantInt::get(Loaded->getType(), 0); + Constant *One = ConstantInt::get(Loaded->getType(), 1); + + Value *Dec = Builder.CreateSub(Loaded, One); + Value *CmpEq0 = Builder.CreateICmpEQ(Loaded, Zero); + Value *CmpOldGtVal = Builder.CreateICmpUGT(Loaded, Val); + Value *Or = Builder.CreateOr(CmpEq0, CmpOldGtVal); + return Builder.CreateSelect(Or, Val, Dec, "new"); + } default: llvm_unreachable("Unknown atomic op"); } diff --git a/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp b/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp index 010deb77a883..195c274ff18e 100644 --- a/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp +++ b/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp @@ -175,7 +175,7 @@ static bool runImpl(Module &M) { FunctionType *VoidVoid = FunctionType::get(Type::getVoidTy(C), /*isVarArg=*/false); - for (auto Dtor : reverse(AssociatedAndMore.second)) + for (auto *Dtor : reverse(AssociatedAndMore.second)) CallInst::Create(VoidVoid, Dtor, "", BB); ReturnInst::Create(C, BB); diff --git a/llvm/lib/Transforms/Utils/LowerIFunc.cpp b/llvm/lib/Transforms/Utils/LowerIFunc.cpp new file mode 100644 index 000000000000..18ae0bbe2e73 --- /dev/null +++ b/llvm/lib/Transforms/Utils/LowerIFunc.cpp @@ -0,0 +1,27 @@ +//===- LowerIFunc.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements replacing calls to ifuncs by introducing indirect calls. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/LowerIFunc.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +/// Replace all call users of ifuncs in the module. +PreservedAnalyses LowerIFuncPass::run(Module &M, ModuleAnalysisManager &AM) { + if (M.ifunc_empty()) + return PreservedAnalyses::all(); + + lowerGlobalIFuncUsersAsGlobalCtor(M, {}); + return PreservedAnalyses::none(); +} diff --git a/llvm/lib/Transforms/Utils/LowerInvoke.cpp b/llvm/lib/Transforms/Utils/LowerInvoke.cpp index 59cfa41fb7fd..6d788857c1ea 100644 --- a/llvm/lib/Transforms/Utils/LowerInvoke.cpp +++ b/llvm/lib/Transforms/Utils/LowerInvoke.cpp @@ -66,7 +66,7 @@ static bool runImpl(Function &F) { II->getUnwindDest()->removePredecessor(&BB); // Remove the invoke instruction now. - BB.getInstList().erase(II); + II->eraseFromParent(); ++NumInvokes; Changed = true; diff --git a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp index b4acb1b2ae90..165740b55298 100644 --- a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -13,16 +13,15 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <optional> using namespace llvm; -void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, - Value *DstAddr, ConstantInt *CopyLen, - Align SrcAlign, Align DstAlign, - bool SrcIsVolatile, bool DstIsVolatile, - bool CanOverlap, - const TargetTransformInfo &TTI, - Optional<uint32_t> AtomicElementSize) { +void llvm::createMemCpyLoopKnownSize( + Instruction *InsertBefore, Value *SrcAddr, Value *DstAddr, + ConstantInt *CopyLen, Align SrcAlign, Align DstAlign, bool SrcIsVolatile, + bool DstIsVolatile, bool CanOverlap, const TargetTransformInfo &TTI, + std::optional<uint32_t> AtomicElementSize) { // No need to expand zero length copies. if (CopyLen->isZero()) return; @@ -122,11 +121,11 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, SrcAS, DstAS, SrcAlign.value(), DstAlign.value(), AtomicElementSize); - for (auto OpTy : RemainingOps) { + for (auto *OpTy : RemainingOps) { Align PartSrcAlign(commonAlignment(SrcAlign, BytesCopied)); Align PartDstAlign(commonAlignment(DstAlign, BytesCopied)); - // Calaculate the new index + // Calculate the new index unsigned OperandSize = DL.getTypeStoreSize(OpTy); assert( (!AtomicElementSize || OperandSize % *AtomicElementSize == 0) && @@ -173,13 +172,11 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, "Bytes copied should match size in the call!"); } -void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, - Value *SrcAddr, Value *DstAddr, - Value *CopyLen, Align SrcAlign, - Align DstAlign, bool SrcIsVolatile, - bool DstIsVolatile, bool CanOverlap, - const TargetTransformInfo &TTI, - Optional<uint32_t> AtomicElementSize) { +void llvm::createMemCpyLoopUnknownSize( + Instruction *InsertBefore, Value *SrcAddr, Value *DstAddr, Value *CopyLen, + Align SrcAlign, Align DstAlign, bool SrcIsVolatile, bool DstIsVolatile, + bool CanOverlap, const TargetTransformInfo &TTI, + std::optional<uint32_t> AtomicElementSize) { BasicBlock *PreLoopBB = InsertBefore->getParent(); BasicBlock *PostLoopBB = PreLoopBB->splitBasicBlock(InsertBefore, "post-loop-memcpy-expansion"); diff --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp index 44aeb26fadf9..227de425ff85 100644 --- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -51,9 +51,9 @@ using namespace llvm; namespace { - struct IntRange { - int64_t Low, High; - }; +struct IntRange { + APInt Low, High; +}; } // end anonymous namespace @@ -66,8 +66,8 @@ bool IsInRanges(const IntRange &R, const std::vector<IntRange> &Ranges) { // then check if the Low field is <= R.Low. If so, we // have a Range that covers R. auto I = llvm::lower_bound( - Ranges, R, [](IntRange A, IntRange B) { return A.High < B.High; }); - return I != Ranges.end() && I->Low <= R.Low; + Ranges, R, [](IntRange A, IntRange B) { return A.High.slt(B.High); }); + return I != Ranges.end() && I->Low.sle(R.Low); } struct CaseRange { @@ -116,15 +116,14 @@ raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) { /// 2) Removed if subsequent incoming values now share the same case, i.e., /// multiple outcome edges are condensed into one. This is necessary to keep the /// number of phi values equal to the number of branches to SuccBB. -void FixPhis( - BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, - const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) { +void FixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, + const APInt &NumMergedCases) { for (auto &I : SuccBB->phis()) { PHINode *PN = cast<PHINode>(&I); // Only update the first occurrence if NewBB exists. unsigned Idx = 0, E = PN->getNumIncomingValues(); - unsigned LocalNumMergedCases = NumMergedCases; + APInt LocalNumMergedCases = NumMergedCases; for (; Idx != E && NewBB; ++Idx) { if (PN->getIncomingBlock(Idx) == OrigBB) { PN->setIncomingBlock(Idx, NewBB); @@ -139,10 +138,10 @@ void FixPhis( // Remove additional occurrences coming from condensed cases and keep the // number of incoming values equal to the number of branches to SuccBB. SmallVector<unsigned, 8> Indices; - for (; LocalNumMergedCases > 0 && Idx < E; ++Idx) + for (; LocalNumMergedCases.ugt(0) && Idx < E; ++Idx) if (PN->getIncomingBlock(Idx) == OrigBB) { Indices.push_back(Idx); - LocalNumMergedCases--; + LocalNumMergedCases -= 1; } // Remove incoming values in the reverse order to prevent invalidating // *successive* index. @@ -160,7 +159,7 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound, BasicBlock *Default) { Function *F = OrigBlock->getParent(); BasicBlock *NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); - F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); + F->insert(++OrigBlock->getIterator(), NewLeaf); // Emit comparison ICmpInst *Comp = nullptr; @@ -209,8 +208,8 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound, for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { PHINode *PN = cast<PHINode>(I); // Remove all but one incoming entries from the cluster - uint64_t Range = Leaf.High->getSExtValue() - Leaf.Low->getSExtValue(); - for (uint64_t j = 0; j < Range; ++j) { + APInt Range = Leaf.High->getValue() - Leaf.Low->getValue(); + for (APInt j(Range.getBitWidth(), 0, true); j.slt(Range); ++j) { PN->removeIncomingValue(OrigBlock); } @@ -241,8 +240,7 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, // emitting the code that checks if the value actually falls in the range // because the bounds already tell us so. if (Begin->Low == LowerBound && Begin->High == UpperBound) { - unsigned NumMergedCases = 0; - NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue(); + APInt NumMergedCases = UpperBound->getValue() - LowerBound->getValue(); FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); return Begin->BB; } @@ -273,25 +271,24 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, if (!UnreachableRanges.empty()) { // Check if the gap between LHS's highest and NewLowerBound is unreachable. - int64_t GapLow = LHS.back().High->getSExtValue() + 1; - int64_t GapHigh = NewLowerBound->getSExtValue() - 1; - IntRange Gap = { GapLow, GapHigh }; - if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) + APInt GapLow = LHS.back().High->getValue() + 1; + APInt GapHigh = NewLowerBound->getValue() - 1; + IntRange Gap = {GapLow, GapHigh}; + if (GapHigh.sge(GapLow) && IsInRanges(Gap, UnreachableRanges)) NewUpperBound = LHS.back().High; } - LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", " - << NewUpperBound->getSExtValue() << "]\n" - << "RHS Bounds ==> [" << NewLowerBound->getSExtValue() - << ", " << UpperBound->getSExtValue() << "]\n"); + LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getValue() << ", " + << NewUpperBound->getValue() << "]\n" + << "RHS Bounds ==> [" << NewLowerBound->getValue() << ", " + << UpperBound->getValue() << "]\n"); // Create a new node that checks if the value is < pivot. Go to the // left branch if it is and right branch if not. - Function* F = OrigBlock->getParent(); - BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); + Function *F = OrigBlock->getParent(); + BasicBlock *NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); - ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, - Val, Pivot.Low, "Pivot"); + ICmpInst *Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot"); BasicBlock *LBranch = SwitchConvert(LHS.begin(), LHS.end(), LowerBound, NewUpperBound, Val, @@ -300,8 +297,8 @@ BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, SwitchConvert(RHS.begin(), RHS.end(), NewLowerBound, UpperBound, Val, NewNode, OrigBlock, Default, UnreachableRanges); - F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewNode); - NewNode->getInstList().push_back(Comp); + F->insert(++OrigBlock->getIterator(), NewNode); + Comp->insertInto(NewNode, NewNode->end()); BranchInst::Create(LBranch, RBranch, Comp, NewNode); return NewNode; @@ -328,14 +325,15 @@ unsigned Clusterify(CaseVector &Cases, SwitchInst *SI) { if (Cases.size() >= 2) { CaseItr I = Cases.begin(); for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { - int64_t nextValue = J->Low->getSExtValue(); - int64_t currentValue = I->High->getSExtValue(); - BasicBlock* nextBB = J->BB; - BasicBlock* currentBB = I->BB; + const APInt &nextValue = J->Low->getValue(); + const APInt ¤tValue = I->High->getValue(); + BasicBlock *nextBB = J->BB; + BasicBlock *currentBB = I->BB; // If the two neighboring cases go to the same destination, merge them // into a single case. - assert(nextValue > currentValue && "Cases should be strictly ascending"); + assert(nextValue.sgt(currentValue) && + "Cases should be strictly ascending"); if ((nextValue == currentValue + 1) && (currentBB == nextBB)) { I->High = J->High; // FIXME: Combine branch weights. @@ -356,8 +354,8 @@ void ProcessSwitchInst(SwitchInst *SI, AssumptionCache *AC, LazyValueInfo *LVI) { BasicBlock *OrigBlock = SI->getParent(); Function *F = OrigBlock->getParent(); - Value *Val = SI->getCondition(); // The value we are switching on... - BasicBlock* Default = SI->getDefaultDest(); + Value *Val = SI->getCondition(); // The value we are switching on... + BasicBlock *Default = SI->getDefaultDest(); // Don't handle unreachable blocks. If there are successors with phis, this // would leave them behind with missing predecessors. @@ -370,6 +368,12 @@ void ProcessSwitchInst(SwitchInst *SI, // Prepare cases vector. CaseVector Cases; const unsigned NumSimpleCases = Clusterify(Cases, SI); + IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType()); + const unsigned BitWidth = IT->getBitWidth(); + // Explictly use higher precision to prevent unsigned overflow where + // `UnsignedMax - 0 + 1 == 0` + APInt UnsignedZero(BitWidth + 1, 0); + APInt UnsignedMax = APInt::getMaxValue(BitWidth); LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() << ". Total non-default cases: " << NumSimpleCases << "\nCase clusters: " << Cases << "\n"); @@ -378,7 +382,7 @@ void ProcessSwitchInst(SwitchInst *SI, if (Cases.empty()) { BranchInst::Create(Default, OrigBlock); // Remove all the references from Default's PHIs to OrigBlock, but one. - FixPhis(Default, OrigBlock, OrigBlock); + FixPhis(Default, OrigBlock, OrigBlock, UnsignedMax); SI->eraseFromParent(); return; } @@ -415,8 +419,8 @@ void ProcessSwitchInst(SwitchInst *SI, // the unlikely event that some of them survived, we just conservatively // maintain the invariant that all the cases lie between the bounds. This // may, however, still render the default case effectively unreachable. - APInt Low = Cases.front().Low->getValue(); - APInt High = Cases.back().High->getValue(); + const APInt &Low = Cases.front().Low->getValue(); + const APInt &High = Cases.back().High->getValue(); APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low); APInt Max = APIntOps::smax(ValRange.getSignedMax(), High); @@ -428,35 +432,38 @@ void ProcessSwitchInst(SwitchInst *SI, std::vector<IntRange> UnreachableRanges; if (DefaultIsUnreachableFromSwitch) { - DenseMap<BasicBlock *, unsigned> Popularity; - unsigned MaxPop = 0; + DenseMap<BasicBlock *, APInt> Popularity; + APInt MaxPop(UnsignedZero); BasicBlock *PopSucc = nullptr; - IntRange R = {std::numeric_limits<int64_t>::min(), - std::numeric_limits<int64_t>::max()}; + APInt SignedMax = APInt::getSignedMaxValue(BitWidth); + APInt SignedMin = APInt::getSignedMinValue(BitWidth); + IntRange R = {SignedMin, SignedMax}; UnreachableRanges.push_back(R); for (const auto &I : Cases) { - int64_t Low = I.Low->getSExtValue(); - int64_t High = I.High->getSExtValue(); + const APInt &Low = I.Low->getValue(); + const APInt &High = I.High->getValue(); IntRange &LastRange = UnreachableRanges.back(); - if (LastRange.Low == Low) { + if (LastRange.Low.eq(Low)) { // There is nothing left of the previous range. UnreachableRanges.pop_back(); } else { // Terminate the previous range. - assert(Low > LastRange.Low); + assert(Low.sgt(LastRange.Low)); LastRange.High = Low - 1; } - if (High != std::numeric_limits<int64_t>::max()) { - IntRange R = { High + 1, std::numeric_limits<int64_t>::max() }; + if (High.ne(SignedMax)) { + IntRange R = {High + 1, SignedMax}; UnreachableRanges.push_back(R); } // Count popularity. - int64_t N = High - Low + 1; - unsigned &Pop = Popularity[I.BB]; - if ((Pop += N) > MaxPop) { + assert(High.sge(Low) && "Popularity shouldn't be negative."); + APInt N = High.sext(BitWidth + 1) - Low.sext(BitWidth + 1) + 1; + // Explict insert to make sure the bitwidth of APInts match + APInt &Pop = Popularity.insert({I.BB, APInt(UnsignedZero)}).first->second; + if ((Pop += N).ugt(MaxPop)) { MaxPop = Pop; PopSucc = I.BB; } @@ -465,10 +472,10 @@ void ProcessSwitchInst(SwitchInst *SI, /* UnreachableRanges should be sorted and the ranges non-adjacent. */ for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end(); I != E; ++I) { - assert(I->Low <= I->High); + assert(I->Low.sle(I->High)); auto Next = I + 1; if (Next != E) { - assert(Next->Low > I->High); + assert(Next->Low.sgt(I->High)); } } #endif @@ -481,7 +488,6 @@ void ProcessSwitchInst(SwitchInst *SI, // Use the most popular block as the new default, reducing the number of // cases. - assert(MaxPop > 0 && PopSucc); Default = PopSucc; llvm::erase_if(Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }); @@ -492,8 +498,9 @@ void ProcessSwitchInst(SwitchInst *SI, SI->eraseFromParent(); // As all the cases have been replaced with a single branch, only keep // one entry in the PHI nodes. - for (unsigned I = 0 ; I < (MaxPop - 1) ; ++I) - PopSucc->removePredecessor(OrigBlock); + if (!MaxPop.isZero()) + for (APInt I(UnsignedZero); I.ult(MaxPop - 1); ++I) + PopSucc->removePredecessor(OrigBlock); return; } @@ -513,14 +520,14 @@ void ProcessSwitchInst(SwitchInst *SI, // that SwitchBlock is the same as Default, under which the PHIs in Default // are fixed inside SwitchConvert(). if (SwitchBlock != Default) - FixPhis(Default, OrigBlock, nullptr); + FixPhis(Default, OrigBlock, nullptr, UnsignedMax); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); // We are now done with the switch instruction, delete it. BasicBlock *OldDefault = SI->getDefaultDest(); - OrigBlock->getInstList().erase(SI); + SI->eraseFromParent(); // If the Default block has no more predecessors just add it to DeleteList. if (pred_empty(OldDefault)) diff --git a/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp b/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp index 68d4dd9d576b..899928c085c6 100644 --- a/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp +++ b/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp @@ -16,6 +16,7 @@ #include "llvm/IR/DebugInfo.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include <optional> using namespace llvm; using namespace llvm::ore; @@ -144,9 +145,10 @@ static void inlineVolatileOrAtomicWithExtraArgs(bool *Inline, bool Volatile, R << " Atomic: " << NV("StoreAtomic", false) << "."; } -static Optional<uint64_t> getSizeInBytes(Optional<uint64_t> SizeInBits) { +static std::optional<uint64_t> +getSizeInBytes(std::optional<uint64_t> SizeInBits) { if (!SizeInBits || *SizeInBits % 8 != 0) - return None; + return std::nullopt; return *SizeInBits / 8; } @@ -297,17 +299,17 @@ void MemoryOpRemark::visitSizeOperand(Value *V, DiagnosticInfoIROptimization &R) } } -static Optional<StringRef> nameOrNone(const Value *V) { +static std::optional<StringRef> nameOrNone(const Value *V) { if (V->hasName()) return V->getName(); - return None; + return std::nullopt; } void MemoryOpRemark::visitVariable(const Value *V, SmallVectorImpl<VariableInfo> &Result) { if (auto *GV = dyn_cast<GlobalVariable>(V)) { auto *Ty = GV->getValueType(); - uint64_t Size = DL.getTypeSizeInBits(Ty).getFixedSize(); + uint64_t Size = DL.getTypeSizeInBits(Ty).getFixedValue(); VariableInfo Var{nameOrNone(GV), Size}; if (!Var.isEmpty()) Result.push_back(std::move(Var)); @@ -321,7 +323,7 @@ void MemoryOpRemark::visitVariable(const Value *V, for (const DbgVariableIntrinsic *DVI : FindDbgAddrUses(const_cast<Value *>(V))) { if (DILocalVariable *DILV = DVI->getVariable()) { - Optional<uint64_t> DISize = getSizeInBytes(DILV->getSizeInBits()); + std::optional<uint64_t> DISize = getSizeInBytes(DILV->getSizeInBits()); VariableInfo Var{DILV->getName(), DISize}; if (!Var.isEmpty()) { Result.push_back(std::move(Var)); @@ -339,9 +341,9 @@ void MemoryOpRemark::visitVariable(const Value *V, return; // If not, get it from the alloca. - Optional<TypeSize> TySize = AI->getAllocationSizeInBits(DL); - Optional<uint64_t> Size = - TySize ? getSizeInBytes(TySize->getFixedSize()) : None; + std::optional<TypeSize> TySize = AI->getAllocationSize(DL); + std::optional<uint64_t> Size = + TySize ? std::optional(TySize->getFixedValue()) : std::nullopt; VariableInfo Var{nameOrNone(AI), Size}; if (!Var.isEmpty()) Result.push_back(std::move(Var)); @@ -361,7 +363,7 @@ void MemoryOpRemark::visitPtr(Value *Ptr, bool IsRead, DiagnosticInfoIROptimizat uint64_t Size = Ptr->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed); if (!Size) return; - VIs.push_back({None, Size}); + VIs.push_back({std::nullopt, Size}); } R << (IsRead ? "\n Read Variables: " : "\n Written Variables: "); diff --git a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp index a1029475cf1d..1e42d7491676 100644 --- a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp +++ b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp @@ -14,9 +14,11 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" namespace llvm { namespace memtag { @@ -114,7 +116,7 @@ void StackInfoBuilder::visit(Instruction &Inst) { } } if (AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { - if (IsInterestingAlloca(*AI)) { + if (isInterestingAlloca(*AI)) { Info.AllocasToInstrument[AI].AI = AI; } return; @@ -127,7 +129,7 @@ void StackInfoBuilder::visit(Instruction &Inst) { Info.UnrecognizedLifetimes.push_back(&Inst); return; } - if (!IsInterestingAlloca(*AI)) + if (!isInterestingAlloca(*AI)) return; if (II->getIntrinsicID() == Intrinsic::lifetime_start) Info.AllocasToInstrument[AI].LifetimeStart.push_back(II); @@ -138,7 +140,7 @@ void StackInfoBuilder::visit(Instruction &Inst) { if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&Inst)) { for (Value *V : DVI->location_ops()) { if (auto *AI = dyn_cast_or_null<AllocaInst>(V)) { - if (!IsInterestingAlloca(*AI)) + if (!isInterestingAlloca(*AI)) continue; AllocaInfo &AInfo = Info.AllocasToInstrument[AI]; auto &DVIVec = AInfo.DbgVariableIntrinsics; @@ -152,9 +154,27 @@ void StackInfoBuilder::visit(Instruction &Inst) { Info.RetVec.push_back(ExitUntag); } +bool StackInfoBuilder::isInterestingAlloca(const AllocaInst &AI) { + return (AI.getAllocatedType()->isSized() && + // FIXME: instrument dynamic allocas, too + AI.isStaticAlloca() && + // alloca() may be called with 0 size, ignore it. + memtag::getAllocaSizeInBytes(AI) > 0 && + // We are only interested in allocas not promotable to registers. + // Promotable allocas are common under -O0. + !isAllocaPromotable(&AI) && + // inalloca allocas are not treated as static, and we don't want + // dynamic alloca instrumentation for them as well. + !AI.isUsedWithInAlloca() && + // swifterror allocas are register promoted by ISel + !AI.isSwiftError()) && + // safe allocas are not interesting + !(SSI && SSI->isSafe(AI)); +} + uint64_t getAllocaSizeInBytes(const AllocaInst &AI) { auto DL = AI.getModule()->getDataLayout(); - return *AI.getAllocationSizeInBits(DL) / 8; + return *AI.getAllocationSize(DL); } void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Alignment) { @@ -176,16 +196,20 @@ void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Alignment) { : Info.AI->getAllocatedType(); Type *PaddingType = ArrayType::get(Type::getInt8Ty(Ctx), AlignedSize - Size); Type *TypeWithPadding = StructType::get(AllocatedType, PaddingType); - auto *NewAI = - new AllocaInst(TypeWithPadding, Info.AI->getType()->getAddressSpace(), - nullptr, "", Info.AI); + auto *NewAI = new AllocaInst(TypeWithPadding, Info.AI->getAddressSpace(), + nullptr, "", Info.AI); NewAI->takeName(Info.AI); NewAI->setAlignment(Info.AI->getAlign()); NewAI->setUsedWithInAlloca(Info.AI->isUsedWithInAlloca()); NewAI->setSwiftError(Info.AI->isSwiftError()); NewAI->copyMetadata(*Info.AI); - auto *NewPtr = new BitCastInst(NewAI, Info.AI->getType(), "", Info.AI); + Value *NewPtr = NewAI; + + // TODO: Remove when typed pointers dropped + if (Info.AI->getType() != NewAI->getType()) + NewPtr = new BitCastInst(NewAI, Info.AI->getType(), "", Info.AI); + Info.AI->replaceAllUsesWith(NewPtr); Info.AI->eraseFromParent(); Info.AI = NewAI; diff --git a/llvm/lib/Transforms/Utils/MetaRenamer.cpp b/llvm/lib/Transforms/Utils/MetaRenamer.cpp index 9fba2f3f86b5..0ea210671b93 100644 --- a/llvm/lib/Transforms/Utils/MetaRenamer.cpp +++ b/llvm/lib/Transforms/Utils/MetaRenamer.cpp @@ -87,7 +87,7 @@ struct Renamer { Renamer(unsigned int seed) { prng.srand(seed); } const char *newName() { - return metaNames[prng.rand() % array_lengthof(metaNames)]; + return metaNames[prng.rand() % std::size(metaNames)]; } PRNG prng; diff --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp index 4414b04c7264..6f5a25a26821 100644 --- a/llvm/lib/Transforms/Utils/MisExpect.cpp +++ b/llvm/lib/Transforms/Utils/MisExpect.cpp @@ -35,10 +35,12 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" +#include <algorithm> #include <cstdint> #include <functional> #include <numeric> @@ -57,7 +59,7 @@ static cl::opt<bool> PGOWarnMisExpect( cl::desc("Use this option to turn on/off " "warnings about incorrect usage of llvm.expect intrinsics.")); -static cl::opt<unsigned> MisExpectTolerance( +static cl::opt<uint32_t> MisExpectTolerance( "misexpect-tolerance", cl::init(0), cl::desc("Prevents emiting diagnostics when profile counts are " "within N% of the threshold..")); @@ -70,8 +72,8 @@ bool isMisExpectDiagEnabled(LLVMContext &Ctx) { return PGOWarnMisExpect || Ctx.getMisExpectWarningRequested(); } -uint64_t getMisExpectTolerance(LLVMContext &Ctx) { - return std::max(static_cast<uint64_t>(MisExpectTolerance), +uint32_t getMisExpectTolerance(LLVMContext &Ctx) { + return std::max(static_cast<uint32_t>(MisExpectTolerance), Ctx.getDiagnosticsMisExpectTolerance()); } @@ -118,43 +120,6 @@ void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx, namespace llvm { namespace misexpect { -// Helper function to extract branch weights into a vector -Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I, - LLVMContext &Ctx) { - assert(I && "MisExpect::extractWeights given invalid pointer"); - - auto *ProfileData = I->getMetadata(LLVMContext::MD_prof); - if (!ProfileData) - return None; - - unsigned NOps = ProfileData->getNumOperands(); - if (NOps < 3) - return None; - - auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); - if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) - return None; - - SmallVector<uint32_t, 4> Weights(NOps - 1); - for (unsigned Idx = 1; Idx < NOps; Idx++) { - ConstantInt *Value = - mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); - uint32_t V = Value->getZExtValue(); - Weights[Idx - 1] = V; - } - - return Weights; -} - -// TODO: when clang allows c++17, use std::clamp instead -uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) { - if (value > hi) - return hi; - if (value < low) - return low; - return value; -} - void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights, ArrayRef<uint32_t> ExpectedWeights) { // To determine if we emit a diagnostic, we need to compare the branch weights @@ -190,6 +155,8 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights, // We cannot calculate branch probability if either of these invariants aren't // met. However, MisExpect diagnostics should not prevent code from compiling, // so we simply forgo emitting diagnostics here, and return early. + // assert((TotalBranchWeight >= LikelyBranchWeight) && (TotalBranchWeight > 0) + // && "TotalBranchWeight is less than the Likely branch weight"); if ((TotalBranchWeight == 0) || (TotalBranchWeight <= LikelyBranchWeight)) return; @@ -203,7 +170,7 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights, // clamp tolerance range to [0, 100) auto Tolerance = getMisExpectTolerance(I.getContext()); - Tolerance = clamp(Tolerance, 0, 99); + Tolerance = std::clamp(Tolerance, 0u, 99u); // Allow users to relax checking by N% i.e., if they use a 5% tolerance, // then we check against 0.95*ScaledThreshold @@ -218,26 +185,24 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights, void checkBackendInstrumentation(Instruction &I, const ArrayRef<uint32_t> RealWeights) { - auto ExpectedWeightsOpt = extractWeights(&I, I.getContext()); - if (!ExpectedWeightsOpt) + SmallVector<uint32_t> ExpectedWeights; + if (!extractBranchWeights(I, ExpectedWeights)) return; - auto ExpectedWeights = ExpectedWeightsOpt.value(); verifyMisExpect(I, RealWeights, ExpectedWeights); } void checkFrontendInstrumentation(Instruction &I, const ArrayRef<uint32_t> ExpectedWeights) { - auto RealWeightsOpt = extractWeights(&I, I.getContext()); - if (!RealWeightsOpt) + SmallVector<uint32_t> RealWeights; + if (!extractBranchWeights(I, RealWeights)) return; - auto RealWeights = RealWeightsOpt.value(); verifyMisExpect(I, RealWeights, ExpectedWeights); } void checkExpectAnnotations(Instruction &I, const ArrayRef<uint32_t> ExistingWeights, - bool IsFrontendInstr) { - if (IsFrontendInstr) { + bool IsFrontend) { + if (IsFrontend) { checkFrontendInstrumentation(I, ExistingWeights); } else { checkBackendInstrumentation(I, ExistingWeights); diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 9e1492b97a86..6d17a466957e 100644 --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -15,13 +15,15 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Support/xxhash.h" using namespace llvm; #define DEBUG_TYPE "moduleutils" -static void appendToGlobalArray(const char *Array, Module &M, Function *F, +static void appendToGlobalArray(StringRef ArrayName, Module &M, Function *F, int Priority, Constant *Data) { IRBuilder<> IRB(M.getContext()); FunctionType *FnTy = FunctionType::get(IRB.getVoidTy(), false); @@ -30,8 +32,10 @@ static void appendToGlobalArray(const char *Array, Module &M, Function *F, // to the list. SmallVector<Constant *, 16> CurrentCtors; StructType *EltTy = StructType::get( - IRB.getInt32Ty(), PointerType::getUnqual(FnTy), IRB.getInt8PtrTy()); - if (GlobalVariable *GVCtor = M.getNamedGlobal(Array)) { + IRB.getInt32Ty(), PointerType::get(FnTy, F->getAddressSpace()), + IRB.getInt8PtrTy()); + + if (GlobalVariable *GVCtor = M.getNamedGlobal(ArrayName)) { if (Constant *Init = GVCtor->getInitializer()) { unsigned n = Init->getNumOperands(); CurrentCtors.reserve(n + 1); @@ -48,7 +52,7 @@ static void appendToGlobalArray(const char *Array, Module &M, Function *F, CSVals[2] = Data ? ConstantExpr::getPointerCast(Data, IRB.getInt8PtrTy()) : Constant::getNullValue(IRB.getInt8PtrTy()); Constant *RuntimeCtorInit = - ConstantStruct::get(EltTy, makeArrayRef(CSVals, EltTy->getNumElements())); + ConstantStruct::get(EltTy, ArrayRef(CSVals, EltTy->getNumElements())); CurrentCtors.push_back(RuntimeCtorInit); @@ -59,7 +63,7 @@ static void appendToGlobalArray(const char *Array, Module &M, Function *F, // Create the new global variable and replace all uses of // the old global variable with the new one. (void)new GlobalVariable(M, NewInit->getType(), false, - GlobalValue::AppendingLinkage, NewInit, Array); + GlobalValue::AppendingLinkage, NewInit, ArrayName); } void llvm::appendToGlobalCtors(Module &M, Function *F, int Priority, Constant *Data) { @@ -70,35 +74,35 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *D appendToGlobalArray("llvm.global_dtors", M, F, Priority, Data); } +static void collectUsedGlobals(GlobalVariable *GV, + SmallSetVector<Constant *, 16> &Init) { + if (!GV || !GV->hasInitializer()) + return; + + auto *CA = cast<ConstantArray>(GV->getInitializer()); + for (Use &Op : CA->operands()) + Init.insert(cast<Constant>(Op)); +} + static void appendToUsedList(Module &M, StringRef Name, ArrayRef<GlobalValue *> Values) { GlobalVariable *GV = M.getGlobalVariable(Name); - SmallPtrSet<Constant *, 16> InitAsSet; - SmallVector<Constant *, 16> Init; - if (GV) { - if (GV->hasInitializer()) { - auto *CA = cast<ConstantArray>(GV->getInitializer()); - for (auto &Op : CA->operands()) { - Constant *C = cast_or_null<Constant>(Op); - if (InitAsSet.insert(C).second) - Init.push_back(C); - } - } + + SmallSetVector<Constant *, 16> Init; + collectUsedGlobals(GV, Init); + if (GV) GV->eraseFromParent(); - } - Type *Int8PtrTy = llvm::Type::getInt8PtrTy(M.getContext()); - for (auto *V : Values) { - Constant *C = ConstantExpr::getPointerBitCastOrAddrSpaceCast(V, Int8PtrTy); - if (InitAsSet.insert(C).second) - Init.push_back(C); - } + Type *ArrayEltTy = llvm::Type::getInt8PtrTy(M.getContext()); + for (auto *V : Values) + Init.insert(ConstantExpr::getPointerBitCastOrAddrSpaceCast(V, ArrayEltTy)); if (Init.empty()) return; - ArrayType *ATy = ArrayType::get(Int8PtrTy, Init.size()); + ArrayType *ATy = ArrayType::get(ArrayEltTy, Init.size()); GV = new llvm::GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage, - ConstantArray::get(ATy, Init), Name); + ConstantArray::get(ATy, Init.getArrayRef()), + Name); GV->setSection("llvm.metadata"); } @@ -110,21 +114,82 @@ void llvm::appendToCompilerUsed(Module &M, ArrayRef<GlobalValue *> Values) { appendToUsedList(M, "llvm.compiler.used", Values); } -FunctionCallee -llvm::declareSanitizerInitFunction(Module &M, StringRef InitName, - ArrayRef<Type *> InitArgTypes) { +static void removeFromUsedList(Module &M, StringRef Name, + function_ref<bool(Constant *)> ShouldRemove) { + GlobalVariable *GV = M.getNamedGlobal(Name); + if (!GV) + return; + + SmallSetVector<Constant *, 16> Init; + collectUsedGlobals(GV, Init); + + Type *ArrayEltTy = cast<ArrayType>(GV->getValueType())->getElementType(); + + SmallVector<Constant *, 16> NewInit; + for (Constant *MaybeRemoved : Init) { + if (!ShouldRemove(MaybeRemoved->stripPointerCasts())) + NewInit.push_back(MaybeRemoved); + } + + if (!NewInit.empty()) { + ArrayType *ATy = ArrayType::get(ArrayEltTy, NewInit.size()); + GlobalVariable *NewGV = + new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage, + ConstantArray::get(ATy, NewInit), "", GV, + GV->getThreadLocalMode(), GV->getAddressSpace()); + NewGV->setSection(GV->getSection()); + NewGV->takeName(GV); + } + + GV->eraseFromParent(); +} + +void llvm::removeFromUsedLists(Module &M, + function_ref<bool(Constant *)> ShouldRemove) { + removeFromUsedList(M, "llvm.used", ShouldRemove); + removeFromUsedList(M, "llvm.compiler.used", ShouldRemove); +} + +void llvm::setKCFIType(Module &M, Function &F, StringRef MangledType) { + if (!M.getModuleFlag("kcfi")) + return; + // Matches CodeGenModule::CreateKCFITypeId in Clang. + LLVMContext &Ctx = M.getContext(); + MDBuilder MDB(Ctx); + F.setMetadata( + LLVMContext::MD_kcfi_type, + MDNode::get(Ctx, MDB.createConstant(ConstantInt::get( + Type::getInt32Ty(Ctx), + static_cast<uint32_t>(xxHash64(MangledType)))))); + // If the module was compiled with -fpatchable-function-entry, ensure + // we use the same patchable-function-prefix. + if (auto *MD = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("kcfi-offset"))) { + if (unsigned Offset = MD->getZExtValue()) + F.addFnAttr("patchable-function-prefix", std::to_string(Offset)); + } +} + +FunctionCallee llvm::declareSanitizerInitFunction(Module &M, StringRef InitName, + ArrayRef<Type *> InitArgTypes, + bool Weak) { assert(!InitName.empty() && "Expected init function name"); - return M.getOrInsertFunction( - InitName, - FunctionType::get(Type::getVoidTy(M.getContext()), InitArgTypes, false), - AttributeList()); + auto *VoidTy = Type::getVoidTy(M.getContext()); + auto *FnTy = FunctionType::get(VoidTy, InitArgTypes, false); + auto FnCallee = M.getOrInsertFunction(InitName, FnTy); + auto *Fn = cast<Function>(FnCallee.getCallee()); + if (Weak && Fn->isDeclaration()) + Fn->setLinkage(Function::ExternalWeakLinkage); + return FnCallee; } Function *llvm::createSanitizerCtor(Module &M, StringRef CtorName) { Function *Ctor = Function::createWithDefaultAttr( FunctionType::get(Type::getVoidTy(M.getContext()), false), - GlobalValue::InternalLinkage, 0, CtorName, &M); + GlobalValue::InternalLinkage, M.getDataLayout().getProgramAddressSpace(), + CtorName, &M); Ctor->addFnAttr(Attribute::NoUnwind); + setKCFIType(M, *Ctor, "_ZTSFvvE"); // void (*)(void) BasicBlock *CtorBB = BasicBlock::Create(M.getContext(), "", Ctor); ReturnInst::Create(M.getContext(), CtorBB); // Ensure Ctor cannot be discarded, even if in a comdat. @@ -135,14 +200,33 @@ Function *llvm::createSanitizerCtor(Module &M, StringRef CtorName) { std::pair<Function *, FunctionCallee> llvm::createSanitizerCtorAndInitFunctions( Module &M, StringRef CtorName, StringRef InitName, ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs, - StringRef VersionCheckName) { + StringRef VersionCheckName, bool Weak) { assert(!InitName.empty() && "Expected init function name"); assert(InitArgs.size() == InitArgTypes.size() && "Sanitizer's init function expects different number of arguments"); FunctionCallee InitFunction = - declareSanitizerInitFunction(M, InitName, InitArgTypes); + declareSanitizerInitFunction(M, InitName, InitArgTypes, Weak); Function *Ctor = createSanitizerCtor(M, CtorName); - IRBuilder<> IRB(Ctor->getEntryBlock().getTerminator()); + IRBuilder<> IRB(M.getContext()); + + BasicBlock *RetBB = &Ctor->getEntryBlock(); + if (Weak) { + RetBB->setName("ret"); + auto *EntryBB = BasicBlock::Create(M.getContext(), "entry", Ctor, RetBB); + auto *CallInitBB = + BasicBlock::Create(M.getContext(), "callfunc", Ctor, RetBB); + auto *InitFn = cast<Function>(InitFunction.getCallee()); + auto *InitFnPtr = + PointerType::get(InitFn->getType(), InitFn->getAddressSpace()); + IRB.SetInsertPoint(EntryBB); + Value *InitNotNull = + IRB.CreateICmpNE(InitFn, ConstantPointerNull::get(InitFnPtr)); + IRB.CreateCondBr(InitNotNull, CallInitBB, RetBB); + IRB.SetInsertPoint(CallInitBB); + } else { + IRB.SetInsertPoint(RetBB->getTerminator()); + } + IRB.CreateCall(InitFunction, InitArgs); if (!VersionCheckName.empty()) { FunctionCallee VersionCheckFunction = M.getOrInsertFunction( @@ -150,6 +234,10 @@ std::pair<Function *, FunctionCallee> llvm::createSanitizerCtorAndInitFunctions( AttributeList()); IRB.CreateCall(VersionCheckFunction, {}); } + + if (Weak) + IRB.CreateBr(RetBB); + return std::make_pair(Ctor, InitFunction); } @@ -158,7 +246,7 @@ llvm::getOrCreateSanitizerCtorAndInitFunctions( Module &M, StringRef CtorName, StringRef InitName, ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs, function_ref<void(Function *, FunctionCallee)> FunctionsCreatedCallback, - StringRef VersionCheckName) { + StringRef VersionCheckName, bool Weak) { assert(!CtorName.empty() && "Expected ctor function name"); if (Function *Ctor = M.getFunction(CtorName)) @@ -166,12 +254,13 @@ llvm::getOrCreateSanitizerCtorAndInitFunctions( // globals. This will make moving to a concurrent model much easier. if (Ctor->arg_empty() || Ctor->getReturnType() == Type::getVoidTy(M.getContext())) - return {Ctor, declareSanitizerInitFunction(M, InitName, InitArgTypes)}; + return {Ctor, + declareSanitizerInitFunction(M, InitName, InitArgTypes, Weak)}; Function *Ctor; FunctionCallee InitFunction; std::tie(Ctor, InitFunction) = llvm::createSanitizerCtorAndInitFunctions( - M, CtorName, InitName, InitArgTypes, InitArgs, VersionCheckName); + M, CtorName, InitName, InitArgTypes, InitArgs, VersionCheckName, Weak); FunctionsCreatedCallback(Ctor, InitFunction); return std::make_pair(Ctor, InitFunction); } @@ -253,9 +342,9 @@ void VFABI::setVectorVariantNames(CallInst *CI, #ifndef NDEBUG for (const std::string &VariantMapping : VariantMappings) { LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << VariantMapping << "'\n"); - Optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *M); + std::optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *M); assert(VI && "Cannot add an invalid VFABI name."); - assert(M->getNamedValue(VI.value().VectorName) && + assert(M->getNamedValue(VI->VectorName) && "Cannot add variant to attribute: " "vector function declaration is missing."); } @@ -268,7 +357,7 @@ void llvm::embedBufferInModule(Module &M, MemoryBufferRef Buf, StringRef SectionName, Align Alignment) { // Embed the memory buffer into the module. Constant *ModuleConstant = ConstantDataArray::get( - M.getContext(), makeArrayRef(Buf.getBufferStart(), Buf.getBufferSize())); + M.getContext(), ArrayRef(Buf.getBufferStart(), Buf.getBufferSize())); GlobalVariable *GV = new GlobalVariable( M, ModuleConstant->getType(), true, GlobalValue::PrivateLinkage, ModuleConstant, "llvm.embedded.object"); @@ -285,3 +374,102 @@ void llvm::embedBufferInModule(Module &M, MemoryBufferRef Buf, appendToCompilerUsed(M, GV); } + +bool llvm::lowerGlobalIFuncUsersAsGlobalCtor( + Module &M, ArrayRef<GlobalIFunc *> FilteredIFuncsToLower) { + SmallVector<GlobalIFunc *, 32> AllIFuncs; + ArrayRef<GlobalIFunc *> IFuncsToLower = FilteredIFuncsToLower; + if (FilteredIFuncsToLower.empty()) { // Default to lowering all ifuncs + for (GlobalIFunc &GI : M.ifuncs()) + AllIFuncs.push_back(&GI); + IFuncsToLower = AllIFuncs; + } + + bool UnhandledUsers = false; + LLVMContext &Ctx = M.getContext(); + const DataLayout &DL = M.getDataLayout(); + + PointerType *TableEntryTy = + Ctx.supportsTypedPointers() + ? PointerType::get(Type::getInt8Ty(Ctx), DL.getProgramAddressSpace()) + : PointerType::get(Ctx, DL.getProgramAddressSpace()); + + ArrayType *FuncPtrTableTy = + ArrayType::get(TableEntryTy, IFuncsToLower.size()); + + Align PtrAlign = DL.getABITypeAlign(TableEntryTy); + + // Create a global table of function pointers we'll initialize in a global + // constructor. + auto *FuncPtrTable = new GlobalVariable( + M, FuncPtrTableTy, false, GlobalValue::InternalLinkage, + PoisonValue::get(FuncPtrTableTy), "", nullptr, + GlobalVariable::NotThreadLocal, DL.getDefaultGlobalsAddressSpace()); + FuncPtrTable->setAlignment(PtrAlign); + + // Create a function to initialize the function pointer table. + Function *NewCtor = Function::Create( + FunctionType::get(Type::getVoidTy(Ctx), false), Function::InternalLinkage, + DL.getProgramAddressSpace(), "", &M); + + BasicBlock *BB = BasicBlock::Create(Ctx, "", NewCtor); + IRBuilder<> InitBuilder(BB); + + size_t TableIndex = 0; + for (GlobalIFunc *GI : IFuncsToLower) { + Function *ResolvedFunction = GI->getResolverFunction(); + + // We don't know what to pass to a resolver function taking arguments + // + // FIXME: Is this even valid? clang and gcc don't complain but this + // probably should be invalid IR. We could just pass through undef. + if (!std::empty(ResolvedFunction->getFunctionType()->params())) { + LLVM_DEBUG(dbgs() << "Not lowering ifunc resolver function " + << ResolvedFunction->getName() << " with parameters\n"); + UnhandledUsers = true; + continue; + } + + // Initialize the function pointer table. + CallInst *ResolvedFunc = InitBuilder.CreateCall(ResolvedFunction); + Value *Casted = InitBuilder.CreatePointerCast(ResolvedFunc, TableEntryTy); + Constant *GEP = cast<Constant>(InitBuilder.CreateConstInBoundsGEP2_32( + FuncPtrTableTy, FuncPtrTable, 0, TableIndex++)); + InitBuilder.CreateAlignedStore(Casted, GEP, PtrAlign); + + // Update all users to load a pointer from the global table. + for (User *User : make_early_inc_range(GI->users())) { + Instruction *UserInst = dyn_cast<Instruction>(User); + if (!UserInst) { + // TODO: Should handle constantexpr casts in user instructions. Probably + // can't do much about constant initializers. + UnhandledUsers = true; + continue; + } + + IRBuilder<> UseBuilder(UserInst); + LoadInst *ResolvedTarget = + UseBuilder.CreateAlignedLoad(TableEntryTy, GEP, PtrAlign); + Value *ResolvedCast = + UseBuilder.CreatePointerCast(ResolvedTarget, GI->getType()); + UserInst->replaceUsesOfWith(GI, ResolvedCast); + } + + // If we handled all users, erase the ifunc. + if (GI->use_empty()) + GI->eraseFromParent(); + } + + InitBuilder.CreateRetVoid(); + + PointerType *ConstantDataTy = Ctx.supportsTypedPointers() + ? PointerType::get(Type::getInt8Ty(Ctx), 0) + : PointerType::get(Ctx, 0); + + // TODO: Is this the right priority? Probably should be before any other + // constructors? + const int Priority = 10; + appendToGlobalCtors(M, NewCtor, Priority, + ConstantPointerNull::get(ConstantDataTy)); + return UnhandledUsers; +} diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp index 53334bc2a369..1f16ba78bdb0 100644 --- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -509,7 +509,7 @@ void PredicateInfoBuilder::buildPredicateInfo() { // Collect operands to rename from all conditional branch terminators, as well // as assume statements. SmallVector<Value *, 8> OpsToRename; - for (auto DTN : depth_first(DT.getRootNode())) { + for (auto *DTN : depth_first(DT.getRootNode())) { BasicBlock *BranchBB = DTN->getBlock(); if (auto *BI = dyn_cast<BranchInst>(BranchBB->getTerminator())) { if (!BI->isConditional()) @@ -626,7 +626,7 @@ void PredicateInfoBuilder::renameUses(SmallVectorImpl<Value *> &OpsToRename) { // Insert the possible copies into the def/use list. // They will become real copies if we find a real use for them, and never // created otherwise. - for (auto &PossibleCopy : ValueInfo.Infos) { + for (const auto &PossibleCopy : ValueInfo.Infos) { ValueDFS VD; // Determine where we are going to place the copy by the copy type. // The predicate info for branches always come first, they will get @@ -772,7 +772,7 @@ PredicateInfo::~PredicateInfo() { // Collect function pointers in set first, as SmallSet uses a SmallVector // internally and we have to remove the asserting value handles first. SmallPtrSet<Function *, 20> FunctionPtrs; - for (auto &F : CreatedDeclarations) + for (const auto &F : CreatedDeclarations) FunctionPtrs.insert(&*F); CreatedDeclarations.clear(); @@ -783,7 +783,7 @@ PredicateInfo::~PredicateInfo() { } } -Optional<PredicateConstraint> PredicateBase::getConstraint() const { +std::optional<PredicateConstraint> PredicateBase::getConstraint() const { switch (Type) { case PT_Assume: case PT_Branch: { @@ -800,7 +800,7 @@ Optional<PredicateConstraint> PredicateBase::getConstraint() const { CmpInst *Cmp = dyn_cast<CmpInst>(Condition); if (!Cmp) { // TODO: Make this an assertion once RenamedOp is fully accurate. - return None; + return std::nullopt; } CmpInst::Predicate Pred; @@ -813,7 +813,7 @@ Optional<PredicateConstraint> PredicateBase::getConstraint() const { OtherOp = Cmp->getOperand(0); } else { // TODO: Make this an assertion once RenamedOp is fully accurate. - return None; + return std::nullopt; } // Invert predicate along false edge. @@ -825,7 +825,7 @@ Optional<PredicateConstraint> PredicateBase::getConstraint() const { case PT_Switch: if (Condition != RenamedOp) { // TODO: Make this an assertion once RenamedOp is fully accurate. - return None; + return std::nullopt; } return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this)->CaseValue}}; diff --git a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index bec1db896efb..75ea9dc5dfc0 100644 --- a/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -100,6 +100,67 @@ bool llvm::isAllocaPromotable(const AllocaInst *AI) { namespace { +/// Helper for updating assignment tracking debug info when promoting allocas. +class AssignmentTrackingInfo { + /// DbgAssignIntrinsics linked to the alloca with at most one per variable + /// fragment. (i.e. not be a comprehensive set if there are multiple + /// dbg.assigns for one variable fragment). + SmallVector<DbgVariableIntrinsic *> DbgAssigns; + +public: + void init(AllocaInst *AI) { + SmallSet<DebugVariable, 2> Vars; + for (DbgAssignIntrinsic *DAI : at::getAssignmentMarkers(AI)) { + if (Vars.insert(DebugVariable(DAI)).second) + DbgAssigns.push_back(DAI); + } + } + + /// Update assignment tracking debug info given for the to-be-deleted store + /// \p ToDelete that stores to this alloca. + void updateForDeletedStore(StoreInst *ToDelete, DIBuilder &DIB) const { + // There's nothing to do if the alloca doesn't have any variables using + // assignment tracking. + if (DbgAssigns.empty()) { + assert(at::getAssignmentMarkers(ToDelete).empty()); + return; + } + + // Just leave dbg.assign intrinsics in place and remember that we've seen + // one for each variable fragment. + SmallSet<DebugVariable, 2> VarHasDbgAssignForStore; + for (DbgAssignIntrinsic *DAI : at::getAssignmentMarkers(ToDelete)) + VarHasDbgAssignForStore.insert(DebugVariable(DAI)); + + // It's possible for variables using assignment tracking to have no + // dbg.assign linked to this store. These are variables in DbgAssigns that + // are missing from VarHasDbgAssignForStore. Since there isn't a dbg.assign + // to mark the assignment - and the store is going to be deleted - insert a + // dbg.value to do that now. An untracked store may be either one that + // cannot be represented using assignment tracking (non-const offset or + // size) or one that is trackable but has had its DIAssignID attachment + // dropped accidentally. + for (auto *DAI : DbgAssigns) { + if (VarHasDbgAssignForStore.contains(DebugVariable(DAI))) + continue; + ConvertDebugDeclareToDebugValue(DAI, ToDelete, DIB); + } + } + + /// Update assignment tracking debug info given for the newly inserted PHI \p + /// NewPhi. + void updateForNewPhi(PHINode *NewPhi, DIBuilder &DIB) const { + // Regardless of the position of dbg.assigns relative to stores, the + // incoming values into a new PHI should be the same for the (imaginary) + // debug-phi. + for (auto *DAI : DbgAssigns) + ConvertDebugDeclareToDebugValue(DAI, NewPhi, DIB); + } + + void clear() { DbgAssigns.clear(); } + bool empty() { return DbgAssigns.empty(); } +}; + struct AllocaInfo { using DbgUserVec = SmallVector<DbgVariableIntrinsic *, 1>; @@ -110,7 +171,10 @@ struct AllocaInfo { BasicBlock *OnlyBlock; bool OnlyUsedInOneBlock; + /// Debug users of the alloca - does not include dbg.assign intrinsics. DbgUserVec DbgUsers; + /// Helper to update assignment tracking debug info. + AssignmentTrackingInfo AssignmentTracking; void clear() { DefiningBlocks.clear(); @@ -119,6 +183,7 @@ struct AllocaInfo { OnlyBlock = nullptr; OnlyUsedInOneBlock = true; DbgUsers.clear(); + AssignmentTracking.clear(); } /// Scan the uses of the specified alloca, filling in the AllocaInfo used @@ -150,8 +215,13 @@ struct AllocaInfo { OnlyUsedInOneBlock = false; } } - - findDbgUsers(DbgUsers, AI); + DbgUserVec AllDbgUsers; + findDbgUsers(AllDbgUsers, AI); + std::copy_if(AllDbgUsers.begin(), AllDbgUsers.end(), + std::back_inserter(DbgUsers), [](DbgVariableIntrinsic *DII) { + return !isa<DbgAssignIntrinsic>(DII); + }); + AssignmentTracking.init(AI); } }; @@ -251,6 +321,10 @@ struct PromoteMem2Reg { /// intrinsic if the alloca gets promoted. SmallVector<AllocaInfo::DbgUserVec, 8> AllocaDbgUsers; + /// For each alloca, keep an instance of a helper class that gives us an easy + /// way to update assignment tracking debug info if the alloca is promoted. + SmallVector<AssignmentTrackingInfo, 8> AllocaATInfo; + /// The set of basic blocks the renamer has already visited. SmallPtrSet<BasicBlock *, 16> Visited; @@ -309,6 +383,19 @@ static void addAssumeNonNull(AssumptionCache *AC, LoadInst *LI) { AC->registerAssumption(cast<AssumeInst>(CI)); } +static void convertMetadataToAssumes(LoadInst *LI, Value *Val, + const DataLayout &DL, AssumptionCache *AC, + const DominatorTree *DT) { + // If the load was marked as nonnull we don't want to lose that information + // when we erase this Load. So we preserve it with an assume. As !nonnull + // returns poison while assume violations are immediate undefined behavior, + // we can only do this if the value is known non-poison. + if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && + LI->getMetadata(LLVMContext::MD_noundef) && + !isKnownNonZero(Val, DL, 0, AC, LI, DT)) + addAssumeNonNull(AC, LI); +} + static void removeIntrinsicUsers(AllocaInst *AI) { // Knowing that this alloca is promotable, we know that it's safe to kill all // instructions except for load and store. @@ -401,13 +488,7 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, if (ReplVal == LI) ReplVal = PoisonValue::get(LI->getType()); - // If the load was marked as nonnull we don't want to lose - // that information when we erase this Load. So we preserve - // it with an assume. - if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && - !isKnownNonZero(ReplVal, DL, 0, AC, LI, &DT)) - addAssumeNonNull(AC, LI); - + convertMetadataToAssumes(LI, ReplVal, DL, AC, &DT); LI->replaceAllUsesWith(ReplVal); LI->eraseFromParent(); LBI.deleteValue(LI); @@ -417,17 +498,24 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, if (!Info.UsingBlocks.empty()) return false; // If not, we'll have to fall back for the remainder. + DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); + // Update assignment tracking info for the store we're going to delete. + Info.AssignmentTracking.updateForDeletedStore(Info.OnlyStore, DIB); + // Record debuginfo for the store and remove the declaration's // debuginfo. for (DbgVariableIntrinsic *DII : Info.DbgUsers) { if (DII->isAddressOfVariable()) { - DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DII, Info.OnlyStore, DIB); DII->eraseFromParent(); } else if (DII->getExpression()->startsWithDeref()) { DII->eraseFromParent(); } } + + // Remove dbg.assigns linked to the alloca as these are now redundant. + at::deleteAssignmentMarkers(AI); + // Remove the (now dead) store and alloca. Info.OnlyStore->eraseFromParent(); LBI.deleteValue(Info.OnlyStore); @@ -503,11 +591,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, ReplVal = std::prev(I)->second->getOperand(0); } - // Note, if the load was marked as nonnull we don't want to lose that - // information when we erase it. So we preserve it with an assume. - if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && - !isKnownNonZero(ReplVal, DL, 0, AC, LI, &DT)) - addAssumeNonNull(AC, LI); + convertMetadataToAssumes(LI, ReplVal, DL, AC, &DT); // If the replacement value is the load, this must occur in unreachable // code. @@ -520,12 +604,14 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, } // Remove the (now dead) stores and alloca. + DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); while (!AI->use_empty()) { StoreInst *SI = cast<StoreInst>(AI->user_back()); + // Update assignment tracking info for the store we're going to delete. + Info.AssignmentTracking.updateForDeletedStore(SI, DIB); // Record debuginfo for the store before removing it. for (DbgVariableIntrinsic *DII : Info.DbgUsers) { if (DII->isAddressOfVariable()) { - DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DII, SI, DIB); } } @@ -533,6 +619,8 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, LBI.deleteValue(SI); } + // Remove dbg.assigns linked to the alloca as these are now redundant. + at::deleteAssignmentMarkers(AI); AI->eraseFromParent(); // The alloca's debuginfo can be removed as well. @@ -548,6 +636,7 @@ void PromoteMem2Reg::run() { Function &F = *DT.getRoot()->getParent(); AllocaDbgUsers.resize(Allocas.size()); + AllocaATInfo.resize(Allocas.size()); AllocaInfo Info; LargeBlockInfo LBI; @@ -607,6 +696,8 @@ void PromoteMem2Reg::run() { // Remember the dbg.declare intrinsic describing this alloca, if any. if (!Info.DbgUsers.empty()) AllocaDbgUsers[AllocaNum] = Info.DbgUsers; + if (!Info.AssignmentTracking.empty()) + AllocaATInfo[AllocaNum] = Info.AssignmentTracking; // Keep the reverse mapping of the 'Allocas' array for the rename pass. AllocaLookup[Allocas[AllocaNum]] = AllocaNum; @@ -670,6 +761,8 @@ void PromoteMem2Reg::run() { // Remove the allocas themselves from the function. for (Instruction *A : Allocas) { + // Remove dbg.assigns linked to the alloca as these are now redundant. + at::deleteAssignmentMarkers(A); // If there are any uses of the alloca instructions left, they must be in // unreachable basic blocks that were not processed by walking the dominator // tree. Just delete the users now. @@ -923,6 +1016,7 @@ NextIteration: // The currently active variable for this block is now the PHI. IncomingVals[AllocaNo] = APN; + AllocaATInfo[AllocaNo].updateForNewPhi(APN, DIB); for (DbgVariableIntrinsic *DII : AllocaDbgUsers[AllocaNo]) if (DII->isAddressOfVariable()) ConvertDebugDeclareToDebugValue(DII, APN, DIB); @@ -956,17 +1050,11 @@ NextIteration: continue; Value *V = IncomingVals[AI->second]; - - // If the load was marked as nonnull we don't want to lose - // that information when we erase this Load. So we preserve - // it with an assume. - if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && - !isKnownNonZero(V, SQ.DL, 0, AC, LI, &DT)) - addAssumeNonNull(AC, LI); + convertMetadataToAssumes(LI, V, SQ.DL, AC, &DT); // Anything using the load now uses the current value. LI->replaceAllUsesWith(V); - BB->getInstList().erase(LI); + LI->eraseFromParent(); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { // Delete this instruction and mark the name as the current holder of the // value @@ -984,10 +1072,11 @@ NextIteration: // Record debuginfo for the store before removing it. IncomingLocs[AllocaNo] = SI->getDebugLoc(); + AllocaATInfo[AllocaNo].updateForDeletedStore(SI, DIB); for (DbgVariableIntrinsic *DII : AllocaDbgUsers[ai->second]) if (DII->isAddressOfVariable()) ConvertDebugDeclareToDebugValue(DII, SI, DIB); - BB->getInstList().erase(SI); + SI->eraseFromParent(); } } diff --git a/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp b/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp index 926427450682..c9ff94dc9744 100644 --- a/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp +++ b/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp @@ -57,11 +57,15 @@ static bool shouldConvertToRelLookupTable(Module &M, GlobalVariable &GV) { return false; ConstantArray *Array = dyn_cast<ConstantArray>(GV.getInitializer()); - // If values are not pointers, do not generate a relative lookup table. - if (!Array || !Array->getType()->getElementType()->isPointerTy()) + if (!Array) return false; + // If values are not 64-bit pointers, do not generate a relative lookup table. const DataLayout &DL = M.getDataLayout(); + Type *ElemType = Array->getType()->getElementType(); + if (!ElemType->isPointerTy() || DL.getPointerTypeSizeInBits(ElemType) != 64) + return false; + for (const Use &Op : Array->operands()) { Constant *ConstOp = cast<Constant>(&Op); GlobalValue *GVOp; diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 09a83f1ea094..8d03a0d8a2c4 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -16,11 +16,13 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueLattice.h" +#include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/IR/InstVisitor.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <utility> #include <vector> @@ -39,28 +41,257 @@ static ValueLatticeElement::MergeOptions getMaxWidenStepsOpts() { MaxNumRangeExtensions); } -namespace { +namespace llvm { -// Helper to check if \p LV is either a constant or a constant -// range with a single element. This should cover exactly the same cases as the -// old ValueLatticeElement::isConstant() and is intended to be used in the -// transition to ValueLatticeElement. -bool isConstant(const ValueLatticeElement &LV) { +bool SCCPSolver::isConstant(const ValueLatticeElement &LV) { return LV.isConstant() || (LV.isConstantRange() && LV.getConstantRange().isSingleElement()); } -// Helper to check if \p LV is either overdefined or a constant range with more -// than a single element. This should cover exactly the same cases as the old -// ValueLatticeElement::isOverdefined() and is intended to be used in the -// transition to ValueLatticeElement. -bool isOverdefined(const ValueLatticeElement &LV) { - return !LV.isUnknownOrUndef() && !isConstant(LV); +bool SCCPSolver::isOverdefined(const ValueLatticeElement &LV) { + return !LV.isUnknownOrUndef() && !SCCPSolver::isConstant(LV); } -} // namespace +static bool canRemoveInstruction(Instruction *I) { + if (wouldInstructionBeTriviallyDead(I)) + return true; -namespace llvm { + // Some instructions can be handled but are rejected above. Catch + // those cases by falling through to here. + // TODO: Mark globals as being constant earlier, so + // TODO: wouldInstructionBeTriviallyDead() knows that atomic loads + // TODO: are safe to remove. + return isa<LoadInst>(I); +} + +bool SCCPSolver::tryToReplaceWithConstant(Value *V) { + Constant *Const = nullptr; + if (V->getType()->isStructTy()) { + std::vector<ValueLatticeElement> IVs = getStructLatticeValueFor(V); + if (llvm::any_of(IVs, isOverdefined)) + return false; + std::vector<Constant *> ConstVals; + auto *ST = cast<StructType>(V->getType()); + for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { + ValueLatticeElement V = IVs[i]; + ConstVals.push_back(SCCPSolver::isConstant(V) + ? getConstant(V) + : UndefValue::get(ST->getElementType(i))); + } + Const = ConstantStruct::get(ST, ConstVals); + } else { + const ValueLatticeElement &IV = getLatticeValueFor(V); + if (isOverdefined(IV)) + return false; + + Const = SCCPSolver::isConstant(IV) ? getConstant(IV) + : UndefValue::get(V->getType()); + } + assert(Const && "Constant is nullptr here!"); + + // Replacing `musttail` instructions with constant breaks `musttail` invariant + // unless the call itself can be removed. + // Calls with "clang.arc.attachedcall" implicitly use the return value and + // those uses cannot be updated with a constant. + CallBase *CB = dyn_cast<CallBase>(V); + if (CB && ((CB->isMustTailCall() && + !canRemoveInstruction(CB)) || + CB->getOperandBundle(LLVMContext::OB_clang_arc_attachedcall))) { + Function *F = CB->getCalledFunction(); + + // Don't zap returns of the callee + if (F) + addToMustPreserveReturnsInFunctions(F); + + LLVM_DEBUG(dbgs() << " Can\'t treat the result of call " << *CB + << " as a constant\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Constant: " << *Const << " = " << *V << '\n'); + + // Replaces all of the uses of a variable with uses of the constant. + V->replaceAllUsesWith(Const); + return true; +} + +/// Try to replace signed instructions with their unsigned equivalent. +static bool replaceSignedInst(SCCPSolver &Solver, + SmallPtrSetImpl<Value *> &InsertedValues, + Instruction &Inst) { + // Determine if a signed value is known to be >= 0. + auto isNonNegative = [&Solver](Value *V) { + // If this value was constant-folded, it may not have a solver entry. + // Handle integers. Otherwise, return false. + if (auto *C = dyn_cast<Constant>(V)) { + auto *CInt = dyn_cast<ConstantInt>(C); + return CInt && !CInt->isNegative(); + } + const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); + return IV.isConstantRange(/*UndefAllowed=*/false) && + IV.getConstantRange().isAllNonNegative(); + }; + + Instruction *NewInst = nullptr; + switch (Inst.getOpcode()) { + // Note: We do not fold sitofp -> uitofp here because that could be more + // expensive in codegen and may not be reversible in the backend. + case Instruction::SExt: { + // If the source value is not negative, this is a zext. + Value *Op0 = Inst.getOperand(0); + if (InsertedValues.count(Op0) || !isNonNegative(Op0)) + return false; + NewInst = new ZExtInst(Op0, Inst.getType(), "", &Inst); + break; + } + case Instruction::AShr: { + // If the shifted value is not negative, this is a logical shift right. + Value *Op0 = Inst.getOperand(0); + if (InsertedValues.count(Op0) || !isNonNegative(Op0)) + return false; + NewInst = BinaryOperator::CreateLShr(Op0, Inst.getOperand(1), "", &Inst); + break; + } + case Instruction::SDiv: + case Instruction::SRem: { + // If both operands are not negative, this is the same as udiv/urem. + Value *Op0 = Inst.getOperand(0), *Op1 = Inst.getOperand(1); + if (InsertedValues.count(Op0) || InsertedValues.count(Op1) || + !isNonNegative(Op0) || !isNonNegative(Op1)) + return false; + auto NewOpcode = Inst.getOpcode() == Instruction::SDiv ? Instruction::UDiv + : Instruction::URem; + NewInst = BinaryOperator::Create(NewOpcode, Op0, Op1, "", &Inst); + break; + } + default: + return false; + } + + // Wire up the new instruction and update state. + assert(NewInst && "Expected replacement instruction"); + NewInst->takeName(&Inst); + InsertedValues.insert(NewInst); + Inst.replaceAllUsesWith(NewInst); + Solver.removeLatticeValueFor(&Inst); + Inst.eraseFromParent(); + return true; +} + +bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB, + SmallPtrSetImpl<Value *> &InsertedValues, + Statistic &InstRemovedStat, + Statistic &InstReplacedStat) { + bool MadeChanges = false; + for (Instruction &Inst : make_early_inc_range(BB)) { + if (Inst.getType()->isVoidTy()) + continue; + if (tryToReplaceWithConstant(&Inst)) { + if (canRemoveInstruction(&Inst)) + Inst.eraseFromParent(); + + MadeChanges = true; + ++InstRemovedStat; + } else if (replaceSignedInst(*this, InsertedValues, Inst)) { + MadeChanges = true; + ++InstReplacedStat; + } + } + return MadeChanges; +} + +bool SCCPSolver::removeNonFeasibleEdges(BasicBlock *BB, DomTreeUpdater &DTU, + BasicBlock *&NewUnreachableBB) const { + SmallPtrSet<BasicBlock *, 8> FeasibleSuccessors; + bool HasNonFeasibleEdges = false; + for (BasicBlock *Succ : successors(BB)) { + if (isEdgeFeasible(BB, Succ)) + FeasibleSuccessors.insert(Succ); + else + HasNonFeasibleEdges = true; + } + + // All edges feasible, nothing to do. + if (!HasNonFeasibleEdges) + return false; + + // SCCP can only determine non-feasible edges for br, switch and indirectbr. + Instruction *TI = BB->getTerminator(); + assert((isa<BranchInst>(TI) || isa<SwitchInst>(TI) || + isa<IndirectBrInst>(TI)) && + "Terminator must be a br, switch or indirectbr"); + + if (FeasibleSuccessors.size() == 0) { + // Branch on undef/poison, replace with unreachable. + SmallPtrSet<BasicBlock *, 8> SeenSuccs; + SmallVector<DominatorTree::UpdateType, 8> Updates; + for (BasicBlock *Succ : successors(BB)) { + Succ->removePredecessor(BB); + if (SeenSuccs.insert(Succ).second) + Updates.push_back({DominatorTree::Delete, BB, Succ}); + } + TI->eraseFromParent(); + new UnreachableInst(BB->getContext(), BB); + DTU.applyUpdatesPermissive(Updates); + } else if (FeasibleSuccessors.size() == 1) { + // Replace with an unconditional branch to the only feasible successor. + BasicBlock *OnlyFeasibleSuccessor = *FeasibleSuccessors.begin(); + SmallVector<DominatorTree::UpdateType, 8> Updates; + bool HaveSeenOnlyFeasibleSuccessor = false; + for (BasicBlock *Succ : successors(BB)) { + if (Succ == OnlyFeasibleSuccessor && !HaveSeenOnlyFeasibleSuccessor) { + // Don't remove the edge to the only feasible successor the first time + // we see it. We still do need to remove any multi-edges to it though. + HaveSeenOnlyFeasibleSuccessor = true; + continue; + } + + Succ->removePredecessor(BB); + Updates.push_back({DominatorTree::Delete, BB, Succ}); + } + + BranchInst::Create(OnlyFeasibleSuccessor, BB); + TI->eraseFromParent(); + DTU.applyUpdatesPermissive(Updates); + } else if (FeasibleSuccessors.size() > 1) { + SwitchInstProfUpdateWrapper SI(*cast<SwitchInst>(TI)); + SmallVector<DominatorTree::UpdateType, 8> Updates; + + // If the default destination is unfeasible it will never be taken. Replace + // it with a new block with a single Unreachable instruction. + BasicBlock *DefaultDest = SI->getDefaultDest(); + if (!FeasibleSuccessors.contains(DefaultDest)) { + if (!NewUnreachableBB) { + NewUnreachableBB = + BasicBlock::Create(DefaultDest->getContext(), "default.unreachable", + DefaultDest->getParent(), DefaultDest); + new UnreachableInst(DefaultDest->getContext(), NewUnreachableBB); + } + + SI->setDefaultDest(NewUnreachableBB); + Updates.push_back({DominatorTree::Delete, BB, DefaultDest}); + Updates.push_back({DominatorTree::Insert, BB, NewUnreachableBB}); + } + + for (auto CI = SI->case_begin(); CI != SI->case_end();) { + if (FeasibleSuccessors.contains(CI->getCaseSuccessor())) { + ++CI; + continue; + } + + BasicBlock *Succ = CI->getCaseSuccessor(); + Succ->removePredecessor(BB); + Updates.push_back({DominatorTree::Delete, BB, Succ}); + SI.removeCase(CI); + // Don't increment CI, as we removed a case. + } + + DTU.applyUpdatesPermissive(Updates); + } else { + llvm_unreachable("Must have at least one feasible successor"); + } + return true; +} /// Helper class for SCCPSolver. This implements the instruction visitor and /// holds all the state. @@ -270,6 +501,8 @@ private: void handleCallOverdefined(CallBase &CB); void handleCallResult(CallBase &CB); void handleCallArguments(CallBase &CB); + void handleExtractOfWithOverflow(ExtractValueInst &EVI, + const WithOverflowInst *WO, unsigned Idx); private: friend class InstVisitor<SCCPInstVisitor>; @@ -339,6 +572,13 @@ public: return A->second.PredInfo->getPredicateInfoFor(I); } + const LoopInfo &getLoopInfo(Function &F) { + auto A = AnalysisResults.find(&F); + assert(A != AnalysisResults.end() && A->second.LI && + "Need LoopInfo analysis results for function."); + return *A->second.LI; + } + DomTreeUpdater getDTU(Function &F) { auto A = AnalysisResults.find(&F); assert(A != AnalysisResults.end() && "Need analysis results for function."); @@ -442,6 +682,7 @@ public: bool isStructLatticeConstant(Function *F, StructType *STy); Constant *getConstant(const ValueLatticeElement &LV) const; + ConstantRange getConstantRange(const ValueLatticeElement &LV, Type *Ty) const; SmallPtrSetImpl<Function *> &getArgumentTrackedFunctions() { return TrackingIncomingArguments; @@ -454,6 +695,26 @@ public: for (auto &BB : *F) BBExecutable.erase(&BB); } + + void solveWhileResolvedUndefsIn(Module &M) { + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + solve(); + ResolvedUndefs = false; + for (Function &F : M) + ResolvedUndefs |= resolvedUndefsIn(F); + } + } + + void solveWhileResolvedUndefsIn(SmallVectorImpl<Function *> &WorkList) { + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + solve(); + ResolvedUndefs = false; + for (Function *F : WorkList) + ResolvedUndefs |= resolvedUndefsIn(*F); + } + } }; } // namespace llvm @@ -504,7 +765,7 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) { const auto &It = TrackedMultipleRetVals.find(std::make_pair(F, i)); assert(It != TrackedMultipleRetVals.end()); ValueLatticeElement LV = It->second; - if (!isConstant(LV)) + if (!SCCPSolver::isConstant(LV)) return false; } return true; @@ -522,6 +783,15 @@ Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const { return nullptr; } +ConstantRange +SCCPInstVisitor::getConstantRange(const ValueLatticeElement &LV, + Type *Ty) const { + assert(Ty->isIntOrIntVectorTy() && "Should be int or int vector"); + if (LV.isConstantRange()) + return LV.getConstantRange(); + return ConstantRange::getFull(Ty->getScalarSizeInBits()); +} + void SCCPInstVisitor::markArgInFuncSpecialization( Function *F, const SmallVectorImpl<ArgInfo> &Args) { assert(!Args.empty() && "Specialization without arguments"); @@ -820,13 +1090,10 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { // Fold the constant as we build. Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL); markConstant(&I, C); - } else if (I.getDestTy()->isIntegerTy()) { + } else if (I.getDestTy()->isIntegerTy() && + I.getSrcTy()->isIntOrIntVectorTy()) { auto &LV = getValueState(&I); - ConstantRange OpRange = - OpSt.isConstantRange() - ? OpSt.getConstantRange() - : ConstantRange::getFull( - I.getOperand(0)->getType()->getScalarSizeInBits()); + ConstantRange OpRange = getConstantRange(OpSt, I.getSrcTy()); Type *DestTy = I.getDestTy(); // Vectors where all elements have the same known constant range are treated @@ -846,6 +1113,33 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { markOverdefined(&I); } +void SCCPInstVisitor::handleExtractOfWithOverflow(ExtractValueInst &EVI, + const WithOverflowInst *WO, + unsigned Idx) { + Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); + ValueLatticeElement L = getValueState(LHS); + ValueLatticeElement R = getValueState(RHS); + addAdditionalUser(LHS, &EVI); + addAdditionalUser(RHS, &EVI); + if (L.isUnknownOrUndef() || R.isUnknownOrUndef()) + return; // Wait to resolve. + + Type *Ty = LHS->getType(); + ConstantRange LR = getConstantRange(L, Ty); + ConstantRange RR = getConstantRange(R, Ty); + if (Idx == 0) { + ConstantRange Res = LR.binaryOp(WO->getBinaryOp(), RR); + mergeInValue(&EVI, ValueLatticeElement::getRange(Res)); + } else { + assert(Idx == 1 && "Index can only be 0 or 1"); + ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + WO->getBinaryOp(), RR, WO->getNoWrapKind()); + if (NWRegion.contains(LR)) + return (void)markConstant(&EVI, ConstantInt::getFalse(EVI.getType())); + markOverdefined(&EVI); + } +} + void SCCPInstVisitor::visitExtractValueInst(ExtractValueInst &EVI) { // If this returns a struct, mark all elements over defined, we don't track // structs in structs. @@ -864,6 +1158,8 @@ void SCCPInstVisitor::visitExtractValueInst(ExtractValueInst &EVI) { Value *AggVal = EVI.getAggregateOperand(); if (AggVal->getType()->isStructTy()) { unsigned i = *EVI.idx_begin(); + if (auto *WO = dyn_cast<WithOverflowInst>(AggVal)) + return handleExtractOfWithOverflow(EVI, WO, i); ValueLatticeElement EltVal = getStructValueState(AggVal, i); mergeInValue(getValueState(&EVI), &EVI, EltVal); } else { @@ -879,7 +1175,7 @@ void SCCPInstVisitor::visitInsertValueInst(InsertValueInst &IVI) { // resolvedUndefsIn might mark I as overdefined. Bail out, even if we would // discover a concrete value later. - if (isOverdefined(ValueState[&IVI])) + if (SCCPSolver::isOverdefined(ValueState[&IVI])) return (void)markOverdefined(&IVI); // If this has more than one index, we can't handle it, drive all results to @@ -950,14 +1246,14 @@ void SCCPInstVisitor::visitUnaryOperator(Instruction &I) { ValueLatticeElement &IV = ValueState[&I]; // resolvedUndefsIn might mark I as overdefined. Bail out, even if we would // discover a concrete value later. - if (isOverdefined(IV)) + if (SCCPSolver::isOverdefined(IV)) return (void)markOverdefined(&I); // If something is unknown/undef, wait for it to resolve. if (V0State.isUnknownOrUndef()) return; - if (isConstant(V0State)) + if (SCCPSolver::isConstant(V0State)) if (Constant *C = ConstantFoldUnaryOpOperand(I.getOpcode(), getConstant(V0State), DL)) return (void)markConstant(IV, &I, C); @@ -984,8 +1280,10 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { // If either of the operands is a constant, try to fold it to a constant. // TODO: Use information from notconstant better. if ((V1State.isConstant() || V2State.isConstant())) { - Value *V1 = isConstant(V1State) ? getConstant(V1State) : I.getOperand(0); - Value *V2 = isConstant(V2State) ? getConstant(V2State) : I.getOperand(1); + Value *V1 = SCCPSolver::isConstant(V1State) ? getConstant(V1State) + : I.getOperand(0); + Value *V2 = SCCPSolver::isConstant(V2State) ? getConstant(V2State) + : I.getOperand(1); Value *R = simplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL)); auto *C = dyn_cast_or_null<Constant>(R); if (C) { @@ -1005,13 +1303,8 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { return markOverdefined(&I); // Try to simplify to a constant range. - ConstantRange A = ConstantRange::getFull(I.getType()->getScalarSizeInBits()); - ConstantRange B = ConstantRange::getFull(I.getType()->getScalarSizeInBits()); - if (V1State.isConstantRange()) - A = V1State.getConstantRange(); - if (V2State.isConstantRange()) - B = V2State.getConstantRange(); - + ConstantRange A = getConstantRange(V1State, I.getType()); + ConstantRange B = getConstantRange(V2State, I.getType()); ConstantRange R = A.binaryOp(cast<BinaryOperator>(&I)->getOpcode(), B); mergeInValue(&I, ValueLatticeElement::getRange(R)); @@ -1024,7 +1317,7 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { void SCCPInstVisitor::visitCmpInst(CmpInst &I) { // Do not cache this lookup, getValueState calls later in the function might // invalidate the reference. - if (isOverdefined(ValueState[&I])) + if (SCCPSolver::isOverdefined(ValueState[&I])) return (void)markOverdefined(&I); Value *Op1 = I.getOperand(0); @@ -1035,11 +1328,8 @@ void SCCPInstVisitor::visitCmpInst(CmpInst &I) { auto V1State = getValueState(Op1); auto V2State = getValueState(Op2); - Constant *C = V1State.getCompare(I.getPredicate(), I.getType(), V2State); + Constant *C = V1State.getCompare(I.getPredicate(), I.getType(), V2State, DL); if (C) { - // TODO: getCompare() currently has incorrect handling for unknown/undef. - if (isa<UndefValue>(C)) - return; ValueLatticeElement CV; CV.markConstant(C); mergeInValue(&I, CV); @@ -1048,7 +1338,7 @@ void SCCPInstVisitor::visitCmpInst(CmpInst &I) { // If operands are still unknown, wait for it to resolve. if ((V1State.isUnknownOrUndef() || V2State.isUnknownOrUndef()) && - !isConstant(ValueState[&I])) + !SCCPSolver::isConstant(ValueState[&I])) return; markOverdefined(&I); @@ -1057,7 +1347,7 @@ void SCCPInstVisitor::visitCmpInst(CmpInst &I) { // Handle getelementptr instructions. If all operands are constants then we // can turn this into a getelementptr ConstantExpr. void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { - if (isOverdefined(ValueState[&I])) + if (SCCPSolver::isOverdefined(ValueState[&I])) return (void)markOverdefined(&I); SmallVector<Constant *, 8> Operands; @@ -1068,7 +1358,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { if (State.isUnknownOrUndef()) return; // Operands are not resolved yet. - if (isOverdefined(State)) + if (SCCPSolver::isOverdefined(State)) return (void)markOverdefined(&I); if (Constant *C = getConstant(State)) { @@ -1080,7 +1370,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) { } Constant *Ptr = Operands[0]; - auto Indices = makeArrayRef(Operands.begin() + 1, Operands.end()); + auto Indices = ArrayRef(Operands.begin() + 1, Operands.end()); Constant *C = ConstantExpr::getGetElementPtr(I.getSourceElementType(), Ptr, Indices); markConstant(&I, C); @@ -1136,7 +1426,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) { ValueLatticeElement &IV = ValueState[&I]; - if (isConstant(PtrVal)) { + if (SCCPSolver::isConstant(PtrVal)) { Constant *Ptr = getConstant(PtrVal); // load null is undefined. @@ -1191,17 +1481,19 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) { for (const Use &A : CB.args()) { if (A.get()->getType()->isStructTy()) return markOverdefined(&CB); // Can't handle struct args. + if (A.get()->getType()->isMetadataTy()) + continue; // Carried in CB, not allowed in Operands. ValueLatticeElement State = getValueState(A); if (State.isUnknownOrUndef()) return; // Operands are not resolved yet. - if (isOverdefined(State)) + if (SCCPSolver::isOverdefined(State)) return (void)markOverdefined(&CB); - assert(isConstant(State) && "Unknown state!"); + assert(SCCPSolver::isConstant(State) && "Unknown state!"); Operands.push_back(getConstant(State)); } - if (isOverdefined(getValueState(&CB))) + if (SCCPSolver::isOverdefined(getValueState(&CB))) return (void)markOverdefined(&CB); // If we can constant fold this, mark the result of the call as a @@ -1219,8 +1511,7 @@ void SCCPInstVisitor::handleCallArguments(CallBase &CB) { // If this is a local function that doesn't have its address taken, mark its // entry block executable and merge in the actual arguments to the call into // the formal arguments of the function. - if (!TrackingIncomingArguments.empty() && - TrackingIncomingArguments.count(F)) { + if (TrackingIncomingArguments.count(F)) { markBlockExecutable(&F->front()); // Propagate information from this call site into the callee. @@ -1259,7 +1550,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { const auto *PI = getPredicateInfoFor(&CB); assert(PI && "Missing predicate info for ssa.copy"); - const Optional<PredicateConstraint> &Constraint = PI->getConstraint(); + const std::optional<PredicateConstraint> &Constraint = + PI->getConstraint(); if (!Constraint) { mergeInValue(ValueState[&CB], &CB, CopyOfVal); return; @@ -1287,10 +1579,7 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { // Combine range info for the original value with the new range from the // condition. - auto CopyOfCR = CopyOfVal.isConstantRange() - ? CopyOfVal.getConstantRange() - : ConstantRange::getFull( - DL.getTypeSizeInBits(CopyOf->getType())); + auto CopyOfCR = getConstantRange(CopyOfVal, CopyOf->getType()); auto NewCR = ImposedCR.intersectWith(CopyOfCR); // If the existing information is != x, do not use the information from // a chained predicate, as the != x information is more likely to be @@ -1308,9 +1597,10 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { IV, &CB, ValueLatticeElement::getRange(NewCR, /*MayIncludeUndef*/ false)); return; - } else if (Pred == CmpInst::ICMP_EQ && CondVal.isConstant()) { + } else if (Pred == CmpInst::ICMP_EQ && + (CondVal.isConstant() || CondVal.isNotConstant())) { // For non-integer values or integer constant expressions, only - // propagate equal constants. + // propagate equal constants or not-constants. addAdditionalUser(OtherOp, &CB); mergeInValue(IV, &CB, CondVal); return; @@ -1332,11 +1622,7 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { SmallVector<ConstantRange, 2> OpRanges; for (Value *Op : II->args()) { const ValueLatticeElement &State = getValueState(Op); - if (State.isConstantRange()) - OpRanges.push_back(State.getConstantRange()); - else - OpRanges.push_back( - ConstantRange::getFull(Op->getType()->getScalarSizeInBits())); + OpRanges.push_back(getConstantRange(State, Op->getType())); } ConstantRange Result = @@ -1498,6 +1784,9 @@ bool SCCPInstVisitor::resolvedUndefsIn(Function &F) { } } + LLVM_DEBUG(if (MadeChange) dbgs() + << "\nResolved undefs in " << F.getName() << '\n'); + return MadeChange; } @@ -1525,6 +1814,10 @@ const PredicateBase *SCCPSolver::getPredicateInfoFor(Instruction *I) { return Visitor->getPredicateInfoFor(I); } +const LoopInfo &SCCPSolver::getLoopInfo(Function &F) { + return Visitor->getLoopInfo(F); +} + DomTreeUpdater SCCPSolver::getDTU(Function &F) { return Visitor->getDTU(F); } void SCCPSolver::trackValueOfGlobalVariable(GlobalVariable *GV) { @@ -1557,6 +1850,15 @@ bool SCCPSolver::resolvedUndefsIn(Function &F) { return Visitor->resolvedUndefsIn(F); } +void SCCPSolver::solveWhileResolvedUndefsIn(Module &M) { + Visitor->solveWhileResolvedUndefsIn(M); +} + +void +SCCPSolver::solveWhileResolvedUndefsIn(SmallVectorImpl<Function *> &WorkList) { + Visitor->solveWhileResolvedUndefsIn(WorkList); +} + bool SCCPSolver::isBlockExecutable(BasicBlock *BB) const { return Visitor->isBlockExecutable(BB); } diff --git a/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/llvm/lib/Transforms/Utils/SSAUpdater.cpp index 37019e3bf95b..2520aa5d9db0 100644 --- a/llvm/lib/Transforms/Utils/SSAUpdater.cpp +++ b/llvm/lib/Transforms/Utils/SSAUpdater.cpp @@ -434,7 +434,7 @@ void LoadAndStorePromoter::run(const SmallVectorImpl<Instruction *> &Insts) { replaceLoadWithValue(ALoad, NewVal); // Avoid assertions in unreachable code. - if (NewVal == ALoad) NewVal = UndefValue::get(NewVal->getType()); + if (NewVal == ALoad) NewVal = PoisonValue::get(NewVal->getType()); ALoad->replaceAllUsesWith(NewVal); ReplacedLoads[ALoad] = NewVal; } diff --git a/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp b/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp index 7de76b86817b..cad7ff64c01f 100644 --- a/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp +++ b/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp @@ -51,7 +51,7 @@ unsigned SSAUpdaterBulk::AddVariable(StringRef Name, Type *Ty) { void SSAUpdaterBulk::AddAvailableValue(unsigned Var, BasicBlock *BB, Value *V) { assert(Var < Rewrites.size() && "Variable not found!"); LLVM_DEBUG(dbgs() << "SSAUpdater: Var=" << Var - << ": added new available value" << *V << " in " + << ": added new available value " << *V << " in " << BB->getName() << "\n"); Rewrites[Var].Defines[BB] = V; } diff --git a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp index 5e92b9852a9f..691ee00bd831 100644 --- a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp +++ b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp @@ -26,34 +26,42 @@ using namespace llvm; namespace { -static cl::opt<bool> SampleProfileEvenCountDistribution( - "sample-profile-even-count-distribution", cl::init(true), cl::Hidden, - cl::desc("Try to evenly distribute counts when there are multiple equally " +static cl::opt<bool> SampleProfileEvenFlowDistribution( + "sample-profile-even-flow-distribution", cl::init(true), cl::Hidden, + cl::desc("Try to evenly distribute flow when there are multiple equally " "likely options.")); -static cl::opt<unsigned> SampleProfileMaxDfsCalls( - "sample-profile-max-dfs-calls", cl::init(10), cl::Hidden, - cl::desc("Maximum number of dfs iterations for even count distribution.")); +static cl::opt<bool> SampleProfileRebalanceUnknown( + "sample-profile-rebalance-unknown", cl::init(true), cl::Hidden, + cl::desc("Evenly re-distribute flow among unknown subgraphs.")); -static cl::opt<unsigned> SampleProfileProfiCostInc( - "sample-profile-profi-cost-inc", cl::init(10), cl::Hidden, - cl::desc("A cost of increasing a block's count by one.")); +static cl::opt<bool> SampleProfileJoinIslands( + "sample-profile-join-islands", cl::init(true), cl::Hidden, + cl::desc("Join isolated components having positive flow.")); -static cl::opt<unsigned> SampleProfileProfiCostDec( - "sample-profile-profi-cost-dec", cl::init(20), cl::Hidden, - cl::desc("A cost of decreasing a block's count by one.")); +static cl::opt<unsigned> SampleProfileProfiCostBlockInc( + "sample-profile-profi-cost-block-inc", cl::init(10), cl::Hidden, + cl::desc("The cost of increasing a block's count by one.")); -static cl::opt<unsigned> SampleProfileProfiCostIncZero( - "sample-profile-profi-cost-inc-zero", cl::init(11), cl::Hidden, - cl::desc("A cost of increasing a count of zero-weight block by one.")); +static cl::opt<unsigned> SampleProfileProfiCostBlockDec( + "sample-profile-profi-cost-block-dec", cl::init(20), cl::Hidden, + cl::desc("The cost of decreasing a block's count by one.")); -static cl::opt<unsigned> SampleProfileProfiCostIncEntry( - "sample-profile-profi-cost-inc-entry", cl::init(40), cl::Hidden, - cl::desc("A cost of increasing the entry block's count by one.")); +static cl::opt<unsigned> SampleProfileProfiCostBlockEntryInc( + "sample-profile-profi-cost-block-entry-inc", cl::init(40), cl::Hidden, + cl::desc("The cost of increasing the entry block's count by one.")); -static cl::opt<unsigned> SampleProfileProfiCostDecEntry( - "sample-profile-profi-cost-dec-entry", cl::init(10), cl::Hidden, - cl::desc("A cost of decreasing the entry block's count by one.")); +static cl::opt<unsigned> SampleProfileProfiCostBlockEntryDec( + "sample-profile-profi-cost-block-entry-dec", cl::init(10), cl::Hidden, + cl::desc("The cost of decreasing the entry block's count by one.")); + +static cl::opt<unsigned> SampleProfileProfiCostBlockZeroInc( + "sample-profile-profi-cost-block-zero-inc", cl::init(11), cl::Hidden, + cl::desc("The cost of increasing a count of zero-weight block by one.")); + +static cl::opt<unsigned> SampleProfileProfiCostBlockUnknownInc( + "sample-profile-profi-cost-block-unknown-inc", cl::init(0), cl::Hidden, + cl::desc("The cost of increasing an unknown block's count by one.")); /// A value indicating an infinite flow/capacity/weight of a block/edge. /// Not using numeric_limits<int64_t>::max(), as the values can be summed up @@ -76,6 +84,8 @@ static constexpr int64_t INF = ((int64_t)1) << 50; /// minimum total cost respecting the given edge capacities. class MinCostMaxFlow { public: + MinCostMaxFlow(const ProfiParams &Params) : Params(Params) {} + // Initialize algorithm's data structures for a network of a given size. void initialize(uint64_t NodeCount, uint64_t SourceNode, uint64_t SinkNode) { Source = SourceNode; @@ -83,13 +93,15 @@ public: Nodes = std::vector<Node>(NodeCount); Edges = std::vector<std::vector<Edge>>(NodeCount, std::vector<Edge>()); - if (SampleProfileEvenCountDistribution) + if (Params.EvenFlowDistribution) AugmentingEdges = std::vector<std::vector<Edge *>>(NodeCount, std::vector<Edge *>()); } // Run the algorithm. int64_t run() { + LLVM_DEBUG(dbgs() << "Starting profi for " << Nodes.size() << " nodes\n"); + // Iteratively find an augmentation path/dag in the network and send the // flow along its edges size_t AugmentationIters = applyFlowAugmentation(); @@ -148,7 +160,7 @@ public: /// Returns a list of pairs (target node, amount of flow to the target). const std::vector<std::pair<uint64_t, int64_t>> getFlow(uint64_t Src) const { std::vector<std::pair<uint64_t, int64_t>> Flow; - for (auto &Edge : Edges[Src]) { + for (const auto &Edge : Edges[Src]) { if (Edge.Flow > 0) Flow.push_back(std::make_pair(Edge.Dst, Edge.Flow)); } @@ -158,7 +170,7 @@ public: /// Get the total flow between a pair of nodes. int64_t getFlow(uint64_t Src, uint64_t Dst) const { int64_t Flow = 0; - for (auto &Edge : Edges[Src]) { + for (const auto &Edge : Edges[Src]) { if (Edge.Dst == Dst) { Flow += Edge.Flow; } @@ -166,11 +178,6 @@ public: return Flow; } - /// A cost of taking an unlikely jump. - static constexpr int64_t AuxCostUnlikely = ((int64_t)1) << 30; - /// Minimum BaseDistance for the jump distance values in island joining. - static constexpr uint64_t MinBaseDistance = 10000; - private: /// Iteratively find an augmentation path/dag in the network and send the /// flow along its edges. The method returns the number of applied iterations. @@ -180,7 +187,7 @@ private: uint64_t PathCapacity = computeAugmentingPathCapacity(); while (PathCapacity > 0) { bool Progress = false; - if (SampleProfileEvenCountDistribution) { + if (Params.EvenFlowDistribution) { // Identify node/edge candidates for augmentation identifyShortestEdges(PathCapacity); @@ -253,7 +260,7 @@ private: // from Source to Target; it follows from inequalities // Dist[Source, Target] >= Dist[Source, V] + Dist[V, Target] // >= Dist[Source, V] - if (!SampleProfileEvenCountDistribution && Nodes[Target].Distance == 0) + if (!Params.EvenFlowDistribution && Nodes[Target].Distance == 0) break; if (Nodes[Src].Distance > Nodes[Target].Distance) continue; @@ -342,7 +349,7 @@ private: if (Edge.OnShortestPath) { // If we haven't seen Edge.Dst so far, continue DFS search there - if (Dst.Discovery == 0 && Dst.NumCalls < SampleProfileMaxDfsCalls) { + if (Dst.Discovery == 0 && Dst.NumCalls < MaxDfsCalls) { Dst.Discovery = ++Time; Stack.emplace(Edge.Dst, 0); Dst.NumCalls++; @@ -512,6 +519,9 @@ private: } } + /// Maximum number of DFS iterations for DAG finding. + static constexpr uint64_t MaxDfsCalls = 10; + /// A node in a flow network. struct Node { /// The cost of the cheapest path from the source to the current node. @@ -566,12 +576,11 @@ private: uint64_t Target; /// Augmenting edges. std::vector<std::vector<Edge *>> AugmentingEdges; + /// Params for flow computation. + const ProfiParams &Params; }; -constexpr int64_t MinCostMaxFlow::AuxCostUnlikely; -constexpr uint64_t MinCostMaxFlow::MinBaseDistance; - -/// A post-processing adjustment of control flow. It applies two steps by +/// A post-processing adjustment of the control flow. It applies two steps by /// rerouting some flow and making it more realistic: /// /// - First, it removes all isolated components ("islands") with a positive flow @@ -589,18 +598,20 @@ constexpr uint64_t MinCostMaxFlow::MinBaseDistance; /// class FlowAdjuster { public: - FlowAdjuster(FlowFunction &Func) : Func(Func) { - assert(Func.Blocks[Func.Entry].isEntry() && - "incorrect index of the entry block"); - } + FlowAdjuster(const ProfiParams &Params, FlowFunction &Func) + : Params(Params), Func(Func) {} - // Run the post-processing + /// Apply the post-processing. void run() { - /// Adjust the flow to get rid of isolated components. - joinIsolatedComponents(); + if (Params.JoinIslands) { + // Adjust the flow to get rid of isolated components + joinIsolatedComponents(); + } - /// Rebalance the flow inside unknown subgraphs. - rebalanceUnknownSubgraphs(); + if (Params.RebalanceUnknown) { + // Rebalance the flow inside unknown subgraphs + rebalanceUnknownSubgraphs(); + } } private: @@ -640,7 +651,7 @@ private: while (!Queue.empty()) { Src = Queue.front(); Queue.pop(); - for (auto Jump : Func.Blocks[Src].SuccJumps) { + for (auto *Jump : Func.Blocks[Src].SuccJumps) { uint64_t Dst = Jump->Target; if (Jump->Flow > 0 && !Visited[Dst]) { Queue.push(Dst); @@ -691,7 +702,7 @@ private: (Func.Blocks[Src].isExit() && Target == AnyExitBlock)) break; - for (auto Jump : Func.Blocks[Src].SuccJumps) { + for (auto *Jump : Func.Blocks[Src].SuccJumps) { uint64_t Dst = Jump->Target; int64_t JumpDist = jumpDistance(Jump); if (Distance[Dst] > Distance[Src] + JumpDist) { @@ -739,15 +750,15 @@ private: /// To capture this objective with integer distances, we round off fractional /// parts to a multiple of 1 / BaseDistance. int64_t jumpDistance(FlowJump *Jump) const { + if (Jump->IsUnlikely) + return Params.CostUnlikely; uint64_t BaseDistance = - std::max(static_cast<uint64_t>(MinCostMaxFlow::MinBaseDistance), + std::max(FlowAdjuster::MinBaseDistance, std::min(Func.Blocks[Func.Entry].Flow, - MinCostMaxFlow::AuxCostUnlikely / NumBlocks())); - if (Jump->IsUnlikely) - return MinCostMaxFlow::AuxCostUnlikely; + Params.CostUnlikely / (2 * (NumBlocks() + 1)))); if (Jump->Flow > 0) return BaseDistance + BaseDistance / Jump->Flow; - return BaseDistance * NumBlocks(); + return 2 * BaseDistance * (NumBlocks() + 1); }; uint64_t NumBlocks() const { return Func.Blocks.size(); } @@ -758,31 +769,30 @@ private: /// blocks. Then it verifies if flow rebalancing is feasible and applies it. void rebalanceUnknownSubgraphs() { // Try to find unknown subgraphs from each block - for (uint64_t I = 0; I < Func.Blocks.size(); I++) { - auto SrcBlock = &Func.Blocks[I]; + for (const FlowBlock &SrcBlock : Func.Blocks) { // Verify if rebalancing rooted at SrcBlock is feasible - if (!canRebalanceAtRoot(SrcBlock)) + if (!canRebalanceAtRoot(&SrcBlock)) continue; // Find an unknown subgraphs starting at SrcBlock. Along the way, // fill in known destinations and intermediate unknown blocks. std::vector<FlowBlock *> UnknownBlocks; std::vector<FlowBlock *> KnownDstBlocks; - findUnknownSubgraph(SrcBlock, KnownDstBlocks, UnknownBlocks); + findUnknownSubgraph(&SrcBlock, KnownDstBlocks, UnknownBlocks); // Verify if rebalancing of the subgraph is feasible. If the search is // successful, find the unique destination block (which can be null) FlowBlock *DstBlock = nullptr; - if (!canRebalanceSubgraph(SrcBlock, KnownDstBlocks, UnknownBlocks, + if (!canRebalanceSubgraph(&SrcBlock, KnownDstBlocks, UnknownBlocks, DstBlock)) continue; // We cannot rebalance subgraphs containing cycles among unknown blocks - if (!isAcyclicSubgraph(SrcBlock, DstBlock, UnknownBlocks)) + if (!isAcyclicSubgraph(&SrcBlock, DstBlock, UnknownBlocks)) continue; // Rebalance the flow - rebalanceUnknownSubgraph(SrcBlock, DstBlock, UnknownBlocks); + rebalanceUnknownSubgraph(&SrcBlock, DstBlock, UnknownBlocks); } } @@ -790,13 +800,13 @@ private: bool canRebalanceAtRoot(const FlowBlock *SrcBlock) { // Do not attempt to find unknown subgraphs from an unknown or a // zero-flow block - if (SrcBlock->UnknownWeight || SrcBlock->Flow == 0) + if (SrcBlock->HasUnknownWeight || SrcBlock->Flow == 0) return false; // Do not attempt to process subgraphs from a block w/o unknown sucessors bool HasUnknownSuccs = false; - for (auto Jump : SrcBlock->SuccJumps) { - if (Func.Blocks[Jump->Target].UnknownWeight) { + for (auto *Jump : SrcBlock->SuccJumps) { + if (Func.Blocks[Jump->Target].HasUnknownWeight) { HasUnknownSuccs = true; break; } @@ -823,7 +833,7 @@ private: auto &Block = Func.Blocks[Queue.front()]; Queue.pop(); // Process blocks reachable from Block - for (auto Jump : Block.SuccJumps) { + for (auto *Jump : Block.SuccJumps) { // If Jump can be ignored, skip it if (ignoreJump(SrcBlock, nullptr, Jump)) continue; @@ -834,7 +844,7 @@ private: continue; // Process block Dst Visited[Dst] = true; - if (!Func.Blocks[Dst].UnknownWeight) { + if (!Func.Blocks[Dst].HasUnknownWeight) { KnownDstBlocks.push_back(&Func.Blocks[Dst]); } else { Queue.push(Dst); @@ -860,7 +870,7 @@ private: DstBlock = KnownDstBlocks.empty() ? nullptr : KnownDstBlocks.front(); // Verify sinks of the subgraph - for (auto Block : UnknownBlocks) { + for (auto *Block : UnknownBlocks) { if (Block->SuccJumps.empty()) { // If there are multiple (known and unknown) sinks, we can't rebalance if (DstBlock != nullptr) @@ -868,7 +878,7 @@ private: continue; } size_t NumIgnoredJumps = 0; - for (auto Jump : Block->SuccJumps) { + for (auto *Jump : Block->SuccJumps) { if (ignoreJump(SrcBlock, DstBlock, Jump)) NumIgnoredJumps++; } @@ -897,11 +907,11 @@ private: return false; // Ignore jumps out of SrcBlock to known blocks - if (!JumpTarget->UnknownWeight && JumpSource == SrcBlock) + if (!JumpTarget->HasUnknownWeight && JumpSource == SrcBlock) return true; // Ignore jumps to known blocks with zero flow - if (!JumpTarget->UnknownWeight && JumpTarget->Flow == 0) + if (!JumpTarget->HasUnknownWeight && JumpTarget->Flow == 0) return true; return false; @@ -914,14 +924,14 @@ private: // Extract local in-degrees in the considered subgraph auto LocalInDegree = std::vector<uint64_t>(NumBlocks(), 0); auto fillInDegree = [&](const FlowBlock *Block) { - for (auto Jump : Block->SuccJumps) { + for (auto *Jump : Block->SuccJumps) { if (ignoreJump(SrcBlock, DstBlock, Jump)) continue; LocalInDegree[Jump->Target]++; } }; fillInDegree(SrcBlock); - for (auto Block : UnknownBlocks) { + for (auto *Block : UnknownBlocks) { fillInDegree(Block); } // A loop containing SrcBlock @@ -939,11 +949,11 @@ private: break; // Keep an acyclic order of unknown blocks - if (Block->UnknownWeight && Block != SrcBlock) + if (Block->HasUnknownWeight && Block != SrcBlock) AcyclicOrder.push_back(Block); // Add to the queue all successors with zero local in-degree - for (auto Jump : Block->SuccJumps) { + for (auto *Jump : Block->SuccJumps) { if (ignoreJump(SrcBlock, DstBlock, Jump)) continue; uint64_t Dst = Jump->Target; @@ -972,7 +982,7 @@ private: // Ditribute flow from the source block uint64_t BlockFlow = 0; // SrcBlock's flow is the sum of outgoing flows along non-ignored jumps - for (auto Jump : SrcBlock->SuccJumps) { + for (auto *Jump : SrcBlock->SuccJumps) { if (ignoreJump(SrcBlock, DstBlock, Jump)) continue; BlockFlow += Jump->Flow; @@ -980,11 +990,11 @@ private: rebalanceBlock(SrcBlock, DstBlock, SrcBlock, BlockFlow); // Ditribute flow from the remaining blocks - for (auto Block : UnknownBlocks) { - assert(Block->UnknownWeight && "incorrect unknown subgraph"); + for (auto *Block : UnknownBlocks) { + assert(Block->HasUnknownWeight && "incorrect unknown subgraph"); uint64_t BlockFlow = 0; // Block's flow is the sum of incoming flows - for (auto Jump : Block->PredJumps) { + for (auto *Jump : Block->PredJumps) { BlockFlow += Jump->Flow; } Block->Flow = BlockFlow; @@ -998,7 +1008,7 @@ private: const FlowBlock *Block, uint64_t BlockFlow) { // Process all successor jumps and update corresponding flow values size_t BlockDegree = 0; - for (auto Jump : Block->SuccJumps) { + for (auto *Jump : Block->SuccJumps) { if (ignoreJump(SrcBlock, DstBlock, Jump)) continue; BlockDegree++; @@ -1011,7 +1021,7 @@ private: // Each of the Block's successors gets the following amount of flow. // Rounding the value up so that all flow is propagated uint64_t SuccFlow = (BlockFlow + BlockDegree - 1) / BlockDegree; - for (auto Jump : Block->SuccJumps) { + for (auto *Jump : Block->SuccJumps) { if (ignoreJump(SrcBlock, DstBlock, Jump)) continue; uint64_t Flow = std::min(SuccFlow, BlockFlow); @@ -1023,104 +1033,88 @@ private: /// A constant indicating an arbitrary exit block of a function. static constexpr uint64_t AnyExitBlock = uint64_t(-1); + /// Minimum BaseDistance for the jump distance values in island joining. + static constexpr uint64_t MinBaseDistance = 10000; + /// Params for flow computation. + const ProfiParams &Params; /// The function. FlowFunction &Func; }; +std::pair<int64_t, int64_t> assignBlockCosts(const ProfiParams &Params, + const FlowBlock &Block); +std::pair<int64_t, int64_t> assignJumpCosts(const ProfiParams &Params, + const FlowJump &Jump); + /// Initializing flow network for a given function. /// -/// Every block is split into three nodes that are responsible for (i) an -/// incoming flow, (ii) an outgoing flow, and (iii) penalizing an increase or +/// Every block is split into two nodes that are responsible for (i) an +/// incoming flow, (ii) an outgoing flow; they penalize an increase or a /// reduction of the block weight. -void initializeNetwork(MinCostMaxFlow &Network, FlowFunction &Func) { +void initializeNetwork(const ProfiParams &Params, MinCostMaxFlow &Network, + FlowFunction &Func) { uint64_t NumBlocks = Func.Blocks.size(); assert(NumBlocks > 1 && "Too few blocks in a function"); - LLVM_DEBUG(dbgs() << "Initializing profi for " << NumBlocks << " blocks\n"); + uint64_t NumJumps = Func.Jumps.size(); + assert(NumJumps > 0 && "Too few jumps in a function"); - // Pre-process data: make sure the entry weight is at least 1 - if (Func.Blocks[Func.Entry].Weight == 0) { - Func.Blocks[Func.Entry].Weight = 1; - } // Introducing dummy source/sink pairs to allow flow circulation. - // The nodes corresponding to blocks of Func have indicies in the range - // [0..3 * NumBlocks); the dummy nodes are indexed by the next four values. - uint64_t S = 3 * NumBlocks; + // The nodes corresponding to blocks of the function have indicies in + // the range [0 .. 2 * NumBlocks); the dummy sources/sinks are indexed by the + // next four values. + uint64_t S = 2 * NumBlocks; uint64_t T = S + 1; uint64_t S1 = S + 2; uint64_t T1 = S + 3; - Network.initialize(3 * NumBlocks + 4, S1, T1); + Network.initialize(2 * NumBlocks + 4, S1, T1); - // Create three nodes for every block of the function + // Initialize nodes of the flow network for (uint64_t B = 0; B < NumBlocks; B++) { auto &Block = Func.Blocks[B]; - assert((!Block.UnknownWeight || Block.Weight == 0 || Block.isEntry()) && - "non-zero weight of a block w/o weight except for an entry"); - // Split every block into two nodes - uint64_t Bin = 3 * B; - uint64_t Bout = 3 * B + 1; - uint64_t Baux = 3 * B + 2; - if (Block.Weight > 0) { - Network.addEdge(S1, Bout, Block.Weight, 0); - Network.addEdge(Bin, T1, Block.Weight, 0); - } + // Split every block into two auxiliary nodes to allow + // increase/reduction of the block count. + uint64_t Bin = 2 * B; + uint64_t Bout = 2 * B + 1; // Edges from S and to T - assert((!Block.isEntry() || !Block.isExit()) && - "a block cannot be an entry and an exit"); if (Block.isEntry()) { Network.addEdge(S, Bin, 0); } else if (Block.isExit()) { Network.addEdge(Bout, T, 0); } - // An auxiliary node to allow increase/reduction of block counts: - // We assume that decreasing block counts is more expensive than increasing, - // and thus, setting separate costs here. In the future we may want to tune - // the relative costs so as to maximize the quality of generated profiles. - int64_t AuxCostInc = SampleProfileProfiCostInc; - int64_t AuxCostDec = SampleProfileProfiCostDec; - if (Block.UnknownWeight) { - // Do not penalize changing weights of blocks w/o known profile count - AuxCostInc = 0; - AuxCostDec = 0; - } else { - // Increasing the count for "cold" blocks with zero initial count is more - // expensive than for "hot" ones - if (Block.Weight == 0) { - AuxCostInc = SampleProfileProfiCostIncZero; - } - // Modifying the count of the entry block is expensive - if (Block.isEntry()) { - AuxCostInc = SampleProfileProfiCostIncEntry; - AuxCostDec = SampleProfileProfiCostDecEntry; - } - } - // For blocks with self-edges, do not penalize a reduction of the count, - // as all of the increase can be attributed to the self-edge - if (Block.HasSelfEdge) { - AuxCostDec = 0; - } + // Assign costs for increasing/decreasing the block counts + auto [AuxCostInc, AuxCostDec] = assignBlockCosts(Params, Block); - Network.addEdge(Bin, Baux, AuxCostInc); - Network.addEdge(Baux, Bout, AuxCostInc); + // Add the corresponding edges to the network + Network.addEdge(Bin, Bout, AuxCostInc); if (Block.Weight > 0) { - Network.addEdge(Bout, Baux, AuxCostDec); - Network.addEdge(Baux, Bin, AuxCostDec); + Network.addEdge(Bout, Bin, Block.Weight, AuxCostDec); + Network.addEdge(S1, Bout, Block.Weight, 0); + Network.addEdge(Bin, T1, Block.Weight, 0); } } - // Creating edges for every jump - for (auto &Jump : Func.Jumps) { - uint64_t Src = Jump.Source; - uint64_t Dst = Jump.Target; - if (Src != Dst) { - uint64_t SrcOut = 3 * Src + 1; - uint64_t DstIn = 3 * Dst; - uint64_t Cost = Jump.IsUnlikely ? MinCostMaxFlow::AuxCostUnlikely : 0; - Network.addEdge(SrcOut, DstIn, Cost); + // Initialize edges of the flow network + for (uint64_t J = 0; J < NumJumps; J++) { + auto &Jump = Func.Jumps[J]; + + // Get the endpoints corresponding to the jump + uint64_t Jin = 2 * Jump.Source + 1; + uint64_t Jout = 2 * Jump.Target; + + // Assign costs for increasing/decreasing the jump counts + auto [AuxCostInc, AuxCostDec] = assignJumpCosts(Params, Jump); + + // Add the corresponding edges to the network + Network.addEdge(Jin, Jout, AuxCostInc); + if (Jump.Weight > 0) { + Network.addEdge(Jout, Jin, Jump.Weight, AuxCostDec); + Network.addEdge(S1, Jout, Jump.Weight, 0); + Network.addEdge(Jin, T1, Jump.Weight, 0); } } @@ -1128,55 +1122,130 @@ void initializeNetwork(MinCostMaxFlow &Network, FlowFunction &Func) { Network.addEdge(T, S, 0); } -/// Extract resulting block and edge counts from the flow network. -void extractWeights(MinCostMaxFlow &Network, FlowFunction &Func) { - uint64_t NumBlocks = Func.Blocks.size(); - - // Extract resulting block counts - for (uint64_t Src = 0; Src < NumBlocks; Src++) { - auto &Block = Func.Blocks[Src]; - uint64_t SrcOut = 3 * Src + 1; - int64_t Flow = 0; - for (auto &Adj : Network.getFlow(SrcOut)) { - uint64_t DstIn = Adj.first; - int64_t DstFlow = Adj.second; - bool IsAuxNode = (DstIn < 3 * NumBlocks && DstIn % 3 == 2); - if (!IsAuxNode || Block.HasSelfEdge) { - Flow += DstFlow; - } +/// Assign costs for increasing/decreasing the block counts. +std::pair<int64_t, int64_t> assignBlockCosts(const ProfiParams &Params, + const FlowBlock &Block) { + // Modifying the weight of an unlikely block is expensive + if (Block.IsUnlikely) + return std::make_pair(Params.CostUnlikely, Params.CostUnlikely); + + // Assign default values for the costs + int64_t CostInc = Params.CostBlockInc; + int64_t CostDec = Params.CostBlockDec; + // Update the costs depending on the block metadata + if (Block.HasUnknownWeight) { + CostInc = Params.CostBlockUnknownInc; + CostDec = 0; + } else { + // Increasing the count for "cold" blocks with zero initial count is more + // expensive than for "hot" ones + if (Block.Weight == 0) + CostInc = Params.CostBlockZeroInc; + // Modifying the count of the entry block is expensive + if (Block.isEntry()) { + CostInc = Params.CostBlockEntryInc; + CostDec = Params.CostBlockEntryDec; } - Block.Flow = Flow; - assert(Flow >= 0 && "negative block flow"); } + return std::make_pair(CostInc, CostDec); +} + +/// Assign costs for increasing/decreasing the jump counts. +std::pair<int64_t, int64_t> assignJumpCosts(const ProfiParams &Params, + const FlowJump &Jump) { + // Modifying the weight of an unlikely jump is expensive + if (Jump.IsUnlikely) + return std::make_pair(Params.CostUnlikely, Params.CostUnlikely); + + // Assign default values for the costs + int64_t CostInc = Params.CostJumpInc; + int64_t CostDec = Params.CostJumpDec; + // Update the costs depending on the block metadata + if (Jump.Source + 1 == Jump.Target) { + // Adjusting the fall-through branch + CostInc = Params.CostJumpFTInc; + CostDec = Params.CostJumpFTDec; + } + if (Jump.HasUnknownWeight) { + // The cost is different for fall-through and non-fall-through branches + if (Jump.Source + 1 == Jump.Target) + CostInc = Params.CostJumpUnknownFTInc; + else + CostInc = Params.CostJumpUnknownInc; + CostDec = 0; + } else { + assert(Jump.Weight > 0 && "found zero-weight jump with a positive weight"); + } + return std::make_pair(CostInc, CostDec); +} + +/// Extract resulting block and edge counts from the flow network. +void extractWeights(const ProfiParams &Params, MinCostMaxFlow &Network, + FlowFunction &Func) { + uint64_t NumBlocks = Func.Blocks.size(); + uint64_t NumJumps = Func.Jumps.size(); // Extract resulting jump counts - for (auto &Jump : Func.Jumps) { - uint64_t Src = Jump.Source; - uint64_t Dst = Jump.Target; + for (uint64_t J = 0; J < NumJumps; J++) { + auto &Jump = Func.Jumps[J]; + uint64_t SrcOut = 2 * Jump.Source + 1; + uint64_t DstIn = 2 * Jump.Target; + int64_t Flow = 0; - if (Src != Dst) { - uint64_t SrcOut = 3 * Src + 1; - uint64_t DstIn = 3 * Dst; - Flow = Network.getFlow(SrcOut, DstIn); - } else { - uint64_t SrcOut = 3 * Src + 1; - uint64_t SrcAux = 3 * Src + 2; - int64_t AuxFlow = Network.getFlow(SrcOut, SrcAux); - if (AuxFlow > 0) - Flow = AuxFlow; - } + int64_t AuxFlow = Network.getFlow(SrcOut, DstIn); + if (Jump.Source != Jump.Target) + Flow = int64_t(Jump.Weight) + AuxFlow; + else + Flow = int64_t(Jump.Weight) + (AuxFlow > 0 ? AuxFlow : 0); + Jump.Flow = Flow; assert(Flow >= 0 && "negative jump flow"); } + + // Extract resulting block counts + auto InFlow = std::vector<uint64_t>(NumBlocks, 0); + auto OutFlow = std::vector<uint64_t>(NumBlocks, 0); + for (auto &Jump : Func.Jumps) { + InFlow[Jump.Target] += Jump.Flow; + OutFlow[Jump.Source] += Jump.Flow; + } + for (uint64_t B = 0; B < NumBlocks; B++) { + auto &Block = Func.Blocks[B]; + Block.Flow = std::max(OutFlow[B], InFlow[B]); + } } #ifndef NDEBUG -/// Verify that the computed flow values satisfy flow conservation rules -void verifyWeights(const FlowFunction &Func) { +/// Verify that the provided block/jump weights are as expected. +void verifyInput(const FlowFunction &Func) { + // Verify the entry block + assert(Func.Entry == 0 && Func.Blocks[0].isEntry()); + for (size_t I = 1; I < Func.Blocks.size(); I++) { + assert(!Func.Blocks[I].isEntry() && "multiple entry blocks"); + } + // Verify CFG jumps + for (auto &Block : Func.Blocks) { + assert((!Block.isEntry() || !Block.isExit()) && + "a block cannot be an entry and an exit"); + } + // Verify input block weights + for (auto &Block : Func.Blocks) { + assert((!Block.HasUnknownWeight || Block.Weight == 0 || Block.isEntry()) && + "non-zero weight of a block w/o weight except for an entry"); + } + // Verify input jump weights + for (auto &Jump : Func.Jumps) { + assert((!Jump.HasUnknownWeight || Jump.Weight == 0) && + "non-zero weight of a jump w/o weight"); + } +} + +/// Verify that the computed flow values satisfy flow conservation rules. +void verifyOutput(const FlowFunction &Func) { const uint64_t NumBlocks = Func.Blocks.size(); auto InFlow = std::vector<uint64_t>(NumBlocks, 0); auto OutFlow = std::vector<uint64_t>(NumBlocks, 0); - for (auto &Jump : Func.Jumps) { + for (const auto &Jump : Func.Jumps) { InFlow[Jump.Target] += Jump.Flow; OutFlow[Jump.Source] += Jump.Flow; } @@ -1202,7 +1271,7 @@ void verifyWeights(const FlowFunction &Func) { // One could modify FlowFunction to hold edges indexed by the sources, which // will avoid a creation of the object auto PositiveFlowEdges = std::vector<std::vector<uint64_t>>(NumBlocks); - for (auto &Jump : Func.Jumps) { + for (const auto &Jump : Func.Jumps) { if (Jump.Flow > 0) { PositiveFlowEdges[Jump.Source].push_back(Jump.Target); } @@ -1235,22 +1304,44 @@ void verifyWeights(const FlowFunction &Func) { } // end of anonymous namespace -/// Apply the profile inference algorithm for a given flow function -void llvm::applyFlowInference(FlowFunction &Func) { +/// Apply the profile inference algorithm for a given function +void llvm::applyFlowInference(const ProfiParams &Params, FlowFunction &Func) { +#ifndef NDEBUG + // Verify the input data + verifyInput(Func); +#endif + // Create and apply an inference network model - auto InferenceNetwork = MinCostMaxFlow(); - initializeNetwork(InferenceNetwork, Func); + auto InferenceNetwork = MinCostMaxFlow(Params); + initializeNetwork(Params, InferenceNetwork, Func); InferenceNetwork.run(); // Extract flow values for every block and every edge - extractWeights(InferenceNetwork, Func); + extractWeights(Params, InferenceNetwork, Func); // Post-processing adjustments to the flow - auto Adjuster = FlowAdjuster(Func); + auto Adjuster = FlowAdjuster(Params, Func); Adjuster.run(); #ifndef NDEBUG // Verify the result - verifyWeights(Func); + verifyOutput(Func); #endif } + +/// Apply the profile inference algorithm for a given flow function +void llvm::applyFlowInference(FlowFunction &Func) { + ProfiParams Params; + // Set the params from the command-line flags. + Params.EvenFlowDistribution = SampleProfileEvenFlowDistribution; + Params.RebalanceUnknown = SampleProfileRebalanceUnknown; + Params.JoinIslands = SampleProfileJoinIslands; + Params.CostBlockInc = SampleProfileProfiCostBlockInc; + Params.CostBlockDec = SampleProfileProfiCostBlockDec; + Params.CostBlockEntryInc = SampleProfileProfiCostBlockEntryInc; + Params.CostBlockEntryDec = SampleProfileProfiCostBlockEntryDec; + Params.CostBlockZeroInc = SampleProfileProfiCostBlockZeroInc; + Params.CostBlockUnknownInc = SampleProfileProfiCostBlockUnknownInc; + + applyFlowInference(Params, Func); +} diff --git a/llvm/lib/Transforms/Utils/SampleProfileLoaderBaseUtil.cpp b/llvm/lib/Transforms/Utils/SampleProfileLoaderBaseUtil.cpp index a2588b8cec7d..f7ae6ad84494 100644 --- a/llvm/lib/Transforms/Utils/SampleProfileLoaderBaseUtil.cpp +++ b/llvm/lib/Transforms/Utils/SampleProfileLoaderBaseUtil.cpp @@ -42,10 +42,6 @@ cl::opt<bool> SampleProfileUseProfi( "sample-profile-use-profi", cl::Hidden, cl::desc("Use profi to infer block and edge counts.")); -cl::opt<bool> SampleProfileInferEntryCount( - "sample-profile-infer-entry-count", cl::init(true), cl::Hidden, - cl::desc("Use profi to infer function entry count.")); - namespace sampleprofutil { /// Return true if the given callsite is hot wrt to hot cutoff threshold. diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 372cd74ea01d..24f1966edd37 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -380,7 +381,7 @@ static void SimplifyAddOperands(SmallVectorImpl<const SCEV *> &Ops, // the sum into a single value, so just use that. Ops.clear(); if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Sum)) - Ops.append(Add->op_begin(), Add->op_end()); + append_range(Ops, Add->operands()); else if (!Sum->isZero()) Ops.push_back(Sum); // Then append the addrecs. @@ -408,7 +409,7 @@ static void SplitAddRecs(SmallVectorImpl<const SCEV *> &Ops, A->getNoWrapFlags(SCEV::FlagNW))); if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Start)) { Ops[i] = Zero; - Ops.append(Add->op_begin(), Add->op_end()); + append_range(Ops, Add->operands()); e += Add->getNumOperands(); } else { Ops[i] = Start; @@ -509,7 +510,7 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, Value *Scaled = ScaledOps.empty() ? Constant::getNullValue(Ty) - : expandCodeForImpl(SE.getAddExpr(ScaledOps), Ty, false); + : expandCodeForImpl(SE.getAddExpr(ScaledOps), Ty); GepIndices.push_back(Scaled); // Collect struct field index operands. @@ -570,13 +571,12 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint())); // Expand the operands for a plain byte offset. - Value *Idx = expandCodeForImpl(SE.getAddExpr(Ops), Ty, false); + Value *Idx = expandCodeForImpl(SE.getAddExpr(Ops), Ty); // Fold a GEP with constant operands. if (Constant *CLHS = dyn_cast<Constant>(V)) if (Constant *CRHS = dyn_cast<Constant>(Idx)) - return ConstantExpr::getGetElementPtr(Type::getInt8Ty(Ty->getContext()), - CLHS, CRHS); + return Builder.CreateGEP(Builder.getInt8Ty(), CLHS, CRHS); // Do a quick scan to see if we have this GEP nearby. If so, reuse it. unsigned ScanLimit = 6; @@ -678,31 +678,38 @@ const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) { if (!Pair.second) return Pair.first->second; - if (isa<SCEVConstant>(S)) - // A constant has no relevant loops. - return nullptr; - if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { - if (const Instruction *I = dyn_cast<Instruction>(U->getValue())) - return Pair.first->second = SE.LI.getLoopFor(I->getParent()); - // A non-instruction has no relevant loops. - return nullptr; - } - if (const SCEVNAryExpr *N = dyn_cast<SCEVNAryExpr>(S)) { + switch (S->getSCEVType()) { + case scConstant: + return nullptr; // A constant has no relevant loops. + case scTruncate: + case scZeroExtend: + case scSignExtend: + case scPtrToInt: + case scAddExpr: + case scMulExpr: + case scUDivExpr: + case scAddRecExpr: + case scUMaxExpr: + case scSMaxExpr: + case scUMinExpr: + case scSMinExpr: + case scSequentialUMinExpr: { const Loop *L = nullptr; if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) L = AR->getLoop(); - for (const SCEV *Op : N->operands()) + for (const SCEV *Op : S->operands()) L = PickMostRelevantLoop(L, getRelevantLoop(Op), SE.DT); - return RelevantLoops[N] = L; + return RelevantLoops[S] = L; } - if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(S)) { - const Loop *Result = getRelevantLoop(C->getOperand()); - return RelevantLoops[C] = Result; + case scUnknown: { + const SCEVUnknown *U = cast<SCEVUnknown>(S); + if (const Instruction *I = dyn_cast<Instruction>(U->getValue())) + return Pair.first->second = SE.LI.getLoopFor(I->getParent()); + // A non-instruction has no relevant loops. + return nullptr; } - if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { - const Loop *Result = PickMostRelevantLoop( - getRelevantLoop(D->getLHS()), getRelevantLoop(D->getRHS()), SE.DT); - return RelevantLoops[D] = Result; + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); } llvm_unreachable("Unexpected SCEV type!"); } @@ -787,14 +794,14 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum); } else if (Op->isNonConstantNegative()) { // Instead of doing a negate and add, just do a subtract. - Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty, false); + Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty); Sum = InsertNoopCastOfTo(Sum, Ty); Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); ++I; } else { // A simple add. - Value *W = expandCodeForImpl(Op, Ty, false); + Value *W = expandCodeForImpl(Op, Ty); Sum = InsertNoopCastOfTo(Sum, Ty); // Canonicalize a constant to the RHS. if (isa<Constant>(Sum)) std::swap(Sum, W); @@ -845,7 +852,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { // Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them // that are needed into the result. - Value *P = expandCodeForImpl(I->second, Ty, false); + Value *P = expandCodeForImpl(I->second, Ty); Value *Result = nullptr; if (Exponent & 1) Result = P; @@ -904,7 +911,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); - Value *LHS = expandCodeForImpl(S->getLHS(), Ty, false); + Value *LHS = expandCodeForImpl(S->getLHS(), Ty); if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) { const APInt &RHS = SC->getAPInt(); if (RHS.isPowerOf2()) @@ -913,7 +920,7 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); } - Value *RHS = expandCodeForImpl(S->getRHS(), Ty, false); + Value *RHS = expandCodeForImpl(S->getRHS(), Ty); return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap, /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS())); } @@ -1024,9 +1031,27 @@ void SCEVExpander::fixupInsertPoints(Instruction *I) { /// hoistStep - Attempt to hoist a simple IV increment above InsertPos to make /// it available to other uses in this loop. Recursively hoist any operands, /// until we reach a value that dominates InsertPos. -bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos) { - if (SE.DT.dominates(IncV, InsertPos)) - return true; +bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos, + bool RecomputePoisonFlags) { + auto FixupPoisonFlags = [this](Instruction *I) { + // Drop flags that are potentially inferred from old context and infer flags + // in new context. + I->dropPoisonGeneratingFlags(); + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) + if (auto Flags = SE.getStrengthenedNoWrapFlagsFromBinOp(OBO)) { + auto *BO = cast<BinaryOperator>(I); + BO->setHasNoUnsignedWrap( + ScalarEvolution::maskFlags(*Flags, SCEV::FlagNUW) == SCEV::FlagNUW); + BO->setHasNoSignedWrap( + ScalarEvolution::maskFlags(*Flags, SCEV::FlagNSW) == SCEV::FlagNSW); + } + }; + + if (SE.DT.dominates(IncV, InsertPos)) { + if (RecomputePoisonFlags) + FixupPoisonFlags(IncV); + return true; + } // InsertPos must itself dominate IncV so that IncV's new position satisfies // its existing users. @@ -1052,6 +1077,8 @@ bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos) { for (Instruction *I : llvm::reverse(IVIncs)) { fixupInsertPoints(I); I->moveBefore(InsertPos); + if (RecomputePoisonFlags) + FixupPoisonFlags(I); } return true; } @@ -1278,7 +1305,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, "Can't expand add recurrences without a loop preheader!"); Value *StartV = expandCodeForImpl(Normalized->getStart(), ExpandTy, - L->getLoopPreheader()->getTerminator(), false); + L->getLoopPreheader()->getTerminator()); // StartV must have been be inserted into L's preheader to dominate the new // phi. @@ -1297,7 +1324,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, Step = SE.getNegativeSCEV(Step); // Expand the step somewhere that dominates the loop header. Value *StepV = expandCodeForImpl( - Step, IntTy, &*L->getHeader()->getFirstInsertionPt(), false); + Step, IntTy, &*L->getHeader()->getFirstInsertionPt()); // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if // we actually do emit an addition. It does not apply if we emit a @@ -1455,7 +1482,7 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { // Expand the step somewhere that dominates the loop header. SCEVInsertPointGuard Guard(Builder, this); StepV = expandCodeForImpl( - Step, IntTy, &*L->getHeader()->getFirstInsertionPt(), false); + Step, IntTy, &*L->getHeader()->getFirstInsertionPt()); } Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); } @@ -1475,7 +1502,7 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { // Invert the result. if (InvertStep) Result = Builder.CreateSub( - expandCodeForImpl(Normalized->getStart(), TruncTy, false), Result); + expandCodeForImpl(Normalized->getStart(), TruncTy), Result); } // Re-apply any non-loop-dominating scale. @@ -1483,14 +1510,14 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { assert(S->isAffine() && "Can't linearly scale non-affine recurrences."); Result = InsertNoopCastOfTo(Result, IntTy); Result = Builder.CreateMul(Result, - expandCodeForImpl(PostLoopScale, IntTy, false)); + expandCodeForImpl(PostLoopScale, IntTy)); } // Re-apply any non-loop-dominating offset. if (PostLoopOffset) { if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) { if (Result->getType()->isIntegerTy()) { - Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy, false); + Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy); Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base); } else { Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result); @@ -1498,7 +1525,7 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { } else { Result = InsertNoopCastOfTo(Result, IntTy); Result = Builder.CreateAdd( - Result, expandCodeForImpl(PostLoopOffset, IntTy, false)); + Result, expandCodeForImpl(PostLoopOffset, IntTy)); } } @@ -1508,7 +1535,7 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { // In canonical mode we compute the addrec as an expression of a canonical IV // using evaluateAtIteration and expand the resulting SCEV expression. This - // way we avoid introducing new IVs to carry on the comutation of the addrec + // way we avoid introducing new IVs to carry on the computation of the addrec // throughout the loop. // // For nested addrecs evaluateAtIteration might need a canonical IV of a @@ -1535,13 +1562,13 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { !S->getType()->isPointerTy()) { SmallVector<const SCEV *, 4> NewOps(S->getNumOperands()); for (unsigned i = 0, e = S->getNumOperands(); i != e; ++i) - NewOps[i] = SE.getAnyExtendExpr(S->op_begin()[i], CanonicalIV->getType()); + NewOps[i] = SE.getAnyExtendExpr(S->getOperand(i), CanonicalIV->getType()); Value *V = expand(SE.getAddRecExpr(NewOps, S->getLoop(), S->getNoWrapFlags(SCEV::FlagNW))); BasicBlock::iterator NewInsertPt = findInsertPointAfter(cast<Instruction>(V), &*Builder.GetInsertPoint()); V = expandCodeForImpl(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr, - &*NewInsertPt, false); + &*NewInsertPt); return V; } @@ -1643,7 +1670,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { Value *SCEVExpander::visitPtrToIntExpr(const SCEVPtrToIntExpr *S) { Value *V = - expandCodeForImpl(S->getOperand(), S->getOperand()->getType(), false); + expandCodeForImpl(S->getOperand(), S->getOperand()->getType()); return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToInt, GetOptimalInsertionPointForCastOf(V)); } @@ -1651,24 +1678,24 @@ Value *SCEVExpander::visitPtrToIntExpr(const SCEVPtrToIntExpr *S) { Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); Value *V = expandCodeForImpl( - S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()), - false); + S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()) + ); return Builder.CreateTrunc(V, Ty); } Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); Value *V = expandCodeForImpl( - S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()), - false); + S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()) + ); return Builder.CreateZExt(V, Ty); } Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); Value *V = expandCodeForImpl( - S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()), - false); + S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()) + ); return Builder.CreateSExt(V, Ty); } @@ -1680,7 +1707,7 @@ Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S, if (IsSequential) LHS = Builder.CreateFreeze(LHS); for (int i = S->getNumOperands() - 2; i >= 0; --i) { - Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); + Value *RHS = expandCodeForImpl(S->getOperand(i), Ty); if (IsSequential && i != 0) RHS = Builder.CreateFreeze(RHS); Value *Sel; @@ -1718,44 +1745,16 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) { } Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, - Instruction *IP, bool Root) { + Instruction *IP) { setInsertPoint(IP); - Value *V = expandCodeForImpl(SH, Ty, Root); + Value *V = expandCodeForImpl(SH, Ty); return V; } -Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root) { +Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty) { // Expand the code for this SCEV. Value *V = expand(SH); - if (PreserveLCSSA) { - if (auto *Inst = dyn_cast<Instruction>(V)) { - // Create a temporary instruction to at the current insertion point, so we - // can hand it off to the helper to create LCSSA PHIs if required for the - // new use. - // FIXME: Ideally formLCSSAForInstructions (used in fixupLCSSAFormFor) - // would accept a insertion point and return an LCSSA phi for that - // insertion point, so there is no need to insert & remove the temporary - // instruction. - Instruction *Tmp; - if (Inst->getType()->isIntegerTy()) - Tmp = cast<Instruction>(Builder.CreateIntToPtr( - Inst, Inst->getType()->getPointerTo(), "tmp.lcssa.user")); - else { - assert(Inst->getType()->isPointerTy()); - Tmp = cast<Instruction>(Builder.CreatePtrToInt( - Inst, Type::getInt32Ty(Inst->getContext()), "tmp.lcssa.user")); - } - V = fixupLCSSAFormFor(Tmp, 0); - - // Clean up temporary instruction. - InsertedValues.erase(Tmp); - InsertedPostIncValues.erase(Tmp); - Tmp->eraseFromParent(); - } - } - - InsertedExpressions[std::make_pair(SH, &*Builder.GetInsertPoint())] = V; if (Ty) { assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) && "non-trivial casts should be done with the SCEVs directly!"); @@ -1860,9 +1859,10 @@ Value *SCEVExpander::expand(const SCEV *S) { // Expand the expression into instructions. Value *V = FindValueInExprValueMap(S, InsertPt); - if (!V) + if (!V) { V = visit(S); - else { + V = fixupLCSSAFormFor(V); + } else { // If we're reusing an existing instruction, we are effectively CSEing two // copies of the instruction (with potentially different flags). As such, // we need to drop any poison generating flags unless we can prove that @@ -1889,18 +1889,6 @@ void SCEVExpander::rememberInstruction(Value *I) { InsertedValues.insert(V); }; DoInsert(I); - - if (!PreserveLCSSA) - return; - - if (auto *Inst = dyn_cast<Instruction>(I)) { - // A new instruction has been added, which might introduce new uses outside - // a defining loop. Fix LCSSA from for each operand of the new instruction, - // if required. - for (unsigned OpIdx = 0, OpEnd = Inst->getNumOperands(); OpIdx != OpEnd; - OpIdx++) - fixupLCSSAFormFor(Inst, OpIdx); - } } /// replaceCongruentIVs - Check for congruent phis in this loop header and @@ -1925,8 +1913,8 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, // Put pointers at the back and make sure pointer < pointer = false. if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return RHS->getType()->isIntegerTy() && !LHS->getType()->isIntegerTy(); - return RHS->getType()->getPrimitiveSizeInBits().getFixedSize() < - LHS->getType()->getPrimitiveSizeInBits().getFixedSize(); + return RHS->getType()->getPrimitiveSizeInBits().getFixedValue() < + LHS->getType()->getPrimitiveSizeInBits().getFixedValue(); }); unsigned NumElim = 0; @@ -1950,6 +1938,7 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, if (Value *V = SimplifyPHINode(Phi)) { if (V->getType() != Phi->getType()) continue; + SE.forgetValue(Phi); Phi->replaceAllUsesWith(V); DeadInsts.emplace_back(Phi); ++NumElim; @@ -2006,12 +1995,14 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, // with the original phi. It's worth eagerly cleaning up the // common case of a single IV increment so that DeleteDeadPHIs // can remove cycles that had postinc uses. + // Because we may potentially introduce a new use of OrigIV that didn't + // exist before at this point, its poison flags need readjustment. const SCEV *TruncExpr = SE.getTruncateOrNoop(SE.getSCEV(OrigInc), IsomorphicInc->getType()); if (OrigInc != IsomorphicInc && TruncExpr == SE.getSCEV(IsomorphicInc) && SE.LI.replacementPreservesLCSSAForm(IsomorphicInc, OrigInc) && - hoistIVInc(OrigInc, IsomorphicInc)) { + hoistIVInc(OrigInc, IsomorphicInc, /*RecomputePoisonFlags*/ true)) { SCEV_DEBUG_WITH_TYPE( DebugType, dbgs() << "INDVARS: Eliminated congruent iv.inc: " << *IsomorphicInc << '\n'); @@ -2122,7 +2113,7 @@ template<typename T> static InstructionCost costAndCollectOperands( auto CmpSelCost = [&](unsigned Opcode, unsigned NumRequired, unsigned MinIdx, unsigned MaxIdx) -> InstructionCost { Operations.emplace_back(Opcode, MinIdx, MaxIdx); - Type *OpType = S->getOperand(0)->getType(); + Type *OpType = S->getType(); return NumRequired * TTI.getCmpSelInstrCost( Opcode, OpType, CmpInst::makeCmpResultType(OpType), CmpInst::BAD_ICMP_PREDICATE, CostKind); @@ -2191,7 +2182,7 @@ template<typename T> static InstructionCost costAndCollectOperands( } case scAddRecExpr: { // In this polynominal, we may have some zero operands, and we shouldn't - // really charge for those. So how many non-zero coeffients are there? + // really charge for those. So how many non-zero coefficients are there? int NumTerms = llvm::count_if(S->operands(), [](const SCEV *Op) { return !Op->isZero(); }); @@ -2200,7 +2191,7 @@ template<typename T> static InstructionCost costAndCollectOperands( assert(!(*std::prev(S->operands().end()))->isZero() && "Last operand should not be zero"); - // Ignoring constant term (operand 0), how many of the coeffients are u> 1? + // Ignoring constant term (operand 0), how many of the coefficients are u> 1? int NumNonZeroDegreeNonOneTerms = llvm::count_if(S->operands(), [](const SCEV *Op) { auto *SConst = dyn_cast<SCEVConstant>(Op); @@ -2351,9 +2342,9 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, Value *SCEVExpander::expandComparePredicate(const SCEVComparePredicate *Pred, Instruction *IP) { Value *Expr0 = - expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP, false); + expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP); Value *Expr1 = - expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP, false); + expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP); Builder.SetInsertPoint(IP); auto InvPred = ICmpInst::getInversePredicate(Pred->getPredicate()); @@ -2387,15 +2378,15 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); Builder.SetInsertPoint(Loc); - Value *TripCountVal = expandCodeForImpl(ExitCount, CountTy, Loc, false); + Value *TripCountVal = expandCodeForImpl(ExitCount, CountTy, Loc); IntegerType *Ty = IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy)); - Value *StepValue = expandCodeForImpl(Step, Ty, Loc, false); + Value *StepValue = expandCodeForImpl(Step, Ty, Loc); Value *NegStepValue = - expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc, false); - Value *StartValue = expandCodeForImpl(Start, ARTy, Loc, false); + expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc); + Value *StartValue = expandCodeForImpl(Start, ARTy, Loc); ConstantInt *Zero = ConstantInt::get(Loc->getContext(), APInt::getZero(DstBits)); @@ -2519,7 +2510,7 @@ Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, Instruction *IP) { // Loop over all checks in this set. SmallVector<Value *> Checks; - for (auto Pred : Union->getPredicates()) { + for (const auto *Pred : Union->getPredicates()) { Checks.push_back(expandCodeForPredicate(Pred, IP)); Builder.SetInsertPoint(IP); } @@ -2529,21 +2520,36 @@ Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, return Builder.CreateOr(Checks); } -Value *SCEVExpander::fixupLCSSAFormFor(Instruction *User, unsigned OpIdx) { - assert(PreserveLCSSA); - SmallVector<Instruction *, 1> ToUpdate; - - auto *OpV = User->getOperand(OpIdx); - auto *OpI = dyn_cast<Instruction>(OpV); - if (!OpI) - return OpV; +Value *SCEVExpander::fixupLCSSAFormFor(Value *V) { + auto *DefI = dyn_cast<Instruction>(V); + if (!PreserveLCSSA || !DefI) + return V; - Loop *DefLoop = SE.LI.getLoopFor(OpI->getParent()); - Loop *UseLoop = SE.LI.getLoopFor(User->getParent()); + Instruction *InsertPt = &*Builder.GetInsertPoint(); + Loop *DefLoop = SE.LI.getLoopFor(DefI->getParent()); + Loop *UseLoop = SE.LI.getLoopFor(InsertPt->getParent()); if (!DefLoop || UseLoop == DefLoop || DefLoop->contains(UseLoop)) - return OpV; + return V; + + // Create a temporary instruction to at the current insertion point, so we + // can hand it off to the helper to create LCSSA PHIs if required for the + // new use. + // FIXME: Ideally formLCSSAForInstructions (used in fixupLCSSAFormFor) + // would accept a insertion point and return an LCSSA phi for that + // insertion point, so there is no need to insert & remove the temporary + // instruction. + Type *ToTy; + if (DefI->getType()->isIntegerTy()) + ToTy = DefI->getType()->getPointerTo(); + else + ToTy = Type::getInt32Ty(DefI->getContext()); + Instruction *User = + CastInst::CreateBitOrPointerCast(DefI, ToTy, "tmp.lcssa.user", InsertPt); + auto RemoveUserOnExit = + make_scope_exit([User]() { User->eraseFromParent(); }); - ToUpdate.push_back(OpI); + SmallVector<Instruction *, 1> ToUpdate; + ToUpdate.push_back(DefI); SmallVector<PHINode *, 16> PHIsToRemove; formLCSSAForInstructions(ToUpdate, SE.DT, SE.LI, &SE, Builder, &PHIsToRemove); for (PHINode *PN : PHIsToRemove) { @@ -2554,7 +2560,7 @@ Value *SCEVExpander::fixupLCSSAFormFor(Instruction *User, unsigned OpIdx) { PN->eraseFromParent(); } - return User->getOperand(OpIdx); + return User->getOperand(0); } namespace { @@ -2666,7 +2672,7 @@ void SCEVExpanderCleaner::cleanup() { #endif assert(!I->getType()->isVoidTy() && "inserted instruction should have non-void types"); - I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->replaceAllUsesWith(PoisonValue::get(I->getType())); I->eraseFromParent(); } } diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 1806081678a8..9e0483966d3e 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -14,7 +14,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" @@ -41,6 +40,7 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" @@ -57,6 +57,7 @@ #include "llvm/IR/NoFolder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -80,6 +81,7 @@ #include <cstdint> #include <iterator> #include <map> +#include <optional> #include <set> #include <tuple> #include <utility> @@ -115,6 +117,12 @@ static cl::opt<bool> HoistCommon("simplifycfg-hoist-common", cl::Hidden, cl::init(true), cl::desc("Hoist common instructions up to the parent block")); +static cl::opt<unsigned> + HoistCommonSkipLimit("simplifycfg-hoist-common-skip-limit", cl::Hidden, + cl::init(20), + cl::desc("Allow reordering across at most this many " + "instructions when hoisting")); + static cl::opt<bool> SinkCommon("simplifycfg-sink-common", cl::Hidden, cl::init(true), cl::desc("Sink common instructions down to the end block")); @@ -380,7 +388,7 @@ static InstructionCost computeSpeculationCost(const User *I, assert((!isa<Instruction>(I) || isSafeToSpeculativelyExecute(cast<Instruction>(I))) && "Instruction is not safe to speculatively execute!"); - return TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency); + return TTI.getInstructionCost(I, TargetTransformInfo::TCK_SizeAndLatency); } /// If we have a merge point of an "if condition" as accepted above, @@ -472,7 +480,8 @@ static bool dominatesMergePoint(Value *V, BasicBlock *BB, static ConstantInt *GetConstantInt(Value *V, const DataLayout &DL) { // Normal constant int. ConstantInt *CI = dyn_cast<ConstantInt>(V); - if (CI || !isa<Constant>(V) || !V->getType()->isPointerTy()) + if (CI || !isa<Constant>(V) || !V->getType()->isPointerTy() || + DL.isNonIntegralPointerType(V->getType())) return CI; // This is some kind of pointer constant. Turn it into a pointer-sized @@ -829,8 +838,8 @@ static bool ValuesOverlap(std::vector<ValueEqualityComparisonCase> &C1, if (V1->size() == 1) { // Just scan V2. ConstantInt *TheVal = (*V1)[0].Value; - for (unsigned i = 0, e = V2->size(); i != e; ++i) - if (TheVal == (*V2)[i].Value) + for (const ValueEqualityComparisonCase &VECC : *V2) + if (TheVal == VECC.Value) return true; } @@ -1050,15 +1059,6 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, return LHS->getValue().ult(RHS->getValue()) ? 1 : -1; } -static inline bool HasBranchWeights(const Instruction *I) { - MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof); - if (ProfMD && ProfMD->getOperand(0)) - if (MDString *MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) - return MDS->getString().equals("branch_weights"); - - return false; -} - /// Get Weights of a given terminator, the default weight is at the front /// of the vector. If TI is a conditional eq, we need to swap the branch-weight /// metadata. @@ -1128,7 +1128,7 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( NewBonusInst->dropUndefImplyingAttrsAndUnknownMetadata( LLVMContext::MD_annotation); - PredBlock->getInstList().insert(PTI->getIterator(), NewBonusInst); + NewBonusInst->insertInto(PredBlock, PTI->getIterator()); NewBonusInst->takeName(&BonusInst); BonusInst.setName(NewBonusInst->getName() + ".old"); @@ -1177,8 +1177,8 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding( // Update the branch weight metadata along the way SmallVector<uint64_t, 8> Weights; - bool PredHasWeights = HasBranchWeights(PTI); - bool SuccHasWeights = HasBranchWeights(TI); + bool PredHasWeights = hasBranchWeightMD(*PTI); + bool SuccHasWeights = hasBranchWeightMD(*TI); if (PredHasWeights) { GetBranchWeights(PTI, Weights); @@ -1430,6 +1430,64 @@ static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2, return true; } +// Get interesting characteristics of instructions that `HoistThenElseCodeToIf` +// didn't hoist. They restrict what kind of instructions can be reordered +// across. +enum SkipFlags { + SkipReadMem = 1, + SkipSideEffect = 2, + SkipImplicitControlFlow = 4 +}; + +static unsigned skippedInstrFlags(Instruction *I) { + unsigned Flags = 0; + if (I->mayReadFromMemory()) + Flags |= SkipReadMem; + // We can't arbitrarily move around allocas, e.g. moving allocas (especially + // inalloca) across stacksave/stackrestore boundaries. + if (I->mayHaveSideEffects() || isa<AllocaInst>(I)) + Flags |= SkipSideEffect; + if (!isGuaranteedToTransferExecutionToSuccessor(I)) + Flags |= SkipImplicitControlFlow; + return Flags; +} + +// Returns true if it is safe to reorder an instruction across preceding +// instructions in a basic block. +static bool isSafeToHoistInstr(Instruction *I, unsigned Flags) { + // Don't reorder a store over a load. + if ((Flags & SkipReadMem) && I->mayWriteToMemory()) + return false; + + // If we have seen an instruction with side effects, it's unsafe to reorder an + // instruction which reads memory or itself has side effects. + if ((Flags & SkipSideEffect) && + (I->mayReadFromMemory() || I->mayHaveSideEffects())) + return false; + + // Reordering across an instruction which does not necessarily transfer + // control to the next instruction is speculation. + if ((Flags & SkipImplicitControlFlow) && !isSafeToSpeculativelyExecute(I)) + return false; + + // Hoisting of llvm.deoptimize is only legal together with the next return + // instruction, which this pass is not always able to do. + if (auto *CB = dyn_cast<CallBase>(I)) + if (CB->getIntrinsicID() == Intrinsic::experimental_deoptimize) + return false; + + // It's also unsafe/illegal to hoist an instruction above its instruction + // operands + BasicBlock *BB = I->getParent(); + for (Value *Op : I->operands()) { + if (auto *J = dyn_cast<Instruction>(Op)) + if (J->getParent() == BB) + return false; + } + + return true; +} + static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValueMayBeModified = false); /// Given a conditional branch that goes to BB1 and BB2, hoist any common code @@ -1444,7 +1502,8 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, // instructions in the two blocks. In particular, we don't want to get into // O(M*N) situations here where M and N are the sizes of BB1 and BB2. As // such, we currently just scan for obviously identical instructions in an - // identical order. + // identical order, possibly separated by the same number of non-identical + // instructions. BasicBlock *BB1 = BI->getSuccessor(0); // The true destination. BasicBlock *BB2 = BI->getSuccessor(1); // The false destination @@ -1467,7 +1526,7 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, while (isa<DbgInfoIntrinsic>(I2)) I2 = &*BB2_Itr++; } - if (isa<PHINode>(I1) || !I1->isIdenticalToWhenDefined(I2)) + if (isa<PHINode>(I1)) return false; BasicBlock *BIParent = BI->getParent(); @@ -1493,75 +1552,104 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, // terminator. Let the loop below handle those 2 cases. } - do { + // Count how many instructions were not hoisted so far. There's a limit on how + // many instructions we skip, serving as a compilation time control as well as + // preventing excessive increase of life ranges. + unsigned NumSkipped = 0; + + // Record any skipped instuctions that may read memory, write memory or have + // side effects, or have implicit control flow. + unsigned SkipFlagsBB1 = 0; + unsigned SkipFlagsBB2 = 0; + + for (;;) { // If we are hoisting the terminator instruction, don't move one (making a // broken BB), instead clone it, and remove BI. - if (I1->isTerminator()) + if (I1->isTerminator() || I2->isTerminator()) { + // If any instructions remain in the block, we cannot hoist terminators. + if (NumSkipped || !I1->isIdenticalToWhenDefined(I2)) + return Changed; goto HoistTerminator; + } - // If we're going to hoist a call, make sure that the two instructions we're - // commoning/hoisting are both marked with musttail, or neither of them is - // marked as such. Otherwise, we might end up in a situation where we hoist - // from a block where the terminator is a `ret` to a block where the terminator - // is a `br`, and `musttail` calls expect to be followed by a return. - auto *C1 = dyn_cast<CallInst>(I1); - auto *C2 = dyn_cast<CallInst>(I2); - if (C1 && C2) - if (C1->isMustTailCall() != C2->isMustTailCall()) + if (I1->isIdenticalToWhenDefined(I2)) { + // Even if the instructions are identical, it may not be safe to hoist + // them if we have skipped over instructions with side effects or their + // operands weren't hoisted. + if (!isSafeToHoistInstr(I1, SkipFlagsBB1) || + !isSafeToHoistInstr(I2, SkipFlagsBB2)) return Changed; - if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2)) - return Changed; - - // If any of the two call sites has nomerge attribute, stop hoisting. - if (const auto *CB1 = dyn_cast<CallBase>(I1)) - if (CB1->cannotMerge()) - return Changed; - if (const auto *CB2 = dyn_cast<CallBase>(I2)) - if (CB2->cannotMerge()) + // If we're going to hoist a call, make sure that the two instructions + // we're commoning/hoisting are both marked with musttail, or neither of + // them is marked as such. Otherwise, we might end up in a situation where + // we hoist from a block where the terminator is a `ret` to a block where + // the terminator is a `br`, and `musttail` calls expect to be followed by + // a return. + auto *C1 = dyn_cast<CallInst>(I1); + auto *C2 = dyn_cast<CallInst>(I2); + if (C1 && C2) + if (C1->isMustTailCall() != C2->isMustTailCall()) + return Changed; + + if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2)) return Changed; - if (isa<DbgInfoIntrinsic>(I1) || isa<DbgInfoIntrinsic>(I2)) { - assert (isa<DbgInfoIntrinsic>(I1) && isa<DbgInfoIntrinsic>(I2)); - // The debug location is an integral part of a debug info intrinsic - // and can't be separated from it or replaced. Instead of attempting - // to merge locations, simply hoist both copies of the intrinsic. - BIParent->getInstList().splice(BI->getIterator(), - BB1->getInstList(), I1); - BIParent->getInstList().splice(BI->getIterator(), - BB2->getInstList(), I2); + // If any of the two call sites has nomerge attribute, stop hoisting. + if (const auto *CB1 = dyn_cast<CallBase>(I1)) + if (CB1->cannotMerge()) + return Changed; + if (const auto *CB2 = dyn_cast<CallBase>(I2)) + if (CB2->cannotMerge()) + return Changed; + + if (isa<DbgInfoIntrinsic>(I1) || isa<DbgInfoIntrinsic>(I2)) { + assert(isa<DbgInfoIntrinsic>(I1) && isa<DbgInfoIntrinsic>(I2)); + // The debug location is an integral part of a debug info intrinsic + // and can't be separated from it or replaced. Instead of attempting + // to merge locations, simply hoist both copies of the intrinsic. + BIParent->splice(BI->getIterator(), BB1, I1->getIterator()); + BIParent->splice(BI->getIterator(), BB2, I2->getIterator()); + } else { + // For a normal instruction, we just move one to right before the + // branch, then replace all uses of the other with the first. Finally, + // we remove the now redundant second instruction. + BIParent->splice(BI->getIterator(), BB1, I1->getIterator()); + if (!I2->use_empty()) + I2->replaceAllUsesWith(I1); + I1->andIRFlags(I2); + unsigned KnownIDs[] = {LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_fpmath, + LLVMContext::MD_invariant_load, + LLVMContext::MD_nonnull, + LLVMContext::MD_invariant_group, + LLVMContext::MD_align, + LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_mem_parallel_loop_access, + LLVMContext::MD_access_group, + LLVMContext::MD_preserve_access_index}; + combineMetadata(I1, I2, KnownIDs, true); + + // I1 and I2 are being combined into a single instruction. Its debug + // location is the merged locations of the original instructions. + I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); + + I2->eraseFromParent(); + } Changed = true; + ++NumHoistCommonInstrs; } else { - // For a normal instruction, we just move one to right before the branch, - // then replace all uses of the other with the first. Finally, we remove - // the now redundant second instruction. - BIParent->getInstList().splice(BI->getIterator(), - BB1->getInstList(), I1); - if (!I2->use_empty()) - I2->replaceAllUsesWith(I1); - I1->andIRFlags(I2); - unsigned KnownIDs[] = {LLVMContext::MD_tbaa, - LLVMContext::MD_range, - LLVMContext::MD_fpmath, - LLVMContext::MD_invariant_load, - LLVMContext::MD_nonnull, - LLVMContext::MD_invariant_group, - LLVMContext::MD_align, - LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null, - LLVMContext::MD_mem_parallel_loop_access, - LLVMContext::MD_access_group, - LLVMContext::MD_preserve_access_index}; - combineMetadata(I1, I2, KnownIDs, true); - - // I1 and I2 are being combined into a single instruction. Its debug - // location is the merged locations of the original instructions. - I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); - - I2->eraseFromParent(); - Changed = true; + if (NumSkipped >= HoistCommonSkipLimit) + return Changed; + // We are about to skip over a pair of non-identical instructions. Record + // if any have characteristics that would prevent reordering instructions + // across them. + SkipFlagsBB1 |= skippedInstrFlags(I1); + SkipFlagsBB2 |= skippedInstrFlags(I2); + ++NumSkipped; } - ++NumHoistCommonInstrs; I1 = &*BB1_Itr++; I2 = &*BB2_Itr++; @@ -1574,9 +1662,9 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, while (isa<DbgInfoIntrinsic>(I2)) I2 = &*BB2_Itr++; } - } while (I1->isIdenticalToWhenDefined(I2)); + } - return true; + return Changed; HoistTerminator: // It may not be possible to hoist an invoke. @@ -1605,7 +1693,7 @@ HoistTerminator: // Okay, it is safe to hoist the terminator. Instruction *NT = I1->clone(); - BIParent->getInstList().insert(BI->getIterator(), NT); + NT->insertInto(BIParent, BI->getIterator()); if (!NT->getType()->isVoidTy()) { I1->replaceAllUsesWith(NT); I2->replaceAllUsesWith(NT); @@ -1915,9 +2003,15 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { } // Finally nuke all instructions apart from the common instruction. - for (auto *I : Insts) - if (I != I0) - I->eraseFromParent(); + for (auto *I : Insts) { + if (I == I0) + continue; + // The remaining uses are debug users, replace those with the common inst. + // In most (all?) cases this just introduces a use-before-def. + assert(I->user_empty() && "Inst unexpectedly still has non-dbg users"); + I->replaceAllUsesWith(I0); + I->eraseFromParent(); + } return true; } @@ -2403,7 +2497,7 @@ static void MergeCompatibleInvokesImpl(ArrayRef<InvokeInst *> Invokes, auto *MergedInvoke = cast<InvokeInst>(II0->clone()); // NOTE: all invokes have the same attributes, so no handling needed. - MergedInvokeBB->getInstList().push_back(MergedInvoke); + MergedInvoke->insertInto(MergedInvokeBB, MergedInvokeBB->end()); if (!HasNormalDest) { // This set does not have a normal destination, @@ -2551,6 +2645,34 @@ static bool MergeCompatibleInvokes(BasicBlock *BB, DomTreeUpdater *DTU) { return Changed; } +namespace { +/// Track ephemeral values, which should be ignored for cost-modelling +/// purposes. Requires walking instructions in reverse order. +class EphemeralValueTracker { + SmallPtrSet<const Instruction *, 32> EphValues; + + bool isEphemeral(const Instruction *I) { + if (isa<AssumeInst>(I)) + return true; + return !I->mayHaveSideEffects() && !I->isTerminator() && + all_of(I->users(), [&](const User *U) { + return EphValues.count(cast<Instruction>(U)); + }); + } + +public: + bool track(const Instruction *I) { + if (isEphemeral(I)) { + EphValues.insert(I); + return true; + } + return false; + } + + bool contains(const Instruction *I) const { return EphValues.contains(I); } +}; +} // namespace + /// Determine if we can hoist sink a sole store instruction out of a /// conditional block. /// @@ -2752,7 +2874,8 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // the `then` block, then avoid speculating it. if (!BI->getMetadata(LLVMContext::MD_unpredictable)) { uint64_t TWeight, FWeight; - if (BI->extractProfMetadata(TWeight, FWeight) && (TWeight + FWeight) != 0) { + if (extractBranchWeights(*BI, TWeight, FWeight) && + (TWeight + FWeight) != 0) { uint64_t EndWeight = Invert ? TWeight : FWeight; BranchProbability BIEndProb = BranchProbability::getBranchProbability(EndWeight, TWeight + FWeight); @@ -2774,13 +2897,11 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, unsigned SpeculatedInstructions = 0; Value *SpeculatedStoreValue = nullptr; StoreInst *SpeculatedStore = nullptr; - for (BasicBlock::iterator BBI = ThenBB->begin(), - BBE = std::prev(ThenBB->end()); - BBI != BBE; ++BBI) { - Instruction *I = &*BBI; + EphemeralValueTracker EphTracker; + for (Instruction &I : reverse(drop_end(*ThenBB))) { // Skip debug info. if (isa<DbgInfoIntrinsic>(I)) { - SpeculatedDbgIntrinsics.push_back(I); + SpeculatedDbgIntrinsics.push_back(&I); continue; } @@ -2792,10 +2913,14 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // the samples collected on the non-conditional path are counted towards // the conditional path. We leave it for the counts inference algorithm to // figure out a proper count for an unknown probe. - SpeculatedDbgIntrinsics.push_back(I); + SpeculatedDbgIntrinsics.push_back(&I); continue; } + // Ignore ephemeral values, they will be dropped by the transform. + if (EphTracker.track(&I)) + continue; + // Only speculatively execute a single instruction (not counting the // terminator) for now. ++SpeculatedInstructions; @@ -2803,23 +2928,23 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, return false; // Don't hoist the instruction if it's unsafe or expensive. - if (!isSafeToSpeculativelyExecute(I) && + if (!isSafeToSpeculativelyExecute(&I) && !(HoistCondStores && (SpeculatedStoreValue = isSafeToSpeculateStore( - I, BB, ThenBB, EndBB)))) + &I, BB, ThenBB, EndBB)))) return false; if (!SpeculatedStoreValue && - computeSpeculationCost(I, TTI) > + computeSpeculationCost(&I, TTI) > PHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic) return false; // Store the store speculation candidate. if (SpeculatedStoreValue) - SpeculatedStore = cast<StoreInst>(I); + SpeculatedStore = cast<StoreInst>(&I); // Do not hoist the instruction if any of its operands are defined but not // used in BB. The transformation will prevent the operand from // being sunk into the use block. - for (Use &Op : I->operands()) { + for (Use &Op : I.operands()) { Instruction *OpI = dyn_cast<Instruction>(Op); if (!OpI || OpI->getParent() != BB || OpI->mayHaveSideEffects()) continue; // Not a candidate for sinking. @@ -2831,11 +2956,8 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Consider any sink candidates which are only used in ThenBB as costs for // speculation. Note, while we iterate over a DenseMap here, we are summing // and so iteration order isn't significant. - for (SmallDenseMap<Instruction *, unsigned, 4>::iterator - I = SinkCandidateUseCounts.begin(), - E = SinkCandidateUseCounts.end(); - I != E; ++I) - if (I->first->hasNUses(I->second)) { + for (const auto &[Inst, Count] : SinkCandidateUseCounts) + if (Inst->hasNUses(Count)) { ++SpeculatedInstructions; if (SpeculatedInstructions > 1) return false; @@ -2857,6 +2979,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Insert a select of the value of the speculated store. if (SpeculatedStoreValue) { IRBuilder<NoFolder> Builder(BI); + Value *OrigV = SpeculatedStore->getValueOperand(); Value *TrueV = SpeculatedStore->getValueOperand(); Value *FalseV = SpeculatedStoreValue; if (Invert) @@ -2866,6 +2989,35 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, SpeculatedStore->setOperand(0, S); SpeculatedStore->applyMergedLocation(BI->getDebugLoc(), SpeculatedStore->getDebugLoc()); + // The value stored is still conditional, but the store itself is now + // unconditonally executed, so we must be sure that any linked dbg.assign + // intrinsics are tracking the new stored value (the result of the + // select). If we don't, and the store were to be removed by another pass + // (e.g. DSE), then we'd eventually end up emitting a location describing + // the conditional value, unconditionally. + // + // === Before this transformation === + // pred: + // store %one, %x.dest, !DIAssignID !1 + // dbg.assign %one, "x", ..., !1, ... + // br %cond if.then + // + // if.then: + // store %two, %x.dest, !DIAssignID !2 + // dbg.assign %two, "x", ..., !2, ... + // + // === After this transformation === + // pred: + // store %one, %x.dest, !DIAssignID !1 + // dbg.assign %one, "x", ..., !1 + /// ... + // %merge = select %cond, %two, %one + // store %merge, %x.dest, !DIAssignID !2 + // dbg.assign %merge, "x", ..., !2 + for (auto *DAI : at::getAssignmentMarkers(SpeculatedStore)) { + if (any_of(DAI->location_ops(), [&](Value *V) { return V == OrigV; })) + DAI->replaceVariableLocationOp(OrigV, S); + } } // Metadata can be dependent on the condition we are hoisting above. @@ -2874,15 +3026,24 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // be misleading while debugging. // Similarly strip attributes that maybe dependent on condition we are // hoisting above. - for (auto &I : *ThenBB) { - if (!SpeculatedStoreValue || &I != SpeculatedStore) - I.setDebugLoc(DebugLoc()); + for (auto &I : make_early_inc_range(*ThenBB)) { + if (!SpeculatedStoreValue || &I != SpeculatedStore) { + // Don't update the DILocation of dbg.assign intrinsics. + if (!isa<DbgAssignIntrinsic>(&I)) + I.setDebugLoc(DebugLoc()); + } I.dropUndefImplyingAttrsAndUnknownMetadata(); + + // Drop ephemeral values. + if (EphTracker.contains(&I)) { + I.replaceAllUsesWith(PoisonValue::get(I.getType())); + I.eraseFromParent(); + } } // Hoist the instructions. - BB->getInstList().splice(BI->getIterator(), ThenBB->getInstList(), - ThenBB->begin(), std::prev(ThenBB->end())); + BB->splice(BI->getIterator(), ThenBB, ThenBB->begin(), + std::prev(ThenBB->end())); // Insert selects and rewrite the PHI operands. IRBuilder<NoFolder> Builder(BI); @@ -2910,8 +3071,12 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Remove speculated dbg intrinsics. // FIXME: Is it possible to do this in a more elegant way? Moving/merging the // dbg value for the different flows and inserting it after the select. - for (Instruction *I : SpeculatedDbgIntrinsics) - I->eraseFromParent(); + for (Instruction *I : SpeculatedDbgIntrinsics) { + // We still want to know that an assignment took place so don't remove + // dbg.assign intrinsics. + if (!isa<DbgAssignIntrinsic>(I)) + I->eraseFromParent(); + } ++NumSpeculations; return true; @@ -2920,15 +3085,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, /// Return true if we can thread a branch across this block. static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { int Size = 0; - - SmallPtrSet<const Value *, 32> EphValues; - auto IsEphemeral = [&](const Instruction *I) { - if (isa<AssumeInst>(I)) - return true; - return !I->mayHaveSideEffects() && !I->isTerminator() && - all_of(I->users(), - [&](const User *U) { return EphValues.count(U); }); - }; + EphemeralValueTracker EphTracker; // Walk the loop in reverse so that we can identify ephemeral values properly // (values only feeding assumes). @@ -2939,11 +3096,9 @@ static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { return false; // Ignore ephemeral values which are deleted during codegen. - if (IsEphemeral(&I)) - EphValues.insert(&I); // We will delete Phis while threading, so Phis should not be accounted in // block's size. - else if (!isa<PHINode>(I)) { + if (!EphTracker.track(&I) && !isa<PHINode>(I)) { if (Size++ > MaxSmallBlockSize) return false; // Don't clone large BB's. } @@ -2983,7 +3138,7 @@ static ConstantInt *getKnownValueOnEdge(Value *V, BasicBlock *From, /// If we have a conditional branch on something for which we know the constant /// value in predecessors (e.g. a phi node in the current block), thread edges /// from the predecessor to their ultimate destination. -static Optional<bool> +static std::optional<bool> FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL, AssumptionCache *AC) { @@ -3089,7 +3244,7 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, } if (N) { // Insert the new instruction into its new home. - EdgeBB->getInstList().insert(InsertPt, N); + N->insertInto(EdgeBB, InsertPt); // Register the new instruction with the assumption cache if necessary. if (auto *Assume = dyn_cast<AssumeInst>(N)) @@ -3117,7 +3272,7 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, MergeBlockIntoPredecessor(EdgeBB, DTU); // Signal repeat, simplifying any other constants. - return None; + return std::nullopt; } return false; @@ -3127,13 +3282,13 @@ static bool FoldCondBranchOnValueKnownInPredecessor(BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL, AssumptionCache *AC) { - Optional<bool> Result; + std::optional<bool> Result; bool EverChanged = false; do { // Note that None means "we changed things, but recurse further." Result = FoldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, AC); - EverChanged |= Result == None || *Result; - } while (Result == None); + EverChanged |= Result == std::nullopt || *Result; + } while (Result == std::nullopt); return EverChanged; } @@ -3174,7 +3329,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // from the block that we know is predictably not entered. if (!DomBI->getMetadata(LLVMContext::MD_unpredictable)) { uint64_t TWeight, FWeight; - if (DomBI->extractProfMetadata(TWeight, FWeight) && + if (extractBranchWeights(*DomBI, TWeight, FWeight) && (TWeight + FWeight) != 0) { BranchProbability BITrueProb = BranchProbability::getBranchProbability(TWeight, TWeight + FWeight); @@ -3354,9 +3509,9 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI, uint64_t &SuccTrueWeight, uint64_t &SuccFalseWeight) { bool PredHasWeights = - PBI->extractProfMetadata(PredTrueWeight, PredFalseWeight); + extractBranchWeights(*PBI, PredTrueWeight, PredFalseWeight); bool SuccHasWeights = - BI->extractProfMetadata(SuccTrueWeight, SuccFalseWeight); + extractBranchWeights(*BI, SuccTrueWeight, SuccFalseWeight); if (PredHasWeights || SuccHasWeights) { if (!PredHasWeights) PredTrueWeight = PredFalseWeight = 1; @@ -3371,7 +3526,7 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI, /// Determine if the two branches share a common destination and deduce a glue /// that joins the branches' conditions to arrive at the common destination if /// that would be profitable. -static Optional<std::pair<Instruction::BinaryOps, bool>> +static std::optional<std::tuple<BasicBlock *, Instruction::BinaryOps, bool>> shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI, const TargetTransformInfo *TTI) { assert(BI && PBI && BI->isConditional() && PBI->isConditional() && @@ -3384,7 +3539,7 @@ shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI, uint64_t PTWeight, PFWeight; BranchProbability PBITrueProb, Likely; if (TTI && !PBI->getMetadata(LLVMContext::MD_unpredictable) && - PBI->extractProfMetadata(PTWeight, PFWeight) && + extractBranchWeights(*PBI, PTWeight, PFWeight) && (PTWeight + PFWeight) != 0) { PBITrueProb = BranchProbability::getBranchProbability(PTWeight, PTWeight + PFWeight); @@ -3394,21 +3549,21 @@ shouldFoldCondBranchesToCommonDestination(BranchInst *BI, BranchInst *PBI, if (PBI->getSuccessor(0) == BI->getSuccessor(0)) { // Speculate the 2nd condition unless the 1st is probably true. if (PBITrueProb.isUnknown() || PBITrueProb < Likely) - return {{Instruction::Or, false}}; + return {{BI->getSuccessor(0), Instruction::Or, false}}; } else if (PBI->getSuccessor(1) == BI->getSuccessor(1)) { // Speculate the 2nd condition unless the 1st is probably false. if (PBITrueProb.isUnknown() || PBITrueProb.getCompl() < Likely) - return {{Instruction::And, false}}; + return {{BI->getSuccessor(1), Instruction::And, false}}; } else if (PBI->getSuccessor(0) == BI->getSuccessor(1)) { // Speculate the 2nd condition unless the 1st is probably true. if (PBITrueProb.isUnknown() || PBITrueProb < Likely) - return {{Instruction::And, true}}; + return {{BI->getSuccessor(1), Instruction::And, true}}; } else if (PBI->getSuccessor(1) == BI->getSuccessor(0)) { // Speculate the 2nd condition unless the 1st is probably false. if (PBITrueProb.isUnknown() || PBITrueProb.getCompl() < Likely) - return {{Instruction::Or, true}}; + return {{BI->getSuccessor(0), Instruction::Or, true}}; } - return None; + return std::nullopt; } static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, @@ -3419,9 +3574,10 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, BasicBlock *PredBlock = PBI->getParent(); // Determine if the two branches share a common destination. + BasicBlock *CommonSucc; Instruction::BinaryOps Opc; bool InvertPredCond; - std::tie(Opc, InvertPredCond) = + std::tie(CommonSucc, Opc, InvertPredCond) = *shouldFoldCondBranchesToCommonDestination(BI, PBI, TTI); LLVM_DEBUG(dbgs() << "FOLDING BRANCH TO COMMON DEST:\n" << *PBI << *BB); @@ -3580,10 +3736,11 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, DomTreeUpdater *DTU, continue; // Determine if the two branches share a common destination. + BasicBlock *CommonSucc; Instruction::BinaryOps Opc; bool InvertPredCond; if (auto Recipe = shouldFoldCondBranchesToCommonDestination(BI, PBI, TTI)) - std::tie(Opc, InvertPredCond) = *Recipe; + std::tie(CommonSucc, Opc, InvertPredCond) = *Recipe; else continue; @@ -3593,7 +3750,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, DomTreeUpdater *DTU, Type *Ty = BI->getCondition()->getType(); InstructionCost Cost = TTI->getArithmeticInstrCost(Opc, Ty, CostKind); if (InvertPredCond && (!PBI->getCondition()->hasOneUse() || - !isa<CmpInst>(PBI->getCondition()))) + !isa<CmpInst>(PBI->getCondition()))) Cost += TTI->getArithmeticInstrCost(Instruction::Xor, Ty, CostKind); if (Cost > BranchFoldThreshold) @@ -3632,8 +3789,8 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, DomTreeUpdater *DTU, // Account for the cost of duplicating this instruction into each // predecessor. Ignore free instructions. - if (!TTI || - TTI->getUserCost(&I, CostKind) != TargetTransformInfo::TCC_Free) { + if (!TTI || TTI->getInstructionCost(&I, CostKind) != + TargetTransformInfo::TCC_Free) { NumBonusInsts += PredCount; // Early exits once we reach the limit. @@ -3805,7 +3962,8 @@ static bool mergeConditionalStoreToAddress( return false; // Not in white-list - not worthwhile folding. // And finally, if this is a non-free instruction that we are okay // speculating, ensure that we consider the speculation budget. - Cost += TTI.getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + Cost += + TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); if (Cost > Budget) return false; // Eagerly refuse to fold as soon as we're out of budget. } @@ -4004,6 +4162,11 @@ static bool tryWidenCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, return false; if (!IfFalseBB->phis().empty()) return false; // TODO + // This helps avoid infinite loop with SimplifyCondBranchToCondBranch which + // may undo the transform done here. + // TODO: There might be a more fine-grained solution to this. + if (!llvm::succ_empty(IfFalseBB)) + return false; // Use lambda to lazily compute expensive condition after cheap ones. auto NoSideEffects = [](BasicBlock &BB) { return llvm::none_of(BB, [](const Instruction &I) { @@ -4349,7 +4512,7 @@ bool SimplifyCFGOpt::SimplifySwitchOnSelect(SwitchInst *SI, // Get weight for TrueBB and FalseBB. uint32_t TrueWeight = 0, FalseWeight = 0; SmallVector<uint64_t, 8> Weights; - bool HasWeights = HasBranchWeights(SI); + bool HasWeights = hasBranchWeightMD(*SI); if (HasWeights) { GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { @@ -5021,7 +5184,9 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { DTU->applyUpdates(Updates); Updates.clear(); } - removeUnwindEdge(TI->getParent(), DTU); + auto *CI = cast<CallInst>(removeUnwindEdge(TI->getParent(), DTU)); + if (!CI->doesNotThrow()) + CI->setDoesNotThrow(); Changed = true; } } else if (auto *CSI = dyn_cast<CatchSwitchInst>(TI)) { @@ -5209,7 +5374,7 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI, BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest); // Update weight for the newly-created conditional branch. - if (HasBranchWeights(SI)) { + if (hasBranchWeightMD(*SI)) { SmallVector<uint64_t, 8> Weights; GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { @@ -5279,7 +5444,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, SmallVector<ConstantInt *, 8> DeadCases; SmallDenseMap<BasicBlock *, int, 8> NumPerSuccessorCases; SmallVector<BasicBlock *, 8> UniqueSuccessors; - for (auto &Case : SI->cases()) { + for (const auto &Case : SI->cases()) { auto *Successor = Case.getCaseSuccessor(); if (DTU) { if (!NumPerSuccessorCases.count(Successor)) @@ -5379,7 +5544,7 @@ static bool ForwardSwitchConditionToPHI(SwitchInst *SI) { ForwardingNodesMap ForwardingNodes; BasicBlock *SwitchBlock = SI->getParent(); bool Changed = false; - for (auto &Case : SI->cases()) { + for (const auto &Case : SI->cases()) { ConstantInt *CaseValue = Case.getCaseValue(); BasicBlock *CaseDest = Case.getCaseSuccessor(); @@ -5595,7 +5760,7 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI, const DataLayout &DL, const TargetTransformInfo &TTI, uintptr_t MaxUniqueResults) { - for (auto &I : SI->cases()) { + for (const auto &I : SI->cases()) { ConstantInt *CaseVal = I.getCaseValue(); // Resulting value at phi nodes for this case value. @@ -5684,13 +5849,13 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, if (isPowerOf2_32(CaseCount)) { ConstantInt *MinCaseVal = CaseValues[0]; // Find mininal value. - for (auto Case : CaseValues) + for (auto *Case : CaseValues) if (Case->getValue().slt(MinCaseVal->getValue())) MinCaseVal = Case; // Mark the bits case number touched. APInt BitMask = APInt::getZero(MinCaseVal->getBitWidth()); - for (auto Case : CaseValues) + for (auto *Case : CaseValues) BitMask |= (Case->getValue() - MinCaseVal->getValue()); // Check if cases with the same result can cover all number @@ -5956,7 +6121,7 @@ SwitchLookupTable::SwitchLookupTable( Array->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Set the alignment to that of an array items. We will be only loading one // value out of it. - Array->setAlignment(Align(DL.getPrefTypeAlignment(ValueType))); + Array->setAlignment(DL.getPrefTypeAlign(ValueType)); Kind = ArrayKind; } @@ -6501,7 +6666,7 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, // cases such as a sequence crossing zero {-4,0,4,8} if we interpret case values // as signed. SmallVector<int64_t,4> Values; - for (auto &C : SI->cases()) + for (const auto &C : SI->cases()) Values.push_back(C.getCaseValue()->getValue().getSExtValue()); llvm::sort(Values); @@ -6856,7 +7021,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // If this basic block has dominating predecessor blocks and the dominating // blocks' conditions imply BI's condition, we know the direction of BI. - Optional<bool> Imp = isImpliedByDomCondition(BI->getCondition(), BI, DL); + std::optional<bool> Imp = isImpliedByDomCondition(BI->getCondition(), BI, DL); if (Imp) { // Turn this into a branch on constant. auto *OldCond = BI->getCondition(); @@ -7023,7 +7188,7 @@ static bool removeUndefIntroducingPredecessor(BasicBlock *BB, IRBuilder<> Builder(T); if (BranchInst *BI = dyn_cast<BranchInst>(T)) { BB->removePredecessor(Predecessor); - // Turn uncoditional branches into unreachables and remove the dead + // Turn unconditional branches into unreachables and remove the dead // destination from conditional branches. if (BI->isUnconditional()) Builder.CreateUnreachable(); @@ -7050,7 +7215,7 @@ static bool removeUndefIntroducingPredecessor(BasicBlock *BB, Builder.SetInsertPoint(Unreachable); // The new block contains only one instruction: Unreachable Builder.CreateUnreachable(); - for (auto &Case : SI->cases()) + for (const auto &Case : SI->cases()) if (Case.getCaseSuccessor() == BB) { BB->removePredecessor(Predecessor); Case.setSuccessor(Unreachable); diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp index 0ab79a32f526..4e83d2f6e3c6 100644 --- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -106,13 +106,8 @@ static Instruction *findCommonDominator(ArrayRef<Instruction *> Instructions, DominatorTree &DT) { Instruction *CommonDom = nullptr; for (auto *Insn : Instructions) - if (!CommonDom || DT.dominates(Insn, CommonDom)) - CommonDom = Insn; - else if (!DT.dominates(CommonDom, Insn)) - // If there is no dominance relation, use common dominator. - CommonDom = - DT.findNearestCommonDominator(CommonDom->getParent(), - Insn->getParent())->getTerminator(); + CommonDom = + CommonDom ? DT.findNearestCommonDominator(CommonDom, Insn) : Insn; assert(CommonDom && "Common dominator not found?"); return CommonDom; } @@ -195,6 +190,9 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) bool SimplifyIndvar::makeIVComparisonInvariant(ICmpInst *ICmp, Instruction *IVOperand) { + auto *Preheader = L->getLoopPreheader(); + if (!Preheader) + return false; unsigned IVOperIdx = 0; ICmpInst::Predicate Pred = ICmp->getPredicate(); if (IVOperand != ICmp->getOperand(0)) { @@ -209,51 +207,22 @@ bool SimplifyIndvar::makeIVComparisonInvariant(ICmpInst *ICmp, const Loop *ICmpLoop = LI->getLoopFor(ICmp->getParent()); const SCEV *S = SE->getSCEVAtScope(ICmp->getOperand(IVOperIdx), ICmpLoop); const SCEV *X = SE->getSCEVAtScope(ICmp->getOperand(1 - IVOperIdx), ICmpLoop); - - auto *PN = dyn_cast<PHINode>(IVOperand); - if (!PN) - return false; - auto LIP = SE->getLoopInvariantPredicate(Pred, S, X, L); + auto LIP = SE->getLoopInvariantPredicate(Pred, S, X, L, ICmp); if (!LIP) return false; ICmpInst::Predicate InvariantPredicate = LIP->Pred; const SCEV *InvariantLHS = LIP->LHS; const SCEV *InvariantRHS = LIP->RHS; - // Rewrite the comparison to a loop invariant comparison if it can be done - // cheaply, where cheaply means "we don't need to emit any new - // instructions". - - SmallDenseMap<const SCEV*, Value*> CheapExpansions; - CheapExpansions[S] = ICmp->getOperand(IVOperIdx); - CheapExpansions[X] = ICmp->getOperand(1 - IVOperIdx); - - // TODO: Support multiple entry loops? (We currently bail out of these in - // the IndVarSimplify pass) - if (auto *BB = L->getLoopPredecessor()) { - const int Idx = PN->getBasicBlockIndex(BB); - if (Idx >= 0) { - Value *Incoming = PN->getIncomingValue(Idx); - const SCEV *IncomingS = SE->getSCEV(Incoming); - CheapExpansions[IncomingS] = Incoming; - } - } - Value *NewLHS = CheapExpansions[InvariantLHS]; - Value *NewRHS = CheapExpansions[InvariantRHS]; - - if (!NewLHS) - if (auto *ConstLHS = dyn_cast<SCEVConstant>(InvariantLHS)) - NewLHS = ConstLHS->getValue(); - if (!NewRHS) - if (auto *ConstRHS = dyn_cast<SCEVConstant>(InvariantRHS)) - NewRHS = ConstRHS->getValue(); - - if (!NewLHS || !NewRHS) - // We could not find an existing value to replace either LHS or RHS. - // Generating new instructions has subtler tradeoffs, so avoid doing that - // for now. + // Do not generate something ridiculous. + auto *PHTerm = Preheader->getTerminator(); + if (Rewriter.isHighCostExpansion({ InvariantLHS, InvariantRHS }, L, + 2 * SCEVCheapExpansionBudget, TTI, PHTerm)) return false; - + auto *NewLHS = + Rewriter.expandCodeFor(InvariantLHS, IVOperand->getType(), PHTerm); + auto *NewRHS = + Rewriter.expandCodeFor(InvariantRHS, IVOperand->getType(), PHTerm); LLVM_DEBUG(dbgs() << "INDVARS: Simplified comparison: " << *ICmp << '\n'); ICmp->setPredicate(InvariantPredicate); ICmp->setOperand(0, NewLHS); @@ -288,6 +257,7 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Users.push_back(cast<Instruction>(U)); const Instruction *CtxI = findCommonDominator(Users, *DT); if (auto Ev = SE->evaluatePredicateAt(Pred, S, X, CtxI)) { + SE->forgetValue(ICmp); ICmp->replaceAllUsesWith(ConstantInt::getBool(ICmp->getContext(), *Ev)); DeadInsts.emplace_back(ICmp); LLVM_DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n'); @@ -683,7 +653,7 @@ bool SimplifyIndvar::replaceFloatIVWithIntegerIV(Instruction *UseInst) { UseInst->getOpcode() != CastInst::UIToFP) return false; - Value *IVOperand = UseInst->getOperand(0); + Instruction *IVOperand = cast<Instruction>(UseInst->getOperand(0)); // Get the symbolic expression for this instruction. const SCEV *IV = SE->getSCEV(IVOperand); unsigned MaskBits; @@ -696,17 +666,35 @@ bool SimplifyIndvar::replaceFloatIVWithIntegerIV(Instruction *UseInst) { for (User *U : UseInst->users()) { // Match for fptosi/fptoui of sitofp and with same type. auto *CI = dyn_cast<CastInst>(U); - if (!CI || IVOperand->getType() != CI->getType()) + if (!CI) continue; CastInst::CastOps Opcode = CI->getOpcode(); if (Opcode != CastInst::FPToSI && Opcode != CastInst::FPToUI) continue; - CI->replaceAllUsesWith(IVOperand); + Value *Conv = nullptr; + if (IVOperand->getType() != CI->getType()) { + IRBuilder<> Builder(CI); + StringRef Name = IVOperand->getName(); + // To match InstCombine logic, we only need sext if both fptosi and + // sitofp are used. If one of them is unsigned, then we can use zext. + if (SE->getTypeSizeInBits(IVOperand->getType()) > + SE->getTypeSizeInBits(CI->getType())) { + Conv = Builder.CreateTrunc(IVOperand, CI->getType(), Name + ".trunc"); + } else if (Opcode == CastInst::FPToUI || + UseInst->getOpcode() == CastInst::UIToFP) { + Conv = Builder.CreateZExt(IVOperand, CI->getType(), Name + ".zext"); + } else { + Conv = Builder.CreateSExt(IVOperand, CI->getType(), Name + ".sext"); + } + } else + Conv = IVOperand; + + CI->replaceAllUsesWith(Conv); DeadInsts.push_back(CI); LLVM_DEBUG(dbgs() << "INDVARS: Replace IV user: " << *CI - << " with: " << *IVOperand << '\n'); + << " with: " << *Conv << '\n'); ++NumFoldedUser; Changed = true; @@ -751,6 +739,7 @@ bool SimplifyIndvar::eliminateIdentitySCEV(Instruction *UseInst, LLVM_DEBUG(dbgs() << "INDVARS: Eliminated identity: " << *UseInst << '\n'); + SE->forgetValue(UseInst); UseInst->replaceAllUsesWith(IVOperand); ++NumElimIdentity; Changed = true; @@ -1041,13 +1030,13 @@ class WidenIV { // context. DenseMap<DefUserPair, ConstantRange> PostIncRangeInfos; - Optional<ConstantRange> getPostIncRangeInfo(Value *Def, - Instruction *UseI) { + std::optional<ConstantRange> getPostIncRangeInfo(Value *Def, + Instruction *UseI) { DefUserPair Key(Def, UseI); auto It = PostIncRangeInfos.find(Key); return It == PostIncRangeInfos.end() - ? Optional<ConstantRange>(None) - : Optional<ConstantRange>(It->second); + ? std::optional<ConstantRange>(std::nullopt) + : std::optional<ConstantRange>(It->second); } void calculatePostIncRanges(PHINode *OrigPhi); diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 03087d8370d5..20f18322d43c 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -33,6 +33,8 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SizeOpts.h" +#include <cmath> + using namespace llvm; using namespace PatternMatch; @@ -89,10 +91,12 @@ static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr, // Fail for an invalid base (required by POSIX). return nullptr; + // Current offset into the original string to reflect in EndPtr. + size_t Offset = 0; // Strip leading whitespace. - for (unsigned i = 0; i != Str.size(); ++i) - if (!isSpace((unsigned char)Str[i])) { - Str = Str.substr(i); + for ( ; Offset != Str.size(); ++Offset) + if (!isSpace((unsigned char)Str[Offset])) { + Str = Str.substr(Offset); break; } @@ -108,6 +112,7 @@ static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr, if (Str.empty()) // Fail for a sign with nothing after it. return nullptr; + ++Offset; } // Set Max to the absolute value of the minimum (for signed), or @@ -127,6 +132,7 @@ static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr, return nullptr; Str = Str.drop_front(2); + Offset += 2; Base = 16; } else if (Base == 0) @@ -167,7 +173,7 @@ static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr, if (EndPtr) { // Store the pointer to the end. - Value *Off = B.getInt64(Str.size()); + Value *Off = B.getInt64(Offset + Str.size()); Value *StrBeg = CI->getArgOperand(0); Value *StrEnd = B.CreateInBoundsGEP(B.getInt8Ty(), StrBeg, Off, "endptr"); B.CreateStore(StrEnd, EndPtr); @@ -241,13 +247,14 @@ static void annotateNonNullNoUndefBasedOnAccess(CallInst *CI, if (!CI->paramHasAttr(ArgNo, Attribute::NoUndef)) CI->addParamAttr(ArgNo, Attribute::NoUndef); - if (CI->paramHasAttr(ArgNo, Attribute::NonNull)) - continue; - unsigned AS = CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace(); - if (llvm::NullPointerIsDefined(F, AS)) - continue; + if (!CI->paramHasAttr(ArgNo, Attribute::NonNull)) { + unsigned AS = + CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace(); + if (llvm::NullPointerIsDefined(F, AS)) + continue; + CI->addParamAttr(ArgNo, Attribute::NonNull); + } - CI->addParamAttr(ArgNo, Attribute::NonNull); annotateDereferenceableBytes(CI, ArgNo, 1); } } @@ -281,6 +288,13 @@ static Value *copyFlags(const CallInst &Old, Value *New) { return New; } +static Value *mergeAttributesAndFlags(CallInst *NewCI, const CallInst &Old) { + NewCI->setAttributes(AttributeList::get( + NewCI->getContext(), {NewCI->getAttributes(), Old.getAttributes()})); + NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); + return copyFlags(Old, NewCI); +} + // Helper to avoid truncating the length if size_t is 32-bits. static StringRef substr(StringRef Str, uint64_t Len) { return Len >= Str.size() ? Str : Str.substr(0, Len); @@ -420,14 +434,16 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) { Function *Callee = CI->getCalledFunction(); FunctionType *FT = Callee->getFunctionType(); - if (!FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32. + unsigned IntBits = TLI->getIntSize(); + if (!FT->getParamType(1)->isIntegerTy(IntBits)) // memchr needs 'int'. return nullptr; - return copyFlags( - *CI, - emitMemChr(SrcStr, CharVal, // include nul. - ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len), B, - DL, TLI)); + unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule()); + Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits); + return copyFlags(*CI, + emitMemChr(SrcStr, CharVal, // include nul. + ConstantInt::get(SizeTTy, Len), B, + DL, TLI)); } if (CharC->isZero()) { @@ -474,11 +490,13 @@ Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilderBase &B) { return nullptr; } + unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule()); + Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits); + // Try to expand strrchr to the memrchr nonstandard extension if it's // available, or simply fail otherwise. uint64_t NBytes = Str.size() + 1; // Include the terminating nul. - Type *IntPtrType = DL.getIntPtrType(CI->getContext()); - Value *Size = ConstantInt::get(IntPtrType, NBytes); + Value *Size = ConstantInt::get(SizeTTy, NBytes); return copyFlags(*CI, emitMemRChr(SrcStr, CharVal, Size, B, DL, TLI)); } @@ -493,7 +511,8 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) { // strcmp(x, y) -> cnst (if both x and y are constant strings) if (HasStr1 && HasStr2) - return ConstantInt::get(CI->getType(), Str1.compare(Str2)); + return ConstantInt::get(CI->getType(), + std::clamp(Str1.compare(Str2), -1, 1)); if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x return B.CreateNeg(B.CreateZExt( @@ -577,7 +596,8 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { // Avoid truncating the 64-bit Length to 32 bits in ILP32. StringRef SubStr1 = substr(Str1, Length); StringRef SubStr2 = substr(Str2, Length); - return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); + return ConstantInt::get(CI->getType(), + std::clamp(SubStr1.compare(SubStr2), -1, 1)); } if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x @@ -648,9 +668,7 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) { CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1), ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len)); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return Dst; } @@ -682,44 +700,145 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) { // We have enough information to now generate the memcpy call to do the // copy for us. Make a memcpy to copy the nul byte with align = 1. CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1), LenV); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return DstEnd; } -Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilderBase &B) { +// Optimize a call to size_t strlcpy(char*, const char*, size_t). + +Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) { + Value *Size = CI->getArgOperand(2); + if (isKnownNonZero(Size, DL)) + // Like snprintf, the function stores into the destination only when + // the size argument is nonzero. + annotateNonNullNoUndefBasedOnAccess(CI, 0); + // The function reads the source argument regardless of Size (it returns + // its length). + annotateNonNullNoUndefBasedOnAccess(CI, 1); + + uint64_t NBytes; + if (ConstantInt *SizeC = dyn_cast<ConstantInt>(Size)) + NBytes = SizeC->getZExtValue(); + else + return nullptr; + + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); + if (NBytes <= 1) { + if (NBytes == 1) + // For a call to strlcpy(D, S, 1) first store a nul in *D. + B.CreateStore(B.getInt8(0), Dst); + + // Transform strlcpy(D, S, 0) to a call to strlen(S). + return copyFlags(*CI, emitStrLen(Src, B, DL, TLI)); + } + + // Try to determine the length of the source, substituting its size + // when it's not nul-terminated (as it's required to be) to avoid + // reading past its end. + StringRef Str; + if (!getConstantStringInfo(Src, Str, /*TrimAtNul=*/false)) + return nullptr; + + uint64_t SrcLen = Str.find('\0'); + // Set if the terminating nul should be copied by the call to memcpy + // below. + bool NulTerm = SrcLen < NBytes; + + if (NulTerm) + // Overwrite NBytes with the number of bytes to copy, including + // the terminating nul. + NBytes = SrcLen + 1; + else { + // Set the length of the source for the function to return to its + // size, and cap NBytes at the same. + SrcLen = std::min(SrcLen, uint64_t(Str.size())); + NBytes = std::min(NBytes - 1, SrcLen); + } + + if (SrcLen == 0) { + // Transform strlcpy(D, "", N) to (*D = '\0, 0). + B.CreateStore(B.getInt8(0), Dst); + return ConstantInt::get(CI->getType(), 0); + } + + Function *Callee = CI->getCalledFunction(); + Type *PT = Callee->getFunctionType()->getParamType(0); + // Transform strlcpy(D, S, N) to memcpy(D, S, N') where N' is the lower + // bound on strlen(S) + 1 and N, optionally followed by a nul store to + // D[N' - 1] if necessary. + CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1), + ConstantInt::get(DL.getIntPtrType(PT), NBytes)); + mergeAttributesAndFlags(NewCI, *CI); + + if (!NulTerm) { + Value *EndOff = ConstantInt::get(CI->getType(), NBytes); + Value *EndPtr = B.CreateInBoundsGEP(B.getInt8Ty(), Dst, EndOff); + B.CreateStore(B.getInt8(0), EndPtr); + } + + // Like snprintf, strlcpy returns the number of nonzero bytes that would + // have been copied if the bound had been sufficiently big (which in this + // case is strlen(Src)). + return ConstantInt::get(CI->getType(), SrcLen); +} + +// Optimize a call CI to either stpncpy when RetEnd is true, or to strncpy +// otherwise. +Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd, + IRBuilderBase &B) { Function *Callee = CI->getCalledFunction(); Value *Dst = CI->getArgOperand(0); Value *Src = CI->getArgOperand(1); Value *Size = CI->getArgOperand(2); - annotateNonNullNoUndefBasedOnAccess(CI, 0); - if (isKnownNonZero(Size, DL)) + + if (isKnownNonZero(Size, DL)) { + // Both st{p,r}ncpy(D, S, N) access the source and destination arrays + // only when N is nonzero. + annotateNonNullNoUndefBasedOnAccess(CI, 0); annotateNonNullNoUndefBasedOnAccess(CI, 1); + } - uint64_t Len; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size)) - Len = LengthArg->getZExtValue(); - else - return nullptr; + // If the "bound" argument is known set N to it. Otherwise set it to + // UINT64_MAX and handle it later. + uint64_t N = UINT64_MAX; + if (ConstantInt *SizeC = dyn_cast<ConstantInt>(Size)) + N = SizeC->getZExtValue(); - // strncpy(x, y, 0) -> x - if (Len == 0) + if (N == 0) + // Fold st{p,r}ncpy(D, S, 0) to D. return Dst; - // See if we can get the length of the input string. + if (N == 1) { + Type *CharTy = B.getInt8Ty(); + Value *CharVal = B.CreateLoad(CharTy, Src, "stxncpy.char0"); + B.CreateStore(CharVal, Dst); + if (!RetEnd) + // Transform strncpy(D, S, 1) to return (*D = *S), D. + return Dst; + + // Transform stpncpy(D, S, 1) to return (*D = *S) ? D + 1 : D. + Value *ZeroChar = ConstantInt::get(CharTy, 0); + Value *Cmp = B.CreateICmpEQ(CharVal, ZeroChar, "stpncpy.char0cmp"); + + Value *Off1 = B.getInt32(1); + Value *EndPtr = B.CreateInBoundsGEP(CharTy, Dst, Off1, "stpncpy.end"); + return B.CreateSelect(Cmp, Dst, EndPtr, "stpncpy.sel"); + } + + // If the length of the input string is known set SrcLen to it. uint64_t SrcLen = GetStringLength(Src); - if (SrcLen) { + if (SrcLen) annotateDereferenceableBytes(CI, 1, SrcLen); - --SrcLen; // Unbias length. - } else { + else return nullptr; - } + + --SrcLen; // Unbias length. if (SrcLen == 0) { - // strncpy(x, "", y) -> memset(x, '\0', y) + // Transform st{p,r}ncpy(D, "", N) to memset(D, '\0', N) for any N. Align MemSetAlign = - CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne(); + CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne(); CallInst *NewCI = B.CreateMemSet(Dst, B.getInt8('\0'), Size, MemSetAlign); AttrBuilder ArgAttrs(CI->getContext(), CI->getAttributes().getParamAttrs(0)); NewCI->setAttributes(NewCI->getAttributes().addParamAttributes( @@ -728,28 +847,35 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilderBase &B) { return Dst; } - // strncpy(a, "a", 4) - > memcpy(a, "a\0\0\0", 4) - if (Len > SrcLen + 1) { - if (Len <= 128) { - StringRef Str; - if (!getConstantStringInfo(Src, Str)) - return nullptr; - std::string SrcStr = Str.str(); - SrcStr.resize(Len, '\0'); - Src = B.CreateGlobalString(SrcStr, "str"); - } else { + if (N > SrcLen + 1) { + if (N > 128) + // Bail if N is large or unknown. return nullptr; - } + + // st{p,r}ncpy(D, "a", N) -> memcpy(D, "a\0\0\0", N) for N <= 128. + StringRef Str; + if (!getConstantStringInfo(Src, Str)) + return nullptr; + std::string SrcStr = Str.str(); + // Create a bigger, nul-padded array with the same length, SrcLen, + // as the original string. + SrcStr.resize(N, '\0'); + Src = B.CreateGlobalString(SrcStr, "str"); } Type *PT = Callee->getFunctionType()->getParamType(0); - // strncpy(x, s, c) -> memcpy(align 1 x, align 1 s, c) [s and c are constant] + // st{p,r}ncpy(D, S, N) -> memcpy(align 1 D, align 1 S, N) when both + // S and N are constant. CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1), - ConstantInt::get(DL.getIntPtrType(PT), Len)); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); - return Dst; + ConstantInt::get(DL.getIntPtrType(PT), N)); + mergeAttributesAndFlags(NewCI, *CI); + if (!RetEnd) + return Dst; + + // stpncpy(D, S, N) returns the address of the first null in D if it writes + // one, otherwise D + N. + Value *Off = B.getInt64(std::min(SrcLen, N)); + return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, Off, "endptr"); } Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, @@ -803,9 +929,9 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, // strlen(s + x) to strlen(s) - x, when x is known to be in the range // [0, strlen(s)] or the string has a single null terminator '\0' at the end. // We only try to simplify strlen when the pointer s points to an array - // of i8. Otherwise, we would need to scale the offset x before doing the - // subtraction. This will make the optimization more complex, and it's not - // very useful because calling strlen for a pointer of other types is + // of CharSize elements. Otherwise, we would need to scale the offset x before + // doing the subtraction. This will make the optimization more complex, and + // it's not very useful because calling strlen for a pointer of other types is // very uncommon. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) { // TODO: Handle subobjects. @@ -1060,7 +1186,7 @@ Value *LibCallSimplifier::optimizeMemRChr(CallInst *CI, IRBuilderBase &B) { } StringRef Str; - if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) + if (!getConstantStringInfo(SrcStr, Str, /*TrimAtNul=*/false)) return nullptr; if (Str.size() == 0) @@ -1155,7 +1281,7 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { } StringRef Str; - if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) + if (!getConstantStringInfo(SrcStr, Str, /*TrimAtNul=*/false)) return nullptr; if (CharC) { @@ -1294,8 +1420,8 @@ static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS, return Constant::getNullValue(CI->getType()); StringRef LStr, RStr; - if (!getConstantStringInfo(LHS, LStr, 0, /*TrimAtNul=*/false) || - !getConstantStringInfo(RHS, RStr, 0, /*TrimAtNul=*/false)) + if (!getConstantStringInfo(LHS, LStr, /*TrimAtNul=*/false) || + !getConstantStringInfo(RHS, RStr, /*TrimAtNul=*/false)) return nullptr; // If the contents of both constant arrays are known, fold a call to @@ -1351,7 +1477,7 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, // to legal integers or equality comparison. See block below this. if (DL.isLegalInteger(Len * 8) && isOnlyUsedInZeroEqualityComparison(CI)) { IntegerType *IntType = IntegerType::get(CI->getContext(), Len * 8); - unsigned PrefAlignment = DL.getPrefTypeAlignment(IntType); + Align PrefAlignment = DL.getPrefTypeAlign(IntType); // First, see if we can fold either argument to a constant. Value *LHSV = nullptr; @@ -1437,9 +1563,7 @@ Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilderBase &B) { // memcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n) CallInst *NewCI = B.CreateMemCpy(CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), Size); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return CI->getArgOperand(0); } @@ -1455,8 +1579,7 @@ Value *LibCallSimplifier::optimizeMemCCpy(CallInst *CI, IRBuilderBase &B) { if (N) { if (N->isNullValue()) return Constant::getNullValue(CI->getType()); - if (!getConstantStringInfo(Src, SrcStr, /*Offset=*/0, - /*TrimAtNul=*/false) || + if (!getConstantStringInfo(Src, SrcStr, /*TrimAtNul=*/false) || // TODO: Handle zeroinitializer. !StopChar) return nullptr; @@ -1493,9 +1616,7 @@ Value *LibCallSimplifier::optimizeMemPCpy(CallInst *CI, IRBuilderBase &B) { // Propagate attributes, but memcpy has no return value, so make sure that // any return attributes are compliant. // TODO: Attach return value attributes to the 1st operand to preserve them? - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, N); } @@ -1508,9 +1629,7 @@ Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilderBase &B) { // memmove(x, y, n) -> llvm.memmove(align 1 x, align 1 y, n) CallInst *NewCI = B.CreateMemMove(CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), Size); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return CI->getArgOperand(0); } @@ -1523,9 +1642,7 @@ Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilderBase &B) { // memset(p, v, n) -> llvm.memset(align 1 p, v, n) Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val, Size, Align(1)); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return CI->getArgOperand(0); } @@ -1741,7 +1858,6 @@ static Value *getIntToFPVal(Value *I2F, IRBuilderBase &B, unsigned DstWidth) { Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { Module *M = Pow->getModule(); Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); - AttributeList Attrs; // Attributes are only meaningful on the original call Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); bool Ignored; @@ -1766,8 +1882,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { LibFunc LibFn; Function *CalleeFn = BaseFn->getCalledFunction(); - if (CalleeFn && - TLI->getLibFunc(CalleeFn->getName(), LibFn) && + if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && isLibFuncEmittable(M, TLI, LibFn)) { StringRef ExpName; Intrinsic::ID ID; @@ -1777,14 +1892,18 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { switch (LibFn) { default: return nullptr; - case LibFunc_expf: case LibFunc_exp: case LibFunc_expl: + case LibFunc_expf: + case LibFunc_exp: + case LibFunc_expl: ExpName = TLI->getName(LibFunc_exp); ID = Intrinsic::exp; LibFnFloat = LibFunc_expf; LibFnDouble = LibFunc_exp; LibFnLongDouble = LibFunc_expl; break; - case LibFunc_exp2f: case LibFunc_exp2: case LibFunc_exp2l: + case LibFunc_exp2f: + case LibFunc_exp2: + case LibFunc_exp2l: ExpName = TLI->getName(LibFunc_exp2); ID = Intrinsic::exp2; LibFnFloat = LibFunc_exp2f; @@ -1817,6 +1936,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { if (!match(Pow->getArgOperand(0), m_APFloat(BaseF))) return nullptr; + AttributeList NoAttrs; // Attributes are only meaningful on the original call + // pow(2.0, itofp(x)) -> ldexp(1.0, x) if (match(Base, m_SpecificFP(2.0)) && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) && @@ -1825,7 +1946,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { return copyFlags(*Pow, emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), ExpoI, TLI, LibFunc_ldexp, LibFunc_ldexpf, - LibFunc_ldexpl, B, Attrs)); + LibFunc_ldexpl, B, NoAttrs)); } // pow(2.0 ** n, x) -> exp2(n * x) @@ -1849,7 +1970,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { else return copyFlags(*Pow, emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, - LibFunc_exp2l, B, Attrs)); + LibFunc_exp2l, B, NoAttrs)); } } @@ -1859,7 +1980,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { hasFloatFn(M, TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) return copyFlags(*Pow, emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l, - B, Attrs)); + B, NoAttrs)); // pow(x, y) -> exp2(log2(x) * y) if (Pow->hasApproxFunc() && Pow->hasNoNaNs() && BaseF->isFiniteNonZero() && @@ -1885,7 +2006,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { LibFunc_exp2l)) return copyFlags(*Pow, emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, - LibFunc_exp2l, B, Attrs)); + LibFunc_exp2l, B, NoAttrs)); } } @@ -1917,7 +2038,6 @@ static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, /// Use square root in place of pow(x, +/-0.5). Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); - AttributeList Attrs; // Attributes are only meaningful on the original call Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); @@ -1939,7 +2059,8 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { !isKnownNeverInfinity(Base, TLI)) return nullptr; - Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI); + Sqrt = getSqrtCall(Base, AttributeList(), Pow->doesNotAccessMemory(), Mod, B, + TLI); if (!Sqrt) return nullptr; @@ -2045,8 +2166,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { return nullptr; ExpoF = &ExpoI; - Sqrt = getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(), - Pow->doesNotAccessMemory(), M, B, TLI); + Sqrt = getSqrtCall(Base, AttributeList(), Pow->doesNotAccessMemory(), M, + B, TLI); if (!Sqrt) return nullptr; } @@ -2090,7 +2211,6 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); - AttributeList Attrs; // Attributes are only meaningful on the original call StringRef Name = Callee->getName(); Value *Ret = nullptr; if (UnsafeFPShrink && Name == TLI->getName(LibFunc_exp2) && @@ -2100,14 +2220,14 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { Type *Ty = CI->getType(); Value *Op = CI->getArgOperand(0); - // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= IntSize - // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < IntSize + // exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= IntSize + // exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < IntSize if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) && hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize())) return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI, - LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl, - B, Attrs); + LibFunc_ldexp, LibFunc_ldexpf, + LibFunc_ldexpl, B, AttributeList()); } return Ret; @@ -2145,7 +2265,6 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) { Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Function *LogFn = Log->getCalledFunction(); - AttributeList Attrs; // Attributes are only meaningful on the original call StringRef LogNm = LogFn->getName(); Intrinsic::ID LogID = LogFn->getIntrinsicID(); Module *Mod = Log->getModule(); @@ -2256,12 +2375,13 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { TLI->getLibFunc(*Arg, ArgLb); // log(pow(x,y)) -> y*log(x) + AttributeList NoAttrs; if (ArgLb == PowLb || ArgID == Intrinsic::pow) { Value *LogX = Log->doesNotAccessMemory() ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), Arg->getOperand(0), "log") - : emitUnaryFloatFnCall(Arg->getOperand(0), TLI, LogNm, B, Attrs); + : emitUnaryFloatFnCall(Arg->getOperand(0), TLI, LogNm, B, NoAttrs); Value *MulY = B.CreateFMul(Arg->getArgOperand(1), LogX, "mul"); // Since pow() may have side effects, e.g. errno, // dead code elimination may not be trusted to remove it. @@ -2284,7 +2404,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Value *LogE = Log->doesNotAccessMemory() ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), Eul, "log") - : emitUnaryFloatFnCall(Eul, TLI, LogNm, B, Attrs); + : emitUnaryFloatFnCall(Eul, TLI, LogNm, B, NoAttrs); Value *MulY = B.CreateFMul(Arg->getArgOperand(0), LogE, "mul"); // Since exp() may have side effects, e.g. errno, // dead code elimination may not be trusted to remove it. @@ -2400,8 +2520,7 @@ static bool isTrigLibCall(CallInst *CI) { // We can only hope to do anything useful if we can ignore things like errno // and floating-point exceptions. // We already checked the prototype. - return CI->hasFnAttr(Attribute::NoUnwind) && - CI->hasFnAttr(Attribute::ReadNone); + return CI->doesNotThrow() && CI->doesNotAccessMemory(); } static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, @@ -2507,9 +2626,7 @@ void LibCallSimplifier::classifyArgUse( SmallVectorImpl<CallInst *> &SinCalls, SmallVectorImpl<CallInst *> &CosCalls, SmallVectorImpl<CallInst *> &SinCosCalls) { - CallInst *CI = dyn_cast<CallInst>(Val); - Module *M = CI->getModule(); - + auto *CI = dyn_cast<CallInst>(Val); if (!CI || CI->use_empty()) return; @@ -2517,6 +2634,7 @@ void LibCallSimplifier::classifyArgUse( if (CI->getFunction() != F) return; + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); LibFunc Func; if (!Callee || !TLI->getLibFunc(*Callee, Func) || @@ -2546,21 +2664,24 @@ void LibCallSimplifier::classifyArgUse( //===----------------------------------------------------------------------===// Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilderBase &B) { - // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 + // All variants of ffs return int which need not be 32 bits wide. + // ffs{,l,ll}(x) -> x != 0 ? (int)llvm.cttz(x)+1 : 0 + Type *RetType = CI->getType(); Value *Op = CI->getArgOperand(0); Type *ArgType = Op->getType(); Function *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), Intrinsic::cttz, ArgType); Value *V = B.CreateCall(F, {Op, B.getTrue()}, "cttz"); V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); - V = B.CreateIntCast(V, B.getInt32Ty(), false); + V = B.CreateIntCast(V, RetType, false); Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType)); - return B.CreateSelect(Cond, V, B.getInt32(0)); + return B.CreateSelect(Cond, V, ConstantInt::get(RetType, 0)); } Value *LibCallSimplifier::optimizeFls(CallInst *CI, IRBuilderBase &B) { - // fls(x) -> (i32)(sizeInBits(x) - llvm.ctlz(x, false)) + // All variants of fls return int which need not be 32 bits wide. + // fls{,l,ll}(x) -> (int)(sizeInBits(x) - llvm.ctlz(x, false)) Value *Op = CI->getArgOperand(0); Type *ArgType = Op->getType(); Function *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), @@ -2583,15 +2704,17 @@ Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilderBase &B) { Value *LibCallSimplifier::optimizeIsDigit(CallInst *CI, IRBuilderBase &B) { // isdigit(c) -> (c-'0') <u 10 Value *Op = CI->getArgOperand(0); - Op = B.CreateSub(Op, B.getInt32('0'), "isdigittmp"); - Op = B.CreateICmpULT(Op, B.getInt32(10), "isdigit"); + Type *ArgType = Op->getType(); + Op = B.CreateSub(Op, ConstantInt::get(ArgType, '0'), "isdigittmp"); + Op = B.CreateICmpULT(Op, ConstantInt::get(ArgType, 10), "isdigit"); return B.CreateZExt(Op, CI->getType()); } Value *LibCallSimplifier::optimizeIsAscii(CallInst *CI, IRBuilderBase &B) { // isascii(c) -> c <u 128 Value *Op = CI->getArgOperand(0); - Op = B.CreateICmpULT(Op, B.getInt32(128), "isascii"); + Type *ArgType = Op->getType(); + Op = B.CreateICmpULT(Op, ConstantInt::get(ArgType, 128), "isascii"); return B.CreateZExt(Op, CI->getType()); } @@ -2697,9 +2820,15 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) { if (!CI->use_empty()) return nullptr; + Type *IntTy = CI->getType(); // printf("x") -> putchar('x'), even for "%" and "%%". - if (FormatStr.size() == 1 || FormatStr == "%%") - return copyFlags(*CI, emitPutChar(B.getInt32(FormatStr[0]), B, TLI)); + if (FormatStr.size() == 1 || FormatStr == "%%") { + // Convert the character to unsigned char before passing it to putchar + // to avoid host-specific sign extension in the IR. Putchar converts + // it to unsigned char regardless. + Value *IntChar = ConstantInt::get(IntTy, (unsigned char)FormatStr[0]); + return copyFlags(*CI, emitPutChar(IntChar, B, TLI)); + } // Try to remove call or emit putchar/puts. if (FormatStr == "%s" && CI->arg_size() > 1) { @@ -2710,8 +2839,13 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) { if (OperandStr.empty()) return (Value *)CI; // printf("%s", "a") --> putchar('a') - if (OperandStr.size() == 1) - return copyFlags(*CI, emitPutChar(B.getInt32(OperandStr[0]), B, TLI)); + if (OperandStr.size() == 1) { + // Convert the character to unsigned char before passing it to putchar + // to avoid host-specific sign extension in the IR. Putchar converts + // it to unsigned char regardless. + Value *IntChar = ConstantInt::get(IntTy, (unsigned char)OperandStr[0]); + return copyFlags(*CI, emitPutChar(IntChar, B, TLI)); + } // printf("%s", str"\n") --> puts(str) if (OperandStr.back() == '\n') { OperandStr = OperandStr.drop_back(); @@ -2734,8 +2868,12 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) { // Optimize specific format strings. // printf("%c", chr) --> putchar(chr) if (FormatStr == "%c" && CI->arg_size() > 1 && - CI->getArgOperand(1)->getType()->isIntegerTy()) - return copyFlags(*CI, emitPutChar(CI->getArgOperand(1), B, TLI)); + CI->getArgOperand(1)->getType()->isIntegerTy()) { + // Convert the argument to the type expected by putchar, i.e., int, which + // need not be 32 bits wide but which is the same as printf's return type. + Value *IntChar = B.CreateIntCast(CI->getArgOperand(1), IntTy, false); + return copyFlags(*CI, emitPutChar(IntChar, B, TLI)); + } // printf("%s\n", str) --> puts(str) if (FormatStr == "%s\n" && CI->arg_size() > 1 && @@ -2753,6 +2891,8 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) { return V; } + annotateNonNullNoUndefBasedOnAccess(CI, 0); + // printf(format, ...) -> iprintf(format, ...) if no floating point // arguments. if (isLibFuncEmittable(M, TLI, LibFunc_iprintf) && @@ -2777,7 +2917,6 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) { return New; } - annotateNonNullNoUndefBasedOnAccess(CI, 0); return nullptr; } @@ -2876,6 +3015,8 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) { return V; } + annotateNonNullNoUndefBasedOnAccess(CI, {0, 1}); + // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating // point arguments. if (isLibFuncEmittable(M, TLI, LibFunc_siprintf) && @@ -2900,10 +3041,63 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) { return New; } - annotateNonNullNoUndefBasedOnAccess(CI, {0, 1}); return nullptr; } +// Transform an snprintf call CI with the bound N to format the string Str +// either to a call to memcpy, or to single character a store, or to nothing, +// and fold the result to a constant. A nonnull StrArg refers to the string +// argument being formatted. Otherwise the call is one with N < 2 and +// the "%c" directive to format a single character. +Value *LibCallSimplifier::emitSnPrintfMemCpy(CallInst *CI, Value *StrArg, + StringRef Str, uint64_t N, + IRBuilderBase &B) { + assert(StrArg || (N < 2 && Str.size() == 1)); + + unsigned IntBits = TLI->getIntSize(); + uint64_t IntMax = maxIntN(IntBits); + if (Str.size() > IntMax) + // Bail if the string is longer than INT_MAX. POSIX requires + // implementations to set errno to EOVERFLOW in this case, in + // addition to when N is larger than that (checked by the caller). + return nullptr; + + Value *StrLen = ConstantInt::get(CI->getType(), Str.size()); + if (N == 0) + return StrLen; + + // Set to the number of bytes to copy fron StrArg which is also + // the offset of the terinating nul. + uint64_t NCopy; + if (N > Str.size()) + // Copy the full string, including the terminating nul (which must + // be present regardless of the bound). + NCopy = Str.size() + 1; + else + NCopy = N - 1; + + Value *DstArg = CI->getArgOperand(0); + if (NCopy && StrArg) + // Transform the call to lvm.memcpy(dst, fmt, N). + copyFlags( + *CI, + B.CreateMemCpy( + DstArg, Align(1), StrArg, Align(1), + ConstantInt::get(DL.getIntPtrType(CI->getContext()), NCopy))); + + if (N > Str.size()) + // Return early when the whole format string, including the final nul, + // has been copied. + return StrLen; + + // Otherwise, when truncating the string append a terminating nul. + Type *Int8Ty = B.getInt8Ty(); + Value *NulOff = B.getIntN(IntBits, NCopy); + Value *DstEnd = B.CreateInBoundsGEP(Int8Ty, DstArg, NulOff, "endptr"); + B.CreateStore(ConstantInt::get(Int8Ty, 0), DstEnd); + return StrLen; +} + Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, IRBuilderBase &B) { // Check for size @@ -2912,78 +3106,66 @@ Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, return nullptr; uint64_t N = Size->getZExtValue(); + uint64_t IntMax = maxIntN(TLI->getIntSize()); + if (N > IntMax) + // Bail if the bound exceeds INT_MAX. POSIX requires implementations + // to set errno to EOVERFLOW in this case. + return nullptr; + + Value *DstArg = CI->getArgOperand(0); + Value *FmtArg = CI->getArgOperand(2); + // Check for a fixed format string. StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr)) + if (!getConstantStringInfo(FmtArg, FormatStr)) return nullptr; // If we just have a format string (nothing else crazy) transform it. if (CI->arg_size() == 3) { - // Make sure there's no % in the constant array. We could try to handle - // %% -> % in the future if we cared. if (FormatStr.contains('%')) - return nullptr; // we found a format specifier, bail out. - - if (N == 0) - return ConstantInt::get(CI->getType(), FormatStr.size()); - else if (N < FormatStr.size() + 1) + // Bail if the format string contains a directive and there are + // no arguments. We could handle "%%" in the future. return nullptr; - // snprintf(dst, size, fmt) -> llvm.memcpy(align 1 dst, align 1 fmt, - // strlen(fmt)+1) - copyFlags( - *CI, - B.CreateMemCpy( - CI->getArgOperand(0), Align(1), CI->getArgOperand(2), Align(1), - ConstantInt::get(DL.getIntPtrType(CI->getContext()), - FormatStr.size() + 1))); // Copy the null byte. - return ConstantInt::get(CI->getType(), FormatStr.size()); + return emitSnPrintfMemCpy(CI, FmtArg, FormatStr, N, B); } // The remaining optimizations require the format string to be "%s" or "%c" // and have an extra operand. - if (FormatStr.size() == 2 && FormatStr[0] == '%' && CI->arg_size() == 4) { - - // Decode the second character of the format string. - if (FormatStr[1] == 'c') { - if (N == 0) - return ConstantInt::get(CI->getType(), 1); - else if (N == 1) - return nullptr; - - // snprintf(dst, size, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 - if (!CI->getArgOperand(3)->getType()->isIntegerTy()) - return nullptr; - Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char"); - Value *Ptr = castToCStr(CI->getArgOperand(0), B); - B.CreateStore(V, Ptr); - Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); - B.CreateStore(B.getInt8(0), Ptr); + if (FormatStr.size() != 2 || FormatStr[0] != '%' || CI->arg_size() != 4) + return nullptr; - return ConstantInt::get(CI->getType(), 1); + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + if (N <= 1) { + // Use an arbitary string of length 1 to transform the call into + // either a nul store (N == 1) or a no-op (N == 0) and fold it + // to one. + StringRef CharStr("*"); + return emitSnPrintfMemCpy(CI, nullptr, CharStr, N, B); } - if (FormatStr[1] == 's') { - // snprintf(dest, size, "%s", str) to llvm.memcpy(dest, str, len+1, 1) - StringRef Str; - if (!getConstantStringInfo(CI->getArgOperand(3), Str)) - return nullptr; + // snprintf(dst, size, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 + if (!CI->getArgOperand(3)->getType()->isIntegerTy()) + return nullptr; + Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char"); + Value *Ptr = castToCStr(DstArg, B); + B.CreateStore(V, Ptr); + Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); + B.CreateStore(B.getInt8(0), Ptr); + return ConstantInt::get(CI->getType(), 1); + } - if (N == 0) - return ConstantInt::get(CI->getType(), Str.size()); - else if (N < Str.size() + 1) - return nullptr; + if (FormatStr[1] != 's') + return nullptr; - copyFlags( - *CI, B.CreateMemCpy(CI->getArgOperand(0), Align(1), - CI->getArgOperand(3), Align(1), - ConstantInt::get(CI->getType(), Str.size() + 1))); + Value *StrArg = CI->getArgOperand(3); + // snprintf(dest, size, "%s", str) to llvm.memcpy(dest, str, len+1, 1) + StringRef Str; + if (!getConstantStringInfo(StrArg, Str)) + return nullptr; - // The snprintf result is the unincremented number of bytes in the string. - return ConstantInt::get(CI->getType(), Str.size()); - } - } - return nullptr; + return emitSnPrintfMemCpy(CI, StrArg, Str, N, B); } Value *LibCallSimplifier::optimizeSnPrintF(CallInst *CI, IRBuilderBase &B) { @@ -3017,10 +3199,11 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, if (FormatStr.contains('%')) return nullptr; // We found a format specifier. + unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule()); + Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits); return copyFlags( *CI, emitFWrite(CI->getArgOperand(1), - ConstantInt::get(DL.getIntPtrType(CI->getContext()), - FormatStr.size()), + ConstantInt::get(SizeTTy, FormatStr.size()), CI->getArgOperand(0), B, DL, TLI)); } @@ -3031,11 +3214,13 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, // Decode the second character of the format string. if (FormatStr[1] == 'c') { - // fprintf(F, "%c", chr) --> fputc(chr, F) + // fprintf(F, "%c", chr) --> fputc((int)chr, F) if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; - return copyFlags( - *CI, emitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, TLI)); + Type *IntTy = B.getIntNTy(TLI->getIntSize()); + Value *V = B.CreateIntCast(CI->getArgOperand(2), IntTy, /*isSigned*/ true, + "chari"); + return copyFlags(*CI, emitFPutC(V, CI->getArgOperand(0), B, TLI)); } if (FormatStr[1] == 's') { @@ -3102,7 +3287,9 @@ Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilderBase &B) { if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) Value *Char = B.CreateLoad(B.getInt8Ty(), castToCStr(CI->getArgOperand(0), B), "char"); - Value *NewCI = emitFPutC(Char, CI->getArgOperand(3), B, TLI); + Type *IntTy = B.getIntNTy(TLI->getIntSize()); + Value *Cast = B.CreateIntCast(Char, IntTy, /*isSigned*/ true, "chari"); + Value *NewCI = emitFPutC(Cast, CI->getArgOperand(3), B, TLI); return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; } } @@ -3131,10 +3318,12 @@ Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilderBase &B) { return nullptr; // Known to have no uses (see above). + unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule()); + Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits); return copyFlags( *CI, emitFWrite(CI->getArgOperand(0), - ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len - 1), + ConstantInt::get(SizeTTy, Len - 1), CI->getArgOperand(1), B, DL, TLI)); } @@ -3146,8 +3335,12 @@ Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilderBase &B) { // Check for a constant string. // puts("") -> putchar('\n') StringRef Str; - if (getConstantStringInfo(CI->getArgOperand(0), Str) && Str.empty()) - return copyFlags(*CI, emitPutChar(B.getInt32('\n'), B, TLI)); + if (getConstantStringInfo(CI->getArgOperand(0), Str) && Str.empty()) { + // putchar takes an argument of the same type as puts returns, i.e., + // int, which need not be 32 bits wide. + Type *IntTy = CI->getType(); + return copyFlags(*CI, emitPutChar(ConstantInt::get(IntTy, '\n'), B, TLI)); + } return nullptr; } @@ -3194,8 +3387,12 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeStrCpy(CI, Builder); case LibFunc_stpcpy: return optimizeStpCpy(CI, Builder); + case LibFunc_strlcpy: + return optimizeStrLCpy(CI, Builder); + case LibFunc_stpncpy: + return optimizeStringNCpy(CI, /*RetEnd=*/true, Builder); case LibFunc_strncpy: - return optimizeStrNCpy(CI, Builder); + return optimizeStringNCpy(CI, /*RetEnd=*/false, Builder); case LibFunc_strlen: return optimizeStrLen(CI, Builder); case LibFunc_strnlen: @@ -3551,12 +3748,9 @@ void LibCallSimplifier::eraseFromParent(Instruction *I) { // Fortified Library Call Optimizations //===----------------------------------------------------------------------===// -bool -FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, - unsigned ObjSizeOp, - Optional<unsigned> SizeOp, - Optional<unsigned> StrOp, - Optional<unsigned> FlagOp) { +bool FortifiedLibCallSimplifier::isFortifiedCallFoldable( + CallInst *CI, unsigned ObjSizeOp, std::optional<unsigned> SizeOp, + std::optional<unsigned> StrOp, std::optional<unsigned> FlagOp) { // If this function takes a flag argument, the implementation may use it to // perform extra checks. Don't fold into the non-checking variant. if (FlagOp) { @@ -3601,9 +3795,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, CallInst *NewCI = B.CreateMemCpy(CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), CI->getArgOperand(2)); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return CI->getArgOperand(0); } return nullptr; @@ -3615,9 +3807,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, CallInst *NewCI = B.CreateMemMove(CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), CI->getArgOperand(2)); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return CI->getArgOperand(0); } return nullptr; @@ -3629,9 +3819,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), Align(1)); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - copyFlags(*CI, NewCI); + mergeAttributesAndFlags(NewCI, *CI); return CI->getArgOperand(0); } return nullptr; @@ -3643,10 +3831,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemPCpyChk(CallInst *CI, if (isFortifiedCallFoldable(CI, 3, 2)) if (Value *Call = emitMemPCpy(CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, DL, TLI)) { - CallInst *NewCI = cast<CallInst>(Call); - NewCI->setAttributes(CI->getAttributes()); - NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType())); - return copyFlags(*CI, NewCI); + return mergeAttributesAndFlags(cast<CallInst>(Call), *CI); } return nullptr; } @@ -3669,7 +3854,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, // st[rp]cpy_chk call which may fail at runtime if the size is too long. // TODO: It might be nice to get a maximum length out of the possible // string lengths for varying. - if (isFortifiedCallFoldable(CI, 2, None, 1)) { + if (isFortifiedCallFoldable(CI, 2, std::nullopt, 1)) { if (Func == LibFunc_strcpy_chk) return copyFlags(*CI, emitStrCpy(Dst, Src, B, TLI)); else @@ -3686,11 +3871,8 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, else return nullptr; - // FIXME: There is really no guarantee that sizeof(size_t) is equal to - // sizeof(int*) for every target. So the assumption used here to derive the - // SizeTBits based on the size of an integer pointer in address space zero - // isn't always valid. - Type *SizeTTy = DL.getIntPtrType(CI->getContext(), /*AddressSpace=*/0); + unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule()); + Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits); Value *LenV = ConstantInt::get(SizeTTy, Len); Value *Ret = emitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI); // If the function was an __stpcpy_chk, and we were able to fold it into @@ -3703,7 +3885,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeStrLenChk(CallInst *CI, IRBuilderBase &B) { - if (isFortifiedCallFoldable(CI, 1, None, 0)) + if (isFortifiedCallFoldable(CI, 1, std::nullopt, 0)) return copyFlags(*CI, emitStrLen(CI->getArgOperand(0), B, CI->getModule()->getDataLayout(), TLI)); return nullptr; @@ -3738,7 +3920,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemCCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeSNPrintfChk(CallInst *CI, IRBuilderBase &B) { - if (isFortifiedCallFoldable(CI, 3, 1, None, 2)) { + if (isFortifiedCallFoldable(CI, 3, 1, std::nullopt, 2)) { SmallVector<Value *, 8> VariadicArgs(drop_begin(CI->args(), 5)); return copyFlags(*CI, emitSNPrintf(CI->getArgOperand(0), CI->getArgOperand(1), @@ -3750,7 +3932,7 @@ Value *FortifiedLibCallSimplifier::optimizeSNPrintfChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeSPrintfChk(CallInst *CI, IRBuilderBase &B) { - if (isFortifiedCallFoldable(CI, 2, None, None, 1)) { + if (isFortifiedCallFoldable(CI, 2, std::nullopt, std::nullopt, 1)) { SmallVector<Value *, 8> VariadicArgs(drop_begin(CI->args(), 4)); return copyFlags(*CI, emitSPrintf(CI->getArgOperand(0), CI->getArgOperand(3), @@ -3801,7 +3983,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrLCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeVSNPrintfChk(CallInst *CI, IRBuilderBase &B) { - if (isFortifiedCallFoldable(CI, 3, 1, None, 2)) + if (isFortifiedCallFoldable(CI, 3, 1, std::nullopt, 2)) return copyFlags( *CI, emitVSNPrintf(CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(4), CI->getArgOperand(5), B, TLI)); @@ -3811,7 +3993,7 @@ Value *FortifiedLibCallSimplifier::optimizeVSNPrintfChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeVSPrintfChk(CallInst *CI, IRBuilderBase &B) { - if (isFortifiedCallFoldable(CI, 2, None, None, 1)) + if (isFortifiedCallFoldable(CI, 2, std::nullopt, std::nullopt, 1)) return copyFlags(*CI, emitVSPrintf(CI->getArgOperand(0), CI->getArgOperand(3), CI->getArgOperand(4), B, TLI)); diff --git a/llvm/lib/Transforms/Utils/SplitModule.cpp b/llvm/lib/Transforms/Utils/SplitModule.cpp index 7e12bbd2851c..9c39c26d8b7a 100644 --- a/llvm/lib/Transforms/Utils/SplitModule.cpp +++ b/llvm/lib/Transforms/Utils/SplitModule.cpp @@ -74,7 +74,7 @@ static void addNonConstUser(ClusterMapType &GVtoClusterMap, // Adds all GlobalValue users of V to the same cluster as GV. static void addAllGlobalValueUsers(ClusterMapType &GVtoClusterMap, const GlobalValue *GV, const Value *V) { - for (auto *U : V->users()) { + for (const auto *U : V->users()) { SmallVector<const User *, 4> Worklist; Worklist.push_back(U); while (!Worklist.empty()) { diff --git a/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp b/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp index 3631733713ab..2b706858cbed 100644 --- a/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp +++ b/llvm/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp @@ -60,7 +60,7 @@ bool unifyUnreachableBlocks(Function &F) { new UnreachableInst(F.getContext(), UnreachableBlock); for (BasicBlock *BB : UnreachableBlocks) { - BB->getInstList().pop_back(); // Remove the unreachable inst. + BB->back().eraseFromParent(); // Remove the unreachable inst. BranchInst::Create(UnreachableBlock, BB); } @@ -90,7 +90,7 @@ bool unifyReturnBlocks(Function &F) { // If the function doesn't return void... add a PHI node to the block... PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(), "UnifiedRetVal"); - NewRetBlock->getInstList().push_back(PN); + PN->insertInto(NewRetBlock, NewRetBlock->end()); ReturnInst::Create(F.getContext(), PN, NewRetBlock); } @@ -102,7 +102,7 @@ bool unifyReturnBlocks(Function &F) { if (PN) PN->addIncoming(BB->getTerminator()->getOperand(0), BB); - BB->getInstList().pop_back(); // Remove the return insn + BB->back().eraseFromParent(); // Remove the return insn BranchInst::Create(NewRetBlock, BB); } diff --git a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp index 9bbfe06b9abb..3be96ebc93a2 100644 --- a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp +++ b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/InitializePasses.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -30,6 +31,11 @@ using namespace llvm; +static cl::opt<unsigned> MaxBooleansInControlFlowHub( + "max-booleans-in-control-flow-hub", cl::init(32), cl::Hidden, + cl::desc("Set the maximum number of outgoing blocks for using a boolean " + "value to record the exiting block in CreateControlFlowHub.")); + namespace { struct UnifyLoopExitsLegacyPass : public FunctionPass { static char ID; @@ -88,7 +94,7 @@ static void restoreSSA(const DominatorTree &DT, const Loop *L, using InstVector = SmallVector<Instruction *, 8>; using IIMap = MapVector<Instruction *, InstVector>; IIMap ExternalUsers; - for (auto BB : L->blocks()) { + for (auto *BB : L->blocks()) { for (auto &I : *BB) { for (auto &U : I.uses()) { auto UserInst = cast<Instruction>(U.getUser()); @@ -114,10 +120,10 @@ static void restoreSSA(const DominatorTree &DT, const Loop *L, // didn't exist in the original CFG. auto Def = II.first; LLVM_DEBUG(dbgs() << "externally used: " << Def->getName() << "\n"); - auto NewPhi = PHINode::Create(Def->getType(), Incoming.size(), - Def->getName() + ".moved", - LoopExitBlock->getTerminator()); - for (auto In : Incoming) { + auto NewPhi = + PHINode::Create(Def->getType(), Incoming.size(), + Def->getName() + ".moved", &LoopExitBlock->front()); + for (auto *In : Incoming) { LLVM_DEBUG(dbgs() << "predecessor " << In->getName() << ": "); if (Def->getParent() == In || DT.dominates(Def, In)) { LLVM_DEBUG(dbgs() << "dominated\n"); @@ -129,7 +135,7 @@ static void restoreSSA(const DominatorTree &DT, const Loop *L, } LLVM_DEBUG(dbgs() << "external users:"); - for (auto U : II.second) { + for (auto *U : II.second) { LLVM_DEBUG(dbgs() << " " << U->getName()); U->replaceUsesOfWith(Def, NewPhi); } @@ -149,9 +155,9 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { // We need SetVectors, but the Loop API takes a vector, so we use a temporary. SmallVector<BasicBlock *, 8> Temp; L->getExitingBlocks(Temp); - for (auto BB : Temp) { + for (auto *BB : Temp) { ExitingBlocks.insert(BB); - for (auto S : successors(BB)) { + for (auto *S : successors(BB)) { auto SL = LI.getLoopFor(S); // A successor is not an exit if it is directly or indirectly in the // current loop. @@ -181,8 +187,9 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { SmallVector<BasicBlock *, 8> GuardBlocks; DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); - auto LoopExitBlock = CreateControlFlowHub(&DTU, GuardBlocks, ExitingBlocks, - Exits, "loop.exit"); + auto LoopExitBlock = + CreateControlFlowHub(&DTU, GuardBlocks, ExitingBlocks, Exits, "loop.exit", + MaxBooleansInControlFlowHub.getValue()); restoreSSA(DT, L, ExitingBlocks, LoopExitBlock); @@ -196,7 +203,7 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { // The guard blocks were created outside the loop, so they need to become // members of the parent loop. if (auto ParentLoop = L->getParentLoop()) { - for (auto G : GuardBlocks) { + for (auto *G : GuardBlocks) { ParentLoop->addBasicBlockToLoop(G, LI); } ParentLoop->verifyLoop(); @@ -213,7 +220,7 @@ static bool runImpl(LoopInfo &LI, DominatorTree &DT) { bool Changed = false; auto Loops = LI.getLoopsInPreorder(); - for (auto L : Loops) { + for (auto *L : Loops) { LLVM_DEBUG(dbgs() << "Loop: " << L->getHeader()->getName() << " (depth: " << LI.getLoopDepth(L->getHeader()) << ")\n"); Changed |= unifyLoopExits(DT, LI, L); diff --git a/llvm/lib/Transforms/Utils/VNCoercion.cpp b/llvm/lib/Transforms/Utils/VNCoercion.cpp index 42be67f3cfc0..f295a7e312b6 100644 --- a/llvm/lib/Transforms/Utils/VNCoercion.cpp +++ b/llvm/lib/Transforms/Utils/VNCoercion.cpp @@ -28,14 +28,14 @@ bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, isFirstClassAggregateOrScalableType(StoredTy)) return false; - uint64_t StoreSize = DL.getTypeSizeInBits(StoredTy).getFixedSize(); + uint64_t StoreSize = DL.getTypeSizeInBits(StoredTy).getFixedValue(); // The store size must be byte-aligned to support future type casts. if (llvm::alignTo(StoreSize, 8) != StoreSize) return false; // The store has to be at least as big as the load. - if (StoreSize < DL.getTypeSizeInBits(LoadTy).getFixedSize()) + if (StoreSize < DL.getTypeSizeInBits(LoadTy).getFixedValue()) return false; bool StoredNI = DL.isNonIntegralPointerType(StoredTy->getScalarType()); @@ -57,8 +57,11 @@ bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, // The implementation below uses inttoptr for vectors of unequal size; we // can't allow this for non integral pointers. We could teach it to extract - // exact subvectors if desired. - if (StoredNI && StoreSize != DL.getTypeSizeInBits(LoadTy).getFixedSize()) + // exact subvectors if desired. + if (StoredNI && StoreSize != DL.getTypeSizeInBits(LoadTy).getFixedValue()) + return false; + + if (StoredTy->isTargetExtTy() || LoadTy->isTargetExtTy()) return false; return true; @@ -81,8 +84,8 @@ Value *coerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, // If this is already the right type, just return it. Type *StoredValTy = StoredVal->getType(); - uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy).getFixedSize(); - uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy).getFixedSize(); + uint64_t StoredValSize = DL.getTypeSizeInBits(StoredValTy).getFixedValue(); + uint64_t LoadedValSize = DL.getTypeSizeInBits(LoadedTy).getFixedValue(); // If the store and reload are the same size, we can always reuse it. if (StoredValSize == LoadedValSize) { @@ -134,8 +137,8 @@ Value *coerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, // If this is a big-endian system, we need to shift the value down to the low // bits so that a truncate will work. if (DL.isBigEndian()) { - uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy).getFixedSize() - - DL.getTypeStoreSizeInBits(LoadedTy).getFixedSize(); + uint64_t ShiftAmt = DL.getTypeStoreSizeInBits(StoredValTy).getFixedValue() - + DL.getTypeStoreSizeInBits(LoadedTy).getFixedValue(); StoredVal = Helper.CreateLShr( StoredVal, ConstantInt::get(StoredVal->getType(), ShiftAmt)); } @@ -183,7 +186,7 @@ static int analyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, if (StoreBase != LoadBase) return -1; - uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy).getFixedSize(); + uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy).getFixedValue(); if ((WriteSizeInBits & 7) | (LoadSize & 7)) return -1; @@ -218,7 +221,7 @@ int analyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr, Value *StorePtr = DepSI->getPointerOperand(); uint64_t StoreSize = - DL.getTypeSizeInBits(DepSI->getValueOperand()->getType()).getFixedSize(); + DL.getTypeSizeInBits(DepSI->getValueOperand()->getType()).getFixedValue(); return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, StorePtr, StoreSize, DL); } @@ -321,7 +324,7 @@ int analyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, LoadInst *DepLI, return -1; Value *DepPtr = DepLI->getPointerOperand(); - uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()).getFixedSize(); + uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()).getFixedValue(); int R = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL); if (R != -1) return R; @@ -331,7 +334,7 @@ int analyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, LoadInst *DepLI, int64_t LoadOffs = 0; const Value *LoadBase = GetPointerBaseWithConstantOffset(LoadPtr, LoadOffs, DL); - unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedSize(); + unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue(); unsigned Size = getLoadLoadClobberFullWidthSize(LoadBase, LoadOffs, LoadSize, DepLI); @@ -356,9 +359,9 @@ int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, // If this is memset, we just need to see if the offset is valid in the size // of the memset.. - if (MI->getIntrinsicID() == Intrinsic::memset) { + if (const auto *memset_inst = dyn_cast<MemSetInst>(MI)) { if (DL.isNonIntegralPointerType(LoadTy->getScalarType())) { - auto *CI = dyn_cast<ConstantInt>(cast<MemSetInst>(MI)->getValue()); + auto *CI = dyn_cast<ConstantInt>(memset_inst->getValue()); if (!CI || !CI->isZero()) return -1; } @@ -408,8 +411,8 @@ static Value *getStoreValueForLoadHelper(Value *SrcVal, unsigned Offset, } uint64_t StoreSize = - (DL.getTypeSizeInBits(SrcVal->getType()).getFixedSize() + 7) / 8; - uint64_t LoadSize = (DL.getTypeSizeInBits(LoadTy).getFixedSize() + 7) / 8; + (DL.getTypeSizeInBits(SrcVal->getType()).getFixedValue() + 7) / 8; + uint64_t LoadSize = (DL.getTypeSizeInBits(LoadTy).getFixedValue() + 7) / 8; // Compute which bits of the stored value are being used by the load. Convert // to an integer type to start with. if (SrcVal->getType()->isPtrOrPtrVectorTy()) @@ -462,8 +465,8 @@ Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy, // If Offset+LoadTy exceeds the size of SrcVal, then we must be wanting to // widen SrcVal out to a larger load. unsigned SrcValStoreSize = - DL.getTypeStoreSize(SrcVal->getType()).getFixedSize(); - unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedSize(); + DL.getTypeStoreSize(SrcVal->getType()).getFixedValue(); + unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue(); if (Offset + LoadSize > SrcValStoreSize) { assert(SrcVal->isSimple() && "Cannot widen volatile/atomic load!"); assert(SrcVal->getType()->isIntegerTy() && "Can't widen non-integer load"); @@ -507,8 +510,8 @@ Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy, Constant *getConstantLoadValueForLoad(Constant *SrcVal, unsigned Offset, Type *LoadTy, const DataLayout &DL) { unsigned SrcValStoreSize = - DL.getTypeStoreSize(SrcVal->getType()).getFixedSize(); - unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedSize(); + DL.getTypeStoreSize(SrcVal->getType()).getFixedValue(); + unsigned LoadSize = DL.getTypeStoreSize(LoadTy).getFixedValue(); if (Offset + LoadSize > SrcValStoreSize) return nullptr; return getConstantStoreValueForLoad(SrcVal, Offset, LoadTy, DL); @@ -520,7 +523,7 @@ Value *getMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, Type *LoadTy, Instruction *InsertPt, const DataLayout &DL) { LLVMContext &Ctx = LoadTy->getContext(); - uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy).getFixedSize() / 8; + uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy).getFixedValue() / 8; IRBuilder<> Builder(InsertPt); // We know that this method is only called when the mem transfer fully @@ -566,7 +569,7 @@ Value *getMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, Constant *getConstantMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, Type *LoadTy, const DataLayout &DL) { LLVMContext &Ctx = LoadTy->getContext(); - uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy).getFixedSize() / 8; + uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy).getFixedValue() / 8; // We know that this method is only called when the mem transfer fully // provides the bits for the load. diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp index 8947303674ee..a5edbb2acc6d 100644 --- a/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -15,8 +15,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Argument.h" @@ -181,7 +179,7 @@ private: Value *mapBlockAddress(const BlockAddress &BA); /// Map metadata that doesn't require visiting operands. - Optional<Metadata *> mapSimpleMetadata(const Metadata *MD); + std::optional<Metadata *> mapSimpleMetadata(const Metadata *MD); Metadata *mapToMetadata(const Metadata *Key, Metadata *Val); Metadata *mapToSelf(const Metadata *MD); @@ -270,9 +268,10 @@ private: /// MDNode, compute and return the mapping. If it's a distinct \a MDNode, /// return the result of \a mapDistinctNode(). /// - /// \return None if \c Op is an unmapped uniqued \a MDNode. - /// \post getMappedOp(Op) only returns None if this returns None. - Optional<Metadata *> tryToMapOperand(const Metadata *Op); + /// \return std::nullopt if \c Op is an unmapped uniqued \a MDNode. + /// \post getMappedOp(Op) only returns std::nullopt if this returns + /// std::nullopt. + std::optional<Metadata *> tryToMapOperand(const Metadata *Op); /// Map a distinct node. /// @@ -284,7 +283,7 @@ private: MDNode *mapDistinctNode(const MDNode &N); /// Get a previously mapped node. - Optional<Metadata *> getMappedOp(const Metadata *Op) const; + std::optional<Metadata *> getMappedOp(const Metadata *Op) const; /// Create a post-order traversal of an unmapped uniqued node subgraph. /// @@ -317,11 +316,10 @@ private: /// This visits all the nodes in \c G in post-order, using the identity /// mapping or creating a new node depending on \a Data::HasChanged. /// - /// \pre \a getMappedOp() returns None for nodes in \c G, but not for any of - /// their operands outside of \c G. - /// \pre \a Data::HasChanged is true for a node in \c G iff any of its - /// operands have changed. - /// \post \a getMappedOp() returns the mapped node for every node in \c G. + /// \pre \a getMappedOp() returns std::nullopt for nodes in \c G, but not for + /// any of their operands outside of \c G. \pre \a Data::HasChanged is true + /// for a node in \c G iff any of its operands have changed. \post \a + /// getMappedOp() returns the mapped node for every node in \c G. void mapNodesInPOT(UniquedGraph &G); /// Remap a node's operands using the given functor. @@ -391,8 +389,9 @@ Value *Mapper::mapValue(const Value *V) { // ensures metadata operands only reference defined SSA values. return (Flags & RF_IgnoreMissingLocals) ? nullptr - : MetadataAsValue::get(V->getContext(), - MDTuple::get(V->getContext(), None)); + : MetadataAsValue::get( + V->getContext(), + MDTuple::get(V->getContext(), std::nullopt)); } if (auto *AL = dyn_cast<DIArgList>(MD)) { SmallVector<ValueAsMetadata *, 4> MappedArgs; @@ -558,11 +557,11 @@ Metadata *Mapper::mapToSelf(const Metadata *MD) { return mapToMetadata(MD, const_cast<Metadata *>(MD)); } -Optional<Metadata *> MDNodeMapper::tryToMapOperand(const Metadata *Op) { +std::optional<Metadata *> MDNodeMapper::tryToMapOperand(const Metadata *Op) { if (!Op) return nullptr; - if (Optional<Metadata *> MappedOp = M.mapSimpleMetadata(Op)) { + if (std::optional<Metadata *> MappedOp = M.mapSimpleMetadata(Op)) { #ifndef NDEBUG if (auto *CMD = dyn_cast<ConstantAsMetadata>(Op)) assert((!*MappedOp || M.getVM().count(CMD->getValue()) || @@ -578,7 +577,7 @@ Optional<Metadata *> MDNodeMapper::tryToMapOperand(const Metadata *Op) { const MDNode &N = *cast<MDNode>(Op); if (N.isDistinct()) return mapDistinctNode(N); - return None; + return std::nullopt; } MDNode *MDNodeMapper::mapDistinctNode(const MDNode &N) { @@ -606,11 +605,11 @@ static ConstantAsMetadata *wrapConstantAsMetadata(const ConstantAsMetadata &CMD, return MappedV ? ConstantAsMetadata::getConstant(MappedV) : nullptr; } -Optional<Metadata *> MDNodeMapper::getMappedOp(const Metadata *Op) const { +std::optional<Metadata *> MDNodeMapper::getMappedOp(const Metadata *Op) const { if (!Op) return nullptr; - if (Optional<Metadata *> MappedOp = M.getVM().getMappedMD(Op)) + if (std::optional<Metadata *> MappedOp = M.getVM().getMappedMD(Op)) return *MappedOp; if (isa<MDString>(Op)) @@ -619,7 +618,7 @@ Optional<Metadata *> MDNodeMapper::getMappedOp(const Metadata *Op) const { if (auto *CMD = dyn_cast<ConstantAsMetadata>(Op)) return wrapConstantAsMetadata(*CMD, M.getVM().lookup(CMD->getValue())); - return None; + return std::nullopt; } Metadata &MDNodeMapper::UniquedGraph::getFwdReference(MDNode &Op) { @@ -704,7 +703,7 @@ MDNode *MDNodeMapper::visitOperands(UniquedGraph &G, MDNode::op_iterator &I, MDNode::op_iterator E, bool &HasChanged) { while (I != E) { Metadata *Op = *I++; // Increment even on early return. - if (Optional<Metadata *> MappedOp = tryToMapOperand(Op)) { + if (std::optional<Metadata *> MappedOp = tryToMapOperand(Op)) { // Check if the operand changes. HasChanged |= Op != *MappedOp; continue; @@ -757,7 +756,7 @@ void MDNodeMapper::mapNodesInPOT(UniquedGraph &G) { // Clone the uniqued node and remap the operands. TempMDNode ClonedN = D.Placeholder ? std::move(D.Placeholder) : N->clone(); remapOperands(*ClonedN, [this, &D, &G](Metadata *Old) { - if (Optional<Metadata *> MappedOp = getMappedOp(Old)) + if (std::optional<Metadata *> MappedOp = getMappedOp(Old)) return *MappedOp; (void)D; assert(G.Info[Old].ID > D.ID && "Expected a forward reference"); @@ -796,7 +795,7 @@ Metadata *MDNodeMapper::map(const MDNode &N) { N.isUniqued() ? mapTopLevelUniquedNode(N) : mapDistinctNode(N); while (!DistinctWorklist.empty()) remapOperands(*DistinctWorklist.pop_back_val(), [this](Metadata *Old) { - if (Optional<Metadata *> MappedOp = tryToMapOperand(Old)) + if (std::optional<Metadata *> MappedOp = tryToMapOperand(Old)) return *MappedOp; return mapTopLevelUniquedNode(*cast<MDNode>(Old)); }); @@ -825,9 +824,9 @@ Metadata *MDNodeMapper::mapTopLevelUniquedNode(const MDNode &FirstN) { return *getMappedOp(&FirstN); } -Optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) { +std::optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) { // If the value already exists in the map, use it. - if (Optional<Metadata *> NewMD = getVM().getMappedMD(MD)) + if (std::optional<Metadata *> NewMD = getVM().getMappedMD(MD)) return *NewMD; if (isa<MDString>(MD)) @@ -848,14 +847,14 @@ Optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) { assert(isa<MDNode>(MD) && "Expected a metadata node"); - return None; + return std::nullopt; } Metadata *Mapper::mapMetadata(const Metadata *MD) { assert(MD && "Expected valid metadata"); assert(!isa<LocalAsMetadata>(MD) && "Unexpected local metadata"); - if (Optional<Metadata *> NewMD = mapSimpleMetadata(MD)) + if (std::optional<Metadata *> NewMD = mapSimpleMetadata(MD)) return *NewMD; return MDNodeMapper(*this).map(*cast<MDNode>(MD)); @@ -881,7 +880,7 @@ void Mapper::flush() { AppendingInits.resize(PrefixSize); mapAppendingVariable(*E.Data.AppendingGV.GV, E.Data.AppendingGV.InitPrefix, - E.AppendingGVIsOldCtorDtor, makeArrayRef(NewInits)); + E.AppendingGVIsOldCtorDtor, ArrayRef(NewInits)); break; } case WorklistEntry::MapAliasOrIFunc: { diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index f59fc3a6dd60..0b7fc853dc1b 100644 --- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -186,8 +186,11 @@ private: SmallPtrSet<Instruction *, 16> *InstructionsProcessed); /// Check if this load/store access is misaligned accesses. + /// Returns a \p RelativeSpeed of an operation if allowed suitable to + /// compare to another result for the same \p AddressSpace and potentially + /// different \p Alignment and \p SzInBytes. bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, - Align Alignment); + Align Alignment, unsigned &RelativeSpeed); }; class LoadStoreVectorizerLegacyPass : public FunctionPass { @@ -1078,8 +1081,14 @@ bool Vectorizer::vectorizeStoreChain( InstructionsProcessed->insert(Chain.begin(), Chain.end()); // If the store is going to be misaligned, don't vectorize it. - if (accessIsMisaligned(SzInBytes, AS, Alignment)) { + unsigned RelativeSpeed; + if (accessIsMisaligned(SzInBytes, AS, Alignment, RelativeSpeed)) { if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { + unsigned SpeedBefore; + accessIsMisaligned(EltSzInBytes, AS, Alignment, SpeedBefore); + if (SpeedBefore > RelativeSpeed) + return false; + auto Chains = splitOddVectorElts(Chain, Sz); bool Vectorized = false; Vectorized |= vectorizeStoreChain(Chains.first, InstructionsProcessed); @@ -1231,8 +1240,14 @@ bool Vectorizer::vectorizeLoadChain( InstructionsProcessed->insert(Chain.begin(), Chain.end()); // If the load is going to be misaligned, don't vectorize it. - if (accessIsMisaligned(SzInBytes, AS, Alignment)) { + unsigned RelativeSpeed; + if (accessIsMisaligned(SzInBytes, AS, Alignment, RelativeSpeed)) { if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { + unsigned SpeedBefore; + accessIsMisaligned(EltSzInBytes, AS, Alignment, SpeedBefore); + if (SpeedBefore > RelativeSpeed) + return false; + auto Chains = splitOddVectorElts(Chain, Sz); bool Vectorized = false; Vectorized |= vectorizeLoadChain(Chains.first, InstructionsProcessed); @@ -1316,15 +1331,15 @@ bool Vectorizer::vectorizeLoadChain( } bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, - Align Alignment) { + Align Alignment, unsigned &RelativeSpeed) { + RelativeSpeed = 0; if (Alignment.value() % SzInBytes == 0) return false; - bool Fast = false; bool Allows = TTI.allowsMisalignedMemoryAccesses(F.getParent()->getContext(), SzInBytes * 8, AddressSpace, - Alignment, &Fast); + Alignment, &RelativeSpeed); LLVM_DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows - << " and fast? " << Fast << "\n";); - return !Allows || !Fast; + << " with relative speed = " << RelativeSpeed << '\n';); + return !Allows || !RelativeSpeed; } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 183ba86abcb4..cd48c0d57eb3 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -295,7 +295,7 @@ void LoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) { Hint *Hints[] = {&Width, &Interleave, &Force, &IsVectorized, &Predicate, &Scalable}; - for (auto H : Hints) { + for (auto *H : Hints) { if (Name == H->Name) { if (H->validate(Val)) H->Value = Val; @@ -456,16 +456,27 @@ int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy, PGSOQueryType::IRPass); bool CanAddPredicate = !OptForSize; int Stride = getPtrStride(PSE, AccessTy, Ptr, TheLoop, Strides, - CanAddPredicate, false); + CanAddPredicate, false).value_or(0); if (Stride == 1 || Stride == -1) return Stride; return 0; } -bool LoopVectorizationLegality::isUniform(Value *V) { +bool LoopVectorizationLegality::isUniform(Value *V) const { return LAI->isUniform(V); } +bool LoopVectorizationLegality::isUniformMemOp(Instruction &I) const { + Value *Ptr = getLoadStorePointerOperand(&I); + if (!Ptr) + return false; + // Note: There's nothing inherent which prevents predicated loads and + // stores from being uniform. The current lowering simply doesn't handle + // it; in particular, the cost model distinguishes scatter/gather from + // scalar w/predication, and we currently rely on the scalar path. + return isUniform(Ptr) && !blockNeedsPredication(I.getParent()); +} + bool LoopVectorizationLegality::canVectorizeOuterLoop() { assert(!TheLoop->isInnermost() && "We are not vectorizing an outer loop."); // Store the result and return it at the end instead of exiting early, in case @@ -666,7 +677,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Non-header phi nodes that have outside uses can be vectorized. Add // them to the list of allowed exits. // Unsafe cyclic dependencies with header phis are identified during - // legalization for reduction, induction and first order + // legalization for reduction, induction and fixed order // recurrences. AllowedExit.insert(&I); continue; @@ -689,20 +700,20 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { continue; } - // TODO: Instead of recording the AllowedExit, it would be good to record the - // complementary set: NotAllowedExit. These include (but may not be - // limited to): + // TODO: Instead of recording the AllowedExit, it would be good to + // record the complementary set: NotAllowedExit. These include (but may + // not be limited to): // 1. Reduction phis as they represent the one-before-last value, which - // is not available when vectorized + // is not available when vectorized // 2. Induction phis and increment when SCEV predicates cannot be used // outside the loop - see addInductionPhi // 3. Non-Phis with outside uses when SCEV predicates cannot be used // outside the loop - see call to hasOutsideLoopUser in the non-phi // handling below - // 4. FirstOrderRecurrence phis that can possibly be handled by + // 4. FixedOrderRecurrence phis that can possibly be handled by // extraction. // By recording these, we can then reason about ways to vectorize each - // of these NotAllowedExit. + // of these NotAllowedExit. InductionDescriptor ID; if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) { addInductionPhi(Phi, ID, AllowedExit); @@ -710,10 +721,10 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { continue; } - if (RecurrenceDescriptor::isFirstOrderRecurrence(Phi, TheLoop, + if (RecurrenceDescriptor::isFixedOrderRecurrence(Phi, TheLoop, SinkAfter, DT)) { AllowedExit.insert(Phi); - FirstOrderRecurrences.insert(Phi); + FixedOrderRecurrences.insert(Phi); continue; } @@ -883,12 +894,12 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } } - // For first order recurrences, we use the previous value (incoming value from + // For fixed order recurrences, we use the previous value (incoming value from // the latch) to check if it dominates all users of the recurrence. Bail out // if we have to sink such an instruction for another recurrence, as the // dominance requirement may not hold after sinking. BasicBlock *LoopLatch = TheLoop->getLoopLatch(); - if (any_of(FirstOrderRecurrences, [LoopLatch, this](const PHINode *Phi) { + if (any_of(FixedOrderRecurrences, [LoopLatch, this](const PHINode *Phi) { Instruction *V = cast<Instruction>(Phi->getIncomingValueForBlock(LoopLatch)); return SinkAfter.find(V) != SinkAfter.end(); @@ -905,7 +916,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } bool LoopVectorizationLegality::canVectorizeMemory() { - LAI = &(*GetLAA)(*TheLoop); + LAI = &LAIs.getInfo(*TheLoop); const OptimizationRemarkAnalysis *LAR = LAI->getReport(); if (LAR) { ORE->emit([&]() { @@ -922,10 +933,13 @@ bool LoopVectorizationLegality::canVectorizeMemory() { // vectorize loop is made, runtime checks are added so as to make sure that // invariant address won't alias with any other objects. if (!LAI->getStoresToInvariantAddresses().empty()) { - // For each invariant address, check its last stored value is unconditional. + // For each invariant address, check if last stored value is unconditional + // and the address is not calculated inside the loop. for (StoreInst *SI : LAI->getStoresToInvariantAddresses()) { - if (isInvariantStoreOfReduction(SI) && - blockNeedsPredication(SI->getParent())) { + if (!isInvariantStoreOfReduction(SI)) + continue; + + if (blockNeedsPredication(SI->getParent())) { reportVectorizationFailure( "We don't allow storing to uniform addresses", "write of conditional recurring variant value to a loop " @@ -933,6 +947,20 @@ bool LoopVectorizationLegality::canVectorizeMemory() { "CantVectorizeStoreToLoopInvariantAddress", ORE, TheLoop); return false; } + + // Invariant address should be defined outside of loop. LICM pass usually + // makes sure it happens, but in rare cases it does not, we do not want + // to overcomplicate vectorization to support this case. + if (Instruction *Ptr = dyn_cast<Instruction>(SI->getPointerOperand())) { + if (TheLoop->contains(Ptr)) { + reportVectorizationFailure( + "Invariant address is calculated inside the loop", + "write to a loop invariant address could not " + "be vectorized", + "CantVectorizeStoreToLoopInvariantAddress", ORE, TheLoop); + return false; + } + } } if (LAI->hasDependenceInvolvingLoopInvariantAddress()) { @@ -1069,9 +1097,9 @@ bool LoopVectorizationLegality::isInductionVariable(const Value *V) const { return isInductionPhi(V) || isCastedInductionVariable(V); } -bool LoopVectorizationLegality::isFirstOrderRecurrence( +bool LoopVectorizationLegality::isFixedOrderRecurrence( const PHINode *Phi) const { - return FirstOrderRecurrences.count(Phi); + return FixedOrderRecurrences.count(Phi); } bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) const { @@ -1096,30 +1124,24 @@ bool LoopVectorizationLegality::blockCanBePredicated( if (isa<NoAliasScopeDeclInst>(&I)) continue; - // We might be able to hoist the load. - if (I.mayReadFromMemory()) { - auto *LI = dyn_cast<LoadInst>(&I); - if (!LI) - return false; - if (!SafePtrs.count(LI->getPointerOperand())) { + // Loads are handled via masking (or speculated if safe to do so.) + if (auto *LI = dyn_cast<LoadInst>(&I)) { + if (!SafePtrs.count(LI->getPointerOperand())) MaskedOp.insert(LI); - continue; - } + continue; } - if (I.mayWriteToMemory()) { - auto *SI = dyn_cast<StoreInst>(&I); - if (!SI) - return false; - // Predicated store requires some form of masking: - // 1) masked store HW instruction, - // 2) emulation via load-blend-store (only if safe and legal to do so, - // be aware on the race conditions), or - // 3) element-by-element predicate check and scalar store. + // Predicated store requires some form of masking: + // 1) masked store HW instruction, + // 2) emulation via load-blend-store (only if safe and legal to do so, + // be aware on the race conditions), or + // 3) element-by-element predicate check and scalar store. + if (auto *SI = dyn_cast<StoreInst>(&I)) { MaskedOp.insert(SI); continue; } - if (I.mayThrow()) + + if (I.mayReadFromMemory() || I.mayWriteToMemory() || I.mayThrow()) return false; } @@ -1162,7 +1184,7 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { for (Instruction &I : *BB) { LoadInst *LI = dyn_cast<LoadInst>(&I); if (LI && !LI->getType()->isVectorTy() && !mustSuppressSpeculation(*LI) && - isDereferenceableAndAlignedInLoop(LI, TheLoop, SE, *DT)) + isDereferenceableAndAlignedInLoop(LI, TheLoop, SE, *DT, AC)) SafePointers.insert(LI->getPointerOperand()); } } @@ -1364,7 +1386,7 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() { SmallPtrSet<const Value *, 8> ReductionLiveOuts; - for (auto &Reduction : getReductionVars()) + for (const auto &Reduction : getReductionVars()) ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr()); // TODO: handle non-reduction outside users when tail is folded by masking. diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 2e9a9fe0640e..8990a65afdb4 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -188,6 +188,7 @@ public: struct VectorizationFactor { /// Vector width with best cost. ElementCount Width; + /// Cost of the loop with that width. InstructionCost Cost; @@ -294,9 +295,9 @@ public: : OrigLoop(L), LI(LI), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM), IAI(IAI), PSE(PSE), Hints(Hints), ORE(ORE) {} - /// Plan how to best vectorize, return the best VF and its cost, or None if - /// vectorization and interleaving should be avoided up front. - Optional<VectorizationFactor> plan(ElementCount UserVF, unsigned UserIC); + /// Plan how to best vectorize, return the best VF and its cost, or + /// std::nullopt if vectorization and interleaving should be avoided up front. + std::optional<VectorizationFactor> plan(ElementCount UserVF, unsigned UserIC); /// Use the VPlan-native path to plan how to best vectorize, return the best /// VF and its cost. diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 238b074089aa..a28099d8ba7d 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -65,8 +65,6 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" @@ -142,6 +140,7 @@ #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" #include <algorithm> #include <cassert> +#include <cmath> #include <cstdint> #include <functional> #include <iterator> @@ -362,10 +361,15 @@ cl::opt<bool> llvm::EnableLoopVectorization( "vectorize-loops", cl::init(true), cl::Hidden, cl::desc("Run the Loop vectorization passes")); -cl::opt<bool> PrintVPlansInDotFormat( - "vplan-print-in-dot-format", cl::init(false), cl::Hidden, +static cl::opt<bool> PrintVPlansInDotFormat( + "vplan-print-in-dot-format", cl::Hidden, cl::desc("Use dot format instead of plain text when dumping VPlans")); +static cl::opt<cl::boolOrDefault> ForceSafeDivisor( + "force-widen-divrem-via-safe-divisor", cl::Hidden, + cl::desc( + "Override cost based safe divisor widening for div/rem instructions")); + /// A helper function that returns true if the given type is irregular. The /// type is irregular if its allocated size doesn't equal the store size of an /// element of the corresponding vector type. @@ -396,8 +400,9 @@ static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) { /// 1) Returns exact trip count if it is known. /// 2) Returns expected trip count according to profile data if any. /// 3) Returns upper bound estimate if it is known. -/// 4) Returns None if all of the above failed. -static Optional<unsigned> getSmallBestKnownTC(ScalarEvolution &SE, Loop *L) { +/// 4) Returns std::nullopt if all of the above failed. +static std::optional<unsigned> getSmallBestKnownTC(ScalarEvolution &SE, + Loop *L) { // Check if exact trip count is known. if (unsigned ExpectedTC = SE.getSmallConstantTripCount(L)) return ExpectedTC; @@ -405,17 +410,19 @@ static Optional<unsigned> getSmallBestKnownTC(ScalarEvolution &SE, Loop *L) { // Check if there is an expected trip count available from profile data. if (LoopVectorizeWithBlockFrequency) if (auto EstimatedTC = getLoopEstimatedTripCount(L)) - return EstimatedTC; + return *EstimatedTC; // Check if upper bound estimate is known. if (unsigned ExpectedTC = SE.getSmallConstantMaxTripCount(L)) return ExpectedTC; - return None; + return std::nullopt; } +namespace { // Forward declare GeneratedRTChecks. class GeneratedRTChecks; +} // namespace namespace llvm { @@ -473,10 +480,6 @@ public: /// complex control flow around the loops. virtual std::pair<BasicBlock *, Value *> createVectorizedLoopSkeleton(); - /// Widen a single call instruction within the innermost loop. - void widenCallInstruction(CallInst &CI, VPValue *Def, VPUser &ArgOperands, - VPTransformState &State); - /// Fix the vectorized code, taking care of header phi's, live-outs, and more. void fixVectorizedLoop(VPTransformState &State, VPlan &Plan); @@ -493,7 +496,8 @@ public: /// and \p MaxLane, times each part between \p MinPart and \p MaxPart, /// inclusive. Uses the VPValue operands from \p RepRecipe instead of \p /// Instr's operands. - void scalarizeInstruction(Instruction *Instr, VPReplicateRecipe *RepRecipe, + void scalarizeInstruction(const Instruction *Instr, + VPReplicateRecipe *RepRecipe, const VPIteration &Instance, bool IfPredicateInstr, VPTransformState &State); @@ -529,6 +533,17 @@ public: // generated by fixReduction. PHINode *getReductionResumeValue(const RecurrenceDescriptor &RdxDesc); + /// Create a new phi node for the induction variable \p OrigPhi to resume + /// iteration count in the scalar epilogue, from where the vectorized loop + /// left off. In cases where the loop skeleton is more complicated (eg. + /// epilogue vectorization) and the resume values can come from an additional + /// bypass block, the \p AdditionalBypass pair provides information about the + /// bypass block and the end value on the edge from bypass to this loop. + PHINode *createInductionResumeValue( + PHINode *OrigPhi, const InductionDescriptor &ID, + ArrayRef<BasicBlock *> BypassBlocks, + std::pair<BasicBlock *, Value *> AdditionalBypass = {nullptr, nullptr}); + protected: friend class LoopVectorizationPlanner; @@ -552,7 +567,7 @@ protected: /// Create the exit value of first order recurrences in the middle block and /// update their users. - void fixFirstOrderRecurrence(VPFirstOrderRecurrencePHIRecipe *PhiR, + void fixFixedOrderRecurrence(VPFirstOrderRecurrencePHIRecipe *PhiR, VPTransformState &State); /// Create code for the loop exit value of the reduction. @@ -611,7 +626,7 @@ protected: /// Complete the loop skeleton by adding debug MDs, creating appropriate /// conditional branches in the middle block, preparing the builder and /// running the verifier. Return the preheader of the completed vector loop. - BasicBlock *completeLoopSkeleton(MDNode *OrigLoopID); + BasicBlock *completeLoopSkeleton(); /// Collect poison-generating recipes that may generate a poison value that is /// used after vectorization, even when their operands are not poison. Those @@ -643,9 +658,6 @@ protected: /// Dominator Tree. DominatorTree *DT; - /// Alias Analysis. - AAResults *AA; - /// Target Library Info. const TargetLibraryInfo *TLI; @@ -951,6 +963,27 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) { return VF.isScalable() ? B.CreateVScale(EC) : EC; } +const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE) { + const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount(); + assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count"); + + ScalarEvolution &SE = *PSE.getSE(); + + // The exit count might have the type of i64 while the phi is i32. This can + // happen if we have an induction variable that is sign extended before the + // compare. The only way that we get a backedge taken count is that the + // induction variable was signed and as such will not overflow. In such a case + // truncation is legal. + if (SE.getTypeSizeInBits(BackedgeTakenCount->getType()) > + IdxTy->getPrimitiveSizeInBits()) + BackedgeTakenCount = SE.getTruncateOrNoop(BackedgeTakenCount, IdxTy); + BackedgeTakenCount = SE.getNoopOrZeroExtend(BackedgeTakenCount, IdxTy); + + // Get the total trip count from the count by adding 1. + return SE.getAddExpr(BackedgeTakenCount, + SE.getOne(BackedgeTakenCount->getType())); +} + static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy, ElementCount VF) { assert(FTy->isFloatingPointTy() && "Expected floating point type!"); @@ -1037,27 +1070,25 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes( // Add new definitions to the worklist. for (VPValue *operand : CurRec->operands()) - if (VPDef *OpDef = operand->getDef()) - Worklist.push_back(cast<VPRecipeBase>(OpDef)); + if (VPRecipeBase *OpDef = operand->getDefiningRecipe()) + Worklist.push_back(OpDef); } }); // Traverse all the recipes in the VPlan and collect the poison-generating // recipes in the backward slice starting at the address of a VPWidenRecipe or // VPInterleaveRecipe. - auto Iter = depth_first( - VPBlockRecursiveTraversalWrapper<VPBlockBase *>(State.Plan->getEntry())); + auto Iter = vp_depth_first_deep(State.Plan->getEntry()); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { for (VPRecipeBase &Recipe : *VPBB) { if (auto *WidenRec = dyn_cast<VPWidenMemoryInstructionRecipe>(&Recipe)) { Instruction &UnderlyingInstr = WidenRec->getIngredient(); - VPDef *AddrDef = WidenRec->getAddr()->getDef(); + VPRecipeBase *AddrDef = WidenRec->getAddr()->getDefiningRecipe(); if (AddrDef && WidenRec->isConsecutive() && Legal->blockNeedsPredication(UnderlyingInstr.getParent())) - collectPoisonGeneratingInstrsInBackwardSlice( - cast<VPRecipeBase>(AddrDef)); + collectPoisonGeneratingInstrsInBackwardSlice(AddrDef); } else if (auto *InterleaveRec = dyn_cast<VPInterleaveRecipe>(&Recipe)) { - VPDef *AddrDef = InterleaveRec->getAddr()->getDef(); + VPRecipeBase *AddrDef = InterleaveRec->getAddr()->getDefiningRecipe(); if (AddrDef) { // Check if any member of the interleave group needs predication. const InterleaveGroup<Instruction> *InterGroup = @@ -1072,8 +1103,7 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes( } if (NeedPredication) - collectPoisonGeneratingInstrsInBackwardSlice( - cast<VPRecipeBase>(AddrDef)); + collectPoisonGeneratingInstrsInBackwardSlice(AddrDef); } } } @@ -1182,7 +1212,7 @@ public: /// If interleave count has been specified by metadata it will be returned. /// Otherwise, the interleave count is computed and returned. VF and LoopCost /// are the selected vectorization factor and the cost of the selected VF. - unsigned selectInterleaveCount(ElementCount VF, unsigned LoopCost); + unsigned selectInterleaveCount(ElementCount VF, InstructionCost LoopCost); /// Memory access instruction may be vectorized in more than one way. /// Form of instruction after vectorization depends on cost. @@ -1435,47 +1465,49 @@ public: })); } - /// Returns true if \p I is an instruction that will be scalarized with - /// predication when vectorizing \p I with vectorization factor \p VF. Such - /// instructions include conditional stores and instructions that may divide - /// by zero. - bool isScalarWithPredication(Instruction *I, ElementCount VF) const; - - // Returns true if \p I is an instruction that will be predicated either - // through scalar predication or masked load/store or masked gather/scatter. - // \p VF is the vectorization factor that will be used to vectorize \p I. - // Superset of instructions that return true for isScalarWithPredication. - bool isPredicatedInst(Instruction *I, ElementCount VF) { - // When we know the load's address is loop invariant and the instruction - // in the original scalar loop was unconditionally executed then we - // don't need to mark it as a predicated instruction. Tail folding may - // introduce additional predication, but we're guaranteed to always have - // at least one active lane. We call Legal->blockNeedsPredication here - // because it doesn't query tail-folding. - if (Legal->isUniformMemOp(*I) && isa<LoadInst>(I) && - !Legal->blockNeedsPredication(I->getParent())) + /// Given costs for both strategies, return true if the scalar predication + /// lowering should be used for div/rem. This incorporates an override + /// option so it is not simply a cost comparison. + bool isDivRemScalarWithPredication(InstructionCost ScalarCost, + InstructionCost SafeDivisorCost) const { + switch (ForceSafeDivisor) { + case cl::BOU_UNSET: + return ScalarCost < SafeDivisorCost; + case cl::BOU_TRUE: return false; - if (!blockNeedsPredicationForAnyReason(I->getParent())) - return false; - // Loads and stores that need some form of masked operation are predicated - // instructions. - if (isa<LoadInst>(I) || isa<StoreInst>(I)) - return Legal->isMaskRequired(I); - return isScalarWithPredication(I, VF); + case cl::BOU_FALSE: + return true; + }; + llvm_unreachable("impossible case value"); } + /// Returns true if \p I is an instruction which requires predication and + /// for which our chosen predication strategy is scalarization (i.e. we + /// don't have an alternate strategy such as masking available). + /// \p VF is the vectorization factor that will be used to vectorize \p I. + bool isScalarWithPredication(Instruction *I, ElementCount VF) const; + + /// Returns true if \p I is an instruction that needs to be predicated + /// at runtime. The result is independent of the predication mechanism. + /// Superset of instructions that return true for isScalarWithPredication. + bool isPredicatedInst(Instruction *I) const; + + /// Return the costs for our two available strategies for lowering a + /// div/rem operation which requires speculating at least one lane. + /// First result is for scalarization (will be invalid for scalable + /// vectors); second is for the safe-divisor strategy. + std::pair<InstructionCost, InstructionCost> + getDivRemSpeculationCost(Instruction *I, + ElementCount VF) const; + /// Returns true if \p I is a memory instruction with consecutive memory /// access that can be widened. - bool - memoryInstructionCanBeWidened(Instruction *I, - ElementCount VF = ElementCount::getFixed(1)); + bool memoryInstructionCanBeWidened(Instruction *I, ElementCount VF); /// Returns true if \p I is a memory instruction in an interleaved-group /// of memory accesses that can be vectorized with wide vector loads/stores /// and shuffles. - bool - interleavedAccessCanBeWidened(Instruction *I, - ElementCount VF = ElementCount::getFixed(1)); + bool interleavedAccessCanBeWidened(Instruction *I, ElementCount VF); /// Check if \p Instr belongs to any interleaved access group. bool isAccessInterleaved(Instruction *Instr) { @@ -1567,7 +1599,7 @@ public: /// Convenience function that returns the value of vscale_range iff /// vscale_range.min == vscale_range.max or otherwise returns the value /// returned by the corresponding TLI method. - Optional<unsigned> getVScaleForTuning() const; + std::optional<unsigned> getVScaleForTuning() const; private: unsigned NumPredStores = 0; @@ -1623,7 +1655,7 @@ private: /// Return the cost of instructions in an inloop reduction pattern, if I is /// part of that pattern. - Optional<InstructionCost> + std::optional<InstructionCost> getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy, TTI::TargetCostKind CostKind); @@ -1651,8 +1683,8 @@ private: /// Estimate the overhead of scalarizing an instruction. This is a /// convenience wrapper for the type-based getScalarizationOverhead API. - InstructionCost getScalarizationOverhead(Instruction *I, - ElementCount VF) const; + InstructionCost getScalarizationOverhead(Instruction *I, ElementCount VF, + TTI::TargetCostKind CostKind) const; /// Returns true if an artificially high cost for emulated masked memrefs /// should be used. @@ -1719,8 +1751,9 @@ private: /// scalarize and their scalar costs are collected in \p ScalarCosts. A /// non-negative return value implies the expression will be scalarized. /// Currently, only single-use chains are considered for scalarization. - int computePredInstDiscount(Instruction *PredInst, ScalarCostsTy &ScalarCosts, - ElementCount VF); + InstructionCost computePredInstDiscount(Instruction *PredInst, + ScalarCostsTy &ScalarCosts, + ElementCount VF); /// Collect the instructions that are uniform after vectorization. An /// instruction is uniform if we represent it with a single scalar value in @@ -1835,6 +1868,7 @@ public: }; } // end namespace llvm +namespace { /// Helper struct to manage generating runtime checks for vectorization. /// /// The runtime checks are created up-front in temporary blocks to allow better @@ -1914,7 +1948,7 @@ public: if (DiffChecks) { Value *RuntimeVF = nullptr; MemRuntimeCheckCond = addDiffRuntimeChecks( - MemCheckBlock->getTerminator(), L, *DiffChecks, MemCheckExp, + MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp, [VF, &RuntimeVF](IRBuilderBase &B, unsigned Bits) { if (!RuntimeVF) RuntimeVF = getRuntimeVF(B, B.getIntNTy(Bits), VF); @@ -2099,6 +2133,7 @@ public: return MemCheckBlock; } }; +} // namespace // Return true if \p OuterLp is an outer loop annotated with hints for explicit // vectorization. The loop needs to be annotated with #pragma omp simd @@ -2194,18 +2229,15 @@ struct LoopVectorize : public FunctionPass { auto *BFI = &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); auto *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; - auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); + auto &LAIs = getAnalysis<LoopAccessLegacyAnalysis>().getLAIs(); auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - std::function<const LoopAccessInfo &(Loop &)> GetLAA = - [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; - - return Impl.runImpl(F, *SE, *LI, *TTI, *DT, *BFI, TLI, *DB, *AA, *AC, - GetLAA, *ORE, PSI).MadeAnyChange; + return Impl + .runImpl(F, *SE, *LI, *TTI, *DT, *BFI, TLI, *DB, *AC, LAIs, *ORE, PSI) + .MadeAnyChange; } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -2215,7 +2247,6 @@ struct LoopVectorize : public FunctionPass { AU.addRequired<LoopInfoWrapperPass>(); AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<LoopAccessLegacyAnalysis>(); AU.addRequired<DemandedBitsWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); @@ -2321,12 +2352,16 @@ static void buildScalarSteps(Value *ScalarIV, Value *Step, const InductionDescriptor &ID, VPValue *Def, VPTransformState &State) { IRBuilderBase &Builder = State.Builder; - // We shouldn't have to build scalar steps if we aren't vectorizing. - assert(State.VF.isVector() && "VF should be greater than one"); - // Get the value type and ensure it and the step have the same integer type. + + // Ensure step has the same type as that of scalar IV. Type *ScalarIVTy = ScalarIV->getType()->getScalarType(); - assert(ScalarIVTy == Step->getType() && - "Val and Step should have the same type"); + if (ScalarIVTy != Step->getType()) { + // TODO: Also use VPDerivedIVRecipe when only the step needs truncating, to + // avoid separate truncate here. + assert(Step->getType()->isIntegerTy() && + "Truncation requires an integer step"); + Step = State.Builder.CreateTrunc(Step, ScalarIVTy); + } // We build scalar steps for both integer and floating-point induction // variables. Here, we determine the kind of arithmetic we will perform. @@ -2343,7 +2378,6 @@ static void buildScalarSteps(Value *ScalarIV, Value *Step, // Determine the number of scalars we need to generate for each unroll // iteration. bool FirstLaneOnly = vputils::onlyFirstLaneUsed(Def); - unsigned Lanes = FirstLaneOnly ? 1 : State.VF.getKnownMinValue(); // Compute the scalar steps and save the results in State. Type *IntStepTy = IntegerType::get(ScalarIVTy->getContext(), ScalarIVTy->getScalarSizeInBits()); @@ -2357,7 +2391,17 @@ static void buildScalarSteps(Value *ScalarIV, Value *Step, SplatIV = Builder.CreateVectorSplat(State.VF, ScalarIV); } - for (unsigned Part = 0; Part < State.UF; ++Part) { + unsigned StartPart = 0; + unsigned EndPart = State.UF; + unsigned StartLane = 0; + unsigned EndLane = FirstLaneOnly ? 1 : State.VF.getKnownMinValue(); + if (State.Instance) { + StartPart = State.Instance->Part; + EndPart = StartPart + 1; + StartLane = State.Instance->Lane.getKnownLane(); + EndLane = StartLane + 1; + } + for (unsigned Part = StartPart; Part < EndPart; ++Part) { Value *StartIdx0 = createStepForVF(Builder, IntStepTy, State.VF, Part); if (!FirstLaneOnly && State.VF.isScalable()) { @@ -2376,7 +2420,7 @@ static void buildScalarSteps(Value *ScalarIV, Value *Step, if (ScalarIVTy->isFloatingPointTy()) StartIdx0 = Builder.CreateSIToFP(StartIdx0, ScalarIVTy); - for (unsigned Lane = 0; Lane < Lanes; ++Lane) { + for (unsigned Lane = StartLane; Lane < EndLane; ++Lane) { Value *StartIdx = Builder.CreateBinOp( AddOp, StartIdx0, getSignedIntOrFpConstant(ScalarIVTy, Lane)); // The step returned by `createStepForVF` is a runtime-evaluated value @@ -2415,8 +2459,14 @@ static Value *CreateStepValue(const SCEV *Step, ScalarEvolution &SE, static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index, Value *StartValue, Value *Step, const InductionDescriptor &ID) { - assert(Index->getType()->getScalarType() == Step->getType() && - "Index scalar type does not match StepValue type"); + Type *StepTy = Step->getType(); + Value *CastedIndex = StepTy->isIntegerTy() + ? B.CreateSExtOrTrunc(Index, StepTy) + : B.CreateCast(Instruction::SIToFP, Index, StepTy); + if (CastedIndex != Index) { + CastedIndex->setName(CastedIndex->getName() + ".cast"); + Index = CastedIndex; + } // Note: the IR at this point is broken. We cannot use SE to create any new // SCEV and then expand it, hoping that SCEV's simplification will give us @@ -2682,6 +2732,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( for (unsigned Part = 0; Part < UF; Part++) { // Collect the stored vector from each member. SmallVector<Value *, 4> StoredVecs; + unsigned StoredIdx = 0; for (unsigned i = 0; i < InterleaveFactor; i++) { assert((Group->getMember(i) || MaskForGaps) && "Fail to get a member from an interleaved store group"); @@ -2694,7 +2745,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( continue; } - Value *StoredVec = State.get(StoredValues[i], Part); + Value *StoredVec = State.get(StoredValues[StoredIdx], Part); + ++StoredIdx; if (Group->isReverse()) StoredVec = Builder.CreateVectorReverse(StoredVec, "reverse"); @@ -2738,7 +2790,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( } } -void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, +void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr, VPReplicateRecipe *RepRecipe, const VPIteration &Instance, bool IfPredicateInstr, @@ -2772,11 +2824,10 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, // Replace the operands of the cloned instructions with their scalar // equivalents in the new loop. - for (auto &I : enumerate(RepRecipe->operands())) { + for (const auto &I : enumerate(RepRecipe->operands())) { auto InputInstance = Instance; VPValue *Operand = I.value(); - VPReplicateRecipe *OperandR = dyn_cast<VPReplicateRecipe>(Operand); - if (OperandR && OperandR->isUniform()) + if (vputils::isUniformAfterVectorization(Operand)) InputInstance.Lane = VPLane::getFirstLane(); Cloned->setOperand(I.index(), State.get(Operand, InputInstance)); } @@ -2803,33 +2854,15 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(BasicBlock *InsertBlock) { assert(InsertBlock); IRBuilder<> Builder(InsertBlock->getTerminator()); // Find the loop boundaries. - ScalarEvolution *SE = PSE.getSE(); - const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount(); - assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && - "Invalid loop count"); - Type *IdxTy = Legal->getWidestInductionType(); assert(IdxTy && "No type for induction"); - - // The exit count might have the type of i64 while the phi is i32. This can - // happen if we have an induction variable that is sign extended before the - // compare. The only way that we get a backedge taken count is that the - // induction variable was signed and as such will not overflow. In such a case - // truncation is legal. - if (SE->getTypeSizeInBits(BackedgeTakenCount->getType()) > - IdxTy->getPrimitiveSizeInBits()) - BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount, IdxTy); - BackedgeTakenCount = SE->getNoopOrZeroExtend(BackedgeTakenCount, IdxTy); - - // Get the total trip count from the count by adding 1. - const SCEV *ExitCount = SE->getAddExpr( - BackedgeTakenCount, SE->getOne(BackedgeTakenCount->getType())); + const SCEV *ExitCount = createTripCountSCEV(IdxTy, PSE); const DataLayout &DL = InsertBlock->getModule()->getDataLayout(); // Expand the trip count and place the new instructions in the preheader. // Notice that the pre-header does not change, only the loop body. - SCEVExpander Exp(*SE, DL, "induction"); + SCEVExpander Exp(*PSE.getSE(), DL, "induction"); // Count holds the overall loop count (N). TripCount = Exp.expandCodeFor(ExitCount, ExitCount->getType(), @@ -3080,7 +3113,7 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { // 1) If we know that we must execute the scalar epilogue, emit an // unconditional branch. // 2) Otherwise, we must have a single unique exit block (due to how we - // implement the multiple exit case). In this case, set up a conditonal + // implement the multiple exit case). In this case, set up a conditional // branch from the middle block to the loop scalar preheader, and the // exit block. completeLoopSkeleton will update the condition to use an // iteration check, if required to decide whether to execute the remainder. @@ -3101,88 +3134,87 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { DT->changeImmediateDominator(LoopExitBlock, LoopMiddleBlock); } -void InnerLoopVectorizer::createInductionResumeValues( +PHINode *InnerLoopVectorizer::createInductionResumeValue( + PHINode *OrigPhi, const InductionDescriptor &II, + ArrayRef<BasicBlock *> BypassBlocks, std::pair<BasicBlock *, Value *> AdditionalBypass) { - assert(((AdditionalBypass.first && AdditionalBypass.second) || - (!AdditionalBypass.first && !AdditionalBypass.second)) && - "Inconsistent information about additional bypass."); - Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); assert(VectorTripCount && "Expected valid arguments"); - // We are going to resume the execution of the scalar loop. - // Go over all of the induction variables that we found and fix the - // PHIs that are left in the scalar version of the loop. - // The starting values of PHI nodes depend on the counter of the last - // iteration in the vectorized loop. - // If we come from a bypass edge then we need to start from the original - // start value. + Instruction *OldInduction = Legal->getPrimaryInduction(); - for (auto &InductionEntry : Legal->getInductionVars()) { - PHINode *OrigPhi = InductionEntry.first; - InductionDescriptor II = InductionEntry.second; + Value *&EndValue = IVEndValues[OrigPhi]; + Value *EndValueFromAdditionalBypass = AdditionalBypass.second; + if (OrigPhi == OldInduction) { + // We know what the end value is. + EndValue = VectorTripCount; + } else { + IRBuilder<> B(LoopVectorPreHeader->getTerminator()); - Value *&EndValue = IVEndValues[OrigPhi]; - Value *EndValueFromAdditionalBypass = AdditionalBypass.second; - if (OrigPhi == OldInduction) { - // We know what the end value is. - EndValue = VectorTripCount; - } else { - IRBuilder<> B(LoopVectorPreHeader->getTerminator()); + // Fast-math-flags propagate from the original induction instruction. + if (II.getInductionBinOp() && isa<FPMathOperator>(II.getInductionBinOp())) + B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags()); - // Fast-math-flags propagate from the original induction instruction. - if (II.getInductionBinOp() && isa<FPMathOperator>(II.getInductionBinOp())) - B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags()); + Value *Step = + CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint()); + EndValue = + emitTransformedIndex(B, VectorTripCount, II.getStartValue(), Step, II); + EndValue->setName("ind.end"); - Type *StepType = II.getStep()->getType(); - Instruction::CastOps CastOp = - CastInst::getCastOpcode(VectorTripCount, true, StepType, true); - Value *VTC = B.CreateCast(CastOp, VectorTripCount, StepType, "cast.vtc"); + // Compute the end value for the additional bypass (if applicable). + if (AdditionalBypass.first) { + B.SetInsertPoint(&(*AdditionalBypass.first->getFirstInsertionPt())); Value *Step = CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint()); - EndValue = emitTransformedIndex(B, VTC, II.getStartValue(), Step, II); - EndValue->setName("ind.end"); - - // Compute the end value for the additional bypass (if applicable). - if (AdditionalBypass.first) { - B.SetInsertPoint(&(*AdditionalBypass.first->getFirstInsertionPt())); - CastOp = CastInst::getCastOpcode(AdditionalBypass.second, true, - StepType, true); - Value *Step = - CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint()); - VTC = - B.CreateCast(CastOp, AdditionalBypass.second, StepType, "cast.vtc"); - EndValueFromAdditionalBypass = - emitTransformedIndex(B, VTC, II.getStartValue(), Step, II); - EndValueFromAdditionalBypass->setName("ind.end"); - } + EndValueFromAdditionalBypass = emitTransformedIndex( + B, AdditionalBypass.second, II.getStartValue(), Step, II); + EndValueFromAdditionalBypass->setName("ind.end"); } + } - // Create phi nodes to merge from the backedge-taken check block. - PHINode *BCResumeVal = - PHINode::Create(OrigPhi->getType(), 3, "bc.resume.val", - LoopScalarPreHeader->getTerminator()); - // Copy original phi DL over to the new one. - BCResumeVal->setDebugLoc(OrigPhi->getDebugLoc()); + // Create phi nodes to merge from the backedge-taken check block. + PHINode *BCResumeVal = PHINode::Create(OrigPhi->getType(), 3, "bc.resume.val", + LoopScalarPreHeader->getTerminator()); + // Copy original phi DL over to the new one. + BCResumeVal->setDebugLoc(OrigPhi->getDebugLoc()); - // The new PHI merges the original incoming value, in case of a bypass, - // or the value at the end of the vectorized loop. - BCResumeVal->addIncoming(EndValue, LoopMiddleBlock); + // The new PHI merges the original incoming value, in case of a bypass, + // or the value at the end of the vectorized loop. + BCResumeVal->addIncoming(EndValue, LoopMiddleBlock); - // Fix the scalar body counter (PHI node). - // The old induction's phi node in the scalar body needs the truncated - // value. - for (BasicBlock *BB : LoopBypassBlocks) - BCResumeVal->addIncoming(II.getStartValue(), BB); + // Fix the scalar body counter (PHI node). + // The old induction's phi node in the scalar body needs the truncated + // value. + for (BasicBlock *BB : BypassBlocks) + BCResumeVal->addIncoming(II.getStartValue(), BB); - if (AdditionalBypass.first) - BCResumeVal->setIncomingValueForBlock(AdditionalBypass.first, - EndValueFromAdditionalBypass); + if (AdditionalBypass.first) + BCResumeVal->setIncomingValueForBlock(AdditionalBypass.first, + EndValueFromAdditionalBypass); + return BCResumeVal; +} +void InnerLoopVectorizer::createInductionResumeValues( + std::pair<BasicBlock *, Value *> AdditionalBypass) { + assert(((AdditionalBypass.first && AdditionalBypass.second) || + (!AdditionalBypass.first && !AdditionalBypass.second)) && + "Inconsistent information about additional bypass."); + // We are going to resume the execution of the scalar loop. + // Go over all of the induction variables that we found and fix the + // PHIs that are left in the scalar version of the loop. + // The starting values of PHI nodes depend on the counter of the last + // iteration in the vectorized loop. + // If we come from a bypass edge then we need to start from the original + // start value. + for (const auto &InductionEntry : Legal->getInductionVars()) { + PHINode *OrigPhi = InductionEntry.first; + const InductionDescriptor &II = InductionEntry.second; + PHINode *BCResumeVal = createInductionResumeValue( + OrigPhi, II, LoopBypassBlocks, AdditionalBypass); OrigPhi->setIncomingValueForBlock(LoopScalarPreHeader, BCResumeVal); } } -BasicBlock *InnerLoopVectorizer::completeLoopSkeleton(MDNode *OrigLoopID) { +BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() { // The trip counts should be cached by now. Value *Count = getOrCreateTripCount(LoopVectorPreHeader); Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); @@ -3251,18 +3283,6 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton() { ... */ - // Get the metadata of the original loop before it gets modified. - MDNode *OrigLoopID = OrigLoop->getLoopID(); - - // Workaround! Compute the trip count of the original loop and cache it - // before we start modifying the CFG. This code has a systemic problem - // wherein it tries to run analysis over partially constructed IR; this is - // wrong, and not simply for SCEV. The trip count of the original loop - // simply happens to be prone to hitting this in practice. In theory, we - // can hit the same issue for any SCEV, or ValueTracking query done during - // mutation. See PR49900. - getOrCreateTripCount(OrigLoop->getLoopPreheader()); - // Create an empty vector loop, and prepare basic blocks for the runtime // checks. createVectorLoopSkeleton(""); @@ -3286,7 +3306,7 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton() { // Emit phis for the new starting index of the scalar loop. createInductionResumeValues(); - return {completeLoopSkeleton(OrigLoopID), nullptr}; + return {completeLoopSkeleton(), nullptr}; } // Fix up external users of the induction variable. At this point, we are @@ -3334,17 +3354,11 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, Value *CountMinusOne = B.CreateSub( VectorTripCount, ConstantInt::get(VectorTripCount->getType(), 1)); - Value *CMO = - !II.getStep()->getType()->isIntegerTy() - ? B.CreateCast(Instruction::SIToFP, CountMinusOne, - II.getStep()->getType()) - : B.CreateSExtOrTrunc(CountMinusOne, II.getStep()->getType()); - CMO->setName("cast.cmo"); - + CountMinusOne->setName("cmo"); Value *Step = CreateStepValue(II.getStep(), *PSE.getSE(), VectorHeader->getTerminator()); Value *Escape = - emitTransformedIndex(B, CMO, II.getStartValue(), Step, II); + emitTransformedIndex(B, CountMinusOne, II.getStartValue(), Step, II); Escape->setName("ind.escape"); MissingVals[UI] = Escape; } @@ -3429,8 +3443,9 @@ LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, ElementCount VF, // to be vectors, so we need to extract individual elements from there, // execute VF scalar calls, and then gather the result into the vector return // value. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; InstructionCost ScalarCallCost = - TTI.getCallInstrCost(F, ScalarRetTy, ScalarTys, TTI::TCK_RecipThroughput); + TTI.getCallInstrCost(F, ScalarRetTy, ScalarTys, CostKind); if (VF.isScalar()) return ScalarCallCost; @@ -3441,7 +3456,8 @@ LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, ElementCount VF, // Compute costs of unpacking argument values for the scalar calls and // packing the return values to a vector. - InstructionCost ScalarizationCost = getScalarizationOverhead(CI, VF); + InstructionCost ScalarizationCost = + getScalarizationOverhead(CI, VF, CostKind); InstructionCost Cost = ScalarCallCost * VF.getKnownMinValue() + ScalarizationCost; @@ -3457,7 +3473,7 @@ LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, ElementCount VF, // If the corresponding vector cost is cheaper, return its cost. InstructionCost VectorCallCost = - TTI.getCallInstrCost(nullptr, RetTy, Tys, TTI::TCK_RecipThroughput); + TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind); if (VectorCallCost < Cost) { NeedToScalarize = false; Cost = VectorCallCost; @@ -3672,7 +3688,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, // edge. // Fix-up external users of the induction variables. - for (auto &Entry : Legal->getInductionVars()) + for (const auto &Entry : Legal->getInductionVars()) fixupIVUsers(Entry.first, Entry.second, getOrCreateVectorTripCount(VectorLoop->getLoopPreheader()), IVEndValues[Entry.first], LoopMiddleBlock, @@ -3682,7 +3698,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, // Fix LCSSA phis not already fixed earlier. Extracts may need to be generated // in the exit block, so update the builder. State.Builder.SetInsertPoint(State.CFG.ExitBB->getFirstNonPHI()); - for (auto &KV : Plan.getLiveOuts()) + for (const auto &KV : Plan.getLiveOuts()) KV.second->fixPhi(Plan, State); for (Instruction *PI : PredicatedInstructions) @@ -3722,11 +3738,11 @@ void InnerLoopVectorizer::fixCrossIterationPHIs(VPTransformState &State) { if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) fixReduction(ReductionPhi, State); else if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) - fixFirstOrderRecurrence(FOR, State); + fixFixedOrderRecurrence(FOR, State); } } -void InnerLoopVectorizer::fixFirstOrderRecurrence( +void InnerLoopVectorizer::fixFixedOrderRecurrence( VPFirstOrderRecurrencePHIRecipe *PhiR, VPTransformState &State) { // This is the second phase of vectorizing first-order recurrences. An // overview of the transformation is described below. Suppose we have the @@ -4019,7 +4035,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // We know that the loop is in LCSSA form. We need to update the PHI nodes // in the exit blocks. See comment on analogous loop in - // fixFirstOrderRecurrence for a more complete explaination of the logic. + // fixFixedOrderRecurrence for a more complete explaination of the logic. if (!Cost->requiresScalarEpilogue(VF)) for (PHINode &LCSSAPhi : LoopExitBlock->phis()) if (llvm::is_contained(LCSSAPhi.incoming_values(), LoopExitInst)) { @@ -4146,8 +4162,7 @@ void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { void InnerLoopVectorizer::fixNonInductionPHIs(VPlan &Plan, VPTransformState &State) { - auto Iter = depth_first( - VPBlockRecursiveTraversalWrapper<VPBlockBase *>(Plan.getEntry())); + auto Iter = vp_depth_first_deep(Plan.getEntry()); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { for (VPRecipeBase &P : VPBB->phis()) { VPWidenPHIRecipe *VPPhi = dyn_cast<VPWidenPHIRecipe>(&P); @@ -4170,78 +4185,6 @@ bool InnerLoopVectorizer::useOrderedReductions( return Cost->useOrderedReductions(RdxDesc); } -void InnerLoopVectorizer::widenCallInstruction(CallInst &CI, VPValue *Def, - VPUser &ArgOperands, - VPTransformState &State) { - assert(!isa<DbgInfoIntrinsic>(CI) && - "DbgInfoIntrinsic should have been dropped during VPlan construction"); - State.setDebugLocFromInst(&CI); - - SmallVector<Type *, 4> Tys; - for (Value *ArgOperand : CI.args()) - Tys.push_back(ToVectorTy(ArgOperand->getType(), VF.getKnownMinValue())); - - Intrinsic::ID ID = getVectorIntrinsicIDForCall(&CI, TLI); - - // The flag shows whether we use Intrinsic or a usual Call for vectorized - // version of the instruction. - // Is it beneficial to perform intrinsic call compared to lib call? - bool NeedToScalarize = false; - InstructionCost CallCost = Cost->getVectorCallCost(&CI, VF, NeedToScalarize); - InstructionCost IntrinsicCost = - ID ? Cost->getVectorIntrinsicCost(&CI, VF) : 0; - bool UseVectorIntrinsic = ID && IntrinsicCost <= CallCost; - assert((UseVectorIntrinsic || !NeedToScalarize) && - "Instruction should be scalarized elsewhere."); - assert((IntrinsicCost.isValid() || CallCost.isValid()) && - "Either the intrinsic cost or vector call cost must be valid"); - - for (unsigned Part = 0; Part < UF; ++Part) { - SmallVector<Type *, 2> TysForDecl = {CI.getType()}; - SmallVector<Value *, 4> Args; - for (auto &I : enumerate(ArgOperands.operands())) { - // Some intrinsics have a scalar argument - don't replace it with a - // vector. - Value *Arg; - if (!UseVectorIntrinsic || - !isVectorIntrinsicWithScalarOpAtArg(ID, I.index())) - Arg = State.get(I.value(), Part); - else - Arg = State.get(I.value(), VPIteration(0, 0)); - if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I.index())) - TysForDecl.push_back(Arg->getType()); - Args.push_back(Arg); - } - - Function *VectorF; - if (UseVectorIntrinsic) { - // Use vector version of the intrinsic. - if (VF.isVector()) - TysForDecl[0] = VectorType::get(CI.getType()->getScalarType(), VF); - Module *M = State.Builder.GetInsertBlock()->getModule(); - VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); - assert(VectorF && "Can't retrieve vector intrinsic."); - } else { - // Use vector version of the function call. - const VFShape Shape = VFShape::get(CI, VF, false /*HasGlobalPred*/); -#ifndef NDEBUG - assert(VFDatabase(CI).getVectorizedFunction(Shape) != nullptr && - "Can't create vector function."); -#endif - VectorF = VFDatabase(CI).getVectorizedFunction(Shape); - } - SmallVector<OperandBundleDef, 1> OpBundles; - CI.getOperandBundlesAsDefs(OpBundles); - CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); - - if (isa<FPMathOperator>(V)) - V->copyFastMathFlags(&CI); - - State.set(Def, V, Part); - State.addMetadata(V, &CI); - } -} - void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { // We should not collect Scalars more than once per VF. Right now, this // function is called from collectUniformsAndScalars(), which already does @@ -4350,8 +4293,10 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { // induction variable when the PHI user is scalarized. auto ForcedScalar = ForcedScalars.find(VF); if (ForcedScalar != ForcedScalars.end()) - for (auto *I : ForcedScalar->second) + for (auto *I : ForcedScalar->second) { + LLVM_DEBUG(dbgs() << "LV: Found (forced) scalar instruction: " << *I << "\n"); Worklist.insert(I); + } // Expand the worklist by looking through any bitcasts and getelementptr // instructions we've already identified as scalar. This is similar to the @@ -4376,7 +4321,7 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { // An induction variable will remain scalar if all users of the induction // variable and induction variable update remain scalar. - for (auto &Induction : Legal->getInductionVars()) { + for (const auto &Induction : Legal->getInductionVars()) { auto *Ind = Induction.first; auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); @@ -4429,15 +4374,16 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { bool LoopVectorizationCostModel::isScalarWithPredication( Instruction *I, ElementCount VF) const { - if (!blockNeedsPredicationForAnyReason(I->getParent())) + if (!isPredicatedInst(I)) return false; + + // Do we have a non-scalar lowering for this predicated + // instruction? No - it is scalar with predication. switch(I->getOpcode()) { default: - break; + return true; case Instruction::Load: case Instruction::Store: { - if (!Legal->isMaskRequired(I)) - return false; auto *Ptr = getLoadStorePointerOperand(I); auto *Ty = getLoadStoreType(I); Type *VTy = Ty; @@ -4452,12 +4398,119 @@ bool LoopVectorizationCostModel::isScalarWithPredication( case Instruction::UDiv: case Instruction::SDiv: case Instruction::SRem: + case Instruction::URem: { + // We have the option to use the safe-divisor idiom to avoid predication. + // The cost based decision here will always select safe-divisor for + // scalable vectors as scalarization isn't legal. + const auto [ScalarCost, SafeDivisorCost] = getDivRemSpeculationCost(I, VF); + return isDivRemScalarWithPredication(ScalarCost, SafeDivisorCost); + } + } +} + +bool LoopVectorizationCostModel::isPredicatedInst(Instruction *I) const { + if (!blockNeedsPredicationForAnyReason(I->getParent())) + return false; + + // Can we prove this instruction is safe to unconditionally execute? + // If not, we must use some form of predication. + switch(I->getOpcode()) { + default: + return false; + case Instruction::Load: + case Instruction::Store: { + if (!Legal->isMaskRequired(I)) + return false; + // When we know the load's address is loop invariant and the instruction + // in the original scalar loop was unconditionally executed then we + // don't need to mark it as a predicated instruction. Tail folding may + // introduce additional predication, but we're guaranteed to always have + // at least one active lane. We call Legal->blockNeedsPredication here + // because it doesn't query tail-folding. For stores, we need to prove + // both speculation safety (which follows from the same argument as loads), + // but also must prove the value being stored is correct. The easiest + // form of the later is to require that all values stored are the same. + if (Legal->isUniformMemOp(*I) && + (isa<LoadInst>(I) || + (isa<StoreInst>(I) && + TheLoop->isLoopInvariant(cast<StoreInst>(I)->getValueOperand()))) && + !Legal->blockNeedsPredication(I->getParent())) + return false; + return true; + } + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::SRem: case Instruction::URem: // TODO: We can use the loop-preheader as context point here and get // context sensitive reasoning return !isSafeToSpeculativelyExecute(I); } - return false; +} + +std::pair<InstructionCost, InstructionCost> +LoopVectorizationCostModel::getDivRemSpeculationCost(Instruction *I, + ElementCount VF) const { + assert(I->getOpcode() == Instruction::UDiv || + I->getOpcode() == Instruction::SDiv || + I->getOpcode() == Instruction::SRem || + I->getOpcode() == Instruction::URem); + assert(!isSafeToSpeculativelyExecute(I)); + + const TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + + // Scalarization isn't legal for scalable vector types + InstructionCost ScalarizationCost = InstructionCost::getInvalid(); + if (!VF.isScalable()) { + // Get the scalarization cost and scale this amount by the probability of + // executing the predicated block. If the instruction is not predicated, + // we fall through to the next case. + ScalarizationCost = 0; + + // These instructions have a non-void type, so account for the phi nodes + // that we will create. This cost is likely to be zero. The phi node + // cost, if any, should be scaled by the block probability because it + // models a copy at the end of each predicated block. + ScalarizationCost += VF.getKnownMinValue() * + TTI.getCFInstrCost(Instruction::PHI, CostKind); + + // The cost of the non-predicated instruction. + ScalarizationCost += VF.getKnownMinValue() * + TTI.getArithmeticInstrCost(I->getOpcode(), I->getType(), CostKind); + + // The cost of insertelement and extractelement instructions needed for + // scalarization. + ScalarizationCost += getScalarizationOverhead(I, VF, CostKind); + + // Scale the cost by the probability of executing the predicated blocks. + // This assumes the predicated block for each vector lane is equally + // likely. + ScalarizationCost = ScalarizationCost / getReciprocalPredBlockProb(); + } + InstructionCost SafeDivisorCost = 0; + + auto *VecTy = ToVectorTy(I->getType(), VF); + + // The cost of the select guard to ensure all lanes are well defined + // after we speculate above any internal control flow. + SafeDivisorCost += TTI.getCmpSelInstrCost( + Instruction::Select, VecTy, + ToVectorTy(Type::getInt1Ty(I->getContext()), VF), + CmpInst::BAD_ICMP_PREDICATE, CostKind); + + // Certain instructions can be cheaper to vectorize if they have a constant + // second vector operand. One example of this are shifts on x86. + Value *Op2 = I->getOperand(1); + auto Op2Info = TTI.getOperandInfo(Op2); + if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && Legal->isUniform(Op2)) + Op2Info.Kind = TargetTransformInfo::OK_UniformValue; + + SmallVector<const Value *, 4> Operands(I->operand_values()); + SafeDivisorCost += TTI.getArithmeticInstrCost( + I->getOpcode(), VecTy, CostKind, + {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None}, + Op2Info, Operands, I); + return {ScalarizationCost, SafeDivisorCost}; } bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( @@ -4610,17 +4663,26 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse()) addToWorklistIfAllowed(Cmp); + // Return true if all lanes perform the same memory operation, and we can + // thus chose to execute only one. + auto isUniformMemOpUse = [&](Instruction *I) { + if (!Legal->isUniformMemOp(*I)) + return false; + if (isa<LoadInst>(I)) + // Loading the same address always produces the same result - at least + // assuming aliasing and ordering which have already been checked. + return true; + // Storing the same value on every iteration. + return TheLoop->isLoopInvariant(cast<StoreInst>(I)->getValueOperand()); + }; + auto isUniformDecision = [&](Instruction *I, ElementCount VF) { InstWidening WideningDecision = getWideningDecision(I, VF); assert(WideningDecision != CM_Unknown && "Widening decision should be ready at this moment"); - // A uniform memory op is itself uniform. We exclude uniform stores - // here as they demand the last lane, not the first one. - if (isa<LoadInst>(I) && Legal->isUniformMemOp(*I)) { - assert(WideningDecision == CM_Scalarize); + if (isUniformMemOpUse(I)) return true; - } return (WideningDecision == CM_Widen || WideningDecision == CM_Widen_Reverse || @@ -4674,9 +4736,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { if (!Ptr) continue; - // A uniform memory op is itself uniform. We exclude uniform stores - // here as they demand the last lane, not the first one. - if (isa<LoadInst>(I) && Legal->isUniformMemOp(I)) + if (isUniformMemOpUse(&I)) addToWorklistIfAllowed(&I); if (isUniformDecision(&I, VF)) { @@ -4707,14 +4767,14 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { while (idx != Worklist.size()) { Instruction *I = Worklist[idx++]; - for (auto OV : I->operand_values()) { + for (auto *OV : I->operand_values()) { // isOutOfScope operands cannot be uniform instructions. if (isOutOfScope(OV)) continue; // First order recurrence Phi's should typically be considered // non-uniform. auto *OP = dyn_cast<PHINode>(OV); - if (OP && Legal->isFirstOrderRecurrence(OP)) + if (OP && Legal->isFixedOrderRecurrence(OP)) continue; // If all the users of the operand are uniform, then add the // operand into the uniform worklist. @@ -4733,7 +4793,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { // nodes separately. An induction variable will remain uniform if all users // of the induction variable and induction variable update remain uniform. // The code below handles both pointer and non-pointer induction variables. - for (auto &Induction : Legal->getInductionVars()) { + for (const auto &Induction : Legal->getInductionVars()) { auto *Ind = Induction.first; auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); @@ -4846,12 +4906,12 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { return MaxScalableVF; // Limit MaxScalableVF by the maximum safe dependence distance. - Optional<unsigned> MaxVScale = TTI.getMaxVScale(); + std::optional<unsigned> MaxVScale = TTI.getMaxVScale(); if (!MaxVScale && TheFunction->hasFnAttribute(Attribute::VScaleRange)) MaxVScale = TheFunction->getFnAttribute(Attribute::VScaleRange).getVScaleRangeMax(); - MaxScalableVF = ElementCount::getScalable( - MaxVScale ? (MaxSafeElements / MaxVScale.value()) : 0); + MaxScalableVF = + ElementCount::getScalable(MaxVScale ? (MaxSafeElements / *MaxVScale) : 0); if (!MaxScalableVF) reportVectorizationInfo( "Max legal vector width too small, scalable vectorization " @@ -4991,7 +5051,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { case CM_ScalarEpilogueAllowed: return computeFeasibleMaxVF(TC, UserVF, false); case CM_ScalarEpilogueNotAllowedUsePredicate: - LLVM_FALLTHROUGH; + [[fallthrough]]; case CM_ScalarEpilogueNotNeededUsePredicate: LLVM_DEBUG( dbgs() << "LV: vector predicate hint/switch found.\n" @@ -5113,7 +5173,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( unsigned ConstTripCount, unsigned SmallestType, unsigned WidestType, ElementCount MaxSafeVF, bool FoldTailByMasking) { bool ComputeScalableMaxVF = MaxSafeVF.isScalable(); - TypeSize WidestRegister = TTI.getRegisterBitWidth( + const TypeSize WidestRegister = TTI.getRegisterBitWidth( ComputeScalableMaxVF ? TargetTransformInfo::RGK_ScalableVector : TargetTransformInfo::RGK_FixedWidthVector); @@ -5127,7 +5187,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( // Ensure MaxVF is a power of 2; the dependence distance bound may not be. // Note that both WidestRegister and WidestType may not be a powers of 2. auto MaxVectorElementCount = ElementCount::get( - PowerOf2Floor(WidestRegister.getKnownMinSize() / WidestType), + PowerOf2Floor(WidestRegister.getKnownMinValue() / WidestType), ComputeScalableMaxVF); MaxVectorElementCount = MinVF(MaxVectorElementCount, MaxSafeVF); LLVM_DEBUG(dbgs() << "LV: The Widest register safe to use is: " @@ -5140,9 +5200,14 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( return ElementCount::getFixed(1); } - const auto TripCountEC = ElementCount::getFixed(ConstTripCount); - if (ConstTripCount && - ElementCount::isKnownLE(TripCountEC, MaxVectorElementCount) && + unsigned WidestRegisterMinEC = MaxVectorElementCount.getKnownMinValue(); + if (MaxVectorElementCount.isScalable() && + TheFunction->hasFnAttribute(Attribute::VScaleRange)) { + auto Attr = TheFunction->getFnAttribute(Attribute::VScaleRange); + auto Min = Attr.getVScaleRangeMin(); + WidestRegisterMinEC *= Min; + } + if (ConstTripCount && ConstTripCount <= WidestRegisterMinEC && (!FoldTailByMasking || isPowerOf2_32(ConstTripCount))) { // If loop trip count (TC) is known at compile time there is no point in // choosing VF greater than TC (as done in the loop below). Select maximum @@ -5163,7 +5228,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( if (MaximizeBandwidth || (MaximizeBandwidth.getNumOccurrences() == 0 && TTI.shouldMaximizeVectorBandwidth(RegKind))) { auto MaxVectorElementCountMaxBW = ElementCount::get( - PowerOf2Floor(WidestRegister.getKnownMinSize() / SmallestType), + PowerOf2Floor(WidestRegister.getKnownMinValue() / SmallestType), ComputeScalableMaxVF); MaxVectorElementCountMaxBW = MinVF(MaxVectorElementCountMaxBW, MaxSafeVF); @@ -5208,7 +5273,7 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( return MaxVF; } -Optional<unsigned> LoopVectorizationCostModel::getVScaleForTuning() const { +std::optional<unsigned> LoopVectorizationCostModel::getVScaleForTuning() const { if (TheFunction->hasFnAttribute(Attribute::VScaleRange)) { auto Attr = TheFunction->getFnAttribute(Attribute::VScaleRange); auto Min = Attr.getVScaleRangeMin(); @@ -5244,11 +5309,11 @@ bool LoopVectorizationCostModel::isMoreProfitable( // Improve estimate for the vector width if it is scalable. unsigned EstimatedWidthA = A.Width.getKnownMinValue(); unsigned EstimatedWidthB = B.Width.getKnownMinValue(); - if (Optional<unsigned> VScale = getVScaleForTuning()) { + if (std::optional<unsigned> VScale = getVScaleForTuning()) { if (A.Width.isScalable()) - EstimatedWidthA *= VScale.value(); + EstimatedWidthA *= *VScale; if (B.Width.isScalable()) - EstimatedWidthB *= VScale.value(); + EstimatedWidthB *= *VScale; } // Assume vscale may be larger than 1 (or the value being tuned for), @@ -5294,7 +5359,7 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( #ifndef NDEBUG unsigned AssumedMinimumVscale = 1; - if (Optional<unsigned> VScale = getVScaleForTuning()) + if (std::optional<unsigned> VScale = getVScaleForTuning()) AssumedMinimumVscale = *VScale; unsigned Width = Candidate.Width.isScalable() @@ -5365,7 +5430,7 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( raw_string_ostream OS(OutString); assert(!Subset.empty() && "Unexpected empty range"); OS << "Instruction with invalid costs prevented vectorization at VF=("; - for (auto &Pair : Subset) + for (const auto &Pair : Subset) OS << (Pair.second == Subset.front().second ? "" : ", ") << Pair.second; OS << "):"; @@ -5403,12 +5468,12 @@ bool LoopVectorizationCostModel::isCandidateForEpilogueVectorization( // Cross iteration phis such as reductions need special handling and are // currently unsupported. if (any_of(L.getHeader()->phis(), - [&](PHINode &Phi) { return Legal->isFirstOrderRecurrence(&Phi); })) + [&](PHINode &Phi) { return Legal->isFixedOrderRecurrence(&Phi); })) return false; // Phis with uses outside of the loop require special handling and are // currently unsupported. - for (auto &Entry : Legal->getInductionVars()) { + for (const auto &Entry : Legal->getInductionVars()) { // Look for uses of the value of the induction at the last iteration. Value *PostInc = Entry.first->getIncomingValueForBlock(L.getLoopLatch()); for (User *U : PostInc->users()) @@ -5420,14 +5485,6 @@ bool LoopVectorizationCostModel::isCandidateForEpilogueVectorization( return false; } - // Induction variables that are widened require special handling that is - // currently not supported. - if (any_of(Legal->getInductionVars(), [&](auto &Entry) { - return !(this->isScalarAfterVectorization(Entry.first, VF) || - this->isProfitableToScalarize(Entry.first, VF)); - })) - return false; - // Epilogue vectorization code has not been auditted to ensure it handles // non-latch exits properly. It may be fine, but it needs auditted and // tested. @@ -5443,6 +5500,11 @@ bool LoopVectorizationCostModel::isEpilogueVectorizationProfitable( // as register pressure, code size increase and cost of extra branches into // account. For now we apply a very crude heuristic and only consider loops // with vectorization factors larger than a certain value. + + // Allow the target to opt out entirely. + if (!TTI.preferEpilogueVectorization()) + return false; + // We also consider epilogue vectorization unprofitable for targets that don't // consider interleaving beneficial (eg. MVE). if (TTI.getMaxInterleaveFactor(VF.getKnownMinValue()) <= 1) @@ -5512,7 +5574,7 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor( ElementCount EstimatedRuntimeVF = MainLoopVF; if (MainLoopVF.isScalable()) { EstimatedRuntimeVF = ElementCount::getFixed(MainLoopVF.getKnownMinValue()); - if (Optional<unsigned> VScale = getVScaleForTuning()) + if (std::optional<unsigned> VScale = getVScaleForTuning()) EstimatedRuntimeVF *= *VScale; } @@ -5542,7 +5604,7 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { // Reset MaxWidth so that we can find the smallest type used by recurrences // in the loop. MaxWidth = -1U; - for (auto &PhiDescriptorPair : Legal->getReductionVars()) { + for (const auto &PhiDescriptorPair : Legal->getReductionVars()) { const RecurrenceDescriptor &RdxDesc = PhiDescriptorPair.second; // When finding the min width used by the recurrence we need to account // for casts on the input operands of the recurrence. @@ -5554,9 +5616,9 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { } else { for (Type *T : ElementTypesInLoop) { MinWidth = std::min<unsigned>( - MinWidth, DL.getTypeSizeInBits(T->getScalarType()).getFixedSize()); + MinWidth, DL.getTypeSizeInBits(T->getScalarType()).getFixedValue()); MaxWidth = std::max<unsigned>( - MaxWidth, DL.getTypeSizeInBits(T->getScalarType()).getFixedSize()); + MaxWidth, DL.getTypeSizeInBits(T->getScalarType()).getFixedValue()); } } return {MinWidth, MaxWidth}; @@ -5605,8 +5667,9 @@ void LoopVectorizationCostModel::collectElementTypesForWidening() { } } -unsigned LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, - unsigned LoopCost) { +unsigned +LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, + InstructionCost LoopCost) { // -- The interleave heuristics -- // We interleave the loop in order to expose ILP and reduce the loop overhead. // There are many micro-architectural considerations that we can't predict @@ -5642,9 +5705,8 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, // If we did not calculate the cost for VF (because the user selected the VF) // then we calculate the cost of VF here. if (LoopCost == 0) { - InstructionCost C = expectedCost(VF).first; - assert(C.isValid() && "Expected to have chosen a VF with valid cost"); - LoopCost = *C.getValue(); + LoopCost = expectedCost(VF).first; + assert(LoopCost.isValid() && "Expected to have chosen a VF with valid cost"); // Loop body is free and there is no need for interleaving. if (LoopCost == 0) @@ -5772,8 +5834,8 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, // We assume that the cost overhead is 1 and we use the cost model // to estimate the cost of the loop and interleave until the cost of the // loop overhead is about 5% of the cost of the loop. - unsigned SmallIC = - std::min(IC, (unsigned)PowerOf2Floor(SmallLoopCost / LoopCost)); + unsigned SmallIC = std::min( + IC, (unsigned)PowerOf2Floor(SmallLoopCost / *LoopCost.getValue())); // Interleave until store/load ports (estimated by max interleave count) are // saturated. @@ -5888,8 +5950,9 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) { IntervalMap EndPoint; // Saves the list of instruction indices that are used in the loop. SmallPtrSet<Instruction *, 8> Ends; - // Saves the list of values that are used in the loop but are - // defined outside the loop, such as arguments and constants. + // Saves the list of values that are used in the loop but are defined outside + // the loop (not including non-instruction values such as arguments and + // constants). SmallPtrSet<Value *, 8> LoopInvariants; for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { @@ -5901,6 +5964,9 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) { auto *Instr = dyn_cast<Instruction>(U); // Ignore non-instruction values such as arguments, constants, etc. + // FIXME: Might need some motivation why these values are ignored. If + // for example an argument is used inside the loop it will increase the + // register pressure (so shouldn't we add it to LoopInvariants). if (!Instr) continue; @@ -5956,44 +6022,44 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) { // For each VF find the maximum usage of registers. for (unsigned j = 0, e = VFs.size(); j < e; ++j) { - // Count the number of live intervals. + // Count the number of registers used, per register class, given all open + // intervals. + // Note that elements in this SmallMapVector will be default constructed + // as 0. So we can use "RegUsage[ClassID] += n" in the code below even if + // there is no previous entry for ClassID. SmallMapVector<unsigned, unsigned, 4> RegUsage; if (VFs[j].isScalar()) { - for (auto Inst : OpenIntervals) { - unsigned ClassID = TTI.getRegisterClassForType(false, Inst->getType()); - if (RegUsage.find(ClassID) == RegUsage.end()) - RegUsage[ClassID] = 1; - else - RegUsage[ClassID] += 1; + for (auto *Inst : OpenIntervals) { + unsigned ClassID = + TTI.getRegisterClassForType(false, Inst->getType()); + // FIXME: The target might use more than one register for the type + // even in the scalar case. + RegUsage[ClassID] += 1; } } else { collectUniformsAndScalars(VFs[j]); - for (auto Inst : OpenIntervals) { + for (auto *Inst : OpenIntervals) { // Skip ignored values for VF > 1. if (VecValuesToIgnore.count(Inst)) continue; if (isScalarAfterVectorization(Inst, VFs[j])) { - unsigned ClassID = TTI.getRegisterClassForType(false, Inst->getType()); - if (RegUsage.find(ClassID) == RegUsage.end()) - RegUsage[ClassID] = 1; - else - RegUsage[ClassID] += 1; + unsigned ClassID = + TTI.getRegisterClassForType(false, Inst->getType()); + // FIXME: The target might use more than one register for the type + // even in the scalar case. + RegUsage[ClassID] += 1; } else { - unsigned ClassID = TTI.getRegisterClassForType(true, Inst->getType()); - if (RegUsage.find(ClassID) == RegUsage.end()) - RegUsage[ClassID] = GetRegUsage(Inst->getType(), VFs[j]); - else - RegUsage[ClassID] += GetRegUsage(Inst->getType(), VFs[j]); + unsigned ClassID = + TTI.getRegisterClassForType(true, Inst->getType()); + RegUsage[ClassID] += GetRegUsage(Inst->getType(), VFs[j]); } } } for (auto& pair : RegUsage) { - if (MaxUsages[j].find(pair.first) != MaxUsages[j].end()) - MaxUsages[j][pair.first] = std::max(MaxUsages[j][pair.first], pair.second); - else - MaxUsages[j][pair.first] = pair.second; + auto &Entry = MaxUsages[j][pair.first]; + Entry = std::max(Entry, pair.second); } } @@ -6005,17 +6071,19 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) { } for (unsigned i = 0, e = VFs.size(); i < e; ++i) { + // Note that elements in this SmallMapVector will be default constructed + // as 0. So we can use "Invariant[ClassID] += n" in the code below even if + // there is no previous entry for ClassID. SmallMapVector<unsigned, unsigned, 4> Invariant; - for (auto Inst : LoopInvariants) { + for (auto *Inst : LoopInvariants) { + // FIXME: The target might use more than one register for the type + // even in the scalar case. unsigned Usage = VFs[i].isScalar() ? 1 : GetRegUsage(Inst->getType(), VFs[i]); unsigned ClassID = TTI.getRegisterClassForType(VFs[i].isVector(), Inst->getType()); - if (Invariant.find(ClassID) == Invariant.end()) - Invariant[ClassID] = Usage; - else - Invariant[ClassID] += Usage; + Invariant[ClassID] += Usage; } LLVM_DEBUG({ @@ -6054,7 +6122,7 @@ bool LoopVectorizationCostModel::useEmulatedMaskMemRefHack(Instruction *I, // from moving "masked load/store" check from legality to cost model. // Masked Load/Gather emulation was previously never allowed. // Limited number of Masked Store/Scatter emulation was allowed. - assert((isPredicatedInst(I, VF) || Legal->isUniformMemOp(*I)) && + assert((isPredicatedInst(I)) && "Expecting a scalar emulated instruction"); return isa<LoadInst>(I) || (isa<StoreInst>(I) && @@ -6099,7 +6167,7 @@ void LoopVectorizationCostModel::collectInstsToScalarize(ElementCount VF) { } } -int LoopVectorizationCostModel::computePredInstDiscount( +InstructionCost LoopVectorizationCostModel::computePredInstDiscount( Instruction *PredInst, ScalarCostsTy &ScalarCosts, ElementCount VF) { assert(!isUniformAfterVectorization(PredInst, VF) && "Instruction marked uniform-after-vectorization will be predicated"); @@ -6173,13 +6241,14 @@ int LoopVectorizationCostModel::computePredInstDiscount( // Compute the scalarization overhead of needed insertelement instructions // and phi nodes. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; if (isScalarWithPredication(I, VF) && !I->getType()->isVoidTy()) { ScalarCost += TTI.getScalarizationOverhead( cast<VectorType>(ToVectorTy(I->getType(), VF)), - APInt::getAllOnes(VF.getFixedValue()), true, false); + APInt::getAllOnes(VF.getFixedValue()), /*Insert*/ true, + /*Extract*/ false, CostKind); ScalarCost += - VF.getFixedValue() * - TTI.getCFInstrCost(Instruction::PHI, TTI::TCK_RecipThroughput); + VF.getFixedValue() * TTI.getCFInstrCost(Instruction::PHI, CostKind); } // Compute the scalarization overhead of needed extractelement @@ -6195,7 +6264,8 @@ int LoopVectorizationCostModel::computePredInstDiscount( else if (needsExtract(J, VF)) { ScalarCost += TTI.getScalarizationOverhead( cast<VectorType>(ToVectorTy(J->getType(), VF)), - APInt::getAllOnes(VF.getFixedValue()), false, true); + APInt::getAllOnes(VF.getFixedValue()), /*Insert*/ false, + /*Extract*/ true, CostKind); } } @@ -6208,7 +6278,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( ScalarCosts[I] = ScalarCost; } - return *Discount.getValue(); + return Discount; } LoopVectorizationCostModel::VectorizationCostTy @@ -6324,19 +6394,20 @@ LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, // Don't pass *I here, since it is scalar but will actually be part of a // vectorized loop where the user of it is a vectorized instruction. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; const Align Alignment = getLoadStoreAlignment(I); - Cost += VF.getKnownMinValue() * - TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment, - AS, TTI::TCK_RecipThroughput); + Cost += VF.getKnownMinValue() * TTI.getMemoryOpCost(I->getOpcode(), + ValTy->getScalarType(), + Alignment, AS, CostKind); // Get the overhead of the extractelement and insertelement instructions // we might create due to scalarization. - Cost += getScalarizationOverhead(I, VF); + Cost += getScalarizationOverhead(I, VF, CostKind); // If we have a predicated load/store, it will need extra i1 extracts and // conditional branches, but may not be executed for each vector lane. Scale // the cost by the probability of executing the predicated block. - if (isPredicatedInst(I, VF)) { + if (isPredicatedInst(I)) { Cost /= getReciprocalPredBlockProb(); // Add the cost of an i1 extract and a branch @@ -6344,8 +6415,8 @@ LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, VectorType::get(IntegerType::getInt1Ty(ValTy->getContext()), VF); Cost += TTI.getScalarizationOverhead( Vec_i1Ty, APInt::getAllOnes(VF.getKnownMinValue()), - /*Insert=*/false, /*Extract=*/true); - Cost += TTI.getCFInstrCost(Instruction::Br, TTI::TCK_RecipThroughput); + /*Insert=*/false, /*Extract=*/true, CostKind); + Cost += TTI.getCFInstrCost(Instruction::Br, CostKind); if (useEmulatedMaskMemRefHack(I, VF)) // Artificially setting to a high enough value to practically disable @@ -6370,17 +6441,19 @@ LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, "Stride should be 1 or -1 for consecutive memory access"); const Align Alignment = getLoadStoreAlignment(I); InstructionCost Cost = 0; - if (Legal->isMaskRequired(I)) + if (Legal->isMaskRequired(I)) { Cost += TTI.getMaskedMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS, CostKind); - else + } else { + TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(I->getOperand(0)); Cost += TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS, - CostKind, I); + CostKind, OpInfo, I); + } bool Reverse = ConsecutiveStride < 0; if (Reverse) - Cost += - TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, None, 0); + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, + std::nullopt, CostKind, 0); return Cost; } @@ -6409,7 +6482,7 @@ LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I, (isLoopInvariantStoreValue ? 0 : TTI.getVectorInstrCost(Instruction::ExtractElement, VectorTy, - VF.getKnownMinValue() - 1)); + CostKind, VF.getKnownMinValue() - 1)); } InstructionCost @@ -6437,6 +6510,7 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, Type *ValTy = getLoadStoreType(I); auto *VectorTy = cast<VectorType>(ToVectorTy(ValTy, VF)); unsigned AS = getLoadStoreAddressSpace(I); + enum TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; auto Group = getInterleavedAccessGroup(I); assert(Group && "Fail to get an interleaved access group."); @@ -6456,25 +6530,26 @@ LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, (isa<StoreInst>(I) && (Group->getNumMembers() < Group->getFactor())); InstructionCost Cost = TTI.getInterleavedMemoryOpCost( I->getOpcode(), WideVecTy, Group->getFactor(), Indices, Group->getAlign(), - AS, TTI::TCK_RecipThroughput, Legal->isMaskRequired(I), UseMaskForGaps); + AS, CostKind, Legal->isMaskRequired(I), UseMaskForGaps); if (Group->isReverse()) { // TODO: Add support for reversed masked interleaved access. assert(!Legal->isMaskRequired(I) && "Reverse masked interleaved access not supported."); - Cost += - Group->getNumMembers() * - TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, None, 0); + Cost += Group->getNumMembers() * + TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, + std::nullopt, CostKind, 0); } return Cost; } -Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( +std::optional<InstructionCost> +LoopVectorizationCostModel::getReductionPatternCost( Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) { using namespace llvm::PatternMatch; // Early exit for no inloop reductions if (InLoopReductionChains.empty() || VF.isScalar() || !isa<VectorType>(Ty)) - return None; + return std::nullopt; auto *VectorTy = cast<VectorType>(Ty); // We are looking for a pattern of, and finding the minimal acceptable cost: @@ -6492,20 +6567,19 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( Instruction *RetI = I; if (match(RetI, m_ZExtOrSExt(m_Value()))) { if (!RetI->hasOneUser()) - return None; + return std::nullopt; RetI = RetI->user_back(); } - if (match(RetI, m_Mul(m_Value(), m_Value())) && + + if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) && RetI->user_back()->getOpcode() == Instruction::Add) { - if (!RetI->hasOneUser()) - return None; RetI = RetI->user_back(); } // Test if the found instruction is a reduction, and if not return an invalid // cost specifying the parent to use the original cost modelling. if (!InLoopReductionImmediateChains.count(RetI)) - return None; + return std::nullopt; // Find the reduction this chain is a part of and calculate the basic cost of // the reduction on its own. @@ -6541,7 +6615,7 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy); Instruction *Op0, *Op1; - if (RedOp && + if (RedOp && RdxDesc.getOpcode() == Instruction::Add && match(RedOp, m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) && match(Op0, m_ZExtOrSExt(m_Value())) && @@ -6550,7 +6624,7 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1) && (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) { - // Matched reduce(ext(mul(ext(A), ext(B))) + // Matched reduce.add(ext(mul(ext(A), ext(B))) // Note that the extend opcodes need to all match, or if A==B they will have // been converted to zext(mul(sext(A), sext(A))) as it is known positive, // which is equally fine. @@ -6567,9 +6641,8 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( TTI.getCastInstrCost(RedOp->getOpcode(), VectorTy, MulType, TTI::CastContextHint::None, CostKind, RedOp); - InstructionCost RedCost = TTI.getExtendedAddReductionCost( - /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, - CostKind); + InstructionCost RedCost = TTI.getMulAccReductionCost( + IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost) @@ -6579,16 +6652,16 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( // Matched reduce(ext(A)) bool IsUnsigned = isa<ZExtInst>(RedOp); auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy); - InstructionCost RedCost = TTI.getExtendedAddReductionCost( - /*IsMLA=*/false, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, - CostKind); + InstructionCost RedCost = TTI.getExtendedReductionCost( + RdxDesc.getOpcode(), IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, + RdxDesc.getFastMathFlags(), CostKind); InstructionCost ExtCost = TTI.getCastInstrCost(RedOp->getOpcode(), VectorTy, ExtType, TTI::CastContextHint::None, CostKind, RedOp); if (RedCost.isValid() && RedCost < BaseCost + ExtCost) return I == RetI ? RedCost : 0; - } else if (RedOp && + } else if (RedOp && RdxDesc.getOpcode() == Instruction::Add && match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) { if (match(Op0, m_ZExtOrSExt(m_Value())) && Op0->getOpcode() == Op1->getOpcode() && @@ -6601,7 +6674,7 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( : Op0Ty; auto *ExtType = VectorType::get(LargestOpTy, VectorTy); - // Matched reduce(mul(ext(A), ext(B))), where the two ext may be of + // Matched reduce.add(mul(ext(A), ext(B))), where the two ext may be of // different sizes. We take the largest type as the ext to reduce, and add // the remaining cost as, for example reduce(mul(ext(ext(A)), ext(B))). InstructionCost ExtCost0 = TTI.getCastInstrCost( @@ -6613,9 +6686,8 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( InstructionCost MulCost = TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); - InstructionCost RedCost = TTI.getExtendedAddReductionCost( - /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, - CostKind); + InstructionCost RedCost = TTI.getMulAccReductionCost( + IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); InstructionCost ExtraExtCost = 0; if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) { Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1; @@ -6629,20 +6701,19 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost( (RedCost + ExtraExtCost) < (ExtCost0 + ExtCost1 + MulCost + BaseCost)) return I == RetI ? RedCost : 0; } else if (!match(I, m_ZExtOrSExt(m_Value()))) { - // Matched reduce(mul()) + // Matched reduce.add(mul()) InstructionCost MulCost = TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); - InstructionCost RedCost = TTI.getExtendedAddReductionCost( - /*IsMLA=*/true, true, RdxDesc.getRecurrenceType(), VectorTy, - CostKind); + InstructionCost RedCost = TTI.getMulAccReductionCost( + true, RdxDesc.getRecurrenceType(), VectorTy, CostKind); if (RedCost.isValid() && RedCost < MulCost + BaseCost) return I == RetI ? RedCost : 0; } } - return I == RetI ? Optional<InstructionCost>(BaseCost) : None; + return I == RetI ? std::optional<InstructionCost>(BaseCost) : std::nullopt; } InstructionCost @@ -6655,9 +6726,10 @@ LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I, const Align Alignment = getLoadStoreAlignment(I); unsigned AS = getLoadStoreAddressSpace(I); + TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(I->getOperand(0)); return TTI.getAddressComputationCost(ValTy) + TTI.getMemoryOpCost(I->getOpcode(), ValTy, Alignment, AS, - TTI::TCK_RecipThroughput, I); + TTI::TCK_RecipThroughput, OpInfo, I); } return getWideningCost(I, VF); } @@ -6705,9 +6777,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, return VectorizationCostTy(C, TypeNotScalarized); } -InstructionCost -LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I, - ElementCount VF) const { +InstructionCost LoopVectorizationCostModel::getScalarizationOverhead( + Instruction *I, ElementCount VF, TTI::TargetCostKind CostKind) const { // There is no mechanism yet to create a scalable scalarization loop, // so this is currently Invalid. @@ -6722,8 +6793,9 @@ LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I, if (!RetTy->isVoidTy() && (!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore())) Cost += TTI.getScalarizationOverhead( - cast<VectorType>(RetTy), APInt::getAllOnes(VF.getKnownMinValue()), true, - false); + cast<VectorType>(RetTy), APInt::getAllOnes(VF.getKnownMinValue()), + /*Insert*/ true, + /*Extract*/ false, CostKind); // Some targets keep addresses scalar. if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing()) @@ -6743,7 +6815,7 @@ LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I, for (auto *V : filterExtractingOperands(Ops, VF)) Tys.push_back(MaybeVectorizeType(V->getType(), VF)); return Cost + TTI.getOperandsScalarizationOverhead( - filterExtractingOperands(Ops, VF), Tys); + filterExtractingOperands(Ops, VF), Tys, CostKind); } void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) { @@ -6765,29 +6837,47 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) { NumPredStores++; if (Legal->isUniformMemOp(I)) { - // Lowering story for uniform memory ops is currently a bit complicated. - // Scalarization works for everything which isn't a store with scalable - // VF. Fixed len VFs just scalarize and then DCE later; scalarization - // knows how to handle uniform-per-part values (i.e. the first lane - // in each unrolled VF) and can thus handle scalable loads too. For - // scalable stores, we use a scatter if legal. If not, we have no way - // to lower (currently) and thus have to abort vectorization. - if (isa<StoreInst>(&I) && VF.isScalable()) { - if (isLegalGatherOrScatter(&I, VF)) - setWideningDecision(&I, VF, CM_GatherScatter, - getGatherScatterCost(&I, VF)); - else - // Error case, abort vectorization - setWideningDecision(&I, VF, CM_Scalarize, - InstructionCost::getInvalid()); - continue; - } + auto isLegalToScalarize = [&]() { + if (!VF.isScalable()) + // Scalarization of fixed length vectors "just works". + return true; + + // We have dedicated lowering for unpredicated uniform loads and + // stores. Note that even with tail folding we know that at least + // one lane is active (i.e. generalized predication is not possible + // here), and the logic below depends on this fact. + if (!foldTailByMasking()) + return true; + + // For scalable vectors, a uniform memop load is always + // uniform-by-parts and we know how to scalarize that. + if (isa<LoadInst>(I)) + return true; + + // A uniform store isn't neccessarily uniform-by-part + // and we can't assume scalarization. + auto &SI = cast<StoreInst>(I); + return TheLoop->isLoopInvariant(SI.getValueOperand()); + }; + + const InstructionCost GatherScatterCost = + isLegalGatherOrScatter(&I, VF) ? + getGatherScatterCost(&I, VF) : InstructionCost::getInvalid(); + // Load: Scalar load + broadcast // Store: Scalar store + isLoopInvariantStoreValue ? 0 : extract - // TODO: Avoid replicating loads and stores instead of relying on - // instcombine to remove them. - setWideningDecision(&I, VF, CM_Scalarize, - getUniformMemOpCost(&I, VF)); + // FIXME: This cost is a significant under-estimate for tail folded + // memory ops. + const InstructionCost ScalarizationCost = isLegalToScalarize() ? + getUniformMemOpCost(&I, VF) : InstructionCost::getInvalid(); + + // Choose better solution for the current VF, Note that Invalid + // costs compare as maximumal large. If both are invalid, we get + // scalable invalid which signals a failure and a vectorization abort. + if (GatherScatterCost < ScalarizationCost) + setWideningDecision(&I, VF, CM_GatherScatter, GatherScatterCost); + else + setWideningDecision(&I, VF, CM_Scalarize, ScalarizationCost); continue; } @@ -6982,7 +7072,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, VectorType::get(IntegerType::getInt1Ty(RetTy->getContext()), VF); return ( TTI.getScalarizationOverhead( - Vec_i1Ty, APInt::getAllOnes(VF.getFixedValue()), false, true) + + Vec_i1Ty, APInt::getAllOnes(VF.getFixedValue()), + /*Insert*/ false, /*Extract*/ true, CostKind) + (TTI.getCFInstrCost(Instruction::Br, CostKind) * VF.getFixedValue())); } else if (I->getParent() == TheLoop->getLoopLatch() || VF.isScalar()) // The back-edge branch will remain, as will all scalar branches. @@ -6998,11 +7089,13 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, auto *Phi = cast<PHINode>(I); // First-order recurrences are replaced by vector shuffles inside the loop. - // NOTE: Don't use ToVectorTy as SK_ExtractSubvector expects a vector type. - if (VF.isVector() && Legal->isFirstOrderRecurrence(Phi)) - return TTI.getShuffleCost( - TargetTransformInfo::SK_ExtractSubvector, cast<VectorType>(VectorTy), - None, VF.getKnownMinValue() - 1, FixedVectorType::get(RetTy, 1)); + if (VF.isVector() && Legal->isFixedOrderRecurrence(Phi)) { + SmallVector<int> Mask(VF.getKnownMinValue()); + std::iota(Mask.begin(), Mask.end(), VF.getKnownMinValue() - 1); + return TTI.getShuffleCost(TargetTransformInfo::SK_Splice, + cast<VectorType>(VectorTy), Mask, CostKind, + VF.getKnownMinValue() - 1); + } // Phi nodes in non-header blocks (not inductions, reductions, etc.) are // converted into select instructions. We require N - 1 selects per phi @@ -7020,34 +7113,13 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, case Instruction::SDiv: case Instruction::URem: case Instruction::SRem: - // If we have a predicated instruction, it may not be executed for each - // vector lane. Get the scalarization cost and scale this amount by the - // probability of executing the predicated block. If the instruction is not - // predicated, we fall through to the next case. - if (VF.isVector() && isScalarWithPredication(I, VF)) { - InstructionCost Cost = 0; - - // These instructions have a non-void type, so account for the phi nodes - // that we will create. This cost is likely to be zero. The phi node - // cost, if any, should be scaled by the block probability because it - // models a copy at the end of each predicated block. - Cost += VF.getKnownMinValue() * - TTI.getCFInstrCost(Instruction::PHI, CostKind); - - // The cost of the non-predicated instruction. - Cost += VF.getKnownMinValue() * - TTI.getArithmeticInstrCost(I->getOpcode(), RetTy, CostKind); - - // The cost of insertelement and extractelement instructions needed for - // scalarization. - Cost += getScalarizationOverhead(I, VF); - - // Scale the cost by the probability of executing the predicated blocks. - // This assumes the predicated block for each vector lane is equally - // likely. - return Cost / getReciprocalPredBlockProb(); + if (VF.isVector() && isPredicatedInst(I)) { + const auto [ScalarCost, SafeDivisorCost] = getDivRemSpeculationCost(I, VF); + return isDivRemScalarWithPredication(ScalarCost, SafeDivisorCost) ? + ScalarCost : SafeDivisorCost; } - LLVM_FALLTHROUGH; + // We've proven all lanes safe to speculate, fall through. + [[fallthrough]]; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -7073,22 +7145,22 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, // Certain instructions can be cheaper to vectorize if they have a constant // second vector operand. One example of this are shifts on x86. Value *Op2 = I->getOperand(1); - TargetTransformInfo::OperandValueProperties Op2VP; - TargetTransformInfo::OperandValueKind Op2VK = - TTI.getOperandInfo(Op2, Op2VP); - if (Op2VK == TargetTransformInfo::OK_AnyValue && Legal->isUniform(Op2)) - Op2VK = TargetTransformInfo::OK_UniformValue; + auto Op2Info = TTI.getOperandInfo(Op2); + if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && Legal->isUniform(Op2)) + Op2Info.Kind = TargetTransformInfo::OK_UniformValue; SmallVector<const Value *, 4> Operands(I->operand_values()); return TTI.getArithmeticInstrCost( - I->getOpcode(), VectorTy, CostKind, TargetTransformInfo::OK_AnyValue, - Op2VK, TargetTransformInfo::OP_None, Op2VP, Operands, I); + I->getOpcode(), VectorTy, CostKind, + {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None}, + Op2Info, Operands, I); } case Instruction::FNeg: { return TTI.getArithmeticInstrCost( - I->getOpcode(), VectorTy, CostKind, TargetTransformInfo::OK_AnyValue, - TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None, - TargetTransformInfo::OP_None, I->getOperand(0), I); + I->getOpcode(), VectorTy, CostKind, + {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None}, + {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None}, + I->getOperand(0), I); } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); @@ -7101,17 +7173,15 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, match(I, m_LogicalOr(m_Value(Op0), m_Value(Op1))))) { // select x, y, false --> x & y // select x, true, y --> x | y - TTI::OperandValueProperties Op1VP = TTI::OP_None; - TTI::OperandValueProperties Op2VP = TTI::OP_None; - TTI::OperandValueKind Op1VK = TTI::getOperandInfo(Op0, Op1VP); - TTI::OperandValueKind Op2VK = TTI::getOperandInfo(Op1, Op2VP); + const auto [Op1VK, Op1VP] = TTI::getOperandInfo(Op0); + const auto [Op2VK, Op2VP] = TTI::getOperandInfo(Op1); assert(Op0->getType()->getScalarSizeInBits() == 1 && Op1->getType()->getScalarSizeInBits() == 1); SmallVector<const Value *, 2> Operands{Op0, Op1}; return TTI.getArithmeticInstrCost( match(I, m_LogicalOr()) ? Instruction::Or : Instruction::And, VectorTy, - CostKind, Op1VK, Op2VK, Op1VP, Op2VP, Operands, I); + CostKind, {Op1VK, Op1VP}, {Op2VK, Op2VP}, Operands, I); } Type *CondTy = SI->getCondition()->getType(); @@ -7153,7 +7223,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, case Instruction::BitCast: if (I->getType()->isPointerTy()) return 0; - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::ZExt: case Instruction::SExt: case Instruction::FPToUI: @@ -7262,7 +7332,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, // the result would need to be a vector of pointers. if (VF.isScalable()) return InstructionCost::getInvalid(); - LLVM_FALLTHROUGH; + [[fallthrough]]; default: // This opcode is unknown. Assume that it is the same as 'mul'. return TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); @@ -7276,7 +7346,6 @@ static const char lv_name[] = "Loop Vectorization"; INITIALIZE_PASS_BEGIN(LoopVectorize, LV_NAME, lv_name, false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) @@ -7317,14 +7386,14 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { // Ignore type-promoting instructions we identified during reduction // detection. - for (auto &Reduction : Legal->getReductionVars()) { + for (const auto &Reduction : Legal->getReductionVars()) { const RecurrenceDescriptor &RedDes = Reduction.second; const SmallPtrSetImpl<Instruction *> &Casts = RedDes.getCastInsts(); VecValuesToIgnore.insert(Casts.begin(), Casts.end()); } // Ignore type-casting instructions we identified during induction // detection. - for (auto &Induction : Legal->getInductionVars()) { + for (const auto &Induction : Legal->getInductionVars()) { const InductionDescriptor &IndDes = Induction.second; const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts(); VecValuesToIgnore.insert(Casts.begin(), Casts.end()); @@ -7332,7 +7401,7 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { } void LoopVectorizationCostModel::collectInLoopReductions() { - for (auto &Reduction : Legal->getReductionVars()) { + for (const auto &Reduction : Legal->getReductionVars()) { PHINode *Phi = Reduction.first; const RecurrenceDescriptor &RdxDesc = Reduction.second; @@ -7394,7 +7463,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { if (UserVF.isZero()) { VF = ElementCount::getFixed(determineVPlanVF( TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) - .getFixedSize(), + .getFixedValue(), CM)); LLVM_DEBUG(dbgs() << "LV: VPlan computed VF " << VF << ".\n"); @@ -7425,12 +7494,12 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { return VectorizationFactor::Disabled(); } -Optional<VectorizationFactor> +std::optional<VectorizationFactor> LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { assert(OrigLoop->isInnermost() && "Inner loop expected."); FixedScalableVFPair MaxFactors = CM.computeMaxVF(UserVF, UserIC); if (!MaxFactors) // Cases that should not to be vectorized nor interleaved. - return None; + return std::nullopt; // Invalidate interleave groups if all blocks of loop will be predicated. if (CM.blockNeedsPredicationForAnyReason(OrigLoop->getHeader()) && @@ -7550,9 +7619,26 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, InnerLoopVectorizer &ILV, DominatorTree *DT, bool IsEpilogueVectorization) { + assert(BestVPlan.hasVF(BestVF) && + "Trying to execute plan with unsupported VF"); + assert(BestVPlan.hasUF(BestUF) && + "Trying to execute plan with unsupported UF"); + LLVM_DEBUG(dbgs() << "Executing best plan with VF=" << BestVF << ", UF=" << BestUF << '\n'); + // Workaround! Compute the trip count of the original loop and cache it + // before we start modifying the CFG. This code has a systemic problem + // wherein it tries to run analysis over partially constructed IR; this is + // wrong, and not simply for SCEV. The trip count of the original loop + // simply happens to be prone to hitting this in practice. In theory, we + // can hit the same issue for any SCEV, or ValueTracking query done during + // mutation. See PR49900. + ILV.getOrCreateTripCount(OrigLoop->getLoopPreheader()); + + if (!IsEpilogueVectorization) + VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE); + // Perform the actual loop transformation. // 1. Set up the skeleton for vectorization, including vector pre-header and @@ -7602,7 +7688,7 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, // replace the vectorizer-specific hints below). MDNode *OrigLoopID = OrigLoop->getLoopID(); - Optional<MDNode *> VectorizedLoopID = + std::optional<MDNode *> VectorizedLoopID = makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, LLVMLoopVectorizeFollowupVectorized}); @@ -7610,7 +7696,7 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, BestVPlan.getVectorLoopRegion()->getEntryBasicBlock(); Loop *L = LI->getLoopFor(State.CFG.VPBB2IRBB[HeaderVPBB]); if (VectorizedLoopID) - L->setLoopID(VectorizedLoopID.value()); + L->setLoopID(*VectorizedLoopID); else { // Keep all loop hints from the original loop on the vector loop (we'll // replace the vectorizer-specific hints below). @@ -7620,9 +7706,7 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, LoopVectorizeHints Hints(L, true, *ORE); Hints.setAlreadyVectorized(); } - // Disable runtime unrolling when vectorizing the epilogue loop. - if (CanonicalIVStartValue) - AddRuntimeUnrollDisableMetaData(L); + AddRuntimeUnrollDisableMetaData(L); // 3. Fix the vectorized code: take care of header phi's, live-outs, // predication, updating analyses. @@ -7651,16 +7735,6 @@ Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } /// depicted in https://llvm.org/docs/Vectorizers.html#epilogue-vectorization. std::pair<BasicBlock *, Value *> EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton() { - MDNode *OrigLoopID = OrigLoop->getLoopID(); - - // Workaround! Compute the trip count of the original loop and cache it - // before we start modifying the CFG. This code has a systemic problem - // wherein it tries to run analysis over partially constructed IR; this is - // wrong, and not simply for SCEV. The trip count of the original loop - // simply happens to be prone to hitting this in practice. In theory, we - // can hit the same issue for any SCEV, or ValueTracking query done during - // mutation. See PR49900. - getOrCreateTripCount(OrigLoop->getLoopPreheader()); createVectorLoopSkeleton(""); // Generate the code to check the minimum iteration count of the vector @@ -7691,11 +7765,11 @@ EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton() { EPI.VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); // Skip induction resume value creation here because they will be created in - // the second pass. If we created them here, they wouldn't be used anyway, - // because the vplan in the second pass still contains the inductions from the - // original loop. + // the second pass for the scalar loop. The induction resume values for the + // inductions in the epilogue loop are created before executing the plan for + // the epilogue loop. - return {completeLoopSkeleton(OrigLoopID), nullptr}; + return {completeLoopSkeleton(), nullptr}; } void EpilogueVectorizerMainLoop::printDebugTracesAtStart() { @@ -7779,7 +7853,6 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, /// depicted in https://llvm.org/docs/Vectorizers.html#epilogue-vectorization. std::pair<BasicBlock *, Value *> EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { - MDNode *OrigLoopID = OrigLoop->getLoopID(); createVectorLoopSkeleton("vec.epilog."); // Now, compare the remaining count and if there aren't enough iterations to @@ -7825,31 +7898,40 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { DT->changeImmediateDominator(LoopExitBlock, EPI.EpilogueIterationCountCheck); - // Keep track of bypass blocks, as they feed start values to the induction - // phis in the scalar loop preheader. + // Keep track of bypass blocks, as they feed start values to the induction and + // reduction phis in the scalar loop preheader. if (EPI.SCEVSafetyCheck) LoopBypassBlocks.push_back(EPI.SCEVSafetyCheck); if (EPI.MemSafetyCheck) LoopBypassBlocks.push_back(EPI.MemSafetyCheck); LoopBypassBlocks.push_back(EPI.EpilogueIterationCountCheck); - // The vec.epilog.iter.check block may contain Phi nodes from reductions which - // merge control-flow from the latch block and the middle block. Update the - // incoming values here and move the Phi into the preheader. + // The vec.epilog.iter.check block may contain Phi nodes from inductions or + // reductions which merge control-flow from the latch block and the middle + // block. Update the incoming values here and move the Phi into the preheader. SmallVector<PHINode *, 4> PhisInBlock; for (PHINode &Phi : VecEpilogueIterationCountCheck->phis()) PhisInBlock.push_back(&Phi); for (PHINode *Phi : PhisInBlock) { + Phi->moveBefore(LoopVectorPreHeader->getFirstNonPHI()); Phi->replaceIncomingBlockWith( VecEpilogueIterationCountCheck->getSinglePredecessor(), VecEpilogueIterationCountCheck); + + // If the phi doesn't have an incoming value from the + // EpilogueIterationCountCheck, we are done. Otherwise remove the incoming + // value and also those from other check blocks. This is needed for + // reduction phis only. + if (none_of(Phi->blocks(), [&](BasicBlock *IncB) { + return EPI.EpilogueIterationCountCheck == IncB; + })) + continue; Phi->removeIncomingValue(EPI.EpilogueIterationCountCheck); if (EPI.SCEVSafetyCheck) Phi->removeIncomingValue(EPI.SCEVSafetyCheck); if (EPI.MemSafetyCheck) Phi->removeIncomingValue(EPI.MemSafetyCheck); - Phi->moveBefore(LoopVectorPreHeader->getFirstNonPHI()); } // Generate a resume induction for the vector epilogue and put it in the @@ -7871,7 +7953,7 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { createInductionResumeValues({VecEpilogueIterationCountCheck, EPI.VectorTripCount} /* AdditionalBypass */); - return {completeLoopSkeleton(OrigLoopID), EPResumeVal}; + return {completeLoopSkeleton(), EPResumeVal}; } BasicBlock * @@ -8149,9 +8231,18 @@ VPRecipeBase *VPRecipeBuilder::tryToOptimizeInductionPHI( *PSE.getSE(), *OrigLoop, Range); // Check if this is pointer induction. If so, build the recipe for it. - if (auto *II = Legal->getPointerInductionDescriptor(Phi)) - return new VPWidenPointerInductionRecipe(Phi, Operands[0], *II, - *PSE.getSE()); + if (auto *II = Legal->getPointerInductionDescriptor(Phi)) { + VPValue *Step = vputils::getOrCreateVPValueForSCEVExpr(Plan, II->getStep(), + *PSE.getSE()); + assert(isa<SCEVConstant>(II->getStep())); + return new VPWidenPointerInductionRecipe( + Phi, Operands[0], Step, *II, + LoopVectorizationPlanner::getDecisionAndClampRange( + [&](ElementCount VF) { + return CM.isScalarAfterVectorization(Phi, VF); + }, + Range)); + } return nullptr; } @@ -8188,12 +8279,8 @@ VPRecipeOrVPValueTy VPRecipeBuilder::tryToBlend(PHINode *Phi, VPlanPtr &Plan) { // If all incoming values are equal, the incoming VPValue can be used directly // instead of creating a new VPBlendRecipe. - VPValue *FirstIncoming = Operands[0]; - if (all_of(Operands, [FirstIncoming](const VPValue *Inc) { - return FirstIncoming == Inc; - })) { + if (llvm::all_equal(Operands)) return Operands[0]; - } unsigned NumIncoming = Phi->getNumIncomingValues(); // For in-loop reductions, we do not need to create an additional select. @@ -8252,24 +8339,42 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, ID == Intrinsic::experimental_noalias_scope_decl)) return nullptr; - auto willWiden = [&](ElementCount VF) -> bool { - Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - // The following case may be scalarized depending on the VF. - // The flag shows whether we use Intrinsic or a usual Call for vectorized - // version of the instruction. - // Is it beneficial to perform intrinsic call compared to lib call? - bool NeedToScalarize = false; - InstructionCost CallCost = CM.getVectorCallCost(CI, VF, NeedToScalarize); - InstructionCost IntrinsicCost = ID ? CM.getVectorIntrinsicCost(CI, VF) : 0; - bool UseVectorIntrinsic = ID && IntrinsicCost <= CallCost; - return UseVectorIntrinsic || !NeedToScalarize; - }; + ArrayRef<VPValue *> Ops = Operands.take_front(CI->arg_size()); - if (!LoopVectorizationPlanner::getDecisionAndClampRange(willWiden, Range)) - return nullptr; + // Is it beneficial to perform intrinsic call compared to lib call? + bool ShouldUseVectorIntrinsic = + ID && LoopVectorizationPlanner::getDecisionAndClampRange( + [&](ElementCount VF) -> bool { + bool NeedToScalarize = false; + // Is it beneficial to perform intrinsic call compared to lib + // call? + InstructionCost CallCost = + CM.getVectorCallCost(CI, VF, NeedToScalarize); + InstructionCost IntrinsicCost = + CM.getVectorIntrinsicCost(CI, VF); + return IntrinsicCost <= CallCost; + }, + Range); + if (ShouldUseVectorIntrinsic) + return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), ID); + + // Is better to call a vectorized version of the function than to to scalarize + // the call? + auto ShouldUseVectorCall = LoopVectorizationPlanner::getDecisionAndClampRange( + [&](ElementCount VF) -> bool { + // The following case may be scalarized depending on the VF. + // The flag shows whether we can use a usual Call for vectorized + // version of the instruction. + bool NeedToScalarize = false; + CM.getVectorCallCost(CI, VF, NeedToScalarize); + return !NeedToScalarize; + }, + Range); + if (ShouldUseVectorCall) + return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), + Intrinsic::not_intrinsic); - ArrayRef<VPValue *> Ops = Operands.take_front(CI->arg_size()); - return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end())); + return nullptr; } bool VPRecipeBuilder::shouldWiden(Instruction *I, VFRange &Range) const { @@ -8286,55 +8391,65 @@ bool VPRecipeBuilder::shouldWiden(Instruction *I, VFRange &Range) const { Range); } -VPWidenRecipe *VPRecipeBuilder::tryToWiden(Instruction *I, - ArrayRef<VPValue *> Operands) const { - auto IsVectorizableOpcode = [](unsigned Opcode) { - switch (Opcode) { - case Instruction::Add: - case Instruction::And: - case Instruction::AShr: - case Instruction::BitCast: - case Instruction::FAdd: - case Instruction::FCmp: - case Instruction::FDiv: - case Instruction::FMul: - case Instruction::FNeg: - case Instruction::FPExt: - case Instruction::FPToSI: - case Instruction::FPToUI: - case Instruction::FPTrunc: - case Instruction::FRem: - case Instruction::FSub: - case Instruction::ICmp: - case Instruction::IntToPtr: - case Instruction::LShr: - case Instruction::Mul: - case Instruction::Or: - case Instruction::PtrToInt: - case Instruction::SDiv: - case Instruction::Select: - case Instruction::SExt: - case Instruction::Shl: - case Instruction::SIToFP: - case Instruction::SRem: - case Instruction::Sub: - case Instruction::Trunc: - case Instruction::UDiv: - case Instruction::UIToFP: - case Instruction::URem: - case Instruction::Xor: - case Instruction::ZExt: - case Instruction::Freeze: - return true; +VPRecipeBase *VPRecipeBuilder::tryToWiden(Instruction *I, + ArrayRef<VPValue *> Operands, + VPBasicBlock *VPBB, VPlanPtr &Plan) { + switch (I->getOpcode()) { + default: + return nullptr; + case Instruction::SDiv: + case Instruction::UDiv: + case Instruction::SRem: + case Instruction::URem: { + // If not provably safe, use a select to form a safe divisor before widening the + // div/rem operation itself. Otherwise fall through to general handling below. + if (CM.isPredicatedInst(I)) { + SmallVector<VPValue *> Ops(Operands.begin(), Operands.end()); + VPValue *Mask = createBlockInMask(I->getParent(), Plan); + VPValue *One = + Plan->getOrAddExternalDef(ConstantInt::get(I->getType(), 1u, false)); + auto *SafeRHS = + new VPInstruction(Instruction::Select, {Mask, Ops[1], One}, + I->getDebugLoc()); + VPBB->appendRecipe(SafeRHS); + Ops[1] = SafeRHS; + return new VPWidenRecipe(*I, make_range(Ops.begin(), Ops.end())); } - return false; + LLVM_FALLTHROUGH; + } + case Instruction::Add: + case Instruction::And: + case Instruction::AShr: + case Instruction::BitCast: + case Instruction::FAdd: + case Instruction::FCmp: + case Instruction::FDiv: + case Instruction::FMul: + case Instruction::FNeg: + case Instruction::FPExt: + case Instruction::FPToSI: + case Instruction::FPToUI: + case Instruction::FPTrunc: + case Instruction::FRem: + case Instruction::FSub: + case Instruction::ICmp: + case Instruction::IntToPtr: + case Instruction::LShr: + case Instruction::Mul: + case Instruction::Or: + case Instruction::PtrToInt: + case Instruction::Select: + case Instruction::SExt: + case Instruction::Shl: + case Instruction::SIToFP: + case Instruction::Sub: + case Instruction::Trunc: + case Instruction::UIToFP: + case Instruction::Xor: + case Instruction::ZExt: + case Instruction::Freeze: + return new VPWidenRecipe(*I, make_range(Operands.begin(), Operands.end())); }; - - if (!IsVectorizableOpcode(I->getOpcode())) - return nullptr; - - // Success: widen this instruction. - return new VPWidenRecipe(*I, make_range(Operands.begin(), Operands.end())); } void VPRecipeBuilder::fixHeaderPhis() { @@ -8354,9 +8469,7 @@ VPBasicBlock *VPRecipeBuilder::handleReplication( [&](ElementCount VF) { return CM.isUniformAfterVectorization(I, VF); }, Range); - bool IsPredicated = LoopVectorizationPlanner::getDecisionAndClampRange( - [&](ElementCount VF) { return CM.isPredicatedInst(I, VF); }, - Range); + bool IsPredicated = CM.isPredicatedInst(I); // Even if the instruction is not marked as uniform, there are certain // intrinsic calls that can be effectively treated as such, so we check for @@ -8396,11 +8509,12 @@ VPBasicBlock *VPRecipeBuilder::handleReplication( // value. Avoid hoisting the insert-element which packs the scalar value into // a vector value, as that happens iff all users use the vector value. for (VPValue *Op : Recipe->operands()) { - auto *PredR = dyn_cast_or_null<VPPredInstPHIRecipe>(Op->getDef()); + auto *PredR = + dyn_cast_or_null<VPPredInstPHIRecipe>(Op->getDefiningRecipe()); if (!PredR) continue; - auto *RepR = - cast_or_null<VPReplicateRecipe>(PredR->getOperand(0)->getDef()); + auto *RepR = cast<VPReplicateRecipe>( + PredR->getOperand(0)->getDefiningRecipe()); assert(RepR->isPredicated() && "expected Replicate recipe to be predicated"); RepR->setAlsoPack(false); @@ -8469,20 +8583,26 @@ VPRecipeBuilder::createReplicateRegion(VPReplicateRecipe *PredRecipe, VPRecipeOrVPValueTy VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, ArrayRef<VPValue *> Operands, - VFRange &Range, VPlanPtr &Plan) { + VFRange &Range, VPBasicBlock *VPBB, + VPlanPtr &Plan) { // First, check for specific widening recipes that deal with inductions, Phi // nodes, calls and memory operations. VPRecipeBase *Recipe; if (auto Phi = dyn_cast<PHINode>(Instr)) { if (Phi->getParent() != OrigLoop->getHeader()) return tryToBlend(Phi, Operands, Plan); + + // Always record recipes for header phis. Later first-order recurrence phis + // can have earlier phis as incoming values. + recordRecipeOf(Phi); + if ((Recipe = tryToOptimizeInductionPHI(Phi, Operands, *Plan, Range))) return toVPRecipeResult(Recipe); VPHeaderPHIRecipe *PhiRecipe = nullptr; assert((Legal->isReductionVariable(Phi) || - Legal->isFirstOrderRecurrence(Phi)) && - "can only widen reductions and first-order recurrences here"); + Legal->isFixedOrderRecurrence(Phi)) && + "can only widen reductions and fixed-order recurrences here"); VPValue *StartV = Operands[0]; if (Legal->isReductionVariable(Phi)) { const RecurrenceDescriptor &RdxDesc = @@ -8493,13 +8613,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, CM.isInLoopReduction(Phi), CM.useOrderedReductions(RdxDesc)); } else { + // TODO: Currently fixed-order recurrences are modeled as chains of + // first-order recurrences. If there are no users of the intermediate + // recurrences in the chain, the fixed order recurrence should be modeled + // directly, enabling more efficient codegen. PhiRecipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *StartV); } // Record the incoming value from the backedge, so we can add the incoming // value from the backedge after all recipes have been created. - recordRecipeOf(cast<Instruction>( - Phi->getIncomingValueForBlock(OrigLoop->getLoopLatch()))); + auto *Inc = cast<Instruction>( + Phi->getIncomingValueForBlock(OrigLoop->getLoopLatch())); + auto RecipeIter = Ingredient2Recipe.find(Inc); + if (RecipeIter == Ingredient2Recipe.end()) + recordRecipeOf(Inc); + PhisToFix.push_back(PhiRecipe); return toVPRecipeResult(PhiRecipe); } @@ -8534,7 +8662,7 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, *SI, make_range(Operands.begin(), Operands.end()), InvariantCond)); } - return toVPRecipeResult(tryToWiden(Instr, Operands)); + return toVPRecipeResult(tryToWiden(Instr, Operands, VPBB, Plan)); } void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, @@ -8564,7 +8692,7 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, assert( SinkTarget != FirstInst && "Must find a live instruction (at least the one feeding the " - "first-order recurrence PHI) before reaching beginning of the block"); + "fixed-order recurrence PHI) before reaching beginning of the block"); SinkTarget = SinkTarget->getPrevNode(); assert(SinkTarget != P.first && "sink source equals target, no sinking required"); @@ -8696,18 +8824,18 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // Mark instructions we'll need to sink later and their targets as // ingredients whose recipe we'll need to record. - for (auto &Entry : SinkAfter) { + for (const auto &Entry : SinkAfter) { RecipeBuilder.recordRecipeOf(Entry.first); RecipeBuilder.recordRecipeOf(Entry.second); } - for (auto &Reduction : CM.getInLoopReductionChains()) { + for (const auto &Reduction : CM.getInLoopReductionChains()) { PHINode *Phi = Reduction.first; RecurKind Kind = Legal->getReductionVars().find(Phi)->second.getRecurrenceKind(); const SmallVector<Instruction *, 4> &ReductionOperations = Reduction.second; RecipeBuilder.recordRecipeOf(Phi); - for (auto &R : ReductionOperations) { + for (const auto &R : ReductionOperations) { RecipeBuilder.recordRecipeOf(R); // For min/max reductions, where we have a pair of icmp/select, we also // need to record the ICmp recipe, so it can be removed later. @@ -8805,14 +8933,14 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( continue; if (auto RecipeOrValue = RecipeBuilder.tryToCreateWidenRecipe( - Instr, Operands, Range, Plan)) { + Instr, Operands, Range, VPBB, Plan)) { // If Instr can be simplified to an existing VPValue, use it. if (RecipeOrValue.is<VPValue *>()) { auto *VPV = RecipeOrValue.get<VPValue *>(); Plan->addVPValue(Instr, VPV); // If the re-used value is a recipe, register the recipe for the // instruction, in case the recipe for Instr needs to be recorded. - if (auto *R = dyn_cast_or_null<VPRecipeBase>(VPV->getDef())) + if (VPRecipeBase *R = VPV->getDefiningRecipe()) RecipeBuilder.setRecipe(Instr, R); continue; } @@ -8854,11 +8982,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor()); } - HeaderVPBB->setName("vector.body"); - - // Fold the last, empty block into its predecessor. - VPBB = VPBlockUtils::tryToMergeBlockIntoPredecessor(VPBB); - assert(VPBB && "expected to fold last (empty) block"); // After here, VPBB should not be used. VPBB = nullptr; @@ -8888,7 +9011,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( } return nullptr; }; - for (auto &Entry : SinkAfter) { + for (const auto &Entry : SinkAfter) { VPRecipeBase *Sink = RecipeBuilder.getRecipe(Entry.first); VPRecipeBase *Target = RecipeBuilder.getRecipe(Entry.second); @@ -8949,14 +9072,19 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( RecipeBuilder, Range.Start); // Introduce a recipe to combine the incoming and previous values of a - // first-order recurrence. + // fixed-order recurrence. for (VPRecipeBase &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { auto *RecurPhi = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R); if (!RecurPhi) continue; - VPRecipeBase *PrevRecipe = RecurPhi->getBackedgeRecipe(); + VPRecipeBase *PrevRecipe = &RecurPhi->getBackedgeRecipe(); + // Fixed-order recurrences do not contain cycles, so this loop is guaranteed + // to terminate. + while (auto *PrevPhi = + dyn_cast<VPFirstOrderRecurrencePHIRecipe>(PrevRecipe)) + PrevRecipe = &PrevPhi->getBackedgeRecipe(); VPBasicBlock *InsertBlock = PrevRecipe->getParent(); auto *Region = GetReplicateRegion(PrevRecipe); if (Region) @@ -8983,7 +9111,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // Interleave memory: for each Interleave Group we marked earlier as relevant // for this VPlan, replace the Recipes widening its memory instructions with a // single VPInterleaveRecipe at its insertion point. - for (auto IG : InterleaveGroups) { + for (const auto *IG : InterleaveGroups) { auto *Recipe = cast<VPWidenMemoryInstructionRecipe>( RecipeBuilder.getRecipe(IG->getInsertPos())); SmallVector<VPValue *, 4> StoredValues; @@ -9011,33 +9139,28 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( } } - std::string PlanName; - raw_string_ostream RSO(PlanName); - ElementCount VF = Range.Start; - Plan->addVF(VF); - RSO << "Initial VPlan for VF={" << VF; - for (VF *= 2; ElementCount::isKnownLT(VF, Range.End); VF *= 2) { + for (ElementCount VF = Range.Start; ElementCount::isKnownLT(VF, Range.End); + VF *= 2) Plan->addVF(VF); - RSO << "," << VF; - } - RSO << "},UF>=1"; - RSO.flush(); - Plan->setName(PlanName); + Plan->setName("Initial VPlan"); // From this point onwards, VPlan-to-VPlan transformations may change the plan // in ways that accessing values using original IR values is incorrect. Plan->disableValue2VPValue(); VPlanTransforms::optimizeInductions(*Plan, *PSE.getSE()); - VPlanTransforms::sinkScalarOperands(*Plan); VPlanTransforms::removeDeadRecipes(*Plan); - VPlanTransforms::mergeReplicateRegions(*Plan); - VPlanTransforms::removeRedundantExpandSCEVRecipes(*Plan); - // Fold Exit block into its predecessor if possible. - // TODO: Fold block earlier once all VPlan transforms properly maintain a - // VPBasicBlock as exit. - VPBlockUtils::tryToMergeBlockIntoPredecessor(TopRegion->getExiting()); + bool ShouldSimplify = true; + while (ShouldSimplify) { + ShouldSimplify = VPlanTransforms::sinkScalarOperands(*Plan); + ShouldSimplify |= + VPlanTransforms::mergeReplicateRegionsIntoSuccessors(*Plan); + ShouldSimplify |= VPlanTransforms::mergeBlocksIntoPredecessors(*Plan); + } + + VPlanTransforms::removeRedundantExpandSCEVRecipes(*Plan); + VPlanTransforms::mergeBlocksIntoPredecessors(*Plan); assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid"); return Plan; @@ -9066,7 +9189,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { VPlanTransforms::VPInstructionsToVPRecipes( OrigLoop, Plan, [this](PHINode *P) { return Legal->getIntOrFpInductionDescriptor(P); }, - DeadInstructions, *PSE.getSE()); + DeadInstructions, *PSE.getSE(), *TLI); // Remove the existing terminator of the exiting block of the top-most region. // A BranchOnCount will be added instead when adding the canonical IV recipes. @@ -9087,7 +9210,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { void LoopVectorizationPlanner::adjustRecipesForReductions( VPBasicBlock *LatchVPBB, VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) { - for (auto &Reduction : CM.getInLoopReductionChains()) { + for (const auto &Reduction : CM.getInLoopReductionChains()) { PHINode *Phi = Reduction.first; const RecurrenceDescriptor &RdxDesc = Legal->getReductionVars().find(Phi)->second; @@ -9127,9 +9250,13 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( R->getOperand(FirstOpId) == Chain ? FirstOpId + 1 : FirstOpId; VPValue *VecOp = Plan->getVPValue(R->getOperand(VecOpId)); - auto *CondOp = CM.blockNeedsPredicationForAnyReason(R->getParent()) - ? RecipeBuilder.createBlockInMask(R->getParent(), Plan) - : nullptr; + VPValue *CondOp = nullptr; + if (CM.blockNeedsPredicationForAnyReason(R->getParent())) { + VPBuilder::InsertPointGuard Guard(Builder); + Builder.setInsertPoint(WidenRecipe->getParent(), + WidenRecipe->getIterator()); + CondOp = RecipeBuilder.createBlockInMask(R->getParent(), Plan); + } if (IsFMulAdd) { // If the instruction is a call to the llvm.fmuladd intrinsic then we @@ -9179,7 +9306,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( VPValue *Cond = RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), Plan); VPValue *Red = PhiR->getBackedgeValue(); - assert(cast<VPRecipeBase>(Red->getDef())->getParent() != LatchVPBB && + assert(Red->getDefiningRecipe()->getParent() != LatchVPBB && "reduction recipe must be defined before latch"); Builder.createNaryOp(Instruction::Select, {Cond, Red, PhiR}); } @@ -9217,11 +9344,6 @@ void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent, } #endif -void VPWidenCallRecipe::execute(VPTransformState &State) { - State.ILV->widenCallInstruction(*cast<CallInst>(getUnderlyingInstr()), this, - *this, State); -} - void VPWidenIntOrFpInductionRecipe::execute(VPTransformState &State) { assert(!State.Instance && "Int or FP induction being replicated."); @@ -9353,8 +9475,7 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { PartStart, ConstantInt::get(PtrInd->getType(), Lane)); Value *GlobalIdx = State.Builder.CreateAdd(PtrInd, Idx); - Value *Step = CreateStepValue(IndDesc.getStep(), SE, - State.CFG.PrevBB->getTerminator()); + Value *Step = State.get(getOperand(1), VPIteration(0, Part)); Value *SclrGep = emitTransformedIndex( State.Builder, GlobalIdx, IndDesc.getStartValue(), Step, IndDesc); SclrGep->setName("next.gep"); @@ -9378,12 +9499,9 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { NewPointerPhi->addIncoming(ScalarStartValue, VectorPH); // A pointer induction, performed by using a gep - const DataLayout &DL = NewPointerPhi->getModule()->getDataLayout(); Instruction *InductionLoc = &*State.Builder.GetInsertPoint(); - const SCEV *ScalarStep = IndDesc.getStep(); - SCEVExpander Exp(SE, DL, "induction"); - Value *ScalarStepValue = Exp.expandCodeFor(ScalarStep, PhiType, InductionLoc); + Value *ScalarStepValue = State.get(getOperand(1), VPIteration(0, 0)); Value *RuntimeVF = getRuntimeVF(State.Builder, PhiType, State.VF); Value *NumUnrolledElems = State.Builder.CreateMul(RuntimeVF, ConstantInt::get(PhiType, State.UF)); @@ -9411,6 +9529,8 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { StartOffset = State.Builder.CreateAdd( StartOffset, State.Builder.CreateStepVector(VecPhiType)); + assert(ScalarStepValue == State.get(getOperand(1), VPIteration(0, Part)) && + "scalar step must be the same across all parts"); Value *GEP = State.Builder.CreateGEP( IndDesc.getElementType(), NewPointerPhi, State.Builder.CreateMul( @@ -9421,8 +9541,8 @@ void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { } } -void VPScalarIVStepsRecipe::execute(VPTransformState &State) { - assert(!State.Instance && "VPScalarIVStepsRecipe being replicated."); +void VPDerivedIVRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "VPDerivedIVRecipe being replicated."); // Fast-math-flags propagate from the original induction instruction. IRBuilder<>::FastMathFlagGuard FMFG(State.Builder); @@ -9432,52 +9552,33 @@ void VPScalarIVStepsRecipe::execute(VPTransformState &State) { IndDesc.getInductionBinOp()->getFastMathFlags()); Value *Step = State.get(getStepValue(), VPIteration(0, 0)); - auto CreateScalarIV = [&](Value *&Step) -> Value * { - Value *ScalarIV = State.get(getCanonicalIV(), VPIteration(0, 0)); - auto *CanonicalIV = State.get(getParent()->getPlan()->getCanonicalIV(), 0); - if (!isCanonical() || CanonicalIV->getType() != Ty) { - ScalarIV = - Ty->isIntegerTy() - ? State.Builder.CreateSExtOrTrunc(ScalarIV, Ty) - : State.Builder.CreateCast(Instruction::SIToFP, ScalarIV, Ty); - ScalarIV = emitTransformedIndex(State.Builder, ScalarIV, - getStartValue()->getLiveInIRValue(), Step, - IndDesc); - ScalarIV->setName("offset.idx"); - } - if (TruncToTy) { - assert(Step->getType()->isIntegerTy() && - "Truncation requires an integer step"); - ScalarIV = State.Builder.CreateTrunc(ScalarIV, TruncToTy); - Step = State.Builder.CreateTrunc(Step, TruncToTy); - } - return ScalarIV; - }; - - Value *ScalarIV = CreateScalarIV(Step); - if (State.VF.isVector()) { - buildScalarSteps(ScalarIV, Step, IndDesc, this, State); - return; + Value *CanonicalIV = State.get(getCanonicalIV(), VPIteration(0, 0)); + Value *DerivedIV = + emitTransformedIndex(State.Builder, CanonicalIV, + getStartValue()->getLiveInIRValue(), Step, IndDesc); + DerivedIV->setName("offset.idx"); + if (ResultTy != DerivedIV->getType()) { + assert(Step->getType()->isIntegerTy() && + "Truncation requires an integer step"); + DerivedIV = State.Builder.CreateTrunc(DerivedIV, ResultTy); } + assert(DerivedIV != CanonicalIV && "IV didn't need transforming?"); - for (unsigned Part = 0; Part < State.UF; ++Part) { - assert(!State.VF.isScalable() && "scalable vectors not yet supported."); - Value *EntryPart; - if (Step->getType()->isFloatingPointTy()) { - Value *StartIdx = - getRuntimeVFAsFloat(State.Builder, Step->getType(), State.VF * Part); - // Floating-point operations inherit FMF via the builder's flags. - Value *MulOp = State.Builder.CreateFMul(StartIdx, Step); - EntryPart = State.Builder.CreateBinOp(IndDesc.getInductionOpcode(), - ScalarIV, MulOp); - } else { - Value *StartIdx = - getRuntimeVF(State.Builder, Step->getType(), State.VF * Part); - EntryPart = State.Builder.CreateAdd( - ScalarIV, State.Builder.CreateMul(StartIdx, Step), "induction"); - } - State.set(this, EntryPart, Part); - } + State.set(this, DerivedIV, VPIteration(0, 0)); +} + +void VPScalarIVStepsRecipe::execute(VPTransformState &State) { + // Fast-math-flags propagate from the original induction instruction. + IRBuilder<>::FastMathFlagGuard FMFG(State.Builder); + if (IndDesc.getInductionBinOp() && + isa<FPMathOperator>(IndDesc.getInductionBinOp())) + State.Builder.setFastMathFlags( + IndDesc.getInductionBinOp()->getFastMathFlags()); + + Value *BaseIV = State.get(getOperand(0), VPIteration(0, 0)); + Value *Step = State.get(getStepValue(), VPIteration(0, 0)); + + buildScalarSteps(BaseIV, Step, IndDesc, this, State); } void VPInterleaveRecipe::execute(VPTransformState &State) { @@ -9536,9 +9637,10 @@ void VPReductionRecipe::execute(VPTransformState &State) { } void VPReplicateRecipe::execute(VPTransformState &State) { + Instruction *UI = getUnderlyingInstr(); if (State.Instance) { // Generate a single instance. assert(!State.VF.isScalable() && "Can't scalarize a scalable vector"); - State.ILV->scalarizeInstruction(getUnderlyingInstr(), this, *State.Instance, + State.ILV->scalarizeInstruction(UI, this, *State.Instance, IsPredicated, State); // Insert scalar instance packing it into a vector. if (AlsoPack && State.VF.isVector()) { @@ -9546,7 +9648,7 @@ void VPReplicateRecipe::execute(VPTransformState &State) { if (State.Instance->Lane.isFirstLane()) { assert(!State.VF.isScalable() && "VF is assumed to be non scalable."); Value *Poison = PoisonValue::get( - VectorType::get(getUnderlyingValue()->getType(), State.VF)); + VectorType::get(UI->getType(), State.VF)); State.set(this, Poison, State.Instance->Part); } State.ILV->packScalarIntoVectorValue(this, *State.Instance, State); @@ -9555,12 +9657,36 @@ void VPReplicateRecipe::execute(VPTransformState &State) { } if (IsUniform) { + // If the recipe is uniform across all parts (instead of just per VF), only + // generate a single instance. + if ((isa<LoadInst>(UI) || isa<StoreInst>(UI)) && + all_of(operands(), [](VPValue *Op) { + return Op->isDefinedOutsideVectorRegions(); + })) { + State.ILV->scalarizeInstruction(UI, this, VPIteration(0, 0), IsPredicated, + State); + if (user_begin() != user_end()) { + for (unsigned Part = 1; Part < State.UF; ++Part) + State.set(this, State.get(this, VPIteration(0, 0)), + VPIteration(Part, 0)); + } + return; + } + // Uniform within VL means we need to generate lane 0 only for each // unrolled copy. for (unsigned Part = 0; Part < State.UF; ++Part) - State.ILV->scalarizeInstruction(getUnderlyingInstr(), this, - VPIteration(Part, 0), IsPredicated, - State); + State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, 0), + IsPredicated, State); + return; + } + + // A store of a loop varying value to a loop invariant address only + // needs only the last copy of the store. + if (isa<StoreInst>(UI) && !getOperand(1)->hasDefiningRecipe()) { + auto Lane = VPLane::getLastLaneForVF(State.VF); + State.ILV->scalarizeInstruction(UI, this, VPIteration(State.UF - 1, Lane), IsPredicated, + State); return; } @@ -9569,9 +9695,8 @@ void VPReplicateRecipe::execute(VPTransformState &State) { const unsigned EndLane = State.VF.getKnownMinValue(); for (unsigned Part = 0; Part < State.UF; ++Part) for (unsigned Lane = 0; Lane < EndLane; ++Lane) - State.ILV->scalarizeInstruction(getUnderlyingInstr(), this, - VPIteration(Part, Lane), IsPredicated, - State); + State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, Lane), + IsPredicated, State); } void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { @@ -9709,7 +9834,7 @@ static ScalarEpilogueLowering getScalarEpilogueLowering( Function *F, Loop *L, LoopVectorizeHints &Hints, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, TargetTransformInfo *TTI, TargetLibraryInfo *TLI, AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, - LoopVectorizationLegality &LVL) { + LoopVectorizationLegality &LVL, InterleavedAccessInfo *IAI) { // 1) OptSize takes precedence over all other options, i.e. if this is set, // don't look at hints or options, and don't request a scalar epilogue. // (For PGSO, as shouldOptimizeForSize isn't currently accessible from @@ -9744,7 +9869,7 @@ static ScalarEpilogueLowering getScalarEpilogueLowering( }; // 4) if the TTI hook indicates this is profitable, request predication. - if (TTI->preferPredicateOverEpilogue(L, LI, *SE, *AC, TLI, DT, &LVL)) + if (TTI->preferPredicateOverEpilogue(L, LI, *SE, *AC, TLI, DT, &LVL, IAI)) return CM_ScalarEpilogueNotNeededUsePredicate; return CM_ScalarEpilogueAllowed; @@ -9770,15 +9895,14 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part) { return ScalarValue; } - auto *RepR = dyn_cast<VPReplicateRecipe>(Def); - bool IsUniform = RepR && RepR->isUniform(); + bool IsUniform = vputils::isUniformAfterVectorization(Def); unsigned LastLane = IsUniform ? 0 : VF.getKnownMinValue() - 1; // Check if there is a scalar value for the selected lane. if (!hasScalarValue(Def, {Part, LastLane})) { - // At the moment, VPWidenIntOrFpInductionRecipes can also be uniform. - assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDef()) || - isa<VPScalarIVStepsRecipe>(Def->getDef())) && + // At the moment, VPWidenIntOrFpInductionRecipes and VPScalarIVStepsRecipes can also be uniform. + assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDefiningRecipe()) || + isa<VPScalarIVStepsRecipe>(Def->getDefiningRecipe())) && "unexpected recipe found to be invariant"); IsUniform = true; LastLane = 0; @@ -9839,7 +9963,7 @@ static bool processLoopInVPlanNativePath( InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL->getLAI()); ScalarEpilogueLowering SEL = getScalarEpilogueLowering( - F, L, Hints, PSI, BFI, TTI, TLI, AC, LI, PSE.getSE(), DT, *LVL); + F, L, Hints, PSI, BFI, TTI, TLI, AC, LI, PSE.getSE(), DT, *LVL, &IAI); LoopVectorizationCostModel CM(SEL, L, PSE, LI, LVL, *TTI, TLI, DB, AC, ORE, F, &Hints, IAI); @@ -9927,7 +10051,7 @@ static void checkMixedPrecision(Loop *L, OptimizationRemarkEmitter *ORE) { static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks, VectorizationFactor &VF, - Optional<unsigned> VScale, Loop *L, + std::optional<unsigned> VScale, Loop *L, ScalarEvolution &SE) { InstructionCost CheckCost = Checks.getCost(); if (!CheckCost.isValid()) @@ -10075,7 +10199,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements; - LoopVectorizationLegality LVL(L, PSE, DT, TTI, TLI, AA, F, GetLAA, LI, ORE, + LoopVectorizationLegality LVL(L, PSE, DT, TTI, TLI, F, *LAIs, LI, ORE, &Requirements, &Hints, DB, AC, BFI, PSI); if (!LVL.canVectorize(EnableVPlanNativePath)) { LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); @@ -10083,11 +10207,6 @@ bool LoopVectorizePass::processLoop(Loop *L) { return false; } - // Check the function attributes and profiles to find out if this function - // should be optimized for size. - ScalarEpilogueLowering SEL = getScalarEpilogueLowering( - F, L, Hints, PSI, BFI, TTI, TLI, AC, LI, PSE.getSE(), DT, LVL); - // Entrance to the VPlan-native vectorization path. Outer loops are processed // here. They may require CFG and instruction level transformations before // even evaluating whether vectorization is profitable. Since we cannot modify @@ -10099,6 +10218,22 @@ bool LoopVectorizePass::processLoop(Loop *L) { assert(L->isInnermost() && "Inner loop expected."); + InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL.getLAI()); + bool UseInterleaved = TTI->enableInterleavedAccessVectorization(); + + // If an override option has been passed in for interleaved accesses, use it. + if (EnableInterleavedMemAccesses.getNumOccurrences() > 0) + UseInterleaved = EnableInterleavedMemAccesses; + + // Analyze interleaved memory accesses. + if (UseInterleaved) + IAI.analyzeInterleaving(useMaskedInterleavedAccesses(*TTI)); + + // Check the function attributes and profiles to find out if this function + // should be optimized for size. + ScalarEpilogueLowering SEL = getScalarEpilogueLowering( + F, L, Hints, PSI, BFI, TTI, TLI, AC, LI, PSE.getSE(), DT, LVL, &IAI); + // Check the loop for a trip count threshold: vectorize loops with a tiny trip // count by optimizing for size, to minimize overheads. auto ExpectedTC = getSmallBestKnownTC(*SE, L); @@ -10109,15 +10244,24 @@ bool LoopVectorizePass::processLoop(Loop *L) { if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) LLVM_DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); else { - LLVM_DEBUG(dbgs() << "\n"); - SEL = CM_ScalarEpilogueNotAllowedLowTripLoop; + if (*ExpectedTC > TTI->getMinTripCountTailFoldingThreshold()) { + LLVM_DEBUG(dbgs() << "\n"); + SEL = CM_ScalarEpilogueNotAllowedLowTripLoop; + } else { + LLVM_DEBUG(dbgs() << " But the target considers the trip count too " + "small to consider vectorizing.\n"); + reportVectorizationFailure( + "The trip count is below the minial threshold value.", + "loop trip count is too low, avoiding vectorization", + "LowTripCount", ORE, L); + Hints.emitRemarkWithHints(); + return false; + } } } - // Check the function attributes to see if implicit floats are allowed. - // FIXME: This check doesn't seem possibly correct -- what if the loop is - // an integer loop and the vector instructions selected are purely integer - // vector instructions? + // Check the function attributes to see if implicit floats or vectors are + // allowed. if (F->hasFnAttribute(Attribute::NoImplicitFloat)) { reportVectorizationFailure( "Can't vectorize when the NoImplicitFloat attribute is used", @@ -10162,18 +10306,6 @@ bool LoopVectorizePass::processLoop(Loop *L) { return false; } - bool UseInterleaved = TTI->enableInterleavedAccessVectorization(); - InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL.getLAI()); - - // If an override option has been passed in for interleaved accesses, use it. - if (EnableInterleavedMemAccesses.getNumOccurrences() > 0) - UseInterleaved = EnableInterleavedMemAccesses; - - // Analyze interleaved memory accesses. - if (UseInterleaved) { - IAI.analyzeInterleaving(useMaskedInterleavedAccesses(*TTI)); - } - // Use the cost model. LoopVectorizationCostModel CM(SEL, L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, &Hints, IAI); @@ -10188,7 +10320,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { unsigned UserIC = Hints.getInterleave(); // Plan how to best vectorize, return the best VF and its cost. - Optional<VectorizationFactor> MaybeVF = LVP.plan(UserVF, UserIC); + std::optional<VectorizationFactor> MaybeVF = LVP.plan(UserVF, UserIC); VectorizationFactor VF = VectorizationFactor::Disabled(); unsigned IC = 1; @@ -10198,7 +10330,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { if (MaybeVF) { VF = *MaybeVF; // Select the interleave count. - IC = CM.selectInterleaveCount(VF.Width, *VF.Cost.getValue()); + IC = CM.selectInterleaveCount(VF.Width, VF.Cost); unsigned SelectedIC = std::max(IC, UserIC); // Optimistically generate runtime checks if they are needed. Drop them if @@ -10360,16 +10492,39 @@ bool LoopVectorizePass::processLoop(Loop *L) { VPBasicBlock *Header = VectorLoop->getEntryBasicBlock(); Header->setName("vec.epilog.vector.body"); - // Ensure that the start values for any VPReductionPHIRecipes are - // updated before vectorising the epilogue loop. + // Ensure that the start values for any VPWidenIntOrFpInductionRecipe, + // VPWidenPointerInductionRecipe and VPReductionPHIRecipes are updated + // before vectorizing the epilogue loop. for (VPRecipeBase &R : Header->phis()) { + if (isa<VPCanonicalIVPHIRecipe>(&R)) + continue; + + Value *ResumeV = nullptr; + // TODO: Move setting of resume values to prepareToExecute. if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) { - if (auto *Resume = MainILV.getReductionResumeValue( - ReductionPhi->getRecurrenceDescriptor())) { - VPValue *StartVal = BestEpiPlan.getOrAddExternalDef(Resume); - ReductionPhi->setOperand(0, StartVal); + ResumeV = MainILV.getReductionResumeValue( + ReductionPhi->getRecurrenceDescriptor()); + } else { + // Create induction resume values for both widened pointer and + // integer/fp inductions and update the start value of the induction + // recipes to use the resume value. + PHINode *IndPhi = nullptr; + const InductionDescriptor *ID; + if (auto *Ind = dyn_cast<VPWidenPointerInductionRecipe>(&R)) { + IndPhi = cast<PHINode>(Ind->getUnderlyingValue()); + ID = &Ind->getInductionDescriptor(); + } else { + auto *WidenInd = cast<VPWidenIntOrFpInductionRecipe>(&R); + IndPhi = WidenInd->getPHINode(); + ID = &WidenInd->getInductionDescriptor(); } + + ResumeV = MainILV.createInductionResumeValue( + IndPhi, *ID, {EPI.MainLoopIterationCountCheck}); } + assert(ResumeV && "Must have a resume value"); + VPValue *StartVal = BestEpiPlan.getOrAddExternalDef(ResumeV); + cast<VPHeaderPHIRecipe>(&R)->setStartValue(StartVal); } LVP.executePlan(EPI.EpilogueVF, EPI.EpilogueUF, BestEpiPlan, EpilogILV, @@ -10407,11 +10562,11 @@ bool LoopVectorizePass::processLoop(Loop *L) { checkMixedPrecision(L, ORE); } - Optional<MDNode *> RemainderLoopID = + std::optional<MDNode *> RemainderLoopID = makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, LLVMLoopVectorizeFollowupEpilogue}); if (RemainderLoopID) { - L->setLoopID(RemainderLoopID.value()); + L->setLoopID(*RemainderLoopID); } else { if (DisableRuntimeUnroll) AddRuntimeUnrollDisableMetaData(L); @@ -10427,8 +10582,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { LoopVectorizeResult LoopVectorizePass::runImpl( Function &F, ScalarEvolution &SE_, LoopInfo &LI_, TargetTransformInfo &TTI_, DominatorTree &DT_, BlockFrequencyInfo &BFI_, TargetLibraryInfo *TLI_, - DemandedBits &DB_, AAResults &AA_, AssumptionCache &AC_, - std::function<const LoopAccessInfo &(Loop &)> &GetLAA_, + DemandedBits &DB_, AssumptionCache &AC_, LoopAccessInfoManager &LAIs_, OptimizationRemarkEmitter &ORE_, ProfileSummaryInfo *PSI_) { SE = &SE_; LI = &LI_; @@ -10436,9 +10590,8 @@ LoopVectorizeResult LoopVectorizePass::runImpl( DT = &DT_; BFI = &BFI_; TLI = TLI_; - AA = &AA_; AC = &AC_; - GetLAA = &GetLAA_; + LAIs = &LAIs_; DB = &DB_; ORE = &ORE_; PSI = PSI_; @@ -10461,7 +10614,7 @@ LoopVectorizeResult LoopVectorizePass::runImpl( // legality and profitability checks. This means running the loop vectorizer // will simplify all loops, regardless of whether anything end up being // vectorized. - for (auto &L : *LI) + for (const auto &L : *LI) Changed |= CFGChanged |= simplifyLoop(L, DT, LI, SE, AC, nullptr, false /* PreserveLCSSA */); @@ -10484,6 +10637,9 @@ LoopVectorizeResult LoopVectorizePass::runImpl( Changed |= formLCSSARecursively(*L, *DT, LI, SE); Changed |= CFGChanged |= processLoop(L); + + if (Changed) + LAIs->clear(); } // Process each loop nest in the function. @@ -10502,23 +10658,16 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &BFI = AM.getResult<BlockFrequencyAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - auto &AA = AM.getResult<AAManager>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DB = AM.getResult<DemandedBitsAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); - std::function<const LoopAccessInfo &(Loop &)> GetLAA = - [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, - TLI, TTI, nullptr, nullptr, nullptr}; - return LAM.getResult<LoopAccessAnalysis>(L, AR); - }; + LoopAccessInfoManager &LAIs = AM.getResult<LoopAccessAnalysis>(F); auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); ProfileSummaryInfo *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); LoopVectorizeResult Result = - runImpl(F, SE, LI, TTI, DT, BFI, &TLI, DB, AA, AC, GetLAA, ORE, PSI); + runImpl(F, SE, LI, TTI, DT, BFI, &TLI, DB, AC, LAIs, ORE, PSI); if (!Result.MadeAnyChange) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index d69d1e3d19f3..e3eb6b1804e7 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -19,7 +19,6 @@ #include "llvm/Transforms/Vectorize/SLPVectorizer.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/PriorityQueue.h" #include "llvm/ADT/STLExtras.h" @@ -94,6 +93,7 @@ #include <cstdint> #include <iterator> #include <memory> +#include <optional> #include <set> #include <string> #include <tuple> @@ -205,7 +205,7 @@ static bool isValidElementType(Type *Ty) { /// \returns True if the value is a constant (but not globals/constant /// expressions). static bool isConstant(Value *V) { - return isa<Constant>(V) && !isa<ConstantExpr>(V) && !isa<GlobalValue>(V); + return isa<Constant>(V) && !isa<ConstantExpr, GlobalValue>(V); } /// Checks if \p V is one of vector-like instructions, i.e. undef, @@ -284,24 +284,124 @@ static bool isCommutative(Instruction *I) { return false; } +/// \returns inserting index of InsertElement or InsertValue instruction, +/// using Offset as base offset for index. +static std::optional<unsigned> getInsertIndex(const Value *InsertInst, + unsigned Offset = 0) { + int Index = Offset; + if (const auto *IE = dyn_cast<InsertElementInst>(InsertInst)) { + const auto *VT = dyn_cast<FixedVectorType>(IE->getType()); + if (!VT) + return std::nullopt; + const auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2)); + if (!CI) + return std::nullopt; + if (CI->getValue().uge(VT->getNumElements())) + return std::nullopt; + Index *= VT->getNumElements(); + Index += CI->getZExtValue(); + return Index; + } + + const auto *IV = cast<InsertValueInst>(InsertInst); + Type *CurrentType = IV->getType(); + for (unsigned I : IV->indices()) { + if (const auto *ST = dyn_cast<StructType>(CurrentType)) { + Index *= ST->getNumElements(); + CurrentType = ST->getElementType(I); + } else if (const auto *AT = dyn_cast<ArrayType>(CurrentType)) { + Index *= AT->getNumElements(); + CurrentType = AT->getElementType(); + } else { + return std::nullopt; + } + Index += I; + } + return Index; +} + +namespace { +/// Specifies the way the mask should be analyzed for undefs/poisonous elements +/// in the shuffle mask. +enum class UseMask { + FirstArg, ///< The mask is expected to be for permutation of 1-2 vectors, + ///< check for the mask elements for the first argument (mask + ///< indices are in range [0:VF)). + SecondArg, ///< The mask is expected to be for permutation of 2 vectors, check + ///< for the mask elements for the second argument (mask indices + ///< are in range [VF:2*VF)) + UndefsAsMask ///< Consider undef mask elements (-1) as placeholders for + ///< future shuffle elements and mark them as ones as being used + ///< in future. Non-undef elements are considered as unused since + ///< they're already marked as used in the mask. +}; +} // namespace + +/// Prepares a use bitset for the given mask either for the first argument or +/// for the second. +static SmallBitVector buildUseMask(int VF, ArrayRef<int> Mask, + UseMask MaskArg) { + SmallBitVector UseMask(VF, true); + for (auto P : enumerate(Mask)) { + if (P.value() == UndefMaskElem) { + if (MaskArg == UseMask::UndefsAsMask) + UseMask.reset(P.index()); + continue; + } + if (MaskArg == UseMask::FirstArg && P.value() < VF) + UseMask.reset(P.value()); + else if (MaskArg == UseMask::SecondArg && P.value() >= VF) + UseMask.reset(P.value() - VF); + } + return UseMask; +} + /// Checks if the given value is actually an undefined constant vector. -static bool isUndefVector(const Value *V) { - if (isa<UndefValue>(V)) - return true; - auto *C = dyn_cast<Constant>(V); - if (!C) - return false; - if (!C->containsUndefOrPoisonElement()) - return false; - auto *VecTy = dyn_cast<FixedVectorType>(C->getType()); +/// Also, if the \p UseMask is not empty, tries to check if the non-masked +/// elements actually mask the insertelement buildvector, if any. +template <bool IsPoisonOnly = false> +static SmallBitVector isUndefVector(const Value *V, + const SmallBitVector &UseMask = {}) { + SmallBitVector Res(UseMask.empty() ? 1 : UseMask.size(), true); + using T = std::conditional_t<IsPoisonOnly, PoisonValue, UndefValue>; + if (isa<T>(V)) + return Res; + auto *VecTy = dyn_cast<FixedVectorType>(V->getType()); if (!VecTy) - return false; + return Res.reset(); + auto *C = dyn_cast<Constant>(V); + if (!C) { + if (!UseMask.empty()) { + const Value *Base = V; + while (auto *II = dyn_cast<InsertElementInst>(Base)) { + if (isa<T>(II->getOperand(1))) + continue; + Base = II->getOperand(0); + std::optional<unsigned> Idx = getInsertIndex(II); + if (!Idx) + continue; + if (*Idx < UseMask.size() && !UseMask.test(*Idx)) + Res.reset(*Idx); + } + // TODO: Add analysis for shuffles here too. + if (V == Base) { + Res.reset(); + } else { + SmallBitVector SubMask(UseMask.size(), false); + Res &= isUndefVector<IsPoisonOnly>(Base, SubMask); + } + } else { + Res.reset(); + } + return Res; + } for (unsigned I = 0, E = VecTy->getNumElements(); I != E; ++I) { if (Constant *Elem = C->getAggregateElement(I)) - if (!isa<UndefValue>(Elem)) - return false; + if (!isa<T>(Elem) && + (UseMask.empty() || (I < UseMask.size() && !UseMask.test(I)))) + Res.reset(I); } - return true; + return Res; } /// Checks if the vector of instructions can be represented as a shuffle, like: @@ -345,16 +445,16 @@ static bool isUndefVector(const Value *V) { /// InstCombiner transforms this into a shuffle and vector mul /// Mask will return the Shuffle Mask equivalent to the extracted elements. /// TODO: Can we split off and reuse the shuffle mask detection from -/// TargetTransformInfo::getInstructionThroughput? -static Optional<TargetTransformInfo::ShuffleKind> +/// ShuffleVectorInst/getShuffleCost? +static std::optional<TargetTransformInfo::ShuffleKind> isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { const auto *It = find_if(VL, [](Value *V) { return isa<ExtractElementInst>(V); }); if (It == VL.end()) - return None; + return std::nullopt; auto *EI0 = cast<ExtractElementInst>(*It); if (isa<ScalableVectorType>(EI0->getVectorOperandType())) - return None; + return std::nullopt; unsigned Size = cast<FixedVectorType>(EI0->getVectorOperandType())->getNumElements(); Value *Vec1 = nullptr; @@ -368,19 +468,19 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { continue; auto *EI = cast<ExtractElementInst>(VL[I]); if (isa<ScalableVectorType>(EI->getVectorOperandType())) - return None; + return std::nullopt; auto *Vec = EI->getVectorOperand(); // We can extractelement from undef or poison vector. - if (isUndefVector(Vec)) + if (isUndefVector(Vec).all()) continue; // All vector operands must have the same number of vector elements. if (cast<FixedVectorType>(Vec->getType())->getNumElements() != Size) - return None; + return std::nullopt; if (isa<UndefValue>(EI->getIndexOperand())) continue; auto *Idx = dyn_cast<ConstantInt>(EI->getIndexOperand()); if (!Idx) - return None; + return std::nullopt; // Undefined behavior if Idx is negative or >= Size. if (Idx->getValue().uge(Size)) continue; @@ -394,7 +494,7 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { Vec2 = Vec; Mask[I] += Size; } else { - return None; + return std::nullopt; } if (CommonShuffleMode == Permute) continue; @@ -415,6 +515,24 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { : TargetTransformInfo::SK_PermuteSingleSrc; } +/// \returns True if Extract{Value,Element} instruction extracts element Idx. +static std::optional<unsigned> getExtractIndex(Instruction *E) { + unsigned Opcode = E->getOpcode(); + assert((Opcode == Instruction::ExtractElement || + Opcode == Instruction::ExtractValue) && + "Expected extractelement or extractvalue instruction."); + if (Opcode == Instruction::ExtractElement) { + auto *CI = dyn_cast<ConstantInt>(E->getOperand(1)); + if (!CI) + return std::nullopt; + return CI->getZExtValue(); + } + auto *EI = cast<ExtractValueInst>(E); + if (EI->getNumIndices() != 1) + return std::nullopt; + return *EI->idx_begin(); +} + namespace { /// Main data required for vectorization of instructions. @@ -473,24 +591,49 @@ static bool isValidForAlternation(unsigned Opcode) { } static InstructionsState getSameOpcode(ArrayRef<Value *> VL, + const TargetLibraryInfo &TLI, unsigned BaseIndex = 0); /// Checks if the provided operands of 2 cmp instructions are compatible, i.e. /// compatible instructions or constants, or just some other regular values. static bool areCompatibleCmpOps(Value *BaseOp0, Value *BaseOp1, Value *Op0, - Value *Op1) { + Value *Op1, const TargetLibraryInfo &TLI) { return (isConstant(BaseOp0) && isConstant(Op0)) || (isConstant(BaseOp1) && isConstant(Op1)) || (!isa<Instruction>(BaseOp0) && !isa<Instruction>(Op0) && !isa<Instruction>(BaseOp1) && !isa<Instruction>(Op1)) || - getSameOpcode({BaseOp0, Op0}).getOpcode() || - getSameOpcode({BaseOp1, Op1}).getOpcode(); + BaseOp0 == Op0 || BaseOp1 == Op1 || + getSameOpcode({BaseOp0, Op0}, TLI).getOpcode() || + getSameOpcode({BaseOp1, Op1}, TLI).getOpcode(); +} + +/// \returns true if a compare instruction \p CI has similar "look" and +/// same predicate as \p BaseCI, "as is" or with its operands and predicate +/// swapped, false otherwise. +static bool isCmpSameOrSwapped(const CmpInst *BaseCI, const CmpInst *CI, + const TargetLibraryInfo &TLI) { + assert(BaseCI->getOperand(0)->getType() == CI->getOperand(0)->getType() && + "Assessing comparisons of different types?"); + CmpInst::Predicate BasePred = BaseCI->getPredicate(); + CmpInst::Predicate Pred = CI->getPredicate(); + CmpInst::Predicate SwappedPred = CmpInst::getSwappedPredicate(Pred); + + Value *BaseOp0 = BaseCI->getOperand(0); + Value *BaseOp1 = BaseCI->getOperand(1); + Value *Op0 = CI->getOperand(0); + Value *Op1 = CI->getOperand(1); + + return (BasePred == Pred && + areCompatibleCmpOps(BaseOp0, BaseOp1, Op0, Op1, TLI)) || + (BasePred == SwappedPred && + areCompatibleCmpOps(BaseOp0, BaseOp1, Op1, Op0, TLI)); } /// \returns analysis of the Instructions in \p VL described in /// InstructionsState, the Opcode that we suppose the whole list /// could be vectorized even if its structure is diverse. static InstructionsState getSameOpcode(ArrayRef<Value *> VL, + const TargetLibraryInfo &TLI, unsigned BaseIndex) { // Make sure these are all Instructions. if (llvm::any_of(VL, [](Value *V) { return !isa<Instruction>(V); })) @@ -508,9 +651,19 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, // Check for one alternate opcode from another BinaryOperator. // TODO - generalize to support all operators (types, calls etc.). + auto *IBase = cast<Instruction>(VL[BaseIndex]); + Intrinsic::ID BaseID = 0; + SmallVector<VFInfo> BaseMappings; + if (auto *CallBase = dyn_cast<CallInst>(IBase)) { + BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI); + BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase); + if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty()) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + } for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) { - unsigned InstOpcode = cast<Instruction>(VL[Cnt])->getOpcode(); - if (IsBinOp && isa<BinaryOperator>(VL[Cnt])) { + auto *I = cast<Instruction>(VL[Cnt]); + unsigned InstOpcode = I->getOpcode(); + if (IsBinOp && isa<BinaryOperator>(I)) { if (InstOpcode == Opcode || InstOpcode == AltOpcode) continue; if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) && @@ -519,9 +672,11 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, AltIndex = Cnt; continue; } - } else if (IsCastOp && isa<CastInst>(VL[Cnt])) { - Type *Ty0 = cast<Instruction>(VL[BaseIndex])->getOperand(0)->getType(); - Type *Ty1 = cast<Instruction>(VL[Cnt])->getOperand(0)->getType(); + } else if (IsCastOp && isa<CastInst>(I)) { + Value *Op0 = IBase->getOperand(0); + Type *Ty0 = Op0->getType(); + Value *Op1 = I->getOperand(0); + Type *Ty1 = Op1->getType(); if (Ty0 == Ty1) { if (InstOpcode == Opcode || InstOpcode == AltOpcode) continue; @@ -534,59 +689,79 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, continue; } } - } else if (IsCmpOp && isa<CmpInst>(VL[Cnt])) { - auto *BaseInst = cast<Instruction>(VL[BaseIndex]); - auto *Inst = cast<Instruction>(VL[Cnt]); + } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) { + auto *BaseInst = cast<CmpInst>(VL[BaseIndex]); Type *Ty0 = BaseInst->getOperand(0)->getType(); Type *Ty1 = Inst->getOperand(0)->getType(); if (Ty0 == Ty1) { - Value *BaseOp0 = BaseInst->getOperand(0); - Value *BaseOp1 = BaseInst->getOperand(1); - Value *Op0 = Inst->getOperand(0); - Value *Op1 = Inst->getOperand(1); - CmpInst::Predicate CurrentPred = - cast<CmpInst>(VL[Cnt])->getPredicate(); - CmpInst::Predicate SwappedCurrentPred = - CmpInst::getSwappedPredicate(CurrentPred); + assert(InstOpcode == Opcode && "Expected same CmpInst opcode."); // Check for compatible operands. If the corresponding operands are not // compatible - need to perform alternate vectorization. - if (InstOpcode == Opcode) { - if (BasePred == CurrentPred && - areCompatibleCmpOps(BaseOp0, BaseOp1, Op0, Op1)) - continue; - if (BasePred == SwappedCurrentPred && - areCompatibleCmpOps(BaseOp0, BaseOp1, Op1, Op0)) - continue; - if (E == 2 && - (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) - continue; - auto *AltInst = cast<CmpInst>(VL[AltIndex]); - CmpInst::Predicate AltPred = AltInst->getPredicate(); - Value *AltOp0 = AltInst->getOperand(0); - Value *AltOp1 = AltInst->getOperand(1); - // Check if operands are compatible with alternate operands. - if (AltPred == CurrentPred && - areCompatibleCmpOps(AltOp0, AltOp1, Op0, Op1)) - continue; - if (AltPred == SwappedCurrentPred && - areCompatibleCmpOps(AltOp0, AltOp1, Op1, Op0)) + CmpInst::Predicate CurrentPred = Inst->getPredicate(); + CmpInst::Predicate SwappedCurrentPred = + CmpInst::getSwappedPredicate(CurrentPred); + + if (E == 2 && + (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) + continue; + + if (isCmpSameOrSwapped(BaseInst, Inst, TLI)) + continue; + auto *AltInst = cast<CmpInst>(VL[AltIndex]); + if (AltIndex != BaseIndex) { + if (isCmpSameOrSwapped(AltInst, Inst, TLI)) continue; - } - if (BaseIndex == AltIndex && BasePred != CurrentPred) { - assert(isValidForAlternation(Opcode) && - isValidForAlternation(InstOpcode) && - "Cast isn't safe for alternation, logic needs to be updated!"); + } else if (BasePred != CurrentPred) { + assert( + isValidForAlternation(InstOpcode) && + "CmpInst isn't safe for alternation, logic needs to be updated!"); AltIndex = Cnt; continue; } - auto *AltInst = cast<CmpInst>(VL[AltIndex]); CmpInst::Predicate AltPred = AltInst->getPredicate(); if (BasePred == CurrentPred || BasePred == SwappedCurrentPred || AltPred == CurrentPred || AltPred == SwappedCurrentPred) continue; } - } else if (InstOpcode == Opcode || InstOpcode == AltOpcode) + } else if (InstOpcode == Opcode || InstOpcode == AltOpcode) { + if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) { + if (Gep->getNumOperands() != 2 || + Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType()) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + } else if (auto *EI = dyn_cast<ExtractElementInst>(I)) { + if (!isVectorLikeInstWithConstOps(EI)) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + } else if (auto *LI = dyn_cast<LoadInst>(I)) { + auto *BaseLI = cast<LoadInst>(IBase); + if (!LI->isSimple() || !BaseLI->isSimple()) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + } else if (auto *Call = dyn_cast<CallInst>(I)) { + auto *CallBase = cast<CallInst>(IBase); + if (Call->getCalledFunction() != CallBase->getCalledFunction()) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + if (Call->hasOperandBundles() && + !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(), + Call->op_begin() + Call->getBundleOperandsEndIndex(), + CallBase->op_begin() + + CallBase->getBundleOperandsStartIndex())) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(Call, &TLI); + if (ID != BaseID) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + if (!ID) { + SmallVector<VFInfo> Mappings = VFDatabase(*Call).getMappings(*Call); + if (Mappings.size() != BaseMappings.size() || + Mappings.front().ISA != BaseMappings.front().ISA || + Mappings.front().ScalarName != BaseMappings.front().ScalarName || + Mappings.front().VectorName != BaseMappings.front().VectorName || + Mappings.front().Shape.VF != BaseMappings.front().Shape.VF || + Mappings.front().Shape.Parameters != + BaseMappings.front().Shape.Parameters) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + } + } continue; + } return InstructionsState(VL[BaseIndex], nullptr, nullptr); } @@ -605,24 +780,6 @@ static bool allSameType(ArrayRef<Value *> VL) { return true; } -/// \returns True if Extract{Value,Element} instruction extracts element Idx. -static Optional<unsigned> getExtractIndex(Instruction *E) { - unsigned Opcode = E->getOpcode(); - assert((Opcode == Instruction::ExtractElement || - Opcode == Instruction::ExtractValue) && - "Expected extractelement or extractvalue instruction."); - if (Opcode == Instruction::ExtractElement) { - auto *CI = dyn_cast<ConstantInt>(E->getOperand(1)); - if (!CI) - return None; - return CI->getZExtValue(); - } - ExtractValueInst *EI = cast<ExtractValueInst>(E); - if (EI->getNumIndices() != 1) - return None; - return *EI->idx_begin(); -} - /// \returns True if in-tree use also needs extract. This refers to /// possible scalar operand in vectorized instruction. static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, @@ -644,7 +801,7 @@ static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, if (isVectorIntrinsicWithScalarOpAtArg(ID, i)) return (CI->getArgOperand(i) == Scalar); } - LLVM_FALLTHROUGH; + [[fallthrough]]; } default: return false; @@ -735,40 +892,6 @@ static void inversePermutation(ArrayRef<unsigned> Indices, Mask[Indices[I]] = I; } -/// \returns inserting index of InsertElement or InsertValue instruction, -/// using Offset as base offset for index. -static Optional<unsigned> getInsertIndex(const Value *InsertInst, - unsigned Offset = 0) { - int Index = Offset; - if (const auto *IE = dyn_cast<InsertElementInst>(InsertInst)) { - if (const auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2))) { - auto *VT = cast<FixedVectorType>(IE->getType()); - if (CI->getValue().uge(VT->getNumElements())) - return None; - Index *= VT->getNumElements(); - Index += CI->getZExtValue(); - return Index; - } - return None; - } - - const auto *IV = cast<InsertValueInst>(InsertInst); - Type *CurrentType = IV->getType(); - for (unsigned I : IV->indices()) { - if (const auto *ST = dyn_cast<StructType>(CurrentType)) { - Index *= ST->getNumElements(); - CurrentType = ST->getElementType(I); - } else if (const auto *AT = dyn_cast<ArrayType>(CurrentType)) { - Index *= AT->getNumElements(); - CurrentType = AT->getElementType(); - } else { - return None; - } - Index += I; - } - return Index; -} - /// Reorders the list of scalars in accordance with the given \p Mask. static void reorderScalars(SmallVectorImpl<Value *> &Scalars, ArrayRef<int> Mask) { @@ -839,6 +962,7 @@ namespace slpvectorizer { class BoUpSLP { struct TreeEntry; struct ScheduleData; + class ShuffleInstructionBuilder; public: using ValueList = SmallVector<Value *, 8>; @@ -867,7 +991,7 @@ public: else MaxVecRegSize = TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) - .getFixedSize(); + .getFixedValue(); if (MinVectorRegSizeOption.getNumOccurrences()) MinVecRegSize = MinVectorRegSizeOption; @@ -882,7 +1006,8 @@ public: /// Vectorize the tree but with the list of externally used values \p /// ExternallyUsedValues. Values in this MapVector can be replaced but the /// generated extractvalue instructions. - Value *vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues); + Value *vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, + Instruction *ReductionRoot = nullptr); /// \returns the cost incurred by unwanted spills and fills, caused by /// holding live values over call sites. @@ -890,7 +1015,7 @@ public: /// \returns the vectorization cost of the subtree that starts at \p VL. /// A negative number means that this is profitable. - InstructionCost getTreeCost(ArrayRef<Value *> VectorizedVals = None); + InstructionCost getTreeCost(ArrayRef<Value *> VectorizedVals = std::nullopt); /// Construct a vectorizable tree that starts at \p Roots, ignoring users for /// the purpose of scheduling and extraction in the \p UserIgnoreLst. @@ -900,6 +1025,24 @@ public: /// Construct a vectorizable tree that starts at \p Roots. void buildTree(ArrayRef<Value *> Roots); + /// Checks if the very first tree node is going to be vectorized. + bool isVectorizedFirstNode() const { + return !VectorizableTree.empty() && + VectorizableTree.front()->State == TreeEntry::Vectorize; + } + + /// Returns the main instruction for the very first node. + Instruction *getFirstNodeMainOp() const { + assert(!VectorizableTree.empty() && "No tree to get the first node from"); + return VectorizableTree.front()->getMainOp(); + } + + /// Returns whether the root node has in-tree uses. + bool doesRootHaveInTreeUses() const { + return !VectorizableTree.empty() && + !VectorizableTree.front()->UserTreeIndices.empty(); + } + /// Builds external uses of the vectorized scalars, i.e. the list of /// vectorized scalars to be extracted, their lanes and their scalar users. \p /// ExternallyUsedValues contains additional list of external uses to handle @@ -912,6 +1055,7 @@ public: VectorizableTree.clear(); ScalarToTreeEntry.clear(); MustGather.clear(); + EntryToLastInstruction.clear(); ExternalUses.clear(); for (auto &Iter : BlocksSchedules) { BlockScheduling *BS = Iter.second.get(); @@ -931,17 +1075,17 @@ public: /// shuffled vector entry + (possibly) permutation with other gathers. It /// implements the checks only for possibly ordered scalars (Loads, /// ExtractElement, ExtractValue), which can be part of the graph. - Optional<OrdersType> findReusedOrderedScalars(const TreeEntry &TE); + std::optional<OrdersType> findReusedOrderedScalars(const TreeEntry &TE); /// Sort loads into increasing pointers offsets to allow greater clustering. - Optional<OrdersType> findPartiallyOrderedLoads(const TreeEntry &TE); + std::optional<OrdersType> findPartiallyOrderedLoads(const TreeEntry &TE); /// Gets reordering data for the given tree entry. If the entry is vectorized /// - just return ReorderIndices, otherwise check if the scalars can be /// reordered and return the most optimal order. /// \param TopToBottom If true, include the order of vectorized stores and /// insertelement nodes, otherwise skip them. - Optional<OrdersType> getReorderingData(const TreeEntry &TE, bool TopToBottom); + std::optional<OrdersType> getReorderingData(const TreeEntry &TE, bool TopToBottom); /// Reorders the current graph to the most profitable order starting from the /// root node to the leaf nodes. The best order is chosen only from the nodes @@ -1052,6 +1196,7 @@ public: /// A helper class used for scoring candidates for two consecutive lanes. class LookAheadHeuristics { + const TargetLibraryInfo &TLI; const DataLayout &DL; ScalarEvolution &SE; const BoUpSLP &R; @@ -1059,9 +1204,11 @@ public: int MaxLevel; // The maximum recursion depth for accumulating score. public: - LookAheadHeuristics(const DataLayout &DL, ScalarEvolution &SE, - const BoUpSLP &R, int NumLanes, int MaxLevel) - : DL(DL), SE(SE), R(R), NumLanes(NumLanes), MaxLevel(MaxLevel) {} + LookAheadHeuristics(const TargetLibraryInfo &TLI, const DataLayout &DL, + ScalarEvolution &SE, const BoUpSLP &R, int NumLanes, + int MaxLevel) + : TLI(TLI), DL(DL), SE(SE), R(R), NumLanes(NumLanes), + MaxLevel(MaxLevel) {} // The hard-coded scores listed here are not very important, though it shall // be higher for better matches to improve the resulting cost. When @@ -1083,6 +1230,8 @@ public: static const int ScoreSplatLoads = 3; /// Loads from reversed memory addresses, e.g. load(A[i+1]), load(A[i]). static const int ScoreReversedLoads = 3; + /// A load candidate for masked gather. + static const int ScoreMaskedGatherCandidate = 1; /// ExtractElementInst from same vector and consecutive indexes. static const int ScoreConsecutiveExtracts = 4; /// ExtractElementInst from same vector and reversed indices. @@ -1108,6 +1257,10 @@ public: /// MainAltOps. int getShallowScore(Value *V1, Value *V2, Instruction *U1, Instruction *U2, ArrayRef<Value *> MainAltOps) const { + if (!isValidElementType(V1->getType()) || + !isValidElementType(V2->getType())) + return LookAheadHeuristics::ScoreFail; + if (V1 == V2) { if (isa<LoadInst>(V1)) { // Retruns true if the users of V1 and V2 won't need to be extracted. @@ -1137,18 +1290,26 @@ public: auto *LI1 = dyn_cast<LoadInst>(V1); auto *LI2 = dyn_cast<LoadInst>(V2); if (LI1 && LI2) { - if (LI1->getParent() != LI2->getParent()) + if (LI1->getParent() != LI2->getParent() || !LI1->isSimple() || + !LI2->isSimple()) return LookAheadHeuristics::ScoreFail; - Optional<int> Dist = getPointersDiff( + std::optional<int> Dist = getPointersDiff( LI1->getType(), LI1->getPointerOperand(), LI2->getType(), LI2->getPointerOperand(), DL, SE, /*StrictCheck=*/true); - if (!Dist || *Dist == 0) + if (!Dist || *Dist == 0) { + if (getUnderlyingObject(LI1->getPointerOperand()) == + getUnderlyingObject(LI2->getPointerOperand()) && + R.TTI->isLegalMaskedGather( + FixedVectorType::get(LI1->getType(), NumLanes), + LI1->getAlign())) + return LookAheadHeuristics::ScoreMaskedGatherCandidate; return LookAheadHeuristics::ScoreFail; + } // The distance is too large - still may be profitable to use masked // loads/gathers. if (std::abs(*Dist) > NumLanes / 2) - return LookAheadHeuristics::ScoreAltOpcodes; + return LookAheadHeuristics::ScoreMaskedGatherCandidate; // This still will detect consecutive loads, but we might have "holes" // in some cases. It is ok for non-power-2 vectorization and may produce // better results. It should not affect current vectorization. @@ -1177,7 +1338,7 @@ public: // Undefs are always profitable for extractelements. if (!Ex2Idx) return LookAheadHeuristics::ScoreConsecutiveExtracts; - if (isUndefVector(EV2) && EV2->getType() == EV1->getType()) + if (isUndefVector(EV2).all() && EV2->getType() == EV1->getType()) return LookAheadHeuristics::ScoreConsecutiveExtracts; if (EV2 == EV1) { int Idx1 = Ex1Idx->getZExtValue(); @@ -1205,7 +1366,7 @@ public: SmallVector<Value *, 4> Ops(MainAltOps.begin(), MainAltOps.end()); Ops.push_back(I1); Ops.push_back(I2); - InstructionsState S = getSameOpcode(Ops); + InstructionsState S = getSameOpcode(Ops, TLI); // Note: Only consider instructions with <= 2 operands to avoid // complexity explosion. if (S.getOpcode() && @@ -1300,7 +1461,7 @@ public: // Recursively calculate the cost at each level int TmpScore = getScoreAtLevelRec(I1->getOperand(OpIdx1), I2->getOperand(OpIdx2), - I1, I2, CurrLevel + 1, None); + I1, I2, CurrLevel + 1, std::nullopt); // Look for the best score. if (TmpScore > LookAheadHeuristics::ScoreFail && TmpScore > MaxTmpScore) { @@ -1381,6 +1542,7 @@ public: /// A vector of operand vectors. SmallVector<OperandDataVec, 4> OpsVec; + const TargetLibraryInfo &TLI; const DataLayout &DL; ScalarEvolution &SE; const BoUpSLP &R; @@ -1464,7 +1626,7 @@ public: auto *IdxLaneI = dyn_cast<Instruction>(IdxLaneV); if (!IdxLaneI || !isa<Instruction>(OpIdxLaneV)) return 0; - return R.areAllUsersVectorized(IdxLaneI, None) + return R.areAllUsersVectorized(IdxLaneI, std::nullopt) ? LookAheadHeuristics::ScoreAllUserVectorized : 0; } @@ -1482,7 +1644,7 @@ public: int getLookAheadScore(Value *LHS, Value *RHS, ArrayRef<Value *> MainAltOps, int Lane, unsigned OpIdx, unsigned Idx, bool &IsUsed) { - LookAheadHeuristics LookAhead(DL, SE, R, getNumLanes(), + LookAheadHeuristics LookAhead(TLI, DL, SE, R, getNumLanes(), LookAheadMaxDepth); // Keep track of the instruction stack as we recurse into the operands // during the look-ahead score exploration. @@ -1520,8 +1682,8 @@ public: // Search all operands in Ops[*][Lane] for the one that matches best // Ops[OpIdx][LastLane] and return its opreand index. - // If no good match can be found, return None. - Optional<unsigned> getBestOperand(unsigned OpIdx, int Lane, int LastLane, + // If no good match can be found, return std::nullopt. + std::optional<unsigned> getBestOperand(unsigned OpIdx, int Lane, int LastLane, ArrayRef<ReorderingMode> ReorderingModes, ArrayRef<Value *> MainAltOps) { unsigned NumOperands = getNumOperands(); @@ -1532,7 +1694,7 @@ public: // Our strategy mode for OpIdx. ReorderingMode RMode = ReorderingModes[OpIdx]; if (RMode == ReorderingMode::Failed) - return None; + return std::nullopt; // The linearized opcode of the operand at OpIdx, Lane. bool OpIdxAPO = getData(OpIdx, Lane).APO; @@ -1541,7 +1703,7 @@ public: // Sometimes we have more than one option (e.g., Opcode and Undefs), so we // are using the score to differentiate between the two. struct BestOpData { - Optional<unsigned> Idx = None; + std::optional<unsigned> Idx; unsigned Score = 0; } BestOp; BestOp.Score = @@ -1600,8 +1762,8 @@ public: getData(*BestOp.Idx, Lane).IsUsed = IsUsed; return BestOp.Idx; } - // If we could not find a good match return None. - return None; + // If we could not find a good match return std::nullopt. + return std::nullopt; } /// Helper for reorderOperandVecs. @@ -1704,7 +1866,7 @@ public: // Use Boyer-Moore majority voting for finding the majority opcode and // the number of times it occurs. if (auto *I = dyn_cast<Instruction>(OpData.V)) { - if (!OpcodeI || !getSameOpcode({OpcodeI, I}).getOpcode() || + if (!OpcodeI || !getSameOpcode({OpcodeI, I}, TLI).getOpcode() || I->getParent() != Parent) { if (NumOpsWithSameOpcodeParent == 0) { NumOpsWithSameOpcodeParent = 1; @@ -1806,9 +1968,9 @@ public: public: /// Initialize with all the operands of the instruction vector \p RootVL. - VLOperands(ArrayRef<Value *> RootVL, const DataLayout &DL, - ScalarEvolution &SE, const BoUpSLP &R) - : DL(DL), SE(SE), R(R) { + VLOperands(ArrayRef<Value *> RootVL, const TargetLibraryInfo &TLI, + const DataLayout &DL, ScalarEvolution &SE, const BoUpSLP &R) + : TLI(TLI), DL(DL), SE(SE), R(R) { // Append all the operands of RootVL. appendOperandsOfVL(RootVL); } @@ -1930,7 +2092,7 @@ public: // Look for a good match for each operand. for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) { // Search for the operand that matches SortedOps[OpIdx][Lane-1]. - Optional<unsigned> BestIdx = getBestOperand( + std::optional<unsigned> BestIdx = getBestOperand( OpIdx, Lane, LastLane, ReorderingModes, MainAltOps[OpIdx]); // By not selecting a value, we allow the operands that follow to // select a better matching value. We will get a non-null value in @@ -1949,7 +2111,7 @@ public: if (MainAltOps[OpIdx].size() != 2) { OperandData &AltOp = getData(OpIdx, Lane); InstructionsState OpS = - getSameOpcode({MainAltOps[OpIdx].front(), AltOp.V}); + getSameOpcode({MainAltOps[OpIdx].front(), AltOp.V}, TLI); if (OpS.getOpcode() && OpS.isAltShuffle()) MainAltOps[OpIdx].push_back(AltOp.V); } @@ -2018,21 +2180,21 @@ public: /// Evaluate each pair in \p Candidates and return index into \p Candidates /// for a pair which have highest score deemed to have best chance to form - /// root of profitable tree to vectorize. Return None if no candidate scored - /// above the LookAheadHeuristics::ScoreFail. - /// \param Limit Lower limit of the cost, considered to be good enough score. - Optional<int> + /// root of profitable tree to vectorize. Return std::nullopt if no candidate + /// scored above the LookAheadHeuristics::ScoreFail. \param Limit Lower limit + /// of the cost, considered to be good enough score. + std::optional<int> findBestRootPair(ArrayRef<std::pair<Value *, Value *>> Candidates, int Limit = LookAheadHeuristics::ScoreFail) { - LookAheadHeuristics LookAhead(*DL, *SE, *this, /*NumLanes=*/2, + LookAheadHeuristics LookAhead(*TLI, *DL, *SE, *this, /*NumLanes=*/2, RootLookAheadMaxDepth); int BestScore = Limit; - Optional<int> Index = None; + std::optional<int> Index; for (int I : seq<int>(0, Candidates.size())) { int Score = LookAhead.getScoreAtLevelRec(Candidates[I].first, Candidates[I].second, /*U1=*/nullptr, /*U2=*/nullptr, - /*Level=*/1, None); + /*Level=*/1, std::nullopt); if (Score > BestScore) { BestScore = Score; Index = I; @@ -2063,7 +2225,7 @@ public: } /// Checks if the provided list of reduced values was checked already for /// vectorization. - bool areAnalyzedReductionVals(ArrayRef<Value *> VL) { + bool areAnalyzedReductionVals(ArrayRef<Value *> VL) const { return AnalyzedReductionVals.contains(hash_value(VL)); } /// Adds the list of reduced values to list of already checked values for the @@ -2081,6 +2243,9 @@ public: return any_of(MustGather, [&](Value *V) { return Vals.contains(V); }); } + /// Check if the value is vectorized in the tree. + bool isVectorized(Value *V) const { return getTreeEntry(V); } + ~BoUpSLP(); private: @@ -2097,6 +2262,10 @@ private: ArrayRef<TreeEntry *> ReorderableGathers, SmallVectorImpl<TreeEntry *> &GatherOps); + /// Checks if the given \p TE is a gather node with clustered reused scalars + /// and reorders it per given \p Mask. + void reorderNodeWithReuses(TreeEntry &TE, ArrayRef<int> Mask) const; + /// Returns vectorized operand \p OpIdx of the node \p UserTE from the graph, /// if any. If it is not vectorized (gather node), returns nullptr. TreeEntry *getVectorizedOperand(TreeEntry *UserTE, unsigned OpIdx) { @@ -2123,6 +2292,11 @@ private: bool areAllUsersVectorized(Instruction *I, ArrayRef<Value *> VectorizedVals) const; + /// Return information about the vector formed for the specified index + /// of a vector of (the same) instruction. + TargetTransformInfo::OperandValueInfo getOperandInfo(ArrayRef<Value *> VL, + unsigned OpIdx); + /// \returns the cost of the vectorizable entry. InstructionCost getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals); @@ -2142,13 +2316,14 @@ private: /// Vectorize a single entry in the tree. Value *vectorizeTree(TreeEntry *E); - /// Vectorize a single entry in the tree, starting in \p VL. - Value *vectorizeTree(ArrayRef<Value *> VL); + /// Vectorize a single entry in the tree, the \p Idx-th operand of the entry + /// \p E. + Value *vectorizeOperand(TreeEntry *E, unsigned NodeIdx); /// Create a new vector from a list of scalar values. Produces a sequence /// which exploits values reused across lanes, and arranges the inserts /// for ease of later optimization. - Value *createBuildVector(ArrayRef<Value *> VL); + Value *createBuildVector(const TreeEntry *E); /// \returns the scalarization cost for this type. Scalarization in this /// context means the creation of vectors from a group of scalars. If \p @@ -2158,12 +2333,22 @@ private: const APInt &ShuffledIndices, bool NeedToShuffle) const; + /// Returns the instruction in the bundle, which can be used as a base point + /// for scheduling. Usually it is the last instruction in the bundle, except + /// for the case when all operands are external (in this case, it is the first + /// instruction in the list). + Instruction &getLastInstructionInBundle(const TreeEntry *E); + /// Checks if the gathered \p VL can be represented as shuffle(s) of previous /// tree entries. + /// \param TE Tree entry checked for permutation. + /// \param VL List of scalars (a subset of the TE scalar), checked for + /// permutations. /// \returns ShuffleKind, if gathered values can be represented as shuffles of /// previous tree entries. \p Mask is filled with the shuffle mask. - Optional<TargetTransformInfo::ShuffleKind> - isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, + std::optional<TargetTransformInfo::ShuffleKind> + isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, + SmallVectorImpl<int> &Mask, SmallVectorImpl<const TreeEntry *> &Entries); /// \returns the scalarization cost for this list of values. Assuming that @@ -2184,12 +2369,10 @@ private: /// Reorder commutative or alt operands to get better probability of /// generating vectorized code. - static void reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, - SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right, - const DataLayout &DL, - ScalarEvolution &SE, - const BoUpSLP &R); + static void reorderInputsAccordingToOpcode( + ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, + SmallVectorImpl<Value *> &Right, const TargetLibraryInfo &TLI, + const DataLayout &DL, ScalarEvolution &SE, const BoUpSLP &R); /// Helper for `findExternalStoreUsersReorderIndices()`. It iterates over the /// users of \p TE and collects the stores. It returns the map from the store @@ -2198,10 +2381,10 @@ private: collectUserStores(const BoUpSLP::TreeEntry *TE) const; /// Helper for `findExternalStoreUsersReorderIndices()`. It checks if the - /// stores in \p StoresVec can for a vector instruction. If so it returns true + /// stores in \p StoresVec can form a vector instruction. If so it returns true /// and populates \p ReorderIndices with the shuffle indices of the the stores /// when compared to the sorted vector. - bool CanFormVector(const SmallVector<StoreInst *, 4> &StoresVec, + bool canFormVector(const SmallVector<StoreInst *, 4> &StoresVec, OrdersType &ReorderIndices) const; /// Iterates through the users of \p TE, looking for scalar stores that can be @@ -2247,6 +2430,12 @@ private: return IsSame(Scalars, ReuseShuffleIndices); } + bool isOperandGatherNode(const EdgeInfo &UserEI) const { + return State == TreeEntry::NeedToGather && + UserTreeIndices.front().EdgeIdx == UserEI.EdgeIdx && + UserTreeIndices.front().UserTE == UserEI.UserTE; + } + /// \returns true if current entry has same operands as \p TE. bool hasEqualOperands(const TreeEntry &TE) const { if (TE.getNumOperands() != getNumOperands()) @@ -2508,11 +2697,11 @@ private: #endif /// Create a new VectorizableTree entry. - TreeEntry *newTreeEntry(ArrayRef<Value *> VL, Optional<ScheduleData *> Bundle, + TreeEntry *newTreeEntry(ArrayRef<Value *> VL, std::optional<ScheduleData *> Bundle, const InstructionsState &S, const EdgeInfo &UserTreeIdx, - ArrayRef<int> ReuseShuffleIndices = None, - ArrayRef<unsigned> ReorderIndices = None) { + ArrayRef<int> ReuseShuffleIndices = std::nullopt, + ArrayRef<unsigned> ReorderIndices = std::nullopt) { TreeEntry::EntryState EntryState = Bundle ? TreeEntry::Vectorize : TreeEntry::NeedToGather; return newTreeEntry(VL, EntryState, Bundle, S, UserTreeIdx, @@ -2521,11 +2710,11 @@ private: TreeEntry *newTreeEntry(ArrayRef<Value *> VL, TreeEntry::EntryState EntryState, - Optional<ScheduleData *> Bundle, + std::optional<ScheduleData *> Bundle, const InstructionsState &S, const EdgeInfo &UserTreeIdx, - ArrayRef<int> ReuseShuffleIndices = None, - ArrayRef<unsigned> ReorderIndices = None) { + ArrayRef<int> ReuseShuffleIndices = std::nullopt, + ArrayRef<unsigned> ReorderIndices = std::nullopt) { assert(((!Bundle && EntryState == TreeEntry::NeedToGather) || (Bundle && EntryState != TreeEntry::NeedToGather)) && "Need to vectorize gather entry?"); @@ -2547,7 +2736,7 @@ private: return UndefValue::get(VL.front()->getType()); return VL[Idx]; }); - InstructionsState S = getSameOpcode(Last->Scalars); + InstructionsState S = getSameOpcode(Last->Scalars, *TLI); Last->setOperations(S); Last->ReorderIndices.append(ReorderIndices.begin(), ReorderIndices.end()); } @@ -2611,6 +2800,14 @@ private: /// A list of scalars that we found that we need to keep as scalars. ValueSet MustGather; + /// A map between the vectorized entries and the last instructions in the + /// bundles. The bundles are built in use order, not in the def order of the + /// instructions. So, we cannot rely directly on the last instruction in the + /// bundle being the last instruction in the program order during + /// vectorization process since the basic blocks are affected, need to + /// pre-gather them before. + DenseMap<const TreeEntry *, Instruction *> EntryToLastInstruction; + /// This POD struct describes one external user in the vectorized tree. struct ExternalUser { ExternalUser(Value *S, llvm::User *U, int L) @@ -2635,9 +2832,9 @@ private: Instruction *Inst2) { // First check if the result is already in the cache. AliasCacheKey key = std::make_pair(Inst1, Inst2); - Optional<bool> &result = AliasCache[key]; + std::optional<bool> &result = AliasCache[key]; if (result) { - return result.value(); + return *result; } bool aliased = true; if (Loc1.Ptr && isSimple(Inst1)) @@ -2651,7 +2848,7 @@ private: /// Cache for alias results. /// TODO: consider moving this to the AliasAnalysis itself. - DenseMap<AliasCacheKey, Optional<bool>> AliasCache; + DenseMap<AliasCacheKey, std::optional<bool>> AliasCache; // Cache for pointerMayBeCaptured calls inside AA. This is preserved // globally through SLP because we don't perform any action which @@ -2680,8 +2877,9 @@ private: /// Values used only by @llvm.assume calls. SmallPtrSet<const Value *, 32> EphValues; - /// Holds all of the instructions that we gathered. - SetVector<Instruction *> GatherShuffleSeq; + /// Holds all of the instructions that we gathered, shuffle instructions and + /// extractelements. + SetVector<Instruction *> GatherShuffleExtractSeq; /// A list of blocks that we are going to CSE. SetVector<BasicBlock *> CSEBlocks; @@ -2994,7 +3192,7 @@ private: // okay. auto *In = BundleMember->Inst; assert(In && - (isa<ExtractValueInst>(In) || isa<ExtractElementInst>(In) || + (isa<ExtractValueInst, ExtractElementInst>(In) || In->getNumOperands() == TE->getNumOperands()) && "Missed TreeEntry operands?"); (void)In; // fake use to avoid build failure when assertions disabled @@ -3102,9 +3300,9 @@ private: /// Checks if a bundle of instructions can be scheduled, i.e. has no /// cyclic dependencies. This is only a dry-run, no instructions are /// actually moved at this stage. - /// \returns the scheduling bundle. The returned Optional value is non-None - /// if \p VL is allowed to be scheduled. - Optional<ScheduleData *> + /// \returns the scheduling bundle. The returned Optional value is not + /// std::nullopt if \p VL is allowed to be scheduled. + std::optional<ScheduleData *> tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, const InstructionsState &S); @@ -3319,9 +3517,10 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { std::string getNodeLabel(const TreeEntry *Entry, const BoUpSLP *R) { std::string Str; raw_string_ostream OS(Str); + OS << Entry->Idx << ".\n"; if (isSplat(Entry->Scalars)) OS << "<splat> "; - for (auto V : Entry->Scalars) { + for (auto *V : Entry->Scalars) { OS << *V; if (llvm::any_of(R->ExternalUses, [&](const BoUpSLP::ExternalUser &EU) { return EU.Scalar == V; @@ -3336,6 +3535,8 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { const BoUpSLP *) { if (Entry->State == TreeEntry::NeedToGather) return "color=red"; + if (Entry->State == TreeEntry::ScatterVectorize) + return "color=blue"; return ""; } }; @@ -3407,7 +3608,7 @@ static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask) { fixupOrderingIndices(Order); } -Optional<BoUpSLP::OrdersType> +std::optional<BoUpSLP::OrdersType> BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { assert(TE.State == TreeEntry::NeedToGather && "Expected gather node only."); unsigned NumScalars = TE.Scalars.size(); @@ -3427,11 +3628,11 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { STE = LocalSTE; else if (STE != LocalSTE) // Take the order only from the single vector node. - return None; + return std::nullopt; unsigned Lane = std::distance(STE->Scalars.begin(), find(STE->Scalars, V)); if (Lane >= NumScalars) - return None; + return std::nullopt; if (CurrentOrder[Lane] != NumScalars) { if (Lane != I) continue; @@ -3470,7 +3671,7 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { } return CurrentOrder; } - return None; + return std::nullopt; } namespace { @@ -3478,12 +3679,31 @@ namespace { enum class LoadsState { Gather, Vectorize, ScatterVectorize }; } // anonymous namespace +static bool arePointersCompatible(Value *Ptr1, Value *Ptr2, + const TargetLibraryInfo &TLI, + bool CompareOpcodes = true) { + if (getUnderlyingObject(Ptr1) != getUnderlyingObject(Ptr2)) + return false; + auto *GEP1 = dyn_cast<GetElementPtrInst>(Ptr1); + if (!GEP1) + return false; + auto *GEP2 = dyn_cast<GetElementPtrInst>(Ptr2); + if (!GEP2) + return false; + return GEP1->getNumOperands() == 2 && GEP2->getNumOperands() == 2 && + ((isConstant(GEP1->getOperand(1)) && + isConstant(GEP2->getOperand(1))) || + !CompareOpcodes || + getSameOpcode({GEP1->getOperand(1), GEP2->getOperand(1)}, TLI) + .getOpcode()); +} + /// Checks if the given array of loads can be represented as a vectorized, /// scatter or just simple gather. static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, const TargetTransformInfo &TTI, const DataLayout &DL, ScalarEvolution &SE, - LoopInfo &LI, + LoopInfo &LI, const TargetLibraryInfo &TLI, SmallVectorImpl<unsigned> &Order, SmallVectorImpl<Value *> &PointerOps) { // Check that a vectorized load would load the same memory as a scalar @@ -3513,18 +3733,8 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, Order.clear(); // Check the order of pointer operands or that all pointers are the same. bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order); - if (IsSorted || all_of(PointerOps, [&PointerOps](Value *P) { - if (getUnderlyingObject(P) != getUnderlyingObject(PointerOps.front())) - return false; - auto *GEP = dyn_cast<GetElementPtrInst>(P); - if (!GEP) - return false; - auto *GEP0 = cast<GetElementPtrInst>(PointerOps.front()); - return GEP->getNumOperands() == 2 && - ((isConstant(GEP->getOperand(1)) && - isConstant(GEP0->getOperand(1))) || - getSameOpcode({GEP->getOperand(1), GEP0->getOperand(1)}) - .getOpcode()); + if (IsSorted || all_of(PointerOps, [&](Value *P) { + return arePointersCompatible(P, PointerOps.front(), TLI); })) { if (IsSorted) { Value *Ptr0; @@ -3536,7 +3746,7 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, Ptr0 = PointerOps[Order.front()]; PtrN = PointerOps[Order.back()]; } - Optional<int> Diff = + std::optional<int> Diff = getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, DL, SE); // Check that the sorted loads are consecutive. if (static_cast<unsigned>(*Diff) == VL.size() - 1) @@ -3584,7 +3794,7 @@ bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy, unsigned Cnt = 1; for (Value *Ptr : VL.drop_front()) { bool Found = any_of(Bases, [&](auto &Base) { - Optional<int> Diff = + std::optional<int> Diff = getPointersDiff(ElemTy, Base.first, ElemTy, Ptr, DL, SE, /*StrictCheck=*/true); if (!Diff) @@ -3636,7 +3846,7 @@ bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy, return true; } -Optional<BoUpSLP::OrdersType> +std::optional<BoUpSLP::OrdersType> BoUpSLP::findPartiallyOrderedLoads(const BoUpSLP::TreeEntry &TE) { assert(TE.State == TreeEntry::NeedToGather && "Expected gather node only."); Type *ScalarTy = TE.Scalars[0]->getType(); @@ -3646,27 +3856,176 @@ BoUpSLP::findPartiallyOrderedLoads(const BoUpSLP::TreeEntry &TE) { for (Value *V : TE.Scalars) { auto *L = dyn_cast<LoadInst>(V); if (!L || !L->isSimple()) - return None; + return std::nullopt; Ptrs.push_back(L->getPointerOperand()); } BoUpSLP::OrdersType Order; if (clusterSortPtrAccesses(Ptrs, ScalarTy, *DL, *SE, Order)) return Order; - return None; + return std::nullopt; +} + +/// Check if two insertelement instructions are from the same buildvector. +static bool areTwoInsertFromSameBuildVector( + InsertElementInst *VU, InsertElementInst *V, + function_ref<Value *(InsertElementInst *)> GetBaseOperand) { + // Instructions must be from the same basic blocks. + if (VU->getParent() != V->getParent()) + return false; + // Checks if 2 insertelements are from the same buildvector. + if (VU->getType() != V->getType()) + return false; + // Multiple used inserts are separate nodes. + if (!VU->hasOneUse() && !V->hasOneUse()) + return false; + auto *IE1 = VU; + auto *IE2 = V; + std::optional<unsigned> Idx1 = getInsertIndex(IE1); + std::optional<unsigned> Idx2 = getInsertIndex(IE2); + if (Idx1 == std::nullopt || Idx2 == std::nullopt) + return false; + // Go through the vector operand of insertelement instructions trying to find + // either VU as the original vector for IE2 or V as the original vector for + // IE1. + do { + if (IE2 == VU) + return VU->hasOneUse(); + if (IE1 == V) + return V->hasOneUse(); + if (IE1) { + if ((IE1 != VU && !IE1->hasOneUse()) || + getInsertIndex(IE1).value_or(*Idx2) == *Idx2) + IE1 = nullptr; + else + IE1 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE1)); + } + if (IE2) { + if ((IE2 != V && !IE2->hasOneUse()) || + getInsertIndex(IE2).value_or(*Idx1) == *Idx1) + IE2 = nullptr; + else + IE2 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE2)); + } + } while (IE1 || IE2); + return false; } -Optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &TE, +std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { // No need to reorder if need to shuffle reuses, still need to shuffle the // node. - if (!TE.ReuseShuffleIndices.empty()) - return None; + if (!TE.ReuseShuffleIndices.empty()) { + // Check if reuse shuffle indices can be improved by reordering. + // For this, check that reuse mask is "clustered", i.e. each scalar values + // is used once in each submask of size <number_of_scalars>. + // Example: 4 scalar values. + // ReuseShuffleIndices mask: 0, 1, 2, 3, 3, 2, 0, 1 - clustered. + // 0, 1, 2, 3, 3, 3, 1, 0 - not clustered, because + // element 3 is used twice in the second submask. + unsigned Sz = TE.Scalars.size(); + if (!ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices, + Sz)) + return std::nullopt; + unsigned VF = TE.getVectorFactor(); + // Try build correct order for extractelement instructions. + SmallVector<int> ReusedMask(TE.ReuseShuffleIndices.begin(), + TE.ReuseShuffleIndices.end()); + if (TE.getOpcode() == Instruction::ExtractElement && !TE.isAltShuffle() && + all_of(TE.Scalars, [Sz](Value *V) { + std::optional<unsigned> Idx = getExtractIndex(cast<Instruction>(V)); + return Idx && *Idx < Sz; + })) { + SmallVector<int> ReorderMask(Sz, UndefMaskElem); + if (TE.ReorderIndices.empty()) + std::iota(ReorderMask.begin(), ReorderMask.end(), 0); + else + inversePermutation(TE.ReorderIndices, ReorderMask); + for (unsigned I = 0; I < VF; ++I) { + int &Idx = ReusedMask[I]; + if (Idx == UndefMaskElem) + continue; + Value *V = TE.Scalars[ReorderMask[Idx]]; + std::optional<unsigned> EI = getExtractIndex(cast<Instruction>(V)); + Idx = std::distance(ReorderMask.begin(), find(ReorderMask, *EI)); + } + } + // Build the order of the VF size, need to reorder reuses shuffles, they are + // always of VF size. + OrdersType ResOrder(VF); + std::iota(ResOrder.begin(), ResOrder.end(), 0); + auto *It = ResOrder.begin(); + for (unsigned K = 0; K < VF; K += Sz) { + OrdersType CurrentOrder(TE.ReorderIndices); + SmallVector<int> SubMask{ArrayRef(ReusedMask).slice(K, Sz)}; + if (SubMask.front() == UndefMaskElem) + std::iota(SubMask.begin(), SubMask.end(), 0); + reorderOrder(CurrentOrder, SubMask); + transform(CurrentOrder, It, [K](unsigned Pos) { return Pos + K; }); + std::advance(It, Sz); + } + if (all_of(enumerate(ResOrder), + [](const auto &Data) { return Data.index() == Data.value(); })) + return {}; // Use identity order. + return ResOrder; + } if (TE.State == TreeEntry::Vectorize && (isa<LoadInst, ExtractElementInst, ExtractValueInst>(TE.getMainOp()) || (TopToBottom && isa<StoreInst, InsertElementInst>(TE.getMainOp()))) && !TE.isAltShuffle()) return TE.ReorderIndices; + if (TE.State == TreeEntry::Vectorize && TE.getOpcode() == Instruction::PHI) { + auto PHICompare = [](llvm::Value *V1, llvm::Value *V2) { + if (!V1->hasOneUse() || !V2->hasOneUse()) + return false; + auto *FirstUserOfPhi1 = cast<Instruction>(*V1->user_begin()); + auto *FirstUserOfPhi2 = cast<Instruction>(*V2->user_begin()); + if (auto *IE1 = dyn_cast<InsertElementInst>(FirstUserOfPhi1)) + if (auto *IE2 = dyn_cast<InsertElementInst>(FirstUserOfPhi2)) { + if (!areTwoInsertFromSameBuildVector( + IE1, IE2, + [](InsertElementInst *II) { return II->getOperand(0); })) + return false; + std::optional<unsigned> Idx1 = getInsertIndex(IE1); + std::optional<unsigned> Idx2 = getInsertIndex(IE2); + if (Idx1 == std::nullopt || Idx2 == std::nullopt) + return false; + return *Idx1 < *Idx2; + } + if (auto *EE1 = dyn_cast<ExtractElementInst>(FirstUserOfPhi1)) + if (auto *EE2 = dyn_cast<ExtractElementInst>(FirstUserOfPhi2)) { + if (EE1->getOperand(0) != EE2->getOperand(0)) + return false; + std::optional<unsigned> Idx1 = getExtractIndex(EE1); + std::optional<unsigned> Idx2 = getExtractIndex(EE2); + if (Idx1 == std::nullopt || Idx2 == std::nullopt) + return false; + return *Idx1 < *Idx2; + } + return false; + }; + auto IsIdentityOrder = [](const OrdersType &Order) { + for (unsigned Idx : seq<unsigned>(0, Order.size())) + if (Idx != Order[Idx]) + return false; + return true; + }; + if (!TE.ReorderIndices.empty()) + return TE.ReorderIndices; + DenseMap<Value *, unsigned> PhiToId; + SmallVector<Value *, 4> Phis; + OrdersType ResOrder(TE.Scalars.size()); + for (unsigned Id = 0, Sz = TE.Scalars.size(); Id < Sz; ++Id) { + PhiToId[TE.Scalars[Id]] = Id; + Phis.push_back(TE.Scalars[Id]); + } + llvm::stable_sort(Phis, PHICompare); + for (unsigned Id = 0, Sz = Phis.size(); Id < Sz; ++Id) + ResOrder[Id] = PhiToId[Phis[Id]]; + if (IsIdentityOrder(ResOrder)) + return {}; + return ResOrder; + } if (TE.State == TreeEntry::NeedToGather) { // TODO: add analysis of other gather nodes with extractelement // instructions and other values/instructions, not only undefs. @@ -3694,13 +4053,55 @@ Optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &TE, return CurrentOrder; } } - if (Optional<OrdersType> CurrentOrder = findReusedOrderedScalars(TE)) + if (std::optional<OrdersType> CurrentOrder = findReusedOrderedScalars(TE)) return CurrentOrder; if (TE.Scalars.size() >= 4) - if (Optional<OrdersType> Order = findPartiallyOrderedLoads(TE)) + if (std::optional<OrdersType> Order = findPartiallyOrderedLoads(TE)) return Order; } - return None; + return std::nullopt; +} + +/// Checks if the given mask is a "clustered" mask with the same clusters of +/// size \p Sz, which are not identity submasks. +static bool isRepeatedNonIdentityClusteredMask(ArrayRef<int> Mask, + unsigned Sz) { + ArrayRef<int> FirstCluster = Mask.slice(0, Sz); + if (ShuffleVectorInst::isIdentityMask(FirstCluster)) + return false; + for (unsigned I = Sz, E = Mask.size(); I < E; I += Sz) { + ArrayRef<int> Cluster = Mask.slice(I, Sz); + if (Cluster != FirstCluster) + return false; + } + return true; +} + +void BoUpSLP::reorderNodeWithReuses(TreeEntry &TE, ArrayRef<int> Mask) const { + // Reorder reuses mask. + reorderReuses(TE.ReuseShuffleIndices, Mask); + const unsigned Sz = TE.Scalars.size(); + // For vectorized and non-clustered reused no need to do anything else. + if (TE.State != TreeEntry::NeedToGather || + !ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices, + Sz) || + !isRepeatedNonIdentityClusteredMask(TE.ReuseShuffleIndices, Sz)) + return; + SmallVector<int> NewMask; + inversePermutation(TE.ReorderIndices, NewMask); + addMask(NewMask, TE.ReuseShuffleIndices); + // Clear reorder since it is going to be applied to the new mask. + TE.ReorderIndices.clear(); + // Try to improve gathered nodes with clustered reuses, if possible. + ArrayRef<int> Slice = ArrayRef(NewMask).slice(0, Sz); + SmallVector<unsigned> NewOrder(Slice.begin(), Slice.end()); + inversePermutation(NewOrder, NewMask); + reorderScalars(TE.Scalars, NewMask); + // Fill the reuses mask with the identity submasks. + for (auto *It = TE.ReuseShuffleIndices.begin(), + *End = TE.ReuseShuffleIndices.end(); + It != End; std::advance(It, Sz)) + std::iota(It, std::next(It, Sz), 0); } void BoUpSLP::reorderTopToBottom() { @@ -3710,6 +4111,9 @@ void BoUpSLP::reorderTopToBottom() { // their ordering. DenseMap<const TreeEntry *, OrdersType> GathersToOrders; + // Phi nodes can have preferred ordering based on their result users + DenseMap<const TreeEntry *, OrdersType> PhisToOrders; + // AltShuffles can also have a preferred ordering that leads to fewer // instructions, e.g., the addsub instruction in x86. DenseMap<const TreeEntry *, OrdersType> AltShufflesToOrders; @@ -3724,13 +4128,13 @@ void BoUpSLP::reorderTopToBottom() { // extracts. for_each(VectorizableTree, [this, &TTIRef, &VFToOrderedEntries, &GathersToOrders, &ExternalUserReorderMap, - &AltShufflesToOrders]( + &AltShufflesToOrders, &PhisToOrders]( const std::unique_ptr<TreeEntry> &TE) { // Look for external users that will probably be vectorized. SmallVector<OrdersType, 1> ExternalUserReorderIndices = findExternalStoreUsersReorderIndices(TE.get()); if (!ExternalUserReorderIndices.empty()) { - VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); + VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get()); ExternalUserReorderMap.try_emplace(TE.get(), std::move(ExternalUserReorderIndices)); } @@ -3750,13 +4154,13 @@ void BoUpSLP::reorderTopToBottom() { OpcodeMask.set(Lane); // If this pattern is supported by the target then we consider the order. if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) { - VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); + VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get()); AltShufflesToOrders.try_emplace(TE.get(), OrdersType()); } // TODO: Check the reverse order too. } - if (Optional<OrdersType> CurrentOrder = + if (std::optional<OrdersType> CurrentOrder = getReorderingData(*TE, /*TopToBottom=*/true)) { // Do not include ordering for nodes used in the alt opcode vectorization, // better to reorder them during bottom-to-top stage. If follow the order @@ -3778,14 +4182,17 @@ void BoUpSLP::reorderTopToBottom() { UserTE = UserTE->UserTreeIndices.back().UserTE; ++Cnt; } - VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); - if (TE->State != TreeEntry::Vectorize) + VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get()); + if (TE->State != TreeEntry::Vectorize || !TE->ReuseShuffleIndices.empty()) GathersToOrders.try_emplace(TE.get(), *CurrentOrder); + if (TE->State == TreeEntry::Vectorize && + TE->getOpcode() == Instruction::PHI) + PhisToOrders.try_emplace(TE.get(), *CurrentOrder); } }); // Reorder the graph nodes according to their vectorization factor. - for (unsigned VF = VectorizableTree.front()->Scalars.size(); VF > 1; + for (unsigned VF = VectorizableTree.front()->getVectorFactor(); VF > 1; VF /= 2) { auto It = VFToOrderedEntries.find(VF); if (It == VFToOrderedEntries.end()) @@ -3803,12 +4210,13 @@ void BoUpSLP::reorderTopToBottom() { for (const TreeEntry *OpTE : OrderedEntries) { // No need to reorder this nodes, still need to extend and to use shuffle, // just need to merge reordering shuffle and the reuse shuffle. - if (!OpTE->ReuseShuffleIndices.empty()) + if (!OpTE->ReuseShuffleIndices.empty() && !GathersToOrders.count(OpTE)) continue; // Count number of orders uses. - const auto &Order = [OpTE, &GathersToOrders, - &AltShufflesToOrders]() -> const OrdersType & { - if (OpTE->State == TreeEntry::NeedToGather) { + const auto &Order = [OpTE, &GathersToOrders, &AltShufflesToOrders, + &PhisToOrders]() -> const OrdersType & { + if (OpTE->State == TreeEntry::NeedToGather || + !OpTE->ReuseShuffleIndices.empty()) { auto It = GathersToOrders.find(OpTE); if (It != GathersToOrders.end()) return It->second; @@ -3818,14 +4226,28 @@ void BoUpSLP::reorderTopToBottom() { if (It != AltShufflesToOrders.end()) return It->second; } + if (OpTE->State == TreeEntry::Vectorize && + OpTE->getOpcode() == Instruction::PHI) { + auto It = PhisToOrders.find(OpTE); + if (It != PhisToOrders.end()) + return It->second; + } return OpTE->ReorderIndices; }(); // First consider the order of the external scalar users. auto It = ExternalUserReorderMap.find(OpTE); if (It != ExternalUserReorderMap.end()) { const auto &ExternalUserReorderIndices = It->second; - for (const OrdersType &ExtOrder : ExternalUserReorderIndices) - ++OrdersUses.insert(std::make_pair(ExtOrder, 0)).first->second; + // If the OpTE vector factor != number of scalars - use natural order, + // it is an attempt to reorder node with reused scalars but with + // external uses. + if (OpTE->getVectorFactor() != OpTE->Scalars.size()) { + OrdersUses.insert(std::make_pair(OrdersType(), 0)).first->second += + ExternalUserReorderIndices.size(); + } else { + for (const OrdersType &ExtOrder : ExternalUserReorderIndices) + ++OrdersUses.insert(std::make_pair(ExtOrder, 0)).first->second; + } // No other useful reorder data in this entry. if (Order.empty()) continue; @@ -3885,7 +4307,7 @@ void BoUpSLP::reorderTopToBottom() { "All users must be of VF size."); // Update ordering of the operands with the smaller VF than the given // one. - reorderReuses(TE->ReuseShuffleIndices, Mask); + reorderNodeWithReuses(*TE, Mask); } continue; } @@ -3982,10 +4404,10 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { const std::unique_ptr<TreeEntry> &TE) { if (TE->State != TreeEntry::Vectorize) NonVectorized.push_back(TE.get()); - if (Optional<OrdersType> CurrentOrder = + if (std::optional<OrdersType> CurrentOrder = getReorderingData(*TE, /*TopToBottom=*/false)) { OrderedEntries.insert(TE.get()); - if (TE->State != TreeEntry::Vectorize) + if (TE->State != TreeEntry::Vectorize || !TE->ReuseShuffleIndices.empty()) GathersToOrders.try_emplace(TE.get(), *CurrentOrder); } }); @@ -4057,10 +4479,11 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { TreeEntry *OpTE = Op.second; if (!VisitedOps.insert(OpTE).second) continue; - if (!OpTE->ReuseShuffleIndices.empty()) + if (!OpTE->ReuseShuffleIndices.empty() && !GathersToOrders.count(OpTE)) continue; const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & { - if (OpTE->State == TreeEntry::NeedToGather) + if (OpTE->State == TreeEntry::NeedToGather || + !OpTE->ReuseShuffleIndices.empty()) return GathersToOrders.find(OpTE)->second; return OpTE->ReorderIndices; }(); @@ -4166,8 +4589,7 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { if (!VisitedOps.insert(TE).second) continue; if (TE->ReuseShuffleIndices.size() == BestOrder.size()) { - // Just reorder reuses indices. - reorderReuses(TE->ReuseShuffleIndices, Mask); + reorderNodeWithReuses(*TE, Mask); continue; } // Gathers are processed separately. @@ -4322,7 +4744,7 @@ BoUpSLP::collectUserStores(const BoUpSLP::TreeEntry *TE) const { return PtrToStoresMap; } -bool BoUpSLP::CanFormVector(const SmallVector<StoreInst *, 4> &StoresVec, +bool BoUpSLP::canFormVector(const SmallVector<StoreInst *, 4> &StoresVec, OrdersType &ReorderIndices) const { // We check whether the stores in StoreVec can form a vector by sorting them // and checking whether they are consecutive. @@ -4336,7 +4758,7 @@ bool BoUpSLP::CanFormVector(const SmallVector<StoreInst *, 4> &StoresVec, Value *S0Ptr = S0->getPointerOperand(); for (unsigned Idx : seq<unsigned>(1, StoresVec.size())) { StoreInst *SI = StoresVec[Idx]; - Optional<int> Diff = + std::optional<int> Diff = getPointersDiff(S0Ty, S0Ptr, SI->getValueOperand()->getType(), SI->getPointerOperand(), *DL, *SE, /*StrictCheck=*/true); @@ -4416,7 +4838,7 @@ BoUpSLP::findExternalStoreUsersReorderIndices(TreeEntry *TE) const { // If the stores are not consecutive then abandon this StoresVec. OrdersType ReorderIndices; - if (!CanFormVector(StoresVec, ReorderIndices)) + if (!canFormVector(StoresVec, ReorderIndices)) continue; // We now know that the scalars in StoresVec can form a vector instruction, @@ -4472,24 +4894,24 @@ static std::pair<size_t, size_t> generateKeySubkey( hash_code SubKey = hash_value(0); // Sort the loads by the distance between the pointers. if (auto *LI = dyn_cast<LoadInst>(V)) { - Key = hash_combine(hash_value(Instruction::Load), Key); + Key = hash_combine(LI->getType(), hash_value(Instruction::Load), Key); if (LI->isSimple()) SubKey = hash_value(LoadsSubkeyGenerator(Key, LI)); else - SubKey = hash_value(LI); + Key = SubKey = hash_value(LI); } else if (isVectorLikeInstWithConstOps(V)) { // Sort extracts by the vector operands. if (isa<ExtractElementInst, UndefValue>(V)) Key = hash_value(Value::UndefValueVal + 1); if (auto *EI = dyn_cast<ExtractElementInst>(V)) { - if (!isUndefVector(EI->getVectorOperand()) && + if (!isUndefVector(EI->getVectorOperand()).all() && !isa<UndefValue>(EI->getIndexOperand())) SubKey = hash_value(EI->getVectorOperand()); } } else if (auto *I = dyn_cast<Instruction>(V)) { // Sort other instructions just by the opcodes except for CMPInst. // For CMP also sort by the predicate kind. - if ((isa<BinaryOperator>(I) || isa<CastInst>(I)) && + if ((isa<BinaryOperator, CastInst>(I)) && isValidForAlternation(I->getOpcode())) { if (AllowAlternate) Key = hash_value(isa<BinaryOperator>(I) ? 1 : 0); @@ -4504,7 +4926,7 @@ static std::pair<size_t, size_t> generateKeySubkey( if (isa<CastInst>(I)) { std::pair<size_t, size_t> OpVals = generateKeySubkey(I->getOperand(0), TLI, LoadsSubkeyGenerator, - /*=AllowAlternate*/ true); + /*AllowAlternate=*/true); Key = hash_combine(OpVals.first, Key); SubKey = hash_combine(OpVals.first, SubKey); } @@ -4547,6 +4969,13 @@ static std::pair<size_t, size_t> generateKeySubkey( return std::make_pair(Key, SubKey); } +/// Checks if the specified instruction \p I is an alternate operation for +/// the given \p MainOp and \p AltOp instructions. +static bool isAlternateInstruction(const Instruction *I, + const Instruction *MainOp, + const Instruction *AltOp, + const TargetLibraryInfo &TLI); + void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, const EdgeInfo &UserTreeIdx) { assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); @@ -4557,7 +4986,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, &UserTreeIdx, this](const InstructionsState &S) { // Check that every instruction appears once in this bundle. - DenseMap<Value *, unsigned> UniquePositions; + DenseMap<Value *, unsigned> UniquePositions(VL.size()); for (Value *V : VL) { if (isConstant(V)) { ReuseShuffleIndicies.emplace_back( @@ -4583,7 +5012,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, })) || !llvm::isPowerOf2_32(NumUniqueScalarValues)) { LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return false; } VL = UniqueValues; @@ -4591,7 +5020,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return true; }; - InstructionsState S = getSameOpcode(VL); + InstructionsState S = getSameOpcode(VL, *TLI); // Gather if we hit the RecursionMaxDepth, unless this is a load (or z/sext of // a load), in which case peek through to include it in the tree, without @@ -4607,7 +5036,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, })))) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); if (TryToFindDuplicates(S)) - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -4618,7 +5047,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, cast<ExtractElementInst>(S.OpValue)->getVectorOperandType())) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to scalable vector type.\n"); if (TryToFindDuplicates(S)) - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -4627,14 +5056,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (S.OpValue->getType()->isVectorTy() && !isa<InsertElementInst>(S.OpValue)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return; } if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue)) if (SI->getValueOperand()->getType()->isVectorTy()) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return; } @@ -4696,10 +5125,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, }; SmallVector<unsigned> SortedIndices; BasicBlock *BB = nullptr; + bool IsScatterVectorizeUserTE = + UserTreeIdx.UserTE && + UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize; bool AreAllSameInsts = (S.getOpcode() && allSameBlock(VL)) || - (S.OpValue->getType()->isPointerTy() && UserTreeIdx.UserTE && - UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize && + (S.OpValue->getType()->isPointerTy() && IsScatterVectorizeUserTE && VL.size() > 2 && all_of(VL, [&BB](Value *V) { @@ -4713,14 +5144,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, BB && sortPtrAccesses(VL, UserTreeIdx.UserTE->getMainOp()->getType(), *DL, *SE, SortedIndices)); - if (allConstant(VL) || isSplat(VL) || !AreAllSameInsts || + if (!AreAllSameInsts || allConstant(VL) || isSplat(VL) || (isa<InsertElementInst, ExtractValueInst, ExtractElementInst>( S.OpValue) && !all_of(VL, isVectorLikeInstWithConstOps)) || NotProfitableForVectorization(VL)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O, small shuffle. \n"); if (TryToFindDuplicates(S)) - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -4734,7 +5165,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (EphValues.count(V)) { LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V << ") is ephemeral.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return; } } @@ -4746,7 +5177,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!E->isSame(VL)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); if (TryToFindDuplicates(S)) - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -4760,14 +5191,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Check that none of the instructions in the bundle are already in the tree. for (Value *V : VL) { - auto *I = dyn_cast<Instruction>(V); - if (!I) + if (!IsScatterVectorizeUserTE && !isa<Instruction>(V)) continue; - if (getTreeEntry(I)) { + if (getTreeEntry(V)) { LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V << ") is already in tree.\n"); if (TryToFindDuplicates(S)) - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -4779,7 +5209,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (UserIgnoreList && UserIgnoreList->contains(V)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); if (TryToFindDuplicates(S)) - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -4788,9 +5218,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Special processing for sorted pointers for ScatterVectorize node with // constant indeces only. - if (AreAllSameInsts && !(S.getOpcode() && allSameBlock(VL)) && - UserTreeIdx.UserTE && - UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize) { + if (AreAllSameInsts && UserTreeIdx.UserTE && + UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize && + !(S.getOpcode() && allSameBlock(VL))) { assert(S.OpValue->getType()->isPointerTy() && count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >= 2 && @@ -4798,7 +5228,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Reset S to make it GetElementPtr kind of node. const auto *It = find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }); assert(It != VL.end() && "Expected at least one GEP."); - S = getSameOpcode(*It); + S = getSameOpcode(*It, *TLI); } // Check that all of the users of the scalars that we want to vectorize are @@ -4810,7 +5240,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Don't go into unreachable blocks. They may contain instructions with // dependency cycles which confuse the final scheduling. LLVM_DEBUG(dbgs() << "SLP: bundle in unreachable block.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return; } @@ -4819,7 +5249,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // place to insert a shuffle if we need to, so just avoid that issue. if (isa<CatchSwitchInst>(BB->getTerminator())) { LLVM_DEBUG(dbgs() << "SLP: bundle in catchswitch block.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return; } @@ -4833,7 +5263,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, BlockScheduling &BS = *BSRef; - Optional<ScheduleData *> Bundle = BS.tryScheduleBundle(VL, this, S); + std::optional<ScheduleData *> Bundle = BS.tryScheduleBundle(VL, this, S); #ifdef EXPENSIVE_CHECKS // Make sure we didn't break any internal invariants BS.verify(); @@ -4843,7 +5273,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, assert((!BS.getScheduleData(VL0) || !BS.getScheduleData(VL0)->isPartOfBundle()) && "tryScheduleBundle should cancelScheduling on failure"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -4863,7 +5293,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (terminator use).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -4930,7 +5360,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); BS.cancelScheduling(VL, VL0); return; @@ -4943,7 +5373,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, ValueSet SourceVectors; for (Value *V : VL) { SourceVectors.insert(cast<Instruction>(V)->getOperand(0)); - assert(getInsertIndex(V) != None && "Non-constant or undef index?"); + assert(getInsertIndex(V) != std::nullopt && + "Non-constant or undef index?"); } if (count_if(VL, [&SourceVectors](Value *V) { @@ -4952,7 +5383,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Found 2nd source vector - cancel. LLVM_DEBUG(dbgs() << "SLP: Gather of insertelement vectors with " "different source vectors.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); BS.cancelScheduling(VL, VL0); return; } @@ -4978,7 +5409,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (IsIdentity) CurrentOrder.clear(); TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - None, CurrentOrder); + std::nullopt, CurrentOrder); LLVM_DEBUG(dbgs() << "SLP: added inserts bundle.\n"); constexpr int NumOps = 2; @@ -5002,8 +5433,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, SmallVector<Value *> PointerOps; OrdersType CurrentOrder; TreeEntry *TE = nullptr; - switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, CurrentOrder, - PointerOps)) { + switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI, + CurrentOrder, PointerOps)) { case LoadsState::Vectorize: if (CurrentOrder.empty()) { // Original loads are consecutive and does not require reordering. @@ -5029,7 +5460,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, break; case LoadsState::Gather: BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); #ifndef NDEBUG Type *ScalarTy = VL0->getType(); @@ -5064,7 +5495,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Type *Ty = cast<Instruction>(V)->getOperand(0)->getType(); if (Ty != SrcTy || !isValidElementType(Ty)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering casts with different src types.\n"); @@ -5097,7 +5528,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if ((Cmp->getPredicate() != P0 && Cmp->getPredicate() != SwapP0) || Cmp->getOperand(0)->getType() != ComparedTy) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n"); @@ -5114,7 +5545,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Commutative predicate - collect + sort operands of the instructions // so that each side is more likely to have the same opcode. assert(P0 == SwapP0 && "Commutative Predicate mismatch"); - reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this); + reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, *this); } else { // Collect operands - commute if it uses the swapped predicate. for (Value *V : VL) { @@ -5161,7 +5592,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // have the same opcode. if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) { ValueList Left, Right; - reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this); + reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, *this); TE->setOperand(0, Left); TE->setOperand(1, Right); buildTree_rec(Left, Depth + 1, {TE, 0}); @@ -5189,7 +5620,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (I->getNumOperands() != 2) { LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -5207,15 +5638,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } } - bool IsScatterUser = - UserTreeIdx.UserTE && - UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize; // We don't combine GEPs with non-constant indexes. Type *Ty1 = VL0->getOperand(1)->getType(); for (Value *V : VL) { @@ -5223,16 +5651,16 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!I) continue; auto *Op = I->getOperand(1); - if ((!IsScatterUser && !isa<ConstantInt>(Op)) || + if ((!IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) || (Op->getType() != Ty1 && - ((IsScatterUser && !isa<ConstantInt>(Op)) || + ((IsScatterVectorizeUserTE && !isa<ConstantInt>(Op)) || Op->getType()->getScalarSizeInBits() > DL->getIndexSizeInBits( V->getType()->getPointerAddressSpace())))) { LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); return; } @@ -5300,7 +5728,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering stores of non-packed type.\n"); return; @@ -5315,7 +5743,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, auto *SI = cast<StoreInst>(V); if (!SI->isSimple()) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple stores.\n"); return; @@ -5338,7 +5766,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Ptr0 = PointerOps[CurrentOrder.front()]; PtrN = PointerOps[CurrentOrder.back()]; } - Optional<int> Dist = + std::optional<int> Dist = getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE); // Check that the sorted pointer operands are consecutive. if (static_cast<unsigned>(*Dist) == VL.size() - 1) { @@ -5363,7 +5791,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); return; @@ -5381,7 +5809,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!VecFunc && !isTriviallyVectorizable(ID)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); return; @@ -5400,7 +5828,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, VecFunc != VFDatabase(*CI2).getVectorizedFunction(Shape)) || !CI->hasIdenticalOperandBundleSchema(*CI2)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *V << "\n"); @@ -5413,7 +5841,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Value *A1J = CI2->getArgOperand(j); if (ScalarArgs[j] != A1J) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI << " argument " << ScalarArgs[j] << "!=" << A1J @@ -5428,7 +5856,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, CI->op_begin() + CI->getBundleOperandsEndIndex(), CI2->op_begin() + CI2->getBundleOperandsStartIndex())) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI << "!=" << *V << '\n'); @@ -5459,7 +5887,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // then do not vectorize this instruction. if (!S.isAltShuffle()) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); return; @@ -5475,31 +5903,28 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!CI || all_of(VL, [](Value *V) { return cast<CmpInst>(V)->isCommutative(); })) { - reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this); + reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, + *this); } else { - CmpInst::Predicate P0 = CI->getPredicate(); - CmpInst::Predicate AltP0 = cast<CmpInst>(S.AltOp)->getPredicate(); - assert(P0 != AltP0 && + auto *MainCI = cast<CmpInst>(S.MainOp); + auto *AltCI = cast<CmpInst>(S.AltOp); + CmpInst::Predicate MainP = MainCI->getPredicate(); + CmpInst::Predicate AltP = AltCI->getPredicate(); + assert(MainP != AltP && "Expected different main/alternate predicates."); - CmpInst::Predicate AltP0Swapped = CmpInst::getSwappedPredicate(AltP0); - Value *BaseOp0 = VL0->getOperand(0); - Value *BaseOp1 = VL0->getOperand(1); // Collect operands - commute if it uses the swapped predicate or // alternate operation. for (Value *V : VL) { auto *Cmp = cast<CmpInst>(V); Value *LHS = Cmp->getOperand(0); Value *RHS = Cmp->getOperand(1); - CmpInst::Predicate CurrentPred = Cmp->getPredicate(); - if (P0 == AltP0Swapped) { - if (CI != Cmp && S.AltOp != Cmp && - ((P0 == CurrentPred && - !areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)) || - (AltP0 == CurrentPred && - areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)))) + + if (isAlternateInstruction(Cmp, MainCI, AltCI, *TLI)) { + if (AltP == CmpInst::getSwappedPredicate(Cmp->getPredicate())) + std::swap(LHS, RHS); + } else { + if (MainP == CmpInst::getSwappedPredicate(Cmp->getPredicate())) std::swap(LHS, RHS); - } else if (P0 != CurrentPred && AltP0 != CurrentPred) { - std::swap(LHS, RHS); } Left.push_back(LHS); Right.push_back(RHS); @@ -5525,7 +5950,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } default: BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); return; @@ -5536,8 +5961,7 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { unsigned N = 1; Type *EltTy = T; - while (isa<StructType>(EltTy) || isa<ArrayType>(EltTy) || - isa<VectorType>(EltTy)) { + while (isa<StructType, ArrayType, VectorType>(EltTy)) { if (auto *ST = dyn_cast<StructType>(EltTy)) { // Check that struct is homogeneous. for (const auto *Ty : ST->elements()) @@ -5619,7 +6043,7 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, if (auto *EE = dyn_cast<ExtractElementInst>(Inst)) if (isa<UndefValue>(EE->getIndexOperand())) continue; - Optional<unsigned> Idx = getExtractIndex(Inst); + std::optional<unsigned> Idx = getExtractIndex(Inst); if (!Idx) break; const unsigned ExtIdx = *Idx; @@ -5787,32 +6211,388 @@ buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, } } -/// Checks if the specified instruction \p I is an alternate operation for the -/// given \p MainOp and \p AltOp instructions. static bool isAlternateInstruction(const Instruction *I, const Instruction *MainOp, - const Instruction *AltOp) { - if (auto *CI0 = dyn_cast<CmpInst>(MainOp)) { - auto *AltCI0 = cast<CmpInst>(AltOp); + const Instruction *AltOp, + const TargetLibraryInfo &TLI) { + if (auto *MainCI = dyn_cast<CmpInst>(MainOp)) { + auto *AltCI = cast<CmpInst>(AltOp); + CmpInst::Predicate MainP = MainCI->getPredicate(); + CmpInst::Predicate AltP = AltCI->getPredicate(); + assert(MainP != AltP && "Expected different main/alternate predicates."); auto *CI = cast<CmpInst>(I); - CmpInst::Predicate P0 = CI0->getPredicate(); - CmpInst::Predicate AltP0 = AltCI0->getPredicate(); - assert(P0 != AltP0 && "Expected different main/alternate predicates."); - CmpInst::Predicate AltP0Swapped = CmpInst::getSwappedPredicate(AltP0); - CmpInst::Predicate CurrentPred = CI->getPredicate(); - if (P0 == AltP0Swapped) - return I == AltCI0 || - (I != MainOp && - !areCompatibleCmpOps(CI0->getOperand(0), CI0->getOperand(1), - CI->getOperand(0), CI->getOperand(1))); - return AltP0 == CurrentPred || AltP0Swapped == CurrentPred; + if (isCmpSameOrSwapped(MainCI, CI, TLI)) + return false; + if (isCmpSameOrSwapped(AltCI, CI, TLI)) + return true; + CmpInst::Predicate P = CI->getPredicate(); + CmpInst::Predicate SwappedP = CmpInst::getSwappedPredicate(P); + + assert((MainP == P || AltP == P || MainP == SwappedP || AltP == SwappedP) && + "CmpInst expected to match either main or alternate predicate or " + "their swap."); + (void)AltP; + return MainP != P && MainP != SwappedP; } return I->getOpcode() == AltOp->getOpcode(); } +TTI::OperandValueInfo BoUpSLP::getOperandInfo(ArrayRef<Value *> VL, + unsigned OpIdx) { + assert(!VL.empty()); + const auto *I0 = cast<Instruction>(*find_if(VL, Instruction::classof)); + const auto *Op0 = I0->getOperand(OpIdx); + + const bool IsConstant = all_of(VL, [&](Value *V) { + // TODO: We should allow undef elements here + const auto *I = dyn_cast<Instruction>(V); + if (!I) + return true; + auto *Op = I->getOperand(OpIdx); + return isConstant(Op) && !isa<UndefValue>(Op); + }); + const bool IsUniform = all_of(VL, [&](Value *V) { + // TODO: We should allow undef elements here + const auto *I = dyn_cast<Instruction>(V); + if (!I) + return false; + return I->getOperand(OpIdx) == Op0; + }); + const bool IsPowerOfTwo = all_of(VL, [&](Value *V) { + // TODO: We should allow undef elements here + const auto *I = dyn_cast<Instruction>(V); + if (!I) { + assert((isa<UndefValue>(V) || + I0->getOpcode() == Instruction::GetElementPtr) && + "Expected undef or GEP."); + return true; + } + auto *Op = I->getOperand(OpIdx); + if (auto *CI = dyn_cast<ConstantInt>(Op)) + return CI->getValue().isPowerOf2(); + return false; + }); + const bool IsNegatedPowerOfTwo = all_of(VL, [&](Value *V) { + // TODO: We should allow undef elements here + const auto *I = dyn_cast<Instruction>(V); + if (!I) { + assert((isa<UndefValue>(V) || + I0->getOpcode() == Instruction::GetElementPtr) && + "Expected undef or GEP."); + return true; + } + const auto *Op = I->getOperand(OpIdx); + if (auto *CI = dyn_cast<ConstantInt>(Op)) + return CI->getValue().isNegatedPowerOf2(); + return false; + }); + + TTI::OperandValueKind VK = TTI::OK_AnyValue; + if (IsConstant && IsUniform) + VK = TTI::OK_UniformConstantValue; + else if (IsConstant) + VK = TTI::OK_NonUniformConstantValue; + else if (IsUniform) + VK = TTI::OK_UniformValue; + + TTI::OperandValueProperties VP = TTI::OP_None; + VP = IsPowerOfTwo ? TTI::OP_PowerOf2 : VP; + VP = IsNegatedPowerOfTwo ? TTI::OP_NegatedPowerOf2 : VP; + + return {VK, VP}; +} + +namespace { +/// The base class for shuffle instruction emission and shuffle cost estimation. +class BaseShuffleAnalysis { +protected: + /// Checks if the mask is an identity mask. + /// \param IsStrict if is true the function returns false if mask size does + /// not match vector size. + static bool isIdentityMask(ArrayRef<int> Mask, const FixedVectorType *VecTy, + bool IsStrict) { + int Limit = Mask.size(); + int VF = VecTy->getNumElements(); + return (VF == Limit || !IsStrict) && + all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) && + ShuffleVectorInst::isIdentityMask(Mask); + } + + /// Tries to combine 2 different masks into single one. + /// \param LocalVF Vector length of the permuted input vector. \p Mask may + /// change the size of the vector, \p LocalVF is the original size of the + /// shuffled vector. + static void combineMasks(unsigned LocalVF, SmallVectorImpl<int> &Mask, + ArrayRef<int> ExtMask) { + unsigned VF = Mask.size(); + SmallVector<int> NewMask(ExtMask.size(), UndefMaskElem); + for (int I = 0, Sz = ExtMask.size(); I < Sz; ++I) { + if (ExtMask[I] == UndefMaskElem) + continue; + int MaskedIdx = Mask[ExtMask[I] % VF]; + NewMask[I] = + MaskedIdx == UndefMaskElem ? UndefMaskElem : MaskedIdx % LocalVF; + } + Mask.swap(NewMask); + } + + /// Looks through shuffles trying to reduce final number of shuffles in the + /// code. The function looks through the previously emitted shuffle + /// instructions and properly mark indices in mask as undef. + /// For example, given the code + /// \code + /// %s1 = shufflevector <2 x ty> %0, poison, <1, 0> + /// %s2 = shufflevector <2 x ty> %1, poison, <1, 0> + /// \endcode + /// and if need to emit shuffle of %s1 and %s2 with mask <1, 0, 3, 2>, it will + /// look through %s1 and %s2 and select vectors %0 and %1 with mask + /// <0, 1, 2, 3> for the shuffle. + /// If 2 operands are of different size, the smallest one will be resized and + /// the mask recalculated properly. + /// For example, given the code + /// \code + /// %s1 = shufflevector <2 x ty> %0, poison, <1, 0, 1, 0> + /// %s2 = shufflevector <2 x ty> %1, poison, <1, 0, 1, 0> + /// \endcode + /// and if need to emit shuffle of %s1 and %s2 with mask <1, 0, 5, 4>, it will + /// look through %s1 and %s2 and select vectors %0 and %1 with mask + /// <0, 1, 2, 3> for the shuffle. + /// So, it tries to transform permutations to simple vector merge, if + /// possible. + /// \param V The input vector which must be shuffled using the given \p Mask. + /// If the better candidate is found, \p V is set to this best candidate + /// vector. + /// \param Mask The input mask for the shuffle. If the best candidate is found + /// during looking-through-shuffles attempt, it is updated accordingly. + /// \param SinglePermute true if the shuffle operation is originally a + /// single-value-permutation. In this case the look-through-shuffles procedure + /// may look for resizing shuffles as the best candidates. + /// \return true if the shuffle results in the non-resizing identity shuffle + /// (and thus can be ignored), false - otherwise. + static bool peekThroughShuffles(Value *&V, SmallVectorImpl<int> &Mask, + bool SinglePermute) { + Value *Op = V; + ShuffleVectorInst *IdentityOp = nullptr; + SmallVector<int> IdentityMask; + while (auto *SV = dyn_cast<ShuffleVectorInst>(Op)) { + // Exit if not a fixed vector type or changing size shuffle. + auto *SVTy = dyn_cast<FixedVectorType>(SV->getType()); + if (!SVTy) + break; + // Remember the identity or broadcast mask, if it is not a resizing + // shuffle. If no better candidates are found, this Op and Mask will be + // used in the final shuffle. + if (isIdentityMask(Mask, SVTy, /*IsStrict=*/false)) { + if (!IdentityOp || !SinglePermute || + (isIdentityMask(Mask, SVTy, /*IsStrict=*/true) && + !ShuffleVectorInst::isZeroEltSplatMask(IdentityMask))) { + IdentityOp = SV; + // Store current mask in the IdentityMask so later we did not lost + // this info if IdentityOp is selected as the best candidate for the + // permutation. + IdentityMask.assign(Mask); + } + } + // Remember the broadcast mask. If no better candidates are found, this Op + // and Mask will be used in the final shuffle. + // Zero splat can be used as identity too, since it might be used with + // mask <0, 1, 2, ...>, i.e. identity mask without extra reshuffling. + // E.g. if need to shuffle the vector with the mask <3, 1, 2, 0>, which is + // expensive, the analysis founds out, that the source vector is just a + // broadcast, this original mask can be transformed to identity mask <0, + // 1, 2, 3>. + // \code + // %0 = shuffle %v, poison, zeroinitalizer + // %res = shuffle %0, poison, <3, 1, 2, 0> + // \endcode + // may be transformed to + // \code + // %0 = shuffle %v, poison, zeroinitalizer + // %res = shuffle %0, poison, <0, 1, 2, 3> + // \endcode + if (SV->isZeroEltSplat()) { + IdentityOp = SV; + IdentityMask.assign(Mask); + } + int LocalVF = Mask.size(); + if (auto *SVOpTy = + dyn_cast<FixedVectorType>(SV->getOperand(0)->getType())) + LocalVF = SVOpTy->getNumElements(); + SmallVector<int> ExtMask(Mask.size(), UndefMaskElem); + for (auto [Idx, I] : enumerate(Mask)) { + if (I == UndefMaskElem) + continue; + ExtMask[Idx] = SV->getMaskValue(I); + } + bool IsOp1Undef = + isUndefVector(SV->getOperand(0), + buildUseMask(LocalVF, ExtMask, UseMask::FirstArg)) + .all(); + bool IsOp2Undef = + isUndefVector(SV->getOperand(1), + buildUseMask(LocalVF, ExtMask, UseMask::SecondArg)) + .all(); + if (!IsOp1Undef && !IsOp2Undef) { + // Update mask and mark undef elems. + for (int &I : Mask) { + if (I == UndefMaskElem) + continue; + if (SV->getMaskValue(I % SV->getShuffleMask().size()) == + UndefMaskElem) + I = UndefMaskElem; + } + break; + } + SmallVector<int> ShuffleMask(SV->getShuffleMask().begin(), + SV->getShuffleMask().end()); + combineMasks(LocalVF, ShuffleMask, Mask); + Mask.swap(ShuffleMask); + if (IsOp2Undef) + Op = SV->getOperand(0); + else + Op = SV->getOperand(1); + } + if (auto *OpTy = dyn_cast<FixedVectorType>(Op->getType()); + !OpTy || !isIdentityMask(Mask, OpTy, SinglePermute)) { + if (IdentityOp) { + V = IdentityOp; + assert(Mask.size() == IdentityMask.size() && + "Expected masks of same sizes."); + // Clear known poison elements. + for (auto [I, Idx] : enumerate(Mask)) + if (Idx == UndefMaskElem) + IdentityMask[I] = UndefMaskElem; + Mask.swap(IdentityMask); + auto *Shuffle = dyn_cast<ShuffleVectorInst>(V); + return SinglePermute && + (isIdentityMask(Mask, cast<FixedVectorType>(V->getType()), + /*IsStrict=*/true) || + (Shuffle && Mask.size() == Shuffle->getShuffleMask().size() && + Shuffle->isZeroEltSplat() && + ShuffleVectorInst::isZeroEltSplatMask(Mask))); + } + V = Op; + return false; + } + V = Op; + return true; + } + + /// Smart shuffle instruction emission, walks through shuffles trees and + /// tries to find the best matching vector for the actual shuffle + /// instruction. + template <typename ShuffleBuilderTy> + static Value *createShuffle(Value *V1, Value *V2, ArrayRef<int> Mask, + ShuffleBuilderTy &Builder) { + assert(V1 && "Expected at least one vector value."); + int VF = Mask.size(); + if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType())) + VF = FTy->getNumElements(); + if (V2 && + !isUndefVector(V2, buildUseMask(VF, Mask, UseMask::SecondArg)).all()) { + // Peek through shuffles. + Value *Op1 = V1; + Value *Op2 = V2; + int VF = + cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue(); + SmallVector<int> CombinedMask1(Mask.size(), UndefMaskElem); + SmallVector<int> CombinedMask2(Mask.size(), UndefMaskElem); + for (int I = 0, E = Mask.size(); I < E; ++I) { + if (Mask[I] < VF) + CombinedMask1[I] = Mask[I]; + else + CombinedMask2[I] = Mask[I] - VF; + } + Value *PrevOp1; + Value *PrevOp2; + do { + PrevOp1 = Op1; + PrevOp2 = Op2; + (void)peekThroughShuffles(Op1, CombinedMask1, /*SinglePermute=*/false); + (void)peekThroughShuffles(Op2, CombinedMask2, /*SinglePermute=*/false); + // Check if we have 2 resizing shuffles - need to peek through operands + // again. + if (auto *SV1 = dyn_cast<ShuffleVectorInst>(Op1)) + if (auto *SV2 = dyn_cast<ShuffleVectorInst>(Op2)) { + SmallVector<int> ExtMask1(Mask.size(), UndefMaskElem); + for (auto [Idx, I] : enumerate(CombinedMask1)) { + if (I == UndefMaskElem) + continue; + ExtMask1[Idx] = SV1->getMaskValue(I); + } + SmallBitVector UseMask1 = buildUseMask( + cast<FixedVectorType>(SV1->getOperand(1)->getType()) + ->getNumElements(), + ExtMask1, UseMask::SecondArg); + SmallVector<int> ExtMask2(CombinedMask2.size(), UndefMaskElem); + for (auto [Idx, I] : enumerate(CombinedMask2)) { + if (I == UndefMaskElem) + continue; + ExtMask2[Idx] = SV2->getMaskValue(I); + } + SmallBitVector UseMask2 = buildUseMask( + cast<FixedVectorType>(SV2->getOperand(1)->getType()) + ->getNumElements(), + ExtMask2, UseMask::SecondArg); + if (SV1->getOperand(0)->getType() == + SV2->getOperand(0)->getType() && + SV1->getOperand(0)->getType() != SV1->getType() && + isUndefVector(SV1->getOperand(1), UseMask1).all() && + isUndefVector(SV2->getOperand(1), UseMask2).all()) { + Op1 = SV1->getOperand(0); + Op2 = SV2->getOperand(0); + SmallVector<int> ShuffleMask1(SV1->getShuffleMask().begin(), + SV1->getShuffleMask().end()); + int LocalVF = ShuffleMask1.size(); + if (auto *FTy = dyn_cast<FixedVectorType>(Op1->getType())) + LocalVF = FTy->getNumElements(); + combineMasks(LocalVF, ShuffleMask1, CombinedMask1); + CombinedMask1.swap(ShuffleMask1); + SmallVector<int> ShuffleMask2(SV2->getShuffleMask().begin(), + SV2->getShuffleMask().end()); + LocalVF = ShuffleMask2.size(); + if (auto *FTy = dyn_cast<FixedVectorType>(Op2->getType())) + LocalVF = FTy->getNumElements(); + combineMasks(LocalVF, ShuffleMask2, CombinedMask2); + CombinedMask2.swap(ShuffleMask2); + } + } + } while (PrevOp1 != Op1 || PrevOp2 != Op2); + Builder.resizeToMatch(Op1, Op2); + VF = std::max(cast<VectorType>(Op1->getType()) + ->getElementCount() + .getKnownMinValue(), + cast<VectorType>(Op2->getType()) + ->getElementCount() + .getKnownMinValue()); + for (int I = 0, E = Mask.size(); I < E; ++I) { + if (CombinedMask2[I] != UndefMaskElem) { + assert(CombinedMask1[I] == UndefMaskElem && + "Expected undefined mask element"); + CombinedMask1[I] = CombinedMask2[I] + (Op1 == Op2 ? 0 : VF); + } + } + return Builder.createShuffleVector( + Op1, Op1 == Op2 ? PoisonValue::get(Op1->getType()) : Op2, + CombinedMask1); + } + if (isa<PoisonValue>(V1)) + return PoisonValue::get(FixedVectorType::get( + cast<VectorType>(V1->getType())->getElementType(), Mask.size())); + SmallVector<int> NewMask(Mask.begin(), Mask.end()); + bool IsIdentity = peekThroughShuffles(V1, NewMask, /*SinglePermute=*/true); + assert(V1 && "Expected non-null value after looking through shuffles."); + + if (!IsIdentity) + return Builder.createShuffleVector(V1, NewMask); + return V1; + } +}; +} // namespace + InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals) { - ArrayRef<Value*> VL = E->Scalars; + ArrayRef<Value *> VL = E->Scalars; Type *ScalarTy = VL[0]->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) @@ -5834,9 +6614,12 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); // FIXME: it tries to fix a problem with MSVC buildbots. - TargetTransformInfo &TTIRef = *TTI; - auto &&AdjustExtractsCost = [this, &TTIRef, CostKind, VL, VecTy, - VectorizedVals, E](InstructionCost &Cost) { + TargetTransformInfo *TTI = this->TTI; + auto AdjustExtractsCost = [=](InstructionCost &Cost) { + // If the resulting type is scalarized, do not adjust the cost. + unsigned VecNumParts = TTI->getNumberOfParts(VecTy); + if (VecNumParts == VecTy->getNumElements()) + return; DenseMap<Value *, int> ExtractVectorsTys; SmallPtrSet<Value *, 4> CheckedExtracts; for (auto *V : VL) { @@ -5854,12 +6637,11 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, (VE && VE != E)) continue; auto *EE = cast<ExtractElementInst>(V); - Optional<unsigned> EEIdx = getExtractIndex(EE); + std::optional<unsigned> EEIdx = getExtractIndex(EE); if (!EEIdx) continue; unsigned Idx = *EEIdx; - if (TTIRef.getNumberOfParts(VecTy) != - TTIRef.getNumberOfParts(EE->getVectorOperandType())) { + if (VecNumParts != TTI->getNumberOfParts(EE->getVectorOperandType())) { auto It = ExtractVectorsTys.try_emplace(EE->getVectorOperand(), Idx).first; It->getSecond() = std::min<int>(It->second, Idx); @@ -5867,23 +6649,23 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // Take credit for instruction that will become dead. if (EE->hasOneUse()) { Instruction *Ext = EE->user_back(); - if ((isa<SExtInst>(Ext) || isa<ZExtInst>(Ext)) && - all_of(Ext->users(), - [](User *U) { return isa<GetElementPtrInst>(U); })) { + if (isa<SExtInst, ZExtInst>(Ext) && all_of(Ext->users(), [](User *U) { + return isa<GetElementPtrInst>(U); + })) { // Use getExtractWithExtendCost() to calculate the cost of // extractelement/ext pair. Cost -= - TTIRef.getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(), - EE->getVectorOperandType(), Idx); + TTI->getExtractWithExtendCost(Ext->getOpcode(), Ext->getType(), + EE->getVectorOperandType(), Idx); // Add back the cost of s|zext which is subtracted separately. - Cost += TTIRef.getCastInstrCost( + Cost += TTI->getCastInstrCost( Ext->getOpcode(), Ext->getType(), EE->getType(), TTI::getCastContextHint(Ext), CostKind, Ext); continue; } } - Cost -= TTIRef.getVectorInstrCost(Instruction::ExtractElement, - EE->getVectorOperandType(), Idx); + Cost -= TTI->getVectorInstrCost(*EE, EE->getVectorOperandType(), CostKind, + Idx); } // Add a cost for subvector extracts/inserts if required. for (const auto &Data : ExtractVectorsTys) { @@ -5891,13 +6673,13 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, unsigned NumElts = VecTy->getNumElements(); if (Data.second % NumElts == 0) continue; - if (TTIRef.getNumberOfParts(EEVTy) > TTIRef.getNumberOfParts(VecTy)) { + if (TTI->getNumberOfParts(EEVTy) > VecNumParts) { unsigned Idx = (Data.second / NumElts) * NumElts; unsigned EENumElts = EEVTy->getNumElements(); if (Idx + NumElts <= EENumElts) { Cost += - TTIRef.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, - EEVTy, None, Idx, VecTy); + TTI->getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, + EEVTy, std::nullopt, CostKind, Idx, VecTy); } else { // Need to round up the subvector type vectorization factor to avoid a // crash in cost model functions. Make SubVT so that Idx + VF of SubVT @@ -5905,12 +6687,12 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, auto *SubVT = FixedVectorType::get(VecTy->getElementType(), EENumElts - Idx); Cost += - TTIRef.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, - EEVTy, None, Idx, SubVT); + TTI->getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, + EEVTy, std::nullopt, CostKind, Idx, SubVT); } } else { - Cost += TTIRef.getShuffleCost(TargetTransformInfo::SK_InsertSubvector, - VecTy, None, 0, EEVTy); + Cost += TTI->getShuffleCost(TargetTransformInfo::SK_InsertSubvector, + VecTy, std::nullopt, CostKind, 0, EEVTy); } } }; @@ -5919,13 +6701,36 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, return 0; if (isa<InsertElementInst>(VL[0])) return InstructionCost::getInvalid(); + SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end()); + // Build a mask out of the reorder indices and reorder scalars per this + // mask. + SmallVector<int> ReorderMask; + inversePermutation(E->ReorderIndices, ReorderMask); + if (!ReorderMask.empty()) + reorderScalars(GatheredScalars, ReorderMask); SmallVector<int> Mask; + std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle; SmallVector<const TreeEntry *> Entries; - Optional<TargetTransformInfo::ShuffleKind> Shuffle = - isGatherShuffledEntry(E, Mask, Entries); - if (Shuffle) { + // Do not try to look for reshuffled loads for gathered loads (they will be + // handled later), for vectorized scalars, and cases, which are definitely + // not profitable (splats and small gather nodes.) + if (E->getOpcode() != Instruction::Load || E->isAltShuffle() || + all_of(E->Scalars, [this](Value *V) { return getTreeEntry(V); }) || + isSplat(E->Scalars) || + (E->Scalars != GatheredScalars && GatheredScalars.size() <= 2)) + GatherShuffle = isGatherShuffledEntry(E, GatheredScalars, Mask, Entries); + if (GatherShuffle) { + // Remove shuffled elements from list of gathers. + for (int I = 0, Sz = Mask.size(); I < Sz; ++I) { + if (Mask[I] != UndefMaskElem) + GatheredScalars[I] = PoisonValue::get(ScalarTy); + } + assert((Entries.size() == 1 || Entries.size() == 2) && + "Expected shuffle of 1 or 2 entries."); InstructionCost GatherCost = 0; - if (ShuffleVectorInst::isIdentityMask(Mask)) { + int Limit = Mask.size() * 2; + if (all_of(Mask, [=](int Idx) { return Idx < Limit; }) && + ShuffleVectorInst::isIdentityMask(Mask)) { // Perfect match in the graph, will reuse the previously vectorized // node. Cost is 0. LLVM_DEBUG( @@ -5944,8 +6749,10 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // previously vectorized nodes. Add the cost of the permutation rather // than gather. ::addMask(Mask, E->ReuseShuffleIndices); - GatherCost = TTI->getShuffleCost(*Shuffle, FinalVecTy, Mask); + GatherCost = TTI->getShuffleCost(*GatherShuffle, FinalVecTy, Mask); } + if (!all_of(GatheredScalars, UndefValue::classof)) + GatherCost += getGatherCost(GatheredScalars); return GatherCost; } if ((E->getOpcode() == Instruction::ExtractElement || @@ -5957,7 +6764,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // Check that gather of extractelements can be represented as just a // shuffle of a single/two vectors the scalars are extracted from. SmallVector<int> Mask; - Optional<TargetTransformInfo::ShuffleKind> ShuffleKind = + std::optional<TargetTransformInfo::ShuffleKind> ShuffleKind = isFixedVectorShuffle(VL, Mask); if (ShuffleKind) { // Found the bunch of extractelement instructions that must be gathered @@ -5977,9 +6784,24 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // broadcast. assert(VecTy == FinalVecTy && "No reused scalars expected for broadcast."); - return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, - /*Mask=*/None, /*Index=*/0, - /*SubTp=*/nullptr, /*Args=*/VL[0]); + const auto *It = + find_if(VL, [](Value *V) { return !isa<UndefValue>(V); }); + // If all values are undefs - consider cost free. + if (It == VL.end()) + return TTI::TCC_Free; + // Add broadcast for non-identity shuffle only. + bool NeedShuffle = + VL.front() != *It || !all_of(VL.drop_front(), UndefValue::classof); + InstructionCost InsertCost = + TTI->getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, + /*Index=*/0, PoisonValue::get(VecTy), *It); + return InsertCost + (NeedShuffle + ? TTI->getShuffleCost( + TargetTransformInfo::SK_Broadcast, VecTy, + /*Mask=*/std::nullopt, CostKind, + /*Index=*/0, + /*SubTp=*/nullptr, /*Args=*/VL[0]) + : TTI::TCC_Free); } InstructionCost ReuseShuffleCost = 0; if (NeedToShuffleReuses) @@ -6005,7 +6827,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, OrdersType CurrentOrder; LoadsState LS = canVectorizeLoads(Slice, Slice.front(), *TTI, *DL, *SE, *LI, - CurrentOrder, PointerOps); + *TLI, CurrentOrder, PointerOps); switch (LS) { case LoadsState::Vectorize: case LoadsState::ScatterVectorize: @@ -6048,9 +6870,10 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, InstructionCost ScalarsCost = 0; for (Value *V : VectorizedLoads) { auto *LI = cast<LoadInst>(V); - ScalarsCost += TTI->getMemoryOpCost( - Instruction::Load, LI->getType(), LI->getAlign(), - LI->getPointerAddressSpace(), CostKind, LI); + ScalarsCost += + TTI->getMemoryOpCost(Instruction::Load, LI->getType(), + LI->getAlign(), LI->getPointerAddressSpace(), + CostKind, TTI::OperandValueInfo(), LI); } auto *LI = cast<LoadInst>(E->getMainOp()); auto *LoadTy = FixedVectorType::get(LI->getType(), VF); @@ -6058,7 +6881,8 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, GatherCost += VectorizedCnt * TTI->getMemoryOpCost(Instruction::Load, LoadTy, Alignment, - LI->getPointerAddressSpace(), CostKind, LI); + LI->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo(), LI); GatherCost += ScatterVectorizeCnt * TTI->getGatherScatterOpCost( Instruction::Load, LoadTy, LI->getPointerOperand(), @@ -6066,8 +6890,9 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, if (NeedInsertSubvectorAnalysis) { // Add the cost for the subvectors insert. for (int I = VF, E = VL.size(); I < E; I += VF) - GatherCost += TTI->getShuffleCost(TTI::SK_InsertSubvector, VecTy, - None, I, LoadTy); + GatherCost += + TTI->getShuffleCost(TTI::SK_InsertSubvector, VecTy, + std::nullopt, CostKind, I, LoadTy); } return ReuseShuffleCost + GatherCost - ScalarsCost; } @@ -6103,240 +6928,306 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, Instruction *VL0 = E->getMainOp(); unsigned ShuffleOrOp = E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode(); + const unsigned Sz = VL.size(); + auto GetCostDiff = + [=](function_ref<InstructionCost(unsigned)> ScalarEltCost, + function_ref<InstructionCost(InstructionCost)> VectorCost) { + // Calculate the cost of this instruction. + InstructionCost ScalarCost = 0; + if (isa<CastInst, CmpInst, SelectInst, CallInst>(VL0)) { + // For some of the instructions no need to calculate cost for each + // particular instruction, we can use the cost of the single + // instruction x total number of scalar instructions. + ScalarCost = Sz * ScalarEltCost(0); + } else { + for (unsigned I = 0; I < Sz; ++I) + ScalarCost += ScalarEltCost(I); + } + + InstructionCost VecCost = VectorCost(CommonCost); + LLVM_DEBUG( + dumpTreeCosts(E, CommonCost, VecCost - CommonCost, ScalarCost)); + // Disable warnings for `this` and `E` are unused. Required for + // `dumpTreeCosts`. + (void)this; + (void)E; + return VecCost - ScalarCost; + }; + // Calculate cost difference from vectorizing set of GEPs. + // Negative value means vectorizing is profitable. + auto GetGEPCostDiff = [=](ArrayRef<Value *> Ptrs, Value *BasePtr) { + InstructionCost CostSavings = 0; + for (Value *V : Ptrs) { + if (V == BasePtr) + continue; + auto *Ptr = dyn_cast<GetElementPtrInst>(V); + // GEPs may contain just addresses without instructions, considered free. + // GEPs with all constant indices also considered to have zero cost. + if (!Ptr || Ptr->hasAllConstantIndices()) + continue; + + // Here we differentiate two cases: when GEPs represent a regular + // vectorization tree node (and hence vectorized) and when the set is + // arguments of a set of loads or stores being vectorized. In the former + // case all the scalar GEPs will be removed as a result of vectorization. + // For any external uses of some lanes extract element instructions will + // be generated (which cost is estimated separately). For the latter case + // since the set of GEPs itself is not vectorized those used more than + // once will remain staying in vectorized code as well. So we should not + // count them as savings. + if (!Ptr->hasOneUse() && isa<LoadInst, StoreInst>(VL0)) + continue; + + // TODO: it is target dependent, so need to implement and then use a TTI + // interface. + CostSavings += TTI->getArithmeticInstrCost(Instruction::Add, + Ptr->getType(), CostKind); + } + LLVM_DEBUG(dbgs() << "SLP: Calculated GEPs cost savings or Tree:\n"; + E->dump()); + LLVM_DEBUG(dbgs() << "SLP: GEP cost saving = " << CostSavings << "\n"); + return InstructionCost() - CostSavings; + }; + switch (ShuffleOrOp) { - case Instruction::PHI: - return 0; + case Instruction::PHI: { + // Count reused scalars. + InstructionCost ScalarCost = 0; + SmallPtrSet<const TreeEntry *, 4> CountedOps; + for (Value *V : VL) { + auto *PHI = dyn_cast<PHINode>(V); + if (!PHI) + continue; - case Instruction::ExtractValue: - case Instruction::ExtractElement: { - // The common cost of removal ExtractElement/ExtractValue instructions + - // the cost of shuffles, if required to resuffle the original vector. - if (NeedToShuffleReuses) { - unsigned Idx = 0; - for (unsigned I : E->ReuseShuffleIndices) { - if (ShuffleOrOp == Instruction::ExtractElement) { - auto *EE = cast<ExtractElementInst>(VL[I]); - CommonCost -= TTI->getVectorInstrCost(Instruction::ExtractElement, - EE->getVectorOperandType(), - *getExtractIndex(EE)); - } else { - CommonCost -= TTI->getVectorInstrCost(Instruction::ExtractElement, - VecTy, Idx); - ++Idx; - } - } - Idx = EntryVF; - for (Value *V : VL) { - if (ShuffleOrOp == Instruction::ExtractElement) { - auto *EE = cast<ExtractElementInst>(V); - CommonCost += TTI->getVectorInstrCost(Instruction::ExtractElement, - EE->getVectorOperandType(), - *getExtractIndex(EE)); - } else { - --Idx; - CommonCost += TTI->getVectorInstrCost(Instruction::ExtractElement, - VecTy, Idx); - } - } - } - if (ShuffleOrOp == Instruction::ExtractValue) { - for (unsigned I = 0, E = VL.size(); I < E; ++I) { - auto *EI = cast<Instruction>(VL[I]); - // Take credit for instruction that will become dead. - if (EI->hasOneUse()) { - Instruction *Ext = EI->user_back(); - if ((isa<SExtInst>(Ext) || isa<ZExtInst>(Ext)) && - all_of(Ext->users(), - [](User *U) { return isa<GetElementPtrInst>(U); })) { - // Use getExtractWithExtendCost() to calculate the cost of - // extractelement/ext pair. - CommonCost -= TTI->getExtractWithExtendCost( - Ext->getOpcode(), Ext->getType(), VecTy, I); - // Add back the cost of s|zext which is subtracted separately. - CommonCost += TTI->getCastInstrCost( - Ext->getOpcode(), Ext->getType(), EI->getType(), - TTI::getCastContextHint(Ext), CostKind, Ext); - continue; - } - } - CommonCost -= - TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, I); - } - } else { - AdjustExtractsCost(CommonCost); + ValueList Operands(PHI->getNumIncomingValues(), nullptr); + for (unsigned I = 0, N = PHI->getNumIncomingValues(); I < N; ++I) { + Value *Op = PHI->getIncomingValue(I); + Operands[I] = Op; } - return CommonCost; + if (const TreeEntry *OpTE = getTreeEntry(Operands.front())) + if (OpTE->isSame(Operands) && CountedOps.insert(OpTE).second) + if (!OpTE->ReuseShuffleIndices.empty()) + ScalarCost += TTI::TCC_Basic * (OpTE->ReuseShuffleIndices.size() - + OpTE->Scalars.size()); } - case Instruction::InsertElement: { - assert(E->ReuseShuffleIndices.empty() && - "Unique insertelements only are expected."); - auto *SrcVecTy = cast<FixedVectorType>(VL0->getType()); - unsigned const NumElts = SrcVecTy->getNumElements(); - unsigned const NumScalars = VL.size(); - - unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy); - - unsigned OffsetBeg = *getInsertIndex(VL.front()); - unsigned OffsetEnd = OffsetBeg; - for (Value *V : VL.drop_front()) { - unsigned Idx = *getInsertIndex(V); - if (OffsetBeg > Idx) - OffsetBeg = Idx; - else if (OffsetEnd < Idx) - OffsetEnd = Idx; - } - unsigned VecScalarsSz = PowerOf2Ceil(NumElts); - if (NumOfParts > 0) - VecScalarsSz = PowerOf2Ceil((NumElts + NumOfParts - 1) / NumOfParts); - unsigned VecSz = - (1 + OffsetEnd / VecScalarsSz - OffsetBeg / VecScalarsSz) * - VecScalarsSz; - unsigned Offset = VecScalarsSz * (OffsetBeg / VecScalarsSz); - unsigned InsertVecSz = std::min<unsigned>( - PowerOf2Ceil(OffsetEnd - OffsetBeg + 1), - ((OffsetEnd - OffsetBeg + VecScalarsSz) / VecScalarsSz) * - VecScalarsSz); - bool IsWholeSubvector = - OffsetBeg == Offset && ((OffsetEnd + 1) % VecScalarsSz == 0); - // Check if we can safely insert a subvector. If it is not possible, just - // generate a whole-sized vector and shuffle the source vector and the new - // subvector. - if (OffsetBeg + InsertVecSz > VecSz) { - // Align OffsetBeg to generate correct mask. - OffsetBeg = alignDown(OffsetBeg, VecSz, Offset); - InsertVecSz = VecSz; - } - - APInt DemandedElts = APInt::getZero(NumElts); - // TODO: Add support for Instruction::InsertValue. - SmallVector<int> Mask; - if (!E->ReorderIndices.empty()) { - inversePermutation(E->ReorderIndices, Mask); - Mask.append(InsertVecSz - Mask.size(), UndefMaskElem); + + return CommonCost - ScalarCost; + } + case Instruction::ExtractValue: + case Instruction::ExtractElement: { + auto GetScalarCost = [=](unsigned Idx) { + auto *I = cast<Instruction>(VL[Idx]); + VectorType *SrcVecTy; + if (ShuffleOrOp == Instruction::ExtractElement) { + auto *EE = cast<ExtractElementInst>(I); + SrcVecTy = EE->getVectorOperandType(); } else { - Mask.assign(VecSz, UndefMaskElem); - std::iota(Mask.begin(), std::next(Mask.begin(), InsertVecSz), 0); - } - bool IsIdentity = true; - SmallVector<int> PrevMask(InsertVecSz, UndefMaskElem); - Mask.swap(PrevMask); - for (unsigned I = 0; I < NumScalars; ++I) { - unsigned InsertIdx = *getInsertIndex(VL[PrevMask[I]]); - DemandedElts.setBit(InsertIdx); - IsIdentity &= InsertIdx - OffsetBeg == I; - Mask[InsertIdx - OffsetBeg] = I; + auto *EV = cast<ExtractValueInst>(I); + Type *AggregateTy = EV->getAggregateOperand()->getType(); + unsigned NumElts; + if (auto *ATy = dyn_cast<ArrayType>(AggregateTy)) + NumElts = ATy->getNumElements(); + else + NumElts = AggregateTy->getStructNumElements(); + SrcVecTy = FixedVectorType::get(ScalarTy, NumElts); } - assert(Offset < NumElts && "Failed to find vector index offset"); - - InstructionCost Cost = 0; - Cost -= TTI->getScalarizationOverhead(SrcVecTy, DemandedElts, - /*Insert*/ true, /*Extract*/ false); - - // First cost - resize to actual vector size if not identity shuffle or - // need to shift the vector. - // Do not calculate the cost if the actual size is the register size and - // we can merge this shuffle with the following SK_Select. - auto *InsertVecTy = - FixedVectorType::get(SrcVecTy->getElementType(), InsertVecSz); - if (!IsIdentity) - Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, - InsertVecTy, Mask); - auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) { - return !is_contained(E->Scalars, cast<Instruction>(V)->getOperand(0)); - })); - // Second cost - permutation with subvector, if some elements are from the - // initial vector or inserting a subvector. - // TODO: Implement the analysis of the FirstInsert->getOperand(0) - // subvector of ActualVecTy. - if (!isUndefVector(FirstInsert->getOperand(0)) && NumScalars != NumElts && - !IsWholeSubvector) { - if (InsertVecSz != VecSz) { - auto *ActualVecTy = - FixedVectorType::get(SrcVecTy->getElementType(), VecSz); - Cost += TTI->getShuffleCost(TTI::SK_InsertSubvector, ActualVecTy, - None, OffsetBeg - Offset, InsertVecTy); - } else { - for (unsigned I = 0, End = OffsetBeg - Offset; I < End; ++I) - Mask[I] = I; - for (unsigned I = OffsetBeg - Offset, End = OffsetEnd - Offset; - I <= End; ++I) - if (Mask[I] != UndefMaskElem) - Mask[I] = I + VecSz; - for (unsigned I = OffsetEnd + 1 - Offset; I < VecSz; ++I) - Mask[I] = I; - Cost += TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, InsertVecTy, Mask); + if (I->hasOneUse()) { + Instruction *Ext = I->user_back(); + if ((isa<SExtInst>(Ext) || isa<ZExtInst>(Ext)) && + all_of(Ext->users(), + [](User *U) { return isa<GetElementPtrInst>(U); })) { + // Use getExtractWithExtendCost() to calculate the cost of + // extractelement/ext pair. + InstructionCost Cost = TTI->getExtractWithExtendCost( + Ext->getOpcode(), Ext->getType(), SrcVecTy, *getExtractIndex(I)); + // Subtract the cost of s|zext which is subtracted separately. + Cost -= TTI->getCastInstrCost( + Ext->getOpcode(), Ext->getType(), I->getType(), + TTI::getCastContextHint(Ext), CostKind, Ext); + return Cost; } } - return Cost; + return TTI->getVectorInstrCost(Instruction::ExtractElement, SrcVecTy, + CostKind, *getExtractIndex(I)); + }; + auto GetVectorCost = [](InstructionCost CommonCost) { return CommonCost; }; + return GetCostDiff(GetScalarCost, GetVectorCost); + } + case Instruction::InsertElement: { + assert(E->ReuseShuffleIndices.empty() && + "Unique insertelements only are expected."); + auto *SrcVecTy = cast<FixedVectorType>(VL0->getType()); + unsigned const NumElts = SrcVecTy->getNumElements(); + unsigned const NumScalars = VL.size(); + + unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy); + + SmallVector<int> InsertMask(NumElts, UndefMaskElem); + unsigned OffsetBeg = *getInsertIndex(VL.front()); + unsigned OffsetEnd = OffsetBeg; + InsertMask[OffsetBeg] = 0; + for (auto [I, V] : enumerate(VL.drop_front())) { + unsigned Idx = *getInsertIndex(V); + if (OffsetBeg > Idx) + OffsetBeg = Idx; + else if (OffsetEnd < Idx) + OffsetEnd = Idx; + InsertMask[Idx] = I + 1; + } + unsigned VecScalarsSz = PowerOf2Ceil(NumElts); + if (NumOfParts > 0) + VecScalarsSz = PowerOf2Ceil((NumElts + NumOfParts - 1) / NumOfParts); + unsigned VecSz = (1 + OffsetEnd / VecScalarsSz - OffsetBeg / VecScalarsSz) * + VecScalarsSz; + unsigned Offset = VecScalarsSz * (OffsetBeg / VecScalarsSz); + unsigned InsertVecSz = std::min<unsigned>( + PowerOf2Ceil(OffsetEnd - OffsetBeg + 1), + ((OffsetEnd - OffsetBeg + VecScalarsSz) / VecScalarsSz) * VecScalarsSz); + bool IsWholeSubvector = + OffsetBeg == Offset && ((OffsetEnd + 1) % VecScalarsSz == 0); + // Check if we can safely insert a subvector. If it is not possible, just + // generate a whole-sized vector and shuffle the source vector and the new + // subvector. + if (OffsetBeg + InsertVecSz > VecSz) { + // Align OffsetBeg to generate correct mask. + OffsetBeg = alignDown(OffsetBeg, VecSz, Offset); + InsertVecSz = VecSz; + } + + APInt DemandedElts = APInt::getZero(NumElts); + // TODO: Add support for Instruction::InsertValue. + SmallVector<int> Mask; + if (!E->ReorderIndices.empty()) { + inversePermutation(E->ReorderIndices, Mask); + Mask.append(InsertVecSz - Mask.size(), UndefMaskElem); + } else { + Mask.assign(VecSz, UndefMaskElem); + std::iota(Mask.begin(), std::next(Mask.begin(), InsertVecSz), 0); + } + bool IsIdentity = true; + SmallVector<int> PrevMask(InsertVecSz, UndefMaskElem); + Mask.swap(PrevMask); + for (unsigned I = 0; I < NumScalars; ++I) { + unsigned InsertIdx = *getInsertIndex(VL[PrevMask[I]]); + DemandedElts.setBit(InsertIdx); + IsIdentity &= InsertIdx - OffsetBeg == I; + Mask[InsertIdx - OffsetBeg] = I; + } + assert(Offset < NumElts && "Failed to find vector index offset"); + + InstructionCost Cost = 0; + Cost -= TTI->getScalarizationOverhead(SrcVecTy, DemandedElts, + /*Insert*/ true, /*Extract*/ false, + CostKind); + + // First cost - resize to actual vector size if not identity shuffle or + // need to shift the vector. + // Do not calculate the cost if the actual size is the register size and + // we can merge this shuffle with the following SK_Select. + auto *InsertVecTy = + FixedVectorType::get(SrcVecTy->getElementType(), InsertVecSz); + if (!IsIdentity) + Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, + InsertVecTy, Mask); + auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) { + return !is_contained(E->Scalars, cast<Instruction>(V)->getOperand(0)); + })); + // Second cost - permutation with subvector, if some elements are from the + // initial vector or inserting a subvector. + // TODO: Implement the analysis of the FirstInsert->getOperand(0) + // subvector of ActualVecTy. + SmallBitVector InMask = + isUndefVector(FirstInsert->getOperand(0), + buildUseMask(NumElts, InsertMask, UseMask::UndefsAsMask)); + if (!InMask.all() && NumScalars != NumElts && !IsWholeSubvector) { + if (InsertVecSz != VecSz) { + auto *ActualVecTy = + FixedVectorType::get(SrcVecTy->getElementType(), VecSz); + Cost += TTI->getShuffleCost(TTI::SK_InsertSubvector, ActualVecTy, + std::nullopt, CostKind, OffsetBeg - Offset, + InsertVecTy); + } else { + for (unsigned I = 0, End = OffsetBeg - Offset; I < End; ++I) + Mask[I] = InMask.test(I) ? UndefMaskElem : I; + for (unsigned I = OffsetBeg - Offset, End = OffsetEnd - Offset; + I <= End; ++I) + if (Mask[I] != UndefMaskElem) + Mask[I] = I + VecSz; + for (unsigned I = OffsetEnd + 1 - Offset; I < VecSz; ++I) + Mask[I] = + ((I >= InMask.size()) || InMask.test(I)) ? UndefMaskElem : I; + Cost += TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, InsertVecTy, Mask); + } } - case Instruction::ZExt: - case Instruction::SExt: - case Instruction::FPToUI: - case Instruction::FPToSI: - case Instruction::FPExt: - case Instruction::PtrToInt: - case Instruction::IntToPtr: - case Instruction::SIToFP: - case Instruction::UIToFP: - case Instruction::Trunc: - case Instruction::FPTrunc: - case Instruction::BitCast: { + return Cost; + } + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: { + auto GetScalarCost = [=](unsigned Idx) { + auto *VI = cast<Instruction>(VL[Idx]); + return TTI->getCastInstrCost(E->getOpcode(), ScalarTy, + VI->getOperand(0)->getType(), + TTI::getCastContextHint(VI), CostKind, VI); + }; + auto GetVectorCost = [=](InstructionCost CommonCost) { Type *SrcTy = VL0->getOperand(0)->getType(); - InstructionCost ScalarEltCost = - TTI->getCastInstrCost(E->getOpcode(), ScalarTy, SrcTy, - TTI::getCastContextHint(VL0), CostKind, VL0); - if (NeedToShuffleReuses) { - CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; - } - - // Calculate the cost of this instruction. - InstructionCost ScalarCost = VL.size() * ScalarEltCost; - auto *SrcVecTy = FixedVectorType::get(SrcTy, VL.size()); - InstructionCost VecCost = 0; + InstructionCost VecCost = CommonCost; // Check if the values are candidates to demote. - if (!MinBWs.count(VL0) || VecTy != SrcVecTy) { - VecCost = CommonCost + TTI->getCastInstrCost( - E->getOpcode(), VecTy, SrcVecTy, - TTI::getCastContextHint(VL0), CostKind, VL0); - } - LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost, ScalarCost)); - return VecCost - ScalarCost; - } - case Instruction::FCmp: - case Instruction::ICmp: - case Instruction::Select: { - // Calculate the cost of this instruction. - InstructionCost ScalarEltCost = - TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy, Builder.getInt1Ty(), - CmpInst::BAD_ICMP_PREDICATE, CostKind, VL0); - if (NeedToShuffleReuses) { - CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; - } + if (!MinBWs.count(VL0) || VecTy != SrcVecTy) + VecCost += + TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy, + TTI::getCastContextHint(VL0), CostKind, VL0); + return VecCost; + }; + return GetCostDiff(GetScalarCost, GetVectorCost); + } + case Instruction::FCmp: + case Instruction::ICmp: + case Instruction::Select: { + CmpInst::Predicate VecPred, SwappedVecPred; + auto MatchCmp = m_Cmp(VecPred, m_Value(), m_Value()); + if (match(VL0, m_Select(MatchCmp, m_Value(), m_Value())) || + match(VL0, MatchCmp)) + SwappedVecPred = CmpInst::getSwappedPredicate(VecPred); + else + SwappedVecPred = VecPred = ScalarTy->isFloatingPointTy() + ? CmpInst::BAD_FCMP_PREDICATE + : CmpInst::BAD_ICMP_PREDICATE; + auto GetScalarCost = [&](unsigned Idx) { + auto *VI = cast<Instruction>(VL[Idx]); + CmpInst::Predicate CurrentPred = ScalarTy->isFloatingPointTy() + ? CmpInst::BAD_FCMP_PREDICATE + : CmpInst::BAD_ICMP_PREDICATE; + auto MatchCmp = m_Cmp(CurrentPred, m_Value(), m_Value()); + if ((!match(VI, m_Select(MatchCmp, m_Value(), m_Value())) && + !match(VI, MatchCmp)) || + (CurrentPred != VecPred && CurrentPred != SwappedVecPred)) + VecPred = SwappedVecPred = ScalarTy->isFloatingPointTy() + ? CmpInst::BAD_FCMP_PREDICATE + : CmpInst::BAD_ICMP_PREDICATE; + + return TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy, + Builder.getInt1Ty(), CurrentPred, CostKind, + VI); + }; + auto GetVectorCost = [&](InstructionCost CommonCost) { auto *MaskTy = FixedVectorType::get(Builder.getInt1Ty(), VL.size()); - InstructionCost ScalarCost = VecTy->getNumElements() * ScalarEltCost; - - // Check if all entries in VL are either compares or selects with compares - // as condition that have the same predicates. - CmpInst::Predicate VecPred = CmpInst::BAD_ICMP_PREDICATE; - bool First = true; - for (auto *V : VL) { - CmpInst::Predicate CurrentPred; - auto MatchCmp = m_Cmp(CurrentPred, m_Value(), m_Value()); - if ((!match(V, m_Select(MatchCmp, m_Value(), m_Value())) && - !match(V, MatchCmp)) || - (!First && VecPred != CurrentPred)) { - VecPred = CmpInst::BAD_ICMP_PREDICATE; - break; - } - First = false; - VecPred = CurrentPred; - } InstructionCost VecCost = TTI->getCmpSelInstrCost( E->getOpcode(), VecTy, MaskTy, VecPred, CostKind, VL0); - // Check if it is possible and profitable to use min/max for selects in - // VL. + // Check if it is possible and profitable to use min/max for selects + // in VL. // auto IntrinsicAndUse = canConvertToMinOrMaxIntrinsic(VL); if (IntrinsicAndUse.first != Intrinsic::not_intrinsic) { @@ -6344,216 +7235,181 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, {VecTy, VecTy}); InstructionCost IntrinsicCost = TTI->getIntrinsicInstrCost(CostAttrs, CostKind); - // If the selects are the only uses of the compares, they will be dead - // and we can adjust the cost by removing their cost. + // If the selects are the only uses of the compares, they will be + // dead and we can adjust the cost by removing their cost. if (IntrinsicAndUse.second) IntrinsicCost -= TTI->getCmpSelInstrCost(Instruction::ICmp, VecTy, MaskTy, VecPred, CostKind); VecCost = std::min(VecCost, IntrinsicCost); } - LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost, ScalarCost)); - return CommonCost + VecCost - ScalarCost; - } - case Instruction::FNeg: - case Instruction::Add: - case Instruction::FAdd: - case Instruction::Sub: - case Instruction::FSub: - case Instruction::Mul: - case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: - case Instruction::FRem: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: { - // Certain instructions can be cheaper to vectorize if they have a - // constant second vector operand. - TargetTransformInfo::OperandValueKind Op1VK = - TargetTransformInfo::OK_AnyValue; - TargetTransformInfo::OperandValueKind Op2VK = - TargetTransformInfo::OK_UniformConstantValue; - TargetTransformInfo::OperandValueProperties Op1VP = - TargetTransformInfo::OP_None; - TargetTransformInfo::OperandValueProperties Op2VP = - TargetTransformInfo::OP_PowerOf2; - - // If all operands are exactly the same ConstantInt then set the - // operand kind to OK_UniformConstantValue. - // If instead not all operands are constants, then set the operand kind - // to OK_AnyValue. If all operands are constants but not the same, - // then set the operand kind to OK_NonUniformConstantValue. - ConstantInt *CInt0 = nullptr; - for (unsigned i = 0, e = VL.size(); i < e; ++i) { - const Instruction *I = cast<Instruction>(VL[i]); - unsigned OpIdx = isa<BinaryOperator>(I) ? 1 : 0; - ConstantInt *CInt = dyn_cast<ConstantInt>(I->getOperand(OpIdx)); - if (!CInt) { - Op2VK = TargetTransformInfo::OK_AnyValue; - Op2VP = TargetTransformInfo::OP_None; - break; - } - if (Op2VP == TargetTransformInfo::OP_PowerOf2 && - !CInt->getValue().isPowerOf2()) - Op2VP = TargetTransformInfo::OP_None; - if (i == 0) { - CInt0 = CInt; - continue; - } - if (CInt0 != CInt) - Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; - } - - SmallVector<const Value *, 4> Operands(VL0->operand_values()); - InstructionCost ScalarEltCost = - TTI->getArithmeticInstrCost(E->getOpcode(), ScalarTy, CostKind, Op1VK, - Op2VK, Op1VP, Op2VP, Operands, VL0); - if (NeedToShuffleReuses) { - CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; - } - InstructionCost ScalarCost = VecTy->getNumElements() * ScalarEltCost; - InstructionCost VecCost = - TTI->getArithmeticInstrCost(E->getOpcode(), VecTy, CostKind, Op1VK, - Op2VK, Op1VP, Op2VP, Operands, VL0); - LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost, ScalarCost)); - return CommonCost + VecCost - ScalarCost; - } - case Instruction::GetElementPtr: { - TargetTransformInfo::OperandValueKind Op1VK = - TargetTransformInfo::OK_AnyValue; - TargetTransformInfo::OperandValueKind Op2VK = - any_of(VL, - [](Value *V) { - return isa<GetElementPtrInst>(V) && - !isConstant( - cast<GetElementPtrInst>(V)->getOperand(1)); - }) - ? TargetTransformInfo::OK_AnyValue - : TargetTransformInfo::OK_UniformConstantValue; - - InstructionCost ScalarEltCost = TTI->getArithmeticInstrCost( - Instruction::Add, ScalarTy, CostKind, Op1VK, Op2VK); - if (NeedToShuffleReuses) { - CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; - } - InstructionCost ScalarCost = VecTy->getNumElements() * ScalarEltCost; - InstructionCost VecCost = TTI->getArithmeticInstrCost( - Instruction::Add, VecTy, CostKind, Op1VK, Op2VK); - LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost, ScalarCost)); - return CommonCost + VecCost - ScalarCost; - } - case Instruction::Load: { - // Cost of wide load - cost of scalar loads. - Align Alignment = cast<LoadInst>(VL0)->getAlign(); - InstructionCost ScalarEltCost = TTI->getMemoryOpCost( - Instruction::Load, ScalarTy, Alignment, 0, CostKind, VL0); - if (NeedToShuffleReuses) { - CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; - } - InstructionCost ScalarLdCost = VecTy->getNumElements() * ScalarEltCost; + return VecCost + CommonCost; + }; + return GetCostDiff(GetScalarCost, GetVectorCost); + } + case Instruction::FNeg: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + auto GetScalarCost = [=](unsigned Idx) { + auto *VI = cast<Instruction>(VL[Idx]); + unsigned OpIdx = isa<UnaryOperator>(VI) ? 0 : 1; + TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(VI->getOperand(0)); + TTI::OperandValueInfo Op2Info = + TTI::getOperandInfo(VI->getOperand(OpIdx)); + SmallVector<const Value *> Operands(VI->operand_values()); + return TTI->getArithmeticInstrCost(ShuffleOrOp, ScalarTy, CostKind, + Op1Info, Op2Info, Operands, VI); + }; + auto GetVectorCost = [=](InstructionCost CommonCost) { + unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1; + TTI::OperandValueInfo Op1Info = getOperandInfo(VL, 0); + TTI::OperandValueInfo Op2Info = getOperandInfo(VL, OpIdx); + return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info, + Op2Info) + + CommonCost; + }; + return GetCostDiff(GetScalarCost, GetVectorCost); + } + case Instruction::GetElementPtr: { + return CommonCost + GetGEPCostDiff(VL, VL0); + } + case Instruction::Load: { + auto GetScalarCost = [=](unsigned Idx) { + auto *VI = cast<LoadInst>(VL[Idx]); + return TTI->getMemoryOpCost(Instruction::Load, ScalarTy, VI->getAlign(), + VI->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo(), VI); + }; + auto *LI0 = cast<LoadInst>(VL0); + auto GetVectorCost = [=](InstructionCost CommonCost) { InstructionCost VecLdCost; if (E->State == TreeEntry::Vectorize) { - VecLdCost = TTI->getMemoryOpCost(Instruction::Load, VecTy, Alignment, 0, - CostKind, VL0); + VecLdCost = TTI->getMemoryOpCost( + Instruction::Load, VecTy, LI0->getAlign(), + LI0->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo()); } else { assert(E->State == TreeEntry::ScatterVectorize && "Unknown EntryState"); - Align CommonAlignment = Alignment; + Align CommonAlignment = LI0->getAlign(); for (Value *V : VL) CommonAlignment = std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); VecLdCost = TTI->getGatherScatterOpCost( - Instruction::Load, VecTy, cast<LoadInst>(VL0)->getPointerOperand(), - /*VariableMask=*/false, CommonAlignment, CostKind, VL0); + Instruction::Load, VecTy, LI0->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind); } - LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecLdCost, ScalarLdCost)); - return CommonCost + VecLdCost - ScalarLdCost; - } - case Instruction::Store: { + return VecLdCost + CommonCost; + }; + + InstructionCost Cost = GetCostDiff(GetScalarCost, GetVectorCost); + // If this node generates masked gather load then it is not a terminal node. + // Hence address operand cost is estimated separately. + if (E->State == TreeEntry::ScatterVectorize) + return Cost; + + // Estimate cost of GEPs since this tree node is a terminator. + SmallVector<Value *> PointerOps(VL.size()); + for (auto [I, V] : enumerate(VL)) + PointerOps[I] = cast<LoadInst>(V)->getPointerOperand(); + return Cost + GetGEPCostDiff(PointerOps, LI0->getPointerOperand()); + } + case Instruction::Store: { + bool IsReorder = !E->ReorderIndices.empty(); + auto GetScalarCost = [=](unsigned Idx) { + auto *VI = cast<StoreInst>(VL[Idx]); + TTI::OperandValueInfo OpInfo = getOperandInfo(VI, 0); + return TTI->getMemoryOpCost(Instruction::Store, ScalarTy, VI->getAlign(), + VI->getPointerAddressSpace(), CostKind, + OpInfo, VI); + }; + auto *BaseSI = + cast<StoreInst>(IsReorder ? VL[E->ReorderIndices.front()] : VL0); + auto GetVectorCost = [=](InstructionCost CommonCost) { // We know that we can merge the stores. Calculate the cost. - bool IsReorder = !E->ReorderIndices.empty(); - auto *SI = - cast<StoreInst>(IsReorder ? VL[E->ReorderIndices.front()] : VL0); - Align Alignment = SI->getAlign(); - InstructionCost ScalarEltCost = TTI->getMemoryOpCost( - Instruction::Store, ScalarTy, Alignment, 0, CostKind, VL0); - InstructionCost ScalarStCost = VecTy->getNumElements() * ScalarEltCost; - InstructionCost VecStCost = TTI->getMemoryOpCost( - Instruction::Store, VecTy, Alignment, 0, CostKind, VL0); - LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecStCost, ScalarStCost)); - return CommonCost + VecStCost - ScalarStCost; + TTI::OperandValueInfo OpInfo = getOperandInfo(VL, 0); + return TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(), + BaseSI->getPointerAddressSpace(), CostKind, + OpInfo) + + CommonCost; + }; + SmallVector<Value *> PointerOps(VL.size()); + for (auto [I, V] : enumerate(VL)) { + unsigned Idx = IsReorder ? E->ReorderIndices[I] : I; + PointerOps[Idx] = cast<StoreInst>(V)->getPointerOperand(); } - case Instruction::Call: { - CallInst *CI = cast<CallInst>(VL0); - Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - // Calculate the cost of the scalar and vector calls. - IntrinsicCostAttributes CostAttrs(ID, *CI, 1); - InstructionCost ScalarEltCost = - TTI->getIntrinsicInstrCost(CostAttrs, CostKind); - if (NeedToShuffleReuses) { - CommonCost -= (EntryVF - VL.size()) * ScalarEltCost; + return GetCostDiff(GetScalarCost, GetVectorCost) + + GetGEPCostDiff(PointerOps, BaseSI->getPointerOperand()); + } + case Instruction::Call: { + auto GetScalarCost = [=](unsigned Idx) { + auto *CI = cast<CallInst>(VL[Idx]); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + if (ID != Intrinsic::not_intrinsic) { + IntrinsicCostAttributes CostAttrs(ID, *CI, 1); + return TTI->getIntrinsicInstrCost(CostAttrs, CostKind); } - InstructionCost ScalarCallCost = VecTy->getNumElements() * ScalarEltCost; - + return TTI->getCallInstrCost(CI->getCalledFunction(), + CI->getFunctionType()->getReturnType(), + CI->getFunctionType()->params(), CostKind); + }; + auto GetVectorCost = [=](InstructionCost CommonCost) { + auto *CI = cast<CallInst>(VL0); auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI); - InstructionCost VecCallCost = - std::min(VecCallCosts.first, VecCallCosts.second); - - LLVM_DEBUG(dbgs() << "SLP: Call cost " << VecCallCost - ScalarCallCost - << " (" << VecCallCost << "-" << ScalarCallCost << ")" - << " for " << *CI << "\n"); - - return CommonCost + VecCallCost - ScalarCallCost; - } - case Instruction::ShuffleVector: { - assert(E->isAltShuffle() && - ((Instruction::isBinaryOp(E->getOpcode()) && - Instruction::isBinaryOp(E->getAltOpcode())) || - (Instruction::isCast(E->getOpcode()) && - Instruction::isCast(E->getAltOpcode())) || - (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) && - "Invalid Shuffle Vector Operand"); - InstructionCost ScalarCost = 0; - if (NeedToShuffleReuses) { - for (unsigned Idx : E->ReuseShuffleIndices) { - Instruction *I = cast<Instruction>(VL[Idx]); - CommonCost -= TTI->getInstructionCost(I, CostKind); - } - for (Value *V : VL) { - Instruction *I = cast<Instruction>(V); - CommonCost += TTI->getInstructionCost(I, CostKind); - } - } - for (Value *V : VL) { - Instruction *I = cast<Instruction>(V); - assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); - ScalarCost += TTI->getInstructionCost(I, CostKind); + return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost; + }; + return GetCostDiff(GetScalarCost, GetVectorCost); + } + case Instruction::ShuffleVector: { + assert(E->isAltShuffle() && + ((Instruction::isBinaryOp(E->getOpcode()) && + Instruction::isBinaryOp(E->getAltOpcode())) || + (Instruction::isCast(E->getOpcode()) && + Instruction::isCast(E->getAltOpcode())) || + (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) && + "Invalid Shuffle Vector Operand"); + // Try to find the previous shuffle node with the same operands and same + // main/alternate ops. + auto TryFindNodeWithEqualOperands = [=]() { + for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree) { + if (TE.get() == E) + break; + if (TE->isAltShuffle() && + ((TE->getOpcode() == E->getOpcode() && + TE->getAltOpcode() == E->getAltOpcode()) || + (TE->getOpcode() == E->getAltOpcode() && + TE->getAltOpcode() == E->getOpcode())) && + TE->hasEqualOperands(*E)) + return true; } + return false; + }; + auto GetScalarCost = [=](unsigned Idx) { + auto *VI = cast<Instruction>(VL[Idx]); + assert(E->isOpcodeOrAlt(VI) && "Unexpected main/alternate opcode"); + (void)E; + return TTI->getInstructionCost(VI, CostKind); + }; + // Need to clear CommonCost since the final shuffle cost is included into + // vector cost. + auto GetVectorCost = [&](InstructionCost) { // VecCost is equal to sum of the cost of creating 2 vectors // and the cost of creating shuffle. InstructionCost VecCost = 0; - // Try to find the previous shuffle node with the same operands and same - // main/alternate ops. - auto &&TryFindNodeWithEqualOperands = [this, E]() { - for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree) { - if (TE.get() == E) - break; - if (TE->isAltShuffle() && - ((TE->getOpcode() == E->getOpcode() && - TE->getAltOpcode() == E->getAltOpcode()) || - (TE->getOpcode() == E->getAltOpcode() && - TE->getAltOpcode() == E->getOpcode())) && - TE->hasEqualOperands(*E)) - return true; - } - return false; - }; if (TryFindNodeWithEqualOperands()) { LLVM_DEBUG({ dbgs() << "SLP: diamond match for alternate node found.\n"; @@ -6563,8 +7419,8 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // same main/alternate vector ops, just do different shuffling. } else if (Instruction::isBinaryOp(E->getOpcode())) { VecCost = TTI->getArithmeticInstrCost(E->getOpcode(), VecTy, CostKind); - VecCost += TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy, - CostKind); + VecCost += + TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy, CostKind); } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) { VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy, Builder.getInt1Ty(), @@ -6583,9 +7439,8 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, VecCost += TTI->getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty, TTI::CastContextHint::None, CostKind); } - if (E->ReuseShuffleIndices.empty()) { - CommonCost = + VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, FinalVecTy); } else { SmallVector<int> Mask; @@ -6596,14 +7451,15 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, return I->getOpcode() == E->getAltOpcode(); }, Mask); - CommonCost = TTI->getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, - FinalVecTy, Mask); + VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, + FinalVecTy, Mask); } - LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost, ScalarCost)); - return CommonCost + VecCost - ScalarCost; - } - default: - llvm_unreachable("Unknown instruction"); + return VecCost; + }; + return GetCostDiff(GetScalarCost, GetVectorCost); + } + default: + llvm_unreachable("Unknown instruction"); } } @@ -6819,9 +7675,30 @@ InstructionCost BoUpSLP::getSpillCost() const { continue; } + auto NoCallIntrinsic = [this](Instruction *I) { + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + if (II->isAssumeLikeIntrinsic()) + return true; + FastMathFlags FMF; + SmallVector<Type *, 4> Tys; + for (auto &ArgOp : II->args()) + Tys.push_back(ArgOp->getType()); + if (auto *FPMO = dyn_cast<FPMathOperator>(II)) + FMF = FPMO->getFastMathFlags(); + IntrinsicCostAttributes ICA(II->getIntrinsicID(), II->getType(), Tys, + FMF); + InstructionCost IntrCost = + TTI->getIntrinsicInstrCost(ICA, TTI::TCK_RecipThroughput); + InstructionCost CallCost = TTI->getCallInstrCost( + nullptr, II->getType(), Tys, TTI::TCK_RecipThroughput); + if (IntrCost < CallCost) + return true; + } + return false; + }; + // Debug information does not impact spill cost. - if ((isa<CallInst>(&*PrevInstIt) && - !isa<DbgInfoIntrinsic>(&*PrevInstIt)) && + if (isa<CallInst>(&*PrevInstIt) && !NoCallIntrinsic(&*PrevInstIt) && &*PrevInstIt != PrevInst) NumCalls++; @@ -6845,48 +7722,6 @@ InstructionCost BoUpSLP::getSpillCost() const { return Cost; } -/// Check if two insertelement instructions are from the same buildvector. -static bool areTwoInsertFromSameBuildVector(InsertElementInst *VU, - InsertElementInst *V) { - // Instructions must be from the same basic blocks. - if (VU->getParent() != V->getParent()) - return false; - // Checks if 2 insertelements are from the same buildvector. - if (VU->getType() != V->getType()) - return false; - // Multiple used inserts are separate nodes. - if (!VU->hasOneUse() && !V->hasOneUse()) - return false; - auto *IE1 = VU; - auto *IE2 = V; - unsigned Idx1 = *getInsertIndex(IE1); - unsigned Idx2 = *getInsertIndex(IE2); - // Go through the vector operand of insertelement instructions trying to find - // either VU as the original vector for IE2 or V as the original vector for - // IE1. - do { - if (IE2 == VU) - return VU->hasOneUse(); - if (IE1 == V) - return V->hasOneUse(); - if (IE1) { - if ((IE1 != VU && !IE1->hasOneUse()) || - getInsertIndex(IE1).value_or(Idx2) == Idx2) - IE1 = nullptr; - else - IE1 = dyn_cast<InsertElementInst>(IE1->getOperand(0)); - } - if (IE2) { - if ((IE2 != V && !IE2->hasOneUse()) || - getInsertIndex(IE2).value_or(Idx1) == Idx1) - IE2 = nullptr; - else - IE2 = dyn_cast<InsertElementInst>(IE2->getOperand(0)); - } - } while (IE1 || IE2); - return false; -} - /// Checks if the \p IE1 instructions is followed by \p IE2 instruction in the /// buildvector sequence. static bool isFirstInsertElement(const InsertElementInst *IE1, @@ -6921,13 +7756,11 @@ namespace { /// value, otherwise. struct ValueSelect { template <typename U> - static typename std::enable_if<std::is_same<Value *, U>::value, Value *>::type - get(Value *V) { + static std::enable_if_t<std::is_same_v<Value *, U>, Value *> get(Value *V) { return V; } template <typename U> - static typename std::enable_if<!std::is_same<Value *, U>::value, U>::type - get(Value *) { + static std::enable_if_t<!std::is_same_v<Value *, U>, U> get(Value *) { return U(); } }; @@ -6949,19 +7782,23 @@ template <typename T> static T *performExtractsShuffleAction( MutableArrayRef<std::pair<T *, SmallVector<int>>> ShuffleMask, Value *Base, function_ref<unsigned(T *)> GetVF, - function_ref<std::pair<T *, bool>(T *, ArrayRef<int>)> ResizeAction, + function_ref<std::pair<T *, bool>(T *, ArrayRef<int>, bool)> ResizeAction, function_ref<T *(ArrayRef<int>, ArrayRef<T *>)> Action) { assert(!ShuffleMask.empty() && "Empty list of shuffles for inserts."); SmallVector<int> Mask(ShuffleMask.begin()->second); auto VMIt = std::next(ShuffleMask.begin()); T *Prev = nullptr; - bool IsBaseNotUndef = !isUndefVector(Base); - if (IsBaseNotUndef) { + SmallBitVector UseMask = + buildUseMask(Mask.size(), Mask, UseMask::UndefsAsMask); + SmallBitVector IsBaseUndef = isUndefVector(Base, UseMask); + if (!IsBaseUndef.all()) { // Base is not undef, need to combine it with the next subvectors. - std::pair<T *, bool> Res = ResizeAction(ShuffleMask.begin()->first, Mask); + std::pair<T *, bool> Res = + ResizeAction(ShuffleMask.begin()->first, Mask, /*ForSingleMask=*/false); + SmallBitVector IsBasePoison = isUndefVector<true>(Base, UseMask); for (unsigned Idx = 0, VF = Mask.size(); Idx < VF; ++Idx) { if (Mask[Idx] == UndefMaskElem) - Mask[Idx] = Idx; + Mask[Idx] = IsBasePoison.test(Idx) ? UndefMaskElem : Idx; else Mask[Idx] = (Res.second ? Idx : Mask[Idx]) + VF; } @@ -6973,7 +7810,8 @@ static T *performExtractsShuffleAction( } else if (ShuffleMask.size() == 1) { // Base is undef and only 1 vector is shuffled - perform the action only for // single vector, if the mask is not the identity mask. - std::pair<T *, bool> Res = ResizeAction(ShuffleMask.begin()->first, Mask); + std::pair<T *, bool> Res = ResizeAction(ShuffleMask.begin()->first, Mask, + /*ForSingleMask=*/true); if (Res.second) // Identity mask is found. Prev = Res.first; @@ -6997,9 +7835,10 @@ static T *performExtractsShuffleAction( Prev = Action(Mask, {ShuffleMask.begin()->first, VMIt->first}); } else { // Vectors of different sizes - resize and reshuffle. - std::pair<T *, bool> Res1 = - ResizeAction(ShuffleMask.begin()->first, Mask); - std::pair<T *, bool> Res2 = ResizeAction(VMIt->first, VMIt->second); + std::pair<T *, bool> Res1 = ResizeAction(ShuffleMask.begin()->first, Mask, + /*ForSingleMask=*/false); + std::pair<T *, bool> Res2 = + ResizeAction(VMIt->first, VMIt->second, /*ForSingleMask=*/false); ArrayRef<int> SecMask = VMIt->second; for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) { if (Mask[I] != UndefMaskElem) { @@ -7015,10 +7854,13 @@ static T *performExtractsShuffleAction( } VMIt = std::next(VMIt); } + bool IsBaseNotUndef = !IsBaseUndef.all(); + (void)IsBaseNotUndef; // Perform requested actions for the remaining masks/vectors. for (auto E = ShuffleMask.end(); VMIt != E; ++VMIt) { // Shuffle other input vectors, if any. - std::pair<T *, bool> Res = ResizeAction(VMIt->first, VMIt->second); + std::pair<T *, bool> Res = + ResizeAction(VMIt->first, VMIt->second, /*ForSingleMask=*/false); ArrayRef<int> SecMask = VMIt->second; for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) { if (SecMask[I] != UndefMaskElem) { @@ -7043,6 +7885,18 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) { TreeEntry &TE = *VectorizableTree[I]; + if (TE.State == TreeEntry::NeedToGather) { + if (const TreeEntry *E = getTreeEntry(TE.getMainOp()); + E && E->getVectorFactor() == TE.getVectorFactor() && + E->isSame(TE.Scalars)) { + // Some gather nodes might be absolutely the same as some vectorizable + // nodes after reordering, need to handle it. + LLVM_DEBUG(dbgs() << "SLP: Adding cost 0 for bundle that starts with " + << *TE.Scalars[0] << ".\n" + << "SLP: Current total cost = " << Cost << "\n"); + continue; + } + } InstructionCost C = getEntryCost(&TE, VectorizedVals); Cost += C; @@ -7073,24 +7927,25 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { if (isa<FixedVectorType>(EU.Scalar->getType())) continue; - // Already counted the cost for external uses when tried to adjust the cost - // for extractelements, no need to add it again. - if (isa<ExtractElementInst>(EU.Scalar)) - continue; - // If found user is an insertelement, do not calculate extract cost but try // to detect it as a final shuffled/identity match. if (auto *VU = dyn_cast_or_null<InsertElementInst>(EU.User)) { if (auto *FTy = dyn_cast<FixedVectorType>(VU->getType())) { - Optional<unsigned> InsertIdx = getInsertIndex(VU); + std::optional<unsigned> InsertIdx = getInsertIndex(VU); if (InsertIdx) { const TreeEntry *ScalarTE = getTreeEntry(EU.Scalar); - auto *It = - find_if(FirstUsers, - [VU](const std::pair<Value *, const TreeEntry *> &Pair) { - return areTwoInsertFromSameBuildVector( - VU, cast<InsertElementInst>(Pair.first)); - }); + auto *It = find_if( + FirstUsers, + [this, VU](const std::pair<Value *, const TreeEntry *> &Pair) { + return areTwoInsertFromSameBuildVector( + VU, cast<InsertElementInst>(Pair.first), + [this](InsertElementInst *II) -> Value * { + Value *Op0 = II->getOperand(0); + if (getTreeEntry(II) && !getTreeEntry(Op0)) + return nullptr; + return Op0; + }); + }); int VecId = -1; if (It == FirstUsers.end()) { (void)ShuffleMasks.emplace_back(); @@ -7142,6 +7997,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { // extend the extracted value back to the original type. Here, we account // for the extract and the added cost of the sign extend if needed. auto *VecTy = FixedVectorType::get(EU.Scalar->getType(), BundleWidth); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; auto *ScalarRoot = VectorizableTree[0]->Scalars[0]; if (MinBWs.count(ScalarRoot)) { auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); @@ -7151,14 +8007,15 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { ExtractCost += TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(), VecTy, EU.Lane); } else { - ExtractCost += - TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); + ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, + CostKind, EU.Lane); } } InstructionCost SpillCost = getSpillCost(); Cost += SpillCost + ExtractCost; - auto &&ResizeToVF = [this, &Cost](const TreeEntry *TE, ArrayRef<int> Mask) { + auto &&ResizeToVF = [this, &Cost](const TreeEntry *TE, ArrayRef<int> Mask, + bool) { InstructionCost C = 0; unsigned VF = Mask.size(); unsigned VecVF = TE->getVectorFactor(); @@ -7220,12 +8077,12 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { return TEs.back(); }; (void)performExtractsShuffleAction<const TreeEntry>( - makeMutableArrayRef(Vector.data(), Vector.size()), Base, + MutableArrayRef(Vector.data(), Vector.size()), Base, [](const TreeEntry *E) { return E->getVectorFactor(); }, ResizeToVF, EstimateShufflesCost); InstructionCost InsertCost = TTI->getScalarizationOverhead( cast<FixedVectorType>(FirstUsers[I].first->getType()), DemandedElts[I], - /*Insert*/ true, /*Extract*/ false); + /*Insert*/ true, /*Extract*/ false, TTI::TCK_RecipThroughput); Cost -= InsertCost; } @@ -7245,22 +8102,89 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { return Cost; } -Optional<TargetTransformInfo::ShuffleKind> -BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, +std::optional<TargetTransformInfo::ShuffleKind> +BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, ArrayRef<Value *> VL, + SmallVectorImpl<int> &Mask, SmallVectorImpl<const TreeEntry *> &Entries) { + Entries.clear(); + // No need to check for the topmost gather node. + if (TE == VectorizableTree.front().get()) + return std::nullopt; + Mask.assign(VL.size(), UndefMaskElem); + assert(TE->UserTreeIndices.size() == 1 && + "Expected only single user of the gather node."); // TODO: currently checking only for Scalars in the tree entry, need to count // reused elements too for better cost estimation. - Mask.assign(TE->Scalars.size(), UndefMaskElem); - Entries.clear(); + Instruction &UserInst = + getLastInstructionInBundle(TE->UserTreeIndices.front().UserTE); + auto *PHI = dyn_cast<PHINode>(&UserInst); + auto *NodeUI = DT->getNode( + PHI ? PHI->getIncomingBlock(TE->UserTreeIndices.front().EdgeIdx) + : UserInst.getParent()); + assert(NodeUI && "Should only process reachable instructions"); + SmallPtrSet<Value *, 4> GatheredScalars(VL.begin(), VL.end()); + auto CheckOrdering = [&](Instruction *LastEI) { + // Check if the user node of the TE comes after user node of EntryPtr, + // otherwise EntryPtr depends on TE. + // Gather nodes usually are not scheduled and inserted before their first + // user node. So, instead of checking dependency between the gather nodes + // themselves, we check the dependency between their user nodes. + // If one user node comes before the second one, we cannot use the second + // gather node as the source vector for the first gather node, because in + // the list of instructions it will be emitted later. + auto *EntryParent = LastEI->getParent(); + auto *NodeEUI = DT->getNode(EntryParent); + if (!NodeEUI) + return false; + assert((NodeUI == NodeEUI) == + (NodeUI->getDFSNumIn() == NodeEUI->getDFSNumIn()) && + "Different nodes should have different DFS numbers"); + // Check the order of the gather nodes users. + if (UserInst.getParent() != EntryParent && + (DT->dominates(NodeUI, NodeEUI) || !DT->dominates(NodeEUI, NodeUI))) + return false; + if (UserInst.getParent() == EntryParent && UserInst.comesBefore(LastEI)) + return false; + return true; + }; // Build a lists of values to tree entries. DenseMap<Value *, SmallPtrSet<const TreeEntry *, 4>> ValueToTEs; for (const std::unique_ptr<TreeEntry> &EntryPtr : VectorizableTree) { if (EntryPtr.get() == TE) - break; + continue; if (EntryPtr->State != TreeEntry::NeedToGather) continue; + if (!any_of(EntryPtr->Scalars, [&GatheredScalars](Value *V) { + return GatheredScalars.contains(V); + })) + continue; + assert(EntryPtr->UserTreeIndices.size() == 1 && + "Expected only single user of the gather node."); + Instruction &EntryUserInst = + getLastInstructionInBundle(EntryPtr->UserTreeIndices.front().UserTE); + if (&UserInst == &EntryUserInst) { + // If 2 gathers are operands of the same entry, compare operands indices, + // use the earlier one as the base. + if (TE->UserTreeIndices.front().UserTE == + EntryPtr->UserTreeIndices.front().UserTE && + TE->UserTreeIndices.front().EdgeIdx < + EntryPtr->UserTreeIndices.front().EdgeIdx) + continue; + } + // Check if the user node of the TE comes after user node of EntryPtr, + // otherwise EntryPtr depends on TE. + auto *EntryPHI = dyn_cast<PHINode>(&EntryUserInst); + auto *EntryI = + EntryPHI + ? EntryPHI + ->getIncomingBlock(EntryPtr->UserTreeIndices.front().EdgeIdx) + ->getTerminator() + : &EntryUserInst; + if (!CheckOrdering(EntryI)) + continue; for (Value *V : EntryPtr->Scalars) - ValueToTEs.try_emplace(V).first->getSecond().insert(EntryPtr.get()); + if (!isConstant(V)) + ValueToTEs.try_emplace(V).first->getSecond().insert(EntryPtr.get()); } // Find all tree entries used by the gathered values. If no common entries // found - not a shuffle. @@ -7272,7 +8196,7 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, SmallVector<SmallPtrSet<const TreeEntry *, 4>> UsedTEs; DenseMap<Value *, int> UsedValuesEntry; for (Value *V : TE->Scalars) { - if (isa<UndefValue>(V)) + if (isConstant(V)) continue; // Build a list of tree entries where V is used. SmallPtrSet<const TreeEntry *, 4> VToTEs; @@ -7282,10 +8206,11 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, if (const TreeEntry *VTE = getTreeEntry(V)) VToTEs.insert(VTE); if (VToTEs.empty()) - return None; + continue; if (UsedTEs.empty()) { // The first iteration, just insert the list of nodes to vector. UsedTEs.push_back(VToTEs); + UsedValuesEntry.try_emplace(V, 0); } else { // Need to check if there are any previously used tree nodes which use V. // If there are no such nodes, consider that we have another one input @@ -7310,8 +8235,9 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, if (Idx == UsedTEs.size()) { // If the number of input vectors is greater than 2 - not a permutation, // fallback to the regular gather. + // TODO: support multiple reshuffled nodes. if (UsedTEs.size() == 2) - return None; + continue; UsedTEs.push_back(SavedVToTEs); Idx = UsedTEs.size() - 1; } @@ -7319,32 +8245,55 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, } } - if (UsedTEs.empty()) { - assert(all_of(TE->Scalars, UndefValue::classof) && - "Expected vector of undefs only."); - return None; - } + if (UsedTEs.empty()) + return std::nullopt; unsigned VF = 0; if (UsedTEs.size() == 1) { + // Keep the order to avoid non-determinism. + SmallVector<const TreeEntry *> FirstEntries(UsedTEs.front().begin(), + UsedTEs.front().end()); + sort(FirstEntries, [](const TreeEntry *TE1, const TreeEntry *TE2) { + return TE1->Idx < TE2->Idx; + }); // Try to find the perfect match in another gather node at first. - auto It = find_if(UsedTEs.front(), [TE](const TreeEntry *EntryPtr) { - return EntryPtr->isSame(TE->Scalars); + auto *It = find_if(FirstEntries, [=](const TreeEntry *EntryPtr) { + return EntryPtr->isSame(VL) || EntryPtr->isSame(TE->Scalars); }); - if (It != UsedTEs.front().end()) { + if (It != FirstEntries.end()) { Entries.push_back(*It); std::iota(Mask.begin(), Mask.end(), 0); + // Clear undef scalars. + for (int I = 0, Sz = VL.size(); I < Sz; ++I) + if (isa<PoisonValue>(TE->Scalars[I])) + Mask[I] = UndefMaskElem; return TargetTransformInfo::SK_PermuteSingleSrc; } - // No perfect match, just shuffle, so choose the first tree node. - Entries.push_back(*UsedTEs.front().begin()); + // No perfect match, just shuffle, so choose the first tree node from the + // tree. + Entries.push_back(FirstEntries.front()); } else { // Try to find nodes with the same vector factor. assert(UsedTEs.size() == 2 && "Expected at max 2 permuted entries."); + // Keep the order of tree nodes to avoid non-determinism. DenseMap<int, const TreeEntry *> VFToTE; - for (const TreeEntry *TE : UsedTEs.front()) - VFToTE.try_emplace(TE->getVectorFactor(), TE); - for (const TreeEntry *TE : UsedTEs.back()) { + for (const TreeEntry *TE : UsedTEs.front()) { + unsigned VF = TE->getVectorFactor(); + auto It = VFToTE.find(VF); + if (It != VFToTE.end()) { + if (It->second->Idx > TE->Idx) + It->getSecond() = TE; + continue; + } + VFToTE.try_emplace(VF, TE); + } + // Same, keep the order to avoid non-determinism. + SmallVector<const TreeEntry *> SecondEntries(UsedTEs.back().begin(), + UsedTEs.back().end()); + sort(SecondEntries, [](const TreeEntry *TE1, const TreeEntry *TE2) { + return TE1->Idx < TE2->Idx; + }); + for (const TreeEntry *TE : SecondEntries) { auto It = VFToTE.find(TE->getVectorFactor()); if (It != VFToTE.end()) { VF = It->first; @@ -7356,40 +8305,135 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, // No 2 source vectors with the same vector factor - give up and do regular // gather. if (Entries.empty()) - return None; - } - + return std::nullopt; + } + + bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, UndefValue::classof); + // Checks if the 2 PHIs are compatible in terms of high possibility to be + // vectorized. + auto AreCompatiblePHIs = [&](Value *V, Value *V1) { + auto *PHI = cast<PHINode>(V); + auto *PHI1 = cast<PHINode>(V1); + // Check that all incoming values are compatible/from same parent (if they + // are instructions). + // The incoming values are compatible if they all are constants, or + // instruction with the same/alternate opcodes from the same basic block. + for (int I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) { + Value *In = PHI->getIncomingValue(I); + Value *In1 = PHI1->getIncomingValue(I); + if (isConstant(In) && isConstant(In1)) + continue; + if (!getSameOpcode({In, In1}, *TLI).getOpcode()) + return false; + if (cast<Instruction>(In)->getParent() != + cast<Instruction>(In1)->getParent()) + return false; + } + return true; + }; + // Check if the value can be ignored during analysis for shuffled gathers. + // We suppose it is better to ignore instruction, which do not form splats, + // are not vectorized/not extractelements (these instructions will be handled + // by extractelements processing) or may form vector node in future. + auto MightBeIgnored = [=](Value *V) { + auto *I = dyn_cast<Instruction>(V); + SmallVector<Value *> IgnoredVals; + if (UserIgnoreList) + IgnoredVals.assign(UserIgnoreList->begin(), UserIgnoreList->end()); + return I && !IsSplatOrUndefs && !ScalarToTreeEntry.count(I) && + !isVectorLikeInstWithConstOps(I) && + !areAllUsersVectorized(I, IgnoredVals) && isSimple(I); + }; + // Check that the neighbor instruction may form a full vector node with the + // current instruction V. It is possible, if they have same/alternate opcode + // and same parent basic block. + auto NeighborMightBeIgnored = [&](Value *V, int Idx) { + Value *V1 = VL[Idx]; + bool UsedInSameVTE = false; + auto It = UsedValuesEntry.find(V1); + if (It != UsedValuesEntry.end()) + UsedInSameVTE = It->second == UsedValuesEntry.find(V)->second; + return V != V1 && MightBeIgnored(V1) && !UsedInSameVTE && + getSameOpcode({V, V1}, *TLI).getOpcode() && + cast<Instruction>(V)->getParent() == + cast<Instruction>(V1)->getParent() && + (!isa<PHINode>(V1) || AreCompatiblePHIs(V, V1)); + }; // Build a shuffle mask for better cost estimation and vector emission. - for (int I = 0, E = TE->Scalars.size(); I < E; ++I) { - Value *V = TE->Scalars[I]; - if (isa<UndefValue>(V)) + SmallBitVector UsedIdxs(Entries.size()); + SmallVector<std::pair<unsigned, int>> EntryLanes; + for (int I = 0, E = VL.size(); I < E; ++I) { + Value *V = VL[I]; + auto It = UsedValuesEntry.find(V); + if (It == UsedValuesEntry.end()) continue; - unsigned Idx = UsedValuesEntry.lookup(V); - const TreeEntry *VTE = Entries[Idx]; - int FoundLane = VTE->findLaneForValue(V); - Mask[I] = Idx * VF + FoundLane; - // Extra check required by isSingleSourceMaskImpl function (called by - // ShuffleVectorInst::isSingleSourceMask). - if (Mask[I] >= 2 * E) - return None; + // Do not try to shuffle scalars, if they are constants, or instructions + // that can be vectorized as a result of the following vector build + // vectorization. + if (isConstant(V) || (MightBeIgnored(V) && + ((I > 0 && NeighborMightBeIgnored(V, I - 1)) || + (I != E - 1 && NeighborMightBeIgnored(V, I + 1))))) + continue; + unsigned Idx = It->second; + EntryLanes.emplace_back(Idx, I); + UsedIdxs.set(Idx); + } + // Iterate through all shuffled scalars and select entries, which can be used + // for final shuffle. + SmallVector<const TreeEntry *> TempEntries; + for (unsigned I = 0, Sz = Entries.size(); I < Sz; ++I) { + if (!UsedIdxs.test(I)) + continue; + // Fix the entry number for the given scalar. If it is the first entry, set + // Pair.first to 0, otherwise to 1 (currently select at max 2 nodes). + // These indices are used when calculating final shuffle mask as the vector + // offset. + for (std::pair<unsigned, int> &Pair : EntryLanes) + if (Pair.first == I) + Pair.first = TempEntries.size(); + TempEntries.push_back(Entries[I]); + } + Entries.swap(TempEntries); + if (EntryLanes.size() == Entries.size() && !VL.equals(TE->Scalars)) { + // We may have here 1 or 2 entries only. If the number of scalars is equal + // to the number of entries, no need to do the analysis, it is not very + // profitable. Since VL is not the same as TE->Scalars, it means we already + // have some shuffles before. Cut off not profitable case. + Entries.clear(); + return std::nullopt; + } + // Build the final mask, check for the identity shuffle, if possible. + bool IsIdentity = Entries.size() == 1; + // Pair.first is the offset to the vector, while Pair.second is the index of + // scalar in the list. + for (const std::pair<unsigned, int> &Pair : EntryLanes) { + Mask[Pair.second] = Pair.first * VF + + Entries[Pair.first]->findLaneForValue(VL[Pair.second]); + IsIdentity &= Mask[Pair.second] == Pair.second; } switch (Entries.size()) { case 1: - return TargetTransformInfo::SK_PermuteSingleSrc; + if (IsIdentity || EntryLanes.size() > 1 || VL.size() <= 2) + return TargetTransformInfo::SK_PermuteSingleSrc; + break; case 2: - return TargetTransformInfo::SK_PermuteTwoSrc; + if (EntryLanes.size() > 2 || VL.size() <= 2) + return TargetTransformInfo::SK_PermuteTwoSrc; + break; default: break; } - return None; + Entries.clear(); + return std::nullopt; } InstructionCost BoUpSLP::getGatherCost(FixedVectorType *Ty, const APInt &ShuffledIndices, bool NeedToShuffle) const { + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; InstructionCost Cost = TTI->getScalarizationOverhead(Ty, ~ShuffledIndices, /*Insert*/ true, - /*Extract*/ false); + /*Extract*/ false, CostKind); if (NeedToShuffle) Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, Ty); return Cost; @@ -7425,22 +8469,20 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const { // Perform operand reordering on the instructions in VL and return the reordered // operands in Left and Right. -void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, - SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right, - const DataLayout &DL, - ScalarEvolution &SE, - const BoUpSLP &R) { +void BoUpSLP::reorderInputsAccordingToOpcode( + ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, + SmallVectorImpl<Value *> &Right, const TargetLibraryInfo &TLI, + const DataLayout &DL, ScalarEvolution &SE, const BoUpSLP &R) { if (VL.empty()) return; - VLOperands Ops(VL, DL, SE, R); + VLOperands Ops(VL, TLI, DL, SE, R); // Reorder the operands in place. Ops.reorder(); Left = Ops.getVL(0); Right = Ops.getVL(1); } -void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { +Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { // Get the basic block this bundle is in. All instructions in the bundle // should be in this block (except for extractelement-like instructions with // constant indeces). @@ -7489,13 +8531,34 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { return LastInst; }; - auto &&FindFirstInst = [E, Front]() { + auto &&FindFirstInst = [E, Front, this]() { Instruction *FirstInst = Front; for (Value *V : E->Scalars) { auto *I = dyn_cast<Instruction>(V); if (!I) continue; - if (I->comesBefore(FirstInst)) + if (FirstInst->getParent() == I->getParent()) { + if (I->comesBefore(FirstInst)) + FirstInst = I; + continue; + } + assert(isVectorLikeInstWithConstOps(FirstInst) && + isVectorLikeInstWithConstOps(I) && + "Expected vector-like insts only."); + if (!DT->isReachableFromEntry(FirstInst->getParent())) { + FirstInst = I; + continue; + } + if (!DT->isReachableFromEntry(I->getParent())) + continue; + auto *NodeA = DT->getNode(FirstInst->getParent()); + auto *NodeB = DT->getNode(I->getParent()); + assert(NodeA && "Should only process reachable instructions"); + assert(NodeB && "Should only process reachable instructions"); + assert((NodeA == NodeB) == + (NodeA->getDFSNumIn() == NodeB->getDFSNumIn()) && + "Different nodes should have different DFS numbers"); + if (NodeA->getDFSNumIn() > NodeB->getDFSNumIn()) FirstInst = I; } return FirstInst; @@ -7504,19 +8567,16 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { // Set the insert point to the beginning of the basic block if the entry // should not be scheduled. if (E->State != TreeEntry::NeedToGather && - doesNotNeedToSchedule(E->Scalars)) { + (doesNotNeedToSchedule(E->Scalars) || + all_of(E->Scalars, isVectorLikeInstWithConstOps))) { Instruction *InsertInst; - if (all_of(E->Scalars, isUsedOutsideBlock)) + if (all_of(E->Scalars, [](Value *V) { + return !isVectorLikeInstWithConstOps(V) && isUsedOutsideBlock(V); + })) InsertInst = FindLastInst(); else InsertInst = FindFirstInst(); - // If the instruction is PHI, set the insert point after all the PHIs. - if (isa<PHINode>(InsertInst)) - InsertInst = BB->getFirstNonPHI(); - BasicBlock::iterator InsertPt = InsertInst->getIterator(); - Builder.SetInsertPoint(BB, InsertPt); - Builder.SetCurrentDebugLocation(Front->getDebugLoc()); - return; + return *InsertInst; } // The last instruction in the bundle in program order. @@ -7555,17 +8615,29 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { // not ideal. However, this should be exceedingly rare since it requires that // we both exit early from buildTree_rec and that the bundle be out-of-order // (causing us to iterate all the way to the end of the block). - if (!LastInst) { + if (!LastInst) LastInst = FindLastInst(); - // If the instruction is PHI, set the insert point after all the PHIs. - if (isa<PHINode>(LastInst)) - LastInst = BB->getFirstNonPHI()->getPrevNode(); - } assert(LastInst && "Failed to find last instruction in bundle"); + return *LastInst; +} - // Set the insertion point after the last instruction in the bundle. Set the - // debug location to Front. - Builder.SetInsertPoint(BB, std::next(LastInst->getIterator())); +void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { + auto *Front = E->getMainOp(); + Instruction *LastInst = EntryToLastInstruction.lookup(E); + assert(LastInst && "Failed to find last instruction in bundle"); + // If the instruction is PHI, set the insert point after all the PHIs. + bool IsPHI = isa<PHINode>(LastInst); + if (IsPHI) + LastInst = LastInst->getParent()->getFirstNonPHI(); + if (IsPHI || (E->State != TreeEntry::NeedToGather && + doesNotNeedToSchedule(E->Scalars))) { + Builder.SetInsertPoint(LastInst); + } else { + // Set the insertion point after the last instruction in the bundle. Set the + // debug location to Front. + Builder.SetInsertPoint(LastInst->getParent(), + std::next(LastInst->getIterator())); + } Builder.SetCurrentDebugLocation(Front->getDebugLoc()); } @@ -7596,7 +8668,7 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL) { auto *InsElt = dyn_cast<InsertElementInst>(Vec); if (!InsElt) return Vec; - GatherShuffleSeq.insert(InsElt); + GatherShuffleExtractSeq.insert(InsElt); CSEBlocks.insert(InsElt->getParent()); // Add to our 'need-to-extract' list. if (TreeEntry *Entry = getTreeEntry(V)) { @@ -7632,196 +8704,452 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL) { return Vec; } -namespace { -/// Merges shuffle masks and emits final shuffle instruction, if required. -class ShuffleInstructionBuilder { - IRBuilderBase &Builder; - const unsigned VF = 0; +/// Merges shuffle masks and emits final shuffle instruction, if required. It +/// supports shuffling of 2 input vectors. It implements lazy shuffles emission, +/// when the actual shuffle instruction is generated only if this is actually +/// required. Otherwise, the shuffle instruction emission is delayed till the +/// end of the process, to reduce the number of emitted instructions and further +/// analysis/transformations. +/// The class also will look through the previously emitted shuffle instructions +/// and properly mark indices in mask as undef. +/// For example, given the code +/// \code +/// %s1 = shufflevector <2 x ty> %0, poison, <1, 0> +/// %s2 = shufflevector <2 x ty> %1, poison, <1, 0> +/// \endcode +/// and if need to emit shuffle of %s1 and %s2 with mask <1, 0, 3, 2>, it will +/// look through %s1 and %s2 and emit +/// \code +/// %res = shufflevector <2 x ty> %0, %1, <0, 1, 2, 3> +/// \endcode +/// instead. +/// If 2 operands are of different size, the smallest one will be resized and +/// the mask recalculated properly. +/// For example, given the code +/// \code +/// %s1 = shufflevector <2 x ty> %0, poison, <1, 0, 1, 0> +/// %s2 = shufflevector <2 x ty> %1, poison, <1, 0, 1, 0> +/// \endcode +/// and if need to emit shuffle of %s1 and %s2 with mask <1, 0, 5, 4>, it will +/// look through %s1 and %s2 and emit +/// \code +/// %res = shufflevector <2 x ty> %0, %1, <0, 1, 2, 3> +/// \endcode +/// instead. +class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { bool IsFinalized = false; - SmallVector<int, 4> Mask; - /// Holds all of the instructions that we gathered. - SetVector<Instruction *> &GatherShuffleSeq; - /// A list of blocks that we are going to CSE. - SetVector<BasicBlock *> &CSEBlocks; + /// Combined mask for all applied operands and masks. It is built during + /// analysis and actual emission of shuffle vector instructions. + SmallVector<int> CommonMask; + /// List of operands for the shuffle vector instruction. It hold at max 2 + /// operands, if the 3rd is going to be added, the first 2 are combined into + /// shuffle with \p CommonMask mask, the first operand sets to be the + /// resulting shuffle and the second operand sets to be the newly added + /// operand. The \p CommonMask is transformed in the proper way after that. + SmallVector<Value *, 2> InVectors; + IRBuilderBase &Builder; + BoUpSLP &R; -public: - ShuffleInstructionBuilder(IRBuilderBase &Builder, unsigned VF, - SetVector<Instruction *> &GatherShuffleSeq, - SetVector<BasicBlock *> &CSEBlocks) - : Builder(Builder), VF(VF), GatherShuffleSeq(GatherShuffleSeq), - CSEBlocks(CSEBlocks) {} - - /// Adds a mask, inverting it before applying. - void addInversedMask(ArrayRef<unsigned> SubMask) { - if (SubMask.empty()) - return; - SmallVector<int, 4> NewMask; - inversePermutation(SubMask, NewMask); - addMask(NewMask); - } + class ShuffleIRBuilder { + IRBuilderBase &Builder; + /// Holds all of the instructions that we gathered. + SetVector<Instruction *> &GatherShuffleExtractSeq; + /// A list of blocks that we are going to CSE. + SetVector<BasicBlock *> &CSEBlocks; + + public: + ShuffleIRBuilder(IRBuilderBase &Builder, + SetVector<Instruction *> &GatherShuffleExtractSeq, + SetVector<BasicBlock *> &CSEBlocks) + : Builder(Builder), GatherShuffleExtractSeq(GatherShuffleExtractSeq), + CSEBlocks(CSEBlocks) {} + ~ShuffleIRBuilder() = default; + /// Creates shufflevector for the 2 operands with the given mask. + Value *createShuffleVector(Value *V1, Value *V2, ArrayRef<int> Mask) { + Value *Vec = Builder.CreateShuffleVector(V1, V2, Mask); + if (auto *I = dyn_cast<Instruction>(Vec)) { + GatherShuffleExtractSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + return Vec; + } + /// Creates permutation of the single vector operand with the given mask, if + /// it is not identity mask. + Value *createShuffleVector(Value *V1, ArrayRef<int> Mask) { + if (Mask.empty()) + return V1; + unsigned VF = Mask.size(); + unsigned LocalVF = cast<FixedVectorType>(V1->getType())->getNumElements(); + if (VF == LocalVF && ShuffleVectorInst::isIdentityMask(Mask)) + return V1; + Value *Vec = Builder.CreateShuffleVector(V1, Mask); + if (auto *I = dyn_cast<Instruction>(Vec)) { + GatherShuffleExtractSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + return Vec; + } + /// Resizes 2 input vector to match the sizes, if the they are not equal + /// yet. The smallest vector is resized to the size of the larger vector. + void resizeToMatch(Value *&V1, Value *&V2) { + if (V1->getType() == V2->getType()) + return; + int V1VF = cast<FixedVectorType>(V1->getType())->getNumElements(); + int V2VF = cast<FixedVectorType>(V2->getType())->getNumElements(); + int VF = std::max(V1VF, V2VF); + int MinVF = std::min(V1VF, V2VF); + SmallVector<int> IdentityMask(VF, UndefMaskElem); + std::iota(IdentityMask.begin(), std::next(IdentityMask.begin(), MinVF), + 0); + Value *&Op = MinVF == V1VF ? V1 : V2; + Op = Builder.CreateShuffleVector(Op, IdentityMask); + if (auto *I = dyn_cast<Instruction>(Op)) { + GatherShuffleExtractSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + if (MinVF == V1VF) + V1 = Op; + else + V2 = Op; + } + }; - /// Functions adds masks, merging them into single one. - void addMask(ArrayRef<unsigned> SubMask) { - SmallVector<int, 4> NewMask(SubMask.begin(), SubMask.end()); - addMask(NewMask); + /// Smart shuffle instruction emission, walks through shuffles trees and + /// tries to find the best matching vector for the actual shuffle + /// instruction. + Value *createShuffle(Value *V1, Value *V2, ArrayRef<int> Mask) { + assert(V1 && "Expected at least one vector value."); + ShuffleIRBuilder ShuffleBuilder(Builder, R.GatherShuffleExtractSeq, + R.CSEBlocks); + return BaseShuffleAnalysis::createShuffle(V1, V2, Mask, ShuffleBuilder); } - void addMask(ArrayRef<int> SubMask) { ::addMask(Mask, SubMask); } + /// Transforms mask \p CommonMask per given \p Mask to make proper set after + /// shuffle emission. + static void transformMaskAfterShuffle(MutableArrayRef<int> CommonMask, + ArrayRef<int> Mask) { + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (Mask[Idx] != UndefMaskElem) + CommonMask[Idx] = Idx; + } - Value *finalize(Value *V) { +public: + ShuffleInstructionBuilder(IRBuilderBase &Builder, BoUpSLP &R) + : Builder(Builder), R(R) {} + + /// Adds 2 input vectors and the mask for their shuffling. + void add(Value *V1, Value *V2, ArrayRef<int> Mask) { + assert(V1 && V2 && !Mask.empty() && "Expected non-empty input vectors."); + if (InVectors.empty()) { + InVectors.push_back(V1); + InVectors.push_back(V2); + CommonMask.assign(Mask.begin(), Mask.end()); + return; + } + Value *Vec = InVectors.front(); + if (InVectors.size() == 2) { + Vec = createShuffle(Vec, InVectors.back(), CommonMask); + transformMaskAfterShuffle(CommonMask, Mask); + } else if (cast<FixedVectorType>(Vec->getType())->getNumElements() != + Mask.size()) { + Vec = createShuffle(Vec, nullptr, CommonMask); + transformMaskAfterShuffle(CommonMask, Mask); + } + V1 = createShuffle(V1, V2, Mask); + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (Mask[Idx] != UndefMaskElem) + CommonMask[Idx] = Idx + Sz; + InVectors.front() = Vec; + if (InVectors.size() == 2) + InVectors.back() = V1; + else + InVectors.push_back(V1); + } + /// Adds another one input vector and the mask for the shuffling. + void add(Value *V1, ArrayRef<int> Mask) { + if (InVectors.empty()) { + if (!isa<FixedVectorType>(V1->getType())) { + V1 = createShuffle(V1, nullptr, CommonMask); + CommonMask.assign(Mask.size(), UndefMaskElem); + transformMaskAfterShuffle(CommonMask, Mask); + } + InVectors.push_back(V1); + CommonMask.assign(Mask.begin(), Mask.end()); + return; + } + const auto *It = find(InVectors, V1); + if (It == InVectors.end()) { + if (InVectors.size() == 2 || + InVectors.front()->getType() != V1->getType() || + !isa<FixedVectorType>(V1->getType())) { + Value *V = InVectors.front(); + if (InVectors.size() == 2) { + V = createShuffle(InVectors.front(), InVectors.back(), CommonMask); + transformMaskAfterShuffle(CommonMask, CommonMask); + } else if (cast<FixedVectorType>(V->getType())->getNumElements() != + CommonMask.size()) { + V = createShuffle(InVectors.front(), nullptr, CommonMask); + transformMaskAfterShuffle(CommonMask, CommonMask); + } + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (CommonMask[Idx] == UndefMaskElem && Mask[Idx] != UndefMaskElem) + CommonMask[Idx] = + V->getType() != V1->getType() + ? Idx + Sz + : Mask[Idx] + cast<FixedVectorType>(V1->getType()) + ->getNumElements(); + if (V->getType() != V1->getType()) + V1 = createShuffle(V1, nullptr, Mask); + InVectors.front() = V; + if (InVectors.size() == 2) + InVectors.back() = V1; + else + InVectors.push_back(V1); + return; + } + // Check if second vector is required if the used elements are already + // used from the first one. + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (Mask[Idx] != UndefMaskElem && CommonMask[Idx] == UndefMaskElem) { + InVectors.push_back(V1); + break; + } + } + int VF = CommonMask.size(); + if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType())) + VF = FTy->getNumElements(); + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (Mask[Idx] != UndefMaskElem && CommonMask[Idx] == UndefMaskElem) + CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF); + } + /// Adds another one input vector and the mask for the shuffling. + void addOrdered(Value *V1, ArrayRef<unsigned> Order) { + SmallVector<int> NewMask; + inversePermutation(Order, NewMask); + add(V1, NewMask); + } + /// Finalize emission of the shuffles. + Value * + finalize(ArrayRef<int> ExtMask = std::nullopt) { IsFinalized = true; - unsigned ValueVF = cast<FixedVectorType>(V->getType())->getNumElements(); - if (VF == ValueVF && Mask.empty()) - return V; - SmallVector<int, 4> NormalizedMask(VF, UndefMaskElem); - std::iota(NormalizedMask.begin(), NormalizedMask.end(), 0); - addMask(NormalizedMask); - - if (VF == ValueVF && ShuffleVectorInst::isIdentityMask(Mask)) - return V; - Value *Vec = Builder.CreateShuffleVector(V, Mask, "shuffle"); - if (auto *I = dyn_cast<Instruction>(Vec)) { - GatherShuffleSeq.insert(I); - CSEBlocks.insert(I->getParent()); + if (!ExtMask.empty()) { + if (CommonMask.empty()) { + CommonMask.assign(ExtMask.begin(), ExtMask.end()); + } else { + SmallVector<int> NewMask(ExtMask.size(), UndefMaskElem); + for (int I = 0, Sz = ExtMask.size(); I < Sz; ++I) { + if (ExtMask[I] == UndefMaskElem) + continue; + NewMask[I] = CommonMask[ExtMask[I]]; + } + CommonMask.swap(NewMask); + } } - return Vec; + if (CommonMask.empty()) { + assert(InVectors.size() == 1 && "Expected only one vector with no mask"); + return InVectors.front(); + } + if (InVectors.size() == 2) + return createShuffle(InVectors.front(), InVectors.back(), CommonMask); + return createShuffle(InVectors.front(), nullptr, CommonMask); } ~ShuffleInstructionBuilder() { - assert((IsFinalized || Mask.empty()) && + assert((IsFinalized || CommonMask.empty()) && "Shuffle construction must be finalized."); } }; -} // namespace -Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { +Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx) { + ArrayRef<Value *> VL = E->getOperand(NodeIdx); const unsigned VF = VL.size(); - InstructionsState S = getSameOpcode(VL); + InstructionsState S = getSameOpcode(VL, *TLI); // Special processing for GEPs bundle, which may include non-gep values. if (!S.getOpcode() && VL.front()->getType()->isPointerTy()) { const auto *It = find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }); if (It != VL.end()) - S = getSameOpcode(*It); + S = getSameOpcode(*It, *TLI); } if (S.getOpcode()) { - if (TreeEntry *E = getTreeEntry(S.OpValue)) - if (E->isSame(VL)) { - Value *V = vectorizeTree(E); - if (VF != cast<FixedVectorType>(V->getType())->getNumElements()) { - if (!E->ReuseShuffleIndices.empty()) { - // Reshuffle to get only unique values. - // If some of the scalars are duplicated in the vectorization tree - // entry, we do not vectorize them but instead generate a mask for - // the reuses. But if there are several users of the same entry, - // they may have different vectorization factors. This is especially - // important for PHI nodes. In this case, we need to adapt the - // resulting instruction for the user vectorization factor and have - // to reshuffle it again to take only unique elements of the vector. - // Without this code the function incorrectly returns reduced vector - // instruction with the same elements, not with the unique ones. - - // block: - // %phi = phi <2 x > { .., %entry} {%shuffle, %block} - // %2 = shuffle <2 x > %phi, poison, <4 x > <1, 1, 0, 0> - // ... (use %2) - // %shuffle = shuffle <2 x> %2, poison, <2 x> {2, 0} - // br %block - SmallVector<int> UniqueIdxs(VF, UndefMaskElem); - SmallSet<int, 4> UsedIdxs; - int Pos = 0; - int Sz = VL.size(); - for (int Idx : E->ReuseShuffleIndices) { - if (Idx != Sz && Idx != UndefMaskElem && - UsedIdxs.insert(Idx).second) - UniqueIdxs[Idx] = Pos; - ++Pos; - } - assert(VF >= UsedIdxs.size() && "Expected vectorization factor " - "less than original vector size."); - UniqueIdxs.append(VF - UsedIdxs.size(), UndefMaskElem); - V = Builder.CreateShuffleVector(V, UniqueIdxs, "shrink.shuffle"); - } else { - assert(VF < cast<FixedVectorType>(V->getType())->getNumElements() && - "Expected vectorization factor less " - "than original vector size."); - SmallVector<int> UniformMask(VF, 0); - std::iota(UniformMask.begin(), UniformMask.end(), 0); - V = Builder.CreateShuffleVector(V, UniformMask, "shrink.shuffle"); - } - if (auto *I = dyn_cast<Instruction>(V)) { - GatherShuffleSeq.insert(I); - CSEBlocks.insert(I->getParent()); + if (TreeEntry *VE = getTreeEntry(S.OpValue); + VE && VE->isSame(VL) && + (any_of(VE->UserTreeIndices, + [E, NodeIdx](const EdgeInfo &EI) { + return EI.UserTE == E && EI.EdgeIdx == NodeIdx; + }) || + any_of(VectorizableTree, + [E, NodeIdx, VE](const std::unique_ptr<TreeEntry> &TE) { + return TE->isOperandGatherNode({E, NodeIdx}) && + VE->isSame(TE->Scalars); + }))) { + auto FinalShuffle = [&](Value *V, ArrayRef<int> Mask) { + ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); + ShuffleBuilder.add(V, Mask); + return ShuffleBuilder.finalize(std::nullopt); + }; + Value *V = vectorizeTree(VE); + if (VF != cast<FixedVectorType>(V->getType())->getNumElements()) { + if (!VE->ReuseShuffleIndices.empty()) { + // Reshuffle to get only unique values. + // If some of the scalars are duplicated in the vectorization + // tree entry, we do not vectorize them but instead generate a + // mask for the reuses. But if there are several users of the + // same entry, they may have different vectorization factors. + // This is especially important for PHI nodes. In this case, we + // need to adapt the resulting instruction for the user + // vectorization factor and have to reshuffle it again to take + // only unique elements of the vector. Without this code the + // function incorrectly returns reduced vector instruction with + // the same elements, not with the unique ones. + + // block: + // %phi = phi <2 x > { .., %entry} {%shuffle, %block} + // %2 = shuffle <2 x > %phi, poison, <4 x > <1, 1, 0, 0> + // ... (use %2) + // %shuffle = shuffle <2 x> %2, poison, <2 x> {2, 0} + // br %block + SmallVector<int> UniqueIdxs(VF, UndefMaskElem); + SmallSet<int, 4> UsedIdxs; + int Pos = 0; + for (int Idx : VE->ReuseShuffleIndices) { + if (Idx != static_cast<int>(VF) && Idx != UndefMaskElem && + UsedIdxs.insert(Idx).second) + UniqueIdxs[Idx] = Pos; + ++Pos; } + assert(VF >= UsedIdxs.size() && "Expected vectorization factor " + "less than original vector size."); + UniqueIdxs.append(VF - UsedIdxs.size(), UndefMaskElem); + V = FinalShuffle(V, UniqueIdxs); + } else { + assert(VF < cast<FixedVectorType>(V->getType())->getNumElements() && + "Expected vectorization factor less " + "than original vector size."); + SmallVector<int> UniformMask(VF, 0); + std::iota(UniformMask.begin(), UniformMask.end(), 0); + V = FinalShuffle(V, UniformMask); } - return V; } + return V; + } } - // Can't vectorize this, so simply build a new vector with each lane - // corresponding to the requested value. - return createBuildVector(VL); + // Find the corresponding gather entry and vectorize it. + // Allows to be more accurate with tree/graph transformations, checks for the + // correctness of the transformations in many cases. + auto *I = find_if(VectorizableTree, + [E, NodeIdx](const std::unique_ptr<TreeEntry> &TE) { + return TE->isOperandGatherNode({E, NodeIdx}); + }); + assert(I != VectorizableTree.end() && "Gather node is not in the graph."); + assert(I->get()->UserTreeIndices.size() == 1 && + "Expected only single user for the gather node."); + assert(I->get()->isSame(VL) && "Expected same list of scalars."); + IRBuilder<>::InsertPointGuard Guard(Builder); + if (E->getOpcode() != Instruction::InsertElement && + E->getOpcode() != Instruction::PHI) { + Instruction *LastInst = EntryToLastInstruction.lookup(E); + assert(LastInst && "Failed to find last instruction in bundle"); + Builder.SetInsertPoint(LastInst); + } + return vectorizeTree(I->get()); } -Value *BoUpSLP::createBuildVector(ArrayRef<Value *> VL) { - assert(any_of(VectorizableTree, - [VL](const std::unique_ptr<TreeEntry> &TE) { - return TE->State == TreeEntry::NeedToGather && TE->isSame(VL); - }) && - "Non-matching gather node."); - unsigned VF = VL.size(); - // Exploit possible reuse of values across lanes. - SmallVector<int> ReuseShuffleIndicies; - SmallVector<Value *> UniqueValues; - if (VL.size() > 2) { + +Value *BoUpSLP::createBuildVector(const TreeEntry *E) { + assert(E->State == TreeEntry::NeedToGather && "Expected gather node."); + unsigned VF = E->getVectorFactor(); + + ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); + SmallVector<Value *> Gathered( + VF, PoisonValue::get(E->Scalars.front()->getType())); + bool NeedFreeze = false; + SmallVector<Value *> VL(E->Scalars.begin(), E->Scalars.end()); + // Build a mask out of the redorder indices and reorder scalars per this mask. + SmallVector<int> ReorderMask; + inversePermutation(E->ReorderIndices, ReorderMask); + if (!ReorderMask.empty()) + reorderScalars(VL, ReorderMask); + SmallVector<int> ReuseMask(VF, UndefMaskElem); + if (!allConstant(VL)) { + // For splats with can emit broadcasts instead of gathers, so try to find + // such sequences. + bool IsSplat = isSplat(VL) && (VL.size() > 2 || VL.front() == VL.back()); + SmallVector<int> UndefPos; DenseMap<Value *, unsigned> UniquePositions; - unsigned NumValues = - std::distance(VL.begin(), find_if(reverse(VL), [](Value *V) { - return !isa<UndefValue>(V); - }).base()); - VF = std::max<unsigned>(VF, PowerOf2Ceil(NumValues)); - int UniqueVals = 0; - for (Value *V : VL.drop_back(VL.size() - VF)) { + // Gather unique non-const values and all constant values. + // For repeated values, just shuffle them. + for (auto [I, V] : enumerate(VL)) { if (isa<UndefValue>(V)) { - ReuseShuffleIndicies.emplace_back(UndefMaskElem); + if (!isa<PoisonValue>(V)) { + Gathered[I] = V; + ReuseMask[I] = I; + UndefPos.push_back(I); + } continue; } if (isConstant(V)) { - ReuseShuffleIndicies.emplace_back(UniqueValues.size()); - UniqueValues.emplace_back(V); + Gathered[I] = V; + ReuseMask[I] = I; continue; } - auto Res = UniquePositions.try_emplace(V, UniqueValues.size()); - ReuseShuffleIndicies.emplace_back(Res.first->second); - if (Res.second) { - UniqueValues.emplace_back(V); - ++UniqueVals; - } - } - if (UniqueVals == 1 && UniqueValues.size() == 1) { - // Emit pure splat vector. - ReuseShuffleIndicies.append(VF - ReuseShuffleIndicies.size(), - UndefMaskElem); - } else if (UniqueValues.size() >= VF - 1 || UniqueValues.size() <= 1) { - if (UniqueValues.empty()) { - assert(all_of(VL, UndefValue::classof) && "Expected list of undefs."); - NumValues = VF; + if (IsSplat) { + Gathered.front() = V; + ReuseMask[I] = 0; + } else { + const auto Res = UniquePositions.try_emplace(V, I); + Gathered[Res.first->second] = V; + ReuseMask[I] = Res.first->second; + } + } + if (!UndefPos.empty() && IsSplat) { + // For undef values, try to replace them with the simple broadcast. + // We can do it if the broadcasted value is guaranteed to be + // non-poisonous, or by freezing the incoming scalar value first. + auto *It = find_if(Gathered, [this, E](Value *V) { + return !isa<UndefValue>(V) && + (getTreeEntry(V) || isGuaranteedNotToBePoison(V) || + any_of(V->uses(), [E](const Use &U) { + // Check if the value already used in the same operation in + // one of the nodes already. + return E->UserTreeIndices.size() == 1 && + is_contained( + E->UserTreeIndices.front().UserTE->Scalars, + U.getUser()) && + E->UserTreeIndices.front().EdgeIdx != U.getOperandNo(); + })); + }); + if (It != Gathered.end()) { + // Replace undefs by the non-poisoned scalars and emit broadcast. + int Pos = std::distance(Gathered.begin(), It); + for_each(UndefPos, [&](int I) { + // Set the undef position to the non-poisoned scalar. + ReuseMask[I] = Pos; + // Replace the undef by the poison, in the mask it is replaced by non-poisoned scalar already. + if (I != Pos) + Gathered[I] = PoisonValue::get(Gathered[I]->getType()); + }); + } else { + // Replace undefs by the poisons, emit broadcast and then emit + // freeze. + for_each(UndefPos, [&](int I) { + ReuseMask[I] = UndefMaskElem; + if (isa<UndefValue>(Gathered[I])) + Gathered[I] = PoisonValue::get(Gathered[I]->getType()); + }); + NeedFreeze = true; } - ReuseShuffleIndicies.clear(); - UniqueValues.clear(); - UniqueValues.append(VL.begin(), std::next(VL.begin(), NumValues)); } - UniqueValues.append(VF - UniqueValues.size(), - PoisonValue::get(VL[0]->getType())); - VL = UniqueValues; - } - - ShuffleInstructionBuilder ShuffleBuilder(Builder, VF, GatherShuffleSeq, - CSEBlocks); - Value *Vec = gather(VL); - if (!ReuseShuffleIndicies.empty()) { - ShuffleBuilder.addMask(ReuseShuffleIndicies); - Vec = ShuffleBuilder.finalize(Vec); - } + } else { + ReuseMask.clear(); + copy(VL, Gathered.begin()); + } + // Gather unique scalars and all constants. + Value *Vec = gather(Gathered); + ShuffleBuilder.add(Vec, ReuseMask); + Vec = ShuffleBuilder.finalize(E->ReuseShuffleIndices); + if (NeedFreeze) + Vec = Builder.CreateFreeze(Vec); return Vec; } @@ -7833,34 +9161,55 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return E->VectorizedValue; } - bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); - unsigned VF = E->getVectorFactor(); - ShuffleInstructionBuilder ShuffleBuilder(Builder, VF, GatherShuffleSeq, - CSEBlocks); + auto FinalShuffle = [&](Value *V, const TreeEntry *E) { + ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); + if (E->State != TreeEntry::NeedToGather && + E->getOpcode() == Instruction::Store) { + ArrayRef<int> Mask = + ArrayRef(reinterpret_cast<const int *>(E->ReorderIndices.begin()), + E->ReorderIndices.size()); + ShuffleBuilder.add(V, Mask); + } else { + ShuffleBuilder.addOrdered(V, E->ReorderIndices); + } + return ShuffleBuilder.finalize(E->ReuseShuffleIndices); + }; + if (E->State == TreeEntry::NeedToGather) { + if (E->Idx > 0) { + // We are in the middle of a vectorizable chain. We need to gather the + // scalars from the users. + Value *Vec = createBuildVector(E); + E->VectorizedValue = Vec; + return Vec; + } if (E->getMainOp()) setInsertPointAfterBundle(E); + SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end()); + // Build a mask out of the reorder indices and reorder scalars per this + // mask. + SmallVector<int> ReorderMask; + inversePermutation(E->ReorderIndices, ReorderMask); + if (!ReorderMask.empty()) + reorderScalars(GatheredScalars, ReorderMask); Value *Vec; SmallVector<int> Mask; SmallVector<const TreeEntry *> Entries; - Optional<TargetTransformInfo::ShuffleKind> Shuffle = - isGatherShuffledEntry(E, Mask, Entries); + std::optional<TargetTransformInfo::ShuffleKind> Shuffle = + isGatherShuffledEntry(E, GatheredScalars, Mask, Entries); if (Shuffle) { assert((Entries.size() == 1 || Entries.size() == 2) && "Expected shuffle of 1 or 2 entries."); Vec = Builder.CreateShuffleVector(Entries.front()->VectorizedValue, Entries.back()->VectorizedValue, Mask); if (auto *I = dyn_cast<Instruction>(Vec)) { - GatherShuffleSeq.insert(I); + GatherShuffleExtractSeq.insert(I); CSEBlocks.insert(I->getParent()); } } else { Vec = gather(E->Scalars); } - if (NeedToShuffleReuses) { - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - Vec = ShuffleBuilder.finalize(Vec); - } + Vec = FinalShuffle(Vec, E); E->VectorizedValue = Vec; return Vec; } @@ -7893,9 +9242,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Builder.SetInsertPoint(&*PH->getParent()->getFirstInsertionPt()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; @@ -7907,6 +9254,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ValueList Operands; BasicBlock *IBB = PH->getIncomingBlock(i); + // Stop emission if all incoming values are generated. + if (NewPhi->getNumIncomingValues() == PH->getNumIncomingValues()) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return V; + } + if (!VisitedBBs.insert(IBB).second) { NewPhi->addIncoming(NewPhi->getIncomingValueForBlock(IBB), IBB); continue; @@ -7914,7 +9267,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Builder.SetInsertPoint(IBB->getTerminator()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); - Value *Vec = vectorizeTree(E->getOperand(i)); + Value *Vec = vectorizeOperand(E, i); NewPhi->addIncoming(Vec, IBB); } @@ -7925,10 +9278,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::ExtractElement: { Value *V = E->getSingleOperand(0); - Builder.SetInsertPoint(VL0); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + setInsertPointAfterBundle(E); + V = FinalShuffle(V, E); E->VectorizedValue = V; return V; } @@ -7939,16 +9290,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *Ptr = Builder.CreateBitCast(LI->getOperand(0), PtrTy); LoadInst *V = Builder.CreateAlignedLoad(VecTy, Ptr, LI->getAlign()); Value *NewV = propagateMetadata(V, E->Scalars); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - NewV = ShuffleBuilder.finalize(NewV); + NewV = FinalShuffle(NewV, E); E->VectorizedValue = NewV; return NewV; } case Instruction::InsertElement: { assert(E->ReuseShuffleIndices.empty() && "All inserts should be unique"); Builder.SetInsertPoint(cast<Instruction>(E->Scalars.back())); - Value *V = vectorizeTree(E->getOperand(1)); + Value *V = vectorizeOperand(E, 1); // Create InsertVector shuffle if necessary auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) { @@ -7983,27 +9332,58 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (!IsIdentity || NumElts != NumScalars) { V = Builder.CreateShuffleVector(V, Mask); if (auto *I = dyn_cast<Instruction>(V)) { - GatherShuffleSeq.insert(I); + GatherShuffleExtractSeq.insert(I); CSEBlocks.insert(I->getParent()); } } - if ((!IsIdentity || Offset != 0 || - !isUndefVector(FirstInsert->getOperand(0))) && + SmallVector<int> InsertMask(NumElts, UndefMaskElem); + for (unsigned I = 0; I < NumElts; I++) { + if (Mask[I] != UndefMaskElem) + InsertMask[Offset + I] = I; + } + SmallBitVector UseMask = + buildUseMask(NumElts, InsertMask, UseMask::UndefsAsMask); + SmallBitVector IsFirstUndef = + isUndefVector(FirstInsert->getOperand(0), UseMask); + if ((!IsIdentity || Offset != 0 || !IsFirstUndef.all()) && NumElts != NumScalars) { - SmallVector<int> InsertMask(NumElts); - std::iota(InsertMask.begin(), InsertMask.end(), 0); - for (unsigned I = 0; I < NumElts; I++) { - if (Mask[I] != UndefMaskElem) - InsertMask[Offset + I] = NumElts + I; - } - - V = Builder.CreateShuffleVector( - FirstInsert->getOperand(0), V, InsertMask, - cast<Instruction>(E->Scalars.back())->getName()); - if (auto *I = dyn_cast<Instruction>(V)) { - GatherShuffleSeq.insert(I); - CSEBlocks.insert(I->getParent()); + if (IsFirstUndef.all()) { + if (!ShuffleVectorInst::isIdentityMask(InsertMask)) { + SmallBitVector IsFirstPoison = + isUndefVector<true>(FirstInsert->getOperand(0), UseMask); + if (!IsFirstPoison.all()) { + for (unsigned I = 0; I < NumElts; I++) { + if (InsertMask[I] == UndefMaskElem && !IsFirstPoison.test(I)) + InsertMask[I] = I + NumElts; + } + } + V = Builder.CreateShuffleVector( + V, + IsFirstPoison.all() ? PoisonValue::get(V->getType()) + : FirstInsert->getOperand(0), + InsertMask, cast<Instruction>(E->Scalars.back())->getName()); + if (auto *I = dyn_cast<Instruction>(V)) { + GatherShuffleExtractSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + } + } else { + SmallBitVector IsFirstPoison = + isUndefVector<true>(FirstInsert->getOperand(0), UseMask); + for (unsigned I = 0; I < NumElts; I++) { + if (InsertMask[I] == UndefMaskElem) + InsertMask[I] = IsFirstPoison.test(I) ? UndefMaskElem : I; + else + InsertMask[I] += NumElts; + } + V = Builder.CreateShuffleVector( + FirstInsert->getOperand(0), V, InsertMask, + cast<Instruction>(E->Scalars.back())->getName()); + if (auto *I = dyn_cast<Instruction>(V)) { + GatherShuffleExtractSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } } } @@ -8025,8 +9405,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::BitCast: { setInsertPointAfterBundle(E); - Value *InVec = vectorizeTree(E->getOperand(0)); - + Value *InVec = vectorizeOperand(E, 0); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; @@ -8034,9 +9413,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { auto *CI = cast<CastInst>(VL0); Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; @@ -8046,9 +9423,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::ICmp: { setInsertPointAfterBundle(E); - Value *L = vectorizeTree(E->getOperand(0)); - Value *R = vectorizeTree(E->getOperand(1)); - + Value *L = vectorizeOperand(E, 0); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } + Value *R = vectorizeOperand(E, 1); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; @@ -8057,9 +9437,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); Value *V = Builder.CreateCmp(P0, L, R); propagateIRFlags(V, E->Scalars, VL0); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; @@ -8068,19 +9446,24 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Select: { setInsertPointAfterBundle(E); - Value *Cond = vectorizeTree(E->getOperand(0)); - Value *True = vectorizeTree(E->getOperand(1)); - Value *False = vectorizeTree(E->getOperand(2)); - + Value *Cond = vectorizeOperand(E, 0); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } + Value *True = vectorizeOperand(E, 1); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } + Value *False = vectorizeOperand(E, 2); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } Value *V = Builder.CreateSelect(Cond, True, False); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; @@ -8089,7 +9472,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::FNeg: { setInsertPointAfterBundle(E); - Value *Op = vectorizeTree(E->getOperand(0)); + Value *Op = vectorizeOperand(E, 0); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); @@ -8102,9 +9485,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; @@ -8131,9 +9512,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Xor: { setInsertPointAfterBundle(E); - Value *LHS = vectorizeTree(E->getOperand(0)); - Value *RHS = vectorizeTree(E->getOperand(1)); - + Value *LHS = vectorizeOperand(E, 0); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } + Value *RHS = vectorizeOperand(E, 1); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; @@ -8146,9 +9530,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; @@ -8179,7 +9561,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } } else { assert(E->State == TreeEntry::ScatterVectorize && "Unhandled state"); - Value *VecPtr = vectorizeTree(E->getOperand(0)); + Value *VecPtr = vectorizeOperand(E, 0); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } // Use the minimum alignment of the gathered loads. Align CommonAlignment = LI->getAlign(); for (Value *V : E->Scalars) @@ -8189,9 +9575,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *V = propagateMetadata(NewLI, E->Scalars); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; return V; @@ -8202,9 +9586,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { setInsertPointAfterBundle(E); - Value *VecValue = vectorizeTree(E->getOperand(0)); - ShuffleBuilder.addMask(E->ReorderIndices); - VecValue = ShuffleBuilder.finalize(VecValue); + Value *VecValue = vectorizeOperand(E, 0); + VecValue = FinalShuffle(VecValue, E); Value *ScalarPtr = SI->getPointerOperand(); Value *VecPtr = Builder.CreateBitCast( @@ -8233,11 +9616,19 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { auto *GEP0 = cast<GetElementPtrInst>(VL0); setInsertPointAfterBundle(E); - Value *Op0 = vectorizeTree(E->getOperand(0)); + Value *Op0 = vectorizeOperand(E, 0); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } SmallVector<Value *> OpVecs; for (int J = 1, N = GEP0->getNumOperands(); J < N; ++J) { - Value *OpVec = vectorizeTree(E->getOperand(J)); + Value *OpVec = vectorizeOperand(E, J); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } OpVecs.push_back(OpVec); } @@ -8251,9 +9642,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { V = propagateMetadata(I, GEPs); } - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; @@ -8291,7 +9680,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { continue; } - Value *OpVec = vectorizeTree(E->getOperand(j)); + Value *OpVec = vectorizeOperand(E, j); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } LLVM_DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); if (isVectorIntrinsicWithOverloadTypeAtArg(IID, j)) @@ -8326,9 +9719,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } propagateIRFlags(V, E->Scalars, VL0); - ShuffleBuilder.addInversedMask(E->ReorderIndices); - ShuffleBuilder.addMask(E->ReuseShuffleIndices); - V = ShuffleBuilder.finalize(V); + V = FinalShuffle(V, E); E->VectorizedValue = V; ++NumVectorInstructions; @@ -8346,13 +9737,16 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *LHS = nullptr, *RHS = nullptr; if (Instruction::isBinaryOp(E->getOpcode()) || isa<CmpInst>(VL0)) { setInsertPointAfterBundle(E); - LHS = vectorizeTree(E->getOperand(0)); - RHS = vectorizeTree(E->getOperand(1)); + LHS = vectorizeOperand(E, 0); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } + RHS = vectorizeOperand(E, 1); } else { setInsertPointAfterBundle(E); - LHS = vectorizeTree(E->getOperand(0)); + LHS = vectorizeOperand(E, 0); } - if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; @@ -8379,7 +9773,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // instruction, if any. for (Value *V : {V0, V1}) { if (auto *I = dyn_cast<Instruction>(V)) { - GatherShuffleSeq.insert(I); + GatherShuffleExtractSeq.insert(I); CSEBlocks.insert(I->getParent()); } } @@ -8391,9 +9785,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { SmallVector<int> Mask; buildShuffleEntryMask( E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices, - [E](Instruction *I) { + [E, this](Instruction *I) { assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); - return isAlternateInstruction(I, E->getMainOp(), E->getAltOp()); + return isAlternateInstruction(I, E->getMainOp(), E->getAltOp(), + *TLI); }, Mask, &OpScalars, &AltScalars); @@ -8403,10 +9798,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *V = Builder.CreateShuffleVector(V0, V1, Mask); if (auto *I = dyn_cast<Instruction>(V)) { V = propagateMetadata(I, E->Scalars); - GatherShuffleSeq.insert(I); + GatherShuffleExtractSeq.insert(I); CSEBlocks.insert(I->getParent()); } - V = ShuffleBuilder.finalize(V); E->VectorizedValue = V; ++NumVectorInstructions; @@ -8435,14 +9829,27 @@ struct ShuffledInsertData { }; } // namespace -Value * -BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { +Value *BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues, + Instruction *ReductionRoot) { // All blocks must be scheduled before any instructions are inserted. for (auto &BSIter : BlocksSchedules) { scheduleBlock(BSIter.second.get()); } - Builder.SetInsertPoint(&F->getEntryBlock().front()); + // Pre-gather last instructions. + for (const std::unique_ptr<TreeEntry> &E : VectorizableTree) { + if ((E->State == TreeEntry::NeedToGather && + (!E->getMainOp() || E->Idx > 0)) || + (E->State != TreeEntry::NeedToGather && + E->getOpcode() == Instruction::ExtractValue) || + E->getOpcode() == Instruction::InsertElement) + continue; + Instruction *LastInst = &getLastInstructionInBundle(E.get()); + EntryToLastInstruction.try_emplace(E.get(), LastInst); + } + + Builder.SetInsertPoint(ReductionRoot ? ReductionRoot + : &F->getEntryBlock().front()); auto *VectorRoot = vectorizeTree(VectorizableTree[0].get()); // If the vectorized tree can be rewritten in a smaller type, we truncate the @@ -8471,6 +9878,9 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { SmallVector<ShuffledInsertData> ShuffledInserts; // Maps vector instruction to original insertelement instruction DenseMap<Value *, InsertElementInst *> VectorToInsertElement; + // Maps extract Scalar to the corresponding extractelement instruction in the + // basic block. Only one extractelement per block should be emitted. + DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs; // Extract all of the elements with the external uses. for (const auto &ExternalUse : ExternalUses) { Value *Scalar = ExternalUse.Scalar; @@ -8495,13 +9905,36 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { Value *Lane = Builder.getInt32(ExternalUse.Lane); auto ExtractAndExtendIfNeeded = [&](Value *Vec) { if (Scalar->getType() != Vec->getType()) { - Value *Ex; - // "Reuse" the existing extract to improve final codegen. - if (auto *ES = dyn_cast<ExtractElementInst>(Scalar)) { - Ex = Builder.CreateExtractElement(ES->getOperand(0), - ES->getOperand(1)); - } else { - Ex = Builder.CreateExtractElement(Vec, Lane); + Value *Ex = nullptr; + auto It = ScalarToEEs.find(Scalar); + if (It != ScalarToEEs.end()) { + // No need to emit many extracts, just move the only one in the + // current block. + auto EEIt = It->second.find(Builder.GetInsertBlock()); + if (EEIt != It->second.end()) { + Instruction *I = EEIt->second; + if (Builder.GetInsertPoint() != Builder.GetInsertBlock()->end() && + Builder.GetInsertPoint()->comesBefore(I)) + I->moveBefore(&*Builder.GetInsertPoint()); + Ex = I; + } + } + if (!Ex) { + // "Reuse" the existing extract to improve final codegen. + if (auto *ES = dyn_cast<ExtractElementInst>(Scalar)) { + Ex = Builder.CreateExtractElement(ES->getOperand(0), + ES->getOperand(1)); + } else { + Ex = Builder.CreateExtractElement(Vec, Lane); + } + if (auto *I = dyn_cast<Instruction>(Ex)) + ScalarToEEs[Scalar].try_emplace(Builder.GetInsertBlock(), I); + } + // The then branch of the previous if may produce constants, since 0 + // operand might be a constant. + if (auto *ExI = dyn_cast<Instruction>(Ex)) { + GatherShuffleExtractSeq.insert(ExI); + CSEBlocks.insert(ExI->getParent()); } // If necessary, sign-extend or zero-extend ScalarRoot // to the larger type. @@ -8526,13 +9959,15 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { "Scalar with nullptr as an external user must be registered in " "ExternallyUsedValues map"); if (auto *VecI = dyn_cast<Instruction>(Vec)) { - Builder.SetInsertPoint(VecI->getParent(), - std::next(VecI->getIterator())); + if (auto *PHI = dyn_cast<PHINode>(VecI)) + Builder.SetInsertPoint(PHI->getParent()->getFirstNonPHI()); + else + Builder.SetInsertPoint(VecI->getParent(), + std::next(VecI->getIterator())); } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); } Value *NewInst = ExtractAndExtendIfNeeded(Vec); - CSEBlocks.insert(cast<Instruction>(Scalar)->getParent()); auto &NewInstLocs = ExternallyUsedValues[NewInst]; auto It = ExternallyUsedValues.find(Scalar); assert(It != ExternallyUsedValues.end() && @@ -8548,7 +9983,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { // Skip if the scalar is another vector op or Vec is not an instruction. if (!Scalar->getType()->isVectorTy() && isa<Instruction>(Vec)) { if (auto *FTy = dyn_cast<FixedVectorType>(User->getType())) { - Optional<unsigned> InsertIdx = getInsertIndex(VU); + std::optional<unsigned> InsertIdx = getInsertIndex(VU); if (InsertIdx) { // Need to use original vector, if the root is truncated. if (MinBWs.count(Scalar) && @@ -8558,7 +9993,9 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { find_if(ShuffledInserts, [VU](const ShuffledInsertData &Data) { // Checks if 2 insertelements are from the same buildvector. InsertElementInst *VecInsert = Data.InsertElements.front(); - return areTwoInsertFromSameBuildVector(VU, VecInsert); + return areTwoInsertFromSameBuildVector( + VU, VecInsert, + [](InsertElementInst *II) { return II->getOperand(0); }); }); unsigned Idx = *InsertIdx; if (It == ShuffledInserts.end()) { @@ -8622,157 +10059,42 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); } Value *NewInst = ExtractAndExtendIfNeeded(Vec); - CSEBlocks.insert(PH->getIncomingBlock(i)); PH->setOperand(i, NewInst); } } } else { Builder.SetInsertPoint(cast<Instruction>(User)); Value *NewInst = ExtractAndExtendIfNeeded(Vec); - CSEBlocks.insert(cast<Instruction>(User)->getParent()); User->replaceUsesOfWith(Scalar, NewInst); } } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); Value *NewInst = ExtractAndExtendIfNeeded(Vec); - CSEBlocks.insert(&F->getEntryBlock()); User->replaceUsesOfWith(Scalar, NewInst); } LLVM_DEBUG(dbgs() << "SLP: Replaced:" << *User << ".\n"); } - // Checks if the mask is an identity mask. - auto &&IsIdentityMask = [](ArrayRef<int> Mask, FixedVectorType *VecTy) { - int Limit = Mask.size(); - return VecTy->getNumElements() == Mask.size() && - all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) && - ShuffleVectorInst::isIdentityMask(Mask); - }; - // Tries to combine 2 different masks into single one. - auto &&CombineMasks = [](SmallVectorImpl<int> &Mask, ArrayRef<int> ExtMask) { - SmallVector<int> NewMask(ExtMask.size(), UndefMaskElem); - for (int I = 0, Sz = ExtMask.size(); I < Sz; ++I) { - if (ExtMask[I] == UndefMaskElem) - continue; - NewMask[I] = Mask[ExtMask[I]]; - } - Mask.swap(NewMask); - }; - // Peek through shuffles, trying to simplify the final shuffle code. - auto &&PeekThroughShuffles = - [&IsIdentityMask, &CombineMasks](Value *&V, SmallVectorImpl<int> &Mask, - bool CheckForLengthChange = false) { - while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) { - // Exit if not a fixed vector type or changing size shuffle. - if (!isa<FixedVectorType>(SV->getType()) || - (CheckForLengthChange && SV->changesLength())) - break; - // Exit if the identity or broadcast mask is found. - if (IsIdentityMask(Mask, cast<FixedVectorType>(SV->getType())) || - SV->isZeroEltSplat()) - break; - bool IsOp1Undef = isUndefVector(SV->getOperand(0)); - bool IsOp2Undef = isUndefVector(SV->getOperand(1)); - if (!IsOp1Undef && !IsOp2Undef) - break; - SmallVector<int> ShuffleMask(SV->getShuffleMask().begin(), - SV->getShuffleMask().end()); - CombineMasks(ShuffleMask, Mask); - Mask.swap(ShuffleMask); - if (IsOp2Undef) - V = SV->getOperand(0); - else - V = SV->getOperand(1); - } - }; - // Smart shuffle instruction emission, walks through shuffles trees and - // tries to find the best matching vector for the actual shuffle - // instruction. - auto &&CreateShuffle = [this, &IsIdentityMask, &PeekThroughShuffles, - &CombineMasks](Value *V1, Value *V2, - ArrayRef<int> Mask) -> Value * { - assert(V1 && "Expected at least one vector value."); - if (V2 && !isUndefVector(V2)) { - // Peek through shuffles. - Value *Op1 = V1; - Value *Op2 = V2; - int VF = - cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue(); - SmallVector<int> CombinedMask1(Mask.size(), UndefMaskElem); - SmallVector<int> CombinedMask2(Mask.size(), UndefMaskElem); - for (int I = 0, E = Mask.size(); I < E; ++I) { - if (Mask[I] < VF) - CombinedMask1[I] = Mask[I]; - else - CombinedMask2[I] = Mask[I] - VF; - } - Value *PrevOp1; - Value *PrevOp2; - do { - PrevOp1 = Op1; - PrevOp2 = Op2; - PeekThroughShuffles(Op1, CombinedMask1, /*CheckForLengthChange=*/true); - PeekThroughShuffles(Op2, CombinedMask2, /*CheckForLengthChange=*/true); - // Check if we have 2 resizing shuffles - need to peek through operands - // again. - if (auto *SV1 = dyn_cast<ShuffleVectorInst>(Op1)) - if (auto *SV2 = dyn_cast<ShuffleVectorInst>(Op2)) - if (SV1->getOperand(0)->getType() == - SV2->getOperand(0)->getType() && - SV1->getOperand(0)->getType() != SV1->getType() && - isUndefVector(SV1->getOperand(1)) && - isUndefVector(SV2->getOperand(1))) { - Op1 = SV1->getOperand(0); - Op2 = SV2->getOperand(0); - SmallVector<int> ShuffleMask1(SV1->getShuffleMask().begin(), - SV1->getShuffleMask().end()); - CombineMasks(ShuffleMask1, CombinedMask1); - CombinedMask1.swap(ShuffleMask1); - SmallVector<int> ShuffleMask2(SV2->getShuffleMask().begin(), - SV2->getShuffleMask().end()); - CombineMasks(ShuffleMask2, CombinedMask2); - CombinedMask2.swap(ShuffleMask2); - } - } while (PrevOp1 != Op1 || PrevOp2 != Op2); - VF = cast<VectorType>(Op1->getType()) - ->getElementCount() - .getKnownMinValue(); - for (int I = 0, E = Mask.size(); I < E; ++I) { - if (CombinedMask2[I] != UndefMaskElem) { - assert(CombinedMask1[I] == UndefMaskElem && - "Expected undefined mask element"); - CombinedMask1[I] = CombinedMask2[I] + (Op1 == Op2 ? 0 : VF); - } - } - Value *Vec = Builder.CreateShuffleVector( - Op1, Op1 == Op2 ? PoisonValue::get(Op1->getType()) : Op2, - CombinedMask1); - if (auto *I = dyn_cast<Instruction>(Vec)) { - GatherShuffleSeq.insert(I); - CSEBlocks.insert(I->getParent()); - } - return Vec; - } - if (isa<PoisonValue>(V1)) - return PoisonValue::get(FixedVectorType::get( - cast<VectorType>(V1->getType())->getElementType(), Mask.size())); - Value *Op = V1; - SmallVector<int> CombinedMask(Mask.begin(), Mask.end()); - PeekThroughShuffles(Op, CombinedMask); - if (!isa<FixedVectorType>(Op->getType()) || - !IsIdentityMask(CombinedMask, cast<FixedVectorType>(Op->getType()))) { - Value *Vec = Builder.CreateShuffleVector(Op, CombinedMask); - if (auto *I = dyn_cast<Instruction>(Vec)) { - GatherShuffleSeq.insert(I); - CSEBlocks.insert(I->getParent()); - } - return Vec; + auto CreateShuffle = [&](Value *V1, Value *V2, ArrayRef<int> Mask) { + SmallVector<int> CombinedMask1(Mask.size(), UndefMaskElem); + SmallVector<int> CombinedMask2(Mask.size(), UndefMaskElem); + int VF = cast<FixedVectorType>(V1->getType())->getNumElements(); + for (int I = 0, E = Mask.size(); I < E; ++I) { + if (Mask[I] < VF) + CombinedMask1[I] = Mask[I]; + else + CombinedMask2[I] = Mask[I] - VF; } - return Op; + ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); + ShuffleBuilder.add(V1, CombinedMask1); + if (V2) + ShuffleBuilder.add(V2, CombinedMask2); + return ShuffleBuilder.finalize(std::nullopt); }; - auto &&ResizeToVF = [&CreateShuffle](Value *Vec, ArrayRef<int> Mask) { + auto &&ResizeToVF = [&CreateShuffle](Value *Vec, ArrayRef<int> Mask, + bool ForSingleMask) { unsigned VF = Mask.size(); unsigned VecVF = cast<FixedVectorType>(Vec->getType())->getNumElements(); if (VF != VecVF) { @@ -8780,12 +10102,14 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { Vec = CreateShuffle(Vec, nullptr, Mask); return std::make_pair(Vec, true); } - SmallVector<int> ResizeMask(VF, UndefMaskElem); - for (unsigned I = 0; I < VF; ++I) { - if (Mask[I] != UndefMaskElem) - ResizeMask[Mask[I]] = Mask[I]; + if (!ForSingleMask) { + SmallVector<int> ResizeMask(VF, UndefMaskElem); + for (unsigned I = 0; I < VF; ++I) { + if (Mask[I] != UndefMaskElem) + ResizeMask[Mask[I]] = Mask[I]; + } + Vec = CreateShuffle(Vec, nullptr, ResizeMask); } - Vec = CreateShuffle(Vec, nullptr, ResizeMask); } return std::make_pair(Vec, false); @@ -8800,7 +10124,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { Builder.SetInsertPoint(LastInsert); auto Vector = ShuffledInserts[I].ValueMasks.takeVector(); Value *NewInst = performExtractsShuffleAction<Value>( - makeMutableArrayRef(Vector.data(), Vector.size()), + MutableArrayRef(Vector.data(), Vector.size()), FirstInsert->getOperand(0), [](Value *Vec) { return cast<VectorType>(Vec->getType()) @@ -8857,6 +10181,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { CSEBlocks.insert(LastInsert->getParent()); } + SmallVector<Instruction *> RemovedInsts; // For each vectorized value: for (auto &TEPtr : VectorizableTree) { TreeEntry *Entry = TEPtr.get(); @@ -8891,9 +10216,18 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { #endif LLVM_DEBUG(dbgs() << "SLP: \tErasing scalar:" << *Scalar << ".\n"); eraseInstruction(cast<Instruction>(Scalar)); + // Retain to-be-deleted instructions for some debug-info + // bookkeeping. NOTE: eraseInstruction only marks the instruction for + // deletion - instructions are not deleted until later. + RemovedInsts.push_back(cast<Instruction>(Scalar)); } } + // Merge the DIAssignIDs from the about-to-be-deleted instructions into the + // new vector instruction. + if (auto *V = dyn_cast<Instruction>(VectorizableTree[0]->VectorizedValue)) + V->mergeDIAssignID(RemovedInsts); + Builder.ClearInsertionPoint(); InstrElementSize.clear(); @@ -8901,10 +10235,10 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { } void BoUpSLP::optimizeGatherSequence() { - LLVM_DEBUG(dbgs() << "SLP: Optimizing " << GatherShuffleSeq.size() + LLVM_DEBUG(dbgs() << "SLP: Optimizing " << GatherShuffleExtractSeq.size() << " gather sequences instructions.\n"); // LICM InsertElementInst sequences. - for (Instruction *I : GatherShuffleSeq) { + for (Instruction *I : GatherShuffleExtractSeq) { if (isDeleted(I)) continue; @@ -8929,6 +10263,7 @@ void BoUpSLP::optimizeGatherSequence() { // We can hoist this instruction. Move it to the pre-header. I->moveBefore(PreHeader->getTerminator()); + CSEBlocks.insert(PreHeader); } // Make a list of all reachable blocks in our CSE queue. @@ -9004,8 +10339,8 @@ void BoUpSLP::optimizeGatherSequence() { for (Instruction &In : llvm::make_early_inc_range(*BB)) { if (isDeleted(&In)) continue; - if (!isa<InsertElementInst>(&In) && !isa<ExtractElementInst>(&In) && - !isa<ShuffleVectorInst>(&In) && !GatherShuffleSeq.contains(&In)) + if (!isa<InsertElementInst, ExtractElementInst, ShuffleVectorInst>(&In) && + !GatherShuffleExtractSeq.contains(&In)) continue; // Check if we can replace this instruction with any of the @@ -9024,7 +10359,7 @@ void BoUpSLP::optimizeGatherSequence() { break; } if (isa<ShuffleVectorInst>(In) && isa<ShuffleVectorInst>(V) && - GatherShuffleSeq.contains(V) && + GatherShuffleExtractSeq.contains(V) && IsIdenticalOrLessDefined(V, &In, NewMask) && DT->dominates(In.getParent(), V->getParent())) { In.moveAfter(V); @@ -9045,7 +10380,7 @@ void BoUpSLP::optimizeGatherSequence() { } } CSEBlocks.clear(); - GatherShuffleSeq.clear(); + GatherShuffleExtractSeq.clear(); } BoUpSLP::ScheduleData * @@ -9077,7 +10412,7 @@ BoUpSLP::BlockScheduling::buildBundle(ArrayRef<Value *> VL) { // Groups the instructions to a bundle (which is then a single scheduling entity) // and schedules instructions until the bundle gets ready. -Optional<BoUpSLP::ScheduleData *> +std::optional<BoUpSLP::ScheduleData *> BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, const InstructionsState &S) { // No need to schedule PHIs, insertelement, extractelement and extractvalue @@ -9139,7 +10474,7 @@ BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, // dependencies and emit instruction in the wrong order at the actual // scheduling. TryScheduleBundleImpl(/*ReSchedule=*/false, nullptr); - return None; + return std::nullopt; } } @@ -9169,7 +10504,7 @@ BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, TryScheduleBundleImpl(ReSchedule, Bundle); if (!Bundle->isReady()) { cancelScheduling(VL, S.OpValue); - return None; + return std::nullopt; } return Bundle; } @@ -9397,13 +10732,13 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, WorkList.push_back(DestBundle); }; - // Any instruction which isn't safe to speculate at the begining of the + // Any instruction which isn't safe to speculate at the beginning of the // block is control dependend on any early exit or non-willreturn call // which proceeds it. if (!isGuaranteedToTransferExecutionToSuccessor(BundleMember->Inst)) { for (Instruction *I = BundleMember->Inst->getNextNode(); I != ScheduleEnd; I = I->getNextNode()) { - if (isSafeToSpeculativelyExecute(I, &*BB->begin())) + if (isSafeToSpeculativelyExecute(I, &*BB->begin(), SLP->AC)) continue; // Add the dependency @@ -9438,9 +10773,12 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, } // In addition to the cases handle just above, we need to prevent - // allocas from moving below a stacksave. The stackrestore case - // is currently thought to be conservatism. - if (isa<AllocaInst>(BundleMember->Inst)) { + // allocas and loads/stores from moving below a stacksave or a + // stackrestore. Avoiding moving allocas below stackrestore is currently + // thought to be conservatism. Moving loads/stores below a stackrestore + // can lead to incorrect code. + if (isa<AllocaInst>(BundleMember->Inst) || + BundleMember->Inst->mayReadOrWriteMemory()) { for (Instruction *I = BundleMember->Inst->getNextNode(); I != ScheduleEnd; I = I->getNextNode()) { if (!match(I, m_Intrinsic<Intrinsic::stacksave>()) && @@ -9663,17 +11001,15 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { // If the current instruction is a load, update MaxWidth to reflect the // width of the loaded value. - if (isa<LoadInst>(I) || isa<ExtractElementInst>(I) || - isa<ExtractValueInst>(I)) + if (isa<LoadInst, ExtractElementInst, ExtractValueInst>(I)) Width = std::max<unsigned>(Width, DL->getTypeSizeInBits(Ty)); // Otherwise, we need to visit the operands of the instruction. We only // handle the interesting cases from buildTree here. If an operand is an // instruction we haven't yet visited and from the same basic block as the // user or the use is a PHI node, we add it to the worklist. - else if (isa<PHINode>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || - isa<CmpInst>(I) || isa<SelectInst>(I) || isa<BinaryOperator>(I) || - isa<UnaryOperator>(I)) { + else if (isa<PHINode, CastInst, GetElementPtrInst, CmpInst, SelectInst, + BinaryOperator, UnaryOperator>(I)) { for (Use &U : I->operands()) if (auto *J = dyn_cast<Instruction>(U.get())) if (Visited.insert(J).second && @@ -9726,8 +11062,7 @@ static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr, break; case Instruction::ZExt: case Instruction::SExt: - if (isa<ExtractElementInst>(I->getOperand(0)) || - isa<InsertElementInst>(I->getOperand(0))) + if (isa<ExtractElementInst, InsertElementInst>(I->getOperand(0))) return false; break; @@ -10028,7 +11363,7 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, DT->updateDFSNumbers(); // Scan the blocks in the function in post order. - for (auto BB : post_order(&F.getEntryBlock())) { + for (auto *BB : post_order(&F.getEntryBlock())) { // Start new block - clear the list of reduction roots. R.clearReductionData(); collectSeedInstructions(BB); @@ -10086,7 +11421,7 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, InstructionCost Cost = R.getTreeCost(); - LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for VF =" << VF << "\n"); + LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for VF=" << VF << "\n"); if (Cost < -SLPCostThreshold) { LLVM_DEBUG(dbgs() << "SLP: Decided to vectorize cost = " << Cost << "\n"); @@ -10130,7 +11465,7 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, ++IterCnt; CheckedPairs[Idx].set(K); CheckedPairs[K].set(Idx); - Optional<int> Diff = getPointersDiff( + std::optional<int> Diff = getPointersDiff( Stores[K]->getValueOperand()->getType(), Stores[K]->getPointerOperand(), Stores[Idx]->getValueOperand()->getType(), Stores[Idx]->getPointerOperand(), *DL, *SE, /*StrictCheck=*/true); @@ -10213,12 +11548,17 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, unsigned MinVF = TTI->getStoreMinimumVF( R.getMinVF(DL->getTypeSizeInBits(ValueTy)), StoreTy, ValueTy); + if (MaxVF <= MinVF) { + LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF << ") <= " + << "MinVF (" << MinVF << ")\n"); + } + // FIXME: Is division-by-2 the correct step? Should we assert that the // register size is a power-of-2? unsigned StartIdx = 0; for (unsigned Size = MaxVF; Size >= MinVF; Size /= 2) { for (unsigned Cnt = StartIdx, E = Operands.size(); Cnt + Size <= E;) { - ArrayRef<Value *> Slice = makeArrayRef(Operands).slice(Cnt, Size); + ArrayRef<Value *> Slice = ArrayRef(Operands).slice(Cnt, Size); if (!VectorizedStores.count(Slice.front()) && !VectorizedStores.count(Slice.back()) && vectorizeStoreChain(Slice, R, Cnt, MinVF)) { @@ -10297,7 +11637,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, // Check that all of the parts are instructions of the same type, // we permit an alternate opcode via InstructionsState. - InstructionsState S = getSameOpcode(VL); + InstructionsState S = getSameOpcode(VL, *TLI); if (!S.getOpcode()) return false; @@ -10379,7 +11719,9 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, if (R.isTreeTinyAndNotFullyVectorizable()) continue; R.reorderTopToBottom(); - R.reorderBottomToTop(!isa<InsertElementInst>(Ops.front())); + R.reorderBottomToTop( + /*IgnoreReorder=*/!isa<InsertElementInst>(Ops.front()) && + !R.doesRootHaveInTreeUses()); R.buildExternalUses(); R.computeMinimumValueSizes(); @@ -10387,6 +11729,8 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, CandidateFound = true; MinCost = std::min(MinCost, Cost); + LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost + << " for VF=" << OpsWidth << "\n"); if (Cost < -SLPCostThreshold) { LLVM_DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n"); R.getORE()->emit(OptimizationRemark(SV_NAME, "VectorizedList", @@ -10425,8 +11769,7 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { if (!I) return false; - if ((!isa<BinaryOperator>(I) && !isa<CmpInst>(I)) || - isa<VectorType>(I->getType())) + if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType())) return false; Value *P = I->getParent(); @@ -10466,7 +11809,7 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { return tryToVectorizePair(Op0, Op1, R); // We have multiple options. Try to pick the single best. - Optional<int> BestCandidate = R.findBestRootPair(Candidates); + std::optional<int> BestCandidate = R.findBestRootPair(Candidates); if (!BestCandidate) return false; return tryToVectorizePair(Candidates[*BestCandidate].first, @@ -10524,8 +11867,8 @@ class HorizontalReduction { // select x, y, false // select x, true, y static bool isBoolLogicOp(Instruction *I) { - return match(I, m_LogicalAnd(m_Value(), m_Value())) || - match(I, m_LogicalOr(m_Value(), m_Value())); + return isa<SelectInst>(I) && + (match(I, m_LogicalAnd()) || match(I, m_LogicalOr())); } /// Checks if instruction is associative and can be vectorized. @@ -10751,7 +12094,7 @@ class HorizontalReduction { /// Checks if the instruction is in basic block \p BB. /// For a cmp+sel min/max reduction check that both ops are in \p BB. static bool hasSameParent(Instruction *I, BasicBlock *BB) { - if (isCmpSelMinMax(I) || (isBoolLogicOp(I) && isa<SelectInst>(I))) { + if (isCmpSelMinMax(I) || isBoolLogicOp(I)) { auto *Sel = cast<SelectInst>(I); auto *Cmp = dyn_cast<Instruction>(Sel->getCondition()); return Sel->getParent() == BB && Cmp && Cmp->getParent() == BB; @@ -10802,6 +12145,13 @@ class HorizontalReduction { return I->getOperand(getFirstOperandIndex(I) + 1); } + static bool isGoodForReduction(ArrayRef<Value *> Data) { + int Sz = Data.size(); + auto *I = dyn_cast<Instruction>(Data.front()); + return Sz > 1 || isConstant(Data.front()) || + (I && !isa<LoadInst>(I) && isValidForAlternation(I->getOpcode())); + } + public: HorizontalReduction() = default; @@ -10897,6 +12247,9 @@ public: MapVector<size_t, MapVector<size_t, MapVector<Value *, unsigned>>> PossibleReducedVals; initReductionOps(Inst); + DenseMap<Value *, SmallVector<LoadInst *>> LoadsMap; + SmallSet<size_t, 2> LoadKeyUsed; + SmallPtrSet<Value *, 4> DoNotReverseVals; while (!Worklist.empty()) { Instruction *TreeN = Worklist.pop_back_val(); SmallVector<Value *> Args; @@ -10918,18 +12271,36 @@ public: size_t Key, Idx; std::tie(Key, Idx) = generateKeySubkey( V, &TLI, - [&PossibleReducedVals, &DL, &SE](size_t Key, LoadInst *LI) { - auto It = PossibleReducedVals.find(Key); - if (It != PossibleReducedVals.end()) { - for (const auto &LoadData : It->second) { - auto *RLI = cast<LoadInst>(LoadData.second.front().first); - if (getPointersDiff(RLI->getType(), - RLI->getPointerOperand(), LI->getType(), - LI->getPointerOperand(), DL, SE, - /*StrictCheck=*/true)) - return hash_value(RLI->getPointerOperand()); + [&](size_t Key, LoadInst *LI) { + Value *Ptr = getUnderlyingObject(LI->getPointerOperand()); + if (LoadKeyUsed.contains(Key)) { + auto LIt = LoadsMap.find(Ptr); + if (LIt != LoadsMap.end()) { + for (LoadInst *RLI: LIt->second) { + if (getPointersDiff( + RLI->getType(), RLI->getPointerOperand(), + LI->getType(), LI->getPointerOperand(), DL, SE, + /*StrictCheck=*/true)) + return hash_value(RLI->getPointerOperand()); + } + for (LoadInst *RLI : LIt->second) { + if (arePointersCompatible(RLI->getPointerOperand(), + LI->getPointerOperand(), TLI)) { + hash_code SubKey = hash_value(RLI->getPointerOperand()); + DoNotReverseVals.insert(RLI); + return SubKey; + } + } + if (LIt->second.size() > 2) { + hash_code SubKey = + hash_value(LIt->second.back()->getPointerOperand()); + DoNotReverseVals.insert(LIt->second.back()); + return SubKey; + } } } + LoadKeyUsed.insert(Key); + LoadsMap.try_emplace(Ptr).first->second.push_back(LI); return hash_value(LI->getPointerOperand()); }, /*AllowAlternate=*/false); @@ -10943,17 +12314,35 @@ public: size_t Key, Idx; std::tie(Key, Idx) = generateKeySubkey( TreeN, &TLI, - [&PossibleReducedVals, &DL, &SE](size_t Key, LoadInst *LI) { - auto It = PossibleReducedVals.find(Key); - if (It != PossibleReducedVals.end()) { - for (const auto &LoadData : It->second) { - auto *RLI = cast<LoadInst>(LoadData.second.front().first); - if (getPointersDiff(RLI->getType(), RLI->getPointerOperand(), - LI->getType(), LI->getPointerOperand(), - DL, SE, /*StrictCheck=*/true)) - return hash_value(RLI->getPointerOperand()); + [&](size_t Key, LoadInst *LI) { + Value *Ptr = getUnderlyingObject(LI->getPointerOperand()); + if (LoadKeyUsed.contains(Key)) { + auto LIt = LoadsMap.find(Ptr); + if (LIt != LoadsMap.end()) { + for (LoadInst *RLI: LIt->second) { + if (getPointersDiff(RLI->getType(), + RLI->getPointerOperand(), LI->getType(), + LI->getPointerOperand(), DL, SE, + /*StrictCheck=*/true)) + return hash_value(RLI->getPointerOperand()); + } + for (LoadInst *RLI : LIt->second) { + if (arePointersCompatible(RLI->getPointerOperand(), + LI->getPointerOperand(), TLI)) { + hash_code SubKey = hash_value(RLI->getPointerOperand()); + DoNotReverseVals.insert(RLI); + return SubKey; + } + } + if (LIt->second.size() > 2) { + hash_code SubKey = hash_value(LIt->second.back()->getPointerOperand()); + DoNotReverseVals.insert(LIt->second.back()); + return SubKey; + } } } + LoadKeyUsed.insert(Key); + LoadsMap.try_emplace(Ptr).first->second.push_back(LI); return hash_value(LI->getPointerOperand()); }, /*AllowAlternate=*/false); @@ -10979,9 +12368,27 @@ public: stable_sort(PossibleRedValsVect, [](const auto &P1, const auto &P2) { return P1.size() > P2.size(); }); - ReducedVals.emplace_back(); - for (ArrayRef<Value *> Data : PossibleRedValsVect) - ReducedVals.back().append(Data.rbegin(), Data.rend()); + int NewIdx = -1; + for (ArrayRef<Value *> Data : PossibleRedValsVect) { + if (isGoodForReduction(Data) || + (isa<LoadInst>(Data.front()) && NewIdx >= 0 && + isa<LoadInst>(ReducedVals[NewIdx].front()) && + getUnderlyingObject( + cast<LoadInst>(Data.front())->getPointerOperand()) == + getUnderlyingObject(cast<LoadInst>(ReducedVals[NewIdx].front()) + ->getPointerOperand()))) { + if (NewIdx < 0) { + NewIdx = ReducedVals.size(); + ReducedVals.emplace_back(); + } + if (DoNotReverseVals.contains(Data.front())) + ReducedVals[NewIdx].append(Data.begin(), Data.end()); + else + ReducedVals[NewIdx].append(Data.rbegin(), Data.rend()); + } else { + ReducedVals.emplace_back().append(Data.rbegin(), Data.rend()); + } + } } // Sort the reduced values by number of same/alternate opcode and/or pointer // operand. @@ -10992,25 +12399,36 @@ public: } /// Attempt to vectorize the tree found by matchAssociativeReduction. - Value *tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI) { + Value *tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI, + const TargetLibraryInfo &TLI) { constexpr int ReductionLimit = 4; constexpr unsigned RegMaxNumber = 4; constexpr unsigned RedValsMaxNumber = 128; // If there are a sufficient number of reduction values, reduce // to a nearby power-of-2. We can safely generate oversized // vectors and rely on the backend to split them to legal sizes. - unsigned NumReducedVals = std::accumulate( - ReducedVals.begin(), ReducedVals.end(), 0, - [](int Num, ArrayRef<Value *> Vals) { return Num + Vals.size(); }); - if (NumReducedVals < ReductionLimit) + size_t NumReducedVals = + std::accumulate(ReducedVals.begin(), ReducedVals.end(), 0, + [](size_t Num, ArrayRef<Value *> Vals) { + if (!isGoodForReduction(Vals)) + return Num; + return Num + Vals.size(); + }); + if (NumReducedVals < ReductionLimit) { + for (ReductionOpsType &RdxOps : ReductionOps) + for (Value *RdxOp : RdxOps) + V.analyzedReductionRoot(cast<Instruction>(RdxOp)); return nullptr; + } IRBuilder<> Builder(cast<Instruction>(ReductionRoot)); // Track the reduced values in case if they are replaced by extractelement // because of the vectorization. - DenseMap<Value *, WeakTrackingVH> TrackedVals; + DenseMap<Value *, WeakTrackingVH> TrackedVals( + ReducedVals.size() * ReducedVals.front().size() + ExtraArgs.size()); BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues; + ExternallyUsedValues.reserve(ExtraArgs.size() + 1); // The same extra argument may be used several times, so log each attempt // to use it. for (const std::pair<Instruction *, Value *> &Pair : ExtraArgs) { @@ -11033,7 +12451,8 @@ public: // The reduction root is used as the insertion point for new instructions, // so set it as externally used to prevent it from being deleted. ExternallyUsedValues[ReductionRoot]; - SmallDenseSet<Value *> IgnoreList; + SmallDenseSet<Value *> IgnoreList(ReductionOps.size() * + ReductionOps.front().size()); for (ReductionOpsType &RdxOps : ReductionOps) for (Value *RdxOp : RdxOps) { if (!RdxOp) @@ -11048,15 +12467,19 @@ public: for (Value *V : Candidates) TrackedVals.try_emplace(V, V); - DenseMap<Value *, unsigned> VectorizedVals; + DenseMap<Value *, unsigned> VectorizedVals(ReducedVals.size()); + // List of the values that were reduced in other trees as part of gather + // nodes and thus requiring extract if fully vectorized in other trees. + SmallPtrSet<Value *, 4> RequiredExtract; Value *VectorizedTree = nullptr; bool CheckForReusedReductionOps = false; // Try to vectorize elements based on their type. for (unsigned I = 0, E = ReducedVals.size(); I < E; ++I) { ArrayRef<Value *> OrigReducedVals = ReducedVals[I]; - InstructionsState S = getSameOpcode(OrigReducedVals); + InstructionsState S = getSameOpcode(OrigReducedVals, TLI); SmallVector<Value *> Candidates; - DenseMap<Value *, Value *> TrackedToOrig; + Candidates.reserve(2 * OrigReducedVals.size()); + DenseMap<Value *, Value *> TrackedToOrig(2 * OrigReducedVals.size()); for (unsigned Cnt = 0, Sz = OrigReducedVals.size(); Cnt < Sz; ++Cnt) { Value *RdxVal = TrackedVals.find(OrigReducedVals[Cnt])->second; // Check if the reduction value was not overriden by the extractelement @@ -11073,7 +12496,7 @@ public: // Try to handle shuffled extractelements. if (S.getOpcode() == Instruction::ExtractElement && !S.isAltShuffle() && I + 1 < E) { - InstructionsState NextS = getSameOpcode(ReducedVals[I + 1]); + InstructionsState NextS = getSameOpcode(ReducedVals[I + 1], TLI); if (NextS.getOpcode() == Instruction::ExtractElement && !NextS.isAltShuffle()) { SmallVector<Value *> CommonCandidates(Candidates); @@ -11181,37 +12604,49 @@ public: }); } // Number of uses of the candidates in the vector of values. - SmallDenseMap<Value *, unsigned> NumUses; + SmallDenseMap<Value *, unsigned> NumUses(Candidates.size()); for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) { Value *V = Candidates[Cnt]; - if (NumUses.count(V) > 0) - continue; - NumUses[V] = std::count(VL.begin(), VL.end(), V); + ++NumUses.try_emplace(V, 0).first->getSecond(); } for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) { Value *V = Candidates[Cnt]; - if (NumUses.count(V) > 0) - continue; - NumUses[V] = std::count(VL.begin(), VL.end(), V); + ++NumUses.try_emplace(V, 0).first->getSecond(); } + SmallPtrSet<Value *, 4> VLScalars(VL.begin(), VL.end()); // Gather externally used values. SmallPtrSet<Value *, 4> Visited; for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) { - Value *V = Candidates[Cnt]; - if (!Visited.insert(V).second) + Value *RdxVal = Candidates[Cnt]; + if (!Visited.insert(RdxVal).second) continue; - unsigned NumOps = VectorizedVals.lookup(V) + NumUses[V]; - if (NumOps != ReducedValsToOps.find(V)->second.size()) - LocalExternallyUsedValues[V]; + // Check if the scalar was vectorized as part of the vectorization + // tree but not the top node. + if (!VLScalars.contains(RdxVal) && V.isVectorized(RdxVal)) { + LocalExternallyUsedValues[RdxVal]; + continue; + } + unsigned NumOps = VectorizedVals.lookup(RdxVal) + NumUses[RdxVal]; + if (NumOps != ReducedValsToOps.find(RdxVal)->second.size()) + LocalExternallyUsedValues[RdxVal]; } for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) { - Value *V = Candidates[Cnt]; - if (!Visited.insert(V).second) + Value *RdxVal = Candidates[Cnt]; + if (!Visited.insert(RdxVal).second) continue; - unsigned NumOps = VectorizedVals.lookup(V) + NumUses[V]; - if (NumOps != ReducedValsToOps.find(V)->second.size()) - LocalExternallyUsedValues[V]; + // Check if the scalar was vectorized as part of the vectorization + // tree but not the top node. + if (!VLScalars.contains(RdxVal) && V.isVectorized(RdxVal)) { + LocalExternallyUsedValues[RdxVal]; + continue; + } + unsigned NumOps = VectorizedVals.lookup(RdxVal) + NumUses[RdxVal]; + if (NumOps != ReducedValsToOps.find(RdxVal)->second.size()) + LocalExternallyUsedValues[RdxVal]; } + for (Value *RdxVal : VL) + if (RequiredExtract.contains(RdxVal)) + LocalExternallyUsedValues[RdxVal]; V.buildExternalUses(LocalExternallyUsedValues); V.computeMinimumValueSizes(); @@ -11226,11 +12661,25 @@ public: InstructionCost TreeCost = V.getTreeCost(VL); InstructionCost ReductionCost = getReductionCost(TTI, VL, ReduxWidth, RdxFMF); + if (V.isVectorizedFirstNode() && isa<LoadInst>(VL.front())) { + Instruction *MainOp = V.getFirstNodeMainOp(); + for (Value *V : VL) { + auto *VI = dyn_cast<LoadInst>(V); + // Add the costs of scalar GEP pointers, to be removed from the + // code. + if (!VI || VI == MainOp) + continue; + auto *Ptr = dyn_cast<GetElementPtrInst>(VI->getPointerOperand()); + if (!Ptr || !Ptr->hasOneUse() || Ptr->hasAllConstantIndices()) + continue; + TreeCost -= TTI->getArithmeticInstrCost( + Instruction::Add, Ptr->getType(), TTI::TCK_RecipThroughput); + } + } InstructionCost Cost = TreeCost + ReductionCost; - if (!Cost.isValid()) { - LLVM_DEBUG(dbgs() << "Encountered invalid baseline cost.\n"); + LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n"); + if (!Cost.isValid()) return nullptr; - } if (Cost >= -SLPCostThreshold) { V.getORE()->emit([&]() { return OptimizationRemarkMissed( @@ -11259,21 +12708,23 @@ public: Builder.setFastMathFlags(RdxFMF); - // Vectorize a tree. - Value *VectorizedRoot = V.vectorizeTree(LocalExternallyUsedValues); - // Emit a reduction. If the root is a select (min/max idiom), the insert // point is the compare condition of that select. Instruction *RdxRootInst = cast<Instruction>(ReductionRoot); + Instruction *InsertPt = RdxRootInst; if (IsCmpSelMinMax) - Builder.SetInsertPoint(GetCmpForMinMaxReduction(RdxRootInst)); - else - Builder.SetInsertPoint(RdxRootInst); + InsertPt = GetCmpForMinMaxReduction(RdxRootInst); + + // Vectorize a tree. + Value *VectorizedRoot = + V.vectorizeTree(LocalExternallyUsedValues, InsertPt); + + Builder.SetInsertPoint(InsertPt); // To prevent poison from leaking across what used to be sequential, // safe, scalar boolean logic operations, the reduction operand must be // frozen. - if (isa<SelectInst>(RdxRootInst) && isBoolLogicOp(RdxRootInst)) + if (isBoolLogicOp(RdxRootInst)) VectorizedRoot = Builder.CreateFreeze(VectorizedRoot); Value *ReducedSubTree = @@ -11290,23 +12741,59 @@ public: ReducedSubTree, "op.rdx", ReductionOps); } // Count vectorized reduced values to exclude them from final reduction. - for (Value *V : VL) - ++VectorizedVals.try_emplace(TrackedToOrig.find(V)->second, 0) + for (Value *RdxVal : VL) { + ++VectorizedVals.try_emplace(TrackedToOrig.find(RdxVal)->second, 0) .first->getSecond(); + if (!V.isVectorized(RdxVal)) + RequiredExtract.insert(RdxVal); + } Pos += ReduxWidth; Start = Pos; ReduxWidth = PowerOf2Floor(NumReducedVals - Pos); } } if (VectorizedTree) { + // Reorder operands of bool logical op in the natural order to avoid + // possible problem with poison propagation. If not possible to reorder + // (both operands are originally RHS), emit an extra freeze instruction + // for the LHS operand. + //I.e., if we have original code like this: + // RedOp1 = select i1 ?, i1 LHS, i1 false + // RedOp2 = select i1 RHS, i1 ?, i1 false + + // Then, we swap LHS/RHS to create a new op that matches the poison + // semantics of the original code. + + // If we have original code like this and both values could be poison: + // RedOp1 = select i1 ?, i1 LHS, i1 false + // RedOp2 = select i1 ?, i1 RHS, i1 false + + // Then, we must freeze LHS in the new op. + auto &&FixBoolLogicalOps = + [&Builder, VectorizedTree](Value *&LHS, Value *&RHS, + Instruction *RedOp1, Instruction *RedOp2) { + if (!isBoolLogicOp(RedOp1)) + return; + if (LHS == VectorizedTree || getRdxOperand(RedOp1, 0) == LHS || + isGuaranteedNotToBePoison(LHS)) + return; + if (!isBoolLogicOp(RedOp2)) + return; + if (RHS == VectorizedTree || getRdxOperand(RedOp2, 0) == RHS || + isGuaranteedNotToBePoison(RHS)) { + std::swap(LHS, RHS); + return; + } + LHS = Builder.CreateFreeze(LHS); + }; // Finish the reduction. // Need to add extra arguments and not vectorized possible reduction // values. // Try to avoid dependencies between the scalar remainders after // reductions. auto &&FinalGen = - [this, &Builder, - &TrackedVals](ArrayRef<std::pair<Instruction *, Value *>> InstVals) { + [this, &Builder, &TrackedVals, &FixBoolLogicalOps]( + ArrayRef<std::pair<Instruction *, Value *>> InstVals) { unsigned Sz = InstVals.size(); SmallVector<std::pair<Instruction *, Value *>> ExtraReds(Sz / 2 + Sz % 2); @@ -11323,6 +12810,11 @@ public: auto It2 = TrackedVals.find(RdxVal2); if (It2 != TrackedVals.end()) StableRdxVal2 = It2->second; + // To prevent poison from leaking across what used to be + // sequential, safe, scalar boolean logic operations, the + // reduction operand must be frozen. + FixBoolLogicalOps(StableRdxVal1, StableRdxVal2, InstVals[I].first, + RedOp); Value *ExtraRed = createOp(Builder, RdxKind, StableRdxVal1, StableRdxVal2, "op.rdx", ReductionOps); ExtraReds[I / 2] = std::make_pair(InstVals[I].first, ExtraRed); @@ -11332,6 +12824,8 @@ public: return ExtraReds; }; SmallVector<std::pair<Instruction *, Value *>> ExtraReductions; + ExtraReductions.emplace_back(cast<Instruction>(ReductionRoot), + VectorizedTree); SmallPtrSet<Value *, 8> Visited; for (ArrayRef<Value *> Candidates : ReducedVals) { for (Value *RdxVal : Candidates) { @@ -11339,7 +12833,7 @@ public: continue; unsigned NumOps = VectorizedVals.lookup(RdxVal); for (Instruction *RedOp : - makeArrayRef(ReducedValsToOps.find(RdxVal)->second) + ArrayRef(ReducedValsToOps.find(RdxVal)->second) .drop_back(NumOps)) ExtraReductions.emplace_back(RedOp, RdxVal); } @@ -11351,22 +12845,12 @@ public: } // Iterate through all not-vectorized reduction values/extra arguments. while (ExtraReductions.size() > 1) { + VectorizedTree = ExtraReductions.front().second; SmallVector<std::pair<Instruction *, Value *>> NewReds = FinalGen(ExtraReductions); ExtraReductions.swap(NewReds); } - // Final reduction. - if (ExtraReductions.size() == 1) { - Instruction *RedOp = ExtraReductions.back().first; - Builder.SetCurrentDebugLocation(RedOp->getDebugLoc()); - Value *RdxVal = ExtraReductions.back().second; - Value *StableRdxVal = RdxVal; - auto It = TrackedVals.find(RdxVal); - if (It != TrackedVals.end()) - StableRdxVal = It->second; - VectorizedTree = createOp(Builder, RdxKind, VectorizedTree, - StableRdxVal, "op.rdx", ReductionOps); - } + VectorizedTree = ExtraReductions.front().second; ReductionRoot->replaceAllUsesWith(VectorizedTree); @@ -11497,7 +12981,7 @@ private: } // end anonymous namespace -static Optional<unsigned> getAggregateSize(Instruction *InsertInst) { +static std::optional<unsigned> getAggregateSize(Instruction *InsertInst) { if (auto *IE = dyn_cast<InsertElementInst>(InsertInst)) return cast<FixedVectorType>(IE->getType())->getNumElements(); @@ -11508,7 +12992,7 @@ static Optional<unsigned> getAggregateSize(Instruction *InsertInst) { if (auto *ST = dyn_cast<StructType>(CurrentType)) { for (auto *Elt : ST->elements()) if (Elt != ST->getElementType(0)) // check homogeneity - return None; + return std::nullopt; AggregateSize *= ST->getNumElements(); CurrentType = ST->getElementType(0); } else if (auto *AT = dyn_cast<ArrayType>(CurrentType)) { @@ -11520,7 +13004,7 @@ static Optional<unsigned> getAggregateSize(Instruction *InsertInst) { } else if (CurrentType->isSingleValueType()) { return AggregateSize; } else { - return None; + return std::nullopt; } } while (true); } @@ -11532,12 +13016,11 @@ static void findBuildAggregate_rec(Instruction *LastInsertInst, unsigned OperandOffset) { do { Value *InsertedOperand = LastInsertInst->getOperand(1); - Optional<unsigned> OperandIndex = + std::optional<unsigned> OperandIndex = getInsertIndex(LastInsertInst, OperandOffset); if (!OperandIndex) return; - if (isa<InsertElementInst>(InsertedOperand) || - isa<InsertValueInst>(InsertedOperand)) { + if (isa<InsertElementInst, InsertValueInst>(InsertedOperand)) { findBuildAggregate_rec(cast<Instruction>(InsertedOperand), TTI, BuildVectorOpds, InsertElts, *OperandIndex); @@ -11547,8 +13030,7 @@ static void findBuildAggregate_rec(Instruction *LastInsertInst, } LastInsertInst = dyn_cast<Instruction>(LastInsertInst->getOperand(0)); } while (LastInsertInst != nullptr && - (isa<InsertValueInst>(LastInsertInst) || - isa<InsertElementInst>(LastInsertInst)) && + isa<InsertValueInst, InsertElementInst>(LastInsertInst) && LastInsertInst->hasOneUse()); } @@ -11578,7 +13060,7 @@ static bool findBuildAggregate(Instruction *LastInsertInst, assert((BuildVectorOpds.empty() && InsertElts.empty()) && "Expected empty result vectors!"); - Optional<unsigned> AggregateSize = getAggregateSize(LastInsertInst); + std::optional<unsigned> AggregateSize = getAggregateSize(LastInsertInst); if (!AggregateSize) return false; BuildVectorOpds.resize(*AggregateSize); @@ -11662,28 +13144,19 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) { return false; } -/// Attempt to reduce a horizontal reduction. -/// If it is legal to match a horizontal reduction feeding the phi node \a P -/// with reduction operators \a Root (or one of its operands) in a basic block -/// \a BB, then check if it can be done. If horizontal reduction is not found -/// and root instruction is a binary operation, vectorization of the operands is -/// attempted. -/// \returns true if a horizontal reduction was matched and reduced or operands -/// of one of the binary instruction were vectorized. -/// \returns false if a horizontal reduction was not matched (or not possible) -/// or no vectorization of any binary operation feeding \a Root instruction was -/// performed. -static bool tryToVectorizeHorReductionOrInstOperands( - PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, - TargetTransformInfo *TTI, ScalarEvolution &SE, const DataLayout &DL, - const TargetLibraryInfo &TLI, - const function_ref<bool(Instruction *, BoUpSLP &)> Vectorize) { +bool SLPVectorizerPass::vectorizeHorReduction( + PHINode *P, Value *V, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI, + SmallVectorImpl<WeakTrackingVH> &PostponedInsts) { if (!ShouldVectorizeHor) return false; + auto *Root = dyn_cast_or_null<Instruction>(V); if (!Root) return false; + if (!isa<BinaryOperator>(Root)) + P = nullptr; + if (Root->getParent() != BB || isa<PHINode>(Root)) return false; // Start analysis starting from Root instruction. If horizontal reduction is @@ -11695,25 +13168,22 @@ static bool tryToVectorizeHorReductionOrInstOperands( // horizontal reduction. // Interrupt the process if the Root instruction itself was vectorized or all // sub-trees not higher that RecursionMaxDepth were analyzed/vectorized. - // Skip the analysis of CmpInsts. Compiler implements postanalysis of the - // CmpInsts so we can skip extra attempts in - // tryToVectorizeHorReductionOrInstOperands and save compile time. + // If a horizintal reduction was not matched or vectorized we collect + // instructions for possible later attempts for vectorization. std::queue<std::pair<Instruction *, unsigned>> Stack; Stack.emplace(Root, 0); SmallPtrSet<Value *, 8> VisitedInstrs; - SmallVector<WeakTrackingVH> PostponedInsts; bool Res = false; - auto &&TryToReduce = [TTI, &SE, &DL, &P, &R, &TLI](Instruction *Inst, - Value *&B0, - Value *&B1) -> Value * { + auto &&TryToReduce = [this, TTI, &P, &R](Instruction *Inst, Value *&B0, + Value *&B1) -> Value * { if (R.isAnalyzedReductionRoot(Inst)) return nullptr; bool IsBinop = matchRdxBop(Inst, B0, B1); bool IsSelect = match(Inst, m_Select(m_Value(), m_Value(), m_Value())); if (IsBinop || IsSelect) { HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, Inst, SE, DL, TLI)) - return HorRdx.tryToReduce(R, TTI); + if (HorRdx.matchAssociativeReduction(P, Inst, *SE, *DL, *TLI)) + return HorRdx.tryToReduce(R, TTI, *TLI); } return nullptr; }; @@ -11754,9 +13224,8 @@ static bool tryToVectorizeHorReductionOrInstOperands( // Set P to nullptr to avoid re-analysis of phi node in // matchAssociativeReduction function unless this is the root node. P = nullptr; - // Do not try to vectorize CmpInst operands, this is done separately. - // Final attempt for binop args vectorization should happen after the loop - // to try to find reductions. + // Do not collect CmpInst or InsertElementInst/InsertValueInst as their + // analysis is done separately. if (!isa<CmpInst, InsertElementInst, InsertValueInst>(Inst)) PostponedInsts.push_back(Inst); } @@ -11774,29 +13243,25 @@ static bool tryToVectorizeHorReductionOrInstOperands( !R.isDeleted(I) && I->getParent() == BB) Stack.emplace(I, Level); } - // Try to vectorized binops where reductions were not found. - for (Value *V : PostponedInsts) - if (auto *Inst = dyn_cast<Instruction>(V)) - if (!R.isDeleted(Inst)) - Res |= Vectorize(Inst, R); return Res; } bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI) { - auto *I = dyn_cast_or_null<Instruction>(V); - if (!I) - return false; + SmallVector<WeakTrackingVH> PostponedInsts; + bool Res = vectorizeHorReduction(P, V, BB, R, TTI, PostponedInsts); + Res |= tryToVectorize(PostponedInsts, R); + return Res; +} - if (!isa<BinaryOperator>(I)) - P = nullptr; - // Try to match and vectorize a horizontal reduction. - auto &&ExtraVectorization = [this](Instruction *I, BoUpSLP &R) -> bool { - return tryToVectorize(I, R); - }; - return tryToVectorizeHorReductionOrInstOperands(P, I, BB, R, TTI, *SE, *DL, - *TLI, ExtraVectorization); +bool SLPVectorizerPass::tryToVectorize(ArrayRef<WeakTrackingVH> Insts, + BoUpSLP &R) { + bool Res = false; + for (Value *V : Insts) + if (auto *Inst = dyn_cast<Instruction>(V); Inst && !R.isDeleted(Inst)) + Res |= tryToVectorize(Inst, R); + return Res; } bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI, @@ -11866,7 +13331,7 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming, // same/alternate ops only, this may result in some extra final // vectorization. if (NumElts > 1 && - TryToVectorizeHelper(makeArrayRef(IncIt, NumElts), LimitForRegisterSize)) { + TryToVectorizeHelper(ArrayRef(IncIt, NumElts), LimitForRegisterSize)) { // Success start over because instructions might have been changed. Changed = true; } else if (NumElts < Limit(*IncIt) && @@ -11888,8 +13353,9 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming, while (SameTypeIt != End && AreCompatible(*SameTypeIt, *It)) ++SameTypeIt; unsigned NumElts = (SameTypeIt - It); - if (NumElts > 1 && TryToVectorizeHelper(makeArrayRef(It, NumElts), - /*LimitForRegisterSize=*/false)) + if (NumElts > 1 && + TryToVectorizeHelper(ArrayRef(It, NumElts), + /*LimitForRegisterSize=*/false)) Changed = true; It = SameTypeIt; } @@ -11911,7 +13377,7 @@ tryToVectorizeSequence(SmallVectorImpl<T *> &Incoming, /// predicate of the second or the operands IDs are less than the operands IDs /// of the second cmp instruction. template <bool IsCompatibility> -static bool compareCmp(Value *V, Value *V2, +static bool compareCmp(Value *V, Value *V2, TargetLibraryInfo &TLI, function_ref<bool(Instruction *)> IsDeleted) { auto *CI1 = cast<CmpInst>(V); auto *CI2 = cast<CmpInst>(V2); @@ -11947,7 +13413,7 @@ static bool compareCmp(Value *V, Value *V2, if (auto *I2 = dyn_cast<Instruction>(Op2)) { if (I1->getParent() != I2->getParent()) return false; - InstructionsState S = getSameOpcode({I1, I2}); + InstructionsState S = getSameOpcode({I1, I2}, TLI); if (S.getOpcode()) continue; return false; @@ -11956,25 +13422,35 @@ static bool compareCmp(Value *V, Value *V2, return IsCompatibility; } -bool SLPVectorizerPass::vectorizeSimpleInstructions( - SmallVectorImpl<Instruction *> &Instructions, BasicBlock *BB, BoUpSLP &R, - bool AtTerminator) { +bool SLPVectorizerPass::vectorizeSimpleInstructions(InstSetVector &Instructions, + BasicBlock *BB, BoUpSLP &R, + bool AtTerminator) { bool OpsChanged = false; SmallVector<Instruction *, 4> PostponedCmps; + SmallVector<WeakTrackingVH> PostponedInsts; + // pass1 - try to vectorize reductions only for (auto *I : reverse(Instructions)) { if (R.isDeleted(I)) continue; + if (isa<CmpInst>(I)) { + PostponedCmps.push_back(I); + continue; + } + OpsChanged |= vectorizeHorReduction(nullptr, I, BB, R, TTI, PostponedInsts); + } + // pass2 - try to match and vectorize a buildvector sequence. + for (auto *I : reverse(Instructions)) { + if (R.isDeleted(I) || isa<CmpInst>(I)) + continue; if (auto *LastInsertValue = dyn_cast<InsertValueInst>(I)) { OpsChanged |= vectorizeInsertValueInst(LastInsertValue, BB, R); } else if (auto *LastInsertElem = dyn_cast<InsertElementInst>(I)) { OpsChanged |= vectorizeInsertElementInst(LastInsertElem, BB, R); - } else if (isa<CmpInst>(I)) { - PostponedCmps.push_back(I); - continue; } - // Try to find reductions in buildvector sequnces. - OpsChanged |= vectorizeRootInstruction(nullptr, I, BB, R, TTI); } + // Now try to vectorize postponed instructions. + OpsChanged |= tryToVectorize(PostponedInsts, R); + if (AtTerminator) { // Try to find reductions first. for (Instruction *I : PostponedCmps) { @@ -11991,15 +13467,15 @@ bool SLPVectorizerPass::vectorizeSimpleInstructions( } // Try to vectorize list of compares. // Sort by type, compare predicate, etc. - auto &&CompareSorter = [&R](Value *V, Value *V2) { - return compareCmp<false>(V, V2, + auto CompareSorter = [&](Value *V, Value *V2) { + return compareCmp<false>(V, V2, *TLI, [&R](Instruction *I) { return R.isDeleted(I); }); }; - auto &&AreCompatibleCompares = [&R](Value *V1, Value *V2) { + auto AreCompatibleCompares = [&](Value *V1, Value *V2) { if (V1 == V2) return true; - return compareCmp<true>(V1, V2, + return compareCmp<true>(V1, V2, *TLI, [&R](Instruction *I) { return R.isDeleted(I); }); }; auto Limit = [&R](Value *V) { @@ -12027,9 +13503,10 @@ bool SLPVectorizerPass::vectorizeSimpleInstructions( /*LimitForRegisterSize=*/true); Instructions.clear(); } else { + Instructions.clear(); // Insert in reverse order since the PostponedCmps vector was filled in // reverse order. - Instructions.assign(PostponedCmps.rbegin(), PostponedCmps.rend()); + Instructions.insert(PostponedCmps.rbegin(), PostponedCmps.rend()); } return OpsChanged; } @@ -12058,7 +13535,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { return true; if (Opcodes1.size() > Opcodes2.size()) return false; - Optional<bool> ConstOrder; + std::optional<bool> ConstOrder; for (int I = 0, E = Opcodes1.size(); I < E; ++I) { // Undefs are compatible with any other value. if (isa<UndefValue>(Opcodes1[I]) || isa<UndefValue>(Opcodes2[I])) { @@ -12080,7 +13557,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { "Different nodes should have different DFS numbers"); if (NodeI1 != NodeI2) return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); - InstructionsState S = getSameOpcode({I1, I2}); + InstructionsState S = getSameOpcode({I1, I2}, *TLI); if (S.getOpcode()) continue; return I1->getOpcode() < I2->getOpcode(); @@ -12097,7 +13574,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } return ConstOrder && *ConstOrder; }; - auto AreCompatiblePHIs = [&PHIToOpcodes](Value *V1, Value *V2) { + auto AreCompatiblePHIs = [&PHIToOpcodes, this](Value *V1, Value *V2) { if (V1 == V2) return true; if (V1->getType() != V2->getType()) @@ -12114,7 +13591,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) { if (I1->getParent() != I2->getParent()) return false; - InstructionsState S = getSameOpcode({I1, I2}); + InstructionsState S = getSameOpcode({I1, I2}, *TLI); if (S.getOpcode()) continue; return false; @@ -12182,7 +13659,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { VisitedInstrs.clear(); - SmallVector<Instruction *, 8> PostProcessInstructions; + InstSetVector PostProcessInstructions; SmallDenseSet<Instruction *, 4> KeyNodes; for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { // Skip instructions with scalable type. The num of elements is unknown at @@ -12234,8 +13711,12 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { !DT->isReachableFromEntry(P->getIncomingBlock(I))) continue; - Changed |= vectorizeRootInstruction(nullptr, P->getIncomingValue(I), - P->getIncomingBlock(I), R, TTI); + // Postponed instructions should not be vectorized here, delay their + // vectorization. + if (auto *PI = dyn_cast<Instruction>(P->getIncomingValue(I)); + PI && !PostProcessInstructions.contains(PI)) + Changed |= vectorizeRootInstruction(nullptr, P->getIncomingValue(I), + P->getIncomingBlock(I), R, TTI); } continue; } @@ -12243,14 +13724,31 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Ran into an instruction without users, like terminator, or function call // with ignored return value, store. Ignore unused instructions (basing on // instruction type, except for CallInst and InvokeInst). - if (it->use_empty() && (it->getType()->isVoidTy() || isa<CallInst>(it) || - isa<InvokeInst>(it))) { + if (it->use_empty() && + (it->getType()->isVoidTy() || isa<CallInst, InvokeInst>(it))) { KeyNodes.insert(&*it); bool OpsChanged = false; - if (ShouldStartVectorizeHorAtStore || !isa<StoreInst>(it)) { + auto *SI = dyn_cast<StoreInst>(it); + bool TryToVectorizeRoot = ShouldStartVectorizeHorAtStore || !SI; + if (SI) { + auto I = Stores.find(getUnderlyingObject(SI->getPointerOperand())); + // Try to vectorize chain in store, if this is the only store to the + // address in the block. + // TODO: This is just a temporarily solution to save compile time. Need + // to investigate if we can safely turn on slp-vectorize-hor-store + // instead to allow lookup for reduction chains in all non-vectorized + // stores (need to check side effects and compile time). + TryToVectorizeRoot = (I == Stores.end() || I->second.size() == 1) && + SI->getValueOperand()->hasOneUse(); + } + if (TryToVectorizeRoot) { for (auto *V : it->operand_values()) { - // Try to match and vectorize a horizontal reduction. - OpsChanged |= vectorizeRootInstruction(nullptr, V, BB, R, TTI); + // Postponed instructions should not be vectorized here, delay their + // vectorization. + if (auto *VI = dyn_cast<Instruction>(V); + VI && !PostProcessInstructions.contains(VI)) + // Try to match and vectorize a horizontal reduction. + OpsChanged |= vectorizeRootInstruction(nullptr, V, BB, R, TTI); } } // Start vectorization of post-process list of instructions from the @@ -12268,9 +13766,8 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } } - if (isa<InsertElementInst>(it) || isa<CmpInst>(it) || - isa<InsertValueInst>(it)) - PostProcessInstructions.push_back(&*it); + if (isa<CmpInst, InsertElementInst, InsertValueInst>(it)) + PostProcessInstructions.insert(&*it); } return Changed; @@ -12397,7 +13894,7 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { "Different nodes should have different DFS numbers"); if (NodeI1 != NodeI2) return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); - InstructionsState S = getSameOpcode({I1, I2}); + InstructionsState S = getSameOpcode({I1, I2}, *TLI); if (S.getOpcode()) return false; return I1->getOpcode() < I2->getOpcode(); @@ -12409,7 +13906,7 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { V2->getValueOperand()->getValueID(); }; - auto &&AreCompatibleStores = [](StoreInst *V1, StoreInst *V2) { + auto &&AreCompatibleStores = [this](StoreInst *V1, StoreInst *V2) { if (V1 == V2) return true; if (V1->getPointerOperandType() != V2->getPointerOperandType()) @@ -12422,7 +13919,7 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { if (auto *I2 = dyn_cast<Instruction>(V2->getValueOperand())) { if (I1->getParent() != I2->getParent()) return false; - InstructionsState S = getSameOpcode({I1, I2}); + InstructionsState S = getSameOpcode({I1, I2}, *TLI); return S.getOpcode() > 0; } if (isa<Constant>(V1->getValueOperand()) && diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h index 07d3fa56020b..733d2e1c667b 100644 --- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h +++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -100,7 +100,8 @@ class VPRecipeBuilder { /// Check if \p I has an opcode that can be widened and return a VPWidenRecipe /// if it can. The function should only be called if the cost-model indicates /// that widening should be performed. - VPWidenRecipe *tryToWiden(Instruction *I, ArrayRef<VPValue *> Operands) const; + VPRecipeBase *tryToWiden(Instruction *I, ArrayRef<VPValue *> Operands, + VPBasicBlock *VPBB, VPlanPtr &Plan); /// Return a VPRecipeOrValueTy with VPRecipeBase * being set. This can be used to force the use as VPRecipeBase* for recipe sub-types that also inherit from VPValue. VPRecipeOrVPValueTy toVPRecipeResult(VPRecipeBase *R) const { return R; } @@ -119,7 +120,8 @@ public: /// VPRecipeOrVPValueTy with nullptr. VPRecipeOrVPValueTy tryToCreateWidenRecipe(Instruction *Instr, ArrayRef<VPValue *> Operands, - VFRange &Range, VPlanPtr &Plan); + VFRange &Range, VPBasicBlock *VPBB, + VPlanPtr &Plan); /// Set the recipe created for given ingredient. This operation is a no-op for /// ingredients that were not marked using a nullptr entry in the map. diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 30032dda7f60..d554f438c804 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -17,6 +17,7 @@ //===----------------------------------------------------------------------===// #include "VPlan.h" +#include "VPlanCFG.h" #include "VPlanDominatorTree.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" @@ -109,6 +110,14 @@ void VPDef::dump() const { } #endif +VPRecipeBase *VPValue::getDefiningRecipe() { + return cast_or_null<VPRecipeBase>(Def); +} + +const VPRecipeBase *VPValue::getDefiningRecipe() const { + return cast_or_null<VPRecipeBase>(Def); +} + // Get the top-most entry block of \p Start. This is the entry block of the // containing VPlan. This function is templated to support both const and non-const blocks template <typename T> static T *getPlanEntry(T *Start) { @@ -188,9 +197,7 @@ VPBlockBase *VPBlockBase::getEnclosingBlockWithPredecessors() { } void VPBlockBase::deleteCFG(VPBlockBase *Entry) { - SmallVector<VPBlockBase *, 8> Blocks(depth_first(Entry)); - - for (VPBlockBase *Block : Blocks) + for (VPBlockBase *Block : to_vector(vp_depth_first_shallow(Entry))) delete Block; } @@ -202,7 +209,7 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() { } Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) { - if (!Def->getDef()) + if (!Def->hasDefiningRecipe()) return Def->getLiveInIRValue(); if (hasScalarValue(Def, Instance)) { @@ -257,7 +264,7 @@ void VPTransformState::setDebugLocFromInst(const Value *V) { const DILocation *DIL = Inst->getDebugLoc(); // When a FSDiscriminator is enabled, we don't need to add the multiply // factors to the discriminators. - if (DIL && Inst->getFunction()->isDebugInfoForProfiling() && + if (DIL && Inst->getFunction()->shouldEmitDebugInfoForProfiling() && !isa<DbgInfoIntrinsic>(Inst) && !EnableFSDiscriminator) { // FIXME: For scalable vectors, assume vscale=1. auto NewDIL = @@ -497,14 +504,15 @@ void VPBasicBlock::print(raw_ostream &O, const Twine &Indent, #endif void VPRegionBlock::dropAllReferences(VPValue *NewValue) { - for (VPBlockBase *Block : depth_first(Entry)) + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) // Drop all references in VPBasicBlocks and replace all uses with // DummyValue. Block->dropAllReferences(NewValue); } void VPRegionBlock::execute(VPTransformState *State) { - ReversePostOrderTraversal<VPBlockBase *> RPOT(Entry); + ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> + RPOT(Entry); if (!isReplicator()) { // Create and register the new vector loop. @@ -558,7 +566,7 @@ void VPRegionBlock::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << (isReplicator() ? "<xVFxUF> " : "<x1> ") << getName() << ": {"; auto NewIndent = Indent + " "; - for (auto *BlockBase : depth_first(Entry)) { + for (auto *BlockBase : vp_depth_first_shallow(Entry)) { O << '\n'; BlockBase->print(O, NewIndent, SlotTracker); } @@ -568,6 +576,26 @@ void VPRegionBlock::print(raw_ostream &O, const Twine &Indent, } #endif +VPlan::~VPlan() { + clearLiveOuts(); + + if (Entry) { + VPValue DummyValue; + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) + Block->dropAllReferences(&DummyValue); + + VPBlockBase::deleteCFG(Entry); + } + for (VPValue *VPV : VPValuesToFree) + delete VPV; + if (TripCount) + delete TripCount; + if (BackedgeTakenCount) + delete BackedgeTakenCount; + for (auto &P : VPExternalDefs) + delete P.second; +} + VPActiveLaneMaskPHIRecipe *VPlan::getActiveLaneMaskPhi() { VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock(); for (VPRecipeBase &R : Header->phis()) { @@ -577,45 +605,11 @@ VPActiveLaneMaskPHIRecipe *VPlan::getActiveLaneMaskPhi() { return nullptr; } -static bool canSimplifyBranchOnCond(VPInstruction *Term) { - VPInstruction *Not = dyn_cast<VPInstruction>(Term->getOperand(0)); - if (!Not || Not->getOpcode() != VPInstruction::Not) - return false; - - VPInstruction *ALM = dyn_cast<VPInstruction>(Not->getOperand(0)); - return ALM && ALM->getOpcode() == VPInstruction::ActiveLaneMask; -} - void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, Value *CanonicalIVStartValue, VPTransformState &State, bool IsEpilogueVectorization) { - VPBasicBlock *ExitingVPBB = getVectorLoopRegion()->getExitingBasicBlock(); - auto *Term = dyn_cast<VPInstruction>(&ExitingVPBB->back()); - // Try to simplify the branch condition if TC <= VF * UF when preparing to - // execute the plan for the main vector loop. We only do this if the - // terminator is: - // 1. BranchOnCount, or - // 2. BranchOnCond where the input is Not(ActiveLaneMask). - if (!IsEpilogueVectorization && Term && isa<ConstantInt>(TripCountV) && - (Term->getOpcode() == VPInstruction::BranchOnCount || - (Term->getOpcode() == VPInstruction::BranchOnCond && - canSimplifyBranchOnCond(Term)))) { - ConstantInt *C = cast<ConstantInt>(TripCountV); - uint64_t TCVal = C->getZExtValue(); - if (TCVal && TCVal <= State.VF.getKnownMinValue() * State.UF) { - auto *BOC = - new VPInstruction(VPInstruction::BranchOnCond, - {getOrAddExternalDef(State.Builder.getTrue())}); - Term->eraseFromParent(); - ExitingVPBB->appendRecipe(BOC); - // TODO: Further simplifications are possible - // 1. Replace inductions with constants. - // 2. Replace vector loop region with VPBasicBlock. - } - } - // Check if the trip count is needed, and if so build it. if (TripCount && TripCount->getNumUsers()) { for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) @@ -640,12 +634,14 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, // When vectorizing the epilogue loop, the canonical induction start value // needs to be changed from zero to the value after the main vector loop. + // FIXME: Improve modeling for canonical IV start values in the epilogue loop. if (CanonicalIVStartValue) { VPValue *VPV = getOrAddExternalDef(CanonicalIVStartValue); auto *IV = getCanonicalIV(); assert(all_of(IV->users(), [](const VPUser *U) { - if (isa<VPScalarIVStepsRecipe>(U)) + if (isa<VPScalarIVStepsRecipe>(U) || + isa<VPDerivedIVRecipe>(U)) return true; auto *VPI = cast<VPInstruction>(U); return VPI->getOpcode() == @@ -675,7 +671,7 @@ void VPlan::execute(VPTransformState *State) { State->Builder.SetInsertPoint(VectorPreHeader->getTerminator()); // Generate code in the loop pre-header and body. - for (VPBlockBase *Block : depth_first(Entry)) + for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) Block->execute(State); VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock(); @@ -747,7 +743,7 @@ LLVM_DUMP_METHOD void VPlan::print(raw_ostream &O) const { VPSlotTracker SlotTracker(this); - O << "VPlan '" << Name << "' {"; + O << "VPlan '" << getName() << "' {"; if (VectorTripCount.getNumUsers() > 0) { O << "\nLive-in "; @@ -761,14 +757,14 @@ void VPlan::print(raw_ostream &O) const { O << " = backedge-taken count\n"; } - for (const VPBlockBase *Block : depth_first(getEntry())) { + for (const VPBlockBase *Block : vp_depth_first_shallow(getEntry())) { O << '\n'; Block->print(O, "", SlotTracker); } if (!LiveOuts.empty()) O << "\n"; - for (auto &KV : LiveOuts) { + for (const auto &KV : LiveOuts) { O << "Live-out "; KV.second->getPhi()->printAsOperand(O); O << " = "; @@ -779,6 +775,29 @@ void VPlan::print(raw_ostream &O) const { O << "}\n"; } +std::string VPlan::getName() const { + std::string Out; + raw_string_ostream RSO(Out); + RSO << Name << " for "; + if (!VFs.empty()) { + RSO << "VF={" << VFs[0]; + for (ElementCount VF : drop_begin(VFs)) + RSO << "," << VF; + RSO << "},"; + } + + if (UFs.empty()) { + RSO << "UF>=1"; + } else { + RSO << "UF={" << UFs[0]; + for (unsigned UF : drop_begin(UFs)) + RSO << "," << UF; + RSO << "}"; + } + + return Out; +} + LLVM_DUMP_METHOD void VPlan::printDOT(raw_ostream &O) const { VPlanPrinter Printer(O, *this); @@ -863,7 +882,7 @@ void VPlanPrinter::dump() { OS << "edge [fontname=Courier, fontsize=30]\n"; OS << "compound=true\n"; - for (const VPBlockBase *Block : depth_first(Plan.getEntry())) + for (const VPBlockBase *Block : vp_depth_first_shallow(Plan.getEntry())) dumpBlock(Block); OS << "}\n"; @@ -948,7 +967,7 @@ void VPlanPrinter::dumpRegion(const VPRegionBlock *Region) { << DOT::EscapeString(Region->getName()) << "\"\n"; // Dump the blocks of the region. assert(Region->getEntry() && "Region contains no inner blocks."); - for (const VPBlockBase *Block : depth_first(Region->getEntry())) + for (const VPBlockBase *Block : vp_depth_first_shallow(Region->getEntry())) dumpBlock(Block); bumpIndent(-1); OS << Indent << "}\n"; @@ -1017,7 +1036,8 @@ void VPUser::printOperands(raw_ostream &O, VPSlotTracker &SlotTracker) const { void VPInterleavedAccessInfo::visitRegion(VPRegionBlock *Region, Old2NewTy &Old2New, InterleavedAccessInfo &IAI) { - ReversePostOrderTraversal<VPBlockBase *> RPOT(Region->getEntry()); + ReversePostOrderTraversal<VPBlockShallowTraversalWrapper<VPBlockBase *>> + RPOT(Region->getEntry()); for (VPBlockBase *Base : RPOT) { visitBlock(Base, Old2New, IAI); } @@ -1079,10 +1099,8 @@ void VPSlotTracker::assignSlots(const VPlan &Plan) { if (Plan.BackedgeTakenCount) assignSlot(Plan.BackedgeTakenCount); - ReversePostOrderTraversal< - VPBlockRecursiveTraversalWrapper<const VPBlockBase *>> - RPOT(VPBlockRecursiveTraversalWrapper<const VPBlockBase *>( - Plan.getEntry())); + ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<const VPBlockBase *>> + RPOT(VPBlockDeepTraversalWrapper<const VPBlockBase *>(Plan.getEntry())); for (const VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<const VPBasicBlock>(RPOT)) for (const VPRecipeBase &Recipe : *VPBB) @@ -1103,7 +1121,7 @@ VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, return Plan.getOrAddExternalDef(E->getValue()); VPBasicBlock *Preheader = Plan.getEntry()->getEntryBasicBlock(); - VPValue *Step = new VPExpandSCEVRecipe(Expr, SE); - Preheader->appendRecipe(cast<VPRecipeBase>(Step->getDef())); + VPExpandSCEVRecipe *Step = new VPExpandSCEVRecipe(Expr, SE); + Preheader->appendRecipe(Step); return Step; } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index f009a7ee6b4b..986faaf99664 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -10,14 +10,12 @@ /// This file contains the declarations of the Vectorization Plan base classes: /// 1. VPBasicBlock and VPRegionBlock that inherit from a common pure virtual /// VPBlockBase, together implementing a Hierarchical CFG; -/// 2. Specializations of GraphTraits that allow VPBlockBase graphs to be -/// treated as proper graphs for generic algorithms; -/// 3. Pure virtual VPRecipeBase serving as the base class for recipes contained +/// 2. Pure virtual VPRecipeBase serving as the base class for recipes contained /// within VPBasicBlocks; -/// 4. VPInstruction, a concrete Recipe and VPUser modeling a single planned +/// 3. VPInstruction, a concrete Recipe and VPUser modeling a single planned /// instruction; -/// 5. The VPlan class holding a candidate for vectorization; -/// 6. The VPlanPrinter class providing a way to print a plan in dot format; +/// 4. The VPlan class holding a candidate for vectorization; +/// 5. The VPlanPrinter class providing a way to print a plan in dot format; /// These are documented in docs/VectorizationPlan.rst. // //===----------------------------------------------------------------------===// @@ -28,9 +26,7 @@ #include "VPlanValue.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -55,14 +51,21 @@ class InductionDescriptor; class InnerLoopVectorizer; class IRBuilderBase; class LoopInfo; +class PredicateScalarEvolution; class raw_ostream; class RecurrenceDescriptor; -class Value; +class SCEV; +class Type; class VPBasicBlock; class VPRegionBlock; class VPlan; class VPReplicateRecipe; class VPlanSlp; +class Value; + +namespace Intrinsic { +typedef unsigned ID; +} /// Returns a calculation for the total number of elements for a given \p VF. /// For fixed width vectors this value is a constant, whereas for scalable @@ -73,6 +76,8 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF); Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF, int64_t Step); +const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE); + /// A range of powers-of-2 vectorization factors with fixed start and /// adjustable end. The range includes start and excludes end, e.g.,: /// [1, 9) = {1, 2, 4, 8} @@ -212,7 +217,7 @@ struct VPTransformState { /// Hold the indices to generate specific scalar instructions. Null indicates /// that all instances are to be generated, using either scalar or vector /// instructions. - Optional<VPIteration> Instance; + std::optional<VPIteration> Instance; struct DataState { /// A type for vectorized values in the new loop. Each value from the @@ -580,7 +585,7 @@ public: /// The method which generates the output IR that correspond to this /// VPBlockBase, thereby "executing" the VPlan. - virtual void execute(struct VPTransformState *State) = 0; + virtual void execute(VPTransformState *State) = 0; /// Delete all blocks reachable from a given VPBlockBase, inclusive. static void deleteCFG(VPBlockBase *Entry); @@ -680,7 +685,7 @@ public: /// The method which generates the output IR instructions that correspond to /// this VPRecipe, thereby "executing" the VPlan. - virtual void execute(struct VPTransformState &State) = 0; + virtual void execute(VPTransformState &State) = 0; /// Insert an unlinked recipe into a basic block immediately before /// the specified recipe. @@ -750,19 +755,22 @@ public: } }; -inline bool VPUser::classof(const VPDef *Def) { - return Def->getVPDefID() == VPRecipeBase::VPInstructionSC || - Def->getVPDefID() == VPRecipeBase::VPWidenSC || - Def->getVPDefID() == VPRecipeBase::VPWidenCallSC || - Def->getVPDefID() == VPRecipeBase::VPWidenSelectSC || - Def->getVPDefID() == VPRecipeBase::VPWidenGEPSC || - Def->getVPDefID() == VPRecipeBase::VPBlendSC || - Def->getVPDefID() == VPRecipeBase::VPInterleaveSC || - Def->getVPDefID() == VPRecipeBase::VPReplicateSC || - Def->getVPDefID() == VPRecipeBase::VPReductionSC || - Def->getVPDefID() == VPRecipeBase::VPBranchOnMaskSC || - Def->getVPDefID() == VPRecipeBase::VPWidenMemoryInstructionSC; -} +// Helper macro to define common classof implementations for recipes. +#define VP_CLASSOF_IMPL(VPDefID) \ + static inline bool classof(const VPDef *D) { \ + return D->getVPDefID() == VPDefID; \ + } \ + static inline bool classof(const VPValue *V) { \ + auto *R = V->getDefiningRecipe(); \ + return R && R->getVPDefID() == VPDefID; \ + } \ + static inline bool classof(const VPUser *U) { \ + auto *R = dyn_cast<VPRecipeBase>(U); \ + return R && R->getVPDefID() == VPDefID; \ + } \ + static inline bool classof(const VPRecipeBase *R) { \ + return R->getVPDefID() == VPDefID; \ + } /// This is a concrete Recipe that models a single VPlan-level instruction. /// While as any Recipe it may generate a sequence of IR instructions when @@ -811,39 +819,20 @@ protected: public: VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL, const Twine &Name = "") - : VPRecipeBase(VPRecipeBase::VPInstructionSC, Operands), - VPValue(VPValue::VPVInstructionSC, nullptr, this), Opcode(Opcode), - DL(DL), Name(Name.str()) {} + : VPRecipeBase(VPDef::VPInstructionSC, Operands), VPValue(this), + Opcode(Opcode), DL(DL), Name(Name.str()) {} VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands, DebugLoc DL = {}, const Twine &Name = "") : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands), DL, Name) {} - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVInstructionSC; - } + VP_CLASSOF_IMPL(VPDef::VPInstructionSC) VPInstruction *clone() const { SmallVector<VPValue *, 2> Operands(operands()); return new VPInstruction(Opcode, Operands, DL, Name); } - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *R) { - return R->getVPDefID() == VPRecipeBase::VPInstructionSC; - } - - /// Extra classof implementations to allow directly casting from VPUser -> - /// VPInstruction. - static inline bool classof(const VPUser *U) { - auto *R = dyn_cast<VPRecipeBase>(U); - return R && R->getVPDefID() == VPRecipeBase::VPInstructionSC; - } - static inline bool classof(const VPRecipeBase *R) { - return R->getVPDefID() == VPRecipeBase::VPInstructionSC; - } - unsigned getOpcode() const { return Opcode; } /// Generate the instruction. @@ -921,18 +910,11 @@ class VPWidenRecipe : public VPRecipeBase, public VPValue { public: template <typename IterT> VPWidenRecipe(Instruction &I, iterator_range<IterT> Operands) - : VPRecipeBase(VPRecipeBase::VPWidenSC, Operands), - VPValue(VPValue::VPVWidenSC, &I, this) {} + : VPRecipeBase(VPDef::VPWidenSC, Operands), VPValue(this, &I) {} ~VPWidenRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPWidenSC; - } - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVWidenSC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenSC) /// Produce widened copies of all Ingredients. void execute(VPTransformState &State) override; @@ -946,19 +928,20 @@ public: /// A recipe for widening Call instructions. class VPWidenCallRecipe : public VPRecipeBase, public VPValue { + /// ID of the vector intrinsic to call when widening the call. If set the + /// Intrinsic::not_intrinsic, a library call will be used instead. + Intrinsic::ID VectorIntrinsicID; public: template <typename IterT> - VPWidenCallRecipe(CallInst &I, iterator_range<IterT> CallArguments) - : VPRecipeBase(VPRecipeBase::VPWidenCallSC, CallArguments), - VPValue(VPValue::VPVWidenCallSC, &I, this) {} + VPWidenCallRecipe(CallInst &I, iterator_range<IterT> CallArguments, + Intrinsic::ID VectorIntrinsicID) + : VPRecipeBase(VPDef::VPWidenCallSC, CallArguments), VPValue(this, &I), + VectorIntrinsicID(VectorIntrinsicID) {} ~VPWidenCallRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPWidenCallSC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenCallSC) /// Produce a widened version of the call instruction. void execute(VPTransformState &State) override; @@ -980,16 +963,12 @@ public: template <typename IterT> VPWidenSelectRecipe(SelectInst &I, iterator_range<IterT> Operands, bool InvariantCond) - : VPRecipeBase(VPRecipeBase::VPWidenSelectSC, Operands), - VPValue(VPValue::VPVWidenSelectSC, &I, this), + : VPRecipeBase(VPDef::VPWidenSelectSC, Operands), VPValue(this, &I), InvariantCond(InvariantCond) {} ~VPWidenSelectRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPWidenSelectSC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenSelectSC) /// Produce a widened version of the select instruction. void execute(VPTransformState &State) override; @@ -1009,15 +988,13 @@ class VPWidenGEPRecipe : public VPRecipeBase, public VPValue { public: template <typename IterT> VPWidenGEPRecipe(GetElementPtrInst *GEP, iterator_range<IterT> Operands) - : VPRecipeBase(VPRecipeBase::VPWidenGEPSC, Operands), - VPValue(VPWidenGEPSC, GEP, this), + : VPRecipeBase(VPDef::VPWidenGEPSC, Operands), VPValue(this, GEP), IsIndexLoopInvariant(GEP->getNumIndices(), false) {} template <typename IterT> VPWidenGEPRecipe(GetElementPtrInst *GEP, iterator_range<IterT> Operands, Loop *OrigLoop) - : VPRecipeBase(VPRecipeBase::VPWidenGEPSC, Operands), - VPValue(VPValue::VPVWidenGEPSC, GEP, this), + : VPRecipeBase(VPDef::VPWidenGEPSC, Operands), VPValue(this, GEP), IsIndexLoopInvariant(GEP->getNumIndices(), false) { IsPtrLoopInvariant = OrigLoop->isLoopInvariant(GEP->getPointerOperand()); for (auto Index : enumerate(GEP->indices())) @@ -1026,10 +1003,7 @@ public: } ~VPWidenGEPRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPWidenGEPSC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenGEPSC) /// Generate the gep nodes. void execute(VPTransformState &State) override; @@ -1052,23 +1026,20 @@ public: VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step, const InductionDescriptor &IndDesc, bool NeedsVectorIV) - : VPRecipeBase(VPWidenIntOrFpInductionSC, {Start, Step}), - VPValue(IV, this), IV(IV), IndDesc(IndDesc), + : VPRecipeBase(VPDef::VPWidenIntOrFpInductionSC, {Start, Step}), + VPValue(this, IV), IV(IV), IndDesc(IndDesc), NeedsVectorIV(NeedsVectorIV) {} VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step, const InductionDescriptor &IndDesc, TruncInst *Trunc, bool NeedsVectorIV) - : VPRecipeBase(VPWidenIntOrFpInductionSC, {Start, Step}), - VPValue(Trunc, this), IV(IV), IndDesc(IndDesc), + : VPRecipeBase(VPDef::VPWidenIntOrFpInductionSC, {Start, Step}), + VPValue(this, Trunc), IV(IV), IndDesc(IndDesc), NeedsVectorIV(NeedsVectorIV) {} ~VPWidenIntOrFpInductionRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPWidenIntOrFpInductionSC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenIntOrFpInductionSC) /// Generate the vectorized and scalarized versions of the phi node as /// needed by their users. @@ -1120,11 +1091,30 @@ public: /// phis for first order recurrences, pointer inductions and reductions. The /// start value is the first operand of the recipe and the incoming value from /// the backedge is the second operand. +/// +/// Inductions are modeled using the following sub-classes: +/// * VPCanonicalIVPHIRecipe: Canonical scalar induction of the vector loop, +/// starting at a specified value (zero for the main vector loop, the resume +/// value for the epilogue vector loop) and stepping by 1. The induction +/// controls exiting of the vector loop by comparing against the vector trip +/// count. Produces a single scalar PHI for the induction value per +/// iteration. +/// * VPWidenIntOrFpInductionRecipe: Generates vector values for integer and +/// floating point inductions with arbitrary start and step values. Produces +/// a vector PHI per-part. +/// * VPDerivedIVRecipe: Converts the canonical IV value to the corresponding +/// value of an IV with different start and step values. Produces a single +/// scalar value per iteration +/// * VPScalarIVStepsRecipe: Generates scalar values per-lane based on a +/// canonical or derived induction. +/// * VPWidenPointerInductionRecipe: Generate vector and scalar values for a +/// pointer induction. Produces either a vector PHI per-part or scalar values +/// per-lane based on the canonical induction. class VPHeaderPHIRecipe : public VPRecipeBase, public VPValue { protected: - VPHeaderPHIRecipe(unsigned char VPVID, unsigned char VPDefID, PHINode *Phi, + VPHeaderPHIRecipe(unsigned char VPDefID, PHINode *Phi, VPValue *Start = nullptr) - : VPRecipeBase(VPDefID, {}), VPValue(VPVID, Phi, this) { + : VPRecipeBase(VPDefID, {}), VPValue(this, Phi) { if (Start) addOperand(Start); } @@ -1134,20 +1124,13 @@ public: /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPRecipeBase *B) { - return B->getVPDefID() == VPRecipeBase::VPCanonicalIVPHISC || - B->getVPDefID() == VPRecipeBase::VPActiveLaneMaskPHISC || - B->getVPDefID() == VPRecipeBase::VPFirstOrderRecurrencePHISC || - B->getVPDefID() == VPRecipeBase::VPReductionPHISC || - B->getVPDefID() == VPRecipeBase::VPWidenIntOrFpInductionSC || - B->getVPDefID() == VPRecipeBase::VPWidenPHISC; + return B->getVPDefID() >= VPDef::VPFirstHeaderPHISC && + B->getVPDefID() <= VPDef::VPLastPHISC; } static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVCanonicalIVPHISC || - V->getVPValueID() == VPValue::VPVActiveLaneMaskPHISC || - V->getVPValueID() == VPValue::VPVFirstOrderRecurrencePHISC || - V->getVPValueID() == VPValue::VPVReductionPHISC || - V->getVPValueID() == VPValue::VPVWidenIntOrFpInductionSC || - V->getVPValueID() == VPValue::VPVWidenPHISC; + auto *B = V->getDefiningRecipe(); + return B && B->getVPDefID() >= VPRecipeBase::VPFirstHeaderPHISC && + B->getVPDefID() <= VPRecipeBase::VPLastPHISC; } /// Generate the phi nodes. @@ -1167,6 +1150,9 @@ public: return getNumOperands() == 0 ? nullptr : getOperand(0); } + /// Update the start value of the recipe. + void setStartValue(VPValue *V) { setOperand(0, V); } + /// Returns the incoming value from the loop backedge. VPValue *getBackedgeValue() { return getOperand(1); @@ -1174,43 +1160,32 @@ public: /// Returns the backedge value as a recipe. The backedge value is guaranteed /// to be a recipe. - VPRecipeBase *getBackedgeRecipe() { - return cast<VPRecipeBase>(getBackedgeValue()->getDef()); + VPRecipeBase &getBackedgeRecipe() { + return *getBackedgeValue()->getDefiningRecipe(); } }; class VPWidenPointerInductionRecipe : public VPHeaderPHIRecipe { const InductionDescriptor &IndDesc; - /// SCEV used to expand step. - /// FIXME: move expansion of step to the pre-header, once it is modeled - /// explicitly. - ScalarEvolution &SE; + bool IsScalarAfterVectorization; public: /// Create a new VPWidenPointerInductionRecipe for \p Phi with start value \p /// Start. - VPWidenPointerInductionRecipe(PHINode *Phi, VPValue *Start, + VPWidenPointerInductionRecipe(PHINode *Phi, VPValue *Start, VPValue *Step, const InductionDescriptor &IndDesc, - ScalarEvolution &SE) - : VPHeaderPHIRecipe(VPVWidenPointerInductionSC, VPWidenPointerInductionSC, - Phi), - IndDesc(IndDesc), SE(SE) { + bool IsScalarAfterVectorization) + : VPHeaderPHIRecipe(VPDef::VPWidenPointerInductionSC, Phi), + IndDesc(IndDesc), + IsScalarAfterVectorization(IsScalarAfterVectorization) { addOperand(Start); + addOperand(Step); } ~VPWidenPointerInductionRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPRecipeBase *B) { - return B->getVPDefID() == VPRecipeBase::VPWidenPointerInductionSC; - } - static inline bool classof(const VPHeaderPHIRecipe *R) { - return R->getVPDefID() == VPRecipeBase::VPWidenPointerInductionSC; - } - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVWidenPointerInductionSC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenPointerInductionSC) /// Generate vector values for the pointer induction. void execute(VPTransformState &State) override; @@ -1218,6 +1193,9 @@ public: /// Returns true if only scalar values will be generated. bool onlyScalarsGenerated(ElementCount VF); + /// Returns the induction descriptor for the recipe. + const InductionDescriptor &getInductionDescriptor() const { return IndDesc; } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, @@ -1235,23 +1213,14 @@ class VPWidenPHIRecipe : public VPHeaderPHIRecipe { public: /// Create a new VPWidenPHIRecipe for \p Phi with start value \p Start. VPWidenPHIRecipe(PHINode *Phi, VPValue *Start = nullptr) - : VPHeaderPHIRecipe(VPVWidenPHISC, VPWidenPHISC, Phi) { + : VPHeaderPHIRecipe(VPDef::VPWidenPHISC, Phi) { if (Start) addOperand(Start); } ~VPWidenPHIRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPRecipeBase *B) { - return B->getVPDefID() == VPRecipeBase::VPWidenPHISC; - } - static inline bool classof(const VPHeaderPHIRecipe *R) { - return R->getVPDefID() == VPRecipeBase::VPWidenPHISC; - } - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVWidenPHISC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenPHISC) /// Generate the phi/select nodes. void execute(VPTransformState &State) override; @@ -1280,18 +1249,12 @@ public: /// second operand. struct VPFirstOrderRecurrencePHIRecipe : public VPHeaderPHIRecipe { VPFirstOrderRecurrencePHIRecipe(PHINode *Phi, VPValue &Start) - : VPHeaderPHIRecipe(VPVFirstOrderRecurrencePHISC, - VPFirstOrderRecurrencePHISC, Phi, &Start) {} + : VPHeaderPHIRecipe(VPDef::VPFirstOrderRecurrencePHISC, Phi, &Start) {} + + VP_CLASSOF_IMPL(VPDef::VPFirstOrderRecurrencePHISC) - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPRecipeBase *R) { - return R->getVPDefID() == VPRecipeBase::VPFirstOrderRecurrencePHISC; - } static inline bool classof(const VPHeaderPHIRecipe *R) { - return R->getVPDefID() == VPRecipeBase::VPFirstOrderRecurrencePHISC; - } - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVFirstOrderRecurrencePHISC; + return R->getVPDefID() == VPDef::VPFirstOrderRecurrencePHISC; } void execute(VPTransformState &State) override; @@ -1322,22 +1285,17 @@ public: VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc, VPValue &Start, bool IsInLoop = false, bool IsOrdered = false) - : VPHeaderPHIRecipe(VPVReductionPHISC, VPReductionPHISC, Phi, &Start), + : VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start), RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) { assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop"); } ~VPReductionPHIRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPRecipeBase *R) { - return R->getVPDefID() == VPRecipeBase::VPReductionPHISC; - } + VP_CLASSOF_IMPL(VPDef::VPReductionPHISC) + static inline bool classof(const VPHeaderPHIRecipe *R) { - return R->getVPDefID() == VPRecipeBase::VPReductionPHISC; - } - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVReductionPHISC; + return R->getVPDefID() == VPDef::VPReductionPHISC; } /// Generate the phi/select nodes. @@ -1370,18 +1328,14 @@ public: /// respective masks, ordered [I0, M0, I1, M1, ...]. Note that a single value /// might be incoming with a full mask for which there is no VPValue. VPBlendRecipe(PHINode *Phi, ArrayRef<VPValue *> Operands) - : VPRecipeBase(VPBlendSC, Operands), - VPValue(VPValue::VPVBlendSC, Phi, this), Phi(Phi) { + : VPRecipeBase(VPDef::VPBlendSC, Operands), VPValue(this, Phi), Phi(Phi) { assert(Operands.size() > 0 && ((Operands.size() == 1) || (Operands.size() % 2 == 0)) && "Expected either a single incoming value or a positive even number " "of operands"); } - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPBlendSC; - } + VP_CLASSOF_IMPL(VPDef::VPBlendSC) /// Return the number of incoming values, taking into account that a single /// incoming value has no mask. @@ -1425,7 +1379,7 @@ class VPInterleaveRecipe : public VPRecipeBase { public: VPInterleaveRecipe(const InterleaveGroup<Instruction> *IG, VPValue *Addr, ArrayRef<VPValue *> StoredValues, VPValue *Mask) - : VPRecipeBase(VPInterleaveSC, {Addr}), IG(IG) { + : VPRecipeBase(VPDef::VPInterleaveSC, {Addr}), IG(IG) { for (unsigned i = 0; i < IG->getFactor(); ++i) if (Instruction *I = IG->getMember(i)) { if (I->getType()->isVoidTy()) @@ -1442,10 +1396,7 @@ public: } ~VPInterleaveRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPInterleaveSC; - } + VP_CLASSOF_IMPL(VPDef::VPInterleaveSC) /// Return the address accessed by this recipe. VPValue *getAddr() const { @@ -1489,9 +1440,7 @@ public: bool onlyFirstLaneUsed(const VPValue *Op) const override { assert(is_contained(operands(), Op) && "Op must be an operand of the recipe"); - return Op == getAddr() && all_of(getStoredValues(), [Op](VPValue *StoredV) { - return Op != StoredV; - }); + return Op == getAddr() && !llvm::is_contained(getStoredValues(), Op); } }; @@ -1508,18 +1457,15 @@ public: VPReductionRecipe(const RecurrenceDescriptor *R, Instruction *I, VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, const TargetTransformInfo *TTI) - : VPRecipeBase(VPRecipeBase::VPReductionSC, {ChainOp, VecOp}), - VPValue(VPValue::VPVReductionSC, I, this), RdxDesc(R), TTI(TTI) { + : VPRecipeBase(VPDef::VPReductionSC, {ChainOp, VecOp}), VPValue(this, I), + RdxDesc(R), TTI(TTI) { if (CondOp) addOperand(CondOp); } ~VPReductionRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVReductionSC; - } + VP_CLASSOF_IMPL(VPDef::VPReductionSC) /// Generate the reduction in the loop void execute(VPTransformState &State) override; @@ -1558,7 +1504,7 @@ public: template <typename IterT> VPReplicateRecipe(Instruction *I, iterator_range<IterT> Operands, bool IsUniform, bool IsPredicated = false) - : VPRecipeBase(VPReplicateSC, Operands), VPValue(VPVReplicateSC, I, this), + : VPRecipeBase(VPDef::VPReplicateSC, Operands), VPValue(this, I), IsUniform(IsUniform), IsPredicated(IsPredicated) { // Retain the previous behavior of predicateInstructions(), where an // insert-element of a predicated instruction got hoisted into the @@ -1570,14 +1516,7 @@ public: ~VPReplicateRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPReplicateSC; - } - - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVReplicateSC; - } + VP_CLASSOF_IMPL(VPDef::VPReplicateSC) /// Generate replicas of the desired Ingredient. Replicas will be generated /// for all parts and lanes unless a specific part and lane are specified in @@ -1617,15 +1556,12 @@ public: class VPBranchOnMaskRecipe : public VPRecipeBase { public: VPBranchOnMaskRecipe(VPValue *BlockInMask) - : VPRecipeBase(VPBranchOnMaskSC, {}) { + : VPRecipeBase(VPDef::VPBranchOnMaskSC, {}) { if (BlockInMask) // nullptr means all-one mask. addOperand(BlockInMask); } - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPBranchOnMaskSC; - } + VP_CLASSOF_IMPL(VPDef::VPBranchOnMaskSC) /// Generate the extraction of the appropriate bit from the block mask and the /// conditional branch. @@ -1669,14 +1605,10 @@ public: /// Construct a VPPredInstPHIRecipe given \p PredInst whose value needs a phi /// nodes after merging back from a Branch-on-Mask. VPPredInstPHIRecipe(VPValue *PredV) - : VPRecipeBase(VPPredInstPHISC, PredV), - VPValue(VPValue::VPVPredInstPHI, nullptr, this) {} + : VPRecipeBase(VPDef::VPPredInstPHISC, PredV), VPValue(this) {} ~VPPredInstPHIRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPPredInstPHISC; - } + VP_CLASSOF_IMPL(VPDef::VPPredInstPHISC) /// Generates phi nodes for live-outs as needed to retain SSA form. void execute(VPTransformState &State) override; @@ -1723,26 +1655,23 @@ class VPWidenMemoryInstructionRecipe : public VPRecipeBase { public: VPWidenMemoryInstructionRecipe(LoadInst &Load, VPValue *Addr, VPValue *Mask, bool Consecutive, bool Reverse) - : VPRecipeBase(VPWidenMemoryInstructionSC, {Addr}), Ingredient(Load), - Consecutive(Consecutive), Reverse(Reverse) { + : VPRecipeBase(VPDef::VPWidenMemoryInstructionSC, {Addr}), + Ingredient(Load), Consecutive(Consecutive), Reverse(Reverse) { assert((Consecutive || !Reverse) && "Reverse implies consecutive"); - new VPValue(VPValue::VPVMemoryInstructionSC, &Load, this); + new VPValue(this, &Load); setMask(Mask); } VPWidenMemoryInstructionRecipe(StoreInst &Store, VPValue *Addr, VPValue *StoredValue, VPValue *Mask, bool Consecutive, bool Reverse) - : VPRecipeBase(VPWidenMemoryInstructionSC, {Addr, StoredValue}), + : VPRecipeBase(VPDef::VPWidenMemoryInstructionSC, {Addr, StoredValue}), Ingredient(Store), Consecutive(Consecutive), Reverse(Reverse) { assert((Consecutive || !Reverse) && "Reverse implies consecutive"); setMask(Mask); } - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPWidenMemoryInstructionSC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenMemoryInstructionSC) /// Return the address accessed by this recipe. VPValue *getAddr() const { @@ -1803,15 +1732,12 @@ class VPExpandSCEVRecipe : public VPRecipeBase, public VPValue { public: VPExpandSCEVRecipe(const SCEV *Expr, ScalarEvolution &SE) - : VPRecipeBase(VPExpandSCEVSC, {}), VPValue(nullptr, this), Expr(Expr), + : VPRecipeBase(VPDef::VPExpandSCEVSC, {}), VPValue(this), Expr(Expr), SE(SE) {} ~VPExpandSCEVRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPExpandSCEVSC; - } + VP_CLASSOF_IMPL(VPDef::VPExpandSCEVSC) /// Generate a canonical vector induction variable of the vector loop, with void execute(VPTransformState &State) override; @@ -1834,21 +1760,14 @@ class VPCanonicalIVPHIRecipe : public VPHeaderPHIRecipe { public: VPCanonicalIVPHIRecipe(VPValue *StartV, DebugLoc DL) - : VPHeaderPHIRecipe(VPValue::VPVCanonicalIVPHISC, VPCanonicalIVPHISC, - nullptr, StartV), - DL(DL) {} + : VPHeaderPHIRecipe(VPDef::VPCanonicalIVPHISC, nullptr, StartV), DL(DL) {} ~VPCanonicalIVPHIRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPCanonicalIVPHISC; - } + VP_CLASSOF_IMPL(VPDef::VPCanonicalIVPHISC) + static inline bool classof(const VPHeaderPHIRecipe *D) { - return D->getVPDefID() == VPCanonicalIVPHISC; - } - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVCanonicalIVPHISC; + return D->getVPDefID() == VPDef::VPCanonicalIVPHISC; } /// Generate the canonical scalar induction phi of the vector loop. @@ -1871,6 +1790,10 @@ public: "Op must be an operand of the recipe"); return true; } + + /// Check if the induction described by \p ID is canonical, i.e. has the same + /// start, step (of 1), and type as the canonical IV. + bool isCanonical(const InductionDescriptor &ID, Type *Ty) const; }; /// A recipe for generating the active lane mask for the vector loop that is @@ -1882,21 +1805,15 @@ class VPActiveLaneMaskPHIRecipe : public VPHeaderPHIRecipe { public: VPActiveLaneMaskPHIRecipe(VPValue *StartMask, DebugLoc DL) - : VPHeaderPHIRecipe(VPValue::VPVActiveLaneMaskPHISC, - VPActiveLaneMaskPHISC, nullptr, StartMask), + : VPHeaderPHIRecipe(VPDef::VPActiveLaneMaskPHISC, nullptr, StartMask), DL(DL) {} ~VPActiveLaneMaskPHIRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPActiveLaneMaskPHISC; - } + VP_CLASSOF_IMPL(VPDef::VPActiveLaneMaskPHISC) + static inline bool classof(const VPHeaderPHIRecipe *D) { - return D->getVPDefID() == VPActiveLaneMaskPHISC; - } - static inline bool classof(const VPValue *V) { - return V->getVPValueID() == VPValue::VPVActiveLaneMaskPHISC; + return D->getVPDefID() == VPDef::VPActiveLaneMaskPHISC; } /// Generate the active lane mask phi of the vector loop. @@ -1913,25 +1830,12 @@ public: class VPWidenCanonicalIVRecipe : public VPRecipeBase, public VPValue { public: VPWidenCanonicalIVRecipe(VPCanonicalIVPHIRecipe *CanonicalIV) - : VPRecipeBase(VPWidenCanonicalIVSC, {CanonicalIV}), - VPValue(VPValue::VPVWidenCanonicalIVSC, nullptr, this) {} + : VPRecipeBase(VPDef::VPWidenCanonicalIVSC, {CanonicalIV}), + VPValue(this) {} ~VPWidenCanonicalIVRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPWidenCanonicalIVSC; - } - - /// Extra classof implementations to allow directly casting from VPUser -> - /// VPWidenCanonicalIVRecipe. - static inline bool classof(const VPUser *U) { - auto *R = dyn_cast<VPRecipeBase>(U); - return R && R->getVPDefID() == VPRecipeBase::VPWidenCanonicalIVSC; - } - static inline bool classof(const VPRecipeBase *R) { - return R->getVPDefID() == VPRecipeBase::VPWidenCanonicalIVSC; - } + VP_CLASSOF_IMPL(VPDef::VPWidenCanonicalIVSC) /// Generate a canonical vector induction variable of the vector loop, with /// start = {<Part*VF, Part*VF+1, ..., Part*VF+VF-1> for 0 <= Part < UF}, and @@ -1946,43 +1850,69 @@ public: /// Returns the scalar type of the induction. const Type *getScalarType() const { - return cast<VPCanonicalIVPHIRecipe>(getOperand(0)->getDef()) + return cast<VPCanonicalIVPHIRecipe>(getOperand(0)->getDefiningRecipe()) ->getScalarType(); } }; +/// A recipe for converting the canonical IV value to the corresponding value of +/// an IV with different start and step values, using Start + CanonicalIV * +/// Step. +class VPDerivedIVRecipe : public VPRecipeBase, public VPValue { + /// The type of the result value. It may be smaller than the type of the + /// induction and in this case it will get truncated to ResultTy. + Type *ResultTy; + + /// Induction descriptor for the induction the canonical IV is transformed to. + const InductionDescriptor &IndDesc; + +public: + VPDerivedIVRecipe(const InductionDescriptor &IndDesc, VPValue *Start, + VPCanonicalIVPHIRecipe *CanonicalIV, VPValue *Step, + Type *ResultTy) + : VPRecipeBase(VPDef::VPDerivedIVSC, {Start, CanonicalIV, Step}), + VPValue(this), ResultTy(ResultTy), IndDesc(IndDesc) {} + + ~VPDerivedIVRecipe() override = default; + + VP_CLASSOF_IMPL(VPDef::VPDerivedIVSC) + + /// Generate the transformed value of the induction at offset StartValue (1. + /// operand) + IV (2. operand) * StepValue (3, operand). + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + VPValue *getStartValue() const { return getOperand(0); } + VPValue *getCanonicalIV() const { return getOperand(1); } + VPValue *getStepValue() const { return getOperand(2); } + + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } +}; + /// A recipe for handling phi nodes of integer and floating-point inductions, /// producing their scalar values. class VPScalarIVStepsRecipe : public VPRecipeBase, public VPValue { - /// Scalar type to use for the generated values. - Type *Ty; - /// If not nullptr, truncate the generated values to TruncToTy. - Type *TruncToTy; const InductionDescriptor &IndDesc; public: - VPScalarIVStepsRecipe(Type *Ty, const InductionDescriptor &IndDesc, - VPValue *CanonicalIV, VPValue *Start, VPValue *Step, - Type *TruncToTy) - : VPRecipeBase(VPScalarIVStepsSC, {CanonicalIV, Start, Step}), - VPValue(nullptr, this), Ty(Ty), TruncToTy(TruncToTy), IndDesc(IndDesc) { - } + VPScalarIVStepsRecipe(const InductionDescriptor &IndDesc, VPValue *IV, + VPValue *Step) + : VPRecipeBase(VPDef::VPScalarIVStepsSC, {IV, Step}), VPValue(this), + IndDesc(IndDesc) {} ~VPScalarIVStepsRecipe() override = default; - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *D) { - return D->getVPDefID() == VPRecipeBase::VPScalarIVStepsSC; - } - /// Extra classof implementations to allow directly casting from VPUser -> - /// VPScalarIVStepsRecipe. - static inline bool classof(const VPUser *U) { - auto *R = dyn_cast<VPRecipeBase>(U); - return R && R->getVPDefID() == VPRecipeBase::VPScalarIVStepsSC; - } - static inline bool classof(const VPRecipeBase *R) { - return R->getVPDefID() == VPRecipeBase::VPScalarIVStepsSC; - } + VP_CLASSOF_IMPL(VPDef::VPScalarIVStepsSC) /// Generate the scalarized versions of the phi node as needed by their users. void execute(VPTransformState &State) override; @@ -1993,13 +1923,7 @@ public: VPSlotTracker &SlotTracker) const override; #endif - /// Returns true if the induction is canonical, i.e. starting at 0 and - /// incremented by UF * VF (= the original IV is incremented by 1). - bool isCanonical() const; - - VPCanonicalIVPHIRecipe *getCanonicalIV() const; - VPValue *getStartValue() const { return getOperand(1); } - VPValue *getStepValue() const { return getOperand(2); } + VPValue *getStepValue() const { return getOperand(1); } /// Returns true if the recipe only uses the first lane of operand \p Op. bool onlyFirstLaneUsed(const VPValue *Op) const override { @@ -2084,7 +2008,7 @@ public: /// The method which generates the output IR instructions that correspond to /// this VPBasicBlock, thereby "executing" the VPlan. - void execute(struct VPTransformState *State) override; + void execute(VPTransformState *State) override; /// Return the position of the first non-phi node recipe in the block. iterator getFirstNonPhi(); @@ -2187,12 +2111,6 @@ public: EntryBlock->setParent(this); } - // FIXME: DominatorTreeBase is doing 'A->getParent()->front()'. 'front' is a - // specific interface of llvm::Function, instead of using - // GraphTraints::getEntryNode. We should add a new template parameter to - // DominatorTreeBase representing the Graph type. - VPBlockBase &front() const { return *Entry; } - const VPBlockBase *getExiting() const { return Exiting; } VPBlockBase *getExiting() { return Exiting; } @@ -2217,7 +2135,7 @@ public: /// The method which generates the output IR instructions that correspond to /// this VPRegionBlock, thereby "executing" the VPlan. - void execute(struct VPTransformState *State) override; + void execute(VPTransformState *State) override; void dropAllReferences(VPValue *NewValue) override; @@ -2234,258 +2152,6 @@ public: #endif }; -//===----------------------------------------------------------------------===// -// GraphTraits specializations for VPlan Hierarchical Control-Flow Graphs // -//===----------------------------------------------------------------------===// - -// The following set of template specializations implement GraphTraits to treat -// any VPBlockBase as a node in a graph of VPBlockBases. It's important to note -// that VPBlockBase traits don't recurse into VPRegioBlocks, i.e., if the -// VPBlockBase is a VPRegionBlock, this specialization provides access to its -// successors/predecessors but not to the blocks inside the region. - -template <> struct GraphTraits<VPBlockBase *> { - using NodeRef = VPBlockBase *; - using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::iterator; - - static NodeRef getEntryNode(NodeRef N) { return N; } - - static inline ChildIteratorType child_begin(NodeRef N) { - return N->getSuccessors().begin(); - } - - static inline ChildIteratorType child_end(NodeRef N) { - return N->getSuccessors().end(); - } -}; - -template <> struct GraphTraits<const VPBlockBase *> { - using NodeRef = const VPBlockBase *; - using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::const_iterator; - - static NodeRef getEntryNode(NodeRef N) { return N; } - - static inline ChildIteratorType child_begin(NodeRef N) { - return N->getSuccessors().begin(); - } - - static inline ChildIteratorType child_end(NodeRef N) { - return N->getSuccessors().end(); - } -}; - -// Inverse order specialization for VPBasicBlocks. Predecessors are used instead -// of successors for the inverse traversal. -template <> struct GraphTraits<Inverse<VPBlockBase *>> { - using NodeRef = VPBlockBase *; - using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::iterator; - - static NodeRef getEntryNode(Inverse<NodeRef> B) { return B.Graph; } - - static inline ChildIteratorType child_begin(NodeRef N) { - return N->getPredecessors().begin(); - } - - static inline ChildIteratorType child_end(NodeRef N) { - return N->getPredecessors().end(); - } -}; - -// The following set of template specializations implement GraphTraits to -// treat VPRegionBlock as a graph and recurse inside its nodes. It's important -// to note that the blocks inside the VPRegionBlock are treated as VPBlockBases -// (i.e., no dyn_cast is performed, VPBlockBases specialization is used), so -// there won't be automatic recursion into other VPBlockBases that turn to be -// VPRegionBlocks. - -template <> -struct GraphTraits<VPRegionBlock *> : public GraphTraits<VPBlockBase *> { - using GraphRef = VPRegionBlock *; - using nodes_iterator = df_iterator<NodeRef>; - - static NodeRef getEntryNode(GraphRef N) { return N->getEntry(); } - - static nodes_iterator nodes_begin(GraphRef N) { - return nodes_iterator::begin(N->getEntry()); - } - - static nodes_iterator nodes_end(GraphRef N) { - // df_iterator::end() returns an empty iterator so the node used doesn't - // matter. - return nodes_iterator::end(N); - } -}; - -template <> -struct GraphTraits<const VPRegionBlock *> - : public GraphTraits<const VPBlockBase *> { - using GraphRef = const VPRegionBlock *; - using nodes_iterator = df_iterator<NodeRef>; - - static NodeRef getEntryNode(GraphRef N) { return N->getEntry(); } - - static nodes_iterator nodes_begin(GraphRef N) { - return nodes_iterator::begin(N->getEntry()); - } - - static nodes_iterator nodes_end(GraphRef N) { - // df_iterator::end() returns an empty iterator so the node used doesn't - // matter. - return nodes_iterator::end(N); - } -}; - -template <> -struct GraphTraits<Inverse<VPRegionBlock *>> - : public GraphTraits<Inverse<VPBlockBase *>> { - using GraphRef = VPRegionBlock *; - using nodes_iterator = df_iterator<NodeRef>; - - static NodeRef getEntryNode(Inverse<GraphRef> N) { - return N.Graph->getExiting(); - } - - static nodes_iterator nodes_begin(GraphRef N) { - return nodes_iterator::begin(N->getExiting()); - } - - static nodes_iterator nodes_end(GraphRef N) { - // df_iterator::end() returns an empty iterator so the node used doesn't - // matter. - return nodes_iterator::end(N); - } -}; - -/// Iterator to traverse all successors of a VPBlockBase node. This includes the -/// entry node of VPRegionBlocks. Exit blocks of a region implicitly have their -/// parent region's successors. This ensures all blocks in a region are visited -/// before any blocks in a successor region when doing a reverse post-order -// traversal of the graph. -template <typename BlockPtrTy> -class VPAllSuccessorsIterator - : public iterator_facade_base<VPAllSuccessorsIterator<BlockPtrTy>, - std::forward_iterator_tag, VPBlockBase> { - BlockPtrTy Block; - /// Index of the current successor. For VPBasicBlock nodes, this simply is the - /// index for the successor array. For VPRegionBlock, SuccessorIdx == 0 is - /// used for the region's entry block, and SuccessorIdx - 1 are the indices - /// for the successor array. - size_t SuccessorIdx; - - static BlockPtrTy getBlockWithSuccs(BlockPtrTy Current) { - while (Current && Current->getNumSuccessors() == 0) - Current = Current->getParent(); - return Current; - } - - /// Templated helper to dereference successor \p SuccIdx of \p Block. Used by - /// both the const and non-const operator* implementations. - template <typename T1> static T1 deref(T1 Block, unsigned SuccIdx) { - if (auto *R = dyn_cast<VPRegionBlock>(Block)) { - if (SuccIdx == 0) - return R->getEntry(); - SuccIdx--; - } - - // For exit blocks, use the next parent region with successors. - return getBlockWithSuccs(Block)->getSuccessors()[SuccIdx]; - } - -public: - VPAllSuccessorsIterator(BlockPtrTy Block, size_t Idx = 0) - : Block(Block), SuccessorIdx(Idx) {} - VPAllSuccessorsIterator(const VPAllSuccessorsIterator &Other) - : Block(Other.Block), SuccessorIdx(Other.SuccessorIdx) {} - - VPAllSuccessorsIterator &operator=(const VPAllSuccessorsIterator &R) { - Block = R.Block; - SuccessorIdx = R.SuccessorIdx; - return *this; - } - - static VPAllSuccessorsIterator end(BlockPtrTy Block) { - BlockPtrTy ParentWithSuccs = getBlockWithSuccs(Block); - unsigned NumSuccessors = ParentWithSuccs - ? ParentWithSuccs->getNumSuccessors() - : Block->getNumSuccessors(); - - if (auto *R = dyn_cast<VPRegionBlock>(Block)) - return {R, NumSuccessors + 1}; - return {Block, NumSuccessors}; - } - - bool operator==(const VPAllSuccessorsIterator &R) const { - return Block == R.Block && SuccessorIdx == R.SuccessorIdx; - } - - const VPBlockBase *operator*() const { return deref(Block, SuccessorIdx); } - - BlockPtrTy operator*() { return deref(Block, SuccessorIdx); } - - VPAllSuccessorsIterator &operator++() { - SuccessorIdx++; - return *this; - } - - VPAllSuccessorsIterator operator++(int X) { - VPAllSuccessorsIterator Orig = *this; - SuccessorIdx++; - return Orig; - } -}; - -/// Helper for GraphTraits specialization that traverses through VPRegionBlocks. -template <typename BlockTy> class VPBlockRecursiveTraversalWrapper { - BlockTy Entry; - -public: - VPBlockRecursiveTraversalWrapper(BlockTy Entry) : Entry(Entry) {} - BlockTy getEntry() { return Entry; } -}; - -/// GraphTraits specialization to recursively traverse VPBlockBase nodes, -/// including traversing through VPRegionBlocks. Exit blocks of a region -/// implicitly have their parent region's successors. This ensures all blocks in -/// a region are visited before any blocks in a successor region when doing a -/// reverse post-order traversal of the graph. -template <> -struct GraphTraits<VPBlockRecursiveTraversalWrapper<VPBlockBase *>> { - using NodeRef = VPBlockBase *; - using ChildIteratorType = VPAllSuccessorsIterator<VPBlockBase *>; - - static NodeRef - getEntryNode(VPBlockRecursiveTraversalWrapper<VPBlockBase *> N) { - return N.getEntry(); - } - - static inline ChildIteratorType child_begin(NodeRef N) { - return ChildIteratorType(N); - } - - static inline ChildIteratorType child_end(NodeRef N) { - return ChildIteratorType::end(N); - } -}; - -template <> -struct GraphTraits<VPBlockRecursiveTraversalWrapper<const VPBlockBase *>> { - using NodeRef = const VPBlockBase *; - using ChildIteratorType = VPAllSuccessorsIterator<const VPBlockBase *>; - - static NodeRef - getEntryNode(VPBlockRecursiveTraversalWrapper<const VPBlockBase *> N) { - return N.getEntry(); - } - - static inline ChildIteratorType child_begin(NodeRef N) { - return ChildIteratorType(N); - } - - static inline ChildIteratorType child_end(NodeRef N) { - return ChildIteratorType::end(N); - } -}; - /// VPlan models a candidate for vectorization, encoding various decisions take /// to produce efficient output IR, including which branches, basic-blocks and /// output IR instructions to generate, and their cost. VPlan holds a @@ -2501,6 +2167,10 @@ class VPlan { /// Holds the VFs applicable to this VPlan. SmallSetVector<ElementCount, 2> VFs; + /// Holds the UFs applicable to this VPlan. If empty, the VPlan is valid for + /// any UF. + SmallSetVector<unsigned, 2> UFs; + /// Holds the name of the VPlan, for printing. std::string Name; @@ -2540,25 +2210,7 @@ public: Entry->setPlan(this); } - ~VPlan() { - clearLiveOuts(); - - if (Entry) { - VPValue DummyValue; - for (VPBlockBase *Block : depth_first(Entry)) - Block->dropAllReferences(&DummyValue); - - VPBlockBase::deleteCFG(Entry); - } - for (VPValue *VPV : VPValuesToFree) - delete VPV; - if (TripCount) - delete TripCount; - if (BackedgeTakenCount) - delete BackedgeTakenCount; - for (auto &P : VPExternalDefs) - delete P.second; - } + ~VPlan(); /// Prepare the plan for execution, setting up the required live-in values. void prepareToExecute(Value *TripCount, Value *VectorTripCount, @@ -2566,7 +2218,7 @@ public: bool IsEpilogueVectorization); /// Generate the IR code for this VPlan. - void execute(struct VPTransformState *State); + void execute(VPTransformState *State); VPBlockBase *getEntry() { return Entry; } const VPBlockBase *getEntry() const { return Entry; } @@ -2600,9 +2252,26 @@ public: void addVF(ElementCount VF) { VFs.insert(VF); } + void setVF(ElementCount VF) { + assert(hasVF(VF) && "Cannot set VF not already in plan"); + VFs.clear(); + VFs.insert(VF); + } + bool hasVF(ElementCount VF) { return VFs.count(VF); } - const std::string &getName() const { return Name; } + bool hasScalarVFOnly() const { return VFs.size() == 1 && VFs[0].isScalar(); } + + bool hasUF(unsigned UF) const { return UFs.empty() || UFs.contains(UF); } + + void setUF(unsigned UF) { + assert(hasUF(UF) && "Cannot set the UF not already in plan"); + UFs.clear(); + UFs.insert(UF); + } + + /// Return a string with the name of the plan and the applicable VFs and UFs. + std::string getName() const; void setName(const Twine &newName) { Name = newName.str(); } @@ -2680,12 +2349,6 @@ public: return map_range(Operands, Fn); } - /// Returns true if \p VPV is uniform after vectorization. - bool isUniformAfterVectorization(VPValue *VPV) const { - auto RepR = dyn_cast_or_null<VPReplicateRecipe>(VPV->getDef()); - return !VPV->getDef() || (RepR && RepR->isUniform()); - } - /// Returns the VPRegionBlock of the vector loop. VPRegionBlock *getVectorLoopRegion() { return cast<VPRegionBlock>(getEntry()->getSingleSuccessor()); @@ -2869,39 +2532,13 @@ public: To->removePredecessor(From); } - /// Try to merge \p Block into its single predecessor, if \p Block is a - /// VPBasicBlock and its predecessor has a single successor. Returns a pointer - /// to the predecessor \p Block was merged into or nullptr otherwise. - static VPBasicBlock *tryToMergeBlockIntoPredecessor(VPBlockBase *Block) { - auto *VPBB = dyn_cast<VPBasicBlock>(Block); - auto *PredVPBB = - dyn_cast_or_null<VPBasicBlock>(Block->getSinglePredecessor()); - if (!VPBB || !PredVPBB || PredVPBB->getNumSuccessors() != 1) - return nullptr; - - for (VPRecipeBase &R : make_early_inc_range(*VPBB)) - R.moveBefore(*PredVPBB, PredVPBB->end()); - VPBlockUtils::disconnectBlocks(PredVPBB, VPBB); - auto *ParentRegion = cast<VPRegionBlock>(Block->getParent()); - if (ParentRegion->getExiting() == Block) - ParentRegion->setExiting(PredVPBB); - SmallVector<VPBlockBase *> Successors(Block->successors()); - for (auto *Succ : Successors) { - VPBlockUtils::disconnectBlocks(Block, Succ); - VPBlockUtils::connectBlocks(PredVPBB, Succ); - } - delete Block; - return PredVPBB; - } - /// Return an iterator range over \p Range which only includes \p BlockTy /// blocks. The accesses are casted to \p BlockTy. template <typename BlockTy, typename T> static auto blocksOnly(const T &Range) { // Create BaseTy with correct const-ness based on BlockTy. - using BaseTy = - typename std::conditional<std::is_const<BlockTy>::value, - const VPBlockBase, VPBlockBase>::type; + using BaseTy = std::conditional_t<std::is_const<BlockTy>::value, + const VPBlockBase, VPBlockBase>; // We need to first create an iterator range over (const) BlocktTy & instead // of (const) BlockTy * for filter_range to work properly. @@ -3061,6 +2698,19 @@ bool onlyFirstLaneUsed(VPValue *Def); /// create a new one. VPValue *getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, ScalarEvolution &SE); + +/// Returns true if \p VPV is uniform after vectorization. +inline bool isUniformAfterVectorization(VPValue *VPV) { + // A value defined outside the vector region must be uniform after + // vectorization inside a vector region. + if (VPV->isDefinedOutsideVectorRegions()) + return true; + VPRecipeBase *Def = VPV->getDefiningRecipe(); + assert(Def && "Must have definition for value defined inside vector region"); + if (auto Rep = dyn_cast<VPReplicateRecipe>(Def)) + return Rep->isUniform(); + return false; +} } // end namespace vputils } // end namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanCFG.h b/llvm/lib/Transforms/Vectorize/VPlanCFG.h new file mode 100644 index 000000000000..f790f7e73e11 --- /dev/null +++ b/llvm/lib/Transforms/Vectorize/VPlanCFG.h @@ -0,0 +1,310 @@ +//===- VPlanCFG.h - GraphTraits for VP blocks -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// Specializations of GraphTraits that allow VPBlockBase graphs to be +/// treated as proper graphs for generic algorithms; +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANCFG_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLANCFG_H + +#include "VPlan.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/SmallVector.h" + +namespace llvm { + +//===----------------------------------------------------------------------===// +// GraphTraits specializations for VPlan Hierarchical Control-Flow Graphs // +//===----------------------------------------------------------------------===// + +/// Iterator to traverse all successors of a VPBlockBase node. This includes the +/// entry node of VPRegionBlocks. Exit blocks of a region implicitly have their +/// parent region's successors. This ensures all blocks in a region are visited +/// before any blocks in a successor region when doing a reverse post-order +// traversal of the graph. Region blocks themselves traverse only their entries +// directly and not their successors. Those will be traversed when a region's +// exiting block is traversed +template <typename BlockPtrTy> +class VPAllSuccessorsIterator + : public iterator_facade_base<VPAllSuccessorsIterator<BlockPtrTy>, + std::bidirectional_iterator_tag, + VPBlockBase> { + BlockPtrTy Block; + /// Index of the current successor. For VPBasicBlock nodes, this simply is the + /// index for the successor array. For VPRegionBlock, SuccessorIdx == 0 is + /// used for the region's entry block, and SuccessorIdx - 1 are the indices + /// for the successor array. + size_t SuccessorIdx; + + static BlockPtrTy getBlockWithSuccs(BlockPtrTy Current) { + while (Current && Current->getNumSuccessors() == 0) + Current = Current->getParent(); + return Current; + } + + /// Templated helper to dereference successor \p SuccIdx of \p Block. Used by + /// both the const and non-const operator* implementations. + template <typename T1> static T1 deref(T1 Block, unsigned SuccIdx) { + if (auto *R = dyn_cast<VPRegionBlock>(Block)) { + assert(SuccIdx == 0); + return R->getEntry(); + } + + // For exit blocks, use the next parent region with successors. + return getBlockWithSuccs(Block)->getSuccessors()[SuccIdx]; + } + +public: + /// Used by iterator_facade_base with bidirectional_iterator_tag. + using reference = BlockPtrTy; + + VPAllSuccessorsIterator(BlockPtrTy Block, size_t Idx = 0) + : Block(Block), SuccessorIdx(Idx) {} + VPAllSuccessorsIterator(const VPAllSuccessorsIterator &Other) + : Block(Other.Block), SuccessorIdx(Other.SuccessorIdx) {} + + VPAllSuccessorsIterator &operator=(const VPAllSuccessorsIterator &R) { + Block = R.Block; + SuccessorIdx = R.SuccessorIdx; + return *this; + } + + static VPAllSuccessorsIterator end(BlockPtrTy Block) { + if (auto *R = dyn_cast<VPRegionBlock>(Block)) { + // Traverse through the region's entry node. + return {R, 1}; + } + BlockPtrTy ParentWithSuccs = getBlockWithSuccs(Block); + unsigned NumSuccessors = + ParentWithSuccs ? ParentWithSuccs->getNumSuccessors() : 0; + return {Block, NumSuccessors}; + } + + bool operator==(const VPAllSuccessorsIterator &R) const { + return Block == R.Block && SuccessorIdx == R.SuccessorIdx; + } + + const VPBlockBase *operator*() const { return deref(Block, SuccessorIdx); } + + BlockPtrTy operator*() { return deref(Block, SuccessorIdx); } + + VPAllSuccessorsIterator &operator++() { + SuccessorIdx++; + return *this; + } + + VPAllSuccessorsIterator &operator--() { + SuccessorIdx--; + return *this; + } + + VPAllSuccessorsIterator operator++(int X) { + VPAllSuccessorsIterator Orig = *this; + SuccessorIdx++; + return Orig; + } +}; + +/// Helper for GraphTraits specialization that traverses through VPRegionBlocks. +template <typename BlockTy> class VPBlockDeepTraversalWrapper { + BlockTy Entry; + +public: + VPBlockDeepTraversalWrapper(BlockTy Entry) : Entry(Entry) {} + BlockTy getEntry() { return Entry; } +}; + +/// GraphTraits specialization to recursively traverse VPBlockBase nodes, +/// including traversing through VPRegionBlocks. Exit blocks of a region +/// implicitly have their parent region's successors. This ensures all blocks in +/// a region are visited before any blocks in a successor region when doing a +/// reverse post-order traversal of the graph. +template <> struct GraphTraits<VPBlockDeepTraversalWrapper<VPBlockBase *>> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = VPAllSuccessorsIterator<VPBlockBase *>; + + static NodeRef getEntryNode(VPBlockDeepTraversalWrapper<VPBlockBase *> N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return ChildIteratorType(N); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return ChildIteratorType::end(N); + } +}; + +template <> +struct GraphTraits<VPBlockDeepTraversalWrapper<const VPBlockBase *>> { + using NodeRef = const VPBlockBase *; + using ChildIteratorType = VPAllSuccessorsIterator<const VPBlockBase *>; + + static NodeRef + getEntryNode(VPBlockDeepTraversalWrapper<const VPBlockBase *> N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return ChildIteratorType(N); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return ChildIteratorType::end(N); + } +}; + +/// Helper for GraphTraits specialization that does not traverses through +/// VPRegionBlocks. +template <typename BlockTy> class VPBlockShallowTraversalWrapper { + BlockTy Entry; + +public: + VPBlockShallowTraversalWrapper(BlockTy Entry) : Entry(Entry) {} + BlockTy getEntry() { return Entry; } +}; + +template <> struct GraphTraits<VPBlockShallowTraversalWrapper<VPBlockBase *>> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::iterator; + + static NodeRef getEntryNode(VPBlockShallowTraversalWrapper<VPBlockBase *> N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getSuccessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getSuccessors().end(); + } +}; + +template <> +struct GraphTraits<VPBlockShallowTraversalWrapper<const VPBlockBase *>> { + using NodeRef = const VPBlockBase *; + using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::const_iterator; + + static NodeRef + getEntryNode(VPBlockShallowTraversalWrapper<const VPBlockBase *> N) { + return N.getEntry(); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getSuccessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getSuccessors().end(); + } +}; + +/// Returns an iterator range to traverse the graph starting at \p G in +/// depth-first order. The iterator won't traverse through region blocks. +inline iterator_range< + df_iterator<VPBlockShallowTraversalWrapper<VPBlockBase *>>> +vp_depth_first_shallow(VPBlockBase *G) { + return depth_first(VPBlockShallowTraversalWrapper<VPBlockBase *>(G)); +} +inline iterator_range< + df_iterator<VPBlockShallowTraversalWrapper<const VPBlockBase *>>> +vp_depth_first_shallow(const VPBlockBase *G) { + return depth_first(VPBlockShallowTraversalWrapper<const VPBlockBase *>(G)); +} + +/// Returns an iterator range to traverse the graph starting at \p G in +/// depth-first order while traversing through region blocks. +inline iterator_range<df_iterator<VPBlockDeepTraversalWrapper<VPBlockBase *>>> +vp_depth_first_deep(VPBlockBase *G) { + return depth_first(VPBlockDeepTraversalWrapper<VPBlockBase *>(G)); +} +inline iterator_range< + df_iterator<VPBlockDeepTraversalWrapper<const VPBlockBase *>>> +vp_depth_first_deep(const VPBlockBase *G) { + return depth_first(VPBlockDeepTraversalWrapper<const VPBlockBase *>(G)); +} + +// The following set of template specializations implement GraphTraits to treat +// any VPBlockBase as a node in a graph of VPBlockBases. It's important to note +// that VPBlockBase traits don't recurse into VPRegioBlocks, i.e., if the +// VPBlockBase is a VPRegionBlock, this specialization provides access to its +// successors/predecessors but not to the blocks inside the region. + +template <> struct GraphTraits<VPBlockBase *> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = VPAllSuccessorsIterator<VPBlockBase *>; + + static NodeRef getEntryNode(NodeRef N) { return N; } + + static inline ChildIteratorType child_begin(NodeRef N) { + return ChildIteratorType(N); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return ChildIteratorType::end(N); + } +}; + +template <> struct GraphTraits<const VPBlockBase *> { + using NodeRef = const VPBlockBase *; + using ChildIteratorType = VPAllSuccessorsIterator<const VPBlockBase *>; + + static NodeRef getEntryNode(NodeRef N) { return N; } + + static inline ChildIteratorType child_begin(NodeRef N) { + return ChildIteratorType(N); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return ChildIteratorType::end(N); + } +}; + +/// Inverse graph traits are not implemented yet. +/// TODO: Implement a version of VPBlockNonRecursiveTraversalWrapper to traverse +/// predecessors recursively through regions. +template <> struct GraphTraits<Inverse<VPBlockBase *>> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::iterator; + + static NodeRef getEntryNode(Inverse<NodeRef> B) { + llvm_unreachable("not implemented"); + } + + static inline ChildIteratorType child_begin(NodeRef N) { + llvm_unreachable("not implemented"); + } + + static inline ChildIteratorType child_end(NodeRef N) { + llvm_unreachable("not implemented"); + } +}; + +template <> struct GraphTraits<VPlan *> { + using GraphRef = VPlan *; + using NodeRef = VPBlockBase *; + using nodes_iterator = df_iterator<NodeRef>; + + static NodeRef getEntryNode(GraphRef N) { return N->getEntry(); } + + static nodes_iterator nodes_begin(GraphRef N) { + return nodes_iterator::begin(N->getEntry()); + } + + static nodes_iterator nodes_end(GraphRef N) { + // df_iterator::end() returns an empty iterator so the node used doesn't + // matter. + return nodes_iterator::end(N->getEntry()); + } +}; + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANCFG_H diff --git a/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h b/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h index a42ebc9ee955..fc4cf709a371 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h +++ b/llvm/lib/Transforms/Vectorize/VPlanDominatorTree.h @@ -16,11 +16,23 @@ #define LLVM_TRANSFORMS_VECTORIZE_VPLANDOMINATORTREE_H #include "VPlan.h" +#include "VPlanCFG.h" #include "llvm/ADT/GraphTraits.h" #include "llvm/IR/Dominators.h" +#include "llvm/Support/GenericDomTree.h" namespace llvm { +template <> struct DomTreeNodeTraits<VPBlockBase> { + using NodeType = VPBlockBase; + using NodePtr = VPBlockBase *; + using ParentPtr = VPlan *; + + static NodePtr getEntryNode(ParentPtr Parent) { return Parent->getEntry(); } + static ParentPtr getParent(NodePtr B) { return B->getPlan(); } +}; + +/// /// Template specialization of the standard LLVM dominator tree utility for /// VPBlockBases. using VPDominatorTree = DomTreeBase<VPBlockBase>; diff --git a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp index 84b0dac862b6..952ce72e36c1 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -243,7 +243,7 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB, for (Value *Op : Inst->operands()) VPOperands.push_back(getOrCreateVPOperand(Op)); - // Build VPInstruction for any arbitraty Instruction without specific + // Build VPInstruction for any arbitrary Instruction without specific // representation in VPlan. NewVPV = cast<VPInstruction>( VPIRBuilder.createNaryOp(Inst->getOpcode(), VPOperands, Inst)); @@ -391,7 +391,7 @@ void VPlanHCFGBuilder::buildHierarchicalCFG() { Verifier.verifyHierarchicalCFG(TopRegion); // Compute plain CFG dom tree for VPLInfo. - VPDomTree.recalculate(*TopRegion); + VPDomTree.recalculate(Plan); LLVM_DEBUG(dbgs() << "Dominator Tree after building the plain CFG.\n"; VPDomTree.print(dbgs())); } diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index cb7507264667..4e9be35001ad 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -49,6 +49,7 @@ bool VPRecipeBase::mayWriteToMemory() const { return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) ->mayWriteToMemory(); case VPBranchOnMaskSC: + case VPScalarIVStepsSC: return false; case VPWidenIntOrFpInductionSC: case VPWidenCanonicalIVSC: @@ -80,6 +81,7 @@ bool VPRecipeBase::mayReadFromMemory() const { return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) ->mayReadFromMemory(); case VPBranchOnMaskSC: + case VPScalarIVStepsSC: return false; case VPWidenIntOrFpInductionSC: case VPWidenCanonicalIVSC: @@ -103,6 +105,9 @@ bool VPRecipeBase::mayReadFromMemory() const { bool VPRecipeBase::mayHaveSideEffects() const { switch (getVPDefID()) { + case VPDerivedIVSC: + case VPPredInstPHISC: + return false; case VPWidenIntOrFpInductionSC: case VPWidenPointerInductionSC: case VPWidenCanonicalIVSC: @@ -132,7 +137,7 @@ bool VPRecipeBase::mayHaveSideEffects() const { void VPLiveOut::fixPhi(VPlan &Plan, VPTransformState &State) { auto Lane = VPLane::getLastLaneForVF(State.VF); VPValue *ExitValue = getOperand(0); - if (Plan.isUniformAfterVectorization(ExitValue)) + if (vputils::isUniformAfterVectorization(ExitValue)) Lane = VPLane::getFirstLane(); Phi->addIncoming(State.get(ExitValue, VPIteration(State.UF - 1, Lane)), State.Builder.GetInsertBlock()); @@ -432,6 +437,64 @@ void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) { FMF = FMFNew; } +void VPWidenCallRecipe::execute(VPTransformState &State) { + auto &CI = *cast<CallInst>(getUnderlyingInstr()); + assert(!isa<DbgInfoIntrinsic>(CI) && + "DbgInfoIntrinsic should have been dropped during VPlan construction"); + State.setDebugLocFromInst(&CI); + + SmallVector<Type *, 4> Tys; + for (Value *ArgOperand : CI.args()) + Tys.push_back( + ToVectorTy(ArgOperand->getType(), State.VF.getKnownMinValue())); + + for (unsigned Part = 0; Part < State.UF; ++Part) { + SmallVector<Type *, 2> TysForDecl = {CI.getType()}; + SmallVector<Value *, 4> Args; + for (const auto &I : enumerate(operands())) { + // Some intrinsics have a scalar argument - don't replace it with a + // vector. + Value *Arg; + if (VectorIntrinsicID == Intrinsic::not_intrinsic || + !isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index())) + Arg = State.get(I.value(), Part); + else + Arg = State.get(I.value(), VPIteration(0, 0)); + if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index())) + TysForDecl.push_back(Arg->getType()); + Args.push_back(Arg); + } + + Function *VectorF; + if (VectorIntrinsicID != Intrinsic::not_intrinsic) { + // Use vector version of the intrinsic. + if (State.VF.isVector()) + TysForDecl[0] = + VectorType::get(CI.getType()->getScalarType(), State.VF); + Module *M = State.Builder.GetInsertBlock()->getModule(); + VectorF = Intrinsic::getDeclaration(M, VectorIntrinsicID, TysForDecl); + assert(VectorF && "Can't retrieve vector intrinsic."); + } else { + // Use vector version of the function call. + const VFShape Shape = VFShape::get(CI, State.VF, false /*HasGlobalPred*/); +#ifndef NDEBUG + assert(VFDatabase(CI).getVectorizedFunction(Shape) != nullptr && + "Can't create vector function."); +#endif + VectorF = VFDatabase(CI).getVectorizedFunction(Shape); + } + SmallVector<OperandBundleDef, 1> OpBundles; + CI.getOperandBundlesAsDefs(OpBundles); + CallInst *V = State.Builder.CreateCall(VectorF, Args, OpBundles); + + if (isa<FPMathOperator>(V)) + V->copyFastMathFlags(&CI); + + State.set(this, V, Part); + State.addMetadata(V, &CI); + } +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { @@ -448,6 +511,11 @@ void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent, O << "call @" << CI->getCalledFunction()->getName() << "("; printOperands(O, SlotTracker); O << ")"; + + if (VectorIntrinsicID) + O << " (using vector intrinsic)"; + else + O << " (using library function)"; } void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent, @@ -618,7 +686,10 @@ void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "WIDEN "; printAsOperand(O, SlotTracker); - O << " = " << getUnderlyingInstr()->getOpcodeName() << " "; + const Instruction *UI = getUnderlyingInstr(); + O << " = " << UI->getOpcodeName() << " "; + if (auto *Cmp = dyn_cast<CmpInst>(UI)) + O << CmpInst::getPredicateName(Cmp->getPredicate()) << " "; printOperands(O, SlotTracker); } @@ -644,22 +715,22 @@ bool VPWidenIntOrFpInductionRecipe::isCanonical() const { return StartC && StartC->isZero() && StepC && StepC->isOne(); } -VPCanonicalIVPHIRecipe *VPScalarIVStepsRecipe::getCanonicalIV() const { - return cast<VPCanonicalIVPHIRecipe>(getOperand(0)); -} +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPDerivedIVRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent; + printAsOperand(O, SlotTracker); + O << Indent << "= DERIVED-IV "; + getStartValue()->printAsOperand(O, SlotTracker); + O << " + "; + getCanonicalIV()->printAsOperand(O, SlotTracker); + O << " * "; + getStepValue()->printAsOperand(O, SlotTracker); -bool VPScalarIVStepsRecipe::isCanonical() const { - auto *CanIV = getCanonicalIV(); - // The start value of the steps-recipe must match the start value of the - // canonical induction and it must step by 1. - if (CanIV->getStartValue() != getStartValue()) - return false; - auto *StepVPV = getStepValue(); - if (StepVPV->getDef()) - return false; - auto *StepC = dyn_cast_or_null<ConstantInt>(StepVPV->getLiveInIRValue()); - return StepC && StepC->isOne(); + if (IndDesc.getStep()->getType() != ResultTy) + O << " (truncated to " << *ResultTy << ")"; } +#endif #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPScalarIVStepsRecipe::print(raw_ostream &O, const Twine &Indent, @@ -982,11 +1053,25 @@ void VPCanonicalIVPHIRecipe::print(raw_ostream &O, const Twine &Indent, } #endif +bool VPCanonicalIVPHIRecipe::isCanonical(const InductionDescriptor &ID, + Type *Ty) const { + if (Ty != getScalarType()) + return false; + // The start value of ID must match the start value of this canonical + // induction. + if (getStartValue()->getLiveInIRValue() != ID.getStartValue()) + return false; + + ConstantInt *Step = ID.getConstIntStepValue(); + // ID must also be incremented by one. IK_IntInduction always increment the + // induction by Step, but the binary op may not be set. + return ID.getKind() == InductionDescriptor::IK_IntInduction && Step && + Step->isOne(); +} + bool VPWidenPointerInductionRecipe::onlyScalarsGenerated(ElementCount VF) { - bool IsUniform = vputils::onlyFirstLaneUsed(this); - return all_of(users(), - [&](const VPUser *U) { return U->usesScalars(this); }) && - (IsUniform || !VF.isScalable()); + return IsScalarAfterVectorization && + (!VF.isScalable() || vputils::onlyFirstLaneUsed(this)); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) diff --git a/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp b/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp index 3a7e77fd9efd..fbcadba33e67 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp @@ -29,6 +29,7 @@ #include "llvm/Support/raw_ostream.h" #include <algorithm> #include <cassert> +#include <optional> #include <utility> using namespace llvm; @@ -187,12 +188,12 @@ getOperands(ArrayRef<VPValue *> Values) { } /// Returns the opcode of Values or ~0 if they do not all agree. -static Optional<unsigned> getOpcode(ArrayRef<VPValue *> Values) { +static std::optional<unsigned> getOpcode(ArrayRef<VPValue *> Values) { unsigned Opcode = cast<VPInstruction>(Values[0])->getOpcode(); if (any_of(Values, [Opcode](VPValue *V) { return cast<VPInstruction>(V)->getOpcode() != Opcode; })) - return None; + return std::nullopt; return {Opcode}; } @@ -343,7 +344,7 @@ SmallVector<VPlanSlp::MultiNodeOpTy, 4> VPlanSlp::reorderMultiNodeOps() { #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPlanSlp::dumpBundle(ArrayRef<VPValue *> Values) { dbgs() << " Ops: "; - for (auto Op : Values) { + for (auto *Op : Values) { if (auto *VPInstr = cast_or_null<VPInstruction>(Op)) if (auto *Instr = VPInstr->getUnderlyingInstr()) { dbgs() << *Instr << " | "; diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index cca484e13bf1..cbf111b00e3d 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -12,9 +12,12 @@ //===----------------------------------------------------------------------===// #include "VPlanTransforms.h" +#include "VPlanCFG.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/IVDescriptors.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Intrinsics.h" using namespace llvm; @@ -22,10 +25,11 @@ void VPlanTransforms::VPInstructionsToVPRecipes( Loop *OrigLoop, VPlanPtr &Plan, function_ref<const InductionDescriptor *(PHINode *)> GetIntOrFpInductionDescriptor, - SmallPtrSetImpl<Instruction *> &DeadInstructions, ScalarEvolution &SE) { + SmallPtrSetImpl<Instruction *> &DeadInstructions, ScalarEvolution &SE, + const TargetLibraryInfo &TLI) { - ReversePostOrderTraversal<VPBlockRecursiveTraversalWrapper<VPBlockBase *>> - RPOT(Plan->getEntry()); + ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( + Plan->getEntry()); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) { VPRecipeBase *Term = VPBB->getTerminator(); auto EndIter = Term ? Term->getIterator() : VPBB->end(); @@ -74,7 +78,8 @@ void VPlanTransforms::VPInstructionsToVPRecipes( GEP, Plan->mapToVPValues(GEP->operands()), OrigLoop); } else if (CallInst *CI = dyn_cast<CallInst>(Inst)) { NewRecipe = - new VPWidenCallRecipe(*CI, Plan->mapToVPValues(CI->args())); + new VPWidenCallRecipe(*CI, Plan->mapToVPValues(CI->args()), + getVectorIntrinsicIDForCall(CI, &TLI)); } else if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) { bool InvariantCond = SE.isLoopInvariant(SE.getSCEV(SI->getOperand(0)), OrigLoop); @@ -102,40 +107,46 @@ void VPlanTransforms::VPInstructionsToVPRecipes( } bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) { - auto Iter = depth_first( - VPBlockRecursiveTraversalWrapper<VPBlockBase *>(Plan.getEntry())); + auto Iter = vp_depth_first_deep(Plan.getEntry()); bool Changed = false; - // First, collect the operands of all predicated replicate recipes as seeds - // for sinking. - SetVector<std::pair<VPBasicBlock *, VPValue *>> WorkList; - for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { + // First, collect the operands of all recipes in replicate blocks as seeds for + // sinking. + SetVector<std::pair<VPBasicBlock *, VPRecipeBase *>> WorkList; + for (VPRegionBlock *VPR : VPBlockUtils::blocksOnly<VPRegionBlock>(Iter)) { + VPBasicBlock *EntryVPBB = VPR->getEntryBasicBlock(); + if (!VPR->isReplicator() || EntryVPBB->getSuccessors().size() != 2) + continue; + VPBasicBlock *VPBB = dyn_cast<VPBasicBlock>(EntryVPBB->getSuccessors()[0]); + if (!VPBB || VPBB->getSingleSuccessor() != VPR->getExitingBasicBlock()) + continue; for (auto &Recipe : *VPBB) { - auto *RepR = dyn_cast<VPReplicateRecipe>(&Recipe); - if (!RepR || !RepR->isPredicated()) - continue; - for (VPValue *Op : RepR->operands()) - WorkList.insert(std::make_pair(RepR->getParent(), Op)); + for (VPValue *Op : Recipe.operands()) + if (auto *Def = Op->getDefiningRecipe()) + WorkList.insert(std::make_pair(VPBB, Def)); } } - // Try to sink each replicate recipe in the worklist. - while (!WorkList.empty()) { + bool ScalarVFOnly = Plan.hasScalarVFOnly(); + // Try to sink each replicate or scalar IV steps recipe in the worklist. + for (unsigned I = 0; I != WorkList.size(); ++I) { VPBasicBlock *SinkTo; - VPValue *C; - std::tie(SinkTo, C) = WorkList.pop_back_val(); - auto *SinkCandidate = dyn_cast_or_null<VPReplicateRecipe>(C->Def); - if (!SinkCandidate || SinkCandidate->isUniform() || - SinkCandidate->getParent() == SinkTo || + VPRecipeBase *SinkCandidate; + std::tie(SinkTo, SinkCandidate) = WorkList[I]; + if (SinkCandidate->getParent() == SinkTo || SinkCandidate->mayHaveSideEffects() || SinkCandidate->mayReadOrWriteMemory()) continue; + if (auto *RepR = dyn_cast<VPReplicateRecipe>(SinkCandidate)) { + if (!ScalarVFOnly && RepR->isUniform()) + continue; + } else if (!isa<VPScalarIVStepsRecipe>(SinkCandidate)) + continue; bool NeedsDuplicating = false; // All recipe users of the sink candidate must be in the same block SinkTo // or all users outside of SinkTo must be uniform-after-vectorization ( // i.e., only first lane is used) . In the latter case, we need to duplicate - // SinkCandidate. At the moment, we identify such UAV's by looking for the - // address operands of widened memory recipes. + // SinkCandidate. auto CanSinkWithUser = [SinkTo, &NeedsDuplicating, SinkCandidate](VPUser *U) { auto *UI = dyn_cast<VPRecipeBase>(U); @@ -143,31 +154,31 @@ bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) { return false; if (UI->getParent() == SinkTo) return true; - auto *WidenI = dyn_cast<VPWidenMemoryInstructionRecipe>(UI); - if (WidenI && WidenI->getAddr() == SinkCandidate) { - NeedsDuplicating = true; - return true; - } - return false; + NeedsDuplicating = + UI->onlyFirstLaneUsed(SinkCandidate->getVPSingleValue()); + // We only know how to duplicate VPRecipeRecipes for now. + return NeedsDuplicating && isa<VPReplicateRecipe>(SinkCandidate); }; - if (!all_of(SinkCandidate->users(), CanSinkWithUser)) + if (!all_of(SinkCandidate->getVPSingleValue()->users(), CanSinkWithUser)) continue; if (NeedsDuplicating) { - Instruction *I = cast<Instruction>(SinkCandidate->getUnderlyingValue()); + if (ScalarVFOnly) + continue; + Instruction *I = cast<Instruction>( + cast<VPReplicateRecipe>(SinkCandidate)->getUnderlyingValue()); auto *Clone = new VPReplicateRecipe(I, SinkCandidate->operands(), true, false); // TODO: add ".cloned" suffix to name of Clone's VPValue. Clone->insertBefore(SinkCandidate); - SmallVector<VPUser *, 4> Users(SinkCandidate->users()); - for (auto *U : Users) { + for (auto *U : to_vector(SinkCandidate->getVPSingleValue()->users())) { auto *UI = cast<VPRecipeBase>(U); if (UI->getParent() == SinkTo) continue; for (unsigned Idx = 0; Idx != UI->getNumOperands(); Idx++) { - if (UI->getOperand(Idx) != SinkCandidate) + if (UI->getOperand(Idx) != SinkCandidate->getVPSingleValue()) continue; UI->setOperand(Idx, Clone); } @@ -175,7 +186,8 @@ bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) { } SinkCandidate->moveBefore(*SinkTo, SinkTo->getFirstNonPhi()); for (VPValue *Op : SinkCandidate->operands()) - WorkList.insert(std::make_pair(SinkTo, Op)); + if (auto *Def = Op->getDefiningRecipe()) + WorkList.insert(std::make_pair(SinkTo, Def)); Changed = true; } return Changed; @@ -212,21 +224,16 @@ static VPBasicBlock *getPredicatedThenBlock(VPRegionBlock *R) { return nullptr; } -bool VPlanTransforms::mergeReplicateRegions(VPlan &Plan) { +bool VPlanTransforms::mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { SetVector<VPRegionBlock *> DeletedRegions; - bool Changed = false; - - // Collect region blocks to process up-front, to avoid iterator invalidation - // issues while merging regions. - SmallVector<VPRegionBlock *, 8> CandidateRegions( - VPBlockUtils::blocksOnly<VPRegionBlock>(depth_first( - VPBlockRecursiveTraversalWrapper<VPBlockBase *>(Plan.getEntry())))); - // Check if Base is a predicated triangle, followed by an empty block, - // followed by another predicate triangle. If that's the case, move the - // recipes from the first to the second triangle. - for (VPRegionBlock *Region1 : CandidateRegions) { - if (DeletedRegions.contains(Region1)) + // Collect replicate regions followed by an empty block, followed by another + // replicate region with matching masks to process front. This is to avoid + // iterator invalidation issues while merging regions. + SmallVector<VPRegionBlock *, 8> WorkList; + for (VPRegionBlock *Region1 : VPBlockUtils::blocksOnly<VPRegionBlock>( + vp_depth_first_deep(Plan.getEntry()))) { + if (!Region1->isReplicator()) continue; auto *MiddleBasicBlock = dyn_cast_or_null<VPBasicBlock>(Region1->getSingleSuccessor()); @@ -235,20 +242,30 @@ bool VPlanTransforms::mergeReplicateRegions(VPlan &Plan) { auto *Region2 = dyn_cast_or_null<VPRegionBlock>(MiddleBasicBlock->getSingleSuccessor()); - if (!Region2) + if (!Region2 || !Region2->isReplicator()) continue; VPValue *Mask1 = getPredicatedMask(Region1); VPValue *Mask2 = getPredicatedMask(Region2); if (!Mask1 || Mask1 != Mask2) continue; + + assert(Mask1 && Mask2 && "both region must have conditions"); + WorkList.push_back(Region1); + } + + // Move recipes from Region1 to its successor region, if both are triangles. + for (VPRegionBlock *Region1 : WorkList) { + if (DeletedRegions.contains(Region1)) + continue; + auto *MiddleBasicBlock = cast<VPBasicBlock>(Region1->getSingleSuccessor()); + auto *Region2 = cast<VPRegionBlock>(MiddleBasicBlock->getSingleSuccessor()); + VPBasicBlock *Then1 = getPredicatedThenBlock(Region1); VPBasicBlock *Then2 = getPredicatedThenBlock(Region2); if (!Then1 || !Then2) continue; - assert(Mask1 && Mask2 && "both region must have conditions"); - // Note: No fusion-preventing memory dependencies are expected in either // region. Such dependencies should be rejected during earlier dependence // checks, which guarantee accesses can be re-ordered for vectorization. @@ -267,8 +284,7 @@ bool VPlanTransforms::mergeReplicateRegions(VPlan &Plan) { VPValue *PredInst1 = cast<VPPredInstPHIRecipe>(&Phi1ToMove)->getOperand(0); VPValue *Phi1ToMoveV = Phi1ToMove.getVPSingleValue(); - SmallVector<VPUser *> Users(Phi1ToMoveV->users()); - for (VPUser *U : Users) { + for (VPUser *U : to_vector(Phi1ToMoveV->users())) { auto *UI = dyn_cast<VPRecipeBase>(U); if (!UI || UI->getParent() != Then2) continue; @@ -293,7 +309,34 @@ bool VPlanTransforms::mergeReplicateRegions(VPlan &Plan) { for (VPRegionBlock *ToDelete : DeletedRegions) delete ToDelete; - return Changed; + return !DeletedRegions.empty(); +} + +bool VPlanTransforms::mergeBlocksIntoPredecessors(VPlan &Plan) { + SmallVector<VPBasicBlock *> WorkList; + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( + vp_depth_first_deep(Plan.getEntry()))) { + auto *PredVPBB = + dyn_cast_or_null<VPBasicBlock>(VPBB->getSinglePredecessor()); + if (PredVPBB && PredVPBB->getNumSuccessors() == 1) + WorkList.push_back(VPBB); + } + + for (VPBasicBlock *VPBB : WorkList) { + VPBasicBlock *PredVPBB = cast<VPBasicBlock>(VPBB->getSinglePredecessor()); + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) + R.moveBefore(*PredVPBB, PredVPBB->end()); + VPBlockUtils::disconnectBlocks(PredVPBB, VPBB); + auto *ParentRegion = cast_or_null<VPRegionBlock>(VPBB->getParent()); + if (ParentRegion && ParentRegion->getExiting() == VPBB) + ParentRegion->setExiting(PredVPBB); + for (auto *Succ : to_vector(VPBB->successors())) { + VPBlockUtils::disconnectBlocks(VPBB, Succ); + VPBlockUtils::connectBlocks(PredVPBB, Succ); + } + delete VPBB; + } + return !WorkList.empty(); } void VPlanTransforms::removeRedundantInductionCasts(VPlan &Plan) { @@ -362,8 +405,8 @@ void VPlanTransforms::removeRedundantCanonicalIVs(VPlan &Plan) { } void VPlanTransforms::removeDeadRecipes(VPlan &Plan) { - ReversePostOrderTraversal<VPBlockRecursiveTraversalWrapper<VPBlockBase *>> - RPOT(Plan.getEntry()); + ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( + Plan.getEntry()); for (VPBasicBlock *VPBB : reverse(VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT))) { // The recipes in the block are processed in reverse order, to catch chains @@ -383,30 +426,40 @@ void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) { VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); bool HasOnlyVectorVFs = !Plan.hasVF(ElementCount::getFixed(1)); for (VPRecipeBase &Phi : HeaderVPBB->phis()) { - auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi); - if (!IV) + auto *WideIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi); + if (!WideIV) continue; - if (HasOnlyVectorVFs && - none_of(IV->users(), [IV](VPUser *U) { return U->usesScalars(IV); })) + if (HasOnlyVectorVFs && none_of(WideIV->users(), [WideIV](VPUser *U) { + return U->usesScalars(WideIV); + })) continue; - const InductionDescriptor &ID = IV->getInductionDescriptor(); + auto IP = HeaderVPBB->getFirstNonPhi(); + VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV(); + Type *ResultTy = WideIV->getPHINode()->getType(); + if (Instruction *TruncI = WideIV->getTruncInst()) + ResultTy = TruncI->getType(); + const InductionDescriptor &ID = WideIV->getInductionDescriptor(); VPValue *Step = vputils::getOrCreateVPValueForSCEVExpr(Plan, ID.getStep(), SE); - Instruction *TruncI = IV->getTruncInst(); - VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe( - IV->getPHINode()->getType(), ID, Plan.getCanonicalIV(), - IV->getStartValue(), Step, TruncI ? TruncI->getType() : nullptr); - HeaderVPBB->insert(Steps, HeaderVPBB->getFirstNonPhi()); + VPValue *BaseIV = CanonicalIV; + if (!CanonicalIV->isCanonical(ID, ResultTy)) { + BaseIV = new VPDerivedIVRecipe(ID, WideIV->getStartValue(), CanonicalIV, + Step, ResultTy); + HeaderVPBB->insert(BaseIV->getDefiningRecipe(), IP); + } + + VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe(ID, BaseIV, Step); + HeaderVPBB->insert(Steps, IP); // Update scalar users of IV to use Step instead. Use SetVector to ensure // the list of users doesn't contain duplicates. - SetVector<VPUser *> Users(IV->user_begin(), IV->user_end()); + SetVector<VPUser *> Users(WideIV->user_begin(), WideIV->user_end()); for (VPUser *U : Users) { - if (HasOnlyVectorVFs && !U->usesScalars(IV)) + if (HasOnlyVectorVFs && !U->usesScalars(WideIV)) continue; for (unsigned I = 0, E = U->getNumOperands(); I != E; I++) { - if (U->getOperand(I) != IV) + if (U->getOperand(I) != WideIV) continue; U->setOperand(I, Steps); } @@ -430,3 +483,53 @@ void VPlanTransforms::removeRedundantExpandSCEVRecipes(VPlan &Plan) { ExpR->eraseFromParent(); } } + +static bool canSimplifyBranchOnCond(VPInstruction *Term) { + VPInstruction *Not = dyn_cast<VPInstruction>(Term->getOperand(0)); + if (!Not || Not->getOpcode() != VPInstruction::Not) + return false; + + VPInstruction *ALM = dyn_cast<VPInstruction>(Not->getOperand(0)); + return ALM && ALM->getOpcode() == VPInstruction::ActiveLaneMask; +} + +void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, + unsigned BestUF, + PredicatedScalarEvolution &PSE) { + assert(Plan.hasVF(BestVF) && "BestVF is not available in Plan"); + assert(Plan.hasUF(BestUF) && "BestUF is not available in Plan"); + VPBasicBlock *ExitingVPBB = + Plan.getVectorLoopRegion()->getExitingBasicBlock(); + auto *Term = dyn_cast<VPInstruction>(&ExitingVPBB->back()); + // Try to simplify the branch condition if TC <= VF * UF when preparing to + // execute the plan for the main vector loop. We only do this if the + // terminator is: + // 1. BranchOnCount, or + // 2. BranchOnCond where the input is Not(ActiveLaneMask). + if (!Term || (Term->getOpcode() != VPInstruction::BranchOnCount && + (Term->getOpcode() != VPInstruction::BranchOnCond || + !canSimplifyBranchOnCond(Term)))) + return; + + Type *IdxTy = + Plan.getCanonicalIV()->getStartValue()->getLiveInIRValue()->getType(); + const SCEV *TripCount = createTripCountSCEV(IdxTy, PSE); + ScalarEvolution &SE = *PSE.getSE(); + const SCEV *C = + SE.getConstant(TripCount->getType(), BestVF.getKnownMinValue() * BestUF); + if (TripCount->isZero() || + !SE.isKnownPredicate(CmpInst::ICMP_ULE, TripCount, C)) + return; + + LLVMContext &Ctx = SE.getContext(); + auto *BOC = + new VPInstruction(VPInstruction::BranchOnCond, + {Plan.getOrAddExternalDef(ConstantInt::getTrue(Ctx))}); + Term->eraseFromParent(); + ExitingVPBB->appendRecipe(BOC); + Plan.setVF(BestVF); + Plan.setUF(BestUF); + // TODO: Further simplifications are possible + // 1. Replace inductions with constants. + // 2. Replace vector loop region with VPBasicBlock. +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index 3372e255dff7..be0d8e76d809 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -23,6 +23,8 @@ class Instruction; class PHINode; class ScalarEvolution; class Loop; +class PredicatedScalarEvolution; +class TargetLibraryInfo; struct VPlanTransforms { /// Replaces the VPInstructions in \p Plan with corresponding @@ -32,11 +34,18 @@ struct VPlanTransforms { function_ref<const InductionDescriptor *(PHINode *)> GetIntOrFpInductionDescriptor, SmallPtrSetImpl<Instruction *> &DeadInstructions, - ScalarEvolution &SE); + ScalarEvolution &SE, const TargetLibraryInfo &TLI); static bool sinkScalarOperands(VPlan &Plan); - static bool mergeReplicateRegions(VPlan &Plan); + /// Merge replicate regions in their successor region, if a replicate region + /// is connected to a successor replicate region with the same predicate by a + /// single, empty VPBasicBlock. + static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan); + + /// Remove redundant VPBasicBlocks by merging them into their predecessor if + /// the predecessor has a single successor. + static bool mergeBlocksIntoPredecessors(VPlan &Plan); /// Remove redundant casts of inductions. /// @@ -61,6 +70,12 @@ struct VPlanTransforms { /// Remove redundant EpxandSCEVRecipes in \p Plan's entry block by replacing /// them with already existing recipes expanding the same SCEV expression. static void removeRedundantExpandSCEVRecipes(VPlan &Plan); + + /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the + /// resulting plan to \p BestVF and \p BestUF. + static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, + unsigned BestUF, + PredicatedScalarEvolution &PSE); }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h index c99fae1b2ab4..62ec65cbfe5d 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -85,35 +85,19 @@ public: const Value *getUnderlyingValue() const { return UnderlyingVal; } /// An enumeration for keeping track of the concrete subclass of VPValue that - /// are actually instantiated. Values of this enumeration are kept in the - /// SubclassID field of the VPValue objects. They are used for concrete - /// type identification. + /// are actually instantiated. enum { - VPValueSC, - VPVInstructionSC, - VPVMemoryInstructionSC, - VPVReductionSC, - VPVReplicateSC, - VPVWidenSC, - VPVWidenCallSC, - VPVWidenCanonicalIVSC, - VPVWidenGEPSC, - VPVWidenSelectSC, - - // Phi-like VPValues. Need to be kept together. - VPVBlendSC, - VPVCanonicalIVPHISC, - VPVActiveLaneMaskPHISC, - VPVFirstOrderRecurrencePHISC, - VPVWidenPHISC, - VPVWidenIntOrFpInductionSC, - VPVWidenPointerInductionSC, - VPVPredInstPHI, - VPVReductionPHISC, + VPValueSC, /// A generic VPValue, like live-in values or defined by a recipe + /// that defines multiple values. + VPVRecipeSC /// A VPValue sub-class that is a VPRecipeBase. }; - VPValue(Value *UV = nullptr, VPDef *Def = nullptr) - : VPValue(VPValueSC, UV, Def) {} + /// Create a live-in VPValue. + VPValue(Value *UV = nullptr) : VPValue(VPValueSC, UV, nullptr) {} + /// Create a VPValue for a \p Def which is a subclass of VPValue. + VPValue(VPDef *Def, Value *UV = nullptr) : VPValue(VPVRecipeSC, UV, Def) {} + /// Create a VPValue for a \p Def which defines multiple values. + VPValue(Value *UV, VPDef *Def) : VPValue(VPValueSC, UV, Def) {} VPValue(const VPValue &) = delete; VPValue &operator=(const VPValue &) = delete; @@ -179,22 +163,32 @@ public: void replaceAllUsesWith(VPValue *New); - VPDef *getDef() { return Def; } - const VPDef *getDef() const { return Def; } + /// Returns the recipe defining this VPValue or nullptr if it is not defined + /// by a recipe, i.e. is a live-in. + VPRecipeBase *getDefiningRecipe(); + const VPRecipeBase *getDefiningRecipe() const; + + /// Returns true if this VPValue is defined by a recipe. + bool hasDefiningRecipe() const { return getDefiningRecipe(); } /// Returns the underlying IR value, if this VPValue is defined outside the /// scope of VPlan. Returns nullptr if the VPValue is defined by a VPDef /// inside a VPlan. Value *getLiveInIRValue() { - assert(!getDef() && + assert(!hasDefiningRecipe() && "VPValue is not a live-in; it is defined by a VPDef inside a VPlan"); return getUnderlyingValue(); } const Value *getLiveInIRValue() const { - assert(!getDef() && + assert(!hasDefiningRecipe() && "VPValue is not a live-in; it is defined by a VPDef inside a VPlan"); return getUnderlyingValue(); } + + /// Returns true if the VPValue is defined outside any vector regions, i.e. it + /// is a live-in value. + /// TODO: Also handle recipes defined in pre-header blocks. + bool isDefinedOutsideVectorRegions() const { return !hasDefiningRecipe(); } }; typedef DenseMap<Value *, VPValue *> Value2VPValueTy; @@ -284,9 +278,6 @@ public: return const_operand_range(op_begin(), op_end()); } - /// Method to support type inquiry through isa, cast, and dyn_cast. - static inline bool classof(const VPDef *Recipe); - /// Returns true if the VPUser uses scalars of operand \p Op. Conservatively /// returns if only first (scalar) lane is used, as default. virtual bool usesScalars(const VPValue *Op) const { @@ -320,7 +311,7 @@ class VPDef { /// Add \p V as a defined value by this VPDef. void addDefinedValue(VPValue *V) { - assert(V->getDef() == this && + assert(V->Def == this && "can only add VPValue already linked with this VPDef"); DefinedValues.push_back(V); } @@ -328,8 +319,7 @@ class VPDef { /// Remove \p V from the values defined by this VPDef. \p V must be a defined /// value of this VPDef. void removeDefinedValue(VPValue *V) { - assert(V->getDef() == this && - "can only remove VPValue linked with this VPDef"); + assert(V->Def == this && "can only remove VPValue linked with this VPDef"); assert(is_contained(DefinedValues, V) && "VPValue to remove must be in DefinedValues"); erase_value(DefinedValues, V); @@ -343,6 +333,7 @@ public: /// type identification. using VPRecipeTy = enum { VPBranchOnMaskSC, + VPDerivedIVSC, VPExpandSCEVSC, VPInstructionSC, VPInterleaveSC, @@ -358,15 +349,17 @@ public: // Phi-like recipes. Need to be kept together. VPBlendSC, + VPPredInstPHISC, + // Header-phi recipes. Need to be kept together. VPCanonicalIVPHISC, VPActiveLaneMaskPHISC, VPFirstOrderRecurrencePHISC, VPWidenPHISC, VPWidenIntOrFpInductionSC, VPWidenPointerInductionSC, - VPPredInstPHISC, VPReductionPHISC, VPFirstPHISC = VPBlendSC, + VPFirstHeaderPHISC = VPCanonicalIVPHISC, VPLastPHISC = VPReductionPHISC, }; diff --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index 43e0a40fedb9..18125cebed33 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -14,6 +14,7 @@ #include "VPlanVerifier.h" #include "VPlan.h" +#include "VPlanCFG.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/Support/CommandLine.h" @@ -43,9 +44,7 @@ static bool hasDuplicates(const SmallVectorImpl<VPBlockBase *> &VPBlockVec) { /// \p Region. Checks in this function are generic for VPBlockBases. They are /// not specific for VPBasicBlocks or VPRegionBlocks. static void verifyBlocksInRegion(const VPRegionBlock *Region) { - for (const VPBlockBase *VPB : make_range( - df_iterator<const VPBlockBase *>::begin(Region->getEntry()), - df_iterator<const VPBlockBase *>::end(Region->getExiting()))) { + for (const VPBlockBase *VPB : vp_depth_first_shallow(Region->getEntry())) { // Check block's parent. assert(VPB->getParent() == Region && "VPBlockBase has wrong parent"); @@ -133,17 +132,38 @@ void VPlanVerifier::verifyHierarchicalCFG( verifyRegionRec(TopRegion); } -static bool -verifyVPBasicBlock(const VPBasicBlock *VPBB, - DenseMap<const VPBlockBase *, unsigned> &BlockNumbering) { - // Verify that phi-like recipes are at the beginning of the block, with no - // other recipes in between. +// Verify that phi-like recipes are at the beginning of \p VPBB, with no +// other recipes in between. Also check that only header blocks contain +// VPHeaderPHIRecipes. +static bool verifyPhiRecipes(const VPBasicBlock *VPBB) { auto RecipeI = VPBB->begin(); auto End = VPBB->end(); unsigned NumActiveLaneMaskPhiRecipes = 0; + const VPRegionBlock *ParentR = VPBB->getParent(); + bool IsHeaderVPBB = ParentR && !ParentR->isReplicator() && + ParentR->getEntryBasicBlock() == VPBB; while (RecipeI != End && RecipeI->isPhi()) { if (isa<VPActiveLaneMaskPHIRecipe>(RecipeI)) NumActiveLaneMaskPhiRecipes++; + + if (IsHeaderVPBB && !isa<VPHeaderPHIRecipe>(*RecipeI)) { + errs() << "Found non-header PHI recipe in header VPBB"; +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + errs() << ": "; + RecipeI->dump(); +#endif + return false; + } + + if (!IsHeaderVPBB && isa<VPHeaderPHIRecipe>(*RecipeI)) { + errs() << "Found header PHI recipe in non-header VPBB"; +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + errs() << ": "; + RecipeI->dump(); +#endif + return false; + } + RecipeI++; } @@ -166,6 +186,14 @@ verifyVPBasicBlock(const VPBasicBlock *VPBB, } RecipeI++; } + return true; +} + +static bool +verifyVPBasicBlock(const VPBasicBlock *VPBB, + DenseMap<const VPBlockBase *, unsigned> &BlockNumbering) { + if (!verifyPhiRecipes(VPBB)) + return false; // Verify that defs in VPBB dominate all their uses. The current // implementation is still incomplete. @@ -224,8 +252,7 @@ verifyVPBasicBlock(const VPBasicBlock *VPBB, bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { DenseMap<const VPBlockBase *, unsigned> BlockNumbering; unsigned Cnt = 0; - auto Iter = depth_first( - VPBlockRecursiveTraversalWrapper<const VPBlockBase *>(Plan.getEntry())); + auto Iter = vp_depth_first_deep(Plan.getEntry()); for (const VPBlockBase *VPB : Iter) { BlockNumbering[VPB] = Cnt++; auto *VPBB = dyn_cast<VPBasicBlock>(VPB); @@ -270,8 +297,7 @@ bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { for (const VPRegionBlock *Region : VPBlockUtils::blocksOnly<const VPRegionBlock>( - depth_first(VPBlockRecursiveTraversalWrapper<const VPBlockBase *>( - Plan.getEntry())))) { + vp_depth_first_deep(Plan.getEntry()))) { if (Region->getEntry()->getNumPredecessors() != 0) { errs() << "region entry block has predecessors\n"; return false; @@ -282,7 +308,7 @@ bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { } } - for (auto &KV : Plan.getLiveOuts()) + for (const auto &KV : Plan.getLiveOuts()) if (KV.second->getNumOperands() != 1) { errs() << "live outs must have a single operand\n"; return false; diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index a38936644bd3..2e489757ebc1 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -30,6 +30,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Vectorize.h" +#include <numeric> #define DEBUG_TYPE "vector-combine" #include "llvm/Transforms/Utils/InstructionWorklist.h" @@ -64,9 +65,9 @@ class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, const DominatorTree &DT, AAResults &AA, AssumptionCache &AC, - bool ScalarizationOnly) + bool TryEarlyFoldsOnly) : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), - ScalarizationOnly(ScalarizationOnly) {} + TryEarlyFoldsOnly(TryEarlyFoldsOnly) {} bool run(); @@ -78,13 +79,17 @@ private: AAResults &AA; AssumptionCache &AC; - /// If true only perform scalarization combines and do not introduce new + /// If true, only perform beneficial early IR transforms. Do not introduce new /// vector operations. - bool ScalarizationOnly; + bool TryEarlyFoldsOnly; InstructionWorklist Worklist; + // TODO: Direct calls from the top-level "run" loop use a plain "Instruction" + // parameter. That should be updated to specific sub-classes because the + // run loop was changed to dispatch on opcode. bool vectorizeLoadInsert(Instruction &I); + bool widenSubvectorLoad(Instruction &I); ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, ExtractElementInst *Ext1, unsigned PreferredExtractIndex) const; @@ -97,6 +102,7 @@ private: void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1, Instruction &I); bool foldExtractExtract(Instruction &I); + bool foldInsExtFNeg(Instruction &I); bool foldBitcastShuf(Instruction &I); bool scalarizeBinopOrCmp(Instruction &I); bool foldExtractedCmps(Instruction &I); @@ -125,12 +131,32 @@ private: }; } // namespace +static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) { + // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan. + // The widened load may load data from dirty regions or create data races + // non-existent in the source. + if (!Load || !Load->isSimple() || !Load->hasOneUse() || + Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || + mustSuppressSpeculation(*Load)) + return false; + + // We are potentially transforming byte-sized (8-bit) memory accesses, so make + // sure we have all of our type-based constraints in place for this target. + Type *ScalarTy = Load->getType()->getScalarType(); + uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); + unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); + if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 || + ScalarSize % 8 != 0) + return false; + + return true; +} + bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // Match insert into fixed vector of scalar value. // TODO: Handle non-zero insert index. - auto *Ty = dyn_cast<FixedVectorType>(I.getType()); Value *Scalar; - if (!Ty || !match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) || + if (!match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) || !Scalar->hasOneUse()) return false; @@ -140,40 +166,28 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { if (!HasExtract) X = Scalar; - // Match source value as load of scalar or vector. - // Do not vectorize scalar load (widening) if atomic/volatile or under - // asan/hwasan/memtag/tsan. The widened load may load data from dirty regions - // or create data races non-existent in the source. auto *Load = dyn_cast<LoadInst>(X); - if (!Load || !Load->isSimple() || !Load->hasOneUse() || - Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || - mustSuppressSpeculation(*Load)) + if (!canWidenLoad(Load, TTI)) return false; - const DataLayout &DL = I.getModule()->getDataLayout(); - Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); - assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); - - unsigned AS = Load->getPointerAddressSpace(); - - // We are potentially transforming byte-sized (8-bit) memory accesses, so make - // sure we have all of our type-based constraints in place for this target. Type *ScalarTy = Scalar->getType(); uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits(); unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth(); - if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 || - ScalarSize % 8 != 0) - return false; // Check safety of replacing the scalar load with a larger vector load. // We use minimal alignment (maximum flexibility) because we only care about // the dereferenceable region. When calculating cost and creating a new op, // we may use a larger value based on alignment attributes. + const DataLayout &DL = I.getModule()->getDataLayout(); + Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); + assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); + unsigned MinVecNumElts = MinVectorSize / ScalarSize; auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false); unsigned OffsetEltIndex = 0; Align Alignment = Load->getAlign(); - if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &DT)) { + if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &AC, + &DT)) { // It is not safe to load directly from the pointer, but we can still peek // through gep offsets and check if it safe to load from a base address with // updated alignment. If it is, we can shuffle the element(s) into place @@ -198,7 +212,8 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { if (OffsetEltIndex >= MinVecNumElts) return false; - if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &DT)) + if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &AC, + &DT)) return false; // Update alignment with offset value. Note that the offset could be negated @@ -211,11 +226,14 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // Use the greater of the alignment on the load or its source pointer. Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment); Type *LoadTy = Load->getType(); + unsigned AS = Load->getPointerAddressSpace(); InstructionCost OldCost = TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS); APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0); - OldCost += TTI.getScalarizationOverhead(MinVecTy, DemandedElts, - /* Insert */ true, HasExtract); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + OldCost += + TTI.getScalarizationOverhead(MinVecTy, DemandedElts, + /* Insert */ true, HasExtract, CostKind); // New pattern: load VecPtr InstructionCost NewCost = @@ -227,6 +245,7 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // We assume this operation has no cost in codegen if there was no offset. // Note that we could use freeze to avoid poison problems, but then we might // still need a shuffle to change the vector size. + auto *Ty = cast<FixedVectorType>(I.getType()); unsigned OutputNumElts = Ty->getNumElements(); SmallVector<int, 16> Mask(OutputNumElts, UndefMaskElem); assert(OffsetEltIndex < MinVecNumElts && "Address offset too big"); @@ -252,6 +271,66 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { return true; } +/// If we are loading a vector and then inserting it into a larger vector with +/// undefined elements, try to load the larger vector and eliminate the insert. +/// This removes a shuffle in IR and may allow combining of other loaded values. +bool VectorCombine::widenSubvectorLoad(Instruction &I) { + // Match subvector insert of fixed vector. + auto *Shuf = cast<ShuffleVectorInst>(&I); + if (!Shuf->isIdentityWithPadding()) + return false; + + // Allow a non-canonical shuffle mask that is choosing elements from op1. + unsigned NumOpElts = + cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements(); + unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) { + return M >= (int)(NumOpElts); + }); + + auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex)); + if (!canWidenLoad(Load, TTI)) + return false; + + // We use minimal alignment (maximum flexibility) because we only care about + // the dereferenceable region. When calculating cost and creating a new op, + // we may use a larger value based on alignment attributes. + auto *Ty = cast<FixedVectorType>(I.getType()); + const DataLayout &DL = I.getModule()->getDataLayout(); + Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); + assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); + Align Alignment = Load->getAlign(); + if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), DL, Load, &AC, &DT)) + return false; + + Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment); + Type *LoadTy = Load->getType(); + unsigned AS = Load->getPointerAddressSpace(); + + // Original pattern: insert_subvector (load PtrOp) + // This conservatively assumes that the cost of a subvector insert into an + // undef value is 0. We could add that cost if the cost model accurately + // reflects the real cost of that operation. + InstructionCost OldCost = + TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS); + + // New pattern: load PtrOp + InstructionCost NewCost = + TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS); + + // We can aggressively convert to the vector form because the backend can + // invert this transform if it does not result in a performance win. + if (OldCost < NewCost || !NewCost.isValid()) + return false; + + IRBuilder<> Builder(Load); + Value *CastedPtr = + Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Ty->getPointerTo(AS)); + Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment); + replaceValue(I, *VecLd); + ++NumVecLoad; + return true; +} + /// Determine which, if any, of the inputs should be replaced by a shuffle /// followed by extract from a different index. ExtractElementInst *VectorCombine::getShuffleExtract( @@ -269,11 +348,12 @@ ExtractElementInst *VectorCombine::getShuffleExtract( return nullptr; Type *VecTy = Ext0->getVectorOperand()->getType(); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types"); InstructionCost Cost0 = - TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); + TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); InstructionCost Cost1 = - TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); + TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); // If both costs are invalid no shuffle is needed if (!Cost0.isValid() && !Cost1.isValid()) @@ -336,11 +416,12 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, // both sequences. unsigned Ext0Index = Ext0IndexC->getZExtValue(); unsigned Ext1Index = Ext1IndexC->getZExtValue(); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; InstructionCost Extract0Cost = - TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index); + TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index); InstructionCost Extract1Cost = - TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index); + TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index); // A more expensive extract will always be replaced by a splat shuffle. // For example, if Ext0 is more expensive: @@ -533,6 +614,69 @@ bool VectorCombine::foldExtractExtract(Instruction &I) { return true; } +/// Try to replace an extract + scalar fneg + insert with a vector fneg + +/// shuffle. +bool VectorCombine::foldInsExtFNeg(Instruction &I) { + // Match an insert (op (extract)) pattern. + Value *DestVec; + uint64_t Index; + Instruction *FNeg; + if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)), + m_ConstantInt(Index)))) + return false; + + // Note: This handles the canonical fneg instruction and "fsub -0.0, X". + Value *SrcVec; + Instruction *Extract; + if (!match(FNeg, m_FNeg(m_CombineAnd( + m_Instruction(Extract), + m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index)))))) + return false; + + // TODO: We could handle this with a length-changing shuffle. + auto *VecTy = cast<FixedVectorType>(I.getType()); + if (SrcVec->getType() != VecTy) + return false; + + // Ignore bogus insert/extract index. + unsigned NumElts = VecTy->getNumElements(); + if (Index >= NumElts) + return false; + + // We are inserting the negated element into the same lane that we extracted + // from. This is equivalent to a select-shuffle that chooses all but the + // negated element from the destination vector. + SmallVector<int> Mask(NumElts); + std::iota(Mask.begin(), Mask.end(), 0); + Mask[Index] = Index + NumElts; + + Type *ScalarTy = VecTy->getScalarType(); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + InstructionCost OldCost = + TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy) + + TTI.getVectorInstrCost(I, VecTy, CostKind, Index); + + // If the extract has one use, it will be eliminated, so count it in the + // original cost. If it has more than one use, ignore the cost because it will + // be the same before/after. + if (Extract->hasOneUse()) + OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index); + + InstructionCost NewCost = + TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy) + + TTI.getShuffleCost(TargetTransformInfo::SK_Select, VecTy, Mask); + + if (NewCost > OldCost) + return false; + + // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index --> + // shuffle DestVec, (fneg SrcVec), Mask + Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg); + Value *Shuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask); + replaceValue(I, *Shuf); + return true; +} + /// If this is a bitcast of a shuffle, try to bitcast the source vector to the /// destination type followed by shuffle. This can enable further transforms by /// moving bitcasts or shuffles together. @@ -548,11 +692,11 @@ bool VectorCombine::foldBitcastShuf(Instruction &I) { // mask for scalable type is a splat or not. // 2) Disallow non-vector casts and length-changing shuffles. // TODO: We could allow any shuffle. - auto *DestTy = dyn_cast<FixedVectorType>(I.getType()); auto *SrcTy = dyn_cast<FixedVectorType>(V->getType()); - if (!SrcTy || !DestTy || I.getOperand(0)->getType() != SrcTy) + if (!SrcTy || I.getOperand(0)->getType() != SrcTy) return false; + auto *DestTy = cast<FixedVectorType>(I.getType()); unsigned DestNumElts = DestTy->getNumElements(); unsigned SrcNumElts = SrcTy->getNumElements(); SmallVector<int, 16> NewMask; @@ -664,8 +808,9 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { // Get cost estimate for the insert element. This cost will factor into // both sequences. - InstructionCost InsertCost = - TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + InstructionCost InsertCost = TTI.getVectorInstrCost( + Instruction::InsertElement, VecTy, CostKind, Index); InstructionCost OldCost = (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) + VectorOpCost; InstructionCost NewCost = ScalarOpCost + InsertCost + @@ -754,9 +899,10 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { if (!VecTy) return false; + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; InstructionCost OldCost = - TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); - OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); + TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); + OldCost += TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred) * @@ -776,7 +922,7 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy, ShufMask); NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy); - NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex); + NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex); // Aggressively form vector ops if the cost is equal because the transform // may enable further optimization. @@ -811,6 +957,7 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin, }); } +namespace { /// Helper class to indicate whether a vector index can be safely scalarized and /// if a freeze needs to be inserted. class ScalarizationResult { @@ -865,6 +1012,7 @@ public: ToFreeze = nullptr; } }; +} // namespace /// Check if it is legal to scalarize a memory access to \p VecTy at index \p /// Idx. \p Idx must access a valid vector element. @@ -928,8 +1076,8 @@ static Align computeAlignmentAfterScalarization(Align VectorAlignment, // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1 // store i32 %b, i32* %1 bool VectorCombine::foldSingleElementStore(Instruction &I) { - StoreInst *SI = dyn_cast<StoreInst>(&I); - if (!SI || !SI->isSimple() || + auto *SI = cast<StoreInst>(&I); + if (!SI->isSimple() || !isa<FixedVectorType>(SI->getValueOperand()->getType())) return false; @@ -985,17 +1133,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { if (!match(&I, m_Load(m_Value(Ptr)))) return false; + auto *FixedVT = cast<FixedVectorType>(I.getType()); auto *LI = cast<LoadInst>(&I); const DataLayout &DL = I.getModule()->getDataLayout(); - if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(LI->getType())) - return false; - - auto *FixedVT = dyn_cast<FixedVectorType>(LI->getType()); - if (!FixedVT) + if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(FixedVT)) return false; InstructionCost OriginalCost = - TTI.getMemoryOpCost(Instruction::Load, LI->getType(), LI->getAlign(), + TTI.getMemoryOpCost(Instruction::Load, FixedVT, LI->getAlign(), LI->getPointerAddressSpace()); InstructionCost ScalarizedCost = 0; @@ -1034,8 +1179,9 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { } auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1)); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; OriginalCost += - TTI.getVectorInstrCost(Instruction::ExtractElement, LI->getType(), + TTI.getVectorInstrCost(Instruction::ExtractElement, FixedVT, CostKind, Index ? Index->getZExtValue() : -1); ScalarizedCost += TTI.getMemoryOpCost(Instruction::Load, FixedVT->getElementType(), @@ -1070,10 +1216,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { /// Try to convert "shuffle (binop), (binop)" with a shared binop operand into /// "binop (shuffle), (shuffle)". bool VectorCombine::foldShuffleOfBinops(Instruction &I) { - auto *VecTy = dyn_cast<FixedVectorType>(I.getType()); - if (!VecTy) - return false; - + auto *VecTy = cast<FixedVectorType>(I.getType()); BinaryOperator *B0, *B1; ArrayRef<int> Mask; if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)), @@ -1244,15 +1387,14 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) { /// architectures with no obvious "select" shuffle, this can reduce the total /// number of operations if the target reports them as cheaper. bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { - auto *SVI = dyn_cast<ShuffleVectorInst>(&I); - auto *VT = dyn_cast<FixedVectorType>(I.getType()); - if (!SVI || !VT) - return false; + auto *SVI = cast<ShuffleVectorInst>(&I); + auto *VT = cast<FixedVectorType>(I.getType()); auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0)); auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1)); if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() || VT != Op0->getType()) return false; + auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0)); auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1)); auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0)); @@ -1300,7 +1442,7 @@ bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { // cost calculations. if (!FromReduction) { for (ShuffleVectorInst *SV : Shuffles) { - for (auto U : SV->users()) { + for (auto *U : SV->users()) { ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U); if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT) Shuffles.push_back(SSV); @@ -1569,19 +1711,78 @@ bool VectorCombine::run() { bool MadeChange = false; auto FoldInst = [this, &MadeChange](Instruction &I) { Builder.SetInsertPoint(&I); - if (!ScalarizationOnly) { - MadeChange |= vectorizeLoadInsert(I); - MadeChange |= foldExtractExtract(I); - MadeChange |= foldBitcastShuf(I); - MadeChange |= foldExtractedCmps(I); - MadeChange |= foldShuffleOfBinops(I); - MadeChange |= foldShuffleFromReductions(I); - MadeChange |= foldSelectShuffle(I); + bool IsFixedVectorType = isa<FixedVectorType>(I.getType()); + auto Opcode = I.getOpcode(); + + // These folds should be beneficial regardless of when this pass is run + // in the optimization pipeline. + // The type checking is for run-time efficiency. We can avoid wasting time + // dispatching to folding functions if there's no chance of matching. + if (IsFixedVectorType) { + switch (Opcode) { + case Instruction::InsertElement: + MadeChange |= vectorizeLoadInsert(I); + break; + case Instruction::ShuffleVector: + MadeChange |= widenSubvectorLoad(I); + break; + case Instruction::Load: + MadeChange |= scalarizeLoadExtract(I); + break; + default: + break; + } + } + + // This transform works with scalable and fixed vectors + // TODO: Identify and allow other scalable transforms + if (isa<VectorType>(I.getType())) + MadeChange |= scalarizeBinopOrCmp(I); + + if (Opcode == Instruction::Store) + MadeChange |= foldSingleElementStore(I); + + + // If this is an early pipeline invocation of this pass, we are done. + if (TryEarlyFoldsOnly) + return; + + // Otherwise, try folds that improve codegen but may interfere with + // early IR canonicalizations. + // The type checking is for run-time efficiency. We can avoid wasting time + // dispatching to folding functions if there's no chance of matching. + if (IsFixedVectorType) { + switch (Opcode) { + case Instruction::InsertElement: + MadeChange |= foldInsExtFNeg(I); + break; + case Instruction::ShuffleVector: + MadeChange |= foldShuffleOfBinops(I); + MadeChange |= foldSelectShuffle(I); + break; + case Instruction::BitCast: + MadeChange |= foldBitcastShuf(I); + break; + } + } else { + switch (Opcode) { + case Instruction::Call: + MadeChange |= foldShuffleFromReductions(I); + break; + case Instruction::ICmp: + case Instruction::FCmp: + MadeChange |= foldExtractExtract(I); + break; + default: + if (Instruction::isBinaryOp(Opcode)) { + MadeChange |= foldExtractExtract(I); + MadeChange |= foldExtractedCmps(I); + } + break; + } } - MadeChange |= scalarizeBinopOrCmp(I); - MadeChange |= scalarizeLoadExtract(I); - MadeChange |= foldSingleElementStore(I); }; + for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(&BB)) @@ -1664,7 +1865,7 @@ PreservedAnalyses VectorCombinePass::run(Function &F, TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); AAResults &AA = FAM.getResult<AAManager>(F); - VectorCombine Combiner(F, TTI, DT, AA, AC, ScalarizationOnly); + VectorCombine Combiner(F, TTI, DT, AA, AC, TryEarlyFoldsOnly); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; |